Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0.
val ctx = servletRequest.startAsync()
ctx.setTimeout(asyncTimeoutMillis)
// Must be done on the container thread for Tomcat's sake when using async I/O.
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _
val bodyWriter = servletIo.initWriter(servletResponse)
val result = F
.attempt(
toRequest(servletRequest).fold(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BlockingHttp4sServlet[F[_]] private (
): Unit = {
val result = F
.defer {
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _
val bodyWriter = servletIo.initWriter(servletResponse)

val render = toRequest(servletRequest).fold(
onParseFailure(_, servletResponse, bodyWriter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ abstract class Http4sServlet[F[_]](
uri = uri,
httpVersion = version,
headers = toHeaders(req),
body = servletIo.requestBody(req, dispatcher),
body = servletIo.reader(req),
attributes = attributes,
)

Expand Down
154 changes: 10 additions & 144 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,30 @@ package servlet

import cats.effect._
import cats.effect.std.Dispatcher
import cats.effect.std.Queue
import cats.syntax.all._
import fs2._
import org.http4s.internal.bug
import org.log4s.getLogger

import java.util.Arrays
import java.util.concurrent.atomic.AtomicReference
import javax.servlet.ReadListener
import javax.servlet.WriteListener
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import scala.annotation.nowarn
import scala.annotation.tailrec

/** Determines the mode of I/O used for reading request bodies and writing response bodies.
*/
sealed abstract class ServletIo[F[_]: Async] {
protected[servlet] val F: Async[F] = Async[F]

@deprecated("Prefer requestBody, which has access to a Dispatcher", "0.23.12")
protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F]

@nowarn("cat=deprecation")
/** An alias for [[reader]]. In the future, this will be optimized with
* the dispatcher.
*
* @param dispatcher currently ignored
*/
def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
Expand All @@ -52,10 +52,13 @@ sealed abstract class ServletIo[F[_]: Async] {
}

/** May install a listener on the servlet response. */
@deprecated("Prefer bodyWriter, which has access to a Dispatcher", "0.23.12")
protected[servlet] def initWriter(servletResponse: HttpServletResponse): BodyWriter[F]

@nowarn("cat=deprecation")
/** An alias for [[initWriter]]. In the future, this will be
* optimized with the dispatcher.
*
* @param dispatcher currently ignored
*/
def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])(
response: Response[F]
): F[Unit] = {
Expand Down Expand Up @@ -206,73 +209,6 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
}
}

/* The queue implementation is influenced by ideas in jetty4s
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
): Stream[F, Byte] = {
sealed trait Read
final case class Bytes(chunk: Chunk[Byte]) extends Read
case object End extends Read
final case class Error(t: Throwable) extends Read

Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in =>
Stream.eval(Queue.bounded[F, Read](4)).flatMap { q =>
val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener {
var buf: Array[Byte] = _
unsafeReplaceBuffer()

def unsafeReplaceBuffer() =
buf = new Array[Byte](chunkSize)

def onDataAvailable(): Unit = {
def loopIfReady =
F.delay(in.isReady()).flatMap {
case true => go
case false => F.unit
}

def go: F[Unit] =
F.delay(in.read(buf)).flatMap {
case len if len == chunkSize =>
// We used the whole buffer. Replace it new before next read.
q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len >= 0 =>
// Got a partial chunk. Copy it, and reuse the current buffer.
q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady
case _ =>
F.unit
}

unsafeRunAndForget(go)
}

def onAllDataRead(): Unit =
unsafeRunAndForget(q.offer(End))

def onError(t: Throwable): Unit =
unsafeRunAndForget(q.offer(Error(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t => F.delay(logger.error(t)("Error in servlet read listener")) }
)
})))

def pullBody: Pull[F, Byte, Unit] =
Pull.eval(q.take).flatMap {
case Bytes(chunk) => Pull.output(chunk) >> pullBody
case End => Pull.done
case Error(t) => Pull.raiseError[F](t)
}

pullBody.stream.concurrently(readBody)
}
}
}

override protected[servlet] def initWriter(
servletResponse: HttpServletResponse
): BodyWriter[F] = {
Expand Down Expand Up @@ -367,74 +303,4 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
.drain
}
}

/* The queue implementation is influenced by ideas in jetty4s
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def bodyWriter(
servletResponse: HttpServletResponse,
dispatcher: Dispatcher[F],
)(response: Response[F]): F[Unit] = {
sealed trait Write
final case class Bytes(chunk: Chunk[Byte]) extends Write
case object End extends Write
case object Init extends Write

val autoFlush = response.isChunked

F.delay(servletResponse.getOutputStream).flatMap { out =>
Queue.bounded[F, Write](4).flatMap { q =>
Deferred[F, Either[Throwable, Unit]].flatMap { done =>
val writeBody = F.delay(out.setWriteListener(new WriteListener {
def onWritePossible(): Unit = {
def loopIfReady = F.delay(out.isReady()).flatMap {
case true => go
case false => F.unit
}

def flush =
if (autoFlush) {
F.delay(out.isReady()).flatMap {
case true => F.delay(out.flush()) >> loopIfReady
case false => F.unit
}
} else
loopIfReady

def go: F[Unit] =
q.take.flatMap {
case Bytes(slice: Chunk.ArraySlice[_]) =>
F.delay(
out.write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length)
) >> flush
case Bytes(chunk) =>
F.delay(out.write(chunk.toArray)) >> flush
case End =>
F.delay(out.flush()) >> done.complete(Either.unit).attempt.void
case Init =>
if (autoFlush) flush else go
}

unsafeRunAndForget(go)
}
def onError(t: Throwable): Unit =
unsafeRunAndForget(done.complete(Left(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError { case t => F.delay(logger.error(t)("Error in servlet write listener")) }
)
}))

val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End)

Stream
.eval(writeBody >> done.get.rethrow)
.mergeHaltL(writes.foreach(q.offer))
.compile
.drain
}
}
}
}
}