diff --git a/internal/cursor/cursor.go b/internal/cursor/cursor.go new file mode 100644 index 0000000..b0b771c --- /dev/null +++ b/internal/cursor/cursor.go @@ -0,0 +1,94 @@ +// Package cursor implements logic for paging. +package cursor + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" +) + +// Cursor represents pagination state with offset for next offset. +// +// We want to follow a pattern defined in MCP specification: +// https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/pagination +type Cursor struct { + Offset int32 `json:"offset"` +} + +// New creates and validates a new Cursor. +func New(offset int32) (*Cursor, error) { + cursor := &Cursor{ + Offset: offset, + } + + if err := cursor.validate(); err != nil { + return nil, err + } + + return cursor, nil +} + +// Encode serializes the cursor to a Base64-encoded string. +func (c *Cursor) Encode() (string, error) { + if err := c.validate(); err != nil { + return "", err + } + + jsonBytes, err := json.Marshal(c) + if err != nil { + return "", fmt.Errorf("failed to marshal cursor: %w", err) + } + + encoded := base64.StdEncoding.EncodeToString(jsonBytes) + + return encoded, nil +} + +// Decode deserializes a Base64-encoded string to a Cursor. +func Decode(encoded string) (*Cursor, error) { + if encoded == "" { + return nil, errors.New("encoded cursor cannot be empty") + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("invalid base64 encoding: %w", err) + } + + var cursor Cursor + if err := json.Unmarshal(decoded, &cursor); err != nil { + return nil, fmt.Errorf("invalid cursor format: %w", err) + } + + if err := cursor.validate(); err != nil { + return nil, err + } + + return &cursor, nil +} + +// GetOffset returns offset that can be used for API call. +func (c *Cursor) GetOffset() int32 { + return c.Offset +} + +// GetNextCursor returns cursor for the next offset. +func (c *Cursor) GetNextCursor(limit int32) *Cursor { + if limit < 0 || c.Offset+limit < 0 { + limit = 0 + } + + return &Cursor{ + Offset: c.Offset + limit, + } +} + +// validate checks if the cursor has valid values. +func (c *Cursor) validate() error { + if c.Offset < 0 { + return errors.New("offset must be non-negative") + } + + return nil +} diff --git a/internal/cursor/cursor_test.go b/internal/cursor/cursor_test.go new file mode 100644 index 0000000..74f6928 --- /dev/null +++ b/internal/cursor/cursor_test.go @@ -0,0 +1,149 @@ +package cursor + +import ( + "encoding/base64" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNew_ValidInput(t *testing.T) { + tests := map[string]struct { + offset int32 + }{ + "offset 0": { + offset: 0, + }, + "offset 5": { + offset: 5, + }, + "large values": { + offset: 1000000, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + cursor, err := New(testCase.offset) + require.NoError(t, err) + require.NotNil(t, cursor) + assert.Equal(t, testCase.offset, cursor.Offset) + }) + } +} + +func TestNew_InvalidInput(t *testing.T) { + cursor, err := New(-1) + require.Error(t, err) + assert.Nil(t, cursor) + assert.Contains(t, err.Error(), "offset must be non-negative") +} + +func TestDecode_Success(t *testing.T) { + original := &Cursor{Offset: 1} + encoded, err := original.Encode() + require.NoError(t, err) + + decoded, err := Decode(encoded) + require.NoError(t, err) + require.NotNil(t, decoded) + assert.Equal(t, original.Offset, decoded.Offset) +} + +func TestDecode_InvalidInput(t *testing.T) { + tests := map[string]struct { + encoded string + expectedError string + }{ + "empty string": { + encoded: "", + expectedError: "encoded cursor cannot be empty", + }, + "invalid base64": { + encoded: "not-base64!@#$%", + expectedError: "invalid base64 encoding", + }, + "invalid json": { + encoded: base64.StdEncoding.EncodeToString([]byte("not json")), + expectedError: "invalid cursor format", + }, + "valid json but invalid cursor - negative offset": { + encoded: base64.StdEncoding.EncodeToString([]byte(`{"offset":-1}`)), + expectedError: "offset must be non-negative", + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + decoded, err := Decode(testCase.encoded) + require.Error(t, err) + assert.Nil(t, decoded) + assert.Contains(t, err.Error(), testCase.expectedError) + }) + } +} + +func TestEncodeDecode_RoundTrip(t *testing.T) { + tests := map[string]struct { + offset int32 + }{ + "zero offset": { + offset: 0, + }, + "non-zero offset": { + offset: 5, + }, + "large offset": { + offset: 10000, + }, + } + + for testName, testCase := range tests { + t.Run(testName, func(t *testing.T) { + original, err := New(testCase.offset) + require.NoError(t, err) + + encoded, err := original.Encode() + require.NoError(t, err) + assert.NotEmpty(t, encoded) + + decoded, err := Decode(encoded) + require.NoError(t, err) + require.NotNil(t, decoded) + + assert.Equal(t, original.Offset, decoded.Offset) + }) + } +} + +func TestEncode_InvalidInput(t *testing.T) { + cursor := &Cursor{Offset: -1} + encoded, err := cursor.Encode() + + require.Error(t, err) + assert.Empty(t, encoded) + assert.Contains(t, err.Error(), "offset must be non-negative") +} + +func TestGetOffset(t *testing.T) { + cursor := &Cursor{Offset: 1} + assert.Equal(t, cursor.Offset, cursor.GetOffset()) +} + +func TestGetNextCursor(t *testing.T) { + cursor := &Cursor{Offset: 0} + + cursorStep1 := cursor.GetNextCursor(10) + assert.Equal(t, int32(10), cursorStep1.GetOffset()) + + cursorStep2 := cursorStep1.GetNextCursor(5) + assert.Equal(t, int32(15), cursorStep2.GetOffset()) + + cursorNegativeLimit := cursorStep2.GetNextCursor(-1) + assert.Equal(t, int32(15), cursorNegativeLimit.GetOffset(), "negative limit should not change offset") + + cursorOverflow := cursorStep2.GetNextCursor(math.MaxInt32) + assert.Equal(t, int32(15), cursorOverflow.GetOffset(), "overflow paging should not change offset") +} diff --git a/internal/toolsets/vulnerability/tools.go b/internal/toolsets/vulnerability/tools.go index 70b1660..09946fc 100644 --- a/internal/toolsets/vulnerability/tools.go +++ b/internal/toolsets/vulnerability/tools.go @@ -11,13 +11,13 @@ import ( v1 "github.com/stackrox/rox/generated/api/v1" "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/client/auth" + "github.com/stackrox/stackrox-mcp/internal/cursor" "github.com/stackrox/stackrox-mcp/internal/logging" "github.com/stackrox/stackrox-mcp/internal/toolsets" ) const ( - defaultLimit = 50 - maximumLimit = 200.0 + defaultLimit = 100 ) type filterPlatformType string @@ -34,8 +34,7 @@ type getDeploymentsForCVEInput struct { FilterClusterID string `json:"filterClusterId,omitempty"` FilterNamespace string `json:"filterNamespace,omitempty"` FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"` - Offset int32 `json:"offset,omitempty"` - Limit int32 `json:"limit,omitempty"` + Cursor string `json:"cursor,omitempty"` } func (input *getDeploymentsForCVEInput) validate() error { @@ -57,6 +56,7 @@ type DeploymentResult struct { // getDeploymentsForCVEOutput defines the output structure for get_deployments_for_cve tool. type getDeploymentsForCVEOutput struct { Deployments []DeploymentResult `json:"deployments"` + NextCursor string `json:"nextCursor"` } // getDeploymentsForCVETool implements the get_deployments_for_cve tool. @@ -118,14 +118,7 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema { filterPlatformPlatform, } - schema.Properties["offset"].Description = "Pagination offset (default: 0)" - schema.Properties["offset"].Default = toolsets.MustJSONMarshal(0) - schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0) - - schema.Properties["limit"].Description = "Pagination limit: minimum: 1, maximum: 200 (default: 50)" - schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit) - schema.Properties["limit"].Minimum = jsonschema.Ptr(1.0) - schema.Properties["limit"].Maximum = jsonschema.Ptr(maximumLimit) + schema.Properties["cursor"].Description = "Cursor for next page provided by server" return schema } @@ -160,6 +153,21 @@ func buildQuery(input getDeploymentsForCVEInput) string { return strings.Join(queryParts, "+") } +func getCursor(input *getDeploymentsForCVEInput) (*cursor.Cursor, error) { + if input.Cursor == "" { + startCursor, err := cursor.New(0) + + return startCursor, errors.Wrap(err, "error creating starting cursor") + } + + currCursor, err := cursor.Decode(input.Cursor) + if err != nil { + return nil, errors.Wrap(err, "error decoding cursor") + } + + return currCursor, nil +} + // handle is the handler for get_deployments_for_cve tool. func (t *getDeploymentsForCVETool) handle( ctx context.Context, @@ -171,6 +179,11 @@ func (t *getDeploymentsForCVETool) handle( return nil, nil, err } + currCursor, err := getCursor(&input) + if err != nil { + return nil, nil, err + } + conn, err := t.client.ReadyConn(ctx) if err != nil { return nil, nil, errors.Wrap(err, "unable to connect to server") @@ -182,8 +195,8 @@ func (t *getDeploymentsForCVETool) handle( listReq := &v1.RawQuery{ Query: buildQuery(input), Pagination: &v1.Pagination{ - Offset: input.Offset, - Limit: input.Limit, + Offset: currCursor.GetOffset(), + Limit: defaultLimit + 1, }, } @@ -202,8 +215,19 @@ func (t *getDeploymentsForCVETool) handle( }) } + // We always fetch limit+1 - if we do not have one additional element we can end paging. + if len(deployments) <= defaultLimit { + return nil, &getDeploymentsForCVEOutput{Deployments: deployments}, nil + } + + nextCursorStr, err := currCursor.GetNextCursor(defaultLimit).Encode() + if err != nil { + return nil, nil, errors.Wrap(err, "unable to create next cursor") + } + output := &getDeploymentsForCVEOutput{ - Deployments: deployments, + Deployments: deployments[:len(deployments)-1], + NextCursor: nextCursorStr, } return nil, output, nil diff --git a/internal/toolsets/vulnerability/tools_test.go b/internal/toolsets/vulnerability/tools_test.go index beca9aa..a8446ee 100644 --- a/internal/toolsets/vulnerability/tools_test.go +++ b/internal/toolsets/vulnerability/tools_test.go @@ -11,6 +11,7 @@ import ( "github.com/stackrox/rox/generated/storage" "github.com/stackrox/stackrox-mcp/internal/client" "github.com/stackrox/stackrox-mcp/internal/config" + "github.com/stackrox/stackrox-mcp/internal/cursor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" @@ -220,7 +221,7 @@ func TestHandle_MissingCVE(t *testing.T) { func TestHandle_WithPagination(t *testing.T) { mockService := &mockDeploymentService{ - deployments: getTestDeployments(5), + deployments: getTestDeployments(defaultLimit + 1), } grpcServer, listener := setupMockDeploymentServer(mockService) @@ -230,12 +231,17 @@ func TestHandle_WithPagination(t *testing.T) { tool, ok := NewGetDeploymentsForCVETool(testClient).(*getDeploymentsForCVETool) require.True(t, ok) + currCursor, err := cursor.New(2) + require.NoError(t, err) + + currCursorStr, err := currCursor.Encode() + require.NoError(t, err) + ctx := context.Background() req := &mcp.CallToolRequest{} input := getDeploymentsForCVEInput{ CVEName: "CVE-2021-44228", - Offset: 3, - Limit: 19, + Cursor: currCursorStr, } result, output, err := tool.handle(ctx, req, input) @@ -244,9 +250,14 @@ func TestHandle_WithPagination(t *testing.T) { require.NotNil(t, output) assert.Nil(t, result) - assert.Len(t, output.Deployments, 5) - assert.Equal(t, int32(3), mockService.lastCallOffset) - assert.Equal(t, int32(19), mockService.lastCallLimit) + assert.Len(t, output.Deployments, defaultLimit) + assert.Equal(t, int32(2), mockService.lastCallOffset) + assert.Equal(t, int32(defaultLimit+1), mockService.lastCallLimit) + + nextCursor := currCursor.GetNextCursor(defaultLimit) + returnedCursor, err := cursor.Decode(output.NextCursor) + require.NoError(t, err) + assert.Equal(t, nextCursor.GetOffset(), returnedCursor.GetOffset()) } func TestHandle_EmptyResults(t *testing.T) { @@ -265,7 +276,6 @@ func TestHandle_EmptyResults(t *testing.T) { req := &mcp.CallToolRequest{} input := getDeploymentsForCVEInput{ CVEName: "CVE-9999-99999", - Limit: defaultLimit, } result, output, err := tool.handle(ctx, req, input) @@ -292,7 +302,6 @@ func TestHandle_ListDeploymentsError(t *testing.T) { req := &mcp.CallToolRequest{} input := getDeploymentsForCVEInput{ CVEName: "CVE-2021-44228", - Limit: defaultLimit, } result, output, err := tool.handle(ctx, req, input)