diff --git a/.golangci.yml b/.golangci.yml index 6ba45d2..72a2098 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,11 +7,9 @@ linters: - asasalint - asciicheck - bodyclose - - contextcheck - copyloopvar - dupl - durationcheck - - err113 - errcheck - errorlint - exhaustive diff --git a/config.go b/config.go index a8bd058..f4f8e6b 100644 --- a/config.go +++ b/config.go @@ -7,8 +7,11 @@ import ( "net/http" "os" "text/template" + "time" ) +const JsonContentType = "application/json" + type Config struct { RequestIDHeader string `json:"request_id_header,omitempty" yaml:"request_id_header,omitempty"` Routes []Route `json:"routes" yaml:"routes"` @@ -21,15 +24,16 @@ type Route struct { } type Response struct { - Headers http.Header `json:"headers,omitempty" yaml:"headers,omitempty"` - Repeat *int `json:"repeat,omitempty" yaml:"repeat,omitempty"` - Body string `json:"body,omitempty" yaml:"body,omitempty"` - File string `json:"file,omitempty" yaml:"file,omitempty"` - Code int `json:"code,omitempty" yaml:"code,omitempty"` - IsJSON bool `json:"is_json,omitempty" yaml:"is_json,omitempty"` + Headers http.Header `json:"headers,omitempty" yaml:"headers,omitempty"` + Repeat *int `json:"repeat,omitempty" yaml:"repeat,omitempty"` + Body string `json:"body,omitempty" yaml:"body,omitempty"` + File string `json:"file,omitempty" yaml:"file,omitempty"` + Code int `json:"code,omitempty" yaml:"code,omitempty"` + IsJSON bool `json:"is_json,omitempty" yaml:"is_json,omitempty"` + Delay time.Duration `json:"delay,omitempty" yaml:"delay,omitempty"` } -func responsesWriter(responses []Response, log *slog.Logger) http.HandlerFunc { +func responsesWriter(responses []Response) http.HandlerFunc { var i int return func(writer http.ResponseWriter, request *http.Request) { for { @@ -69,14 +73,17 @@ func responsesWriter(responses []Response, log *slog.Logger) http.HandlerFunc { } if response.IsJSON { if writer.Header().Get("Content-Type") == "" { - writer.Header().Set("Content-Type", "application/json") + writer.Header().Set("Content-Type", JsonContentType) } } + if response.Delay > 0 { + time.Sleep(response.Delay) + } writer.WriteHeader(response.Code) if len(data) > 0 { if _, err := writer.Write(data); err != nil { - log.ErrorContext(request.Context(), "sending response failed", slog.String("error", err.Error())) + slog.Error("Sending response failed", slog.String("error", err.Error())) } } return diff --git a/config.schema.json b/config.schema.json index 78b9119..86bc727 100644 --- a/config.schema.json +++ b/config.schema.json @@ -63,6 +63,11 @@ "repeat": { "description": "the number of repeats. Infinity if no set. Zero to skip. Or an exact number of repeats.", "type": "integer" + }, + "delay": { + "description": "The delay before sending the response", + "default": "0ms", + "type": "string" } } }, diff --git a/example_config.yaml b/example_config.yaml index 8c7929e..e6d3e53 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -44,3 +44,7 @@ routes: repeat: 1 - code: 404 body: user "{{.PathValue "id"}}" not found + - pattern: GET /delay/1min + responses: + - code: 200 + delay: 1m diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..20094e1 --- /dev/null +++ b/integration_test.go @@ -0,0 +1,121 @@ +package main + +import ( + "context" + "fmt" + "net" + "net/http" + "strconv" + "testing" + "time" +) + +// TestRun_ServerIntegration starts the real HTTP server, exercises endpoints, then shuts it down. +func TestRun_ServerIntegration(t *testing.T) { + t.Parallel() + + // Pick a free port + lc := net.ListenConfig{} + ln, err := lc.Listen(t.Context(), "tcp", ":0") + if err != nil { + t.Fatalf("listen :0: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + if err := ln.Close(); err != nil { + t.Fatalf("close listener: %v", err) + } + + configContent := `routes: + - pattern: /hello + responses: + - code: 200 + body: Hello + - pattern: /json + responses: + - code: 201 + body: '{"ok":true}' + is_json: true +` + cfgPath := writeConfig(t, configContent) + ctx, cancel := context.WithCancel(t.Context()) + + done := make(chan struct{}) + go func() { + defer close(done) + err := run(ctx, []string{"-c", cfgPath, "-p", strconv.Itoa(port)}) + if err != nil { + t.Errorf("run: %v", err) + } + }() + + client := &http.Client{Timeout: 2 * time.Second} + base := fmt.Sprintf("http://localhost:%d", port) + + deadline := time.Now().Add(5 * time.Second) + for { + if time.Now().After(deadline) { + t.Fatalf("server did not start in time on %s", base) + } + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/hello", http.NoBody) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err := client.Do(req) + if err != nil { + time.Sleep(50 * time.Millisecond) + continue + } + _ = resp.Body.Close() + break + } + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/hello", http.NoBody) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("GET /hello: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("/hello expected 200 got %d", resp.StatusCode) + } + _ = resp.Body.Close() + + req, err = http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/json", http.NoBody) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("GET /json: %v", err) + } + if resp.StatusCode != http.StatusCreated { + t.Fatalf("/json expected 201 got %d", resp.StatusCode) + } + if ct := resp.Header.Get("Content-Type"); ct != JsonContentType { + t.Fatalf("expected application/json got %q", ct) + } + _ = resp.Body.Close() + + req, err = http.NewRequestWithContext(t.Context(), http.MethodGet, base+"/", http.NoBody) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err = client.Do(req) + if err != nil { + t.Fatalf("GET /: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("/ expected 404 got %d", resp.StatusCode) + } + _ = resp.Body.Close() + + cancel() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("server did not exit in time") + } +} diff --git a/log.go b/log.go index 711fc97..83b9494 100644 --- a/log.go +++ b/log.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "net/http" + "os" ) type wrapper struct { @@ -29,7 +30,7 @@ func (w *wrapper) Write(b []byte) (int, error) { var _ http.ResponseWriter = &wrapper{} -func StructuredLogger(log *slog.Logger, reqIDHeader string, next http.HandlerFunc) http.HandlerFunc { +func StructuredLogger(reqIDHeader string, next http.HandlerFunc) http.HandlerFunc { return func(writer http.ResponseWriter, request *http.Request) { wr := &wrapper{writer: writer} next.ServeHTTP(wr, request) @@ -39,7 +40,7 @@ func StructuredLogger(log *slog.Logger, reqIDHeader string, next http.HandlerFun scheme = "https" } - log.LogAttrs(request.Context(), slog.LevelInfo, "request completed", + slog.LogAttrs(request.Context(), slog.LevelInfo, "request completed", slog.String("http_scheme", scheme), slog.String("http_proto", request.Proto), slog.String("http_method", request.Method), @@ -52,3 +53,11 @@ func StructuredLogger(log *slog.Logger, reqIDHeader string, next http.HandlerFun ) } } + +func InitLogger() { + slog.SetLogLoggerLevel(slog.LevelDebug) + logHandler := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + slog.SetDefault(slog.New(logHandler)) +} diff --git a/main.go b/main.go index 4394987..e1b23f9 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "net" "net/http" "os" "os/signal" @@ -17,41 +18,96 @@ import ( ) func main() { - log := slog.New(slog.NewJSONHandler(os.Stderr, nil)) - conf := flag.StringP("config", "c", "config.yaml", "config file") - port := flag.IntP("port", "p", 8080, "http port") - flag.Parse() + InitLogger() - if v := os.Getenv("CONFIG"); v != "" && !flag.Lookup("config").Changed { - conf = &v + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + go func() { + <-sig + cancel() + }() + + var args []string + if len(os.Args) > 1 { + args = os.Args[1:] + } + + if err := run(ctx, args); err != nil { + slog.Error("Exiting with error", slog.String("error", err.Error())) + _ = os.Stderr.Sync() // ensure all output is flushed before exiting + cancel() + os.Exit(1) //nolint:gocritic // cancel is called manually } - if v := os.Getenv("PORT"); v != "" && !flag.Lookup("port").Changed { +} + +func getConfig(args []string) (config *Config, retError error) { + fs := flag.NewFlagSet("mock-http-server", flag.ContinueOnError) + fs.SetOutput(os.Stderr) + var port int + if v, ok := os.LookupEnv("PORT"); ok { p, err := strconv.Atoi(v) if err != nil { - log.Error(fmt.Sprintf("wrong value %q of env variable PORT", v), slog.String("error", err.Error())) - os.Exit(1) + return nil, fmt.Errorf("wrong value %q of env variable PORT: %w", v, err) } - port = &p + port = p + } + + configPath := fs.StringP("config", "c", "config.yaml", "config file") + fsPort := fs.IntP("port", "p", 8080, "http port") + + if err := fs.Parse(args); err != nil { + return nil, fmt.Errorf("parsing commandline arguments failed: %w", err) + } + + if fs.Lookup("port").Changed { + port = *fsPort + } + + if v, ok := os.LookupEnv("CONFIG"); ok && !fs.Lookup("config").Changed { + *configPath = v } - f, err := os.Open(*conf) + f, err := os.Open(*configPath) if err != nil { - log.Error(fmt.Sprintf("wrong config file %q", *conf), slog.String("error", err.Error())) - os.Exit(1) + return nil, fmt.Errorf("wrong config file %q: %w", *configPath, err) } + defer func(f *os.File) { + if err := f.Close(); err != nil { + retError = errors.Join(retError, fmt.Errorf("failed to close config file %q: %w", *configPath, err)) + } + }(f) - var config Config + var conf Config d := yaml.NewDecoder(f) d.KnownFields(true) - if err := d.Decode(&config); err != nil { - log.Error("decoding config failed", slog.String("error", err.Error())) - os.Exit(1) + if err := d.Decode(&conf); err != nil { + return nil, fmt.Errorf("decoding config failed: %w", err) + } + if port != 0 { + conf.Port = port + } + if conf.Port == 0 { + conf.Port = *fsPort } - if config.Port != 0 { - port = &config.Port + if conf.RequestIDHeader == "" { + conf.RequestIDHeader = "X-Request-ID" } - if config.RequestIDHeader == "" { - config.RequestIDHeader = "X-Request-ID" + + return &conf, nil +} + +// run executes the main logic. When startServer is false it stops after constructing the server and +// returns immediately. +// Returns (exitCode, serverAddress, error) +func run( + ctx context.Context, + args []string, +) (retError error) { + config, err := getConfig(args) + if err != nil { + return err } mux := http.NewServeMux() @@ -63,53 +119,47 @@ func main() { if route.Pattern == "/" { isRootRegistered = true } - mux.HandleFunc( - route.Pattern, - StructuredLogger(log, config.RequestIDHeader, responsesWriter(route.Responses, log)), - ) + mux.HandleFunc(route.Pattern, StructuredLogger( + config.RequestIDHeader, + responsesWriter(route.Responses), + )) } if !isRootRegistered { - mux.HandleFunc("/", StructuredLogger(log, config.RequestIDHeader, http.NotFound)) + mux.HandleFunc("/", StructuredLogger(config.RequestIDHeader, http.NotFound)) } - server := &http.Server{ - Addr: fmt.Sprintf(":%d", *port), - Handler: mux, - ReadHeaderTimeout: 1 * time.Second, + addr := fmt.Sprintf(":%d", config.Port) + server := &http.Server{Addr: addr, Handler: mux, ReadHeaderTimeout: 1 * time.Second} + server.BaseContext = func(_ net.Listener) context.Context { + return ctx } - serverCtx, serverStopCtx := context.WithCancel(context.Background()) - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - go func() { - <-sig - - // Shutdown signal with a grace period of 30 seconds - shutdownCtx, shutdownCancelCtx := context.WithTimeout(serverCtx, 30*time.Second) - defer shutdownCancelCtx() + serverCtx, serverCancel := context.WithCancel(context.Background()) + defer serverCancel() + go func() { + defer serverCancel() + <-ctx.Done() + slog.Info("Shutdown signal received") + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer shutdownCancel() go func() { - <-shutdownCtx.Done() - log.Info("graceful shutdown") - if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) { - log.Error("graceful shutdown timed out.. forcing exit.") - os.Exit(1) + if err := server.Shutdown(shutdownCtx); err != nil { + if !errors.Is(err, context.DeadlineExceeded) { + retError = errors.Join(retError, fmt.Errorf("server shutdown failed: %w", err)) + } } + shutdownCancel() }() - - // Trigger graceful shutdown - if err := server.Shutdown(shutdownCtx); err != nil { - log.Error("graceful shutdown failed", slog.String("error", err.Error())) + <-shutdownCtx.Done() + if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) { + retError = errors.Join(retError, errors.New("server shutdown timed out")) } - serverStopCtx() }() - // Run the server - log.Info(fmt.Sprintf("Listen on http://localhost:%d", *port)) + slog.Info("Starting server", slog.String("addr", fmt.Sprintf("http://localhost:%d", config.Port))) if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Error("starting failed", slog.String("error", err.Error())) - os.Exit(1) + return fmt.Errorf("starting failed: %w", err) } - - // Wait for server context to be stopped <-serverCtx.Done() + return } diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..b151bda --- /dev/null +++ b/main_test.go @@ -0,0 +1,100 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// helper to write a temporary config file +func writeConfig(tb testing.TB, content string) string { + tb.Helper() + dir := tb.TempDir() + path := filepath.Join(dir, "config.yaml") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + tb.Fatalf("write config: %v", err) + } + return path +} + +func TestRun_ConfigFileNotFound(t *testing.T) { + t.Parallel() + + err := run(t.Context(), []string{"-c", "no_such_file.yaml"}) + if err == nil { + t.Fatalf("expected error") + } + if !strings.Contains(err.Error(), "no_such_file.yaml") { + t.Fatalf("expected config file name in error, got %v", err) + } +} + +func TestRun_InvalidPortEnv(t *testing.T) { + cfgPath := writeConfig(t, "routes: []\n") + t.Setenv("PORT", "abc") + _, err := getConfig([]string{"-c", cfgPath}) + if err == nil { + t.Fatalf("expected invalid port error") + } + if !strings.Contains(err.Error(), "PORT") { + t.Fatalf("expected PORT in error, got %v", err) + } +} + +func TestRun_EnvPortUsedWhenNoFlag(t *testing.T) { + cfgPath := writeConfig(t, "routes: []\n") + t.Setenv("PORT", "65001") + cfg, err := getConfig([]string{"-c", cfgPath}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Port != 65001 { + t.Fatalf("expected port 65001 got %d", cfg.Port) + } +} + +func TestRun_FlagPortOverridesEnv(t *testing.T) { + cfgPath := writeConfig(t, "routes: []\n") + t.Setenv("PORT", "63000") + cfg, err := getConfig([]string{"-c", cfgPath, "-p", "62000"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Port != 62000 { + t.Fatalf("expected port 62000 got %d", cfg.Port) + } +} + +func TestRun_ConfigPortOverridesFlagAndEnv(t *testing.T) { + cfgPath := writeConfig(t, `port: 64000 +routes: + - responses: + - body: ok + code: 200 +`) + t.Setenv("PORT", "63000") + cfg, err := getConfig( + []string{"-c", cfgPath, "-p", "62000"}, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Port != 62000 { + t.Fatalf("expected port 62000 got %d", cfg.Port) + } +} + +func TestRun_DecodeErrorUnknownField(t *testing.T) { + t.Parallel() + + cfgPath := writeConfig(t, "unknown_field: 1\nroutes: []\n") + _, err := getConfig([]string{"-c", cfgPath}) + if err == nil { + t.Fatalf("expected decode failure") + } + // yaml KnownFields error strings vary; just ensure mentions unknown or field + if !strings.Contains(err.Error(), "unknown") && !strings.Contains(err.Error(), "field") { + t.Fatalf("unexpected error message: %v", err) + } +} diff --git a/responses_writer_test.go b/responses_writer_test.go index acc6a9e..0e56ce5 100644 --- a/responses_writer_test.go +++ b/responses_writer_test.go @@ -13,8 +13,15 @@ import ( ) // helper to build logger writing into a buffer -func testLogger(buf io.Writer) *slog.Logger { - return slog.New(slog.NewJSONHandler(buf, nil)) +func setTestLogger(tb testing.TB, buf io.Writer) { + tb.Helper() + oldLogger := slog.Default() + tb.Cleanup(func() { + slog.SetDefault(oldLogger) + }) + + newLogger := slog.New(slog.NewJSONHandler(buf, nil)) + slog.SetDefault(newLogger) } func TestResponsesWriter_FileAndJSONHeader(t *testing.T) { @@ -29,14 +36,14 @@ func TestResponsesWriter_FileAndJSONHeader(t *testing.T) { resp := Response{File: fname, Code: 201, IsJSON: true, Headers: http.Header{"X-Test": {"yes"}}} rec := httptest.NewRecorder() - rw := responsesWriter([]Response{resp}, testLogger(io.Discard)) + rw := responsesWriter([]Response{resp}) r := httptest.NewRequest(http.MethodGet, "/", http.NoBody) rw(rec, r) if rec.Code != 201 { t.Fatalf("expected status 201, got %d", rec.Code) } - if got := rec.Header().Get("Content-Type"); got != "application/json" { + if got := rec.Header().Get("Content-Type"); got != JsonContentType { t.Fatalf("expected application/json, got %q", got) } if got := rec.Header().Get("X-Test"); got != "yes" { @@ -52,7 +59,7 @@ func TestResponsesWriter_RepeatLogic(t *testing.T) { repeat := 2 responses := []Response{{Body: "first", Code: 200, Repeat: &repeat}, {Body: "second", Code: 202}} - rw := responsesWriter(responses, testLogger(io.Discard)) + rw := responsesWriter(responses) // first call rec1 := httptest.NewRecorder() @@ -77,7 +84,7 @@ func TestResponsesWriter_RepeatLogic(t *testing.T) { func TestResponsesWriter_NotFoundWhenExhausted(t *testing.T) { t.Parallel() - rw := responsesWriter([]Response{}, testLogger(io.Discard)) + rw := responsesWriter([]Response{}) rec := httptest.NewRecorder() rw(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) if rec.Code != http.StatusNotFound { @@ -90,7 +97,7 @@ func TestResponsesWriter_FileReadError(t *testing.T) { responses := []Response{{File: "no_such_file", Code: 200}} rec := httptest.NewRecorder() - rw := responsesWriter(responses, testLogger(io.Discard)) + rw := responsesWriter(responses) rw(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) if rec.Code != http.StatusInternalServerError { t.Fatalf("expected 500, got %d", rec.Code) @@ -105,7 +112,7 @@ func TestResponsesWriter_JSONDoesNotOverrideExistingContentType(t *testing.T) { responses := []Response{{Body: "{}", Code: 200, IsJSON: true, Headers: http.Header{"Content-Type": {"text/plain"}}}} rec := httptest.NewRecorder() - rw := responsesWriter(responses, testLogger(io.Discard)) + rw := responsesWriter(responses) rw(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) if got := rec.Header().Get("Content-Type"); got != "text/plain" { t.Fatalf("expected text/plain kept, got %q", got) @@ -138,14 +145,14 @@ func TestStructuredLogger_BasicFields(t *testing.T) { t.Parallel() var buf strings.Builder - logger := testLogger(&buf) + setTestLogger(t, &buf) next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Handled", "1") w.WriteHeader(http.StatusNoContent) }) - h := StructuredLogger(logger, "X-Request-ID", next) + h := StructuredLogger("X-Request-ID", next) rec := httptest.NewRecorder() // Use relative URL so that when we manually set Host, the constructed URI is correct. // If we used an absolute URL, Host would be duplicated in the URI @@ -185,9 +192,9 @@ func TestResponsesWriter_JSONBodySetsContentType(t *testing.T) { responses := []Response{{Body: "{\"k\":1}", Code: 200, IsJSON: true}} rec := httptest.NewRecorder() - rw := responsesWriter(responses, testLogger(io.Discard)) + rw := responsesWriter(responses) rw(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) - if ct := rec.Header().Get("Content-Type"); ct != "application/json" { + if ct := rec.Header().Get("Content-Type"); ct != JsonContentType { t.Fatalf("expected application/json, got %q", ct) } } @@ -198,7 +205,7 @@ func TestResponsesWriter_RepeatZeroSkips(t *testing.T) { zero := 0 responses := []Response{{Body: "skip", Code: 200, Repeat: &zero}, {Body: "use", Code: 201}} rec := httptest.NewRecorder() - rw := responsesWriter(responses, testLogger(io.Discard)) + rw := responsesWriter(responses) rw(rec, httptest.NewRequest(http.MethodGet, "/", http.NoBody)) if rec.Code != 201 || strings.TrimSpace(rec.Body.String()) != "use" { t.Fatalf("expected second response, got %d %q", rec.Code, rec.Body.String())