/*
 * This file is part of packetevents - https://github.com/retrooper/packetevents
 * Copyright (C) 2021 retrooper and contributors
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package com.github.retrooper.packetevents.wrapper;

import com.github.retrooper.packetevents.PacketEvents;
import com.github.retrooper.packetevents.event.impl.PacketReceiveEvent;
import com.github.retrooper.packetevents.event.impl.PacketSendEvent;
import com.github.retrooper.packetevents.manager.server.ServerVersion;
import com.github.retrooper.packetevents.netty.buffer.ByteBufAbstract;
import com.github.retrooper.packetevents.util.AdventureSerializer;
import com.github.retrooper.packetevents.protocol.entity.data.EntityData;
import com.github.retrooper.packetevents.protocol.entity.data.EntityDataType;
import com.github.retrooper.packetevents.protocol.entity.data.EntityDataTypes;
import com.github.retrooper.packetevents.protocol.entity.villager.VillagerData;
import com.github.retrooper.packetevents.protocol.item.ItemStack;
import com.github.retrooper.packetevents.protocol.item.type.ItemType;
import com.github.retrooper.packetevents.protocol.item.type.ItemTypes;
import com.github.retrooper.packetevents.protocol.nbt.NBTCompound;
import com.github.retrooper.packetevents.protocol.nbt.codec.NBTCodec;
import com.github.retrooper.packetevents.protocol.packettype.PacketTypeCommon;
import com.github.retrooper.packetevents.protocol.player.ClientVersion;
import com.github.retrooper.packetevents.protocol.player.GameMode;
import com.github.retrooper.packetevents.resources.ResourceLocation;
import com.github.retrooper.packetevents.util.StringUtil;
import com.github.retrooper.packetevents.util.Vector3i;
import net.kyori.adventure.text.Component;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;

public class PacketWrapper<T extends PacketWrapper> {
    public final ByteBufAbstract buffer;
    protected ClientVersion clientVersion;
    protected ServerVersion serverVersion;
    private int packetID;
    private boolean hasPreparedForSending;

    private static final int MODERN_MESSAGE_LENGTH = 262144;
    private static final int LEGACY_MESSAGE_LENGTH = 32767;

    public PacketWrapper(ClientVersion clientVersion, ServerVersion serverVersion, ByteBufAbstract buffer, int packetID) {
        this.clientVersion = clientVersion;
        this.serverVersion = serverVersion;
        this.buffer = buffer;
        this.packetID = packetID;
    }

    public PacketWrapper(PacketReceiveEvent event) {
        this(event, true);
    }

    public PacketWrapper(PacketReceiveEvent event, boolean readData) {
        this.clientVersion = event.getClientVersion();
        this.serverVersion = event.getServerVersion();
        this.buffer = event.getByteBuf();
        this.packetID = event.getPacketId();
        if (readData) {
            if (event.isCloned()) {
                int bufferIndex = getBuffer().readerIndex();
                readData();
                getBuffer().readerIndex(bufferIndex);
            } else {
                if (event.getLastUsedWrapper() == null) {
                    event.setLastUsedWrapper(this);
                    readData();
                } else {
                    readData((T) event.getLastUsedWrapper());
                }
            }
        }
    }

    public PacketWrapper(PacketSendEvent event) {
        this.clientVersion = event.getClientVersion();
        this.serverVersion = event.getServerVersion();
        this.buffer = event.getByteBuf();
        this.packetID = event.getPacketId();
        if (event.isCloned()) {
            int bufferIndex = getBuffer().readerIndex();
            readData();
            getBuffer().readerIndex(bufferIndex);
        }
        else {
            if (event.getLastUsedWrapper() == null) {
                event.setLastUsedWrapper(this);
                readData();
            } else {
                readData((T) event.getLastUsedWrapper());
            }
        }
    }

    public PacketWrapper(int packetID, ClientVersion clientVersion) {
        this(clientVersion, PacketEvents.getAPI().getServerManager().getVersion(),
                PacketEvents.getAPI().getNettyManager().buffer(), packetID);
    }

    public PacketWrapper(int packetID) {
        this(ClientVersion.UNKNOWN,
                PacketEvents.getAPI().getServerManager().getVersion(),
                PacketEvents.getAPI().getNettyManager().buffer(), packetID);
    }

    public PacketWrapper(PacketTypeCommon packetType) {
        this(packetType.getId());
    }

    public static PacketWrapper<?> createUniversalPacketWrapper(ByteBufAbstract byteBuf) {
        return new PacketWrapper(ClientVersion.UNKNOWN, PacketEvents.getAPI().getServerManager().getVersion(), byteBuf, -1);
    }

    public final void prepareForSend() {
        if (!hasPreparedForSending) {
            writeVarInt(packetID);
            writeData();
            hasPreparedForSending = true;
        }
    }

    public void readData() {

    }

    public void readData(T wrapper) {

    }

    public void writeData() {

    }

    public boolean hasPreparedForSending() {
        return hasPreparedForSending;
    }

    public void setHasPrepareForSending(boolean hasPreparedForSending) {
        this.hasPreparedForSending = hasPreparedForSending;
    }

    public ClientVersion getClientVersion() {
        return clientVersion;
    }

    public void setClientVersion(ClientVersion clientVersion) {
        this.clientVersion = clientVersion;
    }

    public ServerVersion getServerVersion() {
        return serverVersion;
    }

    public void setServerVersion(ServerVersion serverVersion) {
        this.serverVersion = serverVersion;
    }

    public ByteBufAbstract getBuffer() {
        return buffer;
    }

    public int getPacketId() {
        return packetID;
    }

    public void setPacketId(int packetID) {
        this.packetID = packetID;
    }

    public int getMaxMessageLength() {
        return serverVersion.isNewerThanOrEquals(ServerVersion.V_1_13) ? MODERN_MESSAGE_LENGTH : LEGACY_MESSAGE_LENGTH;
    }

    public void resetByteBuf() {
        buffer.clear();
        writeVarInt(packetID);
    }

    public byte readByte() {
        return buffer.readByte();
    }

    public void writeByte(int value) {
        buffer.writeByte(value);
    }

    public short readUnsignedByte() {
        return (short) (readByte() & 255);
    }

    public boolean readBoolean() {
        return readByte() != 0;
    }

    public void writeBoolean(boolean value) {
        writeByte(value ? 1 : 0);
    }

    public int readInt() {
        return buffer.readInt();
    }

    public void writeInt(int value) {
        buffer.writeInt(value);
    }

    public int readVarInt() {
        int result = 0;
        byte i;
        int j = 0;
        do {
            i = buffer.readByte();
            result |= (i & Byte.MAX_VALUE) << j++ * 7;
            if (j > 5)
                throw new RuntimeException("VarInt too big");
        } while ((i & 0x80) == 128);
        return result;
    }

    public void writeVarInt(int value) {
        while ((value & -128) != 0) {
            buffer.writeByte(value & Byte.MAX_VALUE | 128);
            value >>>= 7;
        }

        buffer.writeByte(value);
    }
    public <K, V> Map<K, V> readMap(Function<PacketWrapper<?>, K> keyFunction, Function<PacketWrapper<?>, V> valueFunction) {
        int size = readVarInt();
        Map<K, V> map = new HashMap<>(size);
        for (int i = 0; i < size; i++) {
            K key = keyFunction.apply(this);
            V value = valueFunction.apply(this);
            map.put(key, value);
        }
        return map;
    }

    public <K, V> void writeMap(Map<K, V> map, BiConsumer<PacketWrapper<?>, K> keyConsumer, BiConsumer<PacketWrapper<?>, V> valueConsumer) {
        writeVarInt(map.size());
        for (K key : map.keySet()) {
            V value = map.get(key);
            keyConsumer.accept(this, key);
            valueConsumer.accept(this, value);
        }
    }

    public VillagerData readVillagerData() {
        int villagerTypeId = readVarInt();
        int villagerProfessionId = readVarInt();
        int level = readVarInt();
        return new VillagerData(villagerTypeId, villagerProfessionId, level);
    }

    public void writeVillagerData(VillagerData data) {
        writeVarInt(data.getType().getId());
        writeVarInt(data.getProfession().getId());
        writeVarInt(data.getLevel());
    }

    @NotNull
    public ItemStack readItemStack() {
        boolean v1_13_2 = serverVersion.isNewerThanOrEquals(ServerVersion.V_1_13_2);
        if (v1_13_2) {
            if (!readBoolean()) {
                return ItemStack.EMPTY;
            }
        }
        int typeID = v1_13_2 ? readVarInt() : readShort();
        if (typeID < 0) {
            return ItemStack.EMPTY;
        }
        ItemType type = ItemTypes.getById(typeID);
        int amount = readByte();
        int legacyData = v1_13_2 ? -1 : readShort();
        NBTCompound nbt = readNBT();
        return ItemStack.builder()
                .type(type)
                .amount(amount)
                .nbt(nbt)
                .legacyData(legacyData)
                .build();
    }

    public void writeItemStack(ItemStack itemStack) {
        if (itemStack == null) {
            itemStack = ItemStack.EMPTY;
        }
        boolean v1_13_2 = serverVersion.isNewerThanOrEquals(ServerVersion.V_1_13_2);
        if (v1_13_2) {
            if (ItemStack.EMPTY.equals(itemStack)) {
                writeBoolean(false);
            } else {
                writeBoolean(true);
                int typeID;
                if (itemStack.getType() == null || ItemStack.EMPTY.equals(itemStack)) {
                    typeID = -1;
                } else {
                    typeID = itemStack.getType().getId();
                }
                writeVarInt(typeID);
                if (typeID >= 0) {
                    writeByte(itemStack.getAmount());
                    writeNBT(itemStack.getNBT());
                }
            }
        } else {
            int typeID;
            if (itemStack.getType() == null || itemStack.getAmount() == -1) {
                typeID = -1;
            } else {
                typeID = itemStack.getType().getId();
            }
            writeShort(typeID);
            if (typeID >= 0) {
                writeByte(itemStack.getAmount());
                writeShort(itemStack.getLegacyData());
                writeNBT(itemStack.getNBT());
            }
        }
    }

    public NBTCompound readNBT() {
        return NBTCodec.readNBT(buffer, serverVersion);
    }

    public void writeNBT(NBTCompound nbt) {
        NBTCodec.writeNBT(buffer, serverVersion, nbt);
    }

    public String readString() {
        return readString(32767);
    }

    public String readString(int maxLen) {
        int j = readVarInt();
        if (j > maxLen * 4) {
            throw new RuntimeException("The received encoded string buffer length is longer than maximum allowed (" + j + " > " + maxLen * 4 + ")");
        } else if (j < 0) {
            throw new RuntimeException("The received encoded string buffer length is less than zero! Weird string!");
        } else {
            String s = buffer.toString(buffer.readerIndex(), j, StandardCharsets.UTF_8);
            buffer.readerIndex(buffer.readerIndex() + j);
            if (s.length() > maxLen) {
                throw new RuntimeException("The received string length is longer than maximum allowed (" + j + " > " + maxLen + ")");
            } else {
                return s;
            }
        }
    }

    public void writeString(String s) {
        writeString(s, 32767);
    }

    public void writeString(String s, int maxLen) {
        writeString(s, maxLen, true);
    }

    public void writeString(String s, int maxLen, boolean substr) {
        if (substr) {
            s = StringUtil.maximizeLength(s, maxLen);
        }
        byte[] bytes = s.getBytes(StandardCharsets.UTF_8);
        if (!substr && bytes.length > maxLen) {
            throw new IllegalStateException("String too big (was " + bytes.length + " bytes encoded, max " + maxLen + ")");
        } else {
            writeVarInt(bytes.length);
            buffer.writeBytes(bytes);
        }
    }

    public Component readComponent() {
        return AdventureSerializer.parseComponent(readString(getMaxMessageLength()));
    }

    public void writeComponent(Component component) {
        writeString(AdventureSerializer.toJson(component));
    }

    public ResourceLocation readIdentifier(int maxLen) {
        return new ResourceLocation(readString(maxLen));
    }

    public ResourceLocation readIdentifier() {
        return readIdentifier(32767);
    }

    public void writeIdentifier(ResourceLocation identifier, int maxLen) {
        writeString(identifier.toString(), maxLen);
    }

    public void writeIdentifier(ResourceLocation identifier) {
        writeIdentifier(identifier, 32767);
    }

    public int readUnsignedShort() {
        return buffer.readUnsignedShort();
    }

    public short readShort() {
        return buffer.readShort();
    }

    public void writeShort(int value) {
        buffer.writeShort(value);
    }

    public int readVarShort() {
        int low = buffer.readUnsignedShort();
        int high = 0;
        if ((low & 0x8000) != 0) {
            low = low & 0x7FFF;
            high = buffer.readUnsignedByte();
        }
        return ((high & 0xFF) << 15) | low;
    }

    public void writeVarShort(int value) {
        int low = value & 0x7FFF;
        int high = (value & 0x7F8000) >> 15;
        if (high != 0) {
            low = low | 0x8000;
        }
        buffer.writeShort(low);
        if (high != 0) {
            buffer.writeByte(high);
        }
    }

    public long readLong() {
        return buffer.readLong();
    }

    public long readVarLong() {
        long value = 0;
        int size = 0;
        int b;
        while (((b = this.readByte()) & 0x80) == 0x80) {
            value |= (long) (b & 0x7F) << (size++ * 7);
        }
        return value | ((long) (b & 0x7F) << (size * 7));
    }

    public void writeLong(long value) {
        buffer.writeLong(value);
    }

    public void writeVarLong(long l) {
        while ((l & ~0x7F) != 0) {
            this.writeByte((int) (l & 0x7F) | 0x80);
            l >>>= 7;
        }

        this.writeByte((int) l);
    }

    public float readFloat() {
        return buffer.readFloat();
    }

    public void writeFloat(float value) {
        buffer.writeFloat(value);
    }

    public double readDouble() {
        return buffer.readDouble();
    }

    public void writeDouble(double value) {
        buffer.writeDouble(value);
    }

    public byte[] readBytes(int size) {
        byte[] bytes = new byte[size];
        buffer.readBytes(bytes);
        return bytes;
    }

    public void writeBytes(byte[] array) {
        buffer.writeBytes(array);
    }

    public byte[] readByteArray(int maxLength) {
        int len = readVarInt();
        if (len > maxLength) {
            throw new RuntimeException("The received byte array length is longer than maximum allowed (" + len + " > " + maxLength + ")");
        }
        return readBytes(len);
    }

    public byte[] readByteArray() {
        int len = readVarInt();
        return readBytes(len);
    }

    public void writeByteArray(byte[] array) {
        writeVarInt(array.length);
        writeBytes(array);
    }

    public int[] readVarIntArray() {
        int readableBytes = buffer.readableBytes();
        int size = readVarInt();
        if (size > readableBytes) {
            throw new IllegalStateException("VarIntArray with size " + size + " is bigger than allowed " + readableBytes);
        }

        int[] array = new int[size];
        for (int i = 0; i < size; i++) {
            array[i] = readVarInt();
        }
        return array;
    }

    public void writeVarIntArray(int[] array) {
        writeVarInt(array.length);
        for (int i : array) {
            writeVarInt(i);
        }
    }

    public long[] readLongArray(int size) {
        long[] array = new long[size];

        for (int i = 0; i < array.length; i++) {
            array[i] = readLong();
        }
        return array;
    }

    public byte[] readByteArrayOfSize(int size) {
        byte[] array = new byte[size];
        buffer.readBytes(array);
        return array;
    }

    public void writeByteArrayOfSize(byte[] array) {
        buffer.writeBytes(array);
    }

    public int[] readVarIntArrayOfSize(int size) {
        int[] array = new int[size];
        for (int i = 0; i < array.length; i++) {
            array[i] = readVarInt();
        }
        return array;
    }

    public void writeVarIntArrayOfSize(int[] array) {
        for (int i : array) {
            writeVarInt(i);
        }
    }

    public long[] readLongArray() {
        int readableBytes = buffer.readableBytes() / 8;
        int size = readVarInt();
        if (size > readableBytes) {
            throw new IllegalStateException("LongArray with size " + size + " is bigger than allowed " + readableBytes);
        }
        long[] array = new long[size];

        for (int i = 0; i < array.length; i++) {
            array[i] = readLong();
        }
        return array;
    }

    public void writeLongArray(long[] array) {
        writeVarInt(array.length);
        for (long l : array) {
            writeLong(l);
        }
    }

    public UUID readUUID() {
        long mostSigBits = readLong();
        long leastSigBits = readLong();
        return new UUID(mostSigBits, leastSigBits);
    }

    public void writeUUID(UUID uuid) {
        writeLong(uuid.getMostSignificantBits());
        writeLong(uuid.getLeastSignificantBits());
    }

    public Vector3i readBlockPosition() {
        long val = readLong();
        return new Vector3i(val, serverVersion);
    }

    public void writeBlockPosition(Vector3i pos) {
        long val = pos.getSerializedPosition(serverVersion);
        writeLong(val);
    }

    public GameMode readGameMode() {
        return GameMode.getById(readByte());
    }

    public void writeGameMode(@Nullable GameMode mode) {
        int id = mode == null ? -1 : mode.getId();
        writeByte(id);
    }

    public List<EntityData> readEntityMetadata() {
        List<EntityData> list = new ArrayList<>();
        if (serverVersion.isNewerThanOrEquals(ServerVersion.V_1_9)) {
            boolean v1_10 = serverVersion.isNewerThanOrEquals(ServerVersion.V_1_10);
            short index;
            while ((index = readUnsignedByte()) != 255) {
                int typeID = v1_10 ? readVarInt() : readUnsignedByte();
                EntityDataType<?> type = EntityDataTypes.getById(typeID);
                Object value = type.getDataDeserializer().apply(this);
                list.add(new EntityData(index, type, value));
            }
        } else {
            for (byte data = readByte(); data != 127; data = readByte()) {
                int typeID = (data & 224) >> 5;
                int index = data & 31;
                EntityDataType<?> type = EntityDataTypes.getById(typeID);
                Object value = type.getDataDeserializer().apply(this);
                EntityData entityData = new EntityData(index, type, value);
                list.add(entityData);
            }
        }
        return list;
    }

    public void writeEntityMetadata(List<EntityData> list) {
        if (serverVersion.isNewerThanOrEquals(ServerVersion.V_1_9)) {
            boolean v1_10 = serverVersion.isNewerThanOrEquals(ServerVersion.V_1_10);
            for (EntityData entityData : list) {
                writeByte(entityData.getIndex());
                if (v1_10) {
                    writeVarInt(entityData.getType().getId());
                } else {
                    writeByte(entityData.getType().getId());
                }
                entityData.getType().getDataSerializer().accept(this, entityData.getValue());
            }
            writeByte(255); // End of metadata array
        } else {
            for (EntityData entityData : list) {
                int typeID = entityData.getType().getId();
                int index = entityData.getIndex();
                int data = (typeID << 5 | index & 31) & 255;
                writeByte(data);
                entityData.getType().getDataSerializer().accept(this, entityData.getValue());
            }
            writeByte(127); // End of metadata array
        }
    }
}
