From 3300064b1f18de0046936c9036482e885a1080d5 Mon Sep 17 00:00:00 2001 From: Tom Bojer Date: Sat, 14 Jun 2025 00:03:16 +0200 Subject: [PATCH] feat: add graceful shutdown using signals --- cmd/record.go | 7 ++++- cmd/replay.go | 7 ++++- go.mod | 1 + go.sum | 2 ++ internal/config/config.go | 9 ++++++ internal/record/record.go | 40 ++++++++---------------- internal/record/recording_https_proxy.go | 27 +++++++++++++--- internal/replay/replay.go | 26 +++++++-------- internal/replay/replay_http_server.go | 26 +++++++++++++-- samples/google-genai.yml | 3 +- 10 files changed, 96 insertions(+), 52 deletions(-) diff --git a/cmd/record.go b/cmd/record.go index cc86fe7..8e2c856 100644 --- a/cmd/record.go +++ b/cmd/record.go @@ -17,7 +17,9 @@ package cmd import ( "os" + "os/signal" "strings" + "syscall" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/record" @@ -33,6 +35,9 @@ var recordCmd = &cobra.Command{ Long: `Runs test-server in record mode, all request will be proxies to the target server, and all requests and responses will be recorded.`, Run: func(cmd *cobra.Command, args []string) { + ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL) + defer stop() + config, err := config.ReadConfig(cfgFile) if err != nil { panic(err) @@ -44,7 +49,7 @@ target server, and all requests and responses will be recorded.`, panic(err) } - err = record.Record(config, recordingDir, redactor) + err = record.Record(ctx, config, recordingDir, redactor) if err != nil { panic(err) } diff --git a/cmd/replay.go b/cmd/replay.go index a5c1103..a2c91f7 100644 --- a/cmd/replay.go +++ b/cmd/replay.go @@ -18,7 +18,9 @@ package cmd import ( "os" + "os/signal" "strings" + "syscall" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" @@ -37,6 +39,9 @@ It listens on the configured source ports and returns recorded responses when it finds a matching request. Returns a 404 error if no matching recording is found.`, Run: func(cmd *cobra.Command, args []string) { + ctx, stop := signal.NotifyContext(cmd.Context(), os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT, syscall.SIGKILL) + defer stop() + config, err := config.ReadConfig(cfgFile) if err != nil { panic(err) @@ -48,7 +53,7 @@ recording is found.`, panic(err) } - err = replay.Replay(config, replayRecordingDir, redactor) + err = replay.Replay(ctx, config, replayRecordingDir, redactor) if err != nil { panic(err) } diff --git a/go.mod b/go.mod index 74faf83..42a0b5b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/spf13/afero v1.14.0 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 + golang.org/x/sync v0.12.0 gopkg.in/yaml.v2 v2.4.0 ) diff --git a/go.sum b/go.sum index 00fd811..df434e6 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/config/config.go b/internal/config/config.go index 5c0803c..db816e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -32,6 +32,9 @@ type EndpointConfig struct { Health string `yaml:"health"` RedactRequestHeaders []string `yaml:"redact_request_headers"` ResponseHeaderReplacements []HeaderReplacement `yaml:"response_header_replacements"` + // ShutdownTimeoutSeconds is the time in seconds to wait for the server to shutdown gracefully. + // default is 10 seconds. + ShutdownTimeoutSeconds int64 `yaml:"shutdown_timeout_seconds"` } type HeaderReplacement struct { @@ -60,5 +63,11 @@ func ReadConfigWithFs(fs afero.Fs, filename string) (*TestServerConfig, error) { return nil, fmt.Errorf("failed parsing %s: %w", filename, err) } + for i, ep := range config.Endpoints { + if ep.ShutdownTimeoutSeconds <= 1 { + config.Endpoints[i].ShutdownTimeoutSeconds = 10 // default to 10 seconds + } + } + return config, nil } diff --git a/internal/record/record.go b/internal/record/record.go index 91fa154..44cb2de 100644 --- a/internal/record/record.go +++ b/internal/record/record.go @@ -17,52 +17,38 @@ limitations under the License. package record import ( + "context" "fmt" "os" - "sync" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" + "golang.org/x/sync/errgroup" ) -func Record(cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error { +func Record(ctx context.Context, cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error { // Create recording directory if it doesn't exist if err := os.MkdirAll(recordingDir, 0755); err != nil { return fmt.Errorf("failed to create recording directory: %w", err) } fmt.Printf("Recording to directory: %s\n", recordingDir) - var wg sync.WaitGroup - errChan := make(chan error, len(cfg.Endpoints)) + errGroup, errCtx := errgroup.WithContext(ctx) // Start a proxy for each endpoint for _, endpoint := range cfg.Endpoints { - wg.Add(1) - go func(ep config.EndpointConfig) { - defer wg.Done() - - fmt.Printf("Starting server for %v\n", endpoint) - proxy := NewRecordingHTTPSProxy(&endpoint, recordingDir, redactor) - err := proxy.Start() - + ep := endpoint + errGroup.Go(func() error { + fmt.Printf("Starting server for %v\n", ep) + proxy := NewRecordingHTTPSProxy(&ep, recordingDir, redactor) + err := proxy.Start(errCtx) if err != nil { - errChan <- fmt.Errorf("proxy error for %s:%d: %w", + return fmt.Errorf("proxy error for %s:%d: %w", ep.TargetHost, ep.TargetPort, err) } - }(endpoint) - } - - // Wait for all proxies to complete (they shouldn't unless there's an error) - go func() { - wg.Wait() - close(errChan) - }() - - // Return the first error encountered, if any - for err := range errChan { - return err + return nil + }) } - // Block forever (or until interrupted) - select {} + return errGroup.Wait() } diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index 796cd84..be870f3 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -18,12 +18,14 @@ package record import ( "bytes" + "context" "fmt" "io" "net/http" "os" "path/filepath" "regexp" + "time" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" @@ -51,15 +53,33 @@ func (r *RecordingHTTPSProxy) ResetChain() { r.prevRequestSHA = store.HeadSHA } -func (r *RecordingHTTPSProxy) Start() error { +func (r *RecordingHTTPSProxy) Start(ctx context.Context) error { addr := fmt.Sprintf(":%d", r.config.SourcePort) server := &http.Server{ Addr: addr, Handler: http.HandlerFunc(r.handleRequest), } - if err := server.ListenAndServe(); err != nil { - panic(err) + + errCh := make(chan error, 1) + + go func() { + errCh <- server.ListenAndServe() + }() + + select { + case <-ctx.Done(): + fmt.Printf("Context cancelled, shutting down server on %s\n", addr) + cancelCtx, cancel := context.WithTimeout(context.Background(), time.Duration(r.config.ShutdownTimeoutSeconds)*time.Second) + defer cancel() + if err := server.Shutdown(cancelCtx); err != nil { + return fmt.Errorf("failed to shutdown server: %w", err) + } + case err := <-errCh: + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("failed to start recording HTTPS proxy: %w", err) + } } + return nil } @@ -91,7 +111,6 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } err = r.recordResponse(resp, reqHash, respBody) - if err != nil { fmt.Printf("Error recording response: %v\n", err) http.Error(w, fmt.Sprintf("Error recording response: %v", err), http.StatusInternalServerError) diff --git a/internal/replay/replay.go b/internal/replay/replay.go index 502642b..315c577 100644 --- a/internal/replay/replay.go +++ b/internal/replay/replay.go @@ -17,15 +17,17 @@ limitations under the License. package replay import ( + "context" "fmt" "os" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" + "golang.org/x/sync/errgroup" ) // Replay serves recorded responses for HTTP requests -func Replay(cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error { +func Replay(ctx context.Context, cfg *config.TestServerConfig, recordingDir string, redactor *redact.Redact) error { // Validate recording directory exists if _, err := os.Stat(recordingDir); os.IsNotExist(err) { return fmt.Errorf("recording directory does not exist: %s", recordingDir) @@ -33,26 +35,20 @@ func Replay(cfg *config.TestServerConfig, recordingDir string, redactor *redact. fmt.Printf("Replaying from directory: %s\n", recordingDir) - // Start a server for each endpoint - errChan := make(chan error, len(cfg.Endpoints)) + errGroup, errCtx := errgroup.WithContext(ctx) for _, endpoint := range cfg.Endpoints { - go func(ep config.EndpointConfig) { + ep := endpoint // Capture range variable + errGroup.Go(func() error { server := NewReplayHTTPServer(&endpoint, recordingDir, redactor) - err := server.Start() + err := server.Start(errCtx) if err != nil { - errChan <- fmt.Errorf("replay error for %s:%d: %w", + return fmt.Errorf("replay error for %s:%d: %w", ep.TargetHost, ep.TargetPort, err) } - }(endpoint) + return nil + }) } - // Return the first error encountered, if any - select { - case err := <-errChan: - return err - default: - // Block forever (or until interrupted) - select {} - } + return errGroup.Wait() } diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index bf4bedd..62c1fb0 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -17,10 +17,12 @@ limitations under the License. package replay import ( + "context" "fmt" "net/http" "os" "path/filepath" + "time" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" @@ -43,15 +45,33 @@ func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redact } } -func (r *ReplayHTTPServer) Start() error { +func (r *ReplayHTTPServer) Start(ctx context.Context) error { addr := fmt.Sprintf(":%d", r.config.SourcePort) server := &http.Server{ Addr: addr, Handler: http.HandlerFunc(r.handleRequest), } - if err := server.ListenAndServe(); err != nil { - panic(err) + + errCh := make(chan error, 1) + + go func() { + errCh <- server.ListenAndServe() + }() + + select { + case <-ctx.Done(): + fmt.Printf("Shutting down server on %s\n", addr) + cancelCtx, cancel := context.WithTimeout(context.Background(), time.Duration(r.config.ShutdownTimeoutSeconds)*time.Second) + defer cancel() + if err := server.Shutdown(cancelCtx); err != nil { + return fmt.Errorf("failed to shutdown server: %w", err) + } + case err := <-errCh: + if err != nil && err != http.ErrServerClosed { + return fmt.Errorf("failed to start replay server: %w", err) + } } + return nil } diff --git a/samples/google-genai.yml b/samples/google-genai.yml index fa361a7..7491205 100644 --- a/samples/google-genai.yml +++ b/samples/google-genai.yml @@ -11,6 +11,7 @@ endpoints: - header: X-Goog-Upload-Url regex: "^https://generativelanguage.googleapis.com/" replace: "http://localhost:1443" + shutdown_timeout_seconds: 5 - target_host: us-central1-aiplatform.googleapis.com target_port: 443 source_port: 1444 @@ -19,4 +20,4 @@ endpoints: redact_request_headers: - X-Goog-Api-Key - Authorization - + shutdown_timeout_seconds: 5