diff --git a/build.sbt b/build.sbt index 76410da4..145647a1 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ lazy val root = tlCrossRootProject.aggregate(servlet, examples) val asyncHttpClientVersion = "2.12.3" val jettyVersion = "9.4.50.v20221201" -val http4sVersion = "0.23.17" +val http4sVersion = "0.23.19" val munitCatsEffectVersion = "1.0.7" val servletApiVersion = "3.1.0" diff --git a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala index a726b135..f93996cb 100644 --- a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala @@ -76,6 +76,8 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. dispatcher.unsafeRunAndForget(result) } catch errorHandler(servletRequest, servletResponse).andThen(dispatcher.unsafeRunSync _) + private[this] val noopCancelToken = Some(F.unit) + private def handleRequest( ctx: AsyncContext, request: Request[F], @@ -88,7 +90,7 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. val timeout = F.async[Response[F]](cb => - gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))).as(Option.empty[F[Unit]]) + gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))).as(noopCancelToken) ) val response = gate.get *> diff --git a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala index 0ed72287..6fab84f3 100644 --- a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala +++ b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala @@ -24,6 +24,7 @@ import fs2._ import org.http4s.internal.bug import org.log4s.getLogger +import java.util.concurrent.CancellationException import java.util.concurrent.atomic.AtomicReference import javax.servlet.ReadListener import javax.servlet.WriteListener @@ -115,12 +116,14 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl case object Init extends State case object Ready extends State case object Complete extends State + case object Canceled extends State sealed case class Errored(t: Throwable) extends State sealed case class Blocked(cb: Callback[Option[Chunk[Byte]]]) extends State val in = servletRequest.getInputStream val state = new AtomicReference[State](Init) + val cancelToken: F[Option[F[Unit]]] = F.pure(Some(F.delay(state.set(Canceled)))) def read(cb: Callback[Option[Chunk[Byte]]]): Unit = { val buf = new Array[Byte](chunkSize) @@ -141,7 +144,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl // This effect sets the callback and waits for the first bytes to read val registerRead = // Shift execution to a different EC - F.async_[Option[Chunk[Byte]]] { cb => + F.async[Option[Chunk[Byte]]] { cb => if (!state.compareAndSet(Init, Blocked(cb))) cb(Left(bug("Shouldn't have gotten here: I should be the first to set a state"))) else @@ -166,12 +169,13 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl } } ) + cancelToken } val readStream = Stream.eval(registerRead) ++ Stream .repeatEval( // perform the initial set then transition into normal read mode // Shift execution to a different EC - F.async_[Option[Chunk[Byte]]] { cb => + F.async[Option[Chunk[Byte]]] { cb => @tailrec def go(): Unit = state.get match { @@ -189,6 +193,8 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl case Complete => cb(rightNone) + case Canceled => cb(Left(new CancellationException("Servlet read was canceled"))) + case Errored(t) => cb(Left(t)) // This should never happen so throw a huge fit if it does. @@ -203,6 +209,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl cb(Left(bug("Should have left Init state by now"))) } go() + cancelToken } ) readStream.unNoneTerminate.flatMap(Stream.chunk) @@ -215,6 +222,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl sealed trait State case object Init extends State case object Ready extends State + case object Canceled extends State sealed case class Errored(t: Throwable) extends State sealed case class Blocked(cb: Callback[Chunk[Byte] => Unit]) extends State sealed case class AwaitingLastWrite(cb: Callback[Unit]) extends State @@ -227,6 +235,8 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl * fires. */ val state = new AtomicReference[State](Init) + val cancelToken: F[Option[F[Unit]]] = F.pure(Some(F.delay(state.set(Canceled)))) + @volatile var autoFlush = false val writeChunk = Right { (chunk: Chunk[Byte]) => @@ -263,16 +273,17 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl val awaitLastWrite = Stream.exec { // Shift execution to a different EC - F.async_[Unit] { cb => + F.async[Unit] { cb => state.getAndSet(AwaitingLastWrite(cb)) match { case Ready if out.isReady => cb(Right(())) case _ => () } + cancelToken } } val chunkHandler = - F.async_[Chunk[Byte] => Unit] { cb => + F.async[Chunk[Byte] => Unit] { cb => val blocked = Blocked(cb) state.getAndSet(blocked) match { case Ready if out.isReady => @@ -281,9 +292,13 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl case e @ Errored(t) => if (state.compareAndSet(blocked, e)) cb(Left(t)) + case Canceled => + if (state.compareAndSet(blocked, Canceled)) + cb(Left(new CancellationException("Servlet write was canceled"))) case _ => () } + cancelToken } def flushPrelude = diff --git a/servlet/src/test/scala/org/http4s/servlet/BlockingHttp4sServletSuite.scala b/servlet/src/test/scala/org/http4s/servlet/BlockingHttp4sServletSuite.scala index 79070f11..d4711b63 100644 --- a/servlet/src/test/scala/org/http4s/servlet/BlockingHttp4sServletSuite.scala +++ b/servlet/src/test/scala/org/http4s/servlet/BlockingHttp4sServletSuite.scala @@ -54,7 +54,7 @@ class BlockingHttp4sServletSuite extends CatsEffectSuite { private def get(serverPort: Int, path: String): IO[String] = Resource .make(IO.blocking(Source.fromURL(new URL(s"http://127.0.0.1:$serverPort/$path"))))(source => - IO(source.close()) + IO.blocking(source.close()) ) .use { source => IO.blocking(source.getLines().mkString) @@ -74,7 +74,7 @@ class BlockingHttp4sServletSuite extends CatsEffectSuite { Resource .make( IO.blocking(Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name)) - )(source => IO(source.close())) + )(source => IO.blocking(source.close())) .use { source => IO.blocking(source.getLines().mkString) } diff --git a/servlet/src/test/scala/org/http4s/servlet/RouterInServletSuite.scala b/servlet/src/test/scala/org/http4s/servlet/RouterInServletSuite.scala index 2e87cc97..e1a9ed94 100644 --- a/servlet/src/test/scala/org/http4s/servlet/RouterInServletSuite.scala +++ b/servlet/src/test/scala/org/http4s/servlet/RouterInServletSuite.scala @@ -128,7 +128,7 @@ class RouterInServletSuite extends CatsEffectSuite { private def get(serverPort: Int, path: String): IO[String] = Resource .make(IO.blocking(Source.fromURL(new URL(s"http://127.0.0.1:$serverPort/$path"))))(source => - IO.delay(source.close()) + IO.blocking(source.close()) ) .use { source => IO.blocking(source.getLines().mkString)