Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cmd/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package cmd

import (
"os"
"os/signal"
"strings"
"syscall"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/record"
Expand All @@ -33,6 +35,9 @@ var recordCmd = &cobra.Command{
Long: `Runs test-server in record mode, all request will be proxies to the
target server, and all requests and responses will be recorded.`,
Run: func(cmd *cobra.Command, args []string) {
ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL)
defer stop()

config, err := config.ReadConfig(cfgFile)
if err != nil {
panic(err)
Expand All @@ -44,7 +49,7 @@ target server, and all requests and responses will be recorded.`,
panic(err)
}

err = record.Record(config, recordingDir, redactor)
err = record.Record(ctx, config, recordingDir, redactor)
if err != nil {
panic(err)
}
Expand Down
7 changes: 6 additions & 1 deletion cmd/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package cmd

import (
"os"
"os/signal"
"strings"
"syscall"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
Expand All @@ -37,6 +39,9 @@ It listens on the configured source ports and returns recorded responses
when it finds a matching request. Returns a 404 error if no matching
recording is found.`,
Run: func(cmd *cobra.Command, args []string) {
ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL)
defer stop()

config, err := config.ReadConfig(cfgFile)
if err != nil {
panic(err)
Expand All @@ -48,7 +53,7 @@ recording is found.`,
panic(err)
}

err = replay.Replay(config, replayRecordingDir, redactor)
err = replay.Replay(ctx, config, replayRecordingDir, redactor)
if err != nil {
panic(err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/spf13/afero v1.14.0
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
golang.org/x/sync v0.12.0
gopkg.in/yaml.v2 v2.4.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
Expand Down
9 changes: 9 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type EndpointConfig struct {
Health string `yaml:"health"`
RedactRequestHeaders []string `yaml:"redact_request_headers"`
ResponseHeaderReplacements []HeaderReplacement `yaml:"response_header_replacements"`
// ShutdownTimeoutSeconds is the time in seconds to wait for the server to shutdown gracefully.
// default is 10 seconds.
ShutdownTimeoutSeconds int64 `yaml:"shutdown_timeout_seconds"`
}

type HeaderReplacement struct {
Expand Down Expand Up @@ -60,5 +63,11 @@ func ReadConfigWithFs(fs afero.Fs, filename string) (*TestServerConfig, error) {
return nil, fmt.Errorf("failed parsing %s: %w", filename, err)
}

for i, ep := range config.Endpoints {
if ep.ShutdownTimeoutSeconds <= 1 {
config.Endpoints[i].ShutdownTimeoutSeconds = 10 // default to 10 seconds
}
}

return config, nil
}
40 changes: 13 additions & 27 deletions internal/record/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,38 @@ limitations under the License.
package record

import (
"context"
"fmt"
"os"
"sync"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
"golang.org/x/sync/errgroup"
)

func Record(cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error {
func Record(ctx context.Context, cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error {
// Create recording directory if it doesn't exist
if err := os.MkdirAll(recordingDir, 0755); err != nil {
return fmt.Errorf("failed to create recording directory: %w", err)
}

fmt.Printf("Recording to directory: %s\n", recordingDir)
var wg sync.WaitGroup
errChan := make(chan error, len(cfg.Endpoints))
errGroup, errCtx := errgroup.WithContext(ctx)

// Start a proxy for each endpoint
for _, endpoint := range cfg.Endpoints {
wg.Add(1)
go func(ep config.EndpointConfig) {
defer wg.Done()

fmt.Printf("Starting server for %v\n", endpoint)
proxy := NewRecordingHTTPSProxy(&endpoint, recordingDir, redactor)
err := proxy.Start()

ep := endpoint
errGroup.Go(func() error {
fmt.Printf("Starting server for %v\n", ep)
proxy := NewRecordingHTTPSProxy(&ep, recordingDir, redactor)
err := proxy.Start(errCtx)
if err != nil {
errChan <- fmt.Errorf("proxy error for %s:%d: %w",
return fmt.Errorf("proxy error for %s:%d: %w",
ep.TargetHost, ep.TargetPort, err)
}
}(endpoint)
}

// Wait for all proxies to complete (they shouldn't unless there's an error)
go func() {
wg.Wait()
close(errChan)
}()

// Return the first error encountered, if any
for err := range errChan {
return err
return nil
})
}

// Block forever (or until interrupted)
select {}
return errGroup.Wait()
}
27 changes: 23 additions & 4 deletions internal/record/recording_https_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package record

import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"time"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
Expand Down Expand Up @@ -51,15 +53,33 @@ func (r *RecordingHTTPSProxy) ResetChain() {
r.prevRequestSHA = store.HeadSHA
}

func (r *RecordingHTTPSProxy) Start() error {
func (r *RecordingHTTPSProxy) Start(ctx context.Context) error {
addr := fmt.Sprintf(":%d", r.config.SourcePort)
server := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(r.handleRequest),
}
if err := server.ListenAndServe(); err != nil {
panic(err)

errCh := make(chan error, 1)

go func() {
errCh <- server.ListenAndServe()
}()

select {
case <-ctx.Done():
fmt.Printf("Context cancelled, shutting down server on %s\n", addr)
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Duration(r.config.ShutdownTimeoutSeconds)*time.Second)
defer cancel()
if err := server.Shutdown(cancelCtx); err != nil {
return fmt.Errorf("failed to shutdown server: %w", err)
}
case err := <-errCh:
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("failed to start recording HTTPS proxy: %w", err)
}
}

return nil
}

Expand Down Expand Up @@ -91,7 +111,6 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
}

err = r.recordResponse(resp, reqHash, respBody)

if err != nil {
fmt.Printf("Error recording response: %v\n", err)
http.Error(w, fmt.Sprintf("Error recording response: %v", err), http.StatusInternalServerError)
Expand Down
26 changes: 11 additions & 15 deletions internal/replay/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,38 @@ limitations under the License.
package replay

import (
"context"
"fmt"
"os"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
"golang.org/x/sync/errgroup"
)

// Replay serves recorded responses for HTTP requests
func Replay(cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error {
func Replay(ctx context.Context, cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error {
// Validate recording directory exists
if _, err := os.Stat(recordingDir); os.IsNotExist(err) {
return fmt.Errorf("recording directory does not exist: %s", recordingDir)
}

fmt.Printf("Replaying from directory: %s\n", recordingDir)

// Start a server for each endpoint
errChan := make(chan error, len(cfg.Endpoints))
errGroup, errCtx := errgroup.WithContext(ctx)

for _, endpoint := range cfg.Endpoints {
go func(ep config.EndpointConfig) {
ep := endpoint // Capture range variable
errGroup.Go(func() error {
server := NewReplayHTTPServer(&endpoint, recordingDir, redactor)
err := server.Start()
err := server.Start(errCtx)
if err != nil {
errChan <- fmt.Errorf("replay error for %s:%d: %w",
return fmt.Errorf("replay error for %s:%d: %w",
ep.TargetHost, ep.TargetPort, err)
}
}(endpoint)
return nil
})
}

// Return the first error encountered, if any
select {
case err := <-errChan:
return err
default:
// Block forever (or until interrupted)
select {}
}
return errGroup.Wait()
}
26 changes: 23 additions & 3 deletions internal/replay/replay_http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
package replay

import (
"context"
"fmt"
"net/http"
"os"
"path/filepath"
"time"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
Expand All @@ -43,15 +45,33 @@ func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redact
}
}

func (r *ReplayHTTPServer) Start() error {
func (r *ReplayHTTPServer) Start(ctx context.Context) error {
addr := fmt.Sprintf(":%d", r.config.SourcePort)
server := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(r.handleRequest),
}
if err := server.ListenAndServe(); err != nil {
panic(err)

errCh := make(chan error, 1)

go func() {
errCh <- server.ListenAndServe()
}()

select {
case <-ctx.Done():
fmt.Printf("Shutting down server on %s\n", addr)
cancelCtx, cancel := context.WithTimeout(context.Background(), time.Duration(r.config.ShutdownTimeoutSeconds)*time.Second)
defer cancel()
if err := server.Shutdown(cancelCtx); err != nil {
return fmt.Errorf("failed to shutdown server: %w", err)
}
case err := <-errCh:
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("failed to start replay server: %w", err)
}
}

return nil
}

Expand Down
3 changes: 2 additions & 1 deletion samples/google-genai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ endpoints:
- header: X-Goog-Upload-Url
regex: "^https://generativelanguage.googleapis.com/"
replace: "http://localhost:1443"
shutdown_timeout_seconds: 5
- target_host: us-central1-aiplatform.googleapis.com
target_port: 443
source_port: 1444
Expand All @@ -19,4 +20,4 @@ endpoints:
redact_request_headers:
- X-Goog-Api-Key
- Authorization

shutdown_timeout_seconds: 5