diff --git a/oapi_validate.go b/oapi_validate.go index 9e8b126..4bcd795 100644 --- a/oapi_validate.go +++ b/oapi_validate.go @@ -8,6 +8,7 @@ package nethttpmiddleware import ( + "context" "errors" "fmt" "log" @@ -21,8 +22,58 @@ import ( ) // ErrorHandler is called when there is an error in validation +// +// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence. +// +// Deprecated: it's recommended you migrate to the ErrorHandlerWithOpts, as it provides more control over how to handle an error that occurs, including giving direct access to the `error` itself. There are no plans to remove this method. type ErrorHandler func(w http.ResponseWriter, message string, statusCode int) +// ErrorHandlerWithOpts is called when there is an error in validation, with more information about the `error` that occurred and which request is currently being processed. +// +// If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence. +// +// NOTE that this should ideally be used instead of ErrorHandler +type ErrorHandlerWithOpts func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts ErrorHandlerOpts) + +// ErrorHandlerOpts contains additional options that are passed to the `ErrorHandlerWithOpts` function in the case of an error being returned by the middleware +type ErrorHandlerOpts struct { + // Error is the underlying error that triggered this error handler to be executed. + // + // Known error types: + // + // - `*openapi3filter.SecurityRequirementsError` - if the `AuthenticationFunc` has failed to authenticate the request + // - `*openapi3filter.RequestError` - if a bad request has been made + // + // Additionally, if you have set `openapi3filter.Options#MultiError`: + // + // - `openapi3.MultiError` (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) + Error error + + // StatusCode indicates the HTTP Status Code that the OpenAPI validation middleware _suggests_ is returned to the user. + // + // NOTE that this is very much a suggestion, and can be overridden if you believe you have a better approach. + StatusCode int + + // MatchedRoute is the underlying path that this request is being matched against. + // + // This is the route according to the OpenAPI validation middleware, and can be used in addition to/instead of the `http.Request` + // + // NOTE that this will be nil if there is no matched route (i.e. a request has been sent to an endpoint not in the OpenAPI spec) + MatchedRoute *ErrorHandlerOptsMatchedRoute +} + +type ErrorHandlerOptsMatchedRoute struct { + // Route indicates the Route that this error is received by. + // + // This can be used in addition to/instead of the `http.Request`. + Route *routers.Route + + // PathParams are any path parameters that are determined from the request. + // + // This can be used in addition to/instead of the `http.Request`. + PathParams map[string]string +} + // MultiErrorHandler is called when the OpenAPI filter returns an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) type MultiErrorHandler func(openapi3.MultiError) (int, error) @@ -32,11 +83,21 @@ type Options struct { Options openapi3filter.Options // ErrorHandler is called when a validation error occurs. // + // If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence. + // // If not provided, `http.Error` will be called ErrorHandler ErrorHandler + + // ErrorHandlerWithOpts is called when there is an error in validation. + // + // If both an `ErrorHandlerWithOpts` and `ErrorHandler` are set, the `ErrorHandlerWithOpts` takes precedence. + ErrorHandlerWithOpts ErrorHandlerWithOpts + // MultiErrorHandler is called when there is an openapi3.MultiError (https://pkg.go.dev/github.com/getkin/kin-openapi/openapi3#MultiError) returned by the `openapi3filter`. // // If not provided `defaultMultiErrorHandler` will be used. + // + // Does not get called when using `ErrorHandlerWithOpts` MultiErrorHandler MultiErrorHandler // SilenceServersWarning allows silencing a warning for https://github.com/deepmap/oapi-codegen/issues/882 that reports when an OpenAPI spec has `spec.Servers != nil` SilenceServersWarning bool @@ -62,24 +123,96 @@ func OapiRequestValidatorWithOptions(spec *openapi3.T, options *Options) func(ne return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - - // validate request - if statusCode, err := validateRequest(r, router, options); err != nil { - if options != nil && options.ErrorHandler != nil { - options.ErrorHandler(w, err.Error(), statusCode) - } else { - http.Error(w, err.Error(), statusCode) - } - return + if options == nil { + performRequestValidationForErrorHandler(next, w, r, router, options, http.Error) + } else if options.ErrorHandlerWithOpts != nil { + performRequestValidationForErrorHandlerWithOpts(next, w, r, router, options) + } else if options.ErrorHandler != nil { + performRequestValidationForErrorHandler(next, w, r, router, options, options.ErrorHandler) + } else { + // NOTE that this shouldn't happen, but let's be sure that we always end up calling the default error handler if no other handler is defined + performRequestValidationForErrorHandler(next, w, r, router, options, http.Error) } - - // serve - next.ServeHTTP(w, r) }) } } +func performRequestValidationForErrorHandler(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options, errorHandler ErrorHandler) { + // validate request + statusCode, err := validateRequest(r, router, options) + if err == nil { + // serve + next.ServeHTTP(w, r) + return + } + + errorHandler(w, err.Error(), statusCode) +} + +// Note that this is an inline-and-modified version of `validateRequest`, with a simplified control flow and providing full access to the `error` for the `ErrorHandlerWithOpts` function. +func performRequestValidationForErrorHandlerWithOpts(next http.Handler, w http.ResponseWriter, r *http.Request, router routers.Router, options *Options) { + // Find route + route, pathParams, err := router.FindRoute(r) + if err != nil { + errOpts := ErrorHandlerOpts{ + // MatchedRoute will be nil, as we've not matched a route we know about + Error: err, + StatusCode: http.StatusNotFound, + } + + options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts) + return + } + + errOpts := ErrorHandlerOpts{ + MatchedRoute: &ErrorHandlerOptsMatchedRoute{ + Route: route, + PathParams: pathParams, + }, + // other options will be added before executing + } + + // Validate request + requestValidationInput := &openapi3filter.RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + + if options != nil { + requestValidationInput.Options = &options.Options + } + + err = openapi3filter.ValidateRequest(r.Context(), requestValidationInput) + if err == nil { + // it's a valid request, so serve it + next.ServeHTTP(w, r) + return + } + + switch e := err.(type) { + case openapi3.MultiError: + errOpts.Error = e + errOpts.StatusCode = determineStatusCodeForMultiError(e) + case *openapi3filter.RequestError: + // We've got a bad request + errOpts.Error = e + errOpts.StatusCode = http.StatusBadRequest + case *openapi3filter.SecurityRequirementsError: + errOpts.Error = e + errOpts.StatusCode = http.StatusUnauthorized + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + // return http.StatusInternalServerError, + errOpts.Error = fmt.Errorf("error validating route: %w", e) + errOpts.StatusCode = http.StatusUnauthorized + } + + options.ErrorHandlerWithOpts(r.Context(), w, r, errOpts) +} + // validateRequest is called from the middleware above and actually does the work // of validating a request. func validateRequest(r *http.Request, router routers.Router, options *Options) (int, error) { @@ -147,3 +280,35 @@ func getMultiErrorHandlerFromOptions(options *Options) MultiErrorHandler { func defaultMultiErrorHandler(me openapi3.MultiError) (int, error) { return http.StatusBadRequest, me } + +func determineStatusCodeForMultiError(errs openapi3.MultiError) int { + numRequestErrors := 0 + numSecurityRequirementsErrors := 0 + + for _, err := range errs { + switch err.(type) { + case *openapi3filter.RequestError: + numRequestErrors++ + case *openapi3filter.SecurityRequirementsError: + numSecurityRequirementsErrors++ + default: + // if we have /any/ unknown error types, we should suggest returning an HTTP 500 Internal Server Error + return http.StatusInternalServerError + } + } + + if numRequestErrors > 0 && numSecurityRequirementsErrors > 0 { + return http.StatusInternalServerError + } + + if numRequestErrors > 0 { + return http.StatusBadRequest + } + + if numSecurityRequirementsErrors > 0 { + return http.StatusUnauthorized + } + + // we shouldn't hit this, but to be safe, return an HTTP 500 Internal Server Error if we don't have any cases above + return http.StatusInternalServerError +} diff --git a/oapi_validate_example_test.go b/oapi_validate_example_test.go index d0cab9a..b905350 100644 --- a/oapi_validate_example_test.go +++ b/oapi_validate_example_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/http/httptest" + "reflect" "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" @@ -330,3 +331,490 @@ paths: // Received an HTTP 400 response. Expected HTTP 400 // Response body: This was rewritten by the ErrorHandler } + +func ExampleOapiRequestValidatorWithOptions_withErrorHandlerWithOpts() { + rawSpec := ` +openapi: "3.0.0" +info: + version: 1.0.0 + title: TestServer +servers: + - url: http://example.com/ +paths: + /resource: + post: + operationId: createResource + responses: + '204': + description: No content + requestBody: + required: true + content: + application/json: + schema: + properties: + id: + type: string + minLength: 100 + name: + type: string + enum: + - Marcin + additionalProperties: false + /protected_resource: + get: + operationId: getProtectedResource + security: + - BearerAuth: + - someScope + - BasicAuth: [] + responses: + '204': + description: no content +components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + BasicAuth: + type: http + scheme: basic +` + + must := func(err error) { + if err != nil { + panic(err) + } + } + + use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler { + var s http.Handler + s = r + + for _, mw := range middlewares { + s = mw(s) + } + + return s + } + + logResponseBody := func(rr *httptest.ResponseRecorder) { + if rr.Result().Body != nil { + data, _ := io.ReadAll(rr.Result().Body) + if len(data) > 0 { + fmt.Printf("Response body: %s", data) + } + } + } + + spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec)) + must(err) + + // NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec + // See also: Options#SilenceServersWarning + spec.Servers = nil + + router := http.NewServeMux() + + router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("%s /resource was called\n", r.Method) + + if r.Method == http.MethodPost { + w.WriteHeader(http.StatusNoContent) + return + } + + w.WriteHeader(http.StatusMethodNotAllowed) + }) + + router.HandleFunc("/protected_resource", func(w http.ResponseWriter, r *http.Request) { + // NOTE that we're setting up our `authenticationFunc` (below) to /never/ allow any requests in - so if we get a response from this endpoint, our `authenticationFunc` hasn't correctly worked + + if r.Method == http.MethodGet { + w.WriteHeader(http.StatusNoContent) + return + } + + w.WriteHeader(http.StatusMethodNotAllowed) + }) + + authenticationFunc := func(ctx context.Context, ai *openapi3filter.AuthenticationInput) error { + fmt.Printf("`AuthenticationFunc` was called for securitySchemeName=%s\n", ai.SecuritySchemeName) + return fmt.Errorf("this check always fails - don't let anyone in!") + } + + errorHandlerFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) { + err := opts.Error + + if opts.MatchedRoute == nil { + fmt.Printf("ErrorHandlerWithOpts: An HTTP %d was returned by the middleware with error message: %s\n", opts.StatusCode, err.Error()) + + // NOTE that you may want to override the default (an HTTP 400 Bad Request) to an HTTP 404 Not Found (or maybe an HTTP 405 Method Not Allowed, depending on what the requested resource was) + http.Error(w, fmt.Sprintf("No route was found (according to ErrorHandlerWithOpts), and we changed the HTTP status code to %d", http.StatusNotFound), http.StatusNotFound) + return + } + + switch e := err.(type) { + case *openapi3filter.SecurityRequirementsError: + out := fmt.Sprintf("A SecurityRequirementsError was returned when attempting to authenticate the request to %s %s against %d Security Schemes: %s\n", opts.MatchedRoute.Route.Method, opts.MatchedRoute.Route.Path, len(e.SecurityRequirements), e.Error()) + for _, sr := range e.SecurityRequirements { + for k, v := range sr { + out += fmt.Sprintf("- %s: %v\n", k, v) + } + } + + fmt.Printf("ErrorHandlerWithOpts: %s\n", out) + + http.Error(w, "You're not allowed!", opts.StatusCode) + return + case *openapi3filter.RequestError: + out := fmt.Sprintf("A RequestError was returned when attempting to validate the request to %s %s: %s\n", opts.MatchedRoute.Route.Method, opts.MatchedRoute.Route.Path, e.Error()) + + if e.RequestBody != nil { + out += "This operation has a request body, which was " + if !e.RequestBody.Required { + out += "not " + } + out += "required\n" + } + + if childErr := e.Unwrap(); childErr != nil { + out += "There was a child error, which was " + switch e := childErr.(type) { + case *openapi3.SchemaError: + out += "a SchemaError, which failed to validate on the " + e.SchemaField + " field" + default: + out += "an unknown type (" + reflect.TypeOf(e).String() + ")" + } + } + + fmt.Printf("ErrorHandlerWithOpts: %s\n", out) + + http.Error(w, "A bad request was made - but I'm not going to tell you where or how", opts.StatusCode) + return + } + + http.Error(w, err.Error(), opts.StatusCode) + } + + // create middleware + mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{ + Options: openapi3filter.Options{ + AuthenticationFunc: authenticationFunc, + }, + ErrorHandlerWithOpts: errorHandlerFunc, + }) + + // then wire it in + server := use(router, mw) + + // ================================================================================ + fmt.Println("# A request that is malformed is rejected with HTTP 400 Bad Request (with no request body), and is then logged by the ErrorHandlerWithOpts") + + req, err := http.NewRequest(http.MethodPost, "/resource", nil) + must(err) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 400\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // ================================================================================ + fmt.Println("# A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body), and is then logged by the ErrorHandlerWithOpts") + + body := map[string]string{ + "id": "not-long-enough", + } + + data, err := json.Marshal(body) + must(err) + + req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader(data)) + must(err) + req.Header.Set("Content-Type", "application/json") + + rr = httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 400\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // ================================================================================ + fmt.Println("# A request that to an unknown path is rejected with HTTP 404 Not Found, and is then logged by the ErrorHandlerWithOpts") + + req, err = http.NewRequest(http.MethodGet, "/not-a-real-path", nil) + must(err) + + rr = httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 404\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // ================================================================================ + fmt.Println("# A request to an authenticated endpoint must go through an `AuthenticationFunc`, and if it fails, an HTTP 401 is returned") + + req, err = http.NewRequest(http.MethodGet, "/protected_resource", nil) + must(err) + + rr = httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 401\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // Output: + // # A request that is malformed is rejected with HTTP 400 Bad Request (with no request body), and is then logged by the ErrorHandlerWithOpts + // ErrorHandlerWithOpts: A RequestError was returned when attempting to validate the request to POST /resource: request body has an error: value is required but missing + // This operation has a request body, which was required + // There was a child error, which was an unknown type (*errors.errorString) + // Received an HTTP 400 response. Expected HTTP 400 + // Response body: A bad request was made - but I'm not going to tell you where or how + // + // # A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body), and is then logged by the ErrorHandlerWithOpts + // ErrorHandlerWithOpts: A RequestError was returned when attempting to validate the request to POST /resource: request body has an error: doesn't match schema: Error at "/id": minimum string length is 100 + // Schema: + // { + // "minLength": 100, + // "type": "string" + // } + // + // Value: + // "not-long-enough" + // + // This operation has a request body, which was required + // There was a child error, which was a SchemaError, which failed to validate on the minLength field + // Received an HTTP 400 response. Expected HTTP 400 + // Response body: A bad request was made - but I'm not going to tell you where or how + // + // # A request that to an unknown path is rejected with HTTP 404 Not Found, and is then logged by the ErrorHandlerWithOpts + // ErrorHandlerWithOpts: An HTTP 404 was returned by the middleware with error message: no matching operation was found + // Received an HTTP 404 response. Expected HTTP 404 + // Response body: No route was found (according to ErrorHandlerWithOpts), and we changed the HTTP status code to 404 + // + // # A request to an authenticated endpoint must go through an `AuthenticationFunc`, and if it fails, an HTTP 401 is returned + // `AuthenticationFunc` was called for securitySchemeName=BearerAuth + // `AuthenticationFunc` was called for securitySchemeName=BasicAuth + // ErrorHandlerWithOpts: A SecurityRequirementsError was returned when attempting to authenticate the request to GET /protected_resource against 2 Security Schemes: security requirements failed: this check always fails - don't let anyone in! | this check always fails - don't let anyone in! + // - BearerAuth: [someScope] + // - BasicAuth: [] + // + // Received an HTTP 401 response. Expected HTTP 401 + // Response body: You're not allowed! +} + +func ExampleOapiRequestValidatorWithOptions_withErrorHandlerWithOptsAndMultiError() { + rawSpec := ` +openapi: "3.0.0" +info: + version: 1.0.0 + title: TestServer +servers: + - url: http://example.com/ +paths: + /resource: + post: + operationId: createResource + responses: + '204': + description: No content + requestBody: + required: true + content: + application/json: + schema: + properties: + id: + type: string + minLength: 100 + name: + type: string + enum: + - Marcin + additionalProperties: false +` + + must := func(err error) { + if err != nil { + panic(err) + } + } + + use := func(r *http.ServeMux, middlewares ...func(next http.Handler) http.Handler) http.Handler { + var s http.Handler + s = r + + for _, mw := range middlewares { + s = mw(s) + } + + return s + } + + logResponseBody := func(rr *httptest.ResponseRecorder) { + if rr.Result().Body != nil { + data, _ := io.ReadAll(rr.Result().Body) + if len(data) > 0 { + fmt.Printf("Response body: %s", data) + } + } + } + + spec, err := openapi3.NewLoader().LoadFromData([]byte(rawSpec)) + must(err) + + // NOTE that we need to make sure that the `Servers` aren't set, otherwise the OpenAPI validation middleware will validate that the `Host` header (of incoming requests) are targeting known `Servers` in the OpenAPI spec + // See also: Options#SilenceServersWarning + spec.Servers = nil + + router := http.NewServeMux() + + router.HandleFunc("/resource", func(w http.ResponseWriter, r *http.Request) { + fmt.Printf("%s /resource was called\n", r.Method) + + if r.Method == http.MethodPost { + w.WriteHeader(http.StatusNoContent) + return + } + + w.WriteHeader(http.StatusMethodNotAllowed) + }) + + errorHandlerFunc := func(ctx context.Context, w http.ResponseWriter, r *http.Request, opts middleware.ErrorHandlerOpts) { + err := opts.Error + + if opts.MatchedRoute == nil { + fmt.Printf("ErrorHandlerWithOpts: An HTTP %d was returned by the middleware with error message: %s\n", opts.StatusCode, err.Error()) + + // NOTE that you may want to override the default (an HTTP 400 Bad Request) to an HTTP 404 Not Found (or maybe an HTTP 405 Method Not Allowed, depending on what the requested resource was) + http.Error(w, fmt.Sprintf("No route was found (according to ErrorHandlerWithOpts), and we changed the HTTP status code to %d", http.StatusNotFound), http.StatusNotFound) + return + } + + switch e := err.(type) { + // NOTE that when it's a MultiError, there's more work needed here + case openapi3.MultiError: + var re *openapi3filter.RequestError + if e.As(&re) { + out := fmt.Sprintf("A MultiError was encountered, which contained a RequestError: %s", re) + + if re.Err != nil { + out += ", which inside it has a error of type (" + reflect.TypeOf(e).String() + ")" + } + + fmt.Printf("ErrorHandlerWithOpts: %s\n", out) + + http.Error(w, "There was a bad request", opts.StatusCode) + return + } + + var se *openapi3filter.SecurityRequirementsError + if e.As(&se) { + out := fmt.Sprintf("A MultiError was encountered, which contained a SecurityRequirementsError: %s", re) + + if len(se.Errors) > 0 { + out += fmt.Sprintf(", which contains %d child errors", len(se.Errors)) + } + + fmt.Printf("ErrorHandlerWithOpts: %s\n", out) + + http.Error(w, "There was an unauthorized request", opts.StatusCode) + return + } + } + + http.Error(w, err.Error(), opts.StatusCode) + } + + // create middleware + mw := middleware.OapiRequestValidatorWithOptions(spec, &middleware.Options{ + Options: openapi3filter.Options{ + // make sure that multiple errors in a given request are returned + MultiError: true, + }, + ErrorHandlerWithOpts: errorHandlerFunc, + }) + + // then wire it in + server := use(router, mw) + + // ================================================================================ + fmt.Println("# A request that is malformed is rejected with HTTP 400 Bad Request (with no request body), and is then logged by the ErrorHandlerWithOpts") + + req, err := http.NewRequest(http.MethodPost, "/resource", nil) + must(err) + req.Header.Set("Content-Type", "application/json") + + rr := httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 400\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // ================================================================================ + fmt.Println("# A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body, with multiple issues), and is then logged by the ErrorHandlerWithOpts") + + body := map[string]string{ + "id": "not-long-enough", + "name": "Jamie", + } + + data, err := json.Marshal(body) + must(err) + + req, err = http.NewRequest(http.MethodPost, "/resource", bytes.NewReader(data)) + must(err) + req.Header.Set("Content-Type", "application/json") + + rr = httptest.NewRecorder() + + server.ServeHTTP(rr, req) + + fmt.Printf("Received an HTTP %d response. Expected HTTP 400\n", rr.Code) + logResponseBody(rr) + fmt.Println() + + // Output: + // # A request that is malformed is rejected with HTTP 400 Bad Request (with no request body), and is then logged by the ErrorHandlerWithOpts + // ErrorHandlerWithOpts: A MultiError was encountered, which contained a RequestError: request body has an error: value is required but missing, which inside it has a error of type (openapi3.MultiError) + // Received an HTTP 400 response. Expected HTTP 400 + // Response body: There was a bad request + // + // # A request that is malformed is rejected with HTTP 400 Bad Request (with an invalid request body, with multiple issues), and is then logged by the ErrorHandlerWithOpts + // ErrorHandlerWithOpts: A MultiError was encountered, which contained a RequestError: request body has an error: doesn't match schema: Error at "/id": minimum string length is 100 + // Schema: + // { + // "minLength": 100, + // "type": "string" + // } + // + // Value: + // "not-long-enough" + // | Error at "/name": value is not one of the allowed values ["Marcin"] + // Schema: + // { + // "enum": [ + // "Marcin" + // ], + // "type": "string" + // } + // + // Value: + // "Jamie" + // , which inside it has a error of type (openapi3.MultiError) + // Received an HTTP 400 response. Expected HTTP 400 + // Response body: There was a bad request +} diff --git a/oapi_validate_test.go b/oapi_validate_test.go new file mode 100644 index 0000000..c07b17c --- /dev/null +++ b/oapi_validate_test.go @@ -0,0 +1,65 @@ +package nethttpmiddleware + +import ( + "fmt" + "testing" + + "github.com/getkin/kin-openapi/openapi3filter" +) + +func Test_determineStatusCodeForMultiError(t *testing.T) { + t.Run("returns HTTP 400 Bad Request when only `RequestError`s", func(t *testing.T) { + errs := []error{ + &openapi3filter.RequestError{}, + &openapi3filter.RequestError{}, + } + + expected := 400 + actual := determineStatusCodeForMultiError(errs) + + if expected != actual { + t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual) + } + }) + + t.Run("returns HTTP 401 Unauthorized when only `SecurityRequirementsError`s", func(t *testing.T) { + errs := []error{ + &openapi3filter.SecurityRequirementsError{}, + &openapi3filter.SecurityRequirementsError{}, + } + + expected := 401 + actual := determineStatusCodeForMultiError(errs) + + if expected != actual { + t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual) + } + }) + + t.Run("returns HTTP 500 Internal Server Error when mixed error types", func(t *testing.T) { + errs := []error{ + &openapi3filter.RequestError{}, + &openapi3filter.SecurityRequirementsError{}, + } + + expected := 500 + actual := determineStatusCodeForMultiError(errs) + + if expected != actual { + t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual) + } + }) + + t.Run("returns HTTP 500 Internal Server Error when unknown error type(s) are seen", func(t *testing.T) { + errs := []error{ + fmt.Errorf("this isn't a known error type"), + } + + expected := 500 + actual := determineStatusCodeForMultiError(errs) + + if expected != actual { + t.Errorf("Expected an HTTP %d to be returned, but received %d", expected, actual) + } + }) +}