diff --git a/_examples/auto_params/main.go b/_examples/auto_params/main.go index 3151dd7..709c00a 100644 --- a/_examples/auto_params/main.go +++ b/_examples/auto_params/main.go @@ -15,6 +15,11 @@ type SearchInput struct { Age int `query:"age" validate:"omitempty,min=0,max=120"` Active bool `query:"active"` MinPrice float64 `query:"minPrice" validate:"omitempty,min=0"` + + // Pointer types - automatically optional and nullable + Category *string `query:"category"` + MaxResults *int `query:"maxResults"` + IncludeInactive *bool `query:"includeInactive"` } type SearchOutput struct { @@ -23,12 +28,13 @@ type SearchOutput struct { } type User struct { - ID int `json:"id"` - Name string `json:"name"` - Email string `json:"email"` - Age int `json:"age"` - Active bool `json:"active"` - Price float64 `json:"price"` + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + Age int `json:"age"` + Active bool `json:"active"` + Price float64 `json:"price"` + Category *string `json:"category,omitempty"` } type ErrorResponse struct { @@ -44,12 +50,13 @@ func main() { fiberoapi.Get(oapi, "/users/:name", func(c *fiber.Ctx, input SearchInput) (SearchOutput, ErrorResponse) { // Simulate search results user := User{ - ID: 1, - Name: input.Name, - Email: input.Email, - Age: input.Age, - Active: input.Active, - Price: input.MinPrice, + ID: 1, + Name: input.Name, + Email: input.Email, + Age: input.Age, + Active: input.Active, + Price: input.MinPrice, + Category: input.Category, // Pointer field - can be nil } return SearchOutput{ @@ -86,7 +93,8 @@ func main() { fmt.Println("\n📖 Documentation disponible sur http://localhost:3000/docs") fmt.Println("📊 Spec OpenAPI JSON sur http://localhost:3000/openapi.json") - fmt.Println("🧪 Test de l'endpoint : http://localhost:3000/users/john?email=john@example.com&age=25&active=true&minPrice=10.5") + fmt.Println("🧪 Test de l'endpoint : http://localhost:3000/users/john?email=john@example.com&age=25&active=true&minPrice=10.5&category=electronics&maxResults=10") + fmt.Println("🔧 Paramètres optionnels (pointeurs) : category, maxResults, includeInactive") log.Fatal(app.Listen(":3000")) } diff --git a/auth.go b/auth.go index fa11547..ab8a78f 100644 --- a/auth.go +++ b/auth.go @@ -176,7 +176,7 @@ func validateResourceAccess(c *fiber.Ctx, authCtx *AuthContext, input interface{ inputValue := reflect.ValueOf(input) inputType := reflect.TypeOf(input) - if inputType.Kind() == reflect.Ptr { + if isPointerType(inputType) { inputValue = inputValue.Elem() inputType = inputType.Elem() } diff --git a/auto_params_test.go b/auto_params_test.go index aa6e57e..57912a5 100644 --- a/auto_params_test.go +++ b/auto_params_test.go @@ -236,3 +236,61 @@ func TestNoParametersWhenNoStruct(t *testing.T) { _, hasParams := getOp["parameters"] assert.False(t, hasParams, "Should not have parameters when no input struct") } + +func TestAutoParamsPointerTypesInline(t *testing.T) { + app := fiber.New() + oapi := New(app) + + type PointerTestInput struct { + ID string `path:"id" validate:"required"` + OptionalName *string `query:"optionalName"` + RequiredName string `query:"requiredName" validate:"required"` + OmitEmpty string `query:"omitEmpty" validate:"omitempty"` + } + + type PointerTestOutput struct { + Message string `json:"message"` + } + + type PointerTestError struct { + Code int `json:"code"` + } + + Get(oapi, "/pointer/:id", func(c *fiber.Ctx, input PointerTestInput) (PointerTestOutput, PointerTestError) { + return PointerTestOutput{Message: "ok"}, PointerTestError{} + }, OpenAPIOptions{ + OperationID: "testPointerTypes", + Summary: "Test pointer types in parameters", + }) + + spec := oapi.GenerateOpenAPISpec() + paths := spec["paths"].(map[string]interface{}) + pointerPath := paths["/pointer/{id}"].(map[string]interface{}) + getOp := pointerPath["get"].(map[string]interface{}) + parameters := getOp["parameters"].([]map[string]interface{}) + + // Should have 4 parameters + assert.Len(t, parameters, 4, "Should have 4 parameters") + + paramMap := make(map[string]map[string]interface{}) + for _, param := range parameters { + if name, ok := param["name"].(string); ok { + paramMap[name] = param + } + } + + // Check pointer type is optional and nullable + optionalNameParam := paramMap["optionalName"] + assert.False(t, optionalNameParam["required"].(bool), "Pointer types should be optional by default") + if schema, ok := optionalNameParam["schema"].(map[string]interface{}); ok { + assert.True(t, schema["nullable"].(bool), "Pointer types should be nullable") + } + + // Check required field + requiredNameParam := paramMap["requiredName"] + assert.True(t, requiredNameParam["required"].(bool), "Fields with validate:required should be required") + + // Check omitempty field + omitEmptyParam := paramMap["omitEmpty"] + assert.False(t, omitEmptyParam["required"].(bool), "Fields with omitempty should be optional") +} diff --git a/common.go b/common.go index 91f53a9..900ae40 100644 --- a/common.go +++ b/common.go @@ -196,7 +196,7 @@ func validatePathParams[T any](path string) error { inputType := reflect.TypeOf(zero) // If the type is a pointer, get the element type - if inputType != nil && inputType.Kind() == reflect.Ptr { + if inputType != nil && isPointerType(inputType) { inputType = inputType.Elem() } @@ -263,9 +263,7 @@ func extractParametersFromStruct(inputType reflect.Type) []map[string]interface{ } // Handle pointer types - if inputType.Kind() == reflect.Ptr { - inputType = inputType.Elem() - } + inputType = dereferenceType(inputType) // Only process struct types if inputType.Kind() != reflect.Struct { @@ -282,6 +280,9 @@ func extractParametersFromStruct(inputType reflect.Type) []map[string]interface{ // Process path parameters if pathTag := field.Tag.Get("path"); pathTag != "" { + // Path parameters are always required regardless of type or validation tags. + // This follows OpenAPI 3.0 specification where path parameters must be required, + // and is enforced here by explicitly setting "required": true at line 289. param := map[string]interface{}{ "name": pathTag, "in": "path", @@ -294,7 +295,8 @@ func extractParametersFromStruct(inputType reflect.Type) []map[string]interface{ // Process query parameters if queryTag := field.Tag.Get("query"); queryTag != "" { - required := isFieldRequired(field) + // Query parameters use specialized logic based on type and validation tags + required := isQueryFieldRequired(field) param := map[string]interface{}{ "name": queryTag, "in": "query", @@ -322,14 +324,46 @@ func getFieldDescription(field reflect.StructField, defaultDesc string) string { return fmt.Sprintf("%s: %s", defaultDesc, field.Name) } -// isFieldRequired checks if a field is required based on validation tags -func isFieldRequired(field reflect.StructField) bool { +// isPointerType checks if a reflect.Type is a pointer type +func isPointerType(t reflect.Type) bool { + return t.Kind() == reflect.Ptr +} + +// isPointerField checks if a reflect.StructField is a pointer type +func isPointerField(field reflect.StructField) bool { + return isPointerType(field.Type) +} + +// dereferenceType removes pointer indirection from a type +func dereferenceType(t reflect.Type) reflect.Type { + if isPointerType(t) { + return t.Elem() + } + return t +} + +// isQueryFieldRequired checks if a query parameter field is required +// Query parameters have different logic than path parameters: +// - Path parameters are always required (handled separately) +// - Pointer types (*string, *int, etc.) are optional by default +// - Non-pointer types are optional by default unless explicitly marked as required +// - Fields with "omitempty" are optional +// - Fields with "required" are required +func isQueryFieldRequired(field reflect.StructField) bool { validateTag := field.Tag.Get("validate") - if validateTag == "" { + + // If it's a pointer type, it's optional by default (unless explicitly required) + if isPointerField(field) { + return strings.Contains(validateTag, "required") + } + + // For non-pointer types in query parameters: + // - If has omitempty, it's optional + if strings.Contains(validateTag, "omitempty") { return false } - // Check for required validation + // Check for explicit required validation return strings.Contains(validateTag, "required") } @@ -337,10 +371,9 @@ func isFieldRequired(field reflect.StructField) bool { func getSchemaForType(t reflect.Type) map[string]interface{} { schema := make(map[string]interface{}) - // Handle pointer types - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + // Handle pointer types - preserve original to detect nullability, then dereference for type checking + originalType := t + t = dereferenceType(t) switch t.Kind() { case reflect.String: @@ -372,6 +405,11 @@ func getSchemaForType(t reflect.Type) map[string]interface{} { schema["type"] = "string" } + // If the original type was a pointer, indicate it's nullable + if isPointerType(originalType) { + schema["nullable"] = true + } + return schema } diff --git a/fiberoapi.go b/fiberoapi.go index fa61bfd..d507dda 100644 --- a/fiberoapi.go +++ b/fiberoapi.go @@ -241,10 +241,7 @@ func (o *OApiApp) GenerateOpenAPISpec() map[string]interface{} { // Add request body schema for POST/PUT methods if op.Method == "POST" || op.Method == "PUT" || op.Method == "PATCH" { if op.InputType != nil { - inputType := op.InputType - if inputType.Kind() == reflect.Ptr { - inputType = inputType.Elem() - } + inputType := dereferenceType(op.InputType) var schemaRef map[string]interface{} @@ -274,10 +271,7 @@ func (o *OApiApp) GenerateOpenAPISpec() map[string]interface{} { // Success response (200) if op.OutputType != nil { - outputType := op.OutputType - if outputType.Kind() == reflect.Ptr { - outputType = outputType.Elem() - } + outputType := dereferenceType(op.OutputType) var schemaRef map[string]interface{} @@ -341,9 +335,7 @@ func collectAllTypes(t reflect.Type, collected map[string]reflect.Type) { } // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) typeName := getTypeName(t) if typeName == "" { @@ -418,9 +410,7 @@ func shouldGenerateSchemaForType(t reflect.Type) bool { } // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) switch t.Kind() { case reflect.Struct: @@ -428,19 +418,13 @@ func shouldGenerateSchemaForType(t reflect.Type) bool { return t.NumField() > 0 case reflect.Map: // Generate schema for maps with complex value types - valueType := t.Elem() - if valueType.Kind() == reflect.Ptr { - valueType = valueType.Elem() - } + valueType := dereferenceType(t.Elem()) return valueType.Kind() == reflect.Struct || valueType.Kind() == reflect.Map || valueType.Kind() == reflect.Slice case reflect.Slice: // Generate schema for slices of complex types - elemType := t.Elem() - if elemType.Kind() == reflect.Ptr { - elemType = elemType.Elem() - } + elemType := dereferenceType(t.Elem()) return elemType.Kind() == reflect.Struct || elemType.Kind() == reflect.Map case reflect.Interface: @@ -474,9 +458,7 @@ func getTypeName(t reflect.Type) string { } // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) // Handle different kinds of types switch t.Kind() { @@ -517,9 +499,7 @@ func getSimpleTypeName(t reflect.Type) string { } // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) switch t.Kind() { case reflect.String: @@ -557,9 +537,7 @@ func generateSchema(t reflect.Type) map[string]interface{} { } // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) schema := make(map[string]interface{}) @@ -681,9 +659,7 @@ func generateFieldSchema(t reflect.Type) map[string]interface{} { schema := make(map[string]interface{}) // Handle pointers - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) switch t.Kind() { case reflect.String: @@ -806,9 +782,7 @@ func isEmptyStruct(t reflect.Type) bool { return true } - if t.Kind() == reflect.Ptr { - t = t.Elem() - } + t = dereferenceType(t) return t.Kind() == reflect.Struct && t.NumField() == 0 }