package ai.stainless.micronaut.jupyter.kernel;

import com.twosigma.beakerx.handler.Handler;
import com.twosigma.beakerx.kernel.Config;
import com.twosigma.beakerx.kernel.KernelFunctionality;
import com.twosigma.beakerx.kernel.KernelSockets;
import com.twosigma.beakerx.kernel.SocketCloseAction;
import com.twosigma.beakerx.kernel.msg.JupyterMessages;
import com.twosigma.beakerx.message.Header;
import com.twosigma.beakerx.message.Message;
import com.twosigma.beakerx.message.MessageSerializer;
import com.twosigma.beakerx.security.HashedMessageAuthenticationCode;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zeromq.ZFrame;
import org.zeromq.ZMQ;
import org.zeromq.ZMsg;

/* loaded from: input_file:ai/stainless/micronaut/jupyter/kernel/ClosableKernelSocketsZMQ.class */
public class ClosableKernelSocketsZMQ extends KernelSockets {
    public static final Logger logger = LoggerFactory.getLogger(ClosableKernelSocketsZMQ.class);
    public static final String DELIM = "<IDS|MSG>";
    private KernelFunctionality kernel;
    private SocketCloseAction closeAction;
    private HashedMessageAuthenticationCode hmac;
    private ZMQ.Socket hearbeatSocket;
    private ZMQ.Socket controlSocket;
    private ZMQ.Socket shellSocket;
    private ZMQ.Socket iopubSocket;
    private ZMQ.Socket stdinSocket;
    private ZMQ.Poller sockets;
    private boolean shutdownSystem = false;
    private ZMQ.Context context = ZMQ.context(1);

    public ClosableKernelSocketsZMQ(KernelFunctionality kernelFunctionality, Config config, SocketCloseAction socketCloseAction) {
        this.closeAction = socketCloseAction;
        this.kernel = kernelFunctionality;
        this.hmac = new HashedMessageAuthenticationCode(config.getKey());
        configureSockets(config);
    }

    private void configureSockets(Config config) {
        String str = config.getTransport() + "://" + config.getHost();
        this.hearbeatSocket = getNewSocket(4, config.getHeartbeat(), str, this.context);
        this.iopubSocket = getNewSocket(1, config.getIopub(), str, this.context);
        this.controlSocket = getNewSocket(6, config.getControl(), str, this.context);
        this.stdinSocket = getNewSocket(6, config.getStdin(), str, this.context);
        this.shellSocket = getNewSocket(6, config.getShell(), str, this.context);
        this.sockets = new ZMQ.Poller(4);
        this.sockets.register(this.controlSocket, 1);
        this.sockets.register(this.hearbeatSocket, 1);
        this.sockets.register(this.shellSocket, 1);
        this.sockets.register(this.stdinSocket, 1);
    }

    public void publish(List<Message> list) {
        sendMsg(this.iopubSocket, list);
    }

    public void send(Message message) {
        sendMsg(this.shellSocket, Collections.singletonList(message));
    }

    public String sendStdIn(Message message) {
        sendMsg(this.stdinSocket, Collections.singletonList(message));
        return handleStdIn();
    }

    private synchronized void sendMsg(ZMQ.Socket socket, List<Message> list) {
        if (isShutdown()) {
            return;
        }
        list.forEach(message -> {
            String json = MessageSerializer.toJson(message.getHeader());
            String json2 = MessageSerializer.toJson(message.getParentHeader());
            String json3 = MessageSerializer.toJson(message.getMetadata());
            String json4 = MessageSerializer.toJson(message.getContent());
            String sign = this.hmac.sign(Arrays.asList(json, json2, json3, json4));
            ZMsg zMsg = new ZMsg();
            List identities = message.getIdentities();
            Objects.requireNonNull(zMsg);
            identities.forEach(zMsg::add);
            zMsg.add(DELIM);
            zMsg.add(sign.getBytes(StandardCharsets.UTF_8));
            zMsg.add(json.getBytes(StandardCharsets.UTF_8));
            zMsg.add(json2.getBytes(StandardCharsets.UTF_8));
            zMsg.add(json3.getBytes(StandardCharsets.UTF_8));
            zMsg.add(json4.getBytes(StandardCharsets.UTF_8));
            message.getBuffers().forEach(bArr -> {
                zMsg.add(bArr);
            });
            zMsg.send(socket);
        });
    }

    private Message readMessage(ZMQ.Socket socket) {
        ZMsg zMsg = null;
        try {
            zMsg = ZMsg.recvMsg(socket);
            ZFrame[] zFrameArr = new ZFrame[zMsg.size()];
            zMsg.toArray(zFrameArr);
            byte[] data = zFrameArr[0].getData();
            byte[] data2 = zFrameArr[3].getData();
            byte[] data3 = zFrameArr[4].getData();
            byte[] data4 = zFrameArr[5].getData();
            byte[] data5 = zFrameArr[6].getData();
            byte[] data6 = zFrameArr[2].getData();
            verifyDelim(zFrameArr[1]);
            verifySignatures(data6, data2, data3, data4, data5);
            Message message = new Message((Header) parse(data2, Header.class));
            if (data != null) {
                message.getIdentities().add(data);
            }
            message.setParentHeader((Header) parse(data3, Header.class));
            message.setMetadata((Map) parse(data4, LinkedHashMap.class));
            message.setContent((Map) parse(data5, LinkedHashMap.class));
            if (zMsg != null) {
                zMsg.destroy();
            }
            return message;
        } catch (Throwable th) {
            if (zMsg != null) {
                zMsg.destroy();
            }
            throw th;
        }
    }

    public void run() {
        while (!isShutdown()) {
            try {
                try {
                    this.sockets.poll(1000L);
                    if (isControlMsg()) {
                        handleControlMsg();
                    } else if (isHeartbeatMsg()) {
                        handleHeartbeat();
                    } else if (isShellMsg()) {
                        handleShell();
                    } else if (isStdinMsg()) {
                        handleStdIn();
                    } else if (isShutdown()) {
                        break;
                    }
                } catch (Error e) {
                    logger.error(e.toString());
                    close();
                    return;
                } catch (Exception e2) {
                    throw new RuntimeException(e2);
                }
            } catch (Throwable th) {
                close();
                throw th;
            }
        }
        close();
    }

    private String handleStdIn() {
        return (String) readMessage(this.stdinSocket).getContent().get("value");
    }

    private void handleShell() {
        Message readMessage = readMessage(this.shellSocket);
        Handler handler = this.kernel.getHandler(readMessage.type());
        if (handler != null) {
            handler.handle(readMessage);
        }
    }

    private void handleHeartbeat() {
        this.hearbeatSocket.send(this.hearbeatSocket.recv(0));
    }

    private void handleControlMsg() {
        Message readMessage = readMessage(this.controlSocket);
        if (readMessage.getHeader().getTypeEnum().equals(JupyterMessages.SHUTDOWN_REQUEST)) {
            Message message = new Message(new Header(JupyterMessages.SHUTDOWN_REPLY, readMessage.getHeader().getSession()));
            message.setParentHeader(readMessage.getHeader());
            message.setContent(readMessage.getContent());
            sendMsg(this.controlSocket, Collections.singletonList(message));
            shutdown();
        }
    }

    private ZMQ.Socket getNewSocket(int i, int i2, String str, ZMQ.Context context) {
        ZMQ.Socket socket = context.socket(i);
        socket.bind(str + ":" + String.valueOf(i2));
        return socket;
    }

    private void close() {
        this.closeAction.close();
        closeSockets();
    }

    private void closeSockets() {
        try {
            if (this.shellSocket != null) {
                this.shellSocket.close();
            }
            if (this.controlSocket != null) {
                this.controlSocket.close();
            }
            if (this.iopubSocket != null) {
                this.iopubSocket.close();
            }
            if (this.stdinSocket != null) {
                this.stdinSocket.close();
            }
            if (this.hearbeatSocket != null) {
                this.hearbeatSocket.close();
            }
            this.context.close();
        } catch (Exception e) {
        }
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [byte[], java.lang.Object[]] */
    private void verifySignatures(byte[] bArr, byte[] bArr2, byte[] bArr3, byte[] bArr4, byte[] bArr5) {
        if (!new String(bArr, StandardCharsets.UTF_8).equals(this.hmac.signBytes(new ArrayList(Arrays.asList(new byte[]{bArr2, bArr3, bArr4, bArr5}))))) {
            throw new RuntimeException("Signatures do not match.");
        }
    }

    private String verifyDelim(ZFrame zFrame) {
        String str = new String(zFrame.getData(), StandardCharsets.UTF_8);
        if (DELIM.equals(str)) {
            return str;
        }
        throw new RuntimeException("Delimiter <IDS|MSG> not found");
    }

    private boolean isStdinMsg() {
        return this.sockets.pollin(3);
    }

    private boolean isShellMsg() {
        return this.sockets.pollin(2);
    }

    private boolean isHeartbeatMsg() {
        return this.sockets.pollin(1);
    }

    private boolean isControlMsg() {
        return this.sockets.pollin(0);
    }

    public void shutdown() {
        logger.debug("kernel shutdown");
        this.shutdownSystem = true;
    }

    private boolean isShutdown() {
        return this.shutdownSystem;
    }

    private <T> T parse(byte[] bArr, Class<T> cls) {
        if (bArr != null) {
            return (T) MessageSerializer.parse(new String(bArr, StandardCharsets.UTF_8), cls);
        }
        return null;
    }
}
