diff --git a/src/main/java/io/wispforest/owo/client/screens/ScreenInternals.java b/src/main/java/io/wispforest/owo/client/screens/ScreenInternals.java index bcdf33c14..99cb03513 100644 --- a/src/main/java/io/wispforest/owo/client/screens/ScreenInternals.java +++ b/src/main/java/io/wispforest/owo/client/screens/ScreenInternals.java @@ -1,8 +1,11 @@ package io.wispforest.owo.client.screens; +import io.wispforest.endec.StructEndec; import io.wispforest.endec.impl.StructEndecBuilder; import io.wispforest.owo.Owo; import io.wispforest.endec.Endec; +import io.wispforest.owo.network.OwoHandshake; +import io.wispforest.owo.ops.TextOps; import io.wispforest.owo.serialization.CodecUtils; import io.wispforest.owo.serialization.endec.MinecraftEndecs; import io.wispforest.owo.util.pond.OwoScreenHandlerExtension; @@ -11,16 +14,27 @@ import net.fabricmc.fabric.api.client.networking.v1.ClientPlayNetworking; import net.fabricmc.fabric.api.client.screen.v1.ScreenEvents; import net.fabricmc.fabric.api.networking.v1.PayloadTypeRegistry; +import net.fabricmc.fabric.api.networking.v1.ServerConfigurationNetworking; import net.fabricmc.fabric.api.networking.v1.ServerPlayNetworking; import net.minecraft.client.gui.screen.ingame.ScreenHandlerProvider; import net.minecraft.network.PacketByteBuf; import net.minecraft.network.packet.CustomPayload; +import net.minecraft.registry.Registries; +import net.minecraft.screen.ScreenHandler; +import net.minecraft.screen.ScreenHandlerType; +import net.minecraft.server.network.ServerPlayerEntity; +import net.minecraft.text.Text; +import net.minecraft.text.Texts; import net.minecraft.util.Identifier; import org.jetbrains.annotations.ApiStatus; +import java.util.*; + @ApiStatus.Internal public class ScreenInternals { public static final Identifier SYNC_PROPERTIES = Identifier.of("owo", "sync_screen_handler_properties"); + public static final Identifier HANDSHAKE_REQUEST = Identifier.of("owo", "request_screen_handler_messages"); + public static final Identifier HANDSHAKE_RESPONSE = Identifier.of("owo", "response_screen_handler_messages"); public static void init() { var localPacketCodec = CodecUtils.toPacketCodec(LocalPacket.ENDEC); @@ -29,6 +43,9 @@ public static void init() { PayloadTypeRegistry.playC2S().register(LocalPacket.ID, localPacketCodec); PayloadTypeRegistry.playS2C().register(SyncPropertiesPacket.ID, CodecUtils.toPacketCodec(SyncPropertiesPacket.ENDEC)); + PayloadTypeRegistry.playS2C().register(HandshakeRequest.ID, CodecUtils.toPacketCodec(HandshakeRequest.ENDEC)); + PayloadTypeRegistry.playC2S().register(HandshakeResponse.ID, CodecUtils.toPacketCodec(HandshakeResponse.ENDEC)); + ServerPlayNetworking.registerGlobalReceiver(LocalPacket.ID, (payload, context) -> { var screenHandler = context.player().currentScreenHandler; @@ -39,6 +56,22 @@ public static void init() { ((OwoScreenHandlerExtension) screenHandler).owo$handlePacket(payload, false); }); + + ServerPlayNetworking.registerGlobalReceiver(HandshakeResponse.ID, (payload, context) -> { + var screenHandler = context.player().currentScreenHandler; + + if (screenHandler == null) { + Owo.LOGGER.error("[ScreenHandlerHandshake] Received handshake response for null ScreenHandler"); + return; + } + + if (!payload.type().equals(screenHandler.getType())) { + Owo.LOGGER.error("[ScreenHandlerHandshake] Received handshake response packet for different ScreenHandler type: [Expected Type: {}, Current Type: {}]", payload.type(), screenHandler.getType()); + return; + } + + ((OwoScreenHandlerExtension) screenHandler).owo$verifyData(context.player(), payload.messageNames()); + }); } public record LocalPacket(int packetId, PacketByteBuf payload) implements CustomPayload { @@ -68,6 +101,48 @@ public Id getId() { } } + public static void attemptHandshake(ScreenHandlerType type, ServerPlayerEntity player) { + if (type == null) return; + + if (ServerPlayNetworking.canSend(player, OwoHandshake.OFF_CHANNEL_ID)) { + Owo.LOGGER.info("[ScreenHandlerHandshake] Handshake disabled by client, skipping"); + return; + } + + try { + ServerPlayNetworking.send(player, new HandshakeRequest(type)); + } catch (Exception e) { + Owo.LOGGER.error("[ScreenHandlerHandshake] Unable to Handshake check handler as getting the type encountered an error: ", e); + } + } + + private record HandshakeRequest(ScreenHandlerType type) implements CustomPayload { + public static final Id ID = new Id<>(HANDSHAKE_REQUEST); + public static final Endec ENDEC = StructEndecBuilder.of( + MinecraftEndecs.ofRegistry(Registries.SCREEN_HANDLER).fieldOf("type", HandshakeRequest::type), + HandshakeRequest::new + ); + + @Override + public Id getId() { + return ID; + } + } + + private record HandshakeResponse(ScreenHandlerType type, LinkedHashSet messageNames) implements CustomPayload { + public static final CustomPayload.Id ID = new CustomPayload.Id<>(HANDSHAKE_RESPONSE); + public static final Endec ENDEC = StructEndecBuilder.of( + MinecraftEndecs.ofRegistry(Registries.SCREEN_HANDLER).fieldOf("type", HandshakeResponse::type), + Endec.STRING.listOf().xmap(LinkedHashSet::new, ArrayList::new).fieldOf("message_names", HandshakeResponse::messageNames), + HandshakeResponse::new + ); + + @Override + public CustomPayload.Id getId() { + return ID; + } + } + @Environment(EnvType.CLIENT) public static class Client { public static void init() { @@ -97,6 +172,22 @@ public static void init() { ((OwoScreenHandlerExtension) screenHandler).owo$readPropertySync(payload); }); + + ClientPlayNetworking.registerGlobalReceiver(HandshakeRequest.ID, (payload, context) -> { + var screenHandler = context.player().currentScreenHandler; + + if (screenHandler == null) { + Owo.LOGGER.error("[ScreenHandlerHandshake] Received handshake request packet for null ScreenHandler"); + return; + } + + if (!payload.type().equals(screenHandler.getType())) { + Owo.LOGGER.error("[ScreenHandlerHandshake] Received handshake request packet for different ScreenHandler type: [Expected Type: {}, Current Type: {}]", payload.type(), screenHandler.getType()); + return; + } + + context.responseSender().sendPacket(new HandshakeResponse(screenHandler.getType(), ((OwoScreenHandlerExtension) screenHandler).owo$gatherMessageNames())); + }); } } } diff --git a/src/main/java/io/wispforest/owo/client/screens/ScreenhandlerMessageData.java b/src/main/java/io/wispforest/owo/client/screens/ScreenhandlerMessageData.java index 23a6d1318..c0d478d2f 100644 --- a/src/main/java/io/wispforest/owo/client/screens/ScreenhandlerMessageData.java +++ b/src/main/java/io/wispforest/owo/client/screens/ScreenhandlerMessageData.java @@ -6,4 +6,8 @@ import java.util.function.Consumer; @ApiStatus.Internal -public record ScreenhandlerMessageData(int id, boolean clientbound, Endec endec, Consumer handler) {} +public record ScreenhandlerMessageData(int id, Class messageClass, boolean clientbound, Endec endec, Consumer handler) { + public String messageName() { + return messageClass.getSimpleName(); + } +} diff --git a/src/main/java/io/wispforest/owo/mixin/ScreenHandlerMixin.java b/src/main/java/io/wispforest/owo/mixin/ScreenHandlerMixin.java index 6b3d7385b..048f84448 100644 --- a/src/main/java/io/wispforest/owo/mixin/ScreenHandlerMixin.java +++ b/src/main/java/io/wispforest/owo/mixin/ScreenHandlerMixin.java @@ -2,12 +2,15 @@ import io.wispforest.endec.SerializationContext; import io.wispforest.endec.impl.ReflectiveEndecBuilder; +import io.wispforest.owo.Owo; import io.wispforest.owo.client.screens.OwoScreenHandler; import io.wispforest.owo.client.screens.ScreenInternals; import io.wispforest.owo.client.screens.ScreenhandlerMessageData; import io.wispforest.owo.client.screens.SyncedProperty; import io.wispforest.owo.network.NetworkException; import io.wispforest.endec.Endec; +import io.wispforest.owo.network.OwoHandshake; +import io.wispforest.owo.network.OwoNetChannel; import io.wispforest.owo.serialization.RegistriesAttribute; import io.wispforest.owo.serialization.endec.MinecraftEndecs; import io.wispforest.owo.util.pond.OwoScreenHandlerExtension; @@ -19,9 +22,15 @@ import net.minecraft.entity.player.PlayerEntity; import net.minecraft.network.packet.CustomPayload; import net.minecraft.screen.ScreenHandler; +import net.minecraft.screen.ScreenHandlerSyncHandler; import net.minecraft.screen.ScreenHandlerType; import net.minecraft.server.network.ServerPlayerEntity; +import net.minecraft.text.Text; +import net.minecraft.text.Texts; +import org.apache.commons.lang3.stream.Streams; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.spongepowered.asm.mixin.Final; import org.spongepowered.asm.mixin.Mixin; import org.spongepowered.asm.mixin.Shadow; import org.spongepowered.asm.mixin.Unique; @@ -29,17 +38,21 @@ import org.spongepowered.asm.mixin.injection.Inject; import org.spongepowered.asm.mixin.injection.callback.CallbackInfo; -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.function.Consumer; +import java.util.stream.Collectors; @Mixin(ScreenHandler.class) public abstract class ScreenHandlerMixin implements OwoScreenHandler, OwoScreenHandlerExtension { @Shadow private boolean disableSync; + @Shadow + public abstract ScreenHandlerType getType(); + + @Shadow + @Final + private @Nullable ScreenHandlerType type; private final List> owo$properties = new ArrayList<>(); private final Map, ScreenhandlerMessageData> owo$messages = new LinkedHashMap<>(); @@ -75,7 +88,7 @@ public PlayerEntity player() { public void addServerboundMessage(Class messageClass, Endec endec, Consumer handler) { int id = this.owo$serverboundMessages.size(); - var messageData = new ScreenhandlerMessageData<>(id, false, endec, handler); + var messageData = new ScreenhandlerMessageData<>(id, messageClass, false, endec, handler); this.owo$serverboundMessages.add(messageData); if (this.owo$messages.put(messageClass, messageData) != null) { @@ -87,7 +100,7 @@ public void addServerboundMessage(Class messageClass, Ende public void addClientboundMessage(Class messageClass, Endec endec, Consumer handler) { int id = this.owo$clientboundMessages.size(); - var messageData = new ScreenhandlerMessageData<>(id, true, endec, handler); + var messageData = new ScreenhandlerMessageData<>(id, messageClass, true, endec, handler); this.owo$clientboundMessages.add(messageData); if (this.owo$messages.put(messageClass, messageData) != null) { @@ -116,13 +129,13 @@ public void sendMessage(@NotNull R message) { if (messageData.clientbound()) { if (!(this.owo$player instanceof ServerPlayerEntity serverPlayer)) { - throw new NetworkException("Tried to send clientbound message on the server"); + throw new NetworkException("Tried to send clientbound message on the server: [Type: " + message.getClass().getSimpleName() + "]"); } ServerPlayNetworking.send(serverPlayer, packet); } else { if (!this.owo$player.getWorld().isClient) { - throw new NetworkException("Tried to send serverbound message on the client"); + throw new NetworkException("Tried to send serverbound message on the client: [Type: " + message.getClass().getSimpleName() + "]"); } this.owo$sendToServer(packet); @@ -138,12 +151,45 @@ public void sendMessage(@NotNull R message) { @Override @SuppressWarnings({"rawtypes", "unchecked"}) public void owo$handlePacket(ScreenInternals.LocalPacket packet, boolean clientbound) { - ScreenhandlerMessageData messageData = (clientbound ? this.owo$clientboundMessages : this.owo$serverboundMessages).get(packet.packetId()); + var messages = (clientbound ? this.owo$clientboundMessages : this.owo$serverboundMessages); + + if (packet.packetId() < 0 || packet.packetId() >= messages.size()) { + throw new NetworkException("Unable to handle packet as it was not properly registered on the [" + (clientbound ? "CLIENT" : "SERVER") + "]"); + } + + ScreenhandlerMessageData messageData = messages.get(packet.packetId()); var ctx = SerializationContext.attributes(RegistriesAttribute.of(this.owo$player.getRegistryManager())); messageData.handler().accept(packet.payload().read(ctx, messageData.endec())); } + @Inject(method = "updateSyncHandler", at = @At("HEAD")) + private void compareHandlersNetworking(ScreenHandlerSyncHandler handler, CallbackInfo ci) { + if (!(player() instanceof ServerPlayerEntity serverPlayer)) return; + + ScreenInternals.attemptHandshake(this.type, serverPlayer); + } + + @Override + public LinkedHashSet owo$gatherMessageNames() { + return Streams.of(this.owo$clientboundMessages, this.owo$serverboundMessages) + .flatMap(Collection::stream) + .map(ScreenhandlerMessageData::messageName) + .collect(Collectors.toCollection(LinkedHashSet::new)); + } + + @Override + public void owo$verifyData(ServerPlayerEntity player, Set clientMessageNames) { + var errorMessage = new StringBuilder(); + + if (OwoHandshake.checkForMismatchStrIds("screen_handler_messages", clientMessageNames, owo$gatherMessageNames(), errorMessage)) return; + + player.closeHandledScreen(); + + player.sendMessage(Texts.join(List.of(Owo.PREFIX, Text.of("Unable to open screen as there was a message mismatch:")), Text.empty())); + player.sendMessage(Text.of(errorMessage.toString())); + } + @Override public SyncedProperty createProperty(Class clazz, Endec endec, T initial) { var prop = new SyncedProperty<>(this.owo$properties.size(), endec, initial, (ScreenHandler)(Object) this); diff --git a/src/main/java/io/wispforest/owo/network/OwoHandshake.java b/src/main/java/io/wispforest/owo/network/OwoHandshake.java index e4f486839..da2957855 100644 --- a/src/main/java/io/wispforest/owo/network/OwoHandshake.java +++ b/src/main/java/io/wispforest/owo/network/OwoHandshake.java @@ -16,6 +16,7 @@ import net.fabricmc.fabric.api.client.networking.v1.ClientConfigurationConnectionEvents; import net.fabricmc.fabric.api.client.networking.v1.ClientConfigurationNetworking; import net.fabricmc.fabric.api.client.networking.v1.ClientPlayConnectionEvents; +import net.fabricmc.fabric.api.client.networking.v1.ClientPlayNetworking; import net.fabricmc.fabric.api.networking.v1.PayloadTypeRegistry; import net.fabricmc.fabric.api.networking.v1.ServerConfigurationConnectionEvents; import net.fabricmc.fabric.api.networking.v1.ServerConfigurationNetworking; @@ -29,7 +30,6 @@ import net.minecraft.text.MutableText; import net.minecraft.text.Text; import net.minecraft.util.Identifier; -import net.minecraft.util.Pair; import org.jetbrains.annotations.ApiStatus; import java.util.HashMap; @@ -37,6 +37,7 @@ import java.util.Map; import java.util.Set; import java.util.function.ToIntFunction; +import java.util.stream.Collectors; @ApiStatus.Internal public final class OwoHandshake { @@ -47,7 +48,7 @@ public final class OwoHandshake { public static final Identifier CHANNEL_ID = Identifier.of("owo", "handshake"); public static final Identifier OFF_CHANNEL_ID = Identifier.of("owo", "handshake_off"); - private static final boolean ENABLED = System.getProperty("owo.handshake.enabled") != null ? Boolean.getBoolean("owo.handshake.enabled") : Owo.DEBUG; + public static final boolean ENABLED = System.getProperty("owo.handshake.enabled") != null ? Boolean.getBoolean("owo.handshake.enabled") : Owo.DEBUG; private static boolean HANDSHAKE_REQUIRED = false; private static boolean QUERY_RECEIVED = false; @@ -82,6 +83,9 @@ public static void requireHandshake() { if (!ENABLED) { PayloadTypeRegistry.configurationS2C().register(HandshakeOff.ID, PacketCodec.unit(new HandshakeOff())); ClientConfigurationNetworking.registerGlobalReceiver(HandshakeOff.ID, (payload, context) -> {}); + + PayloadTypeRegistry.playC2S().register(HandshakeOff.ID, PacketCodec.unit(new HandshakeOff())); + ClientPlayNetworking.registerGlobalReceiver(HandshakeOff.ID, (payload, context) -> {}); } ClientConfigurationNetworking.registerGlobalReceiver(HandshakeRequest.ID, OwoHandshake::syncClient); @@ -190,23 +194,7 @@ private static Set filterOptionalServices(Map boolean verifyReceivedHashes(String serviceNamePlural, Map clientMap, Map serverMap, ToIntFunction hashFunction, StringBuilder disconnectMessage) { - boolean isAllGood = true; - - if (!clientMap.keySet().equals(serverMap.keySet())) { - isAllGood = false; - - var leftovers = findCollisions(clientMap.keySet(), serverMap.keySet()); - - if (!leftovers.getLeft().isEmpty()) { - disconnectMessage.append("server is missing ").append(serviceNamePlural).append(":\n"); - leftovers.getLeft().forEach(identifier -> disconnectMessage.append("§7").append(identifier).append("§r\n")); - } - - if (!leftovers.getRight().isEmpty()) { - disconnectMessage.append("client is missing ").append(serviceNamePlural).append(":\n"); - leftovers.getRight().forEach(identifier -> disconnectMessage.append("§7").append(identifier).append("§r\n")); - } - } + boolean isAllGood = checkForMismatchIds(serviceNamePlural, clientMap.keySet(), serverMap.keySet(), disconnectMessage); boolean hasMismatchedHashes = false; for (var entry : clientMap.entrySet()) { @@ -240,9 +228,38 @@ private static Map formatHashes(Map valu return hashes; } - private static Pair, Set> findCollisions(Set first, Set second) { - var firstLeftovers = new HashSet(); - var secondLeftovers = new HashSet(); + private record PacketMismatches(Set missingServerPackets, Set missingClientPackets) {} + + private static boolean checkForMismatchIds(String serviceNamePlural, Set clientIds, Set serverIds, StringBuilder disconnectMessage) { + return checkForMismatchStrIds(serviceNamePlural, + clientIds.stream().map(Identifier::toString).collect(Collectors.toSet()), + serverIds.stream().map(Identifier::toString).collect(Collectors.toSet()), + disconnectMessage); + } + + public static boolean checkForMismatchStrIds(String serviceNamePlural, Set clientIds, Set serverIds, StringBuilder disconnectMessage) { + if (!clientIds.equals(serverIds)) { + var mismatches = findCollisions(clientIds, serverIds); + + if (!mismatches.missingServerPackets().isEmpty()) { + disconnectMessage.append("server is missing ").append(serviceNamePlural).append(":\n"); + mismatches.missingServerPackets().forEach(identifier -> disconnectMessage.append("§7").append(identifier).append("§r\n")); + } + + if (!mismatches.missingClientPackets().isEmpty()) { + disconnectMessage.append("client is missing ").append(serviceNamePlural).append(":\n"); + mismatches.missingClientPackets().forEach(identifier -> disconnectMessage.append("§7").append(identifier).append("§r\n")); + } + + return false; + } + + return true; + } + + private static PacketMismatches findCollisions(Set first, Set second) { + var firstLeftovers = new HashSet(); + var secondLeftovers = new HashSet(); first.forEach(identifier -> { if (!second.contains(identifier)) firstLeftovers.add(identifier); @@ -252,7 +269,7 @@ private static Pair, Set> findCollisions(Set(firstLeftovers, secondLeftovers); + return new PacketMismatches(firstLeftovers, secondLeftovers); } private static int hashChannel(OwoNetChannel channel) { diff --git a/src/main/java/io/wispforest/owo/util/pond/OwoScreenHandlerExtension.java b/src/main/java/io/wispforest/owo/util/pond/OwoScreenHandlerExtension.java index 5149bf905..e7077f738 100644 --- a/src/main/java/io/wispforest/owo/util/pond/OwoScreenHandlerExtension.java +++ b/src/main/java/io/wispforest/owo/util/pond/OwoScreenHandlerExtension.java @@ -2,6 +2,10 @@ import io.wispforest.owo.client.screens.ScreenInternals; import net.minecraft.entity.player.PlayerEntity; +import net.minecraft.server.network.ServerPlayerEntity; + +import java.util.LinkedHashSet; +import java.util.Set; public interface OwoScreenHandlerExtension { void owo$attachToPlayer(PlayerEntity player); @@ -9,4 +13,8 @@ public interface OwoScreenHandlerExtension { void owo$readPropertySync(ScreenInternals.SyncPropertiesPacket packet); void owo$handlePacket(ScreenInternals.LocalPacket packet, boolean clientbound); + + void owo$verifyData(ServerPlayerEntity player, Set clientMessageNames); + + LinkedHashSet owo$gatherMessageNames(); }