Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ lazy val root = tlCrossRootProject.aggregate(servlet, examples)

val asyncHttpClientVersion = "2.12.3"
val jettyVersion = "9.4.50.v20221201"
val http4sVersion = "0.23.17"
val http4sVersion = "0.23.19"
val munitCatsEffectVersion = "1.0.7"
val servletApiVersion = "3.1.0"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0.
dispatcher.unsafeRunAndForget(result)
} catch errorHandler(servletRequest, servletResponse).andThen(dispatcher.unsafeRunSync _)

private[this] val noopCancelToken = Some(F.unit)

private def handleRequest(
ctx: AsyncContext,
request: Request[F],
Expand All @@ -88,7 +90,7 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0.

val timeout =
F.async[Response[F]](cb =>
gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))).as(Option.empty[F[Unit]])
gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))).as(noopCancelToken)
)
val response =
gate.get *>
Expand Down
23 changes: 19 additions & 4 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import fs2._
import org.http4s.internal.bug
import org.log4s.getLogger

import java.util.concurrent.CancellationException
import java.util.concurrent.atomic.AtomicReference
import javax.servlet.ReadListener
import javax.servlet.WriteListener
Expand Down Expand Up @@ -115,12 +116,14 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
case object Init extends State
case object Ready extends State
case object Complete extends State
case object Canceled extends State
sealed case class Errored(t: Throwable) extends State
sealed case class Blocked(cb: Callback[Option[Chunk[Byte]]]) extends State

val in = servletRequest.getInputStream

val state = new AtomicReference[State](Init)
val cancelToken: F[Option[F[Unit]]] = F.pure(Some(F.delay(state.set(Canceled))))

def read(cb: Callback[Option[Chunk[Byte]]]): Unit = {
val buf = new Array[Byte](chunkSize)
Expand All @@ -141,7 +144,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
// This effect sets the callback and waits for the first bytes to read
val registerRead =
// Shift execution to a different EC
F.async_[Option[Chunk[Byte]]] { cb =>
F.async[Option[Chunk[Byte]]] { cb =>
if (!state.compareAndSet(Init, Blocked(cb)))
cb(Left(bug("Shouldn't have gotten here: I should be the first to set a state")))
else
Expand All @@ -166,12 +169,13 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
}
}
)
cancelToken
}

val readStream = Stream.eval(registerRead) ++ Stream
.repeatEval( // perform the initial set then transition into normal read mode
// Shift execution to a different EC
F.async_[Option[Chunk[Byte]]] { cb =>
F.async[Option[Chunk[Byte]]] { cb =>
@tailrec
def go(): Unit =
state.get match {
Expand All @@ -189,6 +193,8 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl

case Complete => cb(rightNone)

case Canceled => cb(Left(new CancellationException("Servlet read was canceled")))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is ever going to be read, and we could just do a no-op cancel token (which is effectively how it has worked for years), but this at least gets us to a terminal state. The same is true on the write side.


case Errored(t) => cb(Left(t))

// This should never happen so throw a huge fit if it does.
Expand All @@ -203,6 +209,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
cb(Left(bug("Should have left Init state by now")))
}
go()
cancelToken
}
)
readStream.unNoneTerminate.flatMap(Stream.chunk)
Expand All @@ -215,6 +222,7 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
sealed trait State
case object Init extends State
case object Ready extends State
case object Canceled extends State
sealed case class Errored(t: Throwable) extends State
sealed case class Blocked(cb: Callback[Chunk[Byte] => Unit]) extends State
sealed case class AwaitingLastWrite(cb: Callback[Unit]) extends State
Expand All @@ -227,6 +235,8 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
* fires.
*/
val state = new AtomicReference[State](Init)
val cancelToken: F[Option[F[Unit]]] = F.pure(Some(F.delay(state.set(Canceled))))

@volatile var autoFlush = false

val writeChunk = Right { (chunk: Chunk[Byte]) =>
Expand Down Expand Up @@ -263,16 +273,17 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl

val awaitLastWrite = Stream.exec {
// Shift execution to a different EC
F.async_[Unit] { cb =>
F.async[Unit] { cb =>
state.getAndSet(AwaitingLastWrite(cb)) match {
case Ready if out.isReady => cb(Right(()))
case _ => ()
}
cancelToken
}
}

val chunkHandler =
F.async_[Chunk[Byte] => Unit] { cb =>
F.async[Chunk[Byte] => Unit] { cb =>
val blocked = Blocked(cb)
state.getAndSet(blocked) match {
case Ready if out.isReady =>
Expand All @@ -281,9 +292,13 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
case e @ Errored(t) =>
if (state.compareAndSet(blocked, e))
cb(Left(t))
case Canceled =>
if (state.compareAndSet(blocked, Canceled))
cb(Left(new CancellationException("Servlet write was canceled")))
case _ =>
()
}
cancelToken
}

def flushPrelude =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class BlockingHttp4sServletSuite extends CatsEffectSuite {
private def get(serverPort: Int, path: String): IO[String] =
Resource
.make(IO.blocking(Source.fromURL(new URL(s"http://127.0.0.1:$serverPort/$path"))))(source =>
IO(source.close())
IO.blocking(source.close())
)
.use { source =>
IO.blocking(source.getLines().mkString)
Expand All @@ -74,7 +74,7 @@ class BlockingHttp4sServletSuite extends CatsEffectSuite {
Resource
.make(
IO.blocking(Source.fromInputStream(conn.getInputStream, StandardCharsets.UTF_8.name))
)(source => IO(source.close()))
)(source => IO.blocking(source.close()))
.use { source =>
IO.blocking(source.getLines().mkString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class RouterInServletSuite extends CatsEffectSuite {
private def get(serverPort: Int, path: String): IO[String] =
Resource
.make(IO.blocking(Source.fromURL(new URL(s"http://127.0.0.1:$serverPort/$path"))))(source =>
IO.delay(source.close())
IO.blocking(source.close())
)
.use { source =>
IO.blocking(source.getLines().mkString)
Expand Down