diff --git a/core/constants/src/mill/constants/InputPumper.java b/core/constants/src/mill/constants/InputPumper.java index bc016b89302a..2eb006747739 100644 --- a/core/constants/src/mill/constants/InputPumper.java +++ b/core/constants/src/mill/constants/InputPumper.java @@ -4,7 +4,7 @@ import java.io.OutputStream; import java.util.function.Supplier; -/// A `Runnable` that reads from `src` and writes to `dest`. +/** A `Runnable` that reads from `src` and writes to `dest`. */ public class InputPumper implements Runnable { private final Supplier src0; private final Supplier dest0; @@ -15,7 +15,7 @@ public class InputPumper implements Runnable { /// and there is nothing to read, [it can unnecessarily delay the JVM exit by 350ms]( /// // https://stackoverflow.com/questions/48951611/blocking-on-stdin-makes-java-process-take-350ms-more-to-exit) - private final Boolean checkAvailable; + private final boolean checkAvailable; public InputPumper( Supplier src, Supplier dest, Boolean checkAvailable) { @@ -28,29 +28,29 @@ public InputPumper( @Override public void run() { - InputStream src = src0.get(); - OutputStream dest = dest0.get(); + var src = src0.get(); + var dest = dest0.get(); - byte[] buffer = new byte[1024]; + var buffer = new byte[1024 /* 1kb */]; try { while (running) { if (checkAvailable && src.available() == 0) //noinspection BusyWait Thread.sleep(1); else { - int n; + int bytesRead; try { - n = src.read(buffer); + bytesRead = src.read(buffer); } catch (Exception e) { - n = -1; + bytesRead = -1; } - if (n == -1) running = false; - else if (n == 0) + if (bytesRead == -1) running = false; + else if (bytesRead == 0) //noinspection BusyWait Thread.sleep(1); else { try { - dest.write(buffer, 0, n); + dest.write(buffer, 0, bytesRead); dest.flush(); } catch (java.io.IOException e) { running = false; diff --git a/core/constants/src/mill/constants/ProxyStream.java b/core/constants/src/mill/constants/ProxyStream.java index 64f4b2a8a6f1..d5970e9668c1 100644 --- a/core/constants/src/mill/constants/ProxyStream.java +++ b/core/constants/src/mill/constants/ProxyStream.java @@ -1,9 +1,9 @@ package mill.constants; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; +import java.io.*; import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; /// Logic to capture a pair of streams (typically stdout and stderr), combining /// them into a single stream, and splitting it back into two streams later while @@ -15,114 +15,194 @@ /// the form: /// ``` /// 1 byte n bytes -/// | header | body | +/// | header | frame | /// ``` /// /// Where header is a single byte of the form: /// -/// - header more than 0 indicating that this packet is for the `OUT` stream -/// - header less than 0 indicating that this packet is for the `ERR` stream -/// - abs(header) indicating the length of the packet body, in bytes -/// - header == 0 indicating the end of the stream +/// - [#HEADER_STREAM_OUT]/[#HEADER_STREAM_ERR] respectively indicating that this packet is for +/// the `OUT`/`ERR` +/// stream, and it will be followed by 4 bytes for the length of the body and then the body. +/// - [#HEADER_STREAM_OUT_SINGLE_BYTE]/[#HEADER_STREAM_ERR_SINGLE_BYTE] respectively indicating +/// that this packet is +/// for the `OUT`/`ERR` stream, and it will be followed by a single byte for the body +/// - [#HEADER_HEARTBEAT] indicating that this packet is a heartbeat and will be ignored +/// - [#HEADER_END] indicating the end of the stream /// /// /// Writes to either of the two `Output`s are synchronized on the shared /// `destination` stream, ensuring that they always arrive complete and without /// interleaving. On the other side, a `Pumper` reads from the combined /// stream, forwards each packet to its respective destination stream, or terminates -/// when it hits a packet with `header == 0` +/// when it hits a packet with [#HEADER_END]. public class ProxyStream { + private static final int MAX_CHUNK_SIZE = 4 * 1024; // 4kb - public static final int OUT = 1; - public static final int ERR = -1; - public static final int END = 0; - public static final int HEARTBEAT = 127; + // The values are picked to make it a bit easier to spot when debugging the hex dump. - private static boolean clientHasClosedConnection(SocketException e) { - var message = e.getMessage(); - return message != null && message.contains("Broken pipe"); + /** The header for the output stream */ + private static final byte HEADER_STREAM_OUT = 26; // 0x1A + + // bincompat forwarder + @SuppressWarnings("unused") + public static final int OUT = HEADER_STREAM_OUT; + + /** The header for the output stream when a single byte is sent. */ + private static final byte HEADER_STREAM_OUT_SINGLE_BYTE = 27; // 0x1B, B as in BYTE + + /** The header for the error stream */ + private static final byte HEADER_STREAM_ERR = 42; // 0x2A + + // bincompat forwarder + @SuppressWarnings("unused") + public static final int ERR = HEADER_STREAM_ERR; + + /** The header for the error stream when a single byte is sent. */ + private static final byte HEADER_STREAM_ERR_SINGLE_BYTE = 43; // 0x2B, B as in BYTE + + /** A heartbeat packet to keep the connection alive. */ + private static final byte HEADER_HEARTBEAT = 123; // 0x7B, B as in BEAT + + // bincompat forwarder + @SuppressWarnings("unused") + public static final int HEARTBEAT = HEADER_HEARTBEAT; + + /** Indicates the end of the connection. */ + private static final byte HEADER_END = 126; // 0x7E, E as in END + + // bincompat forwarder + @SuppressWarnings("unused") + public static final int END = HEADER_END; + + public enum StreamType { + /** The output stream */ + OUT(ProxyStream.HEADER_STREAM_OUT, ProxyStream.HEADER_STREAM_OUT_SINGLE_BYTE), + /** The error stream */ + ERR(ProxyStream.HEADER_STREAM_ERR, ProxyStream.HEADER_STREAM_ERR_SINGLE_BYTE); + public final byte header, headerSingleByte; + + StreamType(byte header, byte headerSingleByte) { + this.header = header; + this.headerSingleByte = headerSingleByte; + } } public static void sendEnd(OutputStream out, int exitCode) throws IOException { synchronized (out) { try { - out.write(ProxyStream.END); - out.write(exitCode); + var buffer = new byte[5]; + ByteBuffer.wrap(buffer) + .order(ByteOrder.BIG_ENDIAN) + .put(ProxyStream.HEADER_END) + .putInt(exitCode); + out.write(buffer); out.flush(); } catch (SocketException e) { // If the client has already closed the connection, we don't really care about sending the // exit code to it. - if (!clientHasClosedConnection(e)) throw e; + if (!SocketUtil.clientHasClosedConnection(e)) throw e; } } } public static void sendHeartbeat(OutputStream out) throws IOException { synchronized (out) { - out.write(ProxyStream.HEARTBEAT); + out.write(ProxyStream.HEADER_HEARTBEAT); out.flush(); } } public static class Output extends java.io.OutputStream { - private final java.io.OutputStream destination; - private final int key; + /** + * Object used for synchronization so that our writes wouldn't interleave. + *

+ * We can't use {@link #destination} because it's a private object that we create here and {@link #sendEnd} + * and {@link #sendHeartbeat} use a different object. + **/ + private final java.io.OutputStream synchronizer; + private final DataOutputStream destination; + private final StreamType streamType; + + public Output(java.io.OutputStream out, StreamType streamType) { + this.synchronizer = out; + this.destination = new DataOutputStream(out); + this.streamType = streamType; + } + + // bincompat forwarder public Output(java.io.OutputStream out, int key) { - this.destination = out; - this.key = key; + this(out, key == OUT ? StreamType.OUT : StreamType.ERR); } @Override public void write(int b) throws IOException { - synchronized (destination) { - destination.write(key); + synchronized (synchronizer) { + destination.write(streamType.headerSingleByte); destination.write(b); } } @Override public void write(byte[] b) throws IOException { - if (b.length > 0) { - synchronized (destination) { + switch (b.length) { + case 0: + return; + case 1: + write(b[0]); + break; + default: write(b, 0, b.length); - } + break; } } @Override - public void write(byte[] b, int off, int len) throws IOException { - - synchronized (destination) { - int i = 0; - while (i < len && i + off < b.length) { - int chunkLength = Math.min(len - i, 126); - if (chunkLength > 0) { - destination.write(chunkLength * key); - destination.write(b, off + i, Math.min(b.length - off - i, chunkLength)); - i += chunkLength; - } + public void write(byte[] sourceBuffer, int offset, int len) throws IOException { + // Validate arguments once at the beginning, which is cleaner + // and standard practice for public methods. + if (sourceBuffer == null) throw new NullPointerException("byte array is null"); + if (offset < 0 || offset > sourceBuffer.length) + throw new IndexOutOfBoundsException("Write offset out of range: " + offset); + if (len < 0) throw new IndexOutOfBoundsException("Write length is negative: " + len); + if (offset + len > sourceBuffer.length) + throw new IndexOutOfBoundsException("Write goes beyond end of buffer: offset=" + offset + + ", len=" + len + ", end=" + (offset + len) + " > " + sourceBuffer.length); + + synchronized (synchronizer) { + var bytesRemaining = len; + var currentOffset = offset; + + while (bytesRemaining > 0) { + var chunkSize = Math.min(bytesRemaining, MAX_CHUNK_SIZE); + + destination.writeByte(streamType.header); + destination.writeInt(chunkSize); + destination.write(sourceBuffer, currentOffset, chunkSize); + + bytesRemaining -= chunkSize; + currentOffset += chunkSize; } } } @Override public void flush() throws IOException { - synchronized (destination) { + synchronized (synchronizer) { destination.flush(); } } @Override public void close() throws IOException { - synchronized (destination) { + synchronized (synchronizer) { destination.close(); } } } public static class Pumper implements Runnable { - private final InputStream src; + private final DataInputStream src; private final OutputStream destOut; private final OutputStream destErr; private final Object synchronizer; @@ -130,7 +210,7 @@ public static class Pumper implements Runnable { public Pumper( InputStream src, OutputStream destOut, OutputStream destErr, Object synchronizer) { - this.src = src; + this.src = new DataInputStream(src); this.destOut = destOut; this.destErr = destErr; this.synchronizer = synchronizer; @@ -140,7 +220,13 @@ public Pumper(InputStream src, OutputStream destOut, OutputStream destErr) { this(src, destOut, destErr, new Object()); } - public void preRead(InputStream src) {} + protected void preRead(DataInputStream src) {} + + @Deprecated(forRemoval = true, since = "1.0.4") + public void preRead(InputStream src) { + if (src instanceof DataInputStream) preRead((DataInputStream) src); + else throw new UnsupportedOperationException("preRead(InputStream) is deprecated"); + } public void write(OutputStream dest, byte[] buffer, int length) throws IOException { dest.write(buffer, 0, length); @@ -148,54 +234,40 @@ public void write(OutputStream dest, byte[] buffer, int length) throws IOExcepti @Override public void run() { - - byte[] buffer = new byte[1024]; - while (true) { - try { + var buffer = new byte[MAX_CHUNK_SIZE]; + try { + readLoop: + while (true) { this.preRead(src); - int header = src.read(); - // -1 means socket was closed, 0 means a ProxyStream.END was sent. Note - // that only header values > 0 represent actual data to read: - // - sign((byte)header) represents which stream the data should be sent to - // - abs((byte)header) represents the length of the data to read and send - if (header == -1) break; - else if (header == END) { - exitCode = src.read(); - break; - } else if (header == HEARTBEAT) continue; - else { - int stream = (byte) header > 0 ? 1 : -1; - int quantity0 = (byte) header; - int quantity = Math.abs(quantity0); - int offset = 0; - int delta = -1; - while (offset < quantity) { - this.preRead(src); - delta = src.read(buffer, offset, quantity - offset); - if (delta == -1) { - break; - } else { - offset += delta; - } - } - - if (delta != -1) { - synchronized (synchronizer) { - switch (stream) { - case ProxyStream.OUT: - this.write(destOut, buffer, offset); - break; - case ProxyStream.ERR: - this.write(destErr, buffer, offset); - break; - } - } - } + var header = src.readByte(); + + switch (header) { + case HEADER_END: + exitCode = src.readInt(); + break readLoop; + case HEADER_HEARTBEAT: + continue; + case HEADER_STREAM_OUT: + pumpData(buffer, false, destOut); + break; + case HEADER_STREAM_OUT_SINGLE_BYTE: + pumpData(buffer, true, destOut); + break; + case HEADER_STREAM_ERR: + pumpData(buffer, false, destErr); + break; + case HEADER_STREAM_ERR_SINGLE_BYTE: + pumpData(buffer, true, destErr); + break; + default: + throw new IllegalStateException("Unexpected header: " + header); } - } catch (IOException e) { - // This happens when the upstream pipe was closed - break; } + } catch (EOFException ignored) { + // This is a normal and expected way for the loop to terminate + // when the other side closes the connection. + } catch (IOException ignored) { + // This happens when the upstream pipe was closed } try { @@ -207,6 +279,35 @@ else if (header == END) { } } + private void pumpData(byte[] buffer, boolean singleByte, OutputStream stream) + throws IOException { + var quantity = singleByte ? 1 : src.readInt(); + + if (quantity > buffer.length) { + // Handle error: received chunk is larger than buffer + throw new IOException("Received chunk of size " + quantity + + " is larger than buffer of size " + buffer.length); + } + + var totalBytesRead = 0; + var bytesReadThisIteration = -1; + while (totalBytesRead < quantity) { + this.preRead(src); + bytesReadThisIteration = src.read(buffer, totalBytesRead, quantity - totalBytesRead); + if (bytesReadThisIteration == -1) { + break; + } else { + totalBytesRead += bytesReadThisIteration; + } + } + + if (bytesReadThisIteration != -1) { + synchronized (synchronizer) { + this.write(stream, buffer, totalBytesRead); + } + } + } + public void flush() throws IOException { synchronized (synchronizer) { destOut.flush(); diff --git a/core/constants/test/src/mill/client/ProxyStreamTests.java b/core/constants/test/src/mill/client/ProxyStreamTests.java index 9db033e521bd..cdaa732d0491 100644 --- a/core/constants/test/src/mill/client/ProxyStreamTests.java +++ b/core/constants/test/src/mill/client/ProxyStreamTests.java @@ -18,7 +18,7 @@ public void test() throws Exception { // are likely sizes to have bugs since we write data in chunks of size 127 int[] interestingLengths = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 100, 126, 127, 128, 129, 130, 253, 254, 255, - 256, 257, 1000, 2000, 4000, 8000 + 256, 257, 1000, 2000, 4000, 8000, 16000, 32000, 64000, 128000, 256000, }; byte[] interestingBytes = { -1, -127, -126, -120, -100, -80, -60, -40, -20, -10, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 10, @@ -56,8 +56,8 @@ public void test0(byte[] outData, byte[] errData, int repeats, boolean gracefulE pipedInputStream.connect(pipedOutputStream); - ProxyStream.Output srcOut = new ProxyStream.Output(pipedOutputStream, ProxyStream.OUT); - ProxyStream.Output srcErr = new ProxyStream.Output(pipedOutputStream, ProxyStream.ERR); + var srcOut = new ProxyStream.Output(pipedOutputStream, ProxyStream.StreamType.OUT); + var srcErr = new ProxyStream.Output(pipedOutputStream, ProxyStream.StreamType.ERR); // Capture both the destOut/destErr from the pumper, as well as the destCombined // to ensure the individual streams contain the right data and combined stream diff --git a/core/internal/src/mill/internal/PromptLogger.scala b/core/internal/src/mill/internal/PromptLogger.scala index 8040837460e3..29658235d8ae 100644 --- a/core/internal/src/mill/internal/PromptLogger.scala +++ b/core/internal/src/mill/internal/PromptLogger.scala @@ -286,9 +286,9 @@ private[mill] object PromptLogger { // `ProxyStream`, as we need to preserve the ordering of writes to each individual // stream, and also need to know when *both* streams are quiescent so that we can // print the prompt at the bottom - val pipe = new PipeStreams() - val proxyOut = new ProxyStream.Output(pipe.output, ProxyStream.OUT) - val proxyErr: ProxyStream.Output = new ProxyStream.Output(pipe.output, ProxyStream.ERR) + val pipe = PipeStreams() + val proxyOut = ProxyStream.Output(pipe.output, ProxyStream.StreamType.OUT) + val proxyErr = ProxyStream.Output(pipe.output, ProxyStream.StreamType.ERR) val proxySystemStreams = new SystemStreams( new PrintStream(proxyOut), new PrintStream(proxyErr), @@ -330,7 +330,7 @@ private[mill] object PromptLogger { private var lastCharWritten = 0.toChar // Make sure we synchronize everywhere - override def preRead(src: InputStream): Unit = synchronizer.synchronized { + override protected def preRead(src: DataInputStream): Unit = synchronizer.synchronized { if ( enableTicker && diff --git a/core/internal/test/src/mill/internal/PromptLoggerTests.scala b/core/internal/test/src/mill/internal/PromptLoggerTests.scala index 0843dcb55601..536652f7da3d 100644 --- a/core/internal/test/src/mill/internal/PromptLoggerTests.scala +++ b/core/internal/test/src/mill/internal/PromptLoggerTests.scala @@ -9,8 +9,8 @@ object PromptLoggerTests extends TestSuite { def setup(now: () => Long, terminfoPath: os.Path) = { val baos = new ByteArrayOutputStream() - val baosOut = new PrintStream(new ProxyStream.Output(baos, ProxyStream.OUT)) - val baosErr = new PrintStream(new ProxyStream.Output(baos, ProxyStream.ERR)) + val baosOut = new PrintStream(new ProxyStream.Output(baos, ProxyStream.StreamType.OUT)) + val baosErr = new PrintStream(new ProxyStream.Output(baos, ProxyStream.StreamType.ERR)) val promptLogger = new PromptLogger( colored = false, enableTicker = true, diff --git a/libs/daemon/client/test/src/mill/client/ClientTests.java b/libs/daemon/client/test/src/mill/client/ClientTests.java index db8fbd6f2276..2e46188c1f4a 100644 --- a/libs/daemon/client/test/src/mill/client/ClientTests.java +++ b/libs/daemon/client/test/src/mill/client/ClientTests.java @@ -1,11 +1,9 @@ package mill.client; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.OutputStream; import java.util.*; import mill.constants.ProxyStream; import org.junit.Test; @@ -127,8 +125,8 @@ public void proxyInputOutputStreams(byte[] samples1, byte[] samples2, int chunkM throws Exception { ByteArrayOutputStream pipe = new ByteArrayOutputStream(); - OutputStream src1 = new ProxyStream.Output(pipe, ProxyStream.OUT); - OutputStream src2 = new ProxyStream.Output(pipe, ProxyStream.ERR); + var src1 = new ProxyStream.Output(pipe, ProxyStream.StreamType.OUT); + var src2 = new ProxyStream.Output(pipe, ProxyStream.StreamType.ERR); Random random = new Random(31337); @@ -154,7 +152,7 @@ public void proxyInputOutputStreams(byte[] samples1, byte[] samples2, int chunkM ProxyStream.Pumper pumper = new ProxyStream.Pumper(new ByteArrayInputStream(bytes), dest1, dest2); pumper.run(); - assertTrue(Arrays.equals(samples1, dest1.toByteArray())); - assertTrue(Arrays.equals(samples2, dest2.toByteArray())); + assertArrayEquals(samples1, dest1.toByteArray()); + assertArrayEquals(samples2, dest2.toByteArray()); } } diff --git a/libs/daemon/server/src/mill/server/ProxyStreamServer.scala b/libs/daemon/server/src/mill/server/ProxyStreamServer.scala index 0f5bd36a18c2..ec369825e480 100644 --- a/libs/daemon/server/src/mill/server/ProxyStreamServer.scala +++ b/libs/daemon/server/src/mill/server/ProxyStreamServer.scala @@ -67,12 +67,12 @@ abstract class ProxyStreamServer(args: Server.Args) extends Server(args) { self ): PreHandleConnectionData = { val stdout = new PrintStream( - new ProxyStream.Output(connectionData.serverToClient, ProxyStream.OUT), + new ProxyStream.Output(connectionData.serverToClient, ProxyStream.StreamType.OUT), true ) val stderr = new PrintStream( - new ProxyStream.Output(connectionData.serverToClient, ProxyStream.ERR), + new ProxyStream.Output(connectionData.serverToClient, ProxyStream.StreamType.ERR), true )