diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3c64f97c8..6fcbddf91 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: path: deps key: v1-${{ matrix.os }}-${{ matrix.otp }}-${{ matrix.elixir }}-mix-${{ hashFiles(format('{0}{1}', github.workspace, '/mix.lock')) }} - name: Install Dependencies - run: mix deps.get 1>/dev/null + run: mix setup 1>/dev/null - name: Check format run: mix format --check-formatted tests: diff --git a/benchmark/bin/test.exs b/benchmark/bin/test.exs index 02b99b3e0..8fba3bec0 100644 --- a/benchmark/bin/test.exs +++ b/benchmark/bin/test.exs @@ -3,10 +3,10 @@ server = async_server_threads: 1, port: 10000, channel_args: [ - Grpc.Testing.ChannelArg.new( + %Grpc.Testing.ChannelArg{ name: "grpc.optimization_target", value: {:str_value, "latency"} - ) + } ] ) diff --git a/benchmark/config/config.exs b/benchmark/config/config.exs index cf3e8d596..1246afe81 100644 --- a/benchmark/config/config.exs +++ b/benchmark/config/config.exs @@ -1,5 +1,6 @@ import Config -config :logger, level: :info +# Reduce logging overhead by default for better performance +config :logger, level: :error -import_config "#{Mix.env()}.exs" +import_config "#{config_env()}.exs" diff --git a/benchmark/lib/benchmark/server_manager.ex b/benchmark/lib/benchmark/server_manager.ex index 68a3b6a55..706a12c15 100644 --- a/benchmark/lib/benchmark/server_manager.ex +++ b/benchmark/lib/benchmark/server_manager.ex @@ -1,13 +1,25 @@ defmodule Benchmark.ServerManager do - def start_server(%Grpc.Testing.ServerConfig{} = config) do + def start_server(%Grpc.Testing.ServerConfig{} = config, opts \\ []) do # get security payload_type = Benchmark.Manager.payload_type(config.payload_config) - start_server(payload_type, config) + start_server(payload_type, config, opts) end - def start_server(:protobuf, config) do + def start_server(:protobuf, config, opts) do cores = Benchmark.Manager.set_cores(config.core_limit) - {:ok, pid, port} = GRPC.Server.start(Grpc.Testing.BenchmarkService.Server, config.port) + + # Extract adapter option, default to Cowboy + adapter = Keyword.get(opts, :adapter, GRPC.Server.Adapters.Cowboy) + adapter_name = adapter |> Module.split() |> List.last() + + IO.puts("Starting server with #{adapter_name} adapter on port #{config.port}...") + + {:ok, pid, port} = + GRPC.Server.start( + Grpc.Testing.BenchmarkService.Server, + config.port, + adapter: adapter + ) %Benchmark.Server{ cores: cores, @@ -18,5 +30,10 @@ defmodule Benchmark.ServerManager do } end - def start_server(_, _), do: raise(GRPC.RPCError, status: :unimplemented) + def start_server(_, _, _), do: raise(GRPC.RPCError, status: :unimplemented) + + def stop_server(%Benchmark.Server{} = _server, opts \\ []) do + adapter = Keyword.get(opts, :adapter, GRPC.Server.Adapters.Cowboy) + GRPC.Server.stop(Grpc.Testing.BenchmarkService.Server, adapter: adapter) + end end diff --git a/benchmark/lib/grpc/worker_server.ex b/benchmark/lib/grpc/worker_server.ex index e080272fc..ed9630fe8 100644 --- a/benchmark/lib/grpc/worker_server.ex +++ b/benchmark/lib/grpc/worker_server.ex @@ -21,11 +21,11 @@ defmodule Grpc.Testing.WorkerService.Server do {server, stats} = Benchmark.Server.get_stats(server) status = - Grpc.Testing.ServerStatus.new( + %Grpc.Testing.ServerStatus{ stats: stats, port: server.port, cores: server.cores - ) + } {server, status} @@ -33,7 +33,7 @@ defmodule Grpc.Testing.WorkerService.Server do {server, stats} = Benchmark.Server.get_stats(server, mark) status = - Grpc.Testing.ServerStatus.new(cores: server.cores, port: server.port, stats: stats) + %Grpc.Testing.ServerStatus{cores: server.cores, port: server.port, stats: stats} {server, status} end @@ -53,11 +53,11 @@ defmodule Grpc.Testing.WorkerService.Server do case args.argtype do {:setup, client_config} -> manager = ClientManager.start_client(client_config) - {Grpc.Testing.ClientStatus.new(), manager} + {%Grpc.Testing.ClientStatus{}, manager} {:mark, mark} -> stats = ClientManager.get_stats(manager, mark.reset) - {Grpc.Testing.ClientStatus.new(stats: stats), manager} + {%Grpc.Testing.ClientStatus{stats: stats}, manager} end Logger.debug("Client send reply #{inspect(status)}") @@ -70,6 +70,6 @@ defmodule Grpc.Testing.WorkerService.Server do Logger.debug("Received quit_work") Logger.debug(inspect(stream.local[:main_pid])) send(stream.local[:main_pid], {:quit, self()}) - Grpc.Testing.Void.new() + %Grpc.Testing.Void{} end end diff --git a/benchmark/lib/mix/tasks/benchmark.test.ex b/benchmark/lib/mix/tasks/benchmark.test.ex index a9661fb54..a763057c6 100644 --- a/benchmark/lib/mix/tasks/benchmark.test.ex +++ b/benchmark/lib/mix/tasks/benchmark.test.ex @@ -1,19 +1,22 @@ defmodule Mix.Tasks.Benchmark.Test do @moduledoc """ Runs a simple gRPC benchmark test. - + This task starts a benchmark server and client, runs performance tests, and reports statistics. - + ## Usage - + mix benchmark.test - + mix benchmark.test --adapter=thousand_island + mix benchmark.test --adapter=cowboy + ## Options - + * `--port` - Server port (default: 10000) * `--requests` - Number of requests to send (default: 1000) - + * `--adapter` - Server adapter: cowboy or thousand_island (default: cowboy) + """ use Mix.Task @@ -27,13 +30,27 @@ defmodule Mix.Tasks.Benchmark.Test do {parsed, _remaining, _invalid} = OptionParser.parse(args, - strict: [port: :integer, requests: :integer] + strict: [port: :integer, requests: :integer, adapter: :string] ) port = Keyword.get(parsed, :port, 10000) num_requests = Keyword.get(parsed, :requests, 1000) + adapter_name = Keyword.get(parsed, :adapter, "cowboy") + + adapter = + case String.downcase(adapter_name) do + "thousand_island" -> + GRPC.Server.Adapters.ThousandIsland + + "cowboy" -> + GRPC.Server.Adapters.Cowboy - Logger.info("Starting benchmark test on port #{port}") + _ -> + Logger.error("Unknown adapter: #{adapter_name}. Using Cowboy.") + GRPC.Server.Adapters.Cowboy + end + + Logger.info("Starting benchmark test on port #{port} with #{adapter_name} adapter") # Configure and start server server = %Grpc.Testing.ServerConfig{ @@ -48,7 +65,7 @@ defmodule Mix.Tasks.Benchmark.Test do } Logger.info("Starting server...") - server = Benchmark.ServerManager.start_server(server) + server = Benchmark.ServerManager.start_server(server, adapter: adapter) Logger.info("Server started: #{inspect(server)}") # Configure client @@ -98,7 +115,7 @@ defmodule Mix.Tasks.Benchmark.Test do # Connect and warm up Logger.info("Connecting to server...") {:ok, ch} = GRPC.Stub.connect("localhost:#{port}") - + Logger.info("Warming up...") Grpc.Testing.BenchmarkService.Stub.unary_call(ch, req) @@ -132,6 +149,10 @@ defmodule Mix.Tasks.Benchmark.Test do IO.inspect(stats, label: "Stats", pretty: true) Logger.info("=" |> String.duplicate(60)) + # Clean shutdown + Logger.info("Stopping server...") + Benchmark.ServerManager.stop_server(server, adapter: adapter) + Logger.info("Server stopped") :ok end end diff --git a/benchmark/lib/mix/tasks/benchmark.worker.ex b/benchmark/lib/mix/tasks/benchmark.worker.ex index 3b24fc7e7..c89865b65 100644 --- a/benchmark/lib/mix/tasks/benchmark.worker.ex +++ b/benchmark/lib/mix/tasks/benchmark.worker.ex @@ -1,15 +1,15 @@ defmodule Mix.Tasks.Benchmark.Worker do @moduledoc """ Starts a gRPC worker server for benchmarking. - + ## Usage - + mix benchmark.worker --port=10000 - + ## Options - + * `--port` - Port to listen on (required) - + """ use Mix.Task diff --git a/benchmark/mix.lock b/benchmark/mix.lock index 427f40b39..c2a78ae07 100644 --- a/benchmark/mix.lock +++ b/benchmark/mix.lock @@ -11,4 +11,5 @@ "protobuf": {:hex, :protobuf, "0.15.0", "c9fc1e9fc1682b05c601df536d5ff21877b55e2023e0466a3855cc1273b74dcb", [:mix], [{:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "5d7bb325319db1d668838d2691c31c7b793c34111aec87d5ee467a39dac6e051"}, "ranch": {:hex, :ranch, "2.2.0", "25528f82bc8d7c6152c57666ca99ec716510fe0925cb188172f41ce93117b1b0", [:make, :rebar3], [], "hexpm", "fa0b99a1780c80218a4197a59ea8d3bdae32fbff7e88527d7d8a4787eff4f8e7"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "thousand_island": {:hex, :thousand_island, "1.4.2", "735fa783005d1703359bbd2d3a5a3a398075ba4456e5afe3c5b7cf4666303d36", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "1c7637f16558fc1c35746d5ee0e83b18b8e59e18d28affd1f2fa1645f8bc7473"}, } diff --git a/benchmark/test/benchmark_test.exs b/benchmark/test/benchmark_test.exs index 75bbbaf6d..cc6f6211b 100644 --- a/benchmark/test/benchmark_test.exs +++ b/benchmark/test/benchmark_test.exs @@ -1,5 +1,4 @@ defmodule BenchmarkgTest do use ExUnit.Case doctest Benchmark - end diff --git a/grpc_client/mix.lock b/grpc_client/mix.lock index 32c8ac98a..a3e339e6c 100644 --- a/grpc_client/mix.lock +++ b/grpc_client/mix.lock @@ -21,4 +21,5 @@ "protobuf": {:hex, :protobuf, "0.15.0", "c9fc1e9fc1682b05c601df536d5ff21877b55e2023e0466a3855cc1273b74dcb", [:mix], [{:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "5d7bb325319db1d668838d2691c31c7b793c34111aec87d5ee467a39dac6e051"}, "ranch": {:hex, :ranch, "2.2.0", "25528f82bc8d7c6152c57666ca99ec716510fe0925cb188172f41ce93117b1b0", [:make, :rebar3], [], "hexpm", "fa0b99a1780c80218a4197a59ea8d3bdae32fbff7e88527d7d8a4787eff4f8e7"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "thousand_island": {:hex, :thousand_island, "1.4.2", "735fa783005d1703359bbd2d3a5a3a398075ba4456e5afe3c5b7cf4666303d36", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "1c7637f16558fc1c35746d5ee0e83b18b8e59e18d28affd1f2fa1645f8bc7473"}, } diff --git a/grpc_client/test/grpc/integration/thousand_island_adapter_test.exs b/grpc_client/test/grpc/integration/thousand_island_adapter_test.exs new file mode 100644 index 000000000..804562b9b --- /dev/null +++ b/grpc_client/test/grpc/integration/thousand_island_adapter_test.exs @@ -0,0 +1,265 @@ +defmodule GRPC.Integration.ThousandIslandAdapterTest do + @moduledoc """ + Integration tests for the ThousandIsland adapter. + """ + + use GRPC.Integration.TestCase + + setup do + {:ok, adapter_opts: [adapter: GRPC.Server.Adapters.ThousandIsland]} + end + + defmodule HelloServer do + use GRPC.Server, service: Helloworld.Greeter.Service + + def say_hello(request, materializer) do + GRPC.Stream.unary(request, materializer: materializer) + |> GRPC.Stream.map(fn req -> + %Helloworld.HelloReply{message: "Hello #{req.name}!"} + end) + |> GRPC.Stream.run() + end + end + + defmodule RouteServer do + use GRPC.Server, service: Routeguide.RouteGuide.Service + require Logger + + def get_feature(point, materializer) do + GRPC.Stream.unary(point, materializer: materializer) + |> GRPC.Stream.map(fn point -> + %Routeguide.Feature{ + location: point, + name: "Feature at #{point.latitude},#{point.longitude}" + } + end) + |> GRPC.Stream.run() + end + + def list_features(_rectangle, materializer) do + features = + Enum.map(1..5, fn i -> + %Routeguide.Feature{ + location: %Routeguide.Point{latitude: i * 10, longitude: i * 20}, + name: "Feature #{i}" + } + end) + + features + |> GRPC.Stream.from() + |> GRPC.Stream.run_with(materializer) + end + + def record_route(point_stream, _materializer) do + # For client streaming, process input and return single response + count = Enum.reduce(point_stream, 0, fn _point, acc -> acc + 1 end) + + %Routeguide.RouteSummary{ + point_count: count, + feature_count: count, + distance: count * 100, + elapsed_time: count * 10 + } + end + + def route_chat(note_stream, materializer) do + GRPC.Stream.from(note_stream) + |> GRPC.Stream.map(fn note -> + %Routeguide.RouteNote{ + # location: note.location, + message: "Echo: #{note.message}" + } + end) + |> GRPC.Stream.run_with(materializer) + end + end + + describe "ThousandIsland adapter - unary RPC" do + test "handles simple unary request/response", %{adapter_opts: adapter_opts} do + run_server( + [HelloServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + request = %Helloworld.HelloRequest{name: "ThousandIsland"} + {:ok, response} = channel |> Helloworld.Greeter.Stub.say_hello(request) + + assert response.message == "Hello ThousandIsland!" + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + + test "handles multiple sequential unary calls", %{adapter_opts: adapter_opts} do + run_server( + [HelloServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + for i <- 1..10 do + request = %Helloworld.HelloRequest{name: "User#{i}"} + {:ok, response} = channel |> Helloworld.Greeter.Stub.say_hello(request) + assert response.message == "Hello User#{i}!" + end + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + end + + describe "ThousandIsland adapter - server streaming RPC" do + test "receives multiple responses from server", %{adapter_opts: adapter_opts} do + run_server( + [RouteServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + rectangle = %Routeguide.Rectangle{ + lo: %Routeguide.Point{latitude: 0, longitude: 0}, + hi: %Routeguide.Point{latitude: 100, longitude: 100} + } + + {:ok, stream} = channel |> Routeguide.RouteGuide.Stub.list_features(rectangle) + + features = stream |> Enum.map(fn {:ok, f} -> f end) |> Enum.to_list() + + assert length(features) == 5 + + Enum.each(1..5, fn i -> + feature = Enum.at(features, i - 1) + assert feature.name == "Feature #{i}" + assert feature.location.latitude == i * 10 + assert feature.location.longitude == i * 20 + end) + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + end + + describe "ThousandIsland adapter - client streaming RPC" do + test "sends multiple requests and receives single response", %{adapter_opts: adapter_opts} do + run_server( + [RouteServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + points = [ + %Routeguide.Point{latitude: 10, longitude: 20}, + %Routeguide.Point{latitude: 30, longitude: 40}, + %Routeguide.Point{latitude: 50, longitude: 60} + ] + + stream = channel |> Routeguide.RouteGuide.Stub.record_route() + + Enum.each(points, fn point -> + GRPC.Stub.send_request(stream, point) + end) + + GRPC.Stub.end_stream(stream) + + {:ok, summary} = GRPC.Stub.recv(stream) + + assert summary.point_count == 3 + assert summary.feature_count == 3 + assert summary.distance == 300 + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + end + + describe "ThousandIsland adapter - bidirectional streaming RPC" do + test "exchanges messages bidirectionally", %{adapter_opts: adapter_opts} do + run_server( + [RouteServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + notes = [ + %Routeguide.RouteNote{ + location: %Routeguide.Point{latitude: 1, longitude: 2}, + message: "First note" + }, + %Routeguide.RouteNote{ + location: %Routeguide.Point{latitude: 3, longitude: 4}, + message: "Second note" + }, + %Routeguide.RouteNote{ + location: %Routeguide.Point{latitude: 5, longitude: 6}, + message: "Third note" + } + ] + + bidi_stream = channel |> Routeguide.RouteGuide.Stub.route_chat() + + Enum.each(notes, fn note -> + GRPC.Stub.send_request(bidi_stream, note) + end) + + GRPC.Stub.end_stream(bidi_stream) + + {:ok, response_stream} = GRPC.Stub.recv(bidi_stream) + responses = response_stream |> Enum.map(fn {:ok, r} -> r end) |> Enum.to_list() + + assert length(responses) == 3 + assert Enum.at(responses, 0).message == "Echo: First note" + assert Enum.at(responses, 1).message == "Echo: Second note" + assert Enum.at(responses, 2).message == "Echo: Third note" + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + end + + describe "ThousandIsland adapter - HTTP/2 protocol validation" do + test "handles multiple concurrent unary calls on same connection", %{ + adapter_opts: adapter_opts + } do + run_server( + [HelloServer], + fn port -> + {:ok, channel} = GRPC.Stub.connect("localhost:#{port}") + + tasks = + 1..10 + |> Enum.map(fn i -> + Task.async(fn -> + request = %Helloworld.HelloRequest{name: "Concurrent#{i}"} + {:ok, response} = channel |> Helloworld.Greeter.Stub.say_hello(request) + response + end) + end) + + responses = Task.await_many(tasks, 5000) + + assert length(responses) == 10 + + Enum.each(1..10, fn i -> + response = Enum.find(responses, fn r -> r.message == "Hello Concurrent#{i}!" end) + assert response != nil + end) + + GRPC.Stub.disconnect(channel) + end, + 0, + adapter_opts + ) + end + end +end diff --git a/grpc_client/test/support/integration_data_case.ex b/grpc_client/test/support/integration_data_case.ex index b51aecd19..69db0dc8a 100644 --- a/grpc_client/test/support/integration_data_case.ex +++ b/grpc_client/test/support/integration_data_case.ex @@ -22,15 +22,17 @@ defmodule GRPC.Integration.TestCase do try do func.(port) after - :ok = GRPC.Server.stop(servers) + # GRPC.Server.stop only accepts :adapter option + stop_opts = Keyword.take(opts, [:adapter]) + :ok = GRPC.Server.stop(servers, stop_opts) end end - def run_endpoint(endpoint, func, port \\ 0) do + def run_endpoint(endpoint, func, port \\ 0, opts \\ []) do {:ok, _pid, port} = start_supervised(%{ id: {GRPC.Server, System.unique_integer([:positive])}, - start: {GRPC.Server, :start_endpoint, [endpoint, port]}, + start: {GRPC.Server, :start_endpoint, [endpoint, port, opts]}, type: :worker, restart: :permanent, shutdown: 500 @@ -39,7 +41,9 @@ defmodule GRPC.Integration.TestCase do try do func.(port) after - :ok = GRPC.Server.stop_endpoint(endpoint, []) + # GRPC.Server.stop_endpoint only accepts :adapter option + stop_opts = Keyword.take(opts, [:adapter]) + :ok = GRPC.Server.stop_endpoint(endpoint, stop_opts) end end diff --git a/grpc_core/lib/grpc/codec/proto.ex b/grpc_core/lib/grpc/codec/proto.ex index 6dd53136e..f140f74f2 100644 --- a/grpc_core/lib/grpc/codec/proto.ex +++ b/grpc_core/lib/grpc/codec/proto.ex @@ -1,15 +1,12 @@ defmodule GRPC.Codec.Proto do @behaviour GRPC.Codec - def name() do - "proto" - end + # Inline codec functions for better performance + @compile {:inline, name: 0, encode: 2, decode: 2} - def encode(struct, _opts \\ []) do - Protobuf.Encoder.encode_to_iodata(struct) - end + def name, do: "proto" - def decode(binary, module) do - module.decode(binary) - end + def encode(struct, _opts \\ []), do: Protobuf.Encoder.encode_to_iodata(struct) + + def decode(binary, module), do: module.decode(binary) end diff --git a/grpc_core/lib/grpc/message.ex b/grpc_core/lib/grpc/message.ex index 03ac2e339..2e0b3f0c5 100644 --- a/grpc_core/lib/grpc/message.ex +++ b/grpc_core/lib/grpc/message.ex @@ -15,6 +15,9 @@ defmodule GRPC.Message do @max_message_length Bitwise.bsl(1, 32 - 1) + # Inline hot path functions, this reduces between 07-10% of overhead in benchmarks + @compile {:inline, to_data: 2, from_data: 1} + @doc """ Transforms Protobuf data into a gRPC body binary. @@ -42,33 +45,40 @@ defmodule GRPC.Message do @spec to_data(iodata, keyword()) :: {:ok, iodata, non_neg_integer} | {:error, String.t()} def to_data(message, opts \\ []) do - compressor = opts[:compressor] - iolist = opts[:iolist] - codec = opts[:codec] max_length = opts[:max_message_length] || @max_message_length - {compress_flag, message} = - if compressor do - {1, compressor.compress(message)} - else - {0, message} + {compress_flag, compressed_message} = + case opts[:compressor] do + nil -> {0, message} + compressor -> {1, compressor.compress(message)} end - length = IO.iodata_length(message) + length = IO.iodata_length(compressed_message) - if length > max_length do - {:error, "Encoded message is too large (#{length} bytes)"} - else - result = [compress_flag, <>, message] + if length <= max_length do + result = [compress_flag, <>, compressed_message] result = - if function_exported?(codec, :pack_for_channel, 1), - do: codec.pack_for_channel(result), - else: result + if opts[:codec] != nil and is_atom(opts[:codec]) do + codec = opts[:codec] - result = if iolist, do: result, else: IO.iodata_to_binary(result) + if function_exported?(codec, :pack_for_channel, 1), + do: codec.pack_for_channel(result), + else: result + else + result + end + + result = + if opts[:iolist] == true do + result + else + IO.iodata_to_binary(result) + end {:ok, result, length + 5} + else + {:error, "Encoded message is too large (#{length} bytes)"} end end diff --git a/grpc_core/lib/grpc/transport/http2/errors.ex b/grpc_core/lib/grpc/transport/http2/errors.ex new file mode 100644 index 000000000..f4798dbe1 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/errors.ex @@ -0,0 +1,54 @@ +defmodule GRPC.Transport.HTTP2.Errors do + @moduledoc false + # Errors as defined in RFC9113§7 + + @typedoc "An error code as defined for GOAWAY and RST_STREAM errors" + @type error_code() :: + (no_error :: 0x0) + | (protocol_error :: 0x1) + | (internal_error :: 0x2) + | (flow_control_error :: 0x3) + | (settings_timeout :: 0x4) + | (stream_closed :: 0x5) + | (frame_size_error :: 0x6) + | (refused_stream :: 0x7) + | (cancel :: 0x8) + | (compression_error :: 0x9) + | (connect_error :: 0xA) + | (enhance_your_calm :: 0xB) + | (inadequate_security :: 0xC) + | (http_1_1_requires :: 0xD) + + error_codes = %{ + no_error: 0x0, + protocol_error: 0x1, + internal_error: 0x2, + flow_control_error: 0x3, + settings_timeout: 0x4, + stream_closed: 0x5, + frame_size_error: 0x6, + refused_stream: 0x7, + cancel: 0x8, + compression_error: 0x9, + connect_error: 0xA, + enhance_your_calm: 0xB, + inadequate_security: 0xC, + http_1_1_requires: 0xD + } + + for {name, code} <- error_codes do + def unquote(name)(), do: unquote(code) + end + + defmodule ConnectionError do + @moduledoc false + + defexception message: nil, error_code: nil + end + + defmodule StreamError do + @moduledoc false + + defexception message: nil, error_code: nil, stream_id: nil + end +end diff --git a/grpc_core/lib/grpc/transport/http2/flow_control.ex b/grpc_core/lib/grpc/transport/http2/flow_control.ex new file mode 100644 index 000000000..d20065471 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/flow_control.ex @@ -0,0 +1,44 @@ +defmodule GRPC.Transport.HTTP2.FlowControl do + @moduledoc false + # Helpers for working with flow control window calculations + + import Bitwise + + @max_window_increment (1 <<< 31) - 1 + @max_window_size (1 <<< 31) - 1 + @min_window_size 1 <<< 30 + + @spec compute_recv_window(non_neg_integer(), non_neg_integer()) :: + {non_neg_integer(), non_neg_integer()} + def compute_recv_window(recv_window_size, data_size) do + # This is what our window size will be after receiving data_size bytes + recv_window_size = recv_window_size - data_size + + if recv_window_size > @min_window_size do + # We have room to go before we need to update our window + {recv_window_size, 0} + else + # We want our new window to be as large as possible, but are limited by both the maximum size + # of a WINDOW_UPDATE frame (max_window_increment) and the maximum window size (max_window_size) + window_increment = + min(@max_window_increment, @max_window_size - recv_window_size) + + {recv_window_size + window_increment, window_increment} + end + end + + @doc """ + Updates window size by increment, ensuring it doesn't exceed maximum. + """ + @spec update_window(non_neg_integer(), integer()) :: + {:ok, non_neg_integer()} | {:error, :flow_control_error} + def update_window(current_size, increment) do + new_size = current_size + increment + + if new_size > @max_window_size do + {:error, :flow_control_error} + else + {:ok, new_size} + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame.ex b/grpc_core/lib/grpc/transport/http2/frame.ex new file mode 100644 index 000000000..64829cbd8 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame.ex @@ -0,0 +1,93 @@ +defmodule GRPC.Transport.HTTP2.Frame do + @moduledoc false + # HTTP/2 frame parsing and serialization adapted from Bandit + + @typedoc "Indicates a frame type" + @type frame_type :: non_neg_integer() + + @typedoc "The flags passed along with a frame" + @type flags :: byte() + + @typedoc "A valid HTTP/2 frame" + @type frame :: + GRPC.Transport.HTTP2.Frame.Data.t() + | GRPC.Transport.HTTP2.Frame.Headers.t() + | GRPC.Transport.HTTP2.Frame.Priority.t() + | GRPC.Transport.HTTP2.Frame.RstStream.t() + | GRPC.Transport.HTTP2.Frame.Settings.t() + | GRPC.Transport.HTTP2.Frame.Ping.t() + | GRPC.Transport.HTTP2.Frame.Goaway.t() + | GRPC.Transport.HTTP2.Frame.WindowUpdate.t() + | GRPC.Transport.HTTP2.Frame.Continuation.t() + | GRPC.Transport.HTTP2.Frame.Unknown.t() + + @spec deserialize(binary(), non_neg_integer()) :: + {{:ok, frame()}, iodata()} + | {{:more, iodata()}, <<>>} + | {{:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()}, iodata()} + | nil + def deserialize( + <>, + max_frame_size + ) + when length <= max_frame_size do + case deserialize_frame_by_type(type, flags, stream_id, payload) do + {:ok, frame} -> {{:ok, frame}, rest} + {:error, error_code, reason} -> {{:error, error_code, reason}, rest} + end + end + + def deserialize( + <>, + max_frame_size + ) + when length > max_frame_size do + {{:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Payload size too large (RFC9113§4.2)"}, rest} + end + + # nil is used to indicate for Stream.unfold/2 that the frame deserialization is finished + def deserialize(<<>>, _max_frame_size) do + nil + end + + def deserialize(msg, _max_frame_size) do + {{:more, msg}, <<>>} + end + + defp deserialize_frame_by_type(type, flags, stream_id, payload) do + case type do + 0 -> GRPC.Transport.HTTP2.Frame.Data.deserialize(flags, stream_id, payload) + 1 -> GRPC.Transport.HTTP2.Frame.Headers.deserialize(flags, stream_id, payload) + 2 -> GRPC.Transport.HTTP2.Frame.Priority.deserialize(flags, stream_id, payload) + 3 -> GRPC.Transport.HTTP2.Frame.RstStream.deserialize(flags, stream_id, payload) + 4 -> GRPC.Transport.HTTP2.Frame.Settings.deserialize(flags, stream_id, payload) + 5 -> GRPC.Transport.HTTP2.Frame.PushPromise.deserialize(flags, stream_id, payload) + 6 -> GRPC.Transport.HTTP2.Frame.Ping.deserialize(flags, stream_id, payload) + 7 -> GRPC.Transport.HTTP2.Frame.Goaway.deserialize(flags, stream_id, payload) + 8 -> GRPC.Transport.HTTP2.Frame.WindowUpdate.deserialize(flags, stream_id, payload) + 9 -> GRPC.Transport.HTTP2.Frame.Continuation.deserialize(flags, stream_id, payload) + _unknown -> GRPC.Transport.HTTP2.Frame.Unknown.deserialize(type, flags, stream_id, payload) + end + end + + defprotocol Serializable do + @moduledoc false + + @spec serialize(any(), non_neg_integer()) :: [ + {GRPC.Transport.HTTP2.Frame.frame_type(), GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), iodata()} + ] + def serialize(frame, max_frame_size) + end + + @spec serialize(frame(), non_neg_integer()) :: iolist() + def serialize(frame, max_frame_size) do + frame + |> Serializable.serialize(max_frame_size) + |> Enum.map(fn {type, flags, stream_id, payload} -> + [<>, payload] + end) + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/continuation.ex b/grpc_core/lib/grpc/transport/http2/frame/continuation.ex new file mode 100644 index 000000000..a9caea598 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/continuation.ex @@ -0,0 +1,64 @@ +defmodule GRPC.Transport.HTTP2.Frame.Continuation do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + + defstruct stream_id: nil, + end_headers: false, + fragment: nil + + @typedoc "An HTTP/2 CONTINUATION frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + end_headers: boolean(), + fragment: iodata() + } + + @end_headers_bit 2 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "CONTINUATION frame with zero stream_id (RFC9113§6.10)"} + end + + def deserialize(flags, stream_id, <>) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_headers: set?(flags, @end_headers_bit), + fragment: fragment + }} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @end_headers_bit 2 + + def serialize(%GRPC.Transport.HTTP2.Frame.Continuation{} = frame, max_frame_size) do + fragment_length = IO.iodata_length(frame.fragment) + + if fragment_length <= max_frame_size do + [{9, set([@end_headers_bit]), frame.stream_id, frame.fragment}] + else + <> = + IO.iodata_to_binary(frame.fragment) + + [ + {9, 0, frame.stream_id, this_frame} + | GRPC.Transport.HTTP2.Frame.Serializable.serialize( + %GRPC.Transport.HTTP2.Frame.Continuation{ + stream_id: frame.stream_id, + fragment: rest + }, + max_frame_size + ) + ] + end + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/data.ex b/grpc_core/lib/grpc/transport/http2/frame/data.ex new file mode 100644 index 000000000..f5a439577 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/data.ex @@ -0,0 +1,83 @@ +defmodule GRPC.Transport.HTTP2.Frame.Data do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + + defstruct stream_id: nil, + end_stream: false, + data: nil + + @typedoc "An HTTP/2 DATA frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + end_stream: boolean(), + data: iodata() + } + + @end_stream_bit 0 + @padding_bit 3 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "DATA frame with zero stream_id (RFC9113§6.1)"} + end + + def deserialize(flags, stream_id, <>) + when set?(flags, @padding_bit) and byte_size(rest) >= padding_length do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + data: binary_part(rest, 0, byte_size(rest) - padding_length) + }} + end + + def deserialize(flags, stream_id, <>) when not set?(flags, @padding_bit) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + data: data + }} + end + + def deserialize(flags, _stream_id, <<_padding_length::8, _rest::binary>>) + when set?(flags, @padding_bit) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "DATA frame with invalid padding length (RFC9113§6.1)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @end_stream_bit 0 + + def serialize(%GRPC.Transport.HTTP2.Frame.Data{} = frame, max_frame_size) do + data_length = IO.iodata_length(frame.data) + + if data_length <= max_frame_size do + flags = if frame.end_stream, do: [@end_stream_bit], else: [] + [{0, set(flags), frame.stream_id, frame.data}] + else + <> = + IO.iodata_to_binary(frame.data) + + [ + {0, 0, frame.stream_id, this_frame} + | GRPC.Transport.HTTP2.Frame.Serializable.serialize( + %GRPC.Transport.HTTP2.Frame.Data{ + stream_id: frame.stream_id, + end_stream: frame.end_stream, + data: rest + }, + max_frame_size + ) + ] + end + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/flags.ex b/grpc_core/lib/grpc/transport/http2/frame/flags.ex new file mode 100644 index 000000000..2d500b37a --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/flags.ex @@ -0,0 +1,11 @@ +defmodule GRPC.Transport.HTTP2.Frame.Flags do + @moduledoc false + import Bitwise + + defguard set?(flags, bit) when band(flags, bsl(1, bit)) != 0 + + @spec set(list(0..7)) :: 0..255 + def set(bits) do + Enum.reduce(bits, 0, fn bit, acc -> bor(acc, bsl(1, bit)) end) + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/goaway.ex b/grpc_core/lib/grpc/transport/http2/frame/goaway.ex new file mode 100644 index 000000000..4bb3fe609 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/goaway.ex @@ -0,0 +1,39 @@ +defmodule GRPC.Transport.HTTP2.Frame.Goaway do + @moduledoc false + + defstruct last_stream_id: 0, error_code: 0, debug_data: <<>> + + @typedoc "An HTTP/2 GOAWAY frame" + @type t :: %__MODULE__{ + last_stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + error_code: GRPC.Transport.HTTP2.Errors.error_code(), + debug_data: iodata() + } + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize( + _flags, + 0, + <<_reserved::1, last_stream_id::31, error_code::32, debug_data::binary>> + ) do + {:ok, + %__MODULE__{last_stream_id: last_stream_id, error_code: error_code, debug_data: debug_data}} + end + + def deserialize(_flags, stream_id, _payload) when stream_id != 0 do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "GOAWAY frame with non-zero stream_id (RFC9113§6.8)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + def serialize(%GRPC.Transport.HTTP2.Frame.Goaway{} = frame, _max_frame_size) do + payload = <<0::1, frame.last_stream_id::31, frame.error_code::32, frame.debug_data::binary>> + [{7, 0, 0, payload}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/headers.ex b/grpc_core/lib/grpc/transport/http2/frame/headers.ex new file mode 100644 index 000000000..df8cf92d8 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/headers.ex @@ -0,0 +1,147 @@ +defmodule GRPC.Transport.HTTP2.Frame.Headers do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + + defstruct stream_id: nil, + end_stream: false, + end_headers: false, + exclusive_dependency: false, + stream_dependency: nil, + weight: nil, + fragment: nil + + @typedoc "An HTTP/2 HEADERS frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + end_stream: boolean(), + end_headers: boolean(), + exclusive_dependency: boolean(), + stream_dependency: GRPC.Transport.HTTP2.Stream.stream_id() | nil, + weight: non_neg_integer() | nil, + fragment: iodata() + } + + @end_stream_bit 0 + @end_headers_bit 2 + @padding_bit 3 + @priority_bit 5 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "HEADERS frame with zero stream_id (RFC9113§6.2)"} + end + + # Padding and priority + def deserialize( + flags, + stream_id, + <> + ) + when set?(flags, @padding_bit) and set?(flags, @priority_bit) and + byte_size(rest) >= padding_length do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + end_headers: set?(flags, @end_headers_bit), + exclusive_dependency: exclusive_dependency == 1, + stream_dependency: stream_dependency, + weight: weight, + fragment: binary_part(rest, 0, byte_size(rest) - padding_length) + }} + end + + # Padding but not priority + def deserialize(flags, stream_id, <>) + when set?(flags, @padding_bit) and not set?(flags, @priority_bit) and + byte_size(rest) >= padding_length do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + end_headers: set?(flags, @end_headers_bit), + fragment: binary_part(rest, 0, byte_size(rest) - padding_length) + }} + end + + # Any other case where padding is set + def deserialize(flags, _stream_id, <<_padding_length::8, _rest::binary>>) + when set?(flags, @padding_bit) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "HEADERS frame with invalid padding length (RFC9113§6.2)"} + end + + def deserialize( + flags, + stream_id, + <> + ) + when set?(flags, @priority_bit) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + end_headers: set?(flags, @end_headers_bit), + exclusive_dependency: exclusive_dependency == 1, + stream_dependency: stream_dependency, + weight: weight, + fragment: fragment + }} + end + + def deserialize(flags, stream_id, <>) + when not set?(flags, @priority_bit) and not set?(flags, @padding_bit) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_stream: set?(flags, @end_stream_bit), + end_headers: set?(flags, @end_headers_bit), + fragment: fragment + }} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @end_stream_bit 0 + @end_headers_bit 2 + + def serialize( + %GRPC.Transport.HTTP2.Frame.Headers{ + exclusive_dependency: false, + stream_dependency: nil, + weight: nil + } = + frame, + max_frame_size + ) do + flags = if frame.end_stream, do: [@end_stream_bit], else: [] + + fragment_length = IO.iodata_length(frame.fragment) + + if fragment_length <= max_frame_size do + [{1, set([@end_headers_bit | flags]), frame.stream_id, frame.fragment}] + else + <> = + IO.iodata_to_binary(frame.fragment) + + [ + {1, set(flags), frame.stream_id, this_frame} + | GRPC.Transport.HTTP2.Frame.Serializable.serialize( + %GRPC.Transport.HTTP2.Frame.Continuation{ + stream_id: frame.stream_id, + fragment: rest + }, + max_frame_size + ) + ] + end + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/ping.ex b/grpc_core/lib/grpc/transport/http2/frame/ping.ex new file mode 100644 index 000000000..155414202 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/ping.ex @@ -0,0 +1,45 @@ +defmodule GRPC.Transport.HTTP2.Frame.Ping do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + + defstruct ack: false, payload: nil + + @typedoc "An HTTP/2 PING frame" + @type t :: %__MODULE__{ + ack: boolean(), + payload: iodata() + } + + @ack_bit 0 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(flags, 0, <>) do + {:ok, %__MODULE__{ack: set?(flags, @ack_bit), payload: payload}} + end + + def deserialize(_flags, stream_id, _payload) when stream_id != 0 do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "Invalid stream ID in PING frame (RFC9113§6.7)"} + end + + def deserialize(_flags, _stream_id, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "PING frame with invalid payload size (RFC9113§6.7)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @ack_bit 0 + + def serialize(%GRPC.Transport.HTTP2.Frame.Ping{ack: true} = frame, _max_frame_size), + do: [{6, set([@ack_bit]), 0, frame.payload}] + + def serialize(%GRPC.Transport.HTTP2.Frame.Ping{ack: false} = frame, _max_frame_size), + do: [{6, 0, 0, frame.payload}] + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/priority.ex b/grpc_core/lib/grpc/transport/http2/frame/priority.ex new file mode 100644 index 000000000..65815bc8f --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/priority.ex @@ -0,0 +1,54 @@ +defmodule GRPC.Transport.HTTP2.Frame.Priority do + @moduledoc false + + defstruct stream_id: nil, + exclusive_dependency: false, + stream_dependency: nil, + weight: nil + + @typedoc "An HTTP/2 PRIORITY frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + exclusive_dependency: boolean(), + stream_dependency: GRPC.Transport.HTTP2.Stream.stream_id(), + weight: non_neg_integer() + } + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "PRIORITY frame with zero stream_id (RFC9113§6.3)"} + end + + def deserialize( + _flags, + stream_id, + <> + ) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + exclusive_dependency: exclusive_dependency == 1, + stream_dependency: stream_dependency, + weight: weight + }} + end + + def deserialize(_flags, _stream_id, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid payload size in PRIORITY frame (RFC9113§6.3)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + def serialize(%GRPC.Transport.HTTP2.Frame.Priority{} = frame, _max_frame_size) do + exclusive = if frame.exclusive_dependency, do: 0x01, else: 0x00 + payload = <> + [{0x2, 0x0, frame.stream_id, payload}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/push_promise.ex b/grpc_core/lib/grpc/transport/http2/frame/push_promise.ex new file mode 100644 index 000000000..e3dfb5806 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/push_promise.ex @@ -0,0 +1,73 @@ +defmodule GRPC.Transport.HTTP2.Frame.PushPromise do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + + defstruct stream_id: nil, + end_headers: false, + promised_stream_id: nil, + fragment: nil + + @typedoc "An HTTP/2 PUSH_PROMISE frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + end_headers: boolean(), + promised_stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + fragment: iodata() + } + + @end_headers_bit 2 + @padding_bit 3 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "PUSH_PROMISE frame with zero stream_id (RFC9113§6.6)"} + end + + def deserialize( + flags, + stream_id, + <> + ) + when set?(flags, @padding_bit) and byte_size(rest) >= padding_length do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_headers: set?(flags, @end_headers_bit), + promised_stream_id: promised_stream_id, + fragment: binary_part(rest, 0, byte_size(rest) - padding_length) + }} + end + + def deserialize(flags, stream_id, <<_reserved::1, promised_stream_id::31, fragment::binary>>) + when not set?(flags, @padding_bit) do + {:ok, + %__MODULE__{ + stream_id: stream_id, + end_headers: set?(flags, @end_headers_bit), + promised_stream_id: promised_stream_id, + fragment: fragment + }} + end + + def deserialize(flags, _stream_id, <<_padding_length::8, _rest::binary>>) + when set?(flags, @padding_bit) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "PUSH_PROMISE frame with invalid padding length (RFC9113§6.6)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @end_headers_bit 2 + + def serialize(%GRPC.Transport.HTTP2.Frame.PushPromise{} = frame, _max_frame_size) do + payload = <<0::1, frame.promised_stream_id::31, frame.fragment::binary>> + [{5, set([@end_headers_bit]), frame.stream_id, payload}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/rst_stream.ex b/grpc_core/lib/grpc/transport/http2/frame/rst_stream.ex new file mode 100644 index 000000000..e316dbc84 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/rst_stream.ex @@ -0,0 +1,37 @@ +defmodule GRPC.Transport.HTTP2.Frame.RstStream do + @moduledoc false + + defstruct stream_id: nil, error_code: nil + + @typedoc "An HTTP/2 RST_STREAM frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + error_code: GRPC.Transport.HTTP2.Errors.error_code() + } + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, 0, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "RST_STREAM frame with zero stream_id (RFC9113§6.4)"} + end + + def deserialize(_flags, stream_id, <>) do + {:ok, %__MODULE__{stream_id: stream_id, error_code: error_code}} + end + + def deserialize(_flags, _stream_id, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid payload size in RST_STREAM frame (RFC9113§6.4)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + def serialize(%GRPC.Transport.HTTP2.Frame.RstStream{} = frame, _max_frame_size) do + [{3, 0, frame.stream_id, <>}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/settings.ex b/grpc_core/lib/grpc/transport/http2/frame/settings.ex new file mode 100644 index 000000000..5d8d9939b --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/settings.ex @@ -0,0 +1,123 @@ +defmodule GRPC.Transport.HTTP2.Frame.Settings do + @moduledoc false + + import GRPC.Transport.HTTP2.Frame.Flags + import Bitwise + + @max_window_size (1 <<< 31) - 1 + @min_frame_size 1 <<< 14 + @max_frame_size (1 <<< 24) - 1 + + defstruct ack: false, settings: nil + + @typedoc "An HTTP/2 SETTINGS frame" + @type t :: %__MODULE__{ack: true, settings: nil} | %__MODULE__{ack: false, settings: map()} + + @ack_bit 0 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(flags, 0, payload) when not set?(flags, @ack_bit) do + payload + |> Stream.unfold(fn + <<>> -> nil + <> -> {{:ok, {setting, value}}, rest} + <> -> {{:error, rest}, <<>>} + end) + |> Enum.reduce_while({:ok, %{}}, fn + {:ok, {1, value}}, {:ok, acc} -> + {:cont, {:ok, Map.put(acc, :header_table_size, value)}} + + {:ok, {2, val}}, {:ok, acc} when val in [0, 1] -> + {:cont, {:ok, acc}} + + {:ok, {2, _value}}, {:ok, _acc} -> + {:halt, + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), + "Invalid enable_push value (RFC9113§6.5)"}} + + {:ok, {3, value}}, {:ok, acc} -> + {:cont, {:ok, Map.put(acc, :max_concurrent_streams, value)}} + + {:ok, {4, value}}, {:ok, _acc} when value > @max_window_size -> + {:halt, + {:error, GRPC.Transport.HTTP2.Errors.flow_control_error(), + "Invalid window_size (RFC9113§6.5)"}} + + {:ok, {4, value}}, {:ok, acc} -> + {:cont, {:ok, Map.put(acc, :initial_window_size, value)}} + + {:ok, {5, value}}, {:ok, _acc} when value < @min_frame_size -> + {:halt, + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid max_frame_size (RFC9113§6.5)"}} + + {:ok, {5, value}}, {:ok, _acc} when value > @max_frame_size -> + {:halt, + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid max_frame_size (RFC9113§6.5)"}} + + {:ok, {5, value}}, {:ok, acc} -> + {:cont, {:ok, Map.put(acc, :max_frame_size, value)}} + + {:ok, {6, value}}, {:ok, acc} -> + {:cont, {:ok, Map.put(acc, :max_header_list_size, value)}} + + {:ok, {_setting, _value}}, {:ok, acc} -> + {:cont, {:ok, acc}} + + {:error, _rest}, _acc -> + {:halt, + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid SETTINGS size (RFC9113§6.5)"}} + end) + |> case do + {:ok, settings} -> {:ok, %__MODULE__{ack: false, settings: settings}} + {:error, error_code, reason} -> {:error, error_code, reason} + end + end + + def deserialize(flags, 0, <<>>) when set?(flags, @ack_bit) do + {:ok, %__MODULE__{ack: true}} + end + + def deserialize(flags, 0, _payload) when set?(flags, @ack_bit) do + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "SETTINGS ack frame with non-empty payload (RFC9113§6.5)"} + end + + def deserialize(_flags, _stream_id, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.protocol_error(), "Invalid SETTINGS frame (RFC9113§6.5)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + @ack_bit 0 + + def serialize(%GRPC.Transport.HTTP2.Frame.Settings{ack: true}, _max_frame_size), + do: [{4, set([@ack_bit]), 0, <<>>}] + + def serialize(%GRPC.Transport.HTTP2.Frame.Settings{ack: false} = frame, _max_frame_size) do + payload = + frame.settings + |> Enum.uniq_by(fn {setting, _} -> setting end) + |> Enum.map(fn + {:header_table_size, 4_096} -> <<>> + {:header_table_size, value} -> <<1::16, value::32>> + {:max_concurrent_streams, :infinity} -> <<>> + {:max_concurrent_streams, value} -> <<3::16, value::32>> + {:initial_window_size, 65_535} -> <<>> + {:initial_window_size, value} -> <<4::16, value::32>> + {:max_frame_size, 16_384} -> <<>> + {:max_frame_size, value} -> <<5::16, value::32>> + {:max_header_list_size, :infinity} -> <<>> + {:max_header_list_size, value} -> <<6::16, value::32>> + end) + + [{4, 0, 0, payload}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/unknown.ex b/grpc_core/lib/grpc/transport/http2/frame/unknown.ex new file mode 100644 index 000000000..308d89061 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/unknown.ex @@ -0,0 +1,27 @@ +defmodule GRPC.Transport.HTTP2.Frame.Unknown do + @moduledoc false + + defstruct type: nil, + flags: nil, + stream_id: nil, + payload: nil + + @typedoc "An HTTP/2 frame of unknown type" + @type t :: %__MODULE__{ + type: GRPC.Transport.HTTP2.Frame.frame_type(), + flags: GRPC.Transport.HTTP2.Frame.flags(), + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + payload: iodata() + } + + # Note this is arity 4 + @spec deserialize( + GRPC.Transport.HTTP2.Frame.frame_type(), + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: {:ok, t()} + def deserialize(type, flags, stream_id, payload) do + {:ok, %__MODULE__{type: type, flags: flags, stream_id: stream_id, payload: payload}} + end +end diff --git a/grpc_core/lib/grpc/transport/http2/frame/window_update.ex b/grpc_core/lib/grpc/transport/http2/frame/window_update.ex new file mode 100644 index 000000000..bfb369230 --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/frame/window_update.ex @@ -0,0 +1,37 @@ +defmodule GRPC.Transport.HTTP2.Frame.WindowUpdate do + @moduledoc false + + import Bitwise + + defstruct stream_id: nil, size_increment: nil + + @typedoc "An HTTP/2 WINDOW_UPDATE frame" + @type t :: %__MODULE__{ + stream_id: GRPC.Transport.HTTP2.Stream.stream_id(), + size_increment: non_neg_integer() + } + + @max_window_increment (1 <<< 31) - 1 + + @spec deserialize( + GRPC.Transport.HTTP2.Frame.flags(), + GRPC.Transport.HTTP2.Stream.stream_id(), + iodata() + ) :: + {:ok, t()} | {:error, GRPC.Transport.HTTP2.Errors.error_code(), binary()} + def deserialize(_flags, stream_id, <<_reserved::1, size_increment::31>>) + when size_increment > 0 and size_increment <= @max_window_increment do + {:ok, %__MODULE__{stream_id: stream_id, size_increment: size_increment}} + end + + def deserialize(_flags, _stream_id, _payload) do + {:error, GRPC.Transport.HTTP2.Errors.frame_size_error(), + "Invalid WINDOW_UPDATE frame (RFC9113§6.9)"} + end + + defimpl GRPC.Transport.HTTP2.Frame.Serializable do + def serialize(%GRPC.Transport.HTTP2.Frame.WindowUpdate{} = frame, _max_frame_size) do + [{8, 0, frame.stream_id, <<0::1, frame.size_increment::31>>}] + end + end +end diff --git a/grpc_core/lib/grpc/transport/http2/settings.ex b/grpc_core/lib/grpc/transport/http2/settings.ex new file mode 100644 index 000000000..cf03b6baf --- /dev/null +++ b/grpc_core/lib/grpc/transport/http2/settings.ex @@ -0,0 +1,20 @@ +defmodule GRPC.Transport.HTTP2.Settings do + @moduledoc """ + Settings as defined in RFC9113§6.5.2 + """ + + defstruct header_table_size: 4_096, + max_concurrent_streams: :infinity, + initial_window_size: 65_535, + max_frame_size: 16_384, + max_header_list_size: :infinity + + @typedoc "A collection of settings as defined in RFC9113§6.5" + @type t :: %__MODULE__{ + header_table_size: non_neg_integer(), + max_concurrent_streams: non_neg_integer() | :infinity, + initial_window_size: non_neg_integer(), + max_frame_size: non_neg_integer(), + max_header_list_size: non_neg_integer() | :infinity + } +end diff --git a/grpc_core/test/grpc/transport/http2/errors_test.exs b/grpc_core/test/grpc/transport/http2/errors_test.exs new file mode 100644 index 000000000..3f771b588 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/errors_test.exs @@ -0,0 +1,111 @@ +defmodule GRPC.Transport.HTTP2.ErrorsTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Errors + + describe "error codes" do + test "returns correct code for no_error" do + assert Errors.no_error() == 0x0 + end + + test "returns correct code for protocol_error" do + assert Errors.protocol_error() == 0x1 + end + + test "returns correct code for internal_error" do + assert Errors.internal_error() == 0x2 + end + + test "returns correct code for flow_control_error" do + assert Errors.flow_control_error() == 0x3 + end + + test "returns correct code for settings_timeout" do + assert Errors.settings_timeout() == 0x4 + end + + test "returns correct code for stream_closed" do + assert Errors.stream_closed() == 0x5 + end + + test "returns correct code for frame_size_error" do + assert Errors.frame_size_error() == 0x6 + end + + test "returns correct code for refused_stream" do + assert Errors.refused_stream() == 0x7 + end + + test "returns correct code for cancel" do + assert Errors.cancel() == 0x8 + end + + test "returns correct code for compression_error" do + assert Errors.compression_error() == 0x9 + end + + test "returns correct code for connect_error" do + assert Errors.connect_error() == 0xA + end + + test "returns correct code for enhance_your_calm" do + assert Errors.enhance_your_calm() == 0xB + end + + test "returns correct code for inadequate_security" do + assert Errors.inadequate_security() == 0xC + end + + test "returns correct code for http_1_1_requires" do + assert Errors.http_1_1_requires() == 0xD + end + end + + describe "ConnectionError exception" do + test "can be raised with message" do + assert_raise Errors.ConnectionError, "test error", fn -> + raise Errors.ConnectionError, message: "test error" + end + end + + test "can be raised with error code" do + exception = %Errors.ConnectionError{message: "error", error_code: 0x1} + assert exception.error_code == 0x1 + end + + test "can be raised with both message and error code" do + exception = %Errors.ConnectionError{ + message: "protocol violation", + error_code: Errors.protocol_error() + } + + assert exception.message == "protocol violation" + assert exception.error_code == 0x1 + end + end + + describe "StreamError exception" do + test "can be raised with message" do + assert_raise Errors.StreamError, "stream error", fn -> + raise Errors.StreamError, message: "stream error" + end + end + + test "can be raised with stream_id" do + exception = %Errors.StreamError{message: "error", stream_id: 1} + assert exception.stream_id == 1 + end + + test "can be raised with all fields" do + exception = %Errors.StreamError{ + message: "stream closed", + error_code: Errors.stream_closed(), + stream_id: 3 + } + + assert exception.message == "stream closed" + assert exception.error_code == 0x5 + assert exception.stream_id == 3 + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/flow_control_test.exs b/grpc_core/test/grpc/transport/http2/flow_control_test.exs new file mode 100644 index 000000000..9fd3eda7b --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/flow_control_test.exs @@ -0,0 +1,134 @@ +defmodule GRPC.Transport.HTTP2.FlowControlTest do + use ExUnit.Case, async: true + + import Bitwise + alias GRPC.Transport.HTTP2.FlowControl + + describe "compute_recv_window/2" do + test "returns correct window size when above minimum threshold" do + # Start with a large window above the minimum threshold (1GB) + large_window = (1 <<< 30) + 1_000_000 + # Receive small amount of data + {new_window, increment} = FlowControl.compute_recv_window(large_window, 1_024) + + # Should still be above threshold, no increment needed + assert new_window == large_window - 1_024 + assert increment == 0 + end + + test "returns window increment when below minimum threshold" do + # Start with minimum threshold + 1 (1GB + 1) + min_threshold = 1 <<< 30 + # Receive enough data to go below threshold + data_size = 2 + + {new_window, increment} = FlowControl.compute_recv_window(min_threshold + 1, data_size) + + # Should have sent a WINDOW_UPDATE + assert increment > 0 + # New window should be original - data + increment + assert new_window == min_threshold + 1 - data_size + increment + end + + test "respects maximum window increment" do + # Maximum increment is 2^31 - 1 + max_increment = (1 <<< 31) - 1 + + # Start with a very small window + {new_window, increment} = FlowControl.compute_recv_window(1, 0) + + # Increment should not exceed maximum + assert increment <= max_increment + assert new_window == 1 + increment + end + + test "respects maximum window size" do + # Maximum window size is 2^31 - 1 + max_window = (1 <<< 31) - 1 + + # Start with a small window that would overflow + {new_window, _increment} = FlowControl.compute_recv_window(1, 0) + + # New window should not exceed maximum + assert new_window <= max_window + end + + test "handles zero data size" do + initial_window = 65_535 + {new_window, increment} = FlowControl.compute_recv_window(initial_window, 0) + + # Window might change if below threshold (function computes new window) + # Just verify it's reasonable + assert new_window >= initial_window + assert increment >= 0 + end + + test "handles large data sizes" do + initial_window = 1 <<< 30 + # Receive 100MB of data + data_size = 100 * 1024 * 1024 + + {new_window, _increment} = FlowControl.compute_recv_window(initial_window, data_size) + + # Window should be decreased by data size, then potentially increased by increment + # Just verify the window is reasonable (non-negative and less than max) + assert new_window >= 0 + assert new_window <= (1 <<< 31) - 1 + end + end + + describe "update_window/2" do + test "updates window size with positive increment" do + assert {:ok, 100} = FlowControl.update_window(50, 50) + end + + test "updates window size with negative increment" do + assert {:ok, 50} = FlowControl.update_window(100, -50) + end + + test "allows window size of zero" do + assert {:ok, 0} = FlowControl.update_window(50, -50) + end + + test "returns error when window size would exceed maximum" do + max_window = (1 <<< 31) - 1 + + assert {:error, :flow_control_error} = + FlowControl.update_window(max_window, 1) + end + + test "returns error when increment causes overflow" do + # Window size is at max, any positive increment should fail + max_window = (1 <<< 31) - 1 + + assert {:error, :flow_control_error} = + FlowControl.update_window(max_window, 100) + end + + test "allows update to exactly maximum window size" do + max_window = (1 <<< 31) - 1 + + assert {:ok, ^max_window} = FlowControl.update_window(max_window - 100, 100) + end + + test "handles large positive increments" do + large_increment = 1_000_000 + + assert {:ok, 1_000_050} = FlowControl.update_window(50, large_increment) + end + + test "handles large negative increments" do + large_decrement = -1_000_000 + + assert {:ok, 0} = FlowControl.update_window(1_000_000, large_decrement) + end + + test "RFC 9113 compliance - maximum window size is 2^31-1" do + # RFC 9113 §6.9.1 + max_window = (1 <<< 31) - 1 + + assert {:ok, ^max_window} = FlowControl.update_window(0, max_window) + assert {:error, :flow_control_error} = FlowControl.update_window(max_window, 1) + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/continuation_test.exs b/grpc_core/test/grpc/transport/http2/frame/continuation_test.exs new file mode 100644 index 000000000..82cddce2c --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/continuation_test.exs @@ -0,0 +1,359 @@ +defmodule GRPC.Transport.HTTP2.Frame.ContinuationTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "CONTINUATION frame deserialization" do + test "deserializes basic CONTINUATION frame" do + # CONTINUATION with END_HEADERS flag + data = <<3::24, 9::8, 0x4::8, 0::1, 1::31, "hdr">> + + assert {{:ok, %Frame.Continuation{stream_id: 1, end_headers: true, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes CONTINUATION without END_HEADERS" do + # More CONTINUATION frames will follow + data = <<3::24, 9::8, 0x0::8, 0::1, 1::31, "hdr">> + + assert {{:ok, %Frame.Continuation{stream_id: 1, end_headers: false, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes CONTINUATION with large fragment" do + fragment = String.duplicate("x", 16384) + data = <> + + assert {{:ok, %Frame.Continuation{stream_id: 1, fragment: ^fragment}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects CONTINUATION with stream_id 0" do + # RFC 9113 §6.10: CONTINUATION frames MUST be associated with a stream + data = <<3::24, 9::8, 0x4::8, 0::1, 0::31, "hdr">> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.protocol_error() + end + + test "handles empty CONTINUATION frame" do + # Edge case: CONTINUATION with no payload (unusual but valid) + data = <<0::24, 9::8, 0x4::8, 0::1, 1::31>> + + assert {{:ok, %Frame.Continuation{stream_id: 1, end_headers: true, fragment: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "CONTINUATION frame serialization" do + test "serializes CONTINUATION with END_HEADERS" do + frame = %Frame.Continuation{ + stream_id: 123, + end_headers: true, + fragment: "hdr" + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<3::24, 9::8, 0x4::8, 0::1, 123::31, "hdr">> = binary + end + + test "serializes CONTINUATION with large fragment" do + fragment = String.duplicate("x", 10000) + + frame = %Frame.Continuation{ + stream_id: 123, + end_headers: true, + fragment: fragment + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<10000::24, 9::8, 0x4::8, 0::1, 123::31, received_fragment::binary>> = binary + assert received_fragment == fragment + end + end + + describe "CONTINUATION sequence scenarios" do + test "handles HEADERS + CONTINUATION sequence" do + # Large header block split across HEADERS and CONTINUATION + full_headers = "very-long-header-block-that-exceeds-max-frame-size" + + # When serializing with small max_frame_size, it splits automatically + headers = %Frame.Headers{ + stream_id: 1, + end_headers: true, + fragment: full_headers + } + + # Serialize with small frame size to force split + frames_io = Frame.serialize(headers, 20) + + # Should produce [HEADERS, CONTINUATION] + # 3 frames for 51 bytes with 20 byte limit + assert length(frames_io) == 3 + + # Deserialize first frame (HEADERS) + [h_io, c1_io, c2_io] = frames_io + h_binary = IO.iodata_to_binary(h_io) + c1_binary = IO.iodata_to_binary(c1_io) + c2_binary = IO.iodata_to_binary(c2_io) + + {{:ok, h_frame}, <<>>} = Frame.deserialize(h_binary, 16_384) + {{:ok, c1_frame}, <<>>} = Frame.deserialize(c1_binary, 16_384) + {{:ok, c2_frame}, <<>>} = Frame.deserialize(c2_binary, 16_384) + + # Reconstruct full header block + reconstructed = + IO.iodata_to_binary([h_frame.fragment, c1_frame.fragment, c2_frame.fragment]) + + assert reconstructed == full_headers + + # First HEADERS should not have END_HEADERS + assert h_frame.end_headers == false + # Middle CONTINUATION should not have END_HEADERS + assert c1_frame.end_headers == false + # Last CONTINUATION should have END_HEADERS + assert c2_frame.end_headers == true + end + + test "handles HEADERS + multiple CONTINUATION frames" do + # Very large header block requiring multiple CONTINUATIONs + full_fragment = "part1part2part3" + + headers = %Frame.Headers{stream_id: 1, end_headers: true, fragment: full_fragment} + + # Serialize with small max_frame_size + frames_io = Frame.serialize(headers, 5) + + # Should split into multiple frames + assert length(frames_io) == 3 + + [h_io, c1_io, c2_io] = frames_io + + h_bin = IO.iodata_to_binary(h_io) + c1_bin = IO.iodata_to_binary(c1_io) + c2_bin = IO.iodata_to_binary(c2_io) + + {{:ok, h}, <<>>} = Frame.deserialize(h_bin, 16_384) + {{:ok, c1}, <<>>} = Frame.deserialize(c1_bin, 16_384) + {{:ok, c2}, <<>>} = Frame.deserialize(c2_bin, 16_384) + + reconstructed = IO.iodata_to_binary([h.fragment, c1.fragment, c2.fragment]) + assert reconstructed == full_fragment + assert h.end_headers == false + assert c1.end_headers == false + assert c2.end_headers == true + end + + test "verifies stream_id consistency across sequence" do + # All frames in sequence must have same stream_id + headers = %Frame.Headers{stream_id: 5, end_headers: false, fragment: "h"} + cont = %Frame.Continuation{stream_id: 5, end_headers: true, fragment: "c"} + + h_bin = IO.iodata_to_binary(Frame.serialize(headers, 16_384)) + c_bin = IO.iodata_to_binary(Frame.serialize(cont, 16_384)) + + {{:ok, h_frame}, <<>>} = Frame.deserialize(h_bin, 16_384) + {{:ok, c_frame}, <<>>} = Frame.deserialize(c_bin, 16_384) + + assert h_frame.stream_id == c_frame.stream_id + end + end + + describe "gRPC-specific scenarios" do + test "handles large gRPC metadata headers" do + # gRPC metadata can be large, requiring CONTINUATION + metadata_headers = + for i <- 1..50 do + "x-custom-header-#{i}: value-#{i}\n" + end + |> Enum.join() + + # Simulate splitting at max_frame_size + max_size = 100 + chunks = for <>, do: chunk + + # Add any remaining bytes + remainder_size = rem(byte_size(metadata_headers), max_size) + + chunks = + if remainder_size > 0 do + chunks ++ + [ + binary_part( + metadata_headers, + byte_size(metadata_headers) - remainder_size, + remainder_size + ) + ] + else + chunks + end + + [first | rest] = chunks + + headers = %Frame.Headers{stream_id: 1, end_headers: false, fragment: first} + + # Middle chunks in CONTINUATION without END_HEADERS + middle = Enum.slice(rest, 0..-2//1) + + continuations = + for {chunk, _idx} <- Enum.with_index(middle) do + %Frame.Continuation{stream_id: 1, end_headers: false, fragment: chunk} + end + + last_chunk = List.last(rest) + + final_cont = %Frame.Continuation{ + stream_id: 1, + end_headers: true, + fragment: last_chunk || <<>> + } + + all_frames = [headers | continuations] ++ [final_cont] + + serialized = Enum.map(all_frames, &Frame.serialize(&1, 16_384)) + assert length(serialized) == length(all_frames) + end + + test "handles gRPC trailers with CONTINUATION" do + # Trailers can be large if they include detailed error info + trailers = """ + grpc-status: 13 + grpc-message: Internal server error + grpc-status-details-bin: #{String.duplicate("x", 500)} + x-debug-info: #{String.duplicate("y", 500)} + """ + + # Split into HEADERS + CONTINUATION + split_point = 100 + + headers = %Frame.Headers{ + stream_id: 1, + end_stream: true, + end_headers: false, + fragment: binary_part(trailers, 0, split_point) + } + + continuation = %Frame.Continuation{ + stream_id: 1, + end_headers: true, + fragment: binary_part(trailers, split_point, byte_size(trailers) - split_point) + } + + h_bin = IO.iodata_to_binary(Frame.serialize(headers, 16_384)) + c_bin = IO.iodata_to_binary(Frame.serialize(continuation, 16_384)) + + {{:ok, h}, <<>>} = Frame.deserialize(h_bin, 16_384) + {{:ok, c}, <<>>} = Frame.deserialize(c_bin, 16_384) + + reconstructed = h.fragment <> c.fragment + assert reconstructed == trailers + assert h.end_stream == true + end + + test "handles HPACK compressed continuation" do + # HPACK encoding can produce variable-length output + # Simulated HPACK encoded data + hpack_encoded = <<0x82, 0x86, 0x84>> <> String.duplicate(<<0x41>>, 100) + + # If exceeds frame size, split into CONTINUATION + if byte_size(hpack_encoded) > 50 do + headers = %Frame.Headers{ + stream_id: 1, + end_headers: false, + fragment: binary_part(hpack_encoded, 0, 50) + } + + continuation = %Frame.Continuation{ + stream_id: 1, + end_headers: true, + fragment: binary_part(hpack_encoded, 50, byte_size(hpack_encoded) - 50) + } + + h_result = Frame.serialize(headers, 16_384) + c_result = Frame.serialize(continuation, 16_384) + + assert is_list(h_result) + assert is_list(c_result) + end + end + + test "handles interleaved stream violation detection" do + # RFC 9113: CONTINUATION frames MUST follow HEADERS immediately + # No other frames can be sent on ANY stream until END_HEADERS + + # Correct sequence: HEADERS(stream=1, no END_HEADERS) -> CONTINUATION(stream=1) + headers = %Frame.Headers{stream_id: 1, end_headers: false, fragment: "h1"} + continuation = %Frame.Continuation{stream_id: 1, end_headers: true, fragment: "c1"} + + # Frames must be processed in order + h_bin = IO.iodata_to_binary(Frame.serialize(headers, 16_384)) + c_bin = IO.iodata_to_binary(Frame.serialize(continuation, 16_384)) + + # Deserialize in correct order + {{:ok, h}, <<>>} = Frame.deserialize(h_bin, 16_384) + {{:ok, c}, <<>>} = Frame.deserialize(c_bin, 16_384) + + assert h.stream_id == c.stream_id + assert c.end_headers == true + end + + test "handles CONTINUATION frame splitting strategy" do + # gRPC implementation should split at frame boundaries + large_metadata = String.duplicate("x-header: value\n", 1000) + max_frame = 16384 + + # Calculate number of frames needed + num_frames = div(byte_size(large_metadata), max_frame) + 1 + + # First frame is HEADERS + # Remaining are CONTINUATION + # Last frame has END_HEADERS + first_fragment = binary_part(large_metadata, 0, min(max_frame, byte_size(large_metadata))) + + headers = %Frame.Headers{ + stream_id: 1, + end_headers: num_frames == 1, + fragment: first_fragment + } + + assert is_struct(headers, Frame.Headers) + end + end + + describe "edge cases" do + test "handles maximum frame size CONTINUATION" do + # Test with exactly max_frame_size payload + max_payload = String.duplicate("x", 16384) + + frame = %Frame.Continuation{ + stream_id: 1, + end_headers: true, + fragment: max_payload + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert byte_size(binary) == 9 + 16384 + end + + test "handles minimum size CONTINUATION" do + frame = %Frame.Continuation{ + stream_id: 1, + end_headers: true, + fragment: <<>> + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<0::24, 9::8, 0x4::8, 0::1, 1::31>> = binary + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/data_test.exs b/grpc_core/test/grpc/transport/http2/frame/data_test.exs new file mode 100644 index 000000000..a3ceb3a58 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/data_test.exs @@ -0,0 +1,179 @@ +defmodule GRPC.Transport.HTTP2.Frame.DataTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "DATA frame deserialization" do + test "deserializes basic DATA frame" do + # DATA frame: stream_id=1, no padding, end_stream=false, data="hello" + data = <<5::24, 0::8, 0::8, 0::1, 1::31, "hello">> + + assert {{:ok, %Frame.Data{stream_id: 1, end_stream: false, data: "hello"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes DATA frame with END_STREAM flag" do + # Flags: END_STREAM (0x1) + data = <<5::24, 0::8, 0x1::8, 0::1, 1::31, "hello">> + + assert {{:ok, %Frame.Data{stream_id: 1, end_stream: true, data: "hello"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes DATA frame with padding" do + # Flags: PADDED (0x8), padding_length=3 + payload = <<3::8, "hello", 0::8, 0::8, 0::8>> + data = <> + + assert {{:ok, %Frame.Data{stream_id: 1, data: "hello"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes empty DATA frame" do + data = <<0::24, 0::8, 0::8, 0::1, 1::31>> + + assert {{:ok, %Frame.Data{stream_id: 1, data: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects DATA frame with stream_id 0" do + # RFC 9113 §6.1: DATA frames MUST be associated with a stream + data = <<3::24, 0::8, 0::8, 0::1, 0::31, "abc">> + + assert {{:error, error_code, "DATA frame with zero stream_id (RFC9113§6.1)"}, <<>>} = + Frame.deserialize(data, 16_384) + + assert error_code == Errors.protocol_error() + end + + test "handles large DATA frames" do + large_data = :binary.copy(<<1>>, 10_000) + data = <<10_000::24, 0::8, 0::8, 0::1, 1::31, large_data::binary>> + + assert {{:ok, %Frame.Data{data: ^large_data}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects DATA frame with excessive padding" do + # Padding length exceeds payload + payload = <<10::8, "abc">> + data = <> + + assert {{:error, _error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + end + end + + describe "DATA frame serialization" do + test "serializes basic DATA frame" do + frame = %Frame.Data{stream_id: 123, end_stream: false, data: "hello"} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<5::24, 0::8, 0::8, 0::1, 123::31, "hello">> = binary + end + + test "serializes DATA frame with END_STREAM flag" do + frame = %Frame.Data{stream_id: 123, end_stream: true, data: "hello"} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<5::24, 0::8, 0x1::8, 0::1, 123::31, "hello">> = binary + end + + test "serializes empty DATA frame" do + frame = %Frame.Data{stream_id: 123, end_stream: false, data: <<>>} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<0::24, 0::8, 0::8, 0::1, 123::31>> = binary + end + + test "splits DATA frame exceeding max_frame_size" do + # 5 bytes of data, but max_frame_size is 2 + frame = %Frame.Data{stream_id: 123, end_stream: false, data: "hello"} + + result = Frame.serialize(frame, 2) + + assert [ + [<<2::24, 0::8, 0::8, 0::1, 123::31>>, "he"], + [<<2::24, 0::8, 0::8, 0::1, 123::31>>, "ll"], + [<<1::24, 0::8, 0::8, 0::1, 123::31>>, "o"] + ] = result + end + + test "sets END_STREAM only on last frame when splitting" do + # Should split into 3 frames, END_STREAM only on last + frame = %Frame.Data{stream_id: 123, end_stream: true, data: "hello"} + + result = Frame.serialize(frame, 2) + + # First two frames should not have END_STREAM + [[<<2::24, 0::8, 0x0::8, _::binary>>, _], [<<2::24, 0::8, 0x0::8, _::binary>>, _], _last] = + result + + # Last frame should have END_STREAM (0x1) + [[<<1::24, 0::8, 0x1::8, 0::1, 123::31>>, "o"]] = [List.last(result)] + end + + test "handles binary data (not just strings)" do + binary_data = <<0, 1, 2, 255, 254, 253>> + frame = %Frame.Data{stream_id: 123, end_stream: false, data: binary_data} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<6::24, 0::8, 0::8, 0::1, 123::31, ^binary_data::binary>> = binary + end + + test "preserves iodata structure for efficiency" do + frame = %Frame.Data{stream_id: 123, data: ["hello", " ", "world"]} + + result = Frame.serialize(frame, 16_384) + + assert is_list(result) + assert IO.iodata_to_binary(result) =~ "hello world" + end + end + + describe "gRPC-specific scenarios" do + test "handles gRPC message framing (5-byte length prefix)" do + # gRPC message: compressed_flag(1) + length(4) + data + grpc_msg = <<0::8, 5::32, "hello">> + frame = %Frame.Data{stream_id: 1, end_stream: false, data: grpc_msg} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<_header::9-bytes, payload::binary>> = binary + assert payload == grpc_msg + end + + test "handles multiple gRPC messages in single DATA frame" do + # Two gRPC messages back-to-back + msg1 = <<0::8, 5::32, "hello">> + msg2 = <<0::8, 5::32, "world">> + combined = msg1 <> msg2 + + frame = %Frame.Data{stream_id: 1, end_stream: false, data: combined} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<_header::9-bytes, payload::binary>> = binary + assert payload == combined + end + + test "handles compressed gRPC messages" do + # Compressed message: flag=1 + compressed_msg = <<1::8, 100::32, :zlib.compress("large data")::binary>> + frame = %Frame.Data{stream_id: 1, end_stream: false, data: compressed_msg} + + result = Frame.serialize(frame, 16_384) + assert is_list(result) + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/goaway_test.exs b/grpc_core/test/grpc/transport/http2/frame/goaway_test.exs new file mode 100644 index 000000000..c87d72433 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/goaway_test.exs @@ -0,0 +1,347 @@ +defmodule GRPC.Transport.HTTP2.Frame.GoawayTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "GOAWAY frame deserialization" do + test "deserializes GOAWAY frame" do + # GOAWAY: last_stream_id=123, error=NO_ERROR, debug="" + payload = <<0::1, 123::31, 0x0::32>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 123, error_code: 0x0, debug_data: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes GOAWAY with debug data" do + debug = "shutting down" + payload = <<0::1, 123::31, 0x0::32, debug::binary>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 123, error_code: 0x0, debug_data: ^debug}}, + <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes GOAWAY with PROTOCOL_ERROR" do + payload = <<0::1, 50::31, 0x1::32>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 50, error_code: 0x1, debug_data: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes GOAWAY with ENHANCE_YOUR_CALM" do + payload = <<0::1, 10::31, 0xB::32>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 10, error_code: 0xB, debug_data: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes GOAWAY with last_stream_id 0" do + # No streams were processed + payload = <<0::1, 0::31, 0x0::32>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 0, error_code: 0x0, debug_data: <<>>}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects GOAWAY with non-zero stream_id" do + # RFC 9113 §6.8: GOAWAY frames MUST be associated with stream 0 + payload = <<0::1, 123::31, 0x0::32>> + data = <> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.protocol_error() + end + + # Note: GOAWAY deserialization uses pattern matching, so insufficient length + # causes a function clause error rather than returning an error tuple + # This is by design - the frame parser validates lengths before deserialization + + test "handles GOAWAY with large debug data" do + debug = String.duplicate("debug info ", 100) + payload = <<0::1, 123::31, 0x0::32, debug::binary>> + data = <> + + assert {{:ok, %Frame.Goaway{debug_data: ^debug}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "handles reserved bit correctly" do + # Reserved bit should be ignored + payload = <<1::1, 123::31, 0x0::32>> + data = <> + + assert {{:ok, %Frame.Goaway{last_stream_id: 123}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "GOAWAY frame serialization" do + test "serializes GOAWAY with NO_ERROR" do + frame = %Frame.Goaway{ + last_stream_id: 123, + error_code: 0x0, + debug_data: <<>> + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<8::24, 7::8, 0x0::8, 0::1, 0::31, 0::1, 123::31, 0x0::32>> = binary + end + + test "serializes GOAWAY with debug data" do + debug = "shutting down" + + frame = %Frame.Goaway{ + last_stream_id: 123, + error_code: 0x0, + debug_data: debug + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + payload_length = 8 + byte_size(debug) + + <<^payload_length::24, 7::8, 0x0::8, 0::1, 0::31, 0::1, 123::31, 0x0::32, ^debug::binary>> = + binary + end + + test "serializes GOAWAY with INTERNAL_ERROR" do + frame = %Frame.Goaway{ + last_stream_id: 50, + error_code: 0x2, + debug_data: <<>> + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<8::24, 7::8, 0x0::8, 0::1, 0::31, 0::1, 50::31, 0x2::32>> = binary + end + + test "sets reserved bit to 0" do + frame = %Frame.Goaway{ + last_stream_id: 123, + error_code: 0x0, + debug_data: <<>> + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<_::9-bytes, reserved::1, _last_stream::31, _error::32>> = binary + assert reserved == 0 + end + end + + describe "gRPC-specific scenarios" do + test "handles graceful server shutdown" do + # Server initiates graceful shutdown with NO_ERROR + # Sets last_stream_id to highest processed stream + shutdown = %Frame.Goaway{ + last_stream_id: 99, + error_code: Errors.no_error(), + debug_data: "server shutting down gracefully" + } + + result = Frame.serialize(shutdown, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.last_stream_id == 99 + assert received.error_code == Errors.no_error() + assert received.debug_data =~ "graceful" + end + + test "handles connection timeout" do + # Connection idle timeout triggers GOAWAY + timeout = %Frame.Goaway{ + last_stream_id: 50, + error_code: Errors.no_error(), + debug_data: "idle timeout" + } + + result = Frame.serialize(timeout, 16_384) + assert is_list(result) + end + + test "handles protocol violation shutdown" do + # Peer violates protocol, connection terminated + violation = %Frame.Goaway{ + last_stream_id: 25, + error_code: Errors.protocol_error(), + debug_data: "invalid frame sequence" + } + + result = Frame.serialize(violation, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.protocol_error() + end + + test "handles connection overload" do + # Server overloaded, sends GOAWAY with ENHANCE_YOUR_CALM + overload = %Frame.Goaway{ + last_stream_id: 10, + error_code: Errors.enhance_your_calm(), + debug_data: "too many requests" + } + + result = Frame.serialize(overload, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.enhance_your_calm() + end + + test "handles client-initiated close" do + # Client closes connection cleanly + close = %Frame.Goaway{ + last_stream_id: 0, + error_code: Errors.no_error(), + debug_data: <<>> + } + + result = Frame.serialize(close, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<8::24, 7::8, 0x0::8, 0::1, 0::31, 0::1, 0::31, _::32>> = binary + end + + test "handles two-phase shutdown" do + # Server sends GOAWAY with high last_stream_id + # Waits for in-flight streams to complete + # Sends final GOAWAY with actual last_stream_id + + phase1 = %Frame.Goaway{ + last_stream_id: 0x7FFFFFFF, + error_code: Errors.no_error(), + debug_data: "draining connections" + } + + phase2 = %Frame.Goaway{ + last_stream_id: 42, + error_code: Errors.no_error(), + debug_data: "final shutdown" + } + + phase1_binary = IO.iodata_to_binary(Frame.serialize(phase1, 16_384)) + phase2_binary = IO.iodata_to_binary(Frame.serialize(phase2, 16_384)) + + {{:ok, p1}, <<>>} = Frame.deserialize(phase1_binary, 16_384) + {{:ok, p2}, <<>>} = Frame.deserialize(phase2_binary, 16_384) + + assert p1.last_stream_id == 0x7FFFFFFF + assert p2.last_stream_id == 42 + end + + test "handles internal error during processing" do + # Unexpected internal error triggers immediate shutdown + internal = %Frame.Goaway{ + last_stream_id: 15, + error_code: Errors.internal_error(), + debug_data: "unexpected exception in handler" + } + + result = Frame.serialize(internal, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.internal_error() + end + + test "handles flow control error shutdown" do + # Global flow control violation requires connection close + flow_error = %Frame.Goaway{ + last_stream_id: 20, + error_code: Errors.flow_control_error(), + debug_data: "connection window exceeded" + } + + result = Frame.serialize(flow_error, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.flow_control_error() + end + + test "includes diagnostic info in debug data" do + # Useful debug information for troubleshooting + diagnostic = %Frame.Goaway{ + last_stream_id: 30, + error_code: Errors.protocol_error(), + debug_data: "frame_type=1 stream_id=31 error=invalid_headers" + } + + result = Frame.serialize(diagnostic, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.debug_data =~ "frame_type" + assert received.debug_data =~ "invalid_headers" + end + + test "handles connection state after GOAWAY" do + # After sending GOAWAY, no new streams should be created + # Existing streams can complete + + goaway = %Frame.Goaway{ + last_stream_id: 10, + error_code: Errors.no_error(), + debug_data: "no new streams" + } + + result = Frame.serialize(goaway, 16_384) + binary = IO.iodata_to_binary(result) + + # Any stream_id <= 10 can still send frames + # Any stream_id > 10 should be rejected + {{:ok, %{last_stream_id: last_stream}}, <<>>} = Frame.deserialize(binary, 16_384) + assert last_stream == 10 + end + end + + describe "error code mapping" do + test "maps all standard HTTP/2 error codes" do + error_scenarios = [ + {Errors.no_error(), "clean shutdown"}, + {Errors.protocol_error(), "protocol violation"}, + {Errors.internal_error(), "internal server error"}, + {Errors.flow_control_error(), "flow control violation"}, + {Errors.settings_timeout(), "settings timeout"}, + {Errors.stream_closed(), "frame on closed stream"}, + {Errors.frame_size_error(), "invalid frame size"}, + {Errors.refused_stream(), "stream refused"}, + {Errors.cancel(), "operation cancelled"}, + {Errors.compression_error(), "compression error"}, + {Errors.connect_error(), "connect error"}, + {Errors.enhance_your_calm(), "excessive load"}, + {Errors.inadequate_security(), "security requirements not met"}, + {Errors.http_1_1_requires(), "HTTP/1.1 required"} + ] + + for {error_code, description} <- error_scenarios do + frame = %Frame.Goaway{ + last_stream_id: 1, + error_code: error_code, + debug_data: description + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == error_code + assert received.debug_data == description + end + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/headers_test.exs b/grpc_core/test/grpc/transport/http2/frame/headers_test.exs new file mode 100644 index 000000000..08a9de45e --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/headers_test.exs @@ -0,0 +1,236 @@ +defmodule GRPC.Transport.HTTP2.Frame.HeadersTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "HEADERS frame deserialization" do + test "deserializes basic HEADERS frame" do + # HEADERS frame with END_HEADERS flag + data = <<3::24, 1::8, 0x4::8, 0::1, 1::31, "hdr">> + + assert {{:ok, %Frame.Headers{stream_id: 1, end_headers: true, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes HEADERS frame with END_STREAM flag" do + # Flags: END_STREAM (0x1) + END_HEADERS (0x4) = 0x5 + data = <<3::24, 1::8, 0x5::8, 0::1, 1::31, "hdr">> + + assert {{:ok, + %Frame.Headers{stream_id: 1, end_stream: true, end_headers: true, fragment: "hdr"}}, + <<>>} = Frame.deserialize(data, 16_384) + end + + test "deserializes HEADERS frame without END_HEADERS (needs CONTINUATION)" do + # No END_HEADERS flag - requires CONTINUATION + data = <<3::24, 1::8, 0x0::8, 0::1, 1::31, "hdr">> + + assert {{:ok, %Frame.Headers{stream_id: 1, end_headers: false, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes HEADERS frame with PRIORITY flag" do + # Flags: PRIORITY (0x20) + END_HEADERS (0x4) = 0x24 + # Priority: exclusive=1, dependency=5, weight=10 + priority = <<1::1, 5::31, 10::8>> + + data = + < "hdr")::24, 1::8, 0x24::8, 0::1, 1::31, priority::binary, "hdr">> + + assert {{:ok, + %Frame.Headers{ + stream_id: 1, + exclusive_dependency: true, + stream_dependency: 5, + weight: 10, + fragment: "hdr" + }}, <<>>} = Frame.deserialize(data, 16_384) + end + + test "deserializes HEADERS frame with padding" do + # Flags: PADDED (0x8) + END_HEADERS (0x4) = 0xC + payload = <<2::8, "hdr", 0::8, 0::8>> + data = <> + + assert {{:ok, %Frame.Headers{stream_id: 1, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects HEADERS frame with stream_id 0" do + # RFC 9113 §6.2: HEADERS frames MUST be associated with a stream + data = <<3::24, 1::8, 0x4::8, 0::1, 0::31, "hdr">> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.protocol_error() + end + + test "handles HPACK compressed headers" do + # Simulated HPACK encoded headers (not real HPACK encoding) + hpack_data = <<0x82, 0x86, 0x84, 0x41, 0x0F>> + data = <> + + assert {{:ok, %Frame.Headers{fragment: ^hpack_data}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "HEADERS frame serialization" do + test "serializes basic HEADERS frame with END_HEADERS" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + end_headers: true, + fragment: "hdr" + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Flags: END_HEADERS (0x4) + assert <<3::24, 1::8, 0x4::8, 0::1, 123::31, "hdr">> = binary + end + + test "serializes HEADERS frame with END_STREAM and END_HEADERS" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: true, + end_headers: true, + fragment: "hdr" + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Flags: END_STREAM (0x1) + END_HEADERS (0x4) = 0x5 + assert <<3::24, 1::8, 0x5::8, 0::1, 123::31, "hdr">> = binary + end + + test "splits large HEADERS into HEADERS + CONTINUATION frames" do + # 6 bytes of headers, max_frame_size = 2 + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + end_headers: true, + fragment: "header" + } + + result = Frame.serialize(frame, 2) + + # Should produce: HEADERS (no END_HEADERS) + CONTINUATION + CONTINUATION (END_HEADERS) + assert [ + [<<2::24, 1::8, 0x0::8, _::binary>>, "he"], + [<<2::24, 9::8, 0x0::8, _::binary>>, "ad"], + [<<2::24, 9::8, 0x4::8, _::binary>>, "er"] + ] = result + end + + test "CONTINUATION frames have END_HEADERS only on last frame" do + frame = %Frame.Headers{ + stream_id: 123, + end_headers: true, + fragment: "abcde" + } + + result = Frame.serialize(frame, 2) + + # First HEADERS: no END_HEADERS + [[<<2::24, 1::8, flags1::8, _::binary>>, _] | continuation_frames] = result + assert flags1 == 0x0 + + # Middle CONTINUATION: no END_HEADERS + middle_frames = Enum.slice(continuation_frames, 0..-2//1) + + for [<<_::24, 9::8, flags::8, _::binary>>, _] <- middle_frames do + assert flags == 0x0 + end + + # Last CONTINUATION: END_HEADERS (0x4) + [[<<_::24, 9::8, last_flags::8, _::binary>>, _]] = + Enum.slice(continuation_frames, -1..-1//1) + + assert last_flags == 0x4 + end + + test "preserves END_STREAM flag when splitting" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: true, + end_headers: true, + fragment: "abcde" + } + + result = Frame.serialize(frame, 2) + + # First HEADERS should have END_STREAM (0x1) but not END_HEADERS + [[<<2::24, 1::8, 0x1::8, _::binary>>, _] | _] = result + end + end + + describe "gRPC-specific scenarios" do + test "handles gRPC pseudo-headers" do + # gRPC uses HTTP/2 pseudo-headers: :method, :scheme, :path, :authority + # These would be HPACK encoded, but we test with raw bytes + grpc_headers = ":method: POST\n:path: /my.Service/Method" + frame = %Frame.Headers{stream_id: 1, end_headers: true, fragment: grpc_headers} + + result = Frame.serialize(frame, 16_384) + assert is_list(result) + end + + test "handles gRPC metadata headers" do + # gRPC metadata: custom headers, timeout, compression + metadata = "grpc-timeout: 1S\ngrpc-encoding: gzip\nx-custom: value" + frame = %Frame.Headers{stream_id: 1, end_headers: true, fragment: metadata} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<_header::9-bytes, payload::binary>> = binary + assert payload == metadata + end + + test "handles trailers-only response (no DATA frames)" do + # gRPC can send trailers-only response for immediate errors + # HEADERS frame with both END_STREAM and END_HEADERS + trailers = "grpc-status: 0\ngrpc-message: OK" + + frame = %Frame.Headers{ + stream_id: 1, + end_stream: true, + end_headers: true, + fragment: trailers + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Should have both END_STREAM (0x1) and END_HEADERS (0x4) = 0x5 + <<_length::24, 1::8, 0x5::8, _::binary>> = binary + end + + test "handles large metadata requiring continuation" do + # Large custom metadata that exceeds max_frame_size + large_metadata = String.duplicate("x-custom-#{:rand.uniform(1000)}: value\n", 100) + + frame = %Frame.Headers{ + stream_id: 1, + end_headers: true, + fragment: large_metadata + } + + result = Frame.serialize(frame, 100) + + # Should split into multiple frames + assert length(result) > 1 + + # First should be HEADERS + [[<<_::24, 1::8, _::8, _::binary>>, _] | continuation] = result + + # Rest should be CONTINUATION (type 9) + for [<<_::24, 9::8, _::8, _::binary>>, _] <- continuation do + assert true + end + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/ping_test.exs b/grpc_core/test/grpc/transport/http2/frame/ping_test.exs new file mode 100644 index 000000000..5c3acc462 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/ping_test.exs @@ -0,0 +1,204 @@ +defmodule GRPC.Transport.HTTP2.Frame.PingTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "PING frame deserialization" do + test "deserializes PING frame" do + # PING with 8 bytes of opaque data + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + data = <<8::24, 6::8, 0x0::8, 0::1, 0::31, payload::binary>> + + assert {{:ok, %Frame.Ping{ack: false, payload: ^payload}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes PING ACK frame" do + # PING with ACK flag (0x1) + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + data = <<8::24, 6::8, 0x1::8, 0::1, 0::31, payload::binary>> + + assert {{:ok, %Frame.Ping{ack: true, payload: ^payload}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects PING frame with non-zero stream_id" do + # RFC 9113 §6.7: PING frames MUST be associated with stream 0 + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + data = <<8::24, 6::8, 0x0::8, 0::1, 1::31, payload::binary>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.protocol_error() + end + + test "rejects PING frame with incorrect length" do + # RFC 9113 §6.7: PING frames MUST be exactly 8 bytes + payload = <<1, 2, 3, 4>> + data = <<4::24, 6::8, 0x0::8, 0::1, 0::31, payload::binary>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.frame_size_error() + end + + test "rejects PING frame with length too large" do + # Must be exactly 8 bytes, not more + payload = <<1, 2, 3, 4, 5, 6, 7, 8, 9, 10>> + data = <<10::24, 6::8, 0x0::8, 0::1, 0::31, payload::binary>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.frame_size_error() + end + + test "handles PING with all zeros" do + payload = <<0, 0, 0, 0, 0, 0, 0, 0>> + data = <<8::24, 6::8, 0x0::8, 0::1, 0::31, payload::binary>> + + assert {{:ok, %Frame.Ping{ack: false, payload: ^payload}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "handles PING with all ones" do + payload = <<255, 255, 255, 255, 255, 255, 255, 255>> + data = <<8::24, 6::8, 0x0::8, 0::1, 0::31, payload::binary>> + + assert {{:ok, %Frame.Ping{ack: false, payload: ^payload}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "PING frame serialization" do + test "serializes PING frame" do + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + frame = %Frame.Ping{ack: false, payload: payload} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<8::24, 6::8, 0x0::8, 0::1, 0::31, ^payload::binary>> = binary + end + + test "serializes PING ACK frame" do + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + frame = %Frame.Ping{ack: true, payload: payload} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<8::24, 6::8, 0x1::8, 0::1, 0::31, ^payload::binary>> = binary + end + + test "preserves payload exactly" do + # Ensure payload is not modified + payload = <<0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE>> + frame = %Frame.Ping{ack: false, payload: payload} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<_::9-bytes, received_payload::8-bytes>> = binary + assert received_payload == payload + end + end + + describe "PING round-trip" do + test "PING request and ACK response match payload" do + # Client sends PING + client_payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + ping = %Frame.Ping{ack: false, payload: client_payload} + + # Server responds with PING ACK containing same payload + pong = %Frame.Ping{ack: true, payload: client_payload} + + ping_binary = IO.iodata_to_binary(Frame.serialize(ping, 16_384)) + pong_binary = IO.iodata_to_binary(Frame.serialize(pong, 16_384)) + + # Deserialize both + {{:ok, ping_frame}, <<>>} = Frame.deserialize(ping_binary, 16_384) + {{:ok, pong_frame}, <<>>} = Frame.deserialize(pong_binary, 16_384) + + # Verify payload matches + assert IO.iodata_to_binary(ping_frame.payload) == IO.iodata_to_binary(pong_frame.payload) + assert ping_frame.ack == false + assert pong_frame.ack == true + end + end + + describe "gRPC-specific scenarios" do + test "handles keepalive PING" do + # gRPC uses PING for connection keepalive + # Typically uses timestamp or counter as payload + timestamp = System.system_time(:millisecond) + payload = <> + + ping = %Frame.Ping{ack: false, payload: payload} + pong = %Frame.Ping{ack: true, payload: payload} + + ping_frame = Frame.serialize(ping, 16_384) + pong_frame = Frame.serialize(pong, 16_384) + + assert is_list(ping_frame) + assert is_list(pong_frame) + end + + test "handles latency measurement" do + # Can use PING to measure RTT + # Use a positive timestamp value + timestamp = System.system_time(:millisecond) + payload = <> + + # Send PING + ping = %Frame.Ping{ack: false, payload: payload} + _ping_binary = IO.iodata_to_binary(Frame.serialize(ping, 16_384)) + + # Receive PING ACK + {{:ok, pong}, <<>>} = + Frame.deserialize(<<8::24, 6::8, 0x1::8, 0::1, 0::31, payload::binary>>, 16_384) + + # Calculate RTT (in real scenario, would have network delay) + <> = IO.iodata_to_binary(pong.payload) + assert received_time == timestamp + end + + test "handles connection health check" do + # gRPC clients periodically send PING to check connection + health_check = %Frame.Ping{ + ack: false, + payload: <<"HEALTH", 0::16>> + } + + result = Frame.serialize(health_check, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert IO.iodata_to_binary(received.payload) == <<"HEALTH", 0::16>> + end + + test "handles PING flood protection scenario" do + # In production, need to rate-limit PING frames + # Test multiple PINGs with different payload + pings = + for i <- 1..10 do + %Frame.Ping{ack: false, payload: <>} + end + + serialized = Enum.map(pings, &Frame.serialize(&1, 16_384)) + + assert length(serialized) == 10 + assert Enum.all?(serialized, &is_list/1) + end + + test "handles PING timeout scenario" do + # Send PING, simulate no response (timeout detection) + ping = %Frame.Ping{ + ack: false, + payload: <> + } + + ping_binary = IO.iodata_to_binary(Frame.serialize(ping, 16_384)) + + # In real implementation, would start timer and close connection if no ACK + assert byte_size(ping_binary) == 17 + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/rst_stream_test.exs b/grpc_core/test/grpc/transport/http2/frame/rst_stream_test.exs new file mode 100644 index 000000000..cf40580bb --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/rst_stream_test.exs @@ -0,0 +1,217 @@ +defmodule GRPC.Transport.HTTP2.Frame.RstStreamTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "RST_STREAM frame deserialization" do + test "deserializes RST_STREAM with NO_ERROR" do + # RST_STREAM with error code NO_ERROR (0x0) + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x0::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0x0}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes RST_STREAM with PROTOCOL_ERROR" do + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x1::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0x1}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes RST_STREAM with INTERNAL_ERROR" do + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x2::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0x2}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes RST_STREAM with FLOW_CONTROL_ERROR" do + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x3::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0x3}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes RST_STREAM with CANCEL" do + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x8::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0x8}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects RST_STREAM with stream_id 0" do + # RFC 9113 §6.4: RST_STREAM frames MUST be associated with a stream + data = <<4::24, 3::8, 0x0::8, 0::1, 0::31, 0x8::32>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.protocol_error() + end + + test "rejects RST_STREAM with incorrect length" do + # RFC 9113 §6.4: RST_STREAM frames MUST be 4 bytes + data = <<2::24, 3::8, 0x0::8, 0::1, 123::31, 0x8::16>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.frame_size_error() + end + + test "handles RST_STREAM with unknown error code" do + # Unknown error codes should still be accepted + data = <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0xFF::32>> + + assert {{:ok, %Frame.RstStream{stream_id: 123, error_code: 0xFF}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "RST_STREAM frame serialization" do + test "serializes RST_STREAM with NO_ERROR" do + frame = %Frame.RstStream{stream_id: 123, error_code: 0x0} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x0::32>> = binary + end + + test "serializes RST_STREAM with CANCEL" do + frame = %Frame.RstStream{stream_id: 123, error_code: 0x8} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 3::8, 0x0::8, 0::1, 123::31, 0x8::32>> = binary + end + + test "serializes RST_STREAM with INTERNAL_ERROR" do + frame = %Frame.RstStream{stream_id: 456, error_code: 0x2} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 3::8, 0x0::8, 0::1, 456::31, 0x2::32>> = binary + end + end + + describe "gRPC-specific scenarios" do + test "handles client cancellation" do + # Client cancels RPC by sending RST_STREAM with CANCEL (0x8) + cancel = %Frame.RstStream{stream_id: 1, error_code: Errors.cancel()} + + result = Frame.serialize(cancel, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.cancel() + end + + test "handles server rejecting stream" do + # Server rejects stream due to overload with REFUSED_STREAM (0x7) + reject = %Frame.RstStream{stream_id: 1, error_code: Errors.refused_stream()} + + result = Frame.serialize(reject, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.error_code == Errors.refused_stream() + end + + test "handles flow control violation" do + # RST_STREAM sent when peer violates flow control + flow_error = %Frame.RstStream{ + stream_id: 5, + error_code: Errors.flow_control_error() + } + + result = Frame.serialize(flow_error, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 3::8, 0x0::8, 0::1, 5::31, _error::32>> = binary + end + + test "handles stream timeout" do + # Application-level timeout can trigger RST_STREAM with CANCEL + timeout = %Frame.RstStream{ + stream_id: 10, + error_code: Errors.cancel() + } + + result = Frame.serialize(timeout, 16_384) + assert is_list(result) + end + + test "handles protocol violation on stream" do + # Protocol error specific to a stream + protocol_err = %Frame.RstStream{ + stream_id: 15, + error_code: Errors.protocol_error() + } + + result = Frame.serialize(protocol_err, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.stream_id == 15 + assert received.error_code == Errors.protocol_error() + end + + test "handles concurrent stream resets" do + # Multiple streams can be reset independently + resets = + for stream_id <- 1..10 do + %Frame.RstStream{stream_id: stream_id, error_code: Errors.cancel()} + end + + serialized = Enum.map(resets, &Frame.serialize(&1, 16_384)) + + assert length(serialized) == 10 + assert Enum.all?(serialized, &is_list/1) + end + + test "handles RST_STREAM after partial message" do + # Stream reset while message is being transmitted + rst = %Frame.RstStream{ + stream_id: 3, + error_code: Errors.internal_error() + } + + result = Frame.serialize(rst, 16_384) + binary = IO.iodata_to_binary(result) + + # Verify frame structure + assert <<4::24, 3::8, 0x0::8, 0::1, 3::31, _::32>> = binary + end + + test "handles RST_STREAM for idle stream" do + # Receiving RST_STREAM for stream that was never opened + # Implementation should handle gracefully + rst = %Frame.RstStream{stream_id: 999, error_code: 0x1} + + result = Frame.serialize(rst, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.stream_id == 999 + end + + test "error code mapping to gRPC status" do + # Different HTTP/2 errors map to different gRPC status codes + error_codes = [ + {Errors.no_error(), "NO_ERROR - clean shutdown"}, + {Errors.protocol_error(), "PROTOCOL_ERROR - invalid protocol state"}, + {Errors.internal_error(), "INTERNAL_ERROR - server error"}, + {Errors.flow_control_error(), "FLOW_CONTROL_ERROR - window violated"}, + {Errors.cancel(), "CANCEL - client cancellation"}, + {Errors.refused_stream(), "REFUSED_STREAM - server overload"} + ] + + for {error_code, _description} <- error_codes do + frame = %Frame.RstStream{stream_id: 1, error_code: error_code} + result = Frame.serialize(frame, 16_384) + assert is_list(result) + end + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/settings_test.exs b/grpc_core/test/grpc/transport/http2/frame/settings_test.exs new file mode 100644 index 000000000..4558ca3a1 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/settings_test.exs @@ -0,0 +1,189 @@ +defmodule GRPC.Transport.HTTP2.Frame.SettingsTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + + describe "SETTINGS frame deserialization" do + test "deserializes empty SETTINGS frame" do + # Empty SETTINGS frame (connection preface) + data = <<0::24, 4::8, 0x0::8, 0::1, 0::31>> + + assert {{:ok, %Frame.Settings{ack: false, settings: settings}}, <<>>} = + Frame.deserialize(data, 16_384) + + assert settings == %{} + end + + test "deserializes SETTINGS ACK frame" do + # SETTINGS frame with ACK flag (0x1) + data = <<0::24, 4::8, 0x1::8, 0::1, 0::31>> + + assert {{:ok, %Frame.Settings{ack: true, settings: nil}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes SETTINGS frame with max_concurrent_streams" do + # SETTINGS_MAX_CONCURRENT_STREAMS (0x3) = 100 + settings_payload = <<0x3::16, 100::32>> + + data = + <> + + assert {{:ok, %Frame.Settings{settings: %{max_concurrent_streams: 100}}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes SETTINGS frame with multiple settings" do + settings_payload = << + 0x1::16, + 8192::32, + # SETTINGS_HEADER_TABLE_SIZE + 0x3::16, + 100::32, + # SETTINGS_MAX_CONCURRENT_STREAMS + 0x4::16, + 32768::32 + # SETTINGS_INITIAL_WINDOW_SIZE + >> + + data = + <> + + assert {{:ok, + %Frame.Settings{ + settings: %{ + header_table_size: 8192, + max_concurrent_streams: 100, + initial_window_size: 32768 + } + }}, <<>>} = Frame.deserialize(data, 16_384) + end + end + + describe "SETTINGS frame serialization" do + test "serializes empty SETTINGS frame" do + frame = %Frame.Settings{ack: false, settings: %{}} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<0::24, 4::8, 0x0::8, 0::1, 0::31>> = binary + end + + test "serializes SETTINGS ACK frame" do + frame = %Frame.Settings{ack: true, settings: nil} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<0::24, 4::8, 0x1::8, 0::1, 0::31>> = binary + end + + test "serializes SETTINGS frame with max_concurrent_streams" do + frame = %Frame.Settings{ack: false, settings: %{max_concurrent_streams: 100}} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<6::24, 4::8, 0x0::8, 0::1, 0::31, 0x3::16, 100::32>> = binary + end + + test "serializes SETTINGS frame omitting default values" do + # Default values should be omitted + frame = %Frame.Settings{ + ack: false, + settings: %{ + header_table_size: 4096, + # default + max_concurrent_streams: 200, + # non-default + initial_window_size: 65535, + # default + max_frame_size: 16384 + # default + } + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Should only include max_concurrent_streams (non-default) + assert <<6::24, 4::8, 0x0::8, 0::1, 0::31, 0x3::16, 200::32>> = binary + end + end + + describe "gRPC-specific scenarios" do + test "handles gRPC recommended settings" do + # gRPC typically uses specific settings + frame = %Frame.Settings{ + settings: %{ + header_table_size: 8192, + max_concurrent_streams: 100, + initial_window_size: 1_048_576, + max_frame_size: 16384 + } + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Verify it's a valid SETTINGS frame (header_table_size, max_concurrent_streams, initial_window_size) + # max_frame_size with default value is omitted + <> = binary + assert length == 18 + end + + test "handles connection preface SETTINGS" do + # Client/server exchange SETTINGS during connection preface + settings = %Frame.Settings{ + settings: %{ + max_concurrent_streams: 100, + initial_window_size: 65535 + } + } + + result = Frame.serialize(settings, 16_384) + assert is_list(result) + end + + test "handles SETTINGS ACK response" do + # After receiving SETTINGS, peer must send SETTINGS ACK + ack = %Frame.Settings{ack: true, settings: nil} + + result = Frame.serialize(ack, 16_384) + binary = IO.iodata_to_binary(result) + + # Empty payload with ACK flag + assert <<0::24, 4::8, 0x1::8, 0::1, 0::31>> = binary + end + + test "handles window size updates via SETTINGS" do + # Changing SETTINGS_INITIAL_WINDOW_SIZE affects flow control + frame = %Frame.Settings{ + settings: %{ + initial_window_size: 1_048_576 + } + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + <<6::24, 4::8, 0x0::8, 0::1, 0::31, 0x4::16, 1_048_576::32>> = binary + end + + test "handles unlimited max_concurrent_streams" do + # :infinity means no limit on concurrent streams + frame = %Frame.Settings{ + settings: %{ + max_concurrent_streams: :infinity + } + } + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # :infinity is omitted (no setting sent) + assert <<0::24, 4::8, 0x0::8, 0::1, 0::31>> = binary + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame/window_update_test.exs b/grpc_core/test/grpc/transport/http2/frame/window_update_test.exs new file mode 100644 index 000000000..0e145b0f7 --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame/window_update_test.exs @@ -0,0 +1,285 @@ +defmodule GRPC.Transport.HTTP2.Frame.WindowUpdateTest do + use ExUnit.Case, async: true + import Bitwise + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "WINDOW_UPDATE frame deserialization" do + test "deserializes WINDOW_UPDATE for stream" do + # WINDOW_UPDATE for stream 123, increment 1000 + data = <<4::24, 8::8, 0x0::8, 0::1, 123::31, 0::1, 1000::31>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 123, size_increment: 1000}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes WINDOW_UPDATE for connection (stream 0)" do + # WINDOW_UPDATE for connection-level flow control + data = <<4::24, 8::8, 0x0::8, 0::1, 0::31, 0::1, 65535::31>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 0, size_increment: 65535}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes WINDOW_UPDATE with maximum increment" do + # Maximum increment: 2^31 - 1 + max_increment = (1 <<< 31) - 1 + data = <<4::24, 8::8, 0x0::8, 0::1, 1::31, 0::1, max_increment::31>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 1, size_increment: ^max_increment}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "rejects WINDOW_UPDATE with zero increment" do + # RFC 9113 §6.9: A receiver MUST treat increment of 0 as error + data = <<4::24, 8::8, 0x0::8, 0::1, 123::31, 0::1, 0::31>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.frame_size_error() + end + + test "rejects WINDOW_UPDATE with incorrect length" do + # RFC 9113 §6.9: WINDOW_UPDATE frames MUST be 4 bytes + data = <<2::24, 8::8, 0x0::8, 0::1, 123::31, 100::16>> + + assert {{:error, error_code, _reason}, <<>>} = Frame.deserialize(data, 16_384) + assert error_code == Errors.frame_size_error() + end + + test "handles reserved bit correctly" do + # Reserved bit (first bit) should be ignored + data = <<4::24, 8::8, 0x0::8, 0::1, 123::31, 1::1, 1000::31>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 123, size_increment: 1000}}, <<>>} = + Frame.deserialize(data, 16_384) + end + end + + describe "WINDOW_UPDATE frame serialization" do + test "serializes WINDOW_UPDATE for stream" do + frame = %Frame.WindowUpdate{stream_id: 123, size_increment: 1000} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 8::8, 0x0::8, 0::1, 123::31, 0::1, 1000::31>> = binary + end + + test "serializes WINDOW_UPDATE for connection" do + frame = %Frame.WindowUpdate{stream_id: 0, size_increment: 65535} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 8::8, 0x0::8, 0::1, 0::31, 0::1, 65535::31>> = binary + end + + test "serializes WINDOW_UPDATE with maximum increment" do + max_increment = (1 <<< 31) - 1 + frame = %Frame.WindowUpdate{stream_id: 1, size_increment: max_increment} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 8::8, 0x0::8, 0::1, 1::31, 0::1, ^max_increment::31>> = binary + end + + test "sets reserved bit to 0" do + frame = %Frame.WindowUpdate{stream_id: 123, size_increment: 1000} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Extract window size increment field and check reserved bit + <<_::9-bytes, reserved::1, increment::31>> = binary + assert reserved == 0 + assert increment == 1000 + end + end + + describe "gRPC-specific scenarios" do + test "handles stream-level window update after consuming data" do + # After processing DATA frame, update stream window + consumed = 8192 + window_update = %Frame.WindowUpdate{stream_id: 1, size_increment: consumed} + + result = Frame.serialize(window_update, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.size_increment == consumed + end + + test "handles connection-level window update" do + # Update connection-level window after processing multiple streams + connection_update = %Frame.WindowUpdate{ + stream_id: 0, + size_increment: 1_048_576 + } + + result = Frame.serialize(connection_update, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.stream_id == 0 + assert received.size_increment == 1_048_576 + end + + test "handles window exhaustion prevention" do + # Send window update before window is fully exhausted + # Keep at least 50% window available + initial_window = 65535 + consumed = div(initial_window, 2) + + update = %Frame.WindowUpdate{stream_id: 1, size_increment: consumed} + + result = Frame.serialize(update, 16_384) + assert is_list(result) + end + + test "handles large message flow control" do + # For large gRPC messages, need multiple window updates + chunk_size = 16384 + + updates = + for _i <- 1..10 do + %Frame.WindowUpdate{stream_id: 1, size_increment: chunk_size} + end + + serialized = Enum.map(updates, &Frame.serialize(&1, 16_384)) + + assert length(serialized) == 10 + assert Enum.all?(serialized, &is_list/1) + end + + test "handles window update for streaming RPCs" do + # Bidirectional streaming needs careful window management + # Client updates window as it consumes server data + client_update = %Frame.WindowUpdate{stream_id: 1, size_increment: 32768} + + # Server updates window as it consumes client data + server_update = %Frame.WindowUpdate{stream_id: 1, size_increment: 32768} + + client_binary = IO.iodata_to_binary(Frame.serialize(client_update, 16_384)) + server_binary = IO.iodata_to_binary(Frame.serialize(server_update, 16_384)) + + assert client_binary == server_binary + end + + test "handles window overflow detection" do + # Receiving window update that would overflow window size + # Current window: 65535, increment: max_value would exceed 2^31-1 + # Implementation should detect this as flow control error + + # Send legitimate max increment + max_increment = (1 <<< 31) - 1 + frame = %Frame.WindowUpdate{stream_id: 1, size_increment: max_increment} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.size_increment == max_increment + end + + test "handles immediate window update strategy" do + # Eagerly update window after every DATA frame + data_frame_size = 8192 + + update = %Frame.WindowUpdate{ + stream_id: 1, + size_increment: data_frame_size + } + + result = Frame.serialize(update, 16_384) + assert is_list(result) + end + + test "handles batched window updates" do + # Buffer multiple DATA frames, then send one window update + total_consumed = 8192 + 8192 + 4096 + + update = %Frame.WindowUpdate{ + stream_id: 1, + size_increment: total_consumed + } + + result = Frame.serialize(update, 16_384) + binary = IO.iodata_to_binary(result) + + {{:ok, received}, <<>>} = Frame.deserialize(binary, 16_384) + assert received.size_increment == total_consumed + end + + test "handles connection vs stream window priority" do + # Both connection and stream windows must be available + # Update both when consuming data + + connection_update = %Frame.WindowUpdate{stream_id: 0, size_increment: 16384} + stream_update = %Frame.WindowUpdate{stream_id: 1, size_increment: 16384} + + conn_binary = IO.iodata_to_binary(Frame.serialize(connection_update, 16_384)) + stream_binary = IO.iodata_to_binary(Frame.serialize(stream_update, 16_384)) + + {{:ok, conn}, <<>>} = Frame.deserialize(conn_binary, 16_384) + {{:ok, stream}, <<>>} = Frame.deserialize(stream_binary, 16_384) + + assert conn.stream_id == 0 + assert stream.stream_id == 1 + assert conn.size_increment == stream.size_increment + end + + test "handles window update timing for backpressure" do + # Delay window update to apply backpressure on sender + # Send smaller increments to slow down data flow + throttled_increment = 4096 + + update = %Frame.WindowUpdate{ + stream_id: 1, + size_increment: throttled_increment + } + + result = Frame.serialize(update, 16_384) + binary = IO.iodata_to_binary(result) + + assert <<4::24, 8::8, 0x0::8, 0::1, 1::31, 0::1, ^throttled_increment::31>> = binary + end + end + + describe "flow control scenarios" do + test "handles window depletion and replenishment" do + # Start with initial window: 65535 + # Consume data in chunks, replenish periodically + + chunks = [16384, 16384, 16384, 16383] + + updates = + for {chunk, idx} <- Enum.with_index(chunks, 1) do + %Frame.WindowUpdate{stream_id: idx, size_increment: chunk} + end + + total_increment = Enum.sum(chunks) + assert total_increment == 65535 + + serialized = Enum.map(updates, &Frame.serialize(&1, 16_384)) + assert length(serialized) == 4 + end + + test "handles concurrent stream window updates" do + # Multiple active streams, each with independent windows + stream_updates = + for stream_id <- 1..10 do + %Frame.WindowUpdate{stream_id: stream_id, size_increment: 8192} + end + + # Plus connection-level update + connection_update = %Frame.WindowUpdate{stream_id: 0, size_increment: 81920} + + all_updates = [connection_update | stream_updates] + serialized = Enum.map(all_updates, &Frame.serialize(&1, 16_384)) + + assert length(serialized) == 11 + end + end +end diff --git a/grpc_core/test/grpc/transport/http2/frame_test.exs b/grpc_core/test/grpc/transport/http2/frame_test.exs new file mode 100644 index 000000000..094b3a9fe --- /dev/null +++ b/grpc_core/test/grpc/transport/http2/frame_test.exs @@ -0,0 +1,280 @@ +defmodule GRPC.Transport.HTTP2.FrameTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + + describe "deserialize/2" do + test "deserializes DATA frame (type 0x0)" do + # Frame: length=3, type=0, flags=0, stream_id=1, payload="abc" + data = <<3::24, 0::8, 0::8, 0::1, 1::31, "abc">> + + assert {{:ok, %Frame.Data{stream_id: 1, data: "abc"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes HEADERS frame (type 0x1)" do + # Frame: length=3, type=1, flags=4 (END_HEADERS), stream_id=1 + data = <<3::24, 1::8, 4::8, 0::1, 1::31, "hdr">> + + assert {{:ok, %Frame.Headers{stream_id: 1, end_headers: true, fragment: "hdr"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes SETTINGS frame (type 0x4)" do + # Frame: length=0, type=4, flags=0, stream_id=0 (empty SETTINGS) + data = <<0::24, 4::8, 0::8, 0::1, 0::31>> + + assert {{:ok, %Frame.Settings{ack: false}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes PING frame (type 0x6)" do + # Frame: length=8, type=6, flags=0, stream_id=0, payload=8 bytes + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + data = <<8::24, 6::8, 0::8, 0::1, 0::31, payload::binary>> + + assert {{:ok, %Frame.Ping{ack: false, payload: ^payload}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes GOAWAY frame (type 0x7)" do + # Frame: length=8, type=7, flags=0, stream_id=0 + data = <<8::24, 7::8, 0::8, 0::1, 0::31, 0::1, 3::31, 0::32>> + + assert {{:ok, %Frame.Goaway{last_stream_id: 3, error_code: 0}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes WINDOW_UPDATE frame (type 0x8)" do + # Frame: length=4, type=8, flags=0, stream_id=1 + data = <<4::24, 8::8, 0::8, 0::1, 1::31, 0::1, 1000::31>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 1, size_increment: 1000}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "deserializes unknown frame type" do + # Frame: length=3, type=99 (unknown), flags=0, stream_id=1 + data = <<3::24, 99::8, 0::8, 0::1, 1::31, "xyz">> + + assert {{:ok, %Frame.Unknown{type: 99, stream_id: 1, payload: "xyz"}}, <<>>} = + Frame.deserialize(data, 16_384) + end + + test "returns error when payload exceeds max_frame_size" do + # Frame with length=17000 but max_frame_size=16384 + data = <<17000::24, 0::8, 0::8, 0::1, 1::31, "data">> + + assert {{:error, error_code, "Payload size too large (RFC9113§4.2)"}, _rest} = + Frame.deserialize(data, 16_384) + + assert error_code == Errors.frame_size_error() + end + + test "returns {:more, buffer} when frame is incomplete" do + # Partial frame (only header, no payload) + data = <<10::24, 0::8, 0::8, 0::1, 1::31>> + + assert {{:more, ^data}, <<>>} = Frame.deserialize(data, 16_384) + end + + test "returns nil for empty buffer" do + assert Frame.deserialize(<<>>, 16_384) == nil + end + + test "handles multiple frames in buffer" do + # Two DATA frames back-to-back + frame1 = <<3::24, 0::8, 0::8, 0::1, 1::31, "abc">> + frame2 = <<3::24, 0::8, 0::8, 0::1, 2::31, "def">> + data = frame1 <> frame2 + + assert {{:ok, %Frame.Data{stream_id: 1, data: "abc"}}, rest} = + Frame.deserialize(data, 16_384) + + assert {{:ok, %Frame.Data{stream_id: 2, data: "def"}}, <<>>} = + Frame.deserialize(rest, 16_384) + end + + test "preserves remaining data after deserialization" do + frame = <<3::24, 0::8, 0::8, 0::1, 1::31, "abc">> + extra = <<1, 2, 3, 4, 5>> + data = frame <> extra + + assert {{:ok, %Frame.Data{}}, ^extra} = Frame.deserialize(data, 16_384) + end + end + + describe "serialize/2" do + test "serializes DATA frame" do + frame = %Frame.Data{stream_id: 1, end_stream: false, data: "hello"} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Check frame header: length=5, type=0, flags=0, stream_id=1 + <> = binary + + assert length == 5 + assert type == 0 + assert stream_id == 1 + assert payload == "hello" + end + + test "serializes SETTINGS frame" do + frame = %Frame.Settings{ack: false, settings: []} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Check frame header: length=0, type=4, stream_id=0 + <> = binary + + assert length == 0 + assert type == 4 + assert stream_id == 0 + end + + test "serializes PING frame" do + payload = <<1, 2, 3, 4, 5, 6, 7, 8>> + frame = %Frame.Ping{ack: false, payload: payload} + + result = Frame.serialize(frame, 16_384) + binary = IO.iodata_to_binary(result) + + # Check frame header: length=8, type=6, stream_id=0 + <> = binary + + assert length == 8 + assert type == 6 + assert stream_id == 0 + assert data == payload + end + + test "handles iodata efficiently" do + # Frame.serialize should return iodata (list), not binary + frame = %Frame.Data{stream_id: 1, data: "test"} + + result = Frame.serialize(frame, 16_384) + + # Result should be iodata (list) + assert is_list(result) + + # But should convert to valid binary + binary = IO.iodata_to_binary(result) + assert is_binary(binary) + end + end + + describe "Frame.Flags" do + test "set/1 returns 0 for empty list" do + assert Frame.Flags.set([]) == 0x0 + end + + test "set/1 sets single bit" do + assert Frame.Flags.set([0]) == 0b00000001 + assert Frame.Flags.set([1]) == 0b00000010 + assert Frame.Flags.set([2]) == 0b00000100 + assert Frame.Flags.set([7]) == 0b10000000 + end + + test "set/1 sets multiple bits" do + assert Frame.Flags.set([0, 2]) == 0b00000101 + assert Frame.Flags.set([0, 1, 2, 3]) == 0b00001111 + end + + test "set?/2 guard works correctly" do + require Frame.Flags + + # bits 0 and 2 set + flags = 0b00000101 + + assert Frame.Flags.set?(flags, 0) + refute Frame.Flags.set?(flags, 1) + assert Frame.Flags.set?(flags, 2) + refute Frame.Flags.set?(flags, 3) + end + + test "not set?/2 works correctly with guard" do + require Frame.Flags + + # bits 0 and 2 set + flags = 0b00000101 + + refute not Frame.Flags.set?(flags, 0) + assert not Frame.Flags.set?(flags, 1) + refute not Frame.Flags.set?(flags, 2) + assert not Frame.Flags.set?(flags, 3) + end + end + + describe "RFC 9113 compliance" do + test "default max_frame_size is 16,384 bytes" do + # RFC 9113 §4.2 + max_frame_size = 16_384 + + # Should accept frame at max size + data = <> + assert {{:ok, _frame}, <<>>} = Frame.deserialize(data, max_frame_size) + end + + test "reserved bit must be 0" do + # RFC 9113 §4.1 - Reserved bit (R) must be unset + # Frame with reserved bit = 0 + data = <<3::24, 0::8, 0::8, 0::1, 1::31, "abc">> + assert {{:ok, _frame}, <<>>} = Frame.deserialize(data, 16_384) + end + + test "frame types 0-9 are defined" do + # RFC 9113 §6 - Frame types + # DATA=0, HEADERS=1, PRIORITY=2, RST_STREAM=3, SETTINGS=4, + # PUSH_PROMISE=5, PING=6, GOAWAY=7, WINDOW_UPDATE=8, CONTINUATION=9 + + # Type 0 (DATA) requires stream_id > 0, so we test with stream_id=1 + data = <<0::24, 0::8, 0::8, 0::1, 1::31>> + assert {{:ok, frame}, _} = Frame.deserialize(data, 16_384) + refute match?(%Frame.Unknown{}, frame) + + # Types 1-9 can be tested with stream_id=0 (except some need valid payload) + for type <- 1..9 do + # Use appropriate payloads for each type + {stream_id, payload} = + case type do + # HEADERS needs priority data + 1 -> {1, <<0, 0, 0, 0>>} + # PRIORITY needs 5 bytes + 2 -> {1, <<0, 0, 0, 0, 0>>} + # RST_STREAM needs 4 bytes + 3 -> {1, <<0, 0, 0, 0>>} + # SETTINGS + 4 -> {0, <<>>} + # PUSH_PROMISE needs promised stream id + 5 -> {1, <<0, 0, 0, 0>>} + # PING needs 8 bytes + 6 -> {0, <<0, 0, 0, 0, 0, 0, 0, 0>>} + # GOAWAY needs 8 bytes + 7 -> {0, <<0, 0, 0, 0, 0, 0, 0, 0>>} + # WINDOW_UPDATE needs 4 bytes + 8 -> {1, <<0, 0, 0, 100>>} + # CONTINUATION + 9 -> {1, <<>>} + end + + length = byte_size(payload) + data = <> + result = Frame.deserialize(data, 16_384) + + # Should not return Unknown frame for types 0-9 + assert {{:ok, frame}, _} = result + refute match?(%Frame.Unknown{}, frame), "Type #{type} returned Unknown frame" + end + end + + test "unknown frame types are ignored gracefully" do + # RFC 9113 §4.1 - Implementations MUST ignore unknown frame types + data = <<3::24, 255::8, 0::8, 0::1, 1::31, "xyz">> + + assert {{:ok, %Frame.Unknown{type: 255}}, <<>>} = Frame.deserialize(data, 16_384) + end + end +end diff --git a/grpc_server/lib/grpc/server.ex b/grpc_server/lib/grpc/server.ex index 1e48c9261..57b2370d2 100644 --- a/grpc_server/lib/grpc/server.ex +++ b/grpc_server/lib/grpc/server.ex @@ -512,7 +512,7 @@ defmodule GRPC.Server do end @doc false - @spec servers_to_map(module() | [module()]) :: %{String.t() => [module()]} + @spec servers_to_map(module() | [module()]) :: %{String.t() => module()} def servers_to_map(servers) do Enum.reduce(List.wrap(servers), %{}, fn s, acc -> Map.put(acc, s.__meta__(:service).__meta__(:name), s) diff --git a/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex b/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex index bc596772d..8255fb2f3 100644 --- a/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex +++ b/grpc_server/lib/grpc/server/adapters/cowboy/handler.ex @@ -450,7 +450,7 @@ defmodule GRPC.Server.Adapters.Cowboy.Handler do {:stop, req, state} else - case GRPC.Message.to_data(data, compressor: compressor, codec: opts[:codec]) do + case GRPC.Message.to_data(data, compressor: compressor, codec: opts[:codec], iolist: true) do {:ok, data, _size} -> req = check_sent_resp(req) :cowboy_req.stream_body(data, is_fin, req) diff --git a/grpc_server/lib/grpc/server/adapters/thousand_island.ex b/grpc_server/lib/grpc/server/adapters/thousand_island.ex new file mode 100644 index 000000000..053a44588 --- /dev/null +++ b/grpc_server/lib/grpc/server/adapters/thousand_island.ex @@ -0,0 +1,642 @@ +defmodule GRPC.Server.Adapters.ThousandIsland do + @moduledoc """ + A server (`GRPC.Server.Adapter`) adapter using `:thousand_island`. + + ThousandIsland is a modern, pure Elixir socket server that provides: + - Built-in connection pooling + - Efficient resource management + - Better integration with Elixir/OTP ecosystem + - Simpler architecture than Cowboy/Ranch + + ## Advantages over Cowboy + + - **Built-in pooling**: Native connection pool management + - **Lower overhead**: Simpler architecture, fewer layers + - **Modern design**: Built with current Elixir best practices + - **Telemetry integration**: First-class observability support + + ## Architecture & Process Model + + ### Module Responsibilities + + 1. **GRPC.Server.Adapters.ThousandIsland** (this module) + - Adapter API implementation (`GRPC.Server.Adapter` behaviour) + - Server lifecycle (start/stop) + - Helper functions: `send_reply/3`, `send_headers/2`, `send_trailers/2` + - These functions send async messages to the Handler process + + 2. **GRPC.Server.Adapters.ThousandIsland.Handler** + - ThousandIsland.Handler behaviour implementation + - HTTP/2 connection lifecycle + - Frame processing coordinator + - Message handling (async operations from user handlers) + - State management (accumulated headers, connection state) + + 3. **GRPC.Server.HTTP2.Connection** + - HTTP/2 protocol state machine + - Frame encoding/decoding (HEADERS, DATA, SETTINGS, etc.) + - HPACK compression/decompression + - Stream state management (per-stream tracking) + - Flow control + + 4. **GRPC.Server.HTTP2.StreamState** + - Per-stream state tracking + - Message buffering and framing + - gRPC message assembly (5-byte length-prefix framing) + - Stream lifecycle (idle -> open -> half_closed -> closed) + + 5. **GRPC.Server.HTTP2.Dispatcher** + - RPC method routing and dispatch + - Determines RPC type (unary, client_stream, server_stream, bidi_stream) + - Spawns handler tasks for streaming RPCs + - Manages BidiStream GenServer for bidirectional streaming + + 6. **GRPC.Server.BidiStream** + - GenServer for bidirectional streaming + - Message queue for incoming requests + - Lazy enumerable generation for handler consumption + - Backpressure management + + ## Request Pipeline by RPC Type + + ### Process Hierarchy + + ```mermaid + graph TD + A[ThousandIsland Supervisor] --> B[Handler Process] + B --> C[Connection State
HTTP2.Connection] + C --> D[Stream States
HTTP2.StreamState per stream_id] + B --> E[User Handler Tasks
spawned per RPC] + E --> F[BidiStream GenServer
only for bidi streaming] + ``` + + ### 1. Unary RPC (request -> response) + + #### Request Path + + 1. **Client sends HTTP/2 frames** → TCP socket + 2. **Handler.handle_data/3** receives raw bytes + - Buffers until complete frames available + 3. **Connection.handle_frame/3** processes each frame + - HEADERS frame → decode headers, create StreamState + - DATA frame → accumulate in StreamState.data_buffer + - When END_STREAM received → decode gRPC message + 4. **Connection.process_grpc_request/4** extracts complete request + - Decodes 5-byte length-prefixed message + - Looks up RPC method from path + 5. **Dispatcher.dispatch/4** routes to handler + - Calls `Dispatcher.call_unary/5` + - Directly invokes user handler function: `MyServer.my_method(request, stream)` + - Handler runs **synchronously** in Handler process + + #### Response Path + + 1. **Handler returns response** (or calls `GRPC.Server.send_reply/2`) + 2. **Dispatcher** sends response headers + data + trailers + - Headers: `{":status" => "200", "content-type" => "application/grpc+proto"}` + - Data: gRPC framed message (5-byte length + protobuf) + - Trailers: `{"grpc-status" => "0"}` + 3. **Connection.send_headers/4** encodes HEADERS frame + - HPACK compression + - Sends via socket + 4. **Connection.send_data/5** encodes DATA frame + - Sets END_STREAM flag + 5. **Connection.send_trailers/4** encodes final HEADERS frame + - Sets END_HEADERS + END_STREAM flags + + **Process Model**: Single Handler process handles entire request synchronously + + ### 2. Client Streaming RPC (stream of requests -> response) + + #### Request Path + + 1. **Client sends multiple DATA frames** (END_STREAM on last) + 2. **Handler.handle_data/3** → **Connection.handle_frame/3** + - Each DATA frame appends to StreamState.data_buffer + - Messages accumulated in StreamState.message_buffer + 3. **When END_STREAM received** → **process_grpc_request/4** + - All messages decoded + 4. **Dispatcher.call_client_streaming/5** + - Creates `Stream.unfold` from buffered messages + - Calls handler: `MyServer.my_method(request_enum, stream)` + - Handler **synchronously** consumes stream + + #### Response Path + + Same as Unary (single response at end) + + **Process Model**: Single Handler process, synchronous handler execution + + ### 3. Server Streaming RPC (request -> stream of responses) + + #### Request Path + + Same as Unary (single request) + + #### Response Path + + 1. **Handler calls `GRPC.Server.send_reply/2` multiple times** + 2. **This adapter's `send_reply/3`** sends async message: + - `send(handler_pid, {:grpc_send_data, stream_id, framed_data})` + 3. **Handler.handle_info/2** receives `:grpc_send_data` + - Calls `Connection.send_data/5` to send DATA frame + - Each call is a separate DATA frame (END_STREAM=false) + 4. **Final trailers** sent at end + - `GRPC.Server.send_trailers/2` → `{:grpc_send_trailers, ...}` + - Handler sends final HEADERS frame with END_STREAM + + **Process Model**: + + - Handler spawns **Task** to run user handler asynchronously + - Handler process receives messages from Task and sends frames + - Task communicates via messages to Handler process + + ### 4. Bidirectional Streaming RPC (stream ↔ stream) + + This is the most complex case with multiple concurrent processes. + + #### Process Model + + ```mermaid + graph LR + subgraph HP["Handler Process (#PID<0.545.0>)"] + HS["State: accumulated_headers
stream_id => headers"] + end + + subgraph UT["User Handler Task (#PID<0.XXX.0>)"] + UTR["Runs: MyServer.full_duplex_call
request_enum, stream"] + end + + subgraph BS["BidiStream GenServer (#PID<0.YYY.0>)"] + BSQ["Queue: Buffered incoming requests"] + end + + Client -->|HTTP/2 frames| HP + HP -->|HTTP/2 frames| Client + UT -->|:grpc_send_data| HP + HP -->|:add_message| BS + BS -->|request_enum
lazy pull| UT + ``` + + #### Request Path (Incoming) + + 1. **Client sends DATA frames** (multiple, no END_STREAM until done) + 2. **Handler.handle_data/3** → **Connection.handle_frame/3** + - Each DATA frame processed immediately + 3. **Connection.process_grpc_request/4** (on first HEADERS) + - Creates StreamState with `is_bidi_streaming: true` + - Calls **Dispatcher.call_bidi_streaming/5** + 4. **Dispatcher.call_bidi_streaming/5** (CRITICAL!) + - **Starts BidiStream GenServer**: `{:ok, bidi_pid} = BidiStream.start_link(stream_id, [])` + - **Accumulates base headers** (don't send yet!): + ```elixir + base_headers = %{":status" => "200", "content-type" => "application/grpc+proto"} + GRPC.Server.set_headers(stream, base_headers) # Sends {:grpc_accumulate_headers, ...} + ``` + - **Spawns User Handler Task**: + ```elixir + request_enum = BidiStream.to_enum(bidi_pid) + Task.start(fn -> + MyServer.full_duplex_call(request_enum, stream) + end) + ``` + - **Stores bidi_pid in StreamState** for later DATA frames + - Returns `:streaming_done` (dispatcher exits, Handler continues) + 5. **Subsequent DATA frames** (while handler running) + - **Connection.handle_frame/3** receives DATA frame + - Decodes gRPC message + - **Sends to BidiStream**: `GenServer.cast(bidi_pid, {:add_message, message})` + - BidiStream queues message for handler consumption + 6. **User Handler consumes request_enum** + - `Enum.each(request_enum, fn req -> ... end)` + - Each iteration pulls from BidiStream (lazy, blocks if queue empty) + - BidiStream dequeues message and returns to handler + + #### Response Path (Outgoing) + + 1. **User Handler calls `GRPC.Server.send_reply/2`** + - Runs in User Task process + 2. **This adapter's `send_reply/3`**: + ```elixir + send(handler_pid, {:grpc_send_data, stream_id, framed_data}) + ``` + 3. **Handler.handle_info({:grpc_send_data, ...}, state)** + - **CRITICAL: Header accumulation pattern** + - Checks if accumulated headers exist for stream_id: + ```elixir + accumulated = Map.get(state.accumulated_headers, stream_id, %{}) + if map_size(accumulated) > 0 do + # First DATA frame - send accumulated headers first! + Connection.send_headers(socket, stream_id, accumulated, connection) + # Clear accumulated headers + state = %{state | accumulated_headers: Map.delete(..., stream_id)} + end + ``` + - Then sends DATA frame: + ```elixir + Connection.send_data(socket, stream_id, data, false, connection) + ``` + 4. **Custom metadata support**: + - If handler calls `GRPC.Server.send_headers/2`: + ```elixir + send(handler_pid, {:grpc_accumulate_headers, stream_id, headers}) + ``` + - **Handler.handle_info({:grpc_accumulate_headers, ...})**: + ```elixir + current = Map.get(state.accumulated_headers, stream_id, %{}) + updated = Map.merge(current, headers) + state = %{state | accumulated_headers: Map.put(..., stream_id, updated)} + ``` + - These headers are sent with FIRST DATA frame (see step 3) + 5. **Final trailers** (when handler finishes) + - `GRPC.Server.send_trailers/2` → `{:grpc_send_trailers, stream_id, trailers}` + - **Handler.handle_info({:grpc_send_trailers, ...})**: + - Checks for unsent accumulated headers (empty stream case): + ```elixir + if map_size(accumulated_headers) > 0 do + Connection.send_headers(...) # Send base headers first + end + Connection.send_trailers(...) # Then trailers with END_STREAM + ``` + + #### Critical Timing & Synchronization + + **Problem**: HTTP/2 requires HEADERS before DATA, but we need to: + 1. Allow handler to add custom headers (via `send_headers/2`) + 2. Send base headers (`:status`, `content-type`) + 3. All in FIRST HEADERS frame (can't send HEADERS twice) + + **Solution** (inspired by Cowboy's `set_resp_headers` pattern): + 1. **Dispatcher accumulates base headers** without sending: + - Sends `{:grpc_accumulate_headers, stream_id, base_headers}` message + - Handler stores in `state.accumulated_headers` + 2. **User handler can add custom headers** (optional): + - Calls `GRPC.Server.send_headers(stream, custom_headers)` + - Merges into accumulated headers in Handler state + 3. **First `send_reply` sends ALL accumulated headers**: + - Handler checks `accumulated_headers[stream_id]` + - Sends merged (base + custom) headers in SINGLE HEADERS frame + - Clears accumulated headers + - Then sends DATA frame + 4. **For empty streams** (no `send_reply` calls): + - `send_trailers` checks for unsent accumulated headers + - Sends headers before trailers + + **Why this works**: + - Handler process is single-threaded message loop + - Messages processed in order: accumulate_headers → send_data → send_trailers + - Accumulated headers guaranteed to be merged before first DATA + - User Task sends messages asynchronously, Handler serializes them + + ## Message Flow Diagram (Bidi Streaming) + + ```mermaid + sequenceDiagram + participant Client + participant Handler as Handler Process + participant Task as User Task + participant Bidi as BidiStream + + Client->>Handler: HEADERS + Handler->>Handler: create StreamState + Handler->>Bidi: start BidiStream + Handler->>Handler: accumulate headers + Handler->>Task: spawn Task + Task->>Task: request_enum (lazy, blocks) + + Client->>Handler: DATA(req1) + Handler->>Handler: decode message + Handler->>Bidi: add_message + Bidi->>Task: pull next (req1) + Task->>Task: process req1 + Task->>Task: send_reply(resp1) + Task->>Handler: :grpc_send_data + Handler->>Handler: send headers (1st!) + Handler->>Client: HEADERS + Handler->>Client: DATA(resp1) + + Client->>Handler: DATA(req2) + Handler->>Bidi: add_message + Bidi->>Task: pull next (req2) + Task->>Task: process req2 + Task->>Task: send_reply(resp2) + Task->>Handler: :grpc_send_data + Handler->>Client: DATA(resp2) + + Client->>Handler: DATA (END_STREAM) + Handler->>Bidi: finish stream + Bidi->>Task: nil (done) + Task->>Task: handler finishes + Task->>Task: send_trailers + Task->>Handler: :grpc_send_trailers + Handler->>Client: HEADERS(trailers)
END_STREAM + ``` + + ## Key Design Patterns + + 1. **Async Message Passing**: User handlers send messages to Handler process + - Decouples user code from HTTP/2 frame management + - Handler serializes all socket writes (thread-safe) + + 2. **Lazy Enumerables**: Streaming requests use `Stream.unfold` + - Backpressure: handler blocks if no messages available + - Memory efficient: doesn't buffer entire stream + + 3. **Header Accumulation**: Inspired by Cowboy's `set_resp_headers` + - Accumulate headers in Handler state (not process dictionary!) + - Send on first DATA or trailers (whichever comes first) + - Allows custom headers while respecting HTTP/2 constraints + + 4. **GenServer Message Queue**: BidiStream acts as message buffer + - Decouples incoming frame rate from handler processing rate + - Natural backpressure via GenServer mailbox + + 5. **Process Dictionary for Stream Metadata**: + - Used in Dispatcher context: `Process.put({:bidi_stream_pid, stream_id}, pid)` + - Allows Connection to find BidiStream when DATA arrives + - Alternative to passing state through deep call stack + """ + + @behaviour GRPC.Server.Adapter + + require Logger + alias GRPC.Server.Adapters.ThousandIsland.Handler + + @default_num_acceptors 100 + @default_max_connections 16384 + + @doc """ + Starts a ThousandIsland server. + + ## Options + * `:ip` - The IP to bind the server to (default: listen on all interfaces) + * `:port` - The port to listen on (required) + * `:num_acceptors` - Number of acceptor processes (default: 100) + * `:num_connections` - Maximum concurrent connections (default: 16384) + * `:transport_options` - Additional transport options to pass to ThousandIsland + """ + @impl true + def start(endpoint, servers, port, opts) do + case Process.whereis(GRPC.Server.StreamTaskSupervisor) do + nil -> + case Task.Supervisor.start_link(name: GRPC.Server.StreamTaskSupervisor) do + {:ok, _pid} -> :ok + {:error, {:already_started, _pid}} -> :ok + end + + _pid -> + :ok + end + + server_opts = build_server_opts(endpoint, servers, port, opts) + + case ThousandIsland.start_link(server_opts) do + {:ok, pid} -> + actual_port = get_actual_port(pid, port) + {:ok, pid, actual_port} + + {:error, {:already_started, pid}} -> + Logger.warning("Failed to start #{server_names(endpoint, servers)}: already started") + actual_port = get_actual_port(pid, port) + {:ok, pid, actual_port} + + {:error, :eaddrinuse} = error -> + Logger.error("Failed to start #{server_names(endpoint, servers)}: port already in use") + error + + {:error, _} = error -> + error + end + end + + defp get_actual_port(pid, default_port) do + case ThousandIsland.listener_info(pid) do + {:ok, {_ip, actual_port}} -> actual_port + _ -> default_port + end + end + + @doc """ + Return a child_spec to start server under a supervisor. + """ + @spec child_spec(atom(), %{String.t() => [module()]}, non_neg_integer(), Keyword.t()) :: + Supervisor.child_spec() + def child_spec(endpoint, servers, port, opts) do + supervisor_opts = [ + endpoint: endpoint, + servers: servers, + port: port, + adapter_opts: Keyword.get(opts, :adapter_opts, []), + cred: opts[:cred] + ] + + %{ + id: __MODULE__.Supervisor, + start: {__MODULE__.Supervisor, :start_link, [supervisor_opts]}, + type: :supervisor, + restart: :permanent, + shutdown: :infinity + } + end + + @impl true + def stop(_endpoint, _servers) do + # TODO: Implement proper shutdown of ThousandIsland server + # ThousandIsland.stop(server_pid) + :ok + end + + @spec read_body(GRPC.Server.Adapter.state()) :: {:ok, binary()} + def read_body(%{data: data}) do + {:ok, data} + end + + @spec reading_stream(GRPC.Server.Adapter.state()) :: Enumerable.t() + def reading_stream(%{stream_state: %{bidi_stream_pid: bidi_pid}}) when not is_nil(bidi_pid) do + # For bidi streaming, return the lazy stream from BidiStream + GRPC.Server.BidiStream.to_enum(bidi_pid) + end + + def reading_stream(%{data: data}) do + # Create a stream that yields the data once + Stream.unfold({data, false}, fn + {_, true} -> + nil + + {buffer, false} -> + case GRPC.Message.get_message(buffer) do + {message, rest} -> {message, {rest, false}} + _ -> nil + end + end) + end + + def set_headers(%{handler_pid: pid, stream_id: stream_id}, headers) do + send(pid, {:grpc_accumulate_headers, stream_id, headers}) + :ok + end + + def set_resp_trailers(%{handler_pid: pid, stream_id: stream_id}, trailers) do + # Send message to accumulate trailers in handler state + # They will be merged with final trailers when stream completes + send(pid, {:grpc_accumulate_trailers, stream_id, trailers}) + :ok + end + + def get_headers(%{headers: headers}) do + headers + end + + def get_headers(%{connection: connection}) do + connection.metadata || %{} + end + + def get_peer(%{socket: socket}) do + case ThousandIsland.Socket.peername(socket) do + {:ok, {address, port}} -> + {:ok, {address, port}} + + error -> + error + end + end + + def get_cert(%{socket: socket}) do + case ThousandIsland.Socket.peercert(socket) do + {:ok, cert} -> {:ok, cert} + {:error, _} -> {:error, :no_peercert} + end + end + + def get_qs(_payload) do + # Query string not applicable for gRPC + "" + end + + def get_bindings(_payload) do + # Path bindings not applicable for gRPC + %{} + end + + def set_compressor(_payload, _compressor) do + # Compressor will be stored in connection state + :ok + end + + @impl true + def send_reply(%{handler_pid: pid, stream_id: stream_id}, data, opts) do + # Encode message with gRPC framing (compressed flag + length + data) + compressor = Keyword.get(opts, :compressor) + codec = Keyword.get(opts, :codec) + + case GRPC.Message.to_data(data, compressor: compressor, codec: codec, iolist: true) do + {:ok, framed_data, _size} -> + # Send data frame - handler will send accumulated headers first if needed + send(pid, {:grpc_send_data, stream_id, framed_data}) + :ok + + {:error, _msg} -> + :ok + end + end + + # Fallback for non-streaming (shouldn't happen but keeps compatibility) + def send_reply(_payload, _data, _opts), do: :ok + + @impl true + def send_headers(%{handler_pid: pid, stream_id: stream_id}, headers) do + # Send message to accumulate headers in handler state + # They will be sent on first send_reply call + send(pid, {:grpc_accumulate_headers, stream_id, headers}) + :ok + end + + def send_headers(_payload, _headers), do: :ok + + def send_trailers(%{handler_pid: pid, stream_id: stream_id}, trailers) do + send(pid, {:grpc_send_trailers, stream_id, trailers}) + :ok + end + + defp build_server_opts(endpoint, servers, port, opts) do + adapter_opts = Keyword.get(opts, :adapter_opts, opts) + + num_acceptors = Keyword.get(adapter_opts, :num_acceptors, @default_num_acceptors) + num_connections = Keyword.get(adapter_opts, :num_connections, @default_max_connections) + + transport_opts = + adapter_opts + |> Keyword.get(:transport_options, []) + |> Keyword.put(:port, port) + |> maybe_add_ip(adapter_opts) + |> maybe_add_ssl(cred_opts(opts)) + # Optimize TCP buffers for gRPC performance (support up to 1MB messages) + # 1MB buffer for large messages + |> Keyword.put_new(:buffer, 1_048_576) + # 1MB receive buffer + |> Keyword.put_new(:recbuf, 1_048_576) + # 1MB send buffer + |> Keyword.put_new(:sndbuf, 1_048_576) + # Disable Nagle's algorithm for low latency + |> Keyword.put_new(:nodelay, true) + + # Configure HTTP/2 settings for larger frames (needed for large gRPC messages) + local_settings = [ + # 1MB window size for large payloads + initial_window_size: 1_048_576, + # Keep default max frame size + max_frame_size: 16_384 + ] + + handler_options = %{ + endpoint: endpoint, + servers: servers, + opts: [local_settings: local_settings] + } + + Logger.debug("[build_server_opts] Creating handler_options") + + [ + port: port, + transport_module: transport_module(opts), + transport_options: transport_opts, + handler_module: Handler, + handler_options: handler_options, + num_acceptors: num_acceptors, + num_connections: num_connections + ] + end + + defp maybe_add_ip(transport_opts, adapter_opts) do + case Keyword.get(adapter_opts, :ip) do + nil -> transport_opts + ip -> Keyword.put(transport_opts, :ip, ip) + end + end + + defp maybe_add_ssl(transport_opts, nil), do: transport_opts + + defp maybe_add_ssl(transport_opts, cred_opts) do + Keyword.merge(transport_opts, cred_opts.ssl) + end + + defp transport_module(opts) do + if cred_opts(opts) do + ThousandIsland.Transports.SSL + else + ThousandIsland.Transports.TCP + end + end + + defp cred_opts(opts) do + opts[:cred] + end + + defp server_names(nil, servers) do + Enum.map_join(servers, ",", fn _k, s -> inspect(s) end) + end + + defp server_names(endpoint, _) do + inspect(endpoint) + end +end diff --git a/grpc_server/lib/grpc/server/adapters/thousand_island/handler.ex b/grpc_server/lib/grpc/server/adapters/thousand_island/handler.ex new file mode 100644 index 000000000..b78f2ff8a --- /dev/null +++ b/grpc_server/lib/grpc/server/adapters/thousand_island/handler.ex @@ -0,0 +1,363 @@ +defmodule GRPC.Server.Adapters.ThousandIsland.Handler do + @moduledoc """ + ThousandIsland handler for gRPC requests. + + Implementa ThousandIsland.Handler para lidar com gRPC sobre HTTP/2. + """ + use ThousandIsland.Handler + + alias GRPC.Server.HTTP2.Connection + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Errors + require Logger + + # HTTP/2 connection preface per RFC9113§3.4 + @connection_preface "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + # Inline hot path functions + @compile {:inline, handle_data: 3, handle_preface: 3, handle_frames_loop: 5} + + @impl ThousandIsland.Handler + def handle_connection(socket, handler_options) do + Logger.debug("New HTTP/2 connection established - socket: #{inspect(socket)}") + + # Initialize ETS cache for codecs/compressors lookup + GRPC.Server.Cache.init() + + # Support both keyword list and map formats + {servers_list, endpoint, opts} = + if is_map(handler_options) do + {Map.get(handler_options, :servers, []), Map.get(handler_options, :endpoint), + Map.get(handler_options, :opts, [])} + else + {Keyword.get(handler_options, :servers, []), Keyword.get(handler_options, :endpoint), + Keyword.get(handler_options, :opts, [])} + end + + servers = + cond do + is_map(servers_list) and not is_struct(servers_list) -> servers_list + is_list(servers_list) -> GRPC.Server.servers_to_map(servers_list) + true -> %{} + end + + Logger.debug("[handle_connection] servers: #{inspect(servers)}") + Logger.debug("[handle_connection] endpoint: #{inspect(endpoint)}") + + new_state = %{ + endpoint: endpoint, + servers: servers, + opts: opts, + connection: nil, + buffer: <<>>, + preface_received: false, + accumulated_headers: %{}, + accumulated_trailers: %{}, + stream_tasks: %{} + } + + {:continue, new_state} + end + + @impl ThousandIsland.Handler + def handle_data(data, socket, %{preface_received: false, buffer: buffer} = state) do + new_buffer = buffer <> data + handle_preface(new_buffer, socket, state) + end + + def handle_data(data, socket, %{buffer: buffer} = state) do + new_buffer = buffer <> data + handle_frames(new_buffer, socket, state) + end + + @impl ThousandIsland.Handler + def handle_close(_socket, state) do + Logger.debug("Connection closed") + {:close, state} + end + + @impl ThousandIsland.Handler + def handle_error(reason, _socket, state) do + Logger.error("Connection error: #{inspect(reason)}") + {:close, state} + end + + def handle_info({:grpc_accumulate_headers, stream_id, headers}, {socket, state}) do + current_headers = Map.get(state.accumulated_headers, stream_id, %{}) + updated_headers = Map.merge(current_headers, headers) + new_accumulated = Map.put(state.accumulated_headers, stream_id, updated_headers) + {:noreply, {socket, %{state | accumulated_headers: new_accumulated}}} + end + + def handle_info({:grpc_accumulate_trailers, stream_id, trailers}, {socket, state}) do + current_trailers = Map.get(state.accumulated_trailers, stream_id, %{}) + updated_trailers = Map.merge(current_trailers, trailers) + new_accumulated = Map.put(state.accumulated_trailers, stream_id, updated_trailers) + {:noreply, {socket, %{state | accumulated_trailers: new_accumulated}}} + end + + def handle_info({:register_stream_task, stream_id, task_pid, task_ref}, {socket, state}) do + Logger.debug( + "[Handler] Registering stream task for stream #{stream_id}, pid=#{inspect(task_pid)}" + ) + + new_stream_tasks = Map.put(state.stream_tasks, stream_id, {task_pid, task_ref}) + {:noreply, {socket, %{state | stream_tasks: new_stream_tasks}}} + end + + def handle_info({:grpc_send_headers, stream_id, headers}, {socket, state}) do + Logger.debug("[Streaming] Sending headers for stream #{stream_id}") + Connection.send_headers(socket, stream_id, headers, state.connection) + {:noreply, {socket, state}} + end + + def handle_info({:grpc_send_data, stream_id, data}, {socket, state}) do + accumulated = Map.get(state.accumulated_headers, stream_id, %{}) + + new_state = + if map_size(accumulated) > 0 do + Connection.send_headers(socket, stream_id, accumulated, state.connection) + %{state | accumulated_headers: Map.delete(state.accumulated_headers, stream_id)} + else + state + end + + Connection.send_data(socket, stream_id, data, false, new_state.connection) + {:noreply, {socket, new_state}} + end + + def handle_info({:grpc_send_trailers, stream_id, trailers}, {socket, state}) do + accumulated = Map.get(state.accumulated_headers, stream_id, %{}) + + new_state = + if map_size(accumulated) > 0 do + updated_conn = Connection.send_headers(socket, stream_id, accumulated, state.connection) + + %{ + state + | accumulated_headers: Map.delete(state.accumulated_headers, stream_id), + connection: updated_conn + } + else + state + end + + # Merge accumulated custom trailers with final trailers + accumulated_trailers = Map.get(state.accumulated_trailers, stream_id, %{}) + final_trailers = Map.merge(trailers, accumulated_trailers) + + # Send trailers (headers with END_STREAM) for streaming + # This will also remove the stream from the connection + updated_connection = + Connection.send_trailers(socket, stream_id, final_trailers, new_state.connection) + + new_state = %{ + new_state + | connection: updated_connection, + accumulated_trailers: Map.delete(new_state.accumulated_trailers, stream_id), + stream_tasks: Map.delete(new_state.stream_tasks, stream_id) + } + + {:noreply, {socket, new_state}} + end + + def handle_info({:update_stream_state, stream_id, updated_stream_state}, {socket, state}) do + Logger.debug( + "[Handler] Updating stream_state for stream #{stream_id}, bidi_pid=#{inspect(updated_stream_state.bidi_stream_pid)}" + ) + + connection = state.connection + + updated_connection = %{ + connection + | streams: Map.put(connection.streams, stream_id, updated_stream_state) + } + + {:noreply, {socket, %{state | connection: updated_connection}}} + end + + def handle_info({:DOWN, _ref, :process, pid, reason}, {socket, state}) do + # Task crashed - find which stream it belongs to and send error trailers + case Enum.find(state.stream_tasks, fn {_stream_id, {task_pid, _ref}} -> task_pid == pid end) do + {stream_id, _} -> + Logger.error("[Handler] Stream #{stream_id} task crashed: #{inspect(reason)}") + + # Send error trailers to client + error_trailers = %{ + "grpc-status" => "13", + # INTERNAL + "grpc-message" => "Stream handler crashed: #{inspect(reason)}" + } + + # Check if we have unsent accumulated headers (stream never sent data) + accumulated = Map.get(state.accumulated_headers, stream_id, %{}) + + new_state = + if map_size(accumulated) > 0 do + updated_conn = + Connection.send_headers(socket, stream_id, accumulated, state.connection) + + %{ + state + | accumulated_headers: Map.delete(state.accumulated_headers, stream_id), + connection: updated_conn + } + else + state + end + + # Send error trailers + updated_connection = + Connection.send_trailers(socket, stream_id, error_trailers, new_state.connection) + + new_state = %{ + new_state + | connection: updated_connection, + accumulated_trailers: Map.delete(new_state.accumulated_trailers, stream_id), + stream_tasks: Map.delete(new_state.stream_tasks, stream_id) + } + + {:noreply, {socket, new_state}} + + nil -> + # Task not found in our tracking - ignore silently + # This can happen for tasks spawned outside our control + {:noreply, {socket, state}} + end + end + + def handle_info(_msg, {socket, state}) do + {:noreply, {socket, state}} + end + + defp handle_preface(buffer, _socket, state) when byte_size(buffer) < 24 do + # Wait for more data (preface is 24 bytes) + {:continue, %{state | buffer: buffer}} + end + + defp handle_preface(<<@connection_preface, remaining::binary>>, socket, state) do + # Valid preface, initialize connection + try do + opts = Keyword.put(state.opts, :handler_pid, self()) + connection = Connection.init(socket, state.endpoint, state.servers, opts) + new_state = %{state | connection: connection, preface_received: true, buffer: <<>>} + + if byte_size(remaining) > 0 do + handle_frames(remaining, socket, new_state) + else + {:continue, new_state} + end + rescue + e -> + Logger.error( + "Error initializing connection: #{inspect(e)}\n#{Exception.format_stacktrace()}" + ) + + {:close, state} + end + end + + defp handle_preface(_buffer, _socket, state) do + Logger.debug("Invalid HTTP/2 preface") + {:close, state} + end + + defp handle_frames(buffer, socket, state) do + handle_frames_loop( + buffer, + socket, + state.connection, + state.connection.remote_settings.max_frame_size, + state + ) + end + + defp handle_frames_loop(buffer, socket, connection, max_frame_size, original_state) do + case Frame.deserialize(buffer, max_frame_size) do + {{:ok, frame}, rest} -> + try do + new_connection = Connection.handle_frame(frame, socket, connection) + + if byte_size(rest) > 0 do + # Continue processing with updated connection + handle_frames_loop(rest, socket, new_connection, max_frame_size, original_state) + else + # All frames processed, return updated state + {:continue, %{original_state | connection: new_connection, buffer: <<>>}} + end + rescue + e in Errors.ConnectionError -> + Logger.debug("Connection error: #{e.message}") + {:close, original_state} + + e in Errors.StreamError -> + Logger.debug("Stream error: #{e.message}") + {:continue, %{original_state | connection: connection, buffer: rest}} + end + + {{:more, _partial}, <<>>} -> + # Need more data to parse frame + {:continue, %{original_state | connection: connection, buffer: buffer}} + + {{:error, error_code, reason}, _rest} -> + Logger.debug("Frame deserialization error: #{reason} (code: #{error_code})") + {:close, original_state} + + nil -> + # No more frames to parse + {:continue, %{original_state | connection: connection, buffer: <<>>}} + end + end + + def read_full_body(pid) do + GenServer.call(pid, :read_full_body) + end + + def read_body(pid) do + GenServer.call(pid, :read_body) + end + + def send_data(pid, data, opts) do + GenServer.cast(pid, {:send_data, data, opts}) + end + + def send_headers(pid, headers) do + GenServer.cast(pid, {:send_headers, headers}) + end + + def set_headers(pid, headers) do + GenServer.cast(pid, {:set_headers, headers}) + end + + def set_trailers(pid, trailers) do + GenServer.cast(pid, {:set_trailers, trailers}) + end + + def send_trailers(pid, trailers) do + GenServer.cast(pid, {:send_trailers, trailers}) + end + + def get_headers(pid) do + GenServer.call(pid, :get_headers) + end + + def get_peer(pid) do + GenServer.call(pid, :get_peer) + end + + def get_cert(pid) do + GenServer.call(pid, :get_cert) + end + + def get_query_string(pid) do + GenServer.call(pid, :get_query_string) + end + + def get_bindings(pid) do + GenServer.call(pid, :get_bindings) + end + + def set_compressor(pid, compressor) do + GenServer.cast(pid, {:set_compressor, compressor}) + end +end diff --git a/grpc_server/lib/grpc/server/adapters/thousand_island/supervisor.ex b/grpc_server/lib/grpc/server/adapters/thousand_island/supervisor.ex new file mode 100644 index 000000000..f079186d2 --- /dev/null +++ b/grpc_server/lib/grpc/server/adapters/thousand_island/supervisor.ex @@ -0,0 +1,160 @@ +defmodule GRPC.Server.Adapters.ThousandIsland.Supervisor do + @moduledoc """ + Supervisor for ThousandIsland adapter. + + This supervisor manages the lifecycle of the ThousandIsland server and + provides isolation from other adapters. It encapsulates all ThousandIsland-specific + configuration and startup logic. + + ## Supervision Tree + + ``` + GRPC.Server.Supervisor + └── ThousandIsland.Supervisor (this module) + └── ThousandIsland (actual socket server) + ├── Acceptor Pool + ├── Connection Handlers + └── Handler Processes + ``` + + ## Responsibilities + + - Configures ThousandIsland server with gRPC-specific settings + - Manages HTTP/2 settings and transport options + - Handles SSL/TLS configuration + - Provides clean shutdown on termination + """ + + use Supervisor + require Logger + + alias GRPC.Server.Adapters.ThousandIsland.Handler + + @default_num_acceptors 10 + @default_max_connections 16_384 + + @doc """ + Starts the ThousandIsland supervisor. + + ## Options + + * `:endpoint` - The endpoint module (optional) + * `:servers` - Map of service name => server modules + * `:port` - The port to listen on + * `:adapter_opts` - ThousandIsland-specific options (see below) + * `:cred` - SSL credentials (optional, for HTTPS) + + ## Adapter Options + + * `:num_acceptors` - Number of acceptor processes (default: 10) + * `:num_connections` - Maximum number of connections (default: 16384) + * `:ip` - IP address to bind to (default: {0, 0, 0, 0}) + * `:transport_options` - Additional transport options + """ + def start_link(opts) do + Supervisor.start_link(__MODULE__, opts, name: __MODULE__) + end + + @impl true + def init(opts) do + endpoint = opts[:endpoint] + servers = opts[:servers] + port = opts[:port] + + server_opts = build_server_opts(endpoint, servers, port, opts) + + scheme = if cred_opts(opts), do: :https, else: :http + server_name = server_names(endpoint, servers) + + Logger.info("Starting #{server_name} with ThousandIsland using #{scheme}://0.0.0.0:#{port}") + + children = [ + {Task.Supervisor, name: GRPC.Server.StreamTaskSupervisor}, + {ThousandIsland, server_opts} + ] + + Supervisor.init(children, strategy: :one_for_one) + end + + defp build_server_opts(endpoint, servers, port, opts) do + adapter_opts = Keyword.get(opts, :adapter_opts, opts) + + num_acceptors = Keyword.get(adapter_opts, :num_acceptors, @default_num_acceptors) + num_connections = Keyword.get(adapter_opts, :num_connections, @default_max_connections) + + transport_opts = + adapter_opts + |> Keyword.get(:transport_options, []) + |> Keyword.put(:port, port) + |> maybe_add_ip(adapter_opts) + |> maybe_add_ssl(cred_opts(opts)) + # Optimize TCP buffers for gRPC performance (support up to 1MB messages) + # 1MB buffer for large messages + |> Keyword.put_new(:buffer, 1_048_576) + # 1MB receive buffer + |> Keyword.put_new(:recbuf, 1_048_576) + # 1MB send buffer + |> Keyword.put_new(:sndbuf, 1_048_576) + # Disable Nagle's algorithm for low latency + |> Keyword.put_new(:nodelay, true) + + # Configure HTTP/2 settings for larger frames (needed for large gRPC messages) + local_settings = [ + # 1MB window size for large payloads + initial_window_size: 1_048_576, + # Keep default max frame size + max_frame_size: 16_384 + ] + + handler_options = %{ + endpoint: endpoint, + servers: servers, + opts: [local_settings: local_settings] + } + + Logger.debug("[ThousandIsland.Supervisor] Creating server configuration") + + [ + port: port, + transport_module: transport_module(opts), + transport_options: transport_opts, + handler_module: Handler, + handler_options: handler_options, + num_acceptors: num_acceptors, + num_connections: num_connections + ] + end + + defp maybe_add_ip(transport_opts, adapter_opts) do + case Keyword.get(adapter_opts, :ip) do + nil -> transport_opts + ip -> Keyword.put(transport_opts, :ip, ip) + end + end + + defp maybe_add_ssl(transport_opts, nil), do: transport_opts + + defp maybe_add_ssl(transport_opts, cred_opts) do + Keyword.merge(transport_opts, cred_opts.ssl) + end + + defp transport_module(opts) do + if cred_opts(opts) do + ThousandIsland.Transports.SSL + else + ThousandIsland.Transports.TCP + end + end + + defp cred_opts(opts) do + opts[:cred] + end + + defp server_names(nil, servers) do + Enum.map_join(servers, ",", fn {_k, s} -> inspect(s) end) + end + + defp server_names(endpoint, _) do + inspect(endpoint) + end +end diff --git a/grpc_server/lib/grpc/server/bidi_stream.ex b/grpc_server/lib/grpc/server/bidi_stream.ex new file mode 100644 index 000000000..f5fc6cb3f --- /dev/null +++ b/grpc_server/lib/grpc/server/bidi_stream.ex @@ -0,0 +1,168 @@ +defmodule GRPC.Server.BidiStream do + @moduledoc """ + Manages bidirectional streaming request messages using a supervised Task. + + This module stores incoming request messages for a bidi streaming RPC + and provides them as a lazy enumerable to the handler. It blocks when + no messages are available and wakes up when new messages arrive. + + Uses a Task instead of GenServer for lighter weight and automatic supervision. + """ + require Logger + + @doc """ + Starts a supervised task that manages messages for a bidi stream. + Returns {:ok, pid}. + """ + def start_link(stream_id, initial_messages \\ []) do + Task.Supervisor.start_child(GRPC.Server.StreamTaskSupervisor, fn -> + loop(%{ + stream_id: stream_id, + messages: :queue.from_list(initial_messages), + waiting_caller: nil, + stream_finished: false + }) + end) + end + + @doc """ + Adds decoded messages to the stream. + """ + def put_messages(pid, messages) when is_list(messages) do + send(pid, {:put_messages, messages}) + :ok + end + + @doc """ + Marks the stream as finished (client sent END_STREAM). + """ + def finish(pid) do + send(pid, :finish) + :ok + end + + @doc """ + Cancels the stream (client sent RST_STREAM). + """ + def cancel(pid) do + send(pid, :cancel) + :ok + end + + @doc """ + Creates a lazy enumerable that reads messages from this stream. + Blocks when no messages are available. + """ + def to_enum(pid) do + Stream.resource( + fn -> pid end, + fn pid -> + case get_next_message(pid) do + {:ok, message} -> {[message], pid} + :done -> {:halt, pid} + end + end, + fn _pid -> :ok end + ) + end + + ## Private + + defp get_next_message(pid) do + ref = make_ref() + send(pid, {:next_message, self(), ref}) + + receive do + {^ref, response} -> response + end + end + + defp loop(state) do + receive do + {:next_message, caller_pid, ref} -> + Logger.info( + "[BidiStream #{state.stream_id}] Received :next_message, queue_size=#{:queue.len(state.messages)}, finished=#{state.stream_finished}" + ) + + case :queue.out(state.messages) do + {{:value, message}, new_queue} -> + # Return message immediately + Logger.debug("[BidiStream #{state.stream_id}] Returning message from queue") + send(caller_pid, {ref, {:ok, message}}) + loop(%{state | messages: new_queue}) + + {:empty, _} -> + if state.stream_finished do + # No more messages and stream is done + Logger.debug("[BidiStream #{state.stream_id}] Stream finished, no more messages") + send(caller_pid, {ref, :done}) + # Exit the task - stream is complete + :ok + else + # No messages yet - store caller and wait + Logger.debug("[BidiStream #{state.stream_id}] Queue empty, blocking caller") + loop(%{state | waiting_caller: {caller_pid, ref}}) + end + end + + {:put_messages, new_messages} -> + Logger.info( + "[BidiStream #{state.stream_id}] Received #{length(new_messages)} new messages" + ) + + # Add messages to queue + new_queue = + Enum.reduce(new_messages, state.messages, fn msg, queue -> + :queue.in(msg, queue) + end) + + # If someone is waiting, reply with first message + Logger.debug( + "[BidiStream #{state.stream_id}] After adding messages, queue_size=#{:queue.len(new_queue)}, has_waiting_caller=#{not is_nil(state.waiting_caller)}" + ) + + if state.waiting_caller do + {caller_pid, ref} = state.waiting_caller + + case :queue.out(new_queue) do + {{:value, message}, final_queue} -> + send(caller_pid, {ref, {:ok, message}}) + loop(%{state | messages: final_queue, waiting_caller: nil}) + + {:empty, _} -> + # Shouldn't happen but handle gracefully + loop(%{state | messages: new_queue}) + end + else + loop(%{state | messages: new_queue}) + end + + :finish -> + Logger.info( + "[BidiStream #{state.stream_id}] Received :finish, has_waiting_caller=#{not is_nil(state.waiting_caller)}" + ) + + # Mark stream as finished + if state.waiting_caller do + {caller_pid, ref} = state.waiting_caller + # Reply to waiting caller that stream is done + send(caller_pid, {ref, :done}) + # Exit the task - stream is complete + :ok + else + loop(%{state | stream_finished: true}) + end + + :cancel -> + Logger.info("[BidiStream #{state.stream_id}] Received :cancel (RST_STREAM from client)") + # If someone is waiting, reply that stream is done + if state.waiting_caller do + {caller_pid, ref} = state.waiting_caller + send(caller_pid, {ref, :done}) + end + + # Exit the task - stream was cancelled + :ok + end + end +end diff --git a/grpc_server/lib/grpc/server/cache.ex b/grpc_server/lib/grpc/server/cache.ex new file mode 100644 index 000000000..26db6991b --- /dev/null +++ b/grpc_server/lib/grpc/server/cache.ex @@ -0,0 +1,90 @@ +defmodule GRPC.Server.Cache do + @moduledoc """ + ETS-based cache for frequently accessed server metadata. + Improves performance by avoiding repeated Enum.find operations. + """ + + @table_name :grpc_server_cache + + def init do + case :ets.whereis(@table_name) do + :undefined -> + :ets.new(@table_name, [:named_table, :public, :set, {:read_concurrency, true}]) + + _tid -> + :ok + end + end + + @doc """ + Find codec by name for a given server. + Uses ETS cache to avoid repeated Enum.find calls. + """ + def find_codec(server, subtype) do + cache_key = {:codec, server, subtype} + + case :ets.lookup(@table_name, cache_key) do + [{^cache_key, codec}] -> + codec + + [] -> + codec = Enum.find(server.__meta__(:codecs), nil, fn c -> c.name() == subtype end) + :ets.insert(@table_name, {cache_key, codec}) + codec + end + end + + @doc """ + Find compressor by name for a given server. + Uses ETS cache to avoid repeated Enum.find calls. + """ + def find_compressor(server, encoding) do + cache_key = {:compressor, server, encoding} + + case :ets.lookup(@table_name, cache_key) do + [{^cache_key, compressor}] -> + compressor + + [] -> + compressor = + Enum.find(server.__meta__(:compressors), nil, fn c -> c.name() == encoding end) + + :ets.insert(@table_name, {cache_key, compressor}) + compressor + end + end + + @doc """ + Find RPC definition by method name for a given server. + Uses ETS cache to avoid repeated Enum.find calls. + """ + def find_rpc(server, method_name) do + cache_key = {:rpc, server, method_name} + + case :ets.lookup(@table_name, cache_key) do + [{^cache_key, rpc}] -> + rpc + + [] -> + rpc_calls = server.__meta__(:service).__rpc_calls__() + + rpc = + Enum.find(rpc_calls, nil, fn {name, _, _, _} -> + Atom.to_string(name) == method_name + end) + + :ets.insert(@table_name, {cache_key, rpc}) + rpc + end + end + + @doc """ + Clear the cache. Useful for testing or when server configuration changes. + """ + def clear do + case :ets.whereis(@table_name) do + :undefined -> :ok + _tid -> :ets.delete_all_objects(@table_name) + end + end +end diff --git a/grpc_server/lib/grpc/server/http2/connection.ex b/grpc_server/lib/grpc/server/http2/connection.ex new file mode 100644 index 000000000..b6310116e --- /dev/null +++ b/grpc_server/lib/grpc/server/http2/connection.ex @@ -0,0 +1,1046 @@ +defmodule GRPC.Server.HTTP2.Connection do + @moduledoc """ + Represents the state of an HTTP/2 connection for gRPC. + """ + require Logger + + alias GRPC.Transport.HTTP2.Frame + alias GRPC.Transport.HTTP2.Settings + alias GRPC.Transport.HTTP2.Errors + alias GRPC.Server.HTTP2.StreamState + + # Inline hot path functions for performance + @compile {:inline, + extract_messages: 2, send_frame: 3, handle_headers_frame: 3, send_grpc_trailers: 4} + + defstruct local_settings: %Settings{}, + remote_settings: %Settings{}, + fragment_frame: nil, + send_hpack_state: HPAX.new(4096), + recv_hpack_state: HPAX.new(4096), + send_window_size: 65_535, + recv_window_size: 65_535, + streams: %{}, + next_stream_id: 2, + endpoint: nil, + servers: %{}, + socket: nil, + handler_pid: nil + + @typedoc "Encapsulates the state of an HTTP/2 connection" + @type t :: %__MODULE__{ + local_settings: Settings.t(), + remote_settings: Settings.t(), + fragment_frame: Frame.Headers.t() | nil, + send_hpack_state: term(), + recv_hpack_state: term(), + send_window_size: non_neg_integer(), + recv_window_size: non_neg_integer(), + streams: %{non_neg_integer() => map()}, + next_stream_id: non_neg_integer(), + endpoint: atom(), + servers: %{String.t() => [module()]}, + socket: ThousandIsland.Socket.t() | nil + } + + @doc """ + Initializes a new HTTP/2 connection. + """ + @spec init(ThousandIsland.Socket.t(), atom(), %{String.t() => [module()]}, keyword()) :: t() + def init(socket, endpoint, servers, opts \\ []) do + handler_pid = Keyword.get(opts, :handler_pid) + Logger.debug("[Connection.init] handler_pid=#{inspect(handler_pid)}") + + connection = %__MODULE__{ + local_settings: struct!(Settings, Keyword.get(opts, :local_settings, [])), + endpoint: endpoint, + servers: servers, + socket: socket, + handler_pid: handler_pid + } + + # Send initial SETTINGS frame per RFC9113§3.4 + settings_frame = %Frame.Settings{ + ack: false, + settings: [ + header_table_size: connection.local_settings.header_table_size, + max_concurrent_streams: connection.local_settings.max_concurrent_streams, + initial_window_size: connection.local_settings.initial_window_size, + max_frame_size: connection.local_settings.max_frame_size, + max_header_list_size: connection.local_settings.max_header_list_size + ] + } + + send_frame(settings_frame, socket, connection) + + connection + end + + @doc """ + Send headers for streaming response. + """ + def send_headers(socket, stream_id, headers, connection) do + # Check if stream still exists (may have been closed by RST_STREAM) + unless Map.has_key?(connection.streams, stream_id) do + Logger.debug( + "[send_headers] SKIPPED - stream=#{stream_id} no longer exists (likely cancelled by client)" + ) + + connection + else + # Encode headers using HPAX - convert map to list of tuples + Logger.debug("[send_headers] stream_id=#{stream_id}, headers=#{inspect(headers)}") + headers_list = if is_map(headers), do: Map.to_list(headers), else: headers + + {header_block, _new_hpack} = + HPAX.encode(:no_store, headers_list, connection.send_hpack_state) + + # Send HEADERS frame without END_STREAM + frame = %Frame.Headers{ + stream_id: stream_id, + fragment: header_block, + end_stream: false, + end_headers: true + } + + send_frame(frame, socket, connection) + end + end + + @doc """ + Send data frame for streaming response. + """ + def send_data(socket, stream_id, data, end_stream, connection) do + # Check if stream still exists (may have been closed by RST_STREAM) + unless Map.has_key?(connection.streams, stream_id) do + Logger.debug( + "[send_data] SKIPPED - stream=#{stream_id} no longer exists (likely cancelled by client)" + ) + + connection + else + # Send DATA frame + frame = %Frame.Data{ + stream_id: stream_id, + data: data, + end_stream: end_stream + } + + send_frame(frame, socket, connection) + end + end + + @doc """ + Send trailers (headers with END_STREAM) for streaming response. + """ + def send_trailers(socket, stream_id, trailers, connection) do + # Check if stream still exists (may have been closed by RST_STREAM) + unless Map.has_key?(connection.streams, stream_id) do + Logger.debug( + "[send_trailers] SKIPPED - stream=#{stream_id} no longer exists (likely cancelled by client)" + ) + + connection + else + # Encode custom metadata (handles -bin suffix base64 encoding) + # Note: encode_metadata filters out reserved headers like grpc-status + encoded_custom = GRPC.Transport.HTTP2.encode_metadata(trailers) + + # Re-add reserved headers (grpc-status, etc) that were filtered out + encoded_trailers = + Map.merge( + Map.take(trailers, ["grpc-status", "grpc-message"]), + encoded_custom + ) + + # Convert map to list + trailer_list = Map.to_list(encoded_trailers) + + {trailer_block, _new_hpack} = + HPAX.encode(:no_store, trailer_list, connection.send_hpack_state) + + # Send HEADERS frame with END_STREAM + frame = %Frame.Headers{ + stream_id: stream_id, + fragment: trailer_block, + end_stream: true, + end_headers: true + } + + connection = send_frame(frame, socket, connection) + + # Remove stream after sending END_STREAM (RFC 7540: stream transitions to closed) + Logger.debug("[send_trailers] Removing stream #{stream_id} after sending END_STREAM") + %{connection | streams: Map.delete(connection.streams, stream_id)} + end + end + + @doc """ + Set custom headers for a stream (used for custom_metadata test). + """ + def set_stream_custom_headers(connection, stream_id, headers) do + case Map.get(connection.streams, stream_id) do + nil -> + connection + + stream_state -> + updated_stream = %{ + stream_state + | custom_headers: Map.merge(stream_state.custom_headers, headers) + } + + %{connection | streams: Map.put(connection.streams, stream_id, updated_stream)} + end + end + + @doc """ + Set custom trailers for a stream (used for custom_metadata test). + """ + def set_stream_custom_trailers(connection, stream_id, trailers) do + case Map.get(connection.streams, stream_id) do + nil -> + connection + + stream_state -> + updated_stream = %{ + stream_state + | custom_trailers: Map.merge(stream_state.custom_trailers, trailers) + } + + %{connection | streams: Map.put(connection.streams, stream_id, updated_stream)} + end + end + + @doc """ + Get the stream state for a specific stream ID. + Returns nil if stream doesn't exist. + """ + def get_stream(connection, stream_id) do + Map.get(connection.streams, stream_id) + end + + @doc """ + Handles an incoming HTTP/2 frame. + """ + @spec handle_frame(Frame.frame(), ThousandIsland.Socket.t(), t()) :: t() + def handle_frame(frame, socket, connection) do + do_handle_frame(frame, socket, connection) + end + + defp do_handle_frame( + %Frame.Continuation{end_headers: true, stream_id: stream_id} = frame, + socket, + %__MODULE__{fragment_frame: %Frame.Headers{stream_id: stream_id}} = connection + ) do + header_block = connection.fragment_frame.fragment <> frame.fragment + header_frame = %{connection.fragment_frame | end_headers: true, fragment: header_block} + do_handle_frame(header_frame, socket, %{connection | fragment_frame: nil}) + end + + defp do_handle_frame( + %Frame.Continuation{end_headers: false, stream_id: stream_id} = frame, + _socket, + %__MODULE__{fragment_frame: %Frame.Headers{stream_id: stream_id}} = connection + ) do + fragment = connection.fragment_frame.fragment <> frame.fragment + # TODO: Check max header list size + fragment_frame = %{connection.fragment_frame | fragment: fragment} + %{connection | fragment_frame: fragment_frame} + end + + defp do_handle_frame(_frame, _socket, %__MODULE__{fragment_frame: %Frame.Headers{}}) do + connection_error!("Expected CONTINUATION frame (RFC9113§6.10)") + end + + defp do_handle_frame(%Frame.Settings{ack: true}, _socket, connection) do + Logger.info("[Connection] Received SETTINGS ACK") + connection + end + + defp do_handle_frame(%Frame.Settings{ack: false} = frame, socket, connection) do + Logger.info("[Connection] Received SETTINGS, sending ACK") + %Frame.Settings{ack: true, settings: []} |> send_frame(socket, connection) + + remote_settings = apply_settings(connection.remote_settings, frame.settings) + + # Update HPACK table size if changed + send_hpack_state = + if remote_settings.header_table_size != connection.remote_settings.header_table_size do + HPAX.resize(connection.send_hpack_state, remote_settings.header_table_size) + else + connection.send_hpack_state + end + + %{connection | remote_settings: remote_settings, send_hpack_state: send_hpack_state} + end + + defp do_handle_frame(%Frame.Ping{ack: false, payload: data}, socket, connection) do + Logger.info("[Connection] Received PING, sending ACK") + %Frame.Ping{ack: true, payload: data} |> send_frame(socket, connection) + connection + end + + defp do_handle_frame(%Frame.Ping{ack: true}, _socket, connection) do + # Ignore PING ACKs for now (we don't track sent PINGs yet) + connection + end + + defp do_handle_frame(%Frame.Goaway{} = frame, _socket, connection) do + Logger.info( + "Received GOAWAY: last_stream_id=#{frame.last_stream_id}, error=#{frame.error_code}" + ) + + # TODO: Handle graceful shutdown + connection + end + + defp do_handle_frame(%Frame.WindowUpdate{stream_id: 0} = frame, _socket, connection) do + Logger.info("[Connection] WINDOW_UPDATE connection-level: +#{frame.size_increment}") + new_window = connection.send_window_size + frame.size_increment + + if new_window > 2_147_483_647 do + connection_error!("Flow control window overflow (RFC9113§6.9.1)") + end + + %{connection | send_window_size: new_window} + end + + defp do_handle_frame(%Frame.Headers{} = frame, socket, connection) do + Logger.info( + "[RECV_HEADERS] stream=#{frame.stream_id}, end_headers=#{frame.end_headers}, end_stream=#{frame.end_stream}" + ) + + if frame.end_headers do + handle_headers_frame(frame, socket, connection) + else + # Start accumulating CONTINUATION frames + %{connection | fragment_frame: frame} + end + end + + defp do_handle_frame(%Frame.Data{stream_id: stream_id} = frame, socket, connection) do + Logger.info( + "[RECV_DATA] stream=#{stream_id}, size=#{byte_size(frame.data)}, end_stream=#{frame.end_stream}, stream_exists=#{Map.has_key?(connection.streams, stream_id)}" + ) + + case Map.get(connection.streams, stream_id) do + nil -> + Logger.debug( + "[IGNORE_DATA] stream=#{stream_id} not found, size=#{byte_size(frame.data)} (stream already closed)" + ) + + connection + + stream_state -> + # Send WINDOW_UPDATE to allow client to continue sending + data_size = byte_size(frame.data) + + # Only send WINDOW_UPDATE if there's actual data (non-zero increment) + if data_size > 0 do + Logger.debug("[WINDOW_UPDATE] stream=#{stream_id}, size=#{data_size}") + # Send connection-level WINDOW_UPDATE + conn_window_update = %Frame.WindowUpdate{stream_id: 0, size_increment: data_size} + send_frame(conn_window_update, socket, connection) + + # Send stream-level WINDOW_UPDATE + stream_window_update = %Frame.WindowUpdate{ + stream_id: stream_id, + size_increment: data_size + } + + send_frame(stream_window_update, socket, connection) + end + + # Accumulate data in stream buffer + updated_stream = %{stream_state | data_buffer: stream_state.data_buffer <> frame.data} + + # Mark if END_STREAM was received (stream half-closed remote) + updated_stream = + if frame.end_stream do + %{updated_stream | end_stream_received: true} + else + updated_stream + end + + updated_connection = %{ + connection + | streams: Map.put(connection.streams, stream_id, updated_stream) + } + + # For bidirectional streaming, process when we have complete messages + # For other types, wait for END_STREAM + should_process = + if frame.end_stream do + true + else + # For bidi, check if we have complete messages AND start processing + updated_stream.is_bidi_streaming and has_complete_message?(updated_stream.data_buffer) + end + + if should_process do + # Process the request (for bidi, this starts the handler on first message) + Logger.info( + "[Connection] Stream #{stream_id} processing (end_stream=#{frame.end_stream}, bidi=#{updated_stream.is_bidi_streaming})" + ) + + process_grpc_request(socket, updated_stream, updated_connection, frame.end_stream) + else + # More data coming (non-bidi case) + Logger.debug( + "[Connection] Stream #{stream_id} waiting for more data (buffer=#{byte_size(updated_stream.data_buffer)} bytes)" + ) + + updated_connection + end + end + end + + defp do_handle_frame( + %Frame.RstStream{stream_id: stream_id, error_code: error}, + _socket, + connection + ) do + stream_exists = Map.has_key?(connection.streams, stream_id) + + Logger.info( + "[RECV_RST_STREAM] stream=#{stream_id}, error=#{error}, stream_exists=#{stream_exists}" + ) + + # Notify BidiStream (if exists) that stream was cancelled + case Process.get({:bidi_stream_pid, stream_id}) do + nil -> + Logger.debug("[Connection] No BidiStream found for stream #{stream_id}") + :ok + + pid -> + Logger.info("[Connection] Cancelling BidiStream for stream #{stream_id}") + GRPC.Server.BidiStream.cancel(pid) + end + + # Clean up process dictionary + Process.delete({:bidi_stream_pid, stream_id}) + Process.delete({:bidi_stream_state, stream_id}) + Process.delete({:grpc_custom_trailers, stream_id}) + + # Remove stream from streams map + streams = Map.delete(connection.streams, stream_id) + Logger.debug("[REMOVE_STREAM] stream=#{stream_id}, remaining_streams=#{map_size(streams)}") + %{connection | streams: streams} + end + + defp do_handle_frame(%Frame.WindowUpdate{stream_id: stream_id} = frame, _socket, connection) do + Logger.info("[Connection] WINDOW_UPDATE stream=#{stream_id}: +#{frame.size_increment}") + # TODO: Update stream send window + connection + end + + defp do_handle_frame(%Frame.Priority{}, _socket, connection) do + # gRPC doesn't use priority, ignore + connection + end + + defp do_handle_frame(%Frame.PushPromise{}, _socket, _connection) do + # Server push not supported in gRPC + connection_error!("PUSH_PROMISE not supported (RFC9113§8.4)") + end + + defp do_handle_frame(%Frame.Unknown{}, _socket, connection) do + # Ignore unknown frames per RFC9113§4.1 + connection + end + + defp handle_headers_frame(frame, _socket, connection) do + Logger.info("[handle_headers_frame] Decoding HPACK for stream #{frame.stream_id}") + + # Check if this is trailers for an existing stream + case Map.get(connection.streams, frame.stream_id) do + nil -> + # New stream - decode headers and create stream state + case HPAX.decode(frame.fragment, connection.recv_hpack_state) do + {:ok, headers, new_hpack_state} -> + Logger.info( + "[handle_headers_frame] Decoded headers for stream #{frame.stream_id}: #{inspect(headers)}" + ) + + # Create stream state from headers + stream_state = + StreamState.from_headers( + frame.stream_id, + headers, + connection.local_settings.initial_window_size + ) + + # Add handler_pid for streaming support + stream_state = %{stream_state | handler_pid: connection.handler_pid} + + # Check if this is bidirectional streaming + # For bidi, we need to process messages as they arrive (not wait for END_STREAM) + is_bidi = + GRPC.Server.HTTP2.Dispatcher.is_bidi_streaming?( + stream_state.path, + connection.servers + ) + + stream_state = %{stream_state | is_bidi_streaming: is_bidi} + + if is_bidi do + Logger.info( + "[handle_headers_frame] Stream #{frame.stream_id} is bidirectional streaming" + ) + end + + # Store stream in connection + streams = Map.put(connection.streams, frame.stream_id, stream_state) + + %{connection | recv_hpack_state: new_hpack_state, streams: streams} + + {:error, reason} -> + connection_error!("HPACK decode error: #{inspect(reason)}") + end + + _stream_state -> + # Trailers for existing stream - just decode but don't create new stream + # This can happen when client sends trailers after we've sent response/error + Logger.info( + "[handle_headers_frame] Ignoring trailers for stream #{frame.stream_id} (stream already processed)" + ) + + case HPAX.decode(frame.fragment, connection.recv_hpack_state) do + {:ok, _headers, new_hpack_state} -> + %{connection | recv_hpack_state: new_hpack_state} + + {:error, reason} -> + Logger.debug( + "[handle_headers_frame] Failed to decode trailers for stream #{frame.stream_id}: #{inspect(reason)}" + ) + + # Continue without updating HPACK state to avoid connection error + connection + end + end + end + + defp apply_settings(settings, []), do: settings + # Convert map to keyword list if needed (for compatibility with grpc_core Frame.Settings) + defp apply_settings(settings, settings_map) when is_map(settings_map) do + apply_settings(settings, Map.to_list(settings_map)) + end + + defp apply_settings(settings, [{:header_table_size, value} | rest]) do + apply_settings(%{settings | header_table_size: value}, rest) + end + + defp apply_settings(settings, [{:enable_push, value} | rest]) do + apply_settings(%{settings | enable_push: value}, rest) + end + + defp apply_settings(settings, [{:max_concurrent_streams, value} | rest]) do + apply_settings(%{settings | max_concurrent_streams: value}, rest) + end + + defp apply_settings(settings, [{:initial_window_size, value} | rest]) do + if value > 2_147_483_647 do + connection_error!("Invalid initial window size (RFC9113§6.5.2)") + end + + apply_settings(%{settings | initial_window_size: value}, rest) + end + + defp apply_settings(settings, [{:max_frame_size, value} | rest]) do + if value < 16_384 or value > 16_777_215 do + connection_error!("Invalid max frame size (RFC9113§6.5.2)") + end + + apply_settings(%{settings | max_frame_size: value}, rest) + end + + defp apply_settings(settings, [{:max_header_list_size, value} | rest]) do + apply_settings(%{settings | max_header_list_size: value}, rest) + end + + defp apply_settings(settings, [_unknown | rest]) do + # Ignore unknown settings per RFC9113§6.5.2 + apply_settings(settings, rest) + end + + defp send_frame(frame, socket, connection) do + Logger.debug( + "[SEND_FRAME] type=#{inspect(frame.__struct__)}, stream=#{Map.get(frame, :stream_id, :none)}, flags=#{inspect(Map.get(frame, :flags, []))}}" + ) + + max_frame_size = connection.remote_settings.max_frame_size + iodata = Frame.serialize(frame, max_frame_size) + + # Send all frame data at once (iodata is already properly formatted) + # Skip sending if socket is nil (test mode) + if socket != nil do + ThousandIsland.Socket.send(socket, iodata) + end + + connection + end + + defp connection_error!(message) do + raise Errors.ConnectionError, message + end + + # Process a complete gRPC request (called when DATA arrives) + # For bidi streaming, this may be called multiple times as messages arrive + defp process_grpc_request(socket, stream_state, connection, end_stream) do + Logger.debug( + "[process_grpc_request] Processing gRPC call: #{stream_state.path} (end_stream=#{end_stream}, bidi=#{stream_state.is_bidi_streaming})" + ) + + # For bidi streaming that's already started, feed new messages to the BidiStream + if stream_state.is_bidi_streaming and stream_state.handler_started do + Logger.debug("[process_grpc_request] Bidi stream already started, feeding new messages") + stream_state = extract_messages_from_buffer(stream_state) + + # Feed messages to the BidiStream Task in {flag, data} format (not decoded yet) + if length(stream_state.message_buffer) > 0 and stream_state.bidi_stream_pid do + Logger.info( + "[Connection] Feeding #{length(stream_state.message_buffer)} messages to BidiStream #{stream_state.stream_id}, pid=#{inspect(stream_state.bidi_stream_pid)}" + ) + + # Convert messages to {flag, data} format expected by GRPC.Server.do_handle_request + messages_as_tuples = + Enum.map(stream_state.message_buffer, fn %{compressed: compressed?, data: data} -> + flag = if compressed?, do: 1, else: 0 + {flag, data} + end) + + GRPC.Server.BidiStream.put_messages(stream_state.bidi_stream_pid, messages_as_tuples) + # Clear both message_buffer and data_buffer after feeding + _stream_state = %{stream_state | message_buffer: [], data_buffer: <<>>} + end + + # If END_STREAM, mark the BidiStream as finished + if end_stream and stream_state.bidi_stream_pid do + Logger.info( + "[Connection] Marking bidi stream #{stream_state.stream_id} as finished, pid=#{inspect(stream_state.bidi_stream_pid)}" + ) + + GRPC.Server.BidiStream.finish(stream_state.bidi_stream_pid) + end + + # Update stream state + streams = Map.put(connection.streams, stream_state.stream_id, stream_state) + %{connection | streams: streams} + else + # First time processing (or non-bidi) - start the handler + process_grpc_request_initial(socket, stream_state, connection, end_stream) + end + end + + defp process_grpc_request_initial(socket, stream_state, connection, end_stream) do + Logger.debug("[process_grpc_request_initial] Starting gRPC handler") + + try do + # Extract messages from data_buffer + stream_state = extract_messages_from_buffer(stream_state) + + # Mark handler as started for bidi + stream_state = + if stream_state.is_bidi_streaming do + %{stream_state | handler_started: true} + else + stream_state + end + + # Update stream in connection before dispatching + connection = %{ + connection + | streams: Map.put(connection.streams, stream_state.stream_id, stream_state) + } + + # Use dispatcher to handle the gRPC call + Logger.debug("[process_grpc_request_initial] Dispatching gRPC call: #{stream_state.path}") + + Logger.debug( + "[process_grpc_request] Message buffer: #{length(stream_state.message_buffer)} messages" + ) + + updated_connection = + case GRPC.Server.HTTP2.Dispatcher.dispatch( + stream_state, + connection.servers, + connection.endpoint, + connection + ) do + {:ok, :streaming_done} -> + # Streaming was handled incrementally via messages, nothing more to send + Logger.debug("[process_grpc_request] Streaming completed") + + # For bidi streaming, update stream_state with full state from process dictionary + # (includes bidi_stream_pid, codec, compressor, rpc) + if stream_state.is_bidi_streaming do + updated_stream_state = Process.get({:bidi_stream_state, stream_state.stream_id}) + + Logger.info( + "[Connection] Updated stream_state for stream #{stream_state.stream_id}: bidi_pid=#{inspect(updated_stream_state.bidi_stream_pid)}" + ) + + %{ + connection + | streams: + Map.put(connection.streams, stream_state.stream_id, updated_stream_state) + } + else + connection + end + + {:ok, response_headers, response_data, trailers} -> + Logger.debug("[process_grpc_request] RPC succeeded, sending response") + # OPTIMIZATION: Send all frames in one syscall using iolist + send_grpc_response_batch( + socket, + stream_state.stream_id, + response_headers, + response_data, + trailers, + connection + ) + + {:error, %GRPC.RPCError{} = error} -> + Logger.error("[process_grpc_request] RPC error: #{inspect(error)}") + # Check if stream still exists (might have been removed by concurrent error handling) + if Map.has_key?(connection.streams, stream_state.stream_id) do + updated_connection = + send_grpc_error(socket, stream_state.stream_id, error, connection) + + # Remove stream immediately after error and return (don't continue processing) + Logger.info( + "[process_grpc_request] Removing stream #{stream_state.stream_id} after error" + ) + + %{ + updated_connection + | streams: Map.delete(updated_connection.streams, stream_state.stream_id) + } + else + Logger.debug( + "[process_grpc_request] Stream #{stream_state.stream_id} already removed, skipping error send" + ) + + connection + end + end + + # Check if stream still exists (it may have been removed by error handling) + stream_exists = Map.has_key?(updated_connection.streams, stream_state.stream_id) + + # For bidi streaming, keep the stream alive to receive more messages + # But if it's the end of the stream (client sent END_STREAM), clean up + if stream_state.is_bidi_streaming and not end_stream and stream_exists do + Logger.debug("[process_grpc_request] Keeping bidi stream #{stream_state.stream_id} alive") + updated_connection + else + # If stream doesn't exist, it was already removed (e.g., by send_grpc_error) + if not stream_exists do + updated_connection + else + # If bidi streaming and END_STREAM, mark as finished + if stream_state.is_bidi_streaming and end_stream do + # Get bidi_stream_pid from updated connection state + updated_stream_state = updated_connection.streams[stream_state.stream_id] + + if updated_stream_state && updated_stream_state.bidi_stream_pid do + Logger.info( + "[Connection] Marking bidi stream #{stream_state.stream_id} as finished (initial END_STREAM), pid=#{inspect(updated_stream_state.bidi_stream_pid)}" + ) + + GRPC.Server.BidiStream.finish(updated_stream_state.bidi_stream_pid) + end + end + + # DON'T remove stream here for ThousandIsland adapter! + # The adapter sends async messages ({:grpc_send_data}, {:grpc_send_trailers}) + # that will be processed later by handle_info in the Handler. + # The stream will be removed when send_grpc_trailers is called (which sends END_STREAM). + # + # For Cowboy adapter (synchronous), the response is sent immediately during dispatch, + # so the stream can be removed here. But for ThousandIsland, we need to wait for + # the async messages to be processed. + # + # TODO: Add a flag to StreamState to indicate if response was fully sent, + # or let the trailers handler remove the stream after sending END_STREAM. + Logger.debug( + "[process_grpc_request] Keeping stream #{stream_state.stream_id} alive for async response (will be removed after trailers)" + ) + + updated_connection + end + end + rescue + e -> + Logger.error("[process_grpc_request] Exception: #{inspect(e)}") + + Logger.error( + "[process_grpc_request] Stacktrace:\n#{Exception.format_stacktrace(__STACKTRACE__)}" + ) + + # Check if stream still exists (might have been removed by concurrent error handling) + if Map.has_key?(connection.streams, stream_state.stream_id) do + updated_connection = + send_grpc_error( + socket, + stream_state.stream_id, + %{status: :internal, message: "Internal error"}, + connection + ) + + %{ + updated_connection + | streams: Map.delete(updated_connection.streams, stream_state.stream_id) + } + else + Logger.debug( + "[process_grpc_request] Stream #{stream_state.stream_id} already removed in rescue, skipping error send" + ) + + connection + end + end + end + + defp send_grpc_trailers(socket, stream_id, trailers, connection) do + # Check if stream still exists (may have been closed by RST_STREAM) + unless Map.has_key?(connection.streams, stream_id) do + Logger.debug( + "[send_grpc_trailers] SKIPPED - stream=#{stream_id} no longer exists (likely cancelled by client)" + ) + + connection + else + Logger.debug( + "[send_grpc_trailers] Sending trailers for stream #{stream_id}: #{inspect(trailers)}" + ) + + trailer_list = Map.to_list(trailers) + + {trailer_block, new_hpack} = + HPAX.encode(:no_store, trailer_list, connection.send_hpack_state) + + headers_frame = %Frame.Headers{ + stream_id: stream_id, + fragment: trailer_block, + end_stream: true, + end_headers: true + } + + send_frame(headers_frame, socket, connection) + + %{connection | send_hpack_state: new_hpack} + end + end + + # OPTIMIZATION: Send headers + data + trailers in one syscall (Bandit-style batching) + defp send_grpc_response_batch( + socket, + stream_id, + response_headers, + response_data, + trailers, + connection + ) do + max_frame_size = connection.remote_settings.max_frame_size + + # Encode headers frame if provided + {headers_iodata, hpack_after_headers} = + if response_headers && response_headers != [] do + header_list = + if is_map(response_headers), do: Map.to_list(response_headers), else: response_headers + + {encoded_headers, new_hpack} = + HPAX.encode(:no_store, header_list, connection.send_hpack_state) + + headers_frame = %Frame.Headers{ + stream_id: stream_id, + fragment: encoded_headers, + end_stream: false, + end_headers: true + } + + {Frame.serialize(headers_frame, max_frame_size), new_hpack} + else + {[], connection.send_hpack_state} + end + + data_frame = %Frame.Data{ + stream_id: stream_id, + data: response_data, + end_stream: false + } + + data_iodata = Frame.serialize(data_frame, max_frame_size) + + trailer_list = Map.to_list(trailers) + {trailer_block, final_hpack} = HPAX.encode(:no_store, trailer_list, hpack_after_headers) + + trailers_frame = %Frame.Headers{ + stream_id: stream_id, + fragment: trailer_block, + end_stream: true, + end_headers: true + } + + trailers_iodata = Frame.serialize(trailers_frame, max_frame_size) + + # Combine all frames into one iolist and send in single syscall + combined_iodata = [headers_iodata, data_iodata, trailers_iodata] + + # Skip sending if socket is nil (test mode) + if socket != nil do + ThousandIsland.Socket.send(socket, combined_iodata) + end + + %{connection | send_hpack_state: final_hpack} + end + + defp extract_messages_from_buffer(stream_state) do + # Extract 5-byte length-prefixed messages from data_buffer + Logger.debug( + "[Connection] Extracting messages from data_buffer (#{byte_size(stream_state.data_buffer)} bytes)" + ) + + {messages, remaining} = extract_messages(stream_state.data_buffer, []) + # Reverse since we build list backwards for performance + extracted_count = length(messages) + + Logger.debug( + "[Connection] Extracted #{extracted_count} messages, #{byte_size(remaining)} bytes remaining" + ) + + %{stream_state | message_buffer: Enum.reverse(messages), data_buffer: remaining} + end + + # Optimized: prepend instead of append for O(1) instead of O(n) + defp extract_messages( + <>, + acc + ) do + message = %{compressed: compressed == 1, data: payload} + extract_messages(rest, [message | acc]) + end + + defp extract_messages(buffer, acc) do + # Not enough data for a complete message + {acc, buffer} + end + + # Made public so Handler can call it when deadline exceeded during send_reply + def send_grpc_error(socket, stream_id, error, connection) do + status = Map.get(error, :status, :unknown) + message = Map.get(error, :message, "Unknown error") + + # Convert atom status to integer code + status_code = if is_atom(status), do: apply(GRPC.Status, status, []), else: status + + stream_state = Map.get(connection.streams, stream_id) + + if !stream_state do + Logger.debug("[SEND_GRPC_ERROR] stream=#{stream_id} SKIPPED - stream not found") + connection + else + if stream_state.error_sent do + Logger.debug("[SEND_GRPC_ERROR] stream=#{stream_id} SKIPPED - error already sent") + connection + else + headers_sent = stream_state.headers_sent + end_stream_received = stream_state.end_stream_received + + Logger.debug( + "[SEND_GRPC_ERROR] stream=#{stream_id}, status=#{status_code}, message=#{message}, headers_sent=#{headers_sent}, end_stream_received=#{end_stream_received}" + ) + + # RFC 7540 Section 5.1: If we already sent END_STREAM (via headers_sent=true in previous response), + # the stream is "closed" and we CANNOT send more frames (except PRIORITY) + # In this case, just remove the stream and don't send anything + if headers_sent && end_stream_received do + Logger.debug( + "[SEND_GRPC_ERROR] stream=#{stream_id} SKIPPED - stream is fully closed (both sides sent END_STREAM)" + ) + + %{connection | streams: Map.delete(connection.streams, stream_id)} + else + # Check if headers were already sent for this stream + # ALWAYS send headers first if not sent yet, even if stream received END_STREAM + # HTTP/2 requires :status pseudo-header before trailers + updated_connection = + if !headers_sent do + Logger.debug( + "[SEND_GRPC_ERROR] Sending HTTP/2 headers first for stream=#{stream_id}" + ) + + headers = %{":status" => "200", "content-type" => "application/grpc+proto"} + send_headers(socket, stream_id, headers, connection) + + # Verify stream still exists after send_headers (may have been closed by RST_STREAM) + if Map.has_key?(connection.streams, stream_id) do + # Mark headers as sent in the stream state + updated_stream = %{stream_state | headers_sent: true} + updated_conn = put_in(connection.streams[stream_id], updated_stream) + + trailers = %{ + "grpc-status" => to_string(status_code), + "grpc-message" => message + } + + Logger.debug( + "[SEND_GRPC_ERROR] Sending TRAILERS with END_STREAM for stream=#{stream_id}" + ) + + send_grpc_trailers(socket, stream_id, trailers, updated_conn) + else + Logger.debug( + "[SEND_GRPC_ERROR] SKIPPED trailers - stream=#{stream_id} was closed after sending headers" + ) + + connection + end + else + # Headers already sent, just send trailers with error + trailers = %{ + "grpc-status" => to_string(status_code), + "grpc-message" => message + } + + Logger.debug( + "[SEND_GRPC_ERROR] Sending TRAILERS with END_STREAM for stream=#{stream_id}" + ) + + send_grpc_trailers(socket, stream_id, trailers, connection) + end + + # Verify stream still exists before marking error_sent (may have been closed by RST_STREAM) + updated_connection = + if Map.has_key?(updated_connection.streams, stream_id) do + # Mark that we already sent error (RFC 7540: after END_STREAM, the stream is closed) + update_in(updated_connection.streams[stream_id], fn s -> + if s, do: %{s | error_sent: true}, else: nil + end) + else + Logger.debug( + "[SEND_GRPC_ERROR] SKIPPED marking error_sent - stream=#{stream_id} was already closed" + ) + + updated_connection + end + + # RFC 7540: After sending END_STREAM, the stream transitions to "closed" + # Remove immediately to avoid processing more messages on this stream + Logger.warning("[REMOVE_STREAM] stream=#{stream_id} - removed after sending error") + %{updated_connection | streams: Map.delete(updated_connection.streams, stream_id)} + end + end + end + end + + # Check if buffer has at least one complete gRPC message + # gRPC message format: 1 byte compressed flag + 4 bytes length + N bytes data + defp has_complete_message?(buffer) when byte_size(buffer) < 5, do: false + + defp has_complete_message?(<<_compressed::8, length::32, rest::binary>>) do + byte_size(rest) >= length + end +end diff --git a/grpc_server/lib/grpc/server/http2/dispatcher.ex b/grpc_server/lib/grpc/server/http2/dispatcher.ex new file mode 100644 index 000000000..aa40dcf3a --- /dev/null +++ b/grpc_server/lib/grpc/server/http2/dispatcher.ex @@ -0,0 +1,829 @@ +defmodule GRPC.Server.HTTP2.Dispatcher do + @moduledoc """ + Dispatches gRPC calls to registered services. + + This module: + - Parses gRPC path ("/package.Service/Method") + - Looks up service implementation from servers registry + - Decodes protobuf request messages + - Calls service handler functions + - Encodes protobuf response messages + - Handles streaming (client/server/bidirectional) + """ + + require Logger + + alias GRPC.Server.Cache + alias GRPC.Server.HTTP2.StreamState + alias GRPC.RPCError + + # Inline hot path functions + @compile {:inline, + parse_path: 1, + lookup_server: 2, + lookup_rpc: 2, + get_codec: 2, + get_compressor: 2, + encode_response: 4, + decode_messages: 4} + + @doc """ + Dispatches a gRPC call to the appropriate service. + + ## Parameters + - `stream_state`: HTTP/2 stream with decoded headers + - `servers`: Map of service name => server module + - `endpoint`: The endpoint module (for telemetry, etc) + + ## Returns + - `{:ok, response_headers, response_data, trailers}` - Success + - `{:error, rpc_error}` - gRPC error to send to client + """ + @spec dispatch(StreamState.t(), map(), atom(), GRPC.Server.HTTP2.Connection.t()) :: + {:ok, list(), binary(), map()} | {:error, GRPC.RPCError.t()} + def dispatch(%StreamState{} = stream_state, servers, endpoint, connection) do + Logger.debug( + "[dispatch] path=#{stream_state.path}, messages=#{length(stream_state.message_buffer)}" + ) + + # Check deadline BEFORE processing - if exceeded, return error immediately + if StreamState.deadline_exceeded?(stream_state) do + now = System.monotonic_time(:microsecond) + + Logger.debug( + "[dispatch] Deadline exceeded for path=#{stream_state.path}, deadline=#{stream_state.deadline}, now=#{now}, diff=#{now - (stream_state.deadline || now)}us" + ) + + {:error, GRPC.RPCError.exception(status: :deadline_exceeded, message: "Deadline exceeded")} + else + with {:ok, service_name, method_name} <- parse_path(stream_state.path), + {:ok, server} <- lookup_server(servers, service_name), + {:ok, rpc} <- lookup_rpc(server, method_name), + {:ok, codec} <- get_codec(server, stream_state.content_type), + {:ok, compressor} <- get_compressor(server, stream_state.metadata), + {:ok, requests} <- + decode_messages(stream_state.message_buffer, rpc, codec, compressor), + {:ok, response} <- + call_service( + server, + rpc, + method_name, + requests, + stream_state, + endpoint, + codec, + compressor, + connection + ) do + Logger.info("[dispatch] Encoding response") + # Get custom headers/trailers from process dictionary + # (set by handler during execution via set_headers/set_trailers) + custom_headers = Process.get({:grpc_custom_headers, stream_state.stream_id}, %{}) + custom_trailers = Process.get({:grpc_custom_trailers, stream_state.stream_id}, %{}) + + # Update stream_state with custom headers/trailers + updated_stream_state = %{ + stream_state + | custom_headers: custom_headers, + custom_trailers: custom_trailers + } + + # Cleanup process dictionary + Process.delete({:grpc_custom_headers, stream_state.stream_id}) + Process.delete({:grpc_custom_trailers, stream_state.stream_id}) + + # Encode response(s) - could be single response or list for streaming + # Pass updated stream_state to include custom headers/trailers + encode_responses(response, codec, compressor, rpc, updated_stream_state) + else + {:error, %GRPC.RPCError{} = error} -> + error + + {:error, _reason} = err -> + Logger.error("Dispatch error: #{inspect(err)}") + {:error, GRPC.RPCError.exception(status: :internal, message: "Internal server error")} + end + end + end + + @doc """ + Parses gRPC path into service and method names. + + Examples: + - "/helloworld.Greeter/SayHello" → {"helloworld.Greeter", "SayHello"} + - "/package.subpackage.Service/Method" → {"package.subpackage.Service", "Method"} + """ + @spec parse_path(String.t()) :: {:ok, String.t(), String.t()} | {:error, RPCError.t()} + def parse_path("/" <> rest) do + case String.split(rest, "/", parts: 2) do + [service_name, method_name] when service_name != "" and method_name != "" -> + {:ok, service_name, method_name} + + _ -> + {:error, RPCError.exception(status: :unimplemented, message: "Invalid path format")} + end + end + + def parse_path(_) do + {:error, RPCError.exception(status: :unimplemented, message: "Path must start with /")} + end + + @doc """ + Checks if a gRPC path corresponds to a bidirectional streaming RPC. + + Returns `true` if both request and response are streaming, `false` otherwise. + """ + @spec is_bidi_streaming?(String.t(), map()) :: boolean() + def is_bidi_streaming?(path, servers) do + with {:ok, service_name, method_name} <- parse_path(path), + {:ok, server} <- lookup_server(servers, service_name), + {:ok, rpc} <- lookup_rpc(server, method_name) do + {_name, {_req_mod, req_stream?}, {_res_mod, res_stream?}, _opts} = rpc + req_stream? and res_stream? + else + _ -> false + end + end + + ## Private Functions + defp lookup_server(servers, service_name) do + case Map.get(servers, service_name) do + nil -> + {:error, + RPCError.exception(status: :unimplemented, message: "Service not found: #{service_name}")} + + server -> + {:ok, server} + end + end + + defp lookup_rpc(server, method_name) do + # Use cache to find RPC definition + case GRPC.Server.Cache.find_rpc(server, method_name) do + nil -> + {:error, + RPCError.exception( + status: :unimplemented, + message: "Method not found: #{method_name}" + )} + + rpc -> + {:ok, rpc} + end + end + + defp get_codec(server, content_type) do + # Extract codec subtype from content-type + # "application/grpc+proto" → "proto" + # "application/grpc+json" → "json" + subtype = + case String.split(content_type, "+", parts: 2) do + ["application/grpc", subtype] -> subtype + ["application/grpc"] -> "proto" + _ -> "proto" + end + + case Cache.find_codec(server, subtype) do + nil -> + {:error, + RPCError.exception(status: :unimplemented, message: "Codec not found: #{subtype}")} + + codec -> + {:ok, codec} + end + end + + defp get_compressor(server, metadata) do + # Check grpc-encoding header + encoding = Map.get(metadata, "grpc-encoding", "identity") + + if encoding == "identity" do + {:ok, nil} + else + case Cache.find_compressor(server, encoding) do + nil -> + {:error, + RPCError.exception( + status: :unimplemented, + message: "Compressor not found: #{encoding}" + )} + + compressor -> + {:ok, compressor} + end + end + end + + defp decode_messages(message_buffer, rpc, codec, compressor) do + # Extract request type from RPC definition + # RPC format: {name, {request_module, is_stream?}, {reply_module, is_stream?}, options} + {_name, {request_module, _is_stream?}, _reply, _options} = rpc + + try do + messages = + Enum.map(message_buffer, fn %{compressed: compressed?, data: data} -> + # Decompress if needed + data = + if compressed? and compressor do + compressor.decompress(data) + else + data + end + + # Decode protobuf with the request module + codec.decode(data, request_module) + end) + + {:ok, messages} + rescue + e -> + Logger.error("Failed to decode messages: #{inspect(e)}") + + {:error, + RPCError.exception(status: :invalid_argument, message: "Invalid request message")} + end + end + + defp call_unary(server, func_name, request, stream) do + Logger.info("[call_unary] Calling #{inspect(server)}.#{func_name}") + + # Check if function is implemented + if function_exported?(server, func_name, 2) do + # Accumulate base headers (don't send yet - handler may add custom headers) + base_headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + # Add grpc-encoding if there's a compressor + base_headers = + if stream.compressor do + Map.put(base_headers, "grpc-encoding", stream.compressor.name()) + else + base_headers + end + + GRPC.Server.set_headers(stream, base_headers) + + try do + # Call handler and get response + response = apply(server, func_name, [request, stream]) + Logger.info("[call_unary] Response received, sending via send_reply") + + # Send response using async message (this will send accumulated headers first) + # :noreply means GRPC.Stream.run() already sent the response + if response != :noreply do + GRPC.Server.send_reply(stream, response) + end + + # Send trailers at the end with END_STREAM + # Custom trailers already accumulated in Handler via set_resp_trailers + trailers = %{"grpc-status" => "0"} + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + + # Return special marker for async handling (like server_streaming) + {:ok, :streaming_done} + rescue + e in GRPC.RPCError -> + # Send error as trailers (headers already accumulated, will be sent with trailers) + _stream_id = stream.payload.stream_id + + error_trailers = %{ + "grpc-status" => "#{e.status}", + "grpc-message" => e.message || "" + } + + # Send trailers with error (will send accumulated headers first if not sent yet) + # Custom trailers already accumulated in Handler via set_resp_trailers + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + + # Return streaming_done (error already sent) + {:ok, :streaming_done} + + e -> + Logger.error("Handler error: #{Exception.message(e)}") + + error_trailers = %{ + # UNKNOWN + "grpc-status" => "2", + "grpc-message" => Exception.message(e) + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + + # Return streaming_done (error already sent) + {:ok, :streaming_done} + end + else + # Function not implemented + Logger.error("Function #{inspect(server)}.#{func_name}/2 is not implemented") + + # Send required HTTP/2 headers first + headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + GRPC.Server.Adapters.ThousandIsland.send_headers(stream.payload, headers) + + # Then send error trailers + error_trailers = %{ + # UNIMPLEMENTED + "grpc-status" => "12", + "grpc-message" => "Method not implemented" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + end + + defp call_client_streaming(server, func_name, requests, stream) do + # Check if function is implemented + if function_exported?(server, func_name, 2) do + try do + # Accumulate base headers (don't send yet - handler may add custom headers) + base_headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + # Add grpc-encoding if there's a compressor + base_headers = + if stream.compressor do + Map.put(base_headers, "grpc-encoding", stream.compressor.name()) + else + base_headers + end + + # Accumulate base headers without sending + GRPC.Server.set_headers(stream, base_headers) + + # Convert list to stream + request_enum = Enum.into(requests, []) + response = apply(server, func_name, [request_enum, stream]) + + # Send response using async message (this will send accumulated headers first) + GRPC.Server.send_reply(stream, response) + + # Send trailers at the end with END_STREAM + # Custom trailers already accumulated in Handler via set_resp_trailers + trailers = %{"grpc-status" => "0"} + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + + # Return special marker for async handling (like server_streaming) + {:ok, :streaming_done} + rescue + e in GRPC.RPCError -> + # Send error as trailers (headers already accumulated, will be sent with trailers) + error_trailers = %{ + "grpc-status" => "#{e.status}", + "grpc-message" => e.message || "" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + + e -> + Logger.error("Handler error: #{Exception.message(e)}") + + error_trailers = %{ + # UNKNOWN + "grpc-status" => "2", + "grpc-message" => Exception.message(e) + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + else + # Function not implemented + Logger.error("Function #{inspect(server)}.#{func_name}/2 is not implemented") + + # Send required HTTP/2 headers first + headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + GRPC.Server.Adapters.ThousandIsland.send_headers(stream.payload, headers) + + # Then send error trailers + error_trailers = %{ + # UNIMPLEMENTED + "grpc-status" => "12", + "grpc-message" => "Method not implemented" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + end + + defp call_server_streaming(server, func_name, request, stream) do + # Check if function is implemented + if function_exported?(server, func_name, 2) do + try do + # Accumulate base headers (don't send yet - handler may add custom headers) + base_headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + # Add grpc-encoding if there's a compressor + base_headers = + if stream.compressor do + Map.put(base_headers, "grpc-encoding", stream.compressor.name()) + else + base_headers + end + + GRPC.Server.set_headers(stream, base_headers) + + # Handler calls GRPC.Server.send_reply for each response + apply(server, func_name, [request, stream]) + + # Send trailers at the end with END_STREAM + # Custom trailers already accumulated in Handler via set_resp_trailers + trailers = %{"grpc-status" => "0"} + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + + # Return special marker for streaming (not data to send) + {:ok, :streaming_done} + rescue + e in GRPC.RPCError -> + # Send error as trailers + error_trailers = %{ + "grpc-status" => "#{e.status}", + "grpc-message" => e.message || "" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + + e -> + Logger.error("Handler error: #{Exception.message(e)}") + + error_trailers = %{ + # UNKNOWN + "grpc-status" => "2", + "grpc-message" => Exception.message(e) + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + else + # Function not implemented + Logger.error("Function #{inspect(server)}.#{func_name}/2 is not implemented") + + # Send required HTTP/2 headers first + headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + GRPC.Server.Adapters.ThousandIsland.send_headers(stream.payload, headers) + + # Then send error trailers + error_trailers = %{ + # UNIMPLEMENTED + "grpc-status" => "12", + "grpc-message" => "Method not implemented" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + end + + defp call_bidi_streaming(server, rpc, func_name, stream_state, stream, _connection) do + # Check if function is implemented + if function_exported?(server, func_name, 2) do + stream_id = stream_state.stream_id + message_buffer = stream_state.message_buffer + + Logger.info( + "[call_bidi_streaming] Starting bidi stream #{stream_id} with #{length(message_buffer)} initial requests" + ) + + try do + # Mark as streaming mode so send_headers will send immediately + Process.put(:grpc_streaming_mode, true) + + # Convert initial messages to {flag, data} format + initial_messages = + Enum.map(message_buffer, fn %{compressed: compressed?, data: data} -> + flag = if compressed?, do: 1, else: 0 + {flag, data} + end) + + # Start BidiStream Task with initial messages in {flag, data} format + {:ok, bidi_pid} = GRPC.Server.BidiStream.start_link(stream_id, initial_messages) + + Logger.info( + "[call_bidi_streaming] BidiStream task started for stream #{stream_id}, pid=#{inspect(bidi_pid)}" + ) + + # Monitor the BidiStream task to detect early termination + ref = Process.monitor(bidi_pid) + Logger.debug("[call_bidi_streaming] Monitoring BidiStream task with ref #{inspect(ref)}") + + # Store the BidiStream PID in stream_state for later use + # (when more DATA frames arrive) + stream_state = stream.payload.stream_state + updated_stream_state = %{stream_state | bidi_stream_pid: bidi_pid, handler_started: true} + + # CRITICAL: Store in process dictionary for immediate access + # We can't use send() because handler is blocked in this dispatch call! + _handler_pid = stream.payload.handler_pid + + Logger.info( + "[Dispatcher] Storing bidi_pid #{inspect(bidi_pid)} in process dictionary for stream #{stream_id}" + ) + + # Store codec, compressor, and RPC in stream_state for decoding subsequent messages + updated_stream_state = %{ + updated_stream_state + | codec: stream.codec, + compressor: stream.compressor, + rpc: stream.rpc + } + + # Store both PID and full state in process dictionary + Process.put({:bidi_stream_pid, stream_id}, bidi_pid) + Process.put({:bidi_stream_state, stream_id}, updated_stream_state) + + # Accumulate base headers (don't send yet - handler may add custom headers) + base_headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + # Add grpc-encoding if there's a compressor + base_headers = + if stream.compressor do + Map.put(base_headers, "grpc-encoding", stream.compressor.name()) + else + base_headers + end + + # Accumulate base headers without sending + GRPC.Server.set_headers(stream, base_headers) + + # Update the stream's payload to include bidi_stream_pid so adapter.reading_stream() can access it + updated_stream = %{ + stream + | payload: %{stream.payload | stream_state: updated_stream_state} + } + + # Get handler_pid to send task monitoring info + handler_pid = stream.payload.handler_pid + + # CRITICAL: Run handler in a separate supervised task to not block the connection handler + # The connection handler MUST continue processing incoming DATA frames + # and feed them to the BidiStream while the handler is consuming messages + # Use async_nolink to avoid taking down the supervisor if task crashes + task = + Task.Supervisor.async_nolink(GRPC.Server.StreamTaskSupervisor, fn -> + try do + # Use GRPC.Server.call to properly handle the request + # This ensures the reading_stream is created correctly via adapter.reading_stream + result = GRPC.Server.call(server, updated_stream, rpc, func_name) + Logger.info("[call_bidi_streaming] Handler returned: #{inspect(result)}") + + case result do + {:ok, stream} -> + # Send trailers with success status + # Custom trailers already accumulated in Handler via set_resp_trailers + trailers = %{"grpc-status" => "0"} + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + + {:error, error} -> + Logger.error("[call_bidi_streaming] Handler error result: #{inspect(error)}") + + trailers = %{ + "grpc-status" => "#{error.status || 2}", + "grpc-message" => error.message || "Handler error" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + end + rescue + e in GRPC.RPCError -> + Logger.error("[call_bidi_streaming] Handler RPC Error: #{inspect(e)}") + + trailers = %{ + "grpc-status" => "#{e.status}", + "grpc-message" => e.message || "" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + + e -> + Logger.error( + "[call_bidi_streaming] Handler error: #{Exception.message(e)}\n#{Exception.format_stacktrace(__STACKTRACE__)}" + ) + + trailers = %{ + "grpc-status" => "2", + "grpc-message" => Exception.message(e) + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + end + end) + + # Register task in handler for monitoring + # Handler will monitor the task and send error trailers if it crashes unexpectedly + send(handler_pid, {:register_stream_task, stream_id, task.pid, task.ref}) + + # Return special marker for streaming (not data to send) + {:ok, :streaming_done} + rescue + e in GRPC.RPCError -> + Logger.error("[call_bidi_streaming] RPC Error: #{inspect(e)}") + + trailers = %{ + "grpc-status" => "#{e.status}", + "grpc-message" => e.message || "" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + {:ok, :streaming_done} + + e -> + Logger.error( + "[call_bidi_streaming] Handler error: #{Exception.message(e)}\n#{Exception.format_stacktrace(__STACKTRACE__)}" + ) + + trailers = %{ + "grpc-status" => "2", + "grpc-message" => Exception.message(e) + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, trailers) + {:ok, :streaming_done} + end + else + # Function not implemented + Logger.error("Function #{inspect(server)}.#{func_name}/2 is not implemented") + + # Send required HTTP/2 headers first + headers = %{ + ":status" => "200", + "content-type" => "application/grpc+proto" + } + + GRPC.Server.Adapters.ThousandIsland.send_headers(stream.payload, headers) + + # Then send error trailers + error_trailers = %{ + # UNIMPLEMENTED + "grpc-status" => "12", + "grpc-message" => "Method not implemented" + } + + GRPC.Server.Adapters.ThousandIsland.send_trailers(stream.payload, error_trailers) + {:ok, :streaming_done} + end + end + + # Encode responses - handles both unary (single response) and streaming (special marker) + defp encode_responses(:streaming_done, _codec, _compressor, _rpc, _stream_state) do + # Streaming was already handled incrementally, nothing to encode + {:ok, :streaming_done} + end + + defp encode_responses(response, codec, compressor, rpc, stream_state) do + {_name, {_req_mod, _req_stream?}, {_res_mod, res_stream?}, _opts} = rpc + + if res_stream? do + # This path shouldn't be reached anymore (streaming handled in call_server_streaming) + Logger.debug("Unexpected: encode_responses called with streaming response") + {:ok, :streaming_done} + else + # Unary - encode single response with custom headers/trailers from stream_state + encode_response(response, codec, compressor, stream_state) + end + end + + # Encode a single message (without headers/trailers) + defp encode_message(response, codec, compressor) do + # Encode protobuf + encoded = codec.encode(response) + + # Compress if needed + {compressed_flag, data_binary} = + if compressor do + binary_data = IO.iodata_to_binary(encoded) + {1, compressor.compress(binary_data)} + else + {0, IO.iodata_to_binary(encoded)} + end + + # Calculate length + length = byte_size(data_binary) + + # Build 5-byte length-prefixed message + <> + end + + defp encode_response(response, codec, compressor, stream_state) do + try do + # Encode the message using the helper + message_data = encode_message(response, codec, compressor) + + # Build response headers + base_headers = [ + {":status", "200"}, + {"content-type", "application/grpc+proto"} + ] + + # Only add grpc-encoding if there's actual compression (not identity) + headers = + if compressor do + base_headers ++ [{"grpc-encoding", compressor.name()}] + else + base_headers + end + + # Merge custom headers from handler (for custom_metadata test) + headers = + if map_size(stream_state.custom_headers) > 0 do + headers ++ Map.to_list(stream_state.custom_headers) + else + headers + end + + # Build trailers (mandatory grpc-status) + trailers = %{ + "grpc-status" => "0" + } + + # Merge custom trailers from handler (for custom_metadata test) + trailers = Map.merge(trailers, stream_state.custom_trailers) + + {:ok, headers, message_data, trailers} + rescue + e -> + Logger.error("Failed to encode response: #{inspect(e)}") + {:error, RPCError.exception(status: :internal, message: "Failed to encode response")} + end + end + + defp call_service( + server, + rpc, + method_name, + requests, + stream_state, + endpoint, + codec, + compressor, + connection + ) do + {_name, {req_mod, req_stream?}, {res_mod, res_stream?}, _opts} = rpc + func_name = Macro.underscore(method_name) |> String.to_atom() + + # Determine gRPC type based on streaming flags + grpc_type = GRPC.Service.grpc_type(rpc) + + # Create a payload struct with metadata and handler info for streaming + payload = %{ + headers: stream_state.metadata, + stream_state: stream_state, + handler_pid: stream_state.handler_pid, + stream_id: stream_state.stream_id + } + + grpc_stream = %GRPC.Server.Stream{ + server: server, + endpoint: endpoint, + grpc_type: grpc_type, + request_mod: req_mod, + response_mod: res_mod, + rpc: rpc, + codec: codec, + compressor: compressor, + adapter: GRPC.Server.Adapters.ThousandIsland, + payload: payload + } + + case {req_stream?, res_stream?} do + {false, false} -> + [request] = requests + call_unary(server, func_name, request, grpc_stream) + + {true, false} -> + call_client_streaming(server, func_name, requests, grpc_stream) + + {false, true} -> + [request] = requests + call_server_streaming(server, func_name, request, grpc_stream) + + {true, true} -> + call_bidi_streaming(server, rpc, func_name, stream_state, grpc_stream, connection) + end + end +end diff --git a/grpc_server/lib/grpc/server/http2/stream_state.ex b/grpc_server/lib/grpc/server/http2/stream_state.ex new file mode 100644 index 000000000..78d8b45d0 --- /dev/null +++ b/grpc_server/lib/grpc/server/http2/stream_state.ex @@ -0,0 +1,360 @@ +defmodule GRPC.Server.HTTP2.StreamState do + @moduledoc """ + Manages individual HTTP/2 stream state for gRPC requests. + + Each gRPC call is handled as a separate HTTP/2 stream. This module: + - Decodes HTTP/2 headers into gRPC metadata + - Accumulates DATA frames into gRPC messages + - Handles 5-byte length-prefixed message framing + - Manages stream lifecycle (HEADERS -> DATA -> trailers) + """ + + alias GRPC.Transport.HTTP2.Frame + + @type stream_id :: pos_integer() + @type state :: :idle | :open | :half_closed_local | :half_closed_remote | :closed + + @type t :: %__MODULE__{ + stream_id: stream_id(), + state: state(), + # gRPC request info + path: String.t() | nil, + method: String.t() | nil, + authority: String.t() | nil, + content_type: String.t() | nil, + metadata: map(), + # Deadline (absolute time in microseconds when request expires) + deadline: integer() | nil, + # Message buffering + data_buffer: binary(), + message_buffer: [map()], + # Flow control + window_size: integer(), + # Trailers + trailers: map(), + # Handler PID for streaming + handler_pid: pid() | nil, + # Custom headers/trailers set by handler + custom_headers: map(), + custom_trailers: map(), + # Flag to track if response headers were sent + headers_sent: boolean(), + # Bidirectional streaming flag + is_bidi_streaming: boolean(), + # Flag to track if handler was already started (for bidi) + handler_started: boolean(), + # PID of the BidiStream task (for bidi streaming only) + bidi_stream_pid: pid() | nil, + # Codec, compressor, and RPC info for decoding subsequent messages + codec: module() | nil, + compressor: module() | nil, + rpc: tuple() | nil, + # Flag to track if client sent END_STREAM (stream half-closed remote) + end_stream_received: boolean(), + # Flag to track if we already sent an error response (prevent duplicates) + error_sent: boolean() + } + + defstruct stream_id: nil, + state: :idle, + path: nil, + method: nil, + authority: nil, + content_type: nil, + metadata: %{}, + data_buffer: <<>>, + message_buffer: [], + window_size: 65_535, + trailers: %{}, + is_bidi_streaming: false, + handler_started: false, + bidi_stream_pid: nil, + handler_pid: nil, + custom_headers: %{}, + custom_trailers: %{}, + headers_sent: false, + codec: nil, + compressor: nil, + rpc: nil, + deadline: nil, + end_stream_received: false, + error_sent: false + + # Parses grpc-timeout header and converts to absolute deadline + # Format: "1H" (hour), "1M" (minute), "1S" (second), "1m" (millisecond), "1u" (microsecond), "1n" (nanosecond) + defp parse_timeout_to_deadline(nil), do: nil + + defp parse_timeout_to_deadline(timeout_str) do + case Integer.parse(timeout_str) do + {value, "H"} -> System.monotonic_time(:microsecond) + value * 3_600_000_000 + {value, "M"} -> System.monotonic_time(:microsecond) + value * 60_000_000 + {value, "S"} -> System.monotonic_time(:microsecond) + value * 1_000_000 + {value, "m"} -> System.monotonic_time(:microsecond) + value * 1_000 + {value, "u"} -> System.monotonic_time(:microsecond) + value + {value, "n"} -> System.monotonic_time(:microsecond) + div(value, 1_000) + _ -> nil + end + end + + @doc """ + Creates a new stream state. + """ + @spec new(stream_id(), integer()) :: t() + def new(stream_id, initial_window_size \\ 65_535) do + %__MODULE__{ + stream_id: stream_id, + state: :idle, + window_size: initial_window_size + } + end + + @doc """ + Creates a stream state from decoded HTTP/2 headers. + Extracts pseudo-headers and metadata from header list. + """ + @spec from_headers(stream_id(), list({String.t(), String.t()}), integer()) :: t() + def from_headers(stream_id, headers, initial_window_size \\ 65_535) do + stream = new(stream_id, initial_window_size) + + {path, method, authority, content_type, timeout_str, metadata} = + Enum.reduce(headers, {nil, nil, nil, nil, nil, %{}}, fn + {":path", value}, {_, m, a, ct, t, meta} -> + {value, m, a, ct, t, meta} + + {":method", value}, {p, _, a, ct, t, meta} -> + {p, value, a, ct, t, meta} + + {":authority", value}, {p, m, _, ct, t, meta} -> + {p, m, value, ct, t, meta} + + {"content-type", value}, {p, m, a, _, t, meta} -> + {p, m, a, value, t, meta} + + {"grpc-timeout", value}, {p, m, a, ct, _, meta} -> + {p, m, a, ct, value, meta} + + # Skip other pseudo-headers + {":" <> _rest, _value}, acc -> + acc + + {key, value}, {p, m, a, ct, t, meta} -> + {p, m, a, ct, t, Map.put(meta, key, value)} + end) + + # Parse timeout and calculate deadline + deadline = parse_timeout_to_deadline(timeout_str) + + if timeout_str do + require Logger + + Logger.info( + "[StreamState] Parsed grpc-timeout: #{timeout_str} -> deadline: #{inspect(deadline)} (now: #{System.monotonic_time(:microsecond)})" + ) + end + + %{ + stream + | path: path, + method: method, + authority: authority, + content_type: content_type || "application/grpc+proto", + metadata: metadata, + deadline: deadline, + state: :open + } + end + + @doc """ + Checks if the stream's deadline has been exceeded. + Returns true if the deadline exists and has passed, false otherwise. + """ + @spec deadline_exceeded?(t()) :: boolean() + def deadline_exceeded?(%__MODULE__{deadline: nil}), do: false + + def deadline_exceeded?(%__MODULE__{deadline: deadline}) do + System.monotonic_time(:microsecond) > deadline + end + + @doc """ + Processes HEADERS frame and extracts gRPC metadata. + + HTTP/2 pseudo-headers are decoded into gRPC request fields: + - `:method` → "POST" (gRPC always uses POST) + - `:path` → "/package.Service/Method" + - `:authority` → "host:port" + - `content-type` → "application/grpc+proto" etc + + Other headers become gRPC metadata. + """ + @spec handle_headers(t(), Frame.Headers.t()) :: {:ok, t()} | {:error, term()} + def handle_headers(stream, %Frame.Headers{} = headers) do + case stream.state do + :idle -> + decode_headers(stream, headers) + + :half_closed_remote -> + # Trailers received + decode_trailers(stream, headers) + + _other -> + {:error, :protocol_error} + end + end + + @doc """ + Processes DATA frame and accumulates gRPC messages. + + gRPC uses 5-byte length-prefixed framing: + - Byte 0: Compression flag (0 = no compression, 1 = compressed) + - Bytes 1-4: Message length (big-endian uint32) + - Bytes 5+: Message payload + + Multiple messages can arrive in a single DATA frame or be split across frames. + """ + @spec handle_data(t(), Frame.Data.t()) :: {:ok, t(), [map()]} | {:error, term()} + def handle_data(stream, %Frame.Data{} = data) do + case stream.state do + :open -> + process_data(stream, data) + + _other -> + {:error, :stream_closed} + end + end + + @doc """ + Updates stream state based on end_stream flag. + """ + @spec maybe_close_stream(t(), boolean()) :: t() + def maybe_close_stream(stream, end_stream) do + if end_stream do + new_state = + case stream.state do + :open -> :half_closed_remote + :half_closed_local -> :closed + other -> other + end + + %{stream | state: new_state} + else + stream + end + end + + @doc """ + Updates stream window size for flow control. + """ + @spec update_window(t(), integer()) :: {:ok, t()} | {:error, :flow_control_error} + def update_window(stream, increment) do + case GRPC.Transport.HTTP2.FlowControl.update_window(stream.window_size, increment) do + {:ok, new_size} -> + {:ok, %{stream | window_size: new_size}} + + {:error, reason} -> + {:error, reason} + end + end + + @doc """ + Checks if stream has enough window to send data. + """ + @spec has_window?(t(), integer()) :: boolean() + def has_window?(stream, size) do + stream.window_size >= size + end + + ## Private Functions + + defp decode_headers(stream, headers) do + headers_map = + Enum.reduce(headers.headers, %{}, fn {name, value}, acc -> + Map.put(acc, name, value) + end) + + # Extract pseudo-headers + method = Map.get(headers_map, ":method") + path = Map.get(headers_map, ":path") + authority = Map.get(headers_map, ":authority") + content_type = Map.get(headers_map, "content-type") + + # Validate gRPC request + with :ok <- validate_method(method), + :ok <- validate_path(path), + :ok <- validate_content_type(content_type) do + # Extract metadata (non-pseudo headers) + metadata = + headers_map + |> Enum.filter(fn {name, _value} -> !String.starts_with?(name, ":") end) + |> Enum.into(%{}) + + stream = %{ + stream + | state: :open, + method: method, + path: path, + authority: authority, + content_type: content_type, + metadata: metadata + } + + stream = maybe_close_stream(stream, headers.end_stream) + + {:ok, stream} + else + {:error, reason} -> {:error, reason} + end + end + + defp decode_trailers(stream, headers) do + trailers = + Enum.reduce(headers.headers, %{}, fn {name, value}, acc -> + Map.put(acc, name, value) + end) + + stream = %{stream | trailers: trailers} + stream = maybe_close_stream(stream, headers.end_stream) + + {:ok, stream} + end + + defp validate_method("POST"), do: :ok + defp validate_method(_), do: {:error, :invalid_method} + + defp validate_path("/" <> _rest), do: :ok + defp validate_path(_), do: {:error, :invalid_path} + + defp validate_content_type("application/grpc" <> _), do: :ok + defp validate_content_type(_), do: {:error, :invalid_content_type} + + defp process_data(stream, data) do + # Append to buffer + buffer = stream.data_buffer <> data.data + + # Extract complete messages + {messages, remaining} = extract_messages(buffer, []) + + stream = %{ + stream + | data_buffer: remaining, + message_buffer: stream.message_buffer ++ messages + } + + stream = maybe_close_stream(stream, data.end_stream) + + {:ok, stream, messages} + end + + # Extract 5-byte length-prefixed messages + defp extract_messages( + <>, + acc + ) do + message = %{compressed: compressed == 1, data: payload} + extract_messages(rest, acc ++ [message]) + end + + defp extract_messages(buffer, acc) do + # Not enough data for a complete message + {acc, buffer} + end +end diff --git a/grpc_server/lib/grpc/server/stream.ex b/grpc_server/lib/grpc/server/stream.ex index fa0a7c195..d0766e0a4 100644 --- a/grpc_server/lib/grpc/server/stream.ex +++ b/grpc_server/lib/grpc/server/stream.ex @@ -98,10 +98,8 @@ defmodule GRPC.Server.Stream do data, opts ) do - opts = - opts - |> Keyword.put(:codec, codec) - |> Keyword.put(:http_transcode, access_mode == :http_transcoding) + # Optimize opts construction - avoid multiple Keyword operations + opts = [{:codec, codec}, {:http_transcode, access_mode == :http_transcoding} | opts] adapter.send_reply(stream.payload, data, opts) diff --git a/grpc_server/lib/grpc/server/supervisor.ex b/grpc_server/lib/grpc/server/supervisor.ex index 946d92cc4..0f7ea1928 100644 --- a/grpc_server/lib/grpc/server/supervisor.ex +++ b/grpc_server/lib/grpc/server/supervisor.ex @@ -103,6 +103,8 @@ defmodule GRPC.Server.Supervisor do servers end + GRPC.Server.Cache.init() + children = if opts[:start_server] do [child_spec(endpoint_or_servers, opts[:port], opts)] diff --git a/grpc_server/mix.exs b/grpc_server/mix.exs index f34e572a4..fa4143acf 100644 --- a/grpc_server/mix.exs +++ b/grpc_server/mix.exs @@ -38,6 +38,8 @@ defmodule GRPC.Server.MixProject do {:protobuf, "~> 0.14"}, {:cowboy, "~> 2.14"}, {:cowlib, "~> 2.14"}, + {:thousand_island, "~> 1.4"}, + {:hpax, "~> 1.0"}, {:flow, "~> 1.2"}, {:protobuf_generate, "~> 0.1.3", only: [:dev, :test]}, {:ex_parameterized, "~> 1.3.7", only: :test}, diff --git a/grpc_server/mix.lock b/grpc_server/mix.lock index 9431acb94..503ff9108 100644 --- a/grpc_server/mix.lock +++ b/grpc_server/mix.lock @@ -24,4 +24,5 @@ "ranch": {:hex, :ranch, "2.2.0", "25528f82bc8d7c6152c57666ca99ec716510fe0925cb188172f41ce93117b1b0", [:make, :rebar3], [], "hexpm", "fa0b99a1780c80218a4197a59ea8d3bdae32fbff7e88527d7d8a4787eff4f8e7"}, "rustler_precompiled": {:hex, :rustler_precompiled, "0.8.3", "4e741024b0b097fe783add06e53ae9a6f23ddc78df1010f215df0c02915ef5a8", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "c23f5f33cb6608542de4d04faf0f0291458c352a4648e4d28d17ee1098cddcc4"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "thousand_island": {:hex, :thousand_island, "1.4.2", "735fa783005d1703359bbd2d3a5a3a398075ba4456e5afe3c5b7cf4666303d36", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "1c7637f16558fc1c35746d5ee0e83b18b8e59e18d28affd1f2fa1645f8bc7473"}, } diff --git a/grpc_server/src/grpc_stream_h.erl b/grpc_server/src/grpc_stream_h.erl index 42e2c0525..20f576645 100644 --- a/grpc_server/src/grpc_stream_h.erl +++ b/grpc_server/src/grpc_stream_h.erl @@ -67,7 +67,9 @@ expect(Req) -> %% Stream isn't waiting for data. data(StreamID, IsFin, Data, State=#state{ read_body_ref=undefined, read_body_buffer=Buffer, body_length=BodyLen}) -> - Commands = case byte_size(Data) of + % Optimization: calculate byte_size once and reuse + DataSize = byte_size(Data), + Commands = case DataSize of 0 -> []; Size -> @@ -77,14 +79,16 @@ data(StreamID, IsFin, Data, State=#state{ expect=undefined, read_body_is_fin=IsFin, read_body_buffer= << Buffer/binary, Data/binary >>, - body_length=BodyLen + byte_size(Data) + body_length=BodyLen + DataSize }); %% Stream is waiting for data using auto mode. % GRPC: We don't pass auto, but treat it as auto data(StreamID, IsFin, Data, State=#state{read_body_pid=Pid, read_body_ref=Ref, body_length=BodyLen}) -> send_request_body(Pid, Ref, IsFin, BodyLen, Data), - Commands = case byte_size(Data) of + % Optimization: calculate byte_size once and reuse + DataSize = byte_size(Data), + Commands = case DataSize of 0 -> []; Size -> diff --git a/grpc_server/test/grpc/server/adapters/thousand_island_test.exs b/grpc_server/test/grpc/server/adapters/thousand_island_test.exs new file mode 100644 index 000000000..dee718580 --- /dev/null +++ b/grpc_server/test/grpc/server/adapters/thousand_island_test.exs @@ -0,0 +1,60 @@ +defmodule GRPC.Server.Adapters.ThousandIslandTest do + use ExUnit.Case, async: true + + alias GRPC.Server.Adapters.ThousandIsland, as: Adapter + + describe "child_spec/2" do + test "returns valid child spec" do + spec = Adapter.child_spec(:test_endpoint, [], 50051, []) + + assert is_map(spec) + assert spec.id == Adapter.Supervisor + assert spec.type == :supervisor + assert {Adapter.Supervisor, :start_link, [opts_list]} = spec.start + assert is_list(opts_list) + assert Keyword.get(opts_list, :endpoint) == :test_endpoint + assert Keyword.get(opts_list, :servers) == [] + assert Keyword.get(opts_list, :port) == 50051 + assert Keyword.get(opts_list, :adapter_opts) == [] + assert Keyword.get(opts_list, :cred) == nil + end + + test "includes adapter options in child spec" do + adapter_opts = [num_acceptors: 5, num_connections: 50] + spec = Adapter.child_spec(:test_endpoint, [], 50051, adapter_opts: adapter_opts) + + {Adapter.Supervisor, :start_link, [opts_list]} = spec.start + adapter_opts_kw = Keyword.get(opts_list, :adapter_opts) + assert Keyword.get(adapter_opts_kw, :num_acceptors) == 5 + assert Keyword.get(adapter_opts_kw, :num_connections) == 50 + end + end + + describe "start/4" do + test "can start and stop server" do + {:ok, pid, port} = Adapter.start(:test_server_unique_a, [], 0, []) + + assert is_pid(pid) + assert is_integer(port) + # Note: Port 0 means "choose any available port" + # ThousandIsland returns the actual assigned port + assert Process.alive?(pid) + + # Stop server + Supervisor.stop(pid) + refute Process.alive?(pid) + end + + test "accepts custom options" do + opts = [ + num_acceptors: 2, + num_connections: 10 + ] + + {:ok, pid, _port} = Adapter.start(:test_server_unique_b, [], 0, opts) + assert Process.alive?(pid) + + Supervisor.stop(pid) + end + end +end diff --git a/grpc_server/test/grpc/server/http2/connection_test.exs b/grpc_server/test/grpc/server/http2/connection_test.exs new file mode 100644 index 000000000..3ecae0437 --- /dev/null +++ b/grpc_server/test/grpc/server/http2/connection_test.exs @@ -0,0 +1,299 @@ +defmodule GRPC.Server.HTTP2.ConnectionTest do + use ExUnit.Case, async: true + + alias GRPC.Server.HTTP2.Connection + alias GRPC.Transport.HTTP2.{Frame, Settings, Errors} + + # For now, we'll test without mocking the socket + # Just test the logic of handle_frame functions + + describe "handle_frame/3 - SETTINGS" do + test "applies SETTINGS and updates remote settings" do + # Create a minimal connection struct + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{}, + send_hpack_state: HPAX.new(4096), + recv_hpack_state: HPAX.new(4096) + } + + settings_frame = %Frame.Settings{ + ack: false, + settings: [header_table_size: 8192, max_frame_size: 32_768] + } + + # Mock socket - we won't actually call send + socket = nil + + new_connection = Connection.handle_frame(settings_frame, socket, connection) + + # Should have updated remote settings + assert new_connection.remote_settings.header_table_size == 8192 + assert new_connection.remote_settings.max_frame_size == 32_768 + end + + test "ignores SETTINGS ACK frames" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + ack_frame = %Frame.Settings{ack: true, settings: []} + socket = nil + + new_connection = Connection.handle_frame(ack_frame, socket, connection) + + # Connection should be unchanged (except socket interaction) + assert new_connection.remote_settings == connection.remote_settings + end + + test "raises error for invalid initial window size" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + invalid_frame = %Frame.Settings{ + ack: false, + settings: [initial_window_size: 3_000_000_000] + } + + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(invalid_frame, socket, connection) + end + end + + test "raises error for invalid max frame size (too small)" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + invalid_frame = %Frame.Settings{ + ack: false, + settings: [max_frame_size: 1000] + } + + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(invalid_frame, socket, connection) + end + end + + test "raises error for invalid max frame size (too large)" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + invalid_frame = %Frame.Settings{ + ack: false, + settings: [max_frame_size: 20_000_000] + } + + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(invalid_frame, socket, connection) + end + end + + test "updates HPACK table size when header_table_size changes" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{header_table_size: 4096}, + send_hpack_state: HPAX.new(4096), + recv_hpack_state: HPAX.new(4096) + } + + settings_frame = %Frame.Settings{ + ack: false, + settings: [header_table_size: 8192] + } + + socket = nil + + new_connection = Connection.handle_frame(settings_frame, socket, connection) + + # HPACK state should be resized + assert new_connection.remote_settings.header_table_size == 8192 + end + end + + describe "handle_frame/3 - PING" do + test "connection unchanged after PING (ACK sent via socket)" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + ping_frame = %Frame.Ping{ack: false, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>} + socket = nil + + new_connection = Connection.handle_frame(ping_frame, socket, connection) + + # Connection state should be unchanged (response sent via socket) + assert new_connection == connection + end + + test "ignores PING ACK frames" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + pong_frame = %Frame.Ping{ack: true, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>} + socket = nil + + new_connection = Connection.handle_frame(pong_frame, socket, connection) + + assert new_connection == connection + end + end + + describe "handle_frame/3 - WINDOW_UPDATE" do + test "updates connection send window" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{}, + send_window_size: 65_535 + } + + window_update = %Frame.WindowUpdate{stream_id: 0, size_increment: 1000} + socket = nil + + new_connection = Connection.handle_frame(window_update, socket, connection) + + assert new_connection.send_window_size == 66_535 + end + + test "raises error on window overflow" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{}, + send_window_size: 2_147_483_647 + } + + window_update = %Frame.WindowUpdate{stream_id: 0, size_increment: 1000} + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(window_update, socket, connection) + end + end + end + + describe "handle_frame/3 - CONTINUATION" do + test "accumulates CONTINUATION frames until end_headers" do + headers_frame = %Frame.Headers{ + stream_id: 1, + end_stream: false, + end_headers: false, + fragment: <<1, 2, 3>> + } + + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{}, + fragment_frame: headers_frame + } + + cont1 = %Frame.Continuation{ + stream_id: 1, + end_headers: false, + fragment: <<4, 5, 6>> + } + + socket = nil + + new_connection = Connection.handle_frame(cont1, socket, connection) + + # Should have accumulated fragment + assert new_connection.fragment_frame != nil + assert new_connection.fragment_frame.fragment == <<1, 2, 3, 4, 5, 6>> + end + + test "raises error if non-CONTINUATION frame while expecting CONTINUATION" do + headers_frame = %Frame.Headers{ + stream_id: 1, + end_stream: false, + end_headers: false, + fragment: <<1, 2, 3>> + } + + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{}, + fragment_frame: headers_frame + } + + # Send DATA frame instead of CONTINUATION + data_frame = %Frame.Data{stream_id: 1, end_stream: false, data: <<>>} + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(data_frame, socket, connection) + end + end + end + + describe "handle_frame/3 - unsupported frames" do + test "raises error for PUSH_PROMISE (not supported in gRPC)" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + push_frame = %Frame.PushPromise{ + stream_id: 1, + promised_stream_id: 2, + end_headers: true, + fragment: <<>> + } + + socket = nil + + assert_raise Errors.ConnectionError, fn -> + Connection.handle_frame(push_frame, socket, connection) + end + end + + test "ignores PRIORITY frames (gRPC doesn't use priorities)" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + priority_frame = %Frame.Priority{ + stream_id: 1, + exclusive_dependency: false, + stream_dependency: 0, + weight: 16 + } + + socket = nil + + new_connection = Connection.handle_frame(priority_frame, socket, connection) + + # Should be unchanged + assert new_connection == connection + end + + test "ignores UNKNOWN frames" do + connection = %Connection{ + local_settings: %Settings{}, + remote_settings: %Settings{} + } + + unknown_frame = %Frame.Unknown{type: 255, flags: 0, stream_id: 0, payload: <<>>} + socket = nil + + new_connection = Connection.handle_frame(unknown_frame, socket, connection) + + assert new_connection == connection + end + end +end diff --git a/grpc_server/test/grpc/server/http2/errors_test.exs b/grpc_server/test/grpc/server/http2/errors_test.exs new file mode 100644 index 000000000..c1e323875 --- /dev/null +++ b/grpc_server/test/grpc/server/http2/errors_test.exs @@ -0,0 +1,79 @@ +defmodule GRPC.Transport.HTTP2.ErrorsTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Errors + + describe "error codes" do + test "no_error returns 0x0" do + assert Errors.no_error() == 0x0 + end + + test "protocol_error returns 0x1" do + assert Errors.protocol_error() == 0x1 + end + + test "internal_error returns 0x2" do + assert Errors.internal_error() == 0x2 + end + + test "flow_control_error returns 0x3" do + assert Errors.flow_control_error() == 0x3 + end + + test "settings_timeout returns 0x4" do + assert Errors.settings_timeout() == 0x4 + end + + test "stream_closed returns 0x5" do + assert Errors.stream_closed() == 0x5 + end + + test "frame_size_error returns 0x6" do + assert Errors.frame_size_error() == 0x6 + end + + test "refused_stream returns 0x7" do + assert Errors.refused_stream() == 0x7 + end + + test "cancel returns 0x8" do + assert Errors.cancel() == 0x8 + end + + test "compression_error returns 0x9" do + assert Errors.compression_error() == 0x9 + end + + test "connect_error returns 0xA" do + assert Errors.connect_error() == 0xA + end + + test "enhance_your_calm returns 0xB" do + assert Errors.enhance_your_calm() == 0xB + end + + test "inadequate_security returns 0xC" do + assert Errors.inadequate_security() == 0xC + end + + test "http_1_1_requires returns 0xD" do + assert Errors.http_1_1_requires() == 0xD + end + end + + describe "ConnectionError" do + test "can be raised with message and error code" do + assert_raise Errors.ConnectionError, "test message", fn -> + raise Errors.ConnectionError, message: "test message", error_code: 0x1 + end + end + end + + describe "StreamError" do + test "can be raised with message, error code and stream_id" do + assert_raise Errors.StreamError, "test message", fn -> + raise Errors.StreamError, message: "test message", error_code: 0x1, stream_id: 1 + end + end + end +end diff --git a/grpc_server/test/grpc/server/http2/flow_control_test.exs b/grpc_server/test/grpc/server/http2/flow_control_test.exs new file mode 100644 index 000000000..b78885b7f --- /dev/null +++ b/grpc_server/test/grpc/server/http2/flow_control_test.exs @@ -0,0 +1,80 @@ +defmodule GRPC.Transport.HTTP2.FlowControlTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.FlowControl + + import Bitwise + + @max_window_size (1 <<< 31) - 1 + @min_window_size 1 <<< 30 + + describe "compute_recv_window/2" do + test "returns no increment when window is still large" do + # Window is still above minimum, no need to update yet + recv_window_size = @min_window_size + 1000 + data_size = 500 + + assert {new_window, 0} = FlowControl.compute_recv_window(recv_window_size, data_size) + assert new_window == recv_window_size - data_size + end + + test "returns increment when window falls below minimum" do + # Window falls below minimum, need to send WINDOW_UPDATE + recv_window_size = @min_window_size + 100 + data_size = 200 + + assert {new_window, increment} = + FlowControl.compute_recv_window(recv_window_size, data_size) + + assert new_window > recv_window_size - data_size + assert increment > 0 + end + + test "respects maximum window size" do + # Even when requesting update, should not exceed max window size + recv_window_size = @min_window_size - 1000 + data_size = 100 + + assert {new_window, increment} = + FlowControl.compute_recv_window(recv_window_size, data_size) + + assert new_window <= @max_window_size + assert increment > 0 + end + + test "handles small recv_window_size" do + recv_window_size = 1000 + data_size = 500 + + assert {new_window, increment} = + FlowControl.compute_recv_window(recv_window_size, data_size) + + assert new_window > recv_window_size - data_size + assert increment > 0 + end + + test "handles edge case when recv_window equals min_window_size" do + recv_window_size = @min_window_size + data_size = 1 + + assert {new_window, increment} = + FlowControl.compute_recv_window(recv_window_size, data_size) + + assert new_window > recv_window_size - data_size + assert increment > 0 + end + + test "handles large data size" do + recv_window_size = @max_window_size + data_size = @min_window_size + 1000 + + assert {new_window, increment} = + FlowControl.compute_recv_window(recv_window_size, data_size) + + # Window should still be positive after large data consumption + assert new_window >= 0 + # May or may not need increment depending on resulting window size + assert is_integer(increment) + end + end +end diff --git a/grpc_server/test/grpc/server/http2/frame_test.exs b/grpc_server/test/grpc/server/http2/frame_test.exs new file mode 100644 index 000000000..f5d01b709 --- /dev/null +++ b/grpc_server/test/grpc/server/http2/frame_test.exs @@ -0,0 +1,260 @@ +defmodule GRPC.Transport.HTTP2.FrameTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.{Frame, Errors} + + describe "frame deserialization" do + test "deserializes DATA frames" do + # DATA frame: stream_id=1, no padding, end_stream=false + frame = <<0, 0, 3, 0, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, %Frame.Data{stream_id: 1, end_stream: false, data: <<1, 2, 3>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes HEADERS frames" do + # HEADERS frame: stream_id=1, end_stream=false, end_headers=false + frame = <<0, 0, 3, 1, 0x00, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Headers{ + stream_id: 1, + end_stream: false, + end_headers: false, + fragment: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes SETTINGS frames" do + # SETTINGS frame: max_frame_size=32768 + frame = <<0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 128, 0>> + + assert {{:ok, %Frame.Settings{ack: false, settings: %{max_frame_size: 32_768}}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes SETTINGS ACK frames" do + frame = <<0, 0, 0, 4, 1, 0, 0, 0, 0>> + + assert {{:ok, %Frame.Settings{ack: true}}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes PING frames" do + frame = <<0, 0, 8, 6, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8>> + + assert {{:ok, %Frame.Ping{ack: false, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes PING ACK frames" do + frame = <<0, 0, 8, 6, 1, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8>> + + assert {{:ok, %Frame.Ping{ack: true, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes GOAWAY frames" do + frame = <<0, 0, 8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2>> + + assert {{:ok, %Frame.Goaway{last_stream_id: 1, error_code: 2, debug_data: <<>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes RST_STREAM frames" do + frame = <<0, 0, 4, 3, 0, 0, 0, 0, 1, 0, 0, 0, 8>> + + assert {{:ok, %Frame.RstStream{stream_id: 1, error_code: 8}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes WINDOW_UPDATE frames" do + frame = <<0, 0, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 100>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 0, size_increment: 100}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes CONTINUATION frames" do + frame = <<0, 0, 3, 9, 0x00, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Continuation{ + stream_id: 1, + end_headers: false, + fragment: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes PRIORITY frames" do + frame = <<0, 0, 5, 2, 0, 0, 0, 0, 1, 0::1, 12::31, 34>> + + assert {{:ok, + %Frame.Priority{ + stream_id: 1, + exclusive_dependency: false, + stream_dependency: 12, + weight: 34 + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes unknown frame types" do + # Unknown type 0xFF + frame = <<0, 0, 3, 0xFF, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Unknown{ + type: 0xFF, + flags: 0, + stream_id: 1, + payload: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "returns extra data when frame is followed by more data" do + frame = <<0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, %Frame.Data{stream_id: 1, data: <<>>}}, <<1, 2, 3>>} = + Frame.deserialize(frame, 16_384) + end + + test "asks for more data when frame is incomplete" do + frame = <<0, 0, 10, 0, 0, 0, 0, 0, 1>> + + assert {{:more, <<0, 0, 10, 0, 0, 0, 0, 0, 1>>}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "returns nil when buffer is empty" do + assert Frame.deserialize(<<>>, 16_384) == nil + end + + test "rejects frames that exceed max_frame_size" do + # Frame with length 100, max_frame_size 50 + frame = <<0, 0, 100, 0, 0, 0, 0, 0, 1>> <> :binary.copy(<<0>>, 100) + + assert {{:error, error_code, "Payload size too large (RFC9113§4.2)"}, _rest} = + Frame.deserialize(frame, 50) + + assert error_code == Errors.frame_size_error() + end + end + + describe "frame serialization" do + test "serializes DATA frames" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: false, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 0, 0, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes DATA frames with end_stream set" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: true, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 0, 1, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes HEADERS frames" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + fragment: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 1, 4, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes SETTINGS frames" do + frame = %Frame.Settings{ + ack: false, + settings: %{ + header_table_size: 8_192, + max_frame_size: 32_768 + } + } + + result = Frame.serialize(frame, 16_384) + assert [[header, payload]] = result + assert <<0, 0, 12, 4, 0, 0, 0, 0, 0>> = header + # Payload should contain both settings + payload_binary = IO.iodata_to_binary(payload) + assert byte_size(payload_binary) == 12 + end + + test "serializes SETTINGS ACK frames" do + frame = %Frame.Settings{ack: true, settings: %{}} + + assert Frame.serialize(frame, 16_384) == [[<<0, 0, 0, 4, 1, 0, 0, 0, 0>>, <<>>]] + end + + test "serializes PING frames" do + frame = %Frame.Ping{ack: false, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 8, 6, 0, 0, 0, 0, 0>>, <<1, 2, 3, 4, 5, 6, 7, 8>>] + ] + end + + test "serializes GOAWAY frames" do + frame = %Frame.Goaway{last_stream_id: 1, error_code: 2, debug_data: <<>>} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 8, 7, 0, 0, 0, 0, 0>>, <<0, 0, 0, 1, 0, 0, 0, 2>>] + ] + end + + test "serializes RST_STREAM frames" do + frame = %Frame.RstStream{stream_id: 1, error_code: 8} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 4, 3, 0, 0, 0, 0, 1>>, <<0, 0, 0, 8>>] + ] + end + + test "serializes WINDOW_UPDATE frames" do + frame = %Frame.WindowUpdate{stream_id: 123, size_increment: 234} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 4, 8, 0, 0, 0, 0, 123>>, <<0, 0, 0, 234>>] + ] + end + + test "splits DATA frames that exceed max_frame_size" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: false, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 2) == [ + [<<0, 0, 2, 0, 0, 0, 0, 0, 123>>, <<1, 2>>], + [<<0, 0, 1, 0, 0, 0, 0, 0, 123>>, <<3>>] + ] + end + + test "splits HEADERS frames into HEADERS + CONTINUATION" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + fragment: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 2) == [ + [<<0, 0, 2, 1, 0, 0, 0, 0, 123>>, <<1, 2>>], + [<<0, 0, 1, 9, 4, 0, 0, 0, 123>>, <<3>>] + ] + end + end +end diff --git a/grpc_server/test/grpc/server/http2/frame_test.exs.bak b/grpc_server/test/grpc/server/http2/frame_test.exs.bak new file mode 100644 index 000000000..2060b54dd --- /dev/null +++ b/grpc_server/test/grpc/server/http2/frame_test.exs.bak @@ -0,0 +1,259 @@ +defmodule GRPC.Server.HTTP2.FrameTest do + use ExUnit.Case, async: true + + alias GRPC.Server.HTTP2.{Frame, Errors} + + describe "frame deserialization" do + test "deserializes DATA frames" do + # DATA frame: stream_id=1, no padding, end_stream=false + frame = <<0, 0, 3, 0, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, %Frame.Data{stream_id: 1, end_stream: false, data: <<1, 2, 3>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes HEADERS frames" do + # HEADERS frame: stream_id=1, end_stream=false, end_headers=false + frame = <<0, 0, 3, 1, 0x00, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Headers{ + stream_id: 1, + end_stream: false, + end_headers: false, + fragment: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes SETTINGS frames" do + # SETTINGS frame: max_frame_size=32768 + frame = <<0, 0, 6, 4, 0, 0, 0, 0, 0, 0, 5, 0, 0, 128, 0>> + + assert {{:ok, %Frame.Settings{ack: false, settings: %{max_frame_size: 32_768}}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes SETTINGS ACK frames" do + frame = <<0, 0, 0, 4, 1, 0, 0, 0, 0>> + + assert {{:ok, %Frame.Settings{ack: true}}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes PING frames" do + frame = <<0, 0, 8, 6, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8>> + + assert {{:ok, %Frame.Ping{ack: false, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes PING ACK frames" do + frame = <<0, 0, 8, 6, 1, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8>> + + assert {{:ok, %Frame.Ping{ack: true, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes GOAWAY frames" do + frame = <<0, 0, 8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2>> + + assert {{:ok, %Frame.Goaway{last_stream_id: 1, error_code: 2, debug_data: <<>>}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes RST_STREAM frames" do + frame = <<0, 0, 4, 3, 0, 0, 0, 0, 1, 0, 0, 0, 8>> + + assert {{:ok, %Frame.RstStream{stream_id: 1, error_code: 8}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes WINDOW_UPDATE frames" do + frame = <<0, 0, 4, 8, 0, 0, 0, 0, 0, 0, 0, 0, 100>> + + assert {{:ok, %Frame.WindowUpdate{stream_id: 0, size_increment: 100}}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "deserializes CONTINUATION frames" do + frame = <<0, 0, 3, 9, 0x00, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Continuation{ + stream_id: 1, + end_headers: false, + fragment: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes PRIORITY frames" do + frame = <<0, 0, 5, 2, 0, 0, 0, 0, 1, 0::1, 12::31, 34>> + + assert {{:ok, + %Frame.Priority{ + stream_id: 1, + exclusive_dependency: false, + stream_dependency: 12, + weight: 34 + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "deserializes unknown frame types" do + # Unknown type 0xFF + frame = <<0, 0, 3, 0xFF, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, + %Frame.Unknown{ + type: 0xFF, + flags: 0, + stream_id: 1, + payload: <<1, 2, 3>> + }}, <<>>} = Frame.deserialize(frame, 16_384) + end + + test "returns extra data when frame is followed by more data" do + frame = <<0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 2, 3>> + + assert {{:ok, %Frame.Data{stream_id: 1, data: <<>>}}, <<1, 2, 3>>} = + Frame.deserialize(frame, 16_384) + end + + test "asks for more data when frame is incomplete" do + frame = <<0, 0, 10, 0, 0, 0, 0, 0, 1>> + + assert {{:more, <<0, 0, 10, 0, 0, 0, 0, 0, 1>>}, <<>>} = + Frame.deserialize(frame, 16_384) + end + + test "returns nil when buffer is empty" do + assert Frame.deserialize(<<>>, 16_384) == nil + end + + test "rejects frames that exceed max_frame_size" do + # Frame with length 100, max_frame_size 50 + frame = <<0, 0, 100, 0, 0, 0, 0, 0, 1>> <> :binary.copy(<<0>>, 100) + + assert {{:error, error_code, "Payload size too large (RFC9113§4.2)"}, _rest} = + Frame.deserialize(frame, 50) + + assert error_code == Errors.frame_size_error() + end + end + + describe "frame serialization" do + test "serializes DATA frames" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: false, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 0, 0, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes DATA frames with end_stream set" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: true, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 0, 1, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes HEADERS frames" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + fragment: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 3, 1, 4, 0, 0, 0, 123>>, <<1, 2, 3>>] + ] + end + + test "serializes SETTINGS frames" do + frame = %Frame.Settings{ + ack: false, + settings: %{ + header_table_size: 8_192, + max_frame_size: 32_768 + } + } + + result = Frame.serialize(frame, 16_384) + assert [[header, payload]] = result + assert <<0, 0, 12, 4, 0, 0, 0, 0, 0>> = header + # Payload should contain both settings + assert byte_size(payload) == 12 + end + + test "serializes SETTINGS ACK frames" do + frame = %Frame.Settings{ack: true, settings: %{}} + + assert Frame.serialize(frame, 16_384) == [[<<0, 0, 0, 4, 1, 0, 0, 0, 0>>, <<>>]] + end + + test "serializes PING frames" do + frame = %Frame.Ping{ack: false, payload: <<1, 2, 3, 4, 5, 6, 7, 8>>} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 8, 6, 0, 0, 0, 0, 0>>, <<1, 2, 3, 4, 5, 6, 7, 8>>] + ] + end + + test "serializes GOAWAY frames" do + frame = %Frame.Goaway{last_stream_id: 1, error_code: 2, debug_data: <<>>} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 8, 7, 0, 0, 0, 0, 0>>, <<0, 0, 0, 1, 0, 0, 0, 2>>] + ] + end + + test "serializes RST_STREAM frames" do + frame = %Frame.RstStream{stream_id: 1, error_code: 8} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 4, 3, 0, 0, 0, 0, 1>>, <<0, 0, 0, 8>>] + ] + end + + test "serializes WINDOW_UPDATE frames" do + frame = %Frame.WindowUpdate{stream_id: 123, size_increment: 234} + + assert Frame.serialize(frame, 16_384) == [ + [<<0, 0, 4, 8, 0, 0, 0, 0, 123>>, <<0, 0, 0, 234>>] + ] + end + + test "splits DATA frames that exceed max_frame_size" do + frame = %Frame.Data{ + stream_id: 123, + end_stream: false, + data: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 2) == [ + [<<0, 0, 2, 0, 0, 0, 0, 0, 123>>, <<1, 2>>], + [<<0, 0, 1, 0, 1, 0, 0, 0, 123>>, <<3>>] + ] + end + + test "splits HEADERS frames into HEADERS + CONTINUATION" do + frame = %Frame.Headers{ + stream_id: 123, + end_stream: false, + fragment: <<1, 2, 3>> + } + + assert Frame.serialize(frame, 2) == [ + [<<0, 0, 2, 1, 0, 0, 0, 0, 123>>, <<1, 2>>], + [<<0, 0, 1, 9, 4, 0, 0, 0, 123>>, <<3>>] + ] + end + end +end diff --git a/grpc_server/test/grpc/server/http2/settings_test.exs b/grpc_server/test/grpc/server/http2/settings_test.exs new file mode 100644 index 000000000..c58f37243 --- /dev/null +++ b/grpc_server/test/grpc/server/http2/settings_test.exs @@ -0,0 +1,44 @@ +defmodule GRPC.Transport.HTTP2.SettingsTest do + use ExUnit.Case, async: true + + alias GRPC.Transport.HTTP2.Settings + + describe "default settings" do + test "has correct default values" do + settings = %Settings{} + + assert settings.header_table_size == 4_096 + assert settings.max_concurrent_streams == :infinity + assert settings.initial_window_size == 65_535 + assert settings.max_frame_size == 16_384 + assert settings.max_header_list_size == :infinity + end + end + + describe "settings modification" do + test "can update header_table_size" do + settings = %Settings{header_table_size: 8_192} + assert settings.header_table_size == 8_192 + end + + test "can update max_concurrent_streams" do + settings = %Settings{max_concurrent_streams: 100} + assert settings.max_concurrent_streams == 100 + end + + test "can update initial_window_size" do + settings = %Settings{initial_window_size: 32_768} + assert settings.initial_window_size == 32_768 + end + + test "can update max_frame_size" do + settings = %Settings{max_frame_size: 32_768} + assert settings.max_frame_size == 32_768 + end + + test "can update max_header_list_size" do + settings = %Settings{max_header_list_size: 16_384} + assert settings.max_header_list_size == 16_384 + end + end +end diff --git a/grpc_server/test/grpc/server/supervisor_test.exs b/grpc_server/test/grpc/server/supervisor_test.exs index f6f5bf733..7ff5b22e9 100644 --- a/grpc_server/test/grpc/server/supervisor_test.exs +++ b/grpc_server/test/grpc/server/supervisor_test.exs @@ -9,8 +9,10 @@ defmodule GRPC.Server.SupervisorTest do describe "init/1" do test "does not start children if opts sets false" do - assert {:ok, {%{strategy: :one_for_one}, []}} = - Supervisor.init(endpoint: MockEndpoint, port: 1234, start_server: false) + assert { + :ok, + {%{strategy: :one_for_one}, []} + } = Supervisor.init(endpoint: MockEndpoint, port: 1234, start_server: false) end test "fails if a tuple is passed" do diff --git a/interop/config/config.exs b/interop/config/config.exs index 63787d4b3..477c90797 100644 --- a/interop/config/config.exs +++ b/interop/config/config.exs @@ -1,3 +1,3 @@ import Config -config :logger, level: :warning +config :logger, level: :info diff --git a/interop/lib/interop/client.ex b/interop/lib/interop/client.ex index 908815fc6..0c0f2489c 100644 --- a/interop/lib/interop/client.ex +++ b/interop/lib/interop/client.ex @@ -326,7 +326,7 @@ defmodule Interop.Client do } stream = Grpc.Testing.TestService.Stub.full_duplex_call(ch, timeout: 1) - resp = stream |> GRPC.Stub.send_request(req) |> GRPC.Stub.recv() + resp = stream |> GRPC.Stub.send_request(req, end_stream: true) |> GRPC.Stub.recv() case resp do {:error, %GRPC.RPCError{status: 4}} -> diff --git a/interop/mix.lock b/interop/mix.lock index 20e08e875..cce13d505 100644 --- a/interop/mix.lock +++ b/interop/mix.lock @@ -15,4 +15,5 @@ "recon": {:hex, :recon, "2.5.6", "9052588e83bfedfd9b72e1034532aee2a5369d9d9343b61aeb7fbce761010741", [:mix, :rebar3], [], "hexpm", "96c6799792d735cc0f0fd0f86267e9d351e63339cbe03df9d162010cefc26bb0"}, "statix": {:hex, :statix, "1.4.0", "c822abd1e60e62828e8460e932515d0717aa3c089b44cc3f795d43b94570b3a8", [:mix], [], "hexpm", "507373cc80925a9b6856cb14ba17f6125552434314f6613c907d295a09d1a375"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "thousand_island": {:hex, :thousand_island, "1.4.2", "735fa783005d1703359bbd2d3a5a3a398075ba4456e5afe3c5b7cf4666303d36", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "1c7637f16558fc1c35746d5ee0e83b18b8e59e18d28affd1f2fa1645f8bc7473"}, } diff --git a/interop/script/run.exs b/interop/script/run.exs index dc1b42c28..9d65e86c7 100644 --- a/interop/script/run.exs +++ b/interop/script/run.exs @@ -1,6 +1,6 @@ {options, _, _} = OptionParser.parse(System.argv(), - strict: [rounds: :integer, concurrency: :integer, port: :integer, level: :string] + strict: [rounds: :integer, concurrency: :integer, port: :integer, level: :string, adapter: :string] ) rounds = Keyword.get(options, :rounds) || 20 @@ -9,18 +9,29 @@ concurrency = Keyword.get(options, :concurrency) || max_concurrency port = Keyword.get(options, :port) || 0 level = Keyword.get(options, :level) || "warning" level = String.to_existing_atom(level) +server_adapter_name = Keyword.get(options, :adapter) || "both" require Logger Logger.configure(level: level) -Logger.info("Rounds: #{rounds}; concurrency: #{concurrency}; port: #{port}") +Logger.info("Rounds: #{rounds}; concurrency: #{concurrency}; port: #{port}; server_adapter: #{server_adapter_name}") alias GRPC.Client.Adapters.Gun alias GRPC.Client.Adapters.Mint alias Interop.Client -{:ok, _pid, port} = GRPC.Server.start_endpoint(Interop.Endpoint, port) +# Determine which server adapters to test +server_adapters = case server_adapter_name do + "cowboy" -> [GRPC.Server.Adapters.Cowboy] + "thousand_island" -> [GRPC.Server.Adapters.ThousandIsland] + "both" -> [GRPC.Server.Adapters.Cowboy, GRPC.Server.Adapters.ThousandIsland] + _ -> + IO.puts("Unknown adapter: #{server_adapter_name}. Using both.") + [GRPC.Server.Adapters.Cowboy, GRPC.Server.Adapters.ThousandIsland] +end + +client_adapters = [Gun, Mint] defmodule InteropTestRunner do def run(_cli, adapter, port, rounds) do @@ -65,15 +76,32 @@ res = DynamicSupervisor.start_link(strategy: :one_for_one, name: GRPC.Client.Sup {:ok, pid} end -for adapter <- [Gun, Mint] do - Logger.info("Starting run for adapter: #{adapter}") - args = [adapter, port, rounds] - stream_opts = [max_concurrency: concurrency, ordered: false, timeout: :infinity] +# Test each server adapter +for server_adapter <- server_adapters do + server_name = server_adapter |> Module.split() |> List.last() + Logger.info("========================================") + Logger.info("Testing with SERVER ADAPTER: #{server_name}") + Logger.info("========================================") + + {:ok, _pid, test_port} = GRPC.Server.start_endpoint(Interop.Endpoint, port, adapter: server_adapter) + Logger.info("Server started on port #{test_port}") + # Give server time to fully initialize + Process.sleep(100) + + for client_adapter <- client_adapters do + client_name = client_adapter |> Module.split() |> List.last() + Logger.info("Starting run for client adapter: #{client_name}") + args = [client_adapter, test_port, rounds] + stream_opts = [max_concurrency: concurrency, ordered: false, timeout: :infinity] + + 1..concurrency + |> Task.async_stream(InteropTestRunner, :run, args, stream_opts) + |> Enum.to_list() + end - 1..concurrency - |> Task.async_stream(InteropTestRunner, :run, args, stream_opts) - |> Enum.to_list() + :ok = GRPC.Server.stop_endpoint(Interop.Endpoint, adapter: server_adapter) + Logger.info("Completed tests for #{server_name}") end -Logger.info("Succeed!") -:ok = GRPC.Server.stop_endpoint(Interop.Endpoint) +Logger.info("All tests succeeded!") +:ok diff --git a/mix.exs b/mix.exs index a227268a1..fce621d57 100644 --- a/mix.exs +++ b/mix.exs @@ -20,6 +20,7 @@ defmodule GRPC.GRPCRoot do defp aliases do [ + format: cmd("format"), setup: cmd("deps.get"), compile: cmd("compile"), test: cmd("test"), @@ -43,4 +44,4 @@ defmodule GRPC.GRPCRoot do end end -end +end \ No newline at end of file diff --git a/mix.lock b/mix.lock index 427f40b39..c2a78ae07 100644 --- a/mix.lock +++ b/mix.lock @@ -11,4 +11,5 @@ "protobuf": {:hex, :protobuf, "0.15.0", "c9fc1e9fc1682b05c601df536d5ff21877b55e2023e0466a3855cc1273b74dcb", [:mix], [{:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: true]}], "hexpm", "5d7bb325319db1d668838d2691c31c7b793c34111aec87d5ee467a39dac6e051"}, "ranch": {:hex, :ranch, "2.2.0", "25528f82bc8d7c6152c57666ca99ec716510fe0925cb188172f41ce93117b1b0", [:make, :rebar3], [], "hexpm", "fa0b99a1780c80218a4197a59ea8d3bdae32fbff7e88527d7d8a4787eff4f8e7"}, "telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"}, + "thousand_island": {:hex, :thousand_island, "1.4.2", "735fa783005d1703359bbd2d3a5a3a398075ba4456e5afe3c5b7cf4666303d36", [:mix], [{:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "1c7637f16558fc1c35746d5ee0e83b18b8e59e18d28affd1f2fa1645f8bc7473"}, }