Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions internal/cursor/cursor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Package cursor implements logic for paging.
package cursor

import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
)

// Cursor represents pagination state with offset for next offset.
//
// We want to follow a pattern defined in MCP specification:
// https://modelcontextprotocol.io/specification/2025-11-25/server/utilities/pagination
type Cursor struct {
Offset int32 `json:"offset"`
}

// New creates and validates a new Cursor.
func New(offset int32) (*Cursor, error) {
cursor := &Cursor{
Offset: offset,
}

if err := cursor.validate(); err != nil {
return nil, err
}

return cursor, nil
}

// Encode serializes the cursor to a Base64-encoded string.
func (c *Cursor) Encode() (string, error) {
if err := c.validate(); err != nil {
return "", err
}

jsonBytes, err := json.Marshal(c)
if err != nil {
return "", fmt.Errorf("failed to marshal cursor: %w", err)
}

encoded := base64.StdEncoding.EncodeToString(jsonBytes)

return encoded, nil
}

// Decode deserializes a Base64-encoded string to a Cursor.
func Decode(encoded string) (*Cursor, error) {
if encoded == "" {
return nil, errors.New("encoded cursor cannot be empty")
}

decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
}

var cursor Cursor
if err := json.Unmarshal(decoded, &cursor); err != nil {
return nil, fmt.Errorf("invalid cursor format: %w", err)
}

if err := cursor.validate(); err != nil {
return nil, err
}

return &cursor, nil
}

// GetOffset returns offset that can be used for API call.
func (c *Cursor) GetOffset() int32 {
return c.Offset
}

// GetNextCursor returns cursor for the next offset.
func (c *Cursor) GetNextCursor(limit int32) *Cursor {
if limit < 0 || c.Offset+limit < 0 {
limit = 0
}

return &Cursor{
Offset: c.Offset + limit,
}
}

// validate checks if the cursor has valid values.
func (c *Cursor) validate() error {
if c.Offset < 0 {
return errors.New("offset must be non-negative")
}

return nil
}
149 changes: 149 additions & 0 deletions internal/cursor/cursor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package cursor

import (
"encoding/base64"
"math"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNew_ValidInput(t *testing.T) {
tests := map[string]struct {
offset int32
}{
"offset 0": {
offset: 0,
},
"offset 5": {
offset: 5,
},
"large values": {
offset: 1000000,
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
cursor, err := New(testCase.offset)
require.NoError(t, err)
require.NotNil(t, cursor)
assert.Equal(t, testCase.offset, cursor.Offset)
})
}
}

func TestNew_InvalidInput(t *testing.T) {
cursor, err := New(-1)
require.Error(t, err)
assert.Nil(t, cursor)
assert.Contains(t, err.Error(), "offset must be non-negative")
}

func TestDecode_Success(t *testing.T) {
original := &Cursor{Offset: 1}
encoded, err := original.Encode()
require.NoError(t, err)

decoded, err := Decode(encoded)
require.NoError(t, err)
require.NotNil(t, decoded)
assert.Equal(t, original.Offset, decoded.Offset)
}

func TestDecode_InvalidInput(t *testing.T) {
tests := map[string]struct {
encoded string
expectedError string
}{
"empty string": {
encoded: "",
expectedError: "encoded cursor cannot be empty",
},
"invalid base64": {
encoded: "not-base64!@#$%",
expectedError: "invalid base64 encoding",
},
"invalid json": {
encoded: base64.StdEncoding.EncodeToString([]byte("not json")),
expectedError: "invalid cursor format",
},
"valid json but invalid cursor - negative offset": {
encoded: base64.StdEncoding.EncodeToString([]byte(`{"offset":-1}`)),
expectedError: "offset must be non-negative",
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
decoded, err := Decode(testCase.encoded)
require.Error(t, err)
assert.Nil(t, decoded)
assert.Contains(t, err.Error(), testCase.expectedError)
})
}
}

func TestEncodeDecode_RoundTrip(t *testing.T) {
tests := map[string]struct {
offset int32
}{
"zero offset": {
offset: 0,
},
"non-zero offset": {
offset: 5,
},
"large offset": {
offset: 10000,
},
}

for testName, testCase := range tests {
t.Run(testName, func(t *testing.T) {
original, err := New(testCase.offset)
require.NoError(t, err)

encoded, err := original.Encode()
require.NoError(t, err)
assert.NotEmpty(t, encoded)

decoded, err := Decode(encoded)
require.NoError(t, err)
require.NotNil(t, decoded)

assert.Equal(t, original.Offset, decoded.Offset)
})
}
}

func TestEncode_InvalidInput(t *testing.T) {
cursor := &Cursor{Offset: -1}
encoded, err := cursor.Encode()

require.Error(t, err)
assert.Empty(t, encoded)
assert.Contains(t, err.Error(), "offset must be non-negative")
}

func TestGetOffset(t *testing.T) {
cursor := &Cursor{Offset: 1}
assert.Equal(t, cursor.Offset, cursor.GetOffset())
}

func TestGetNextCursor(t *testing.T) {
cursor := &Cursor{Offset: 0}

cursorStep1 := cursor.GetNextCursor(10)
assert.Equal(t, int32(10), cursorStep1.GetOffset())

cursorStep2 := cursorStep1.GetNextCursor(5)
assert.Equal(t, int32(15), cursorStep2.GetOffset())

cursorNegativeLimit := cursorStep2.GetNextCursor(-1)
assert.Equal(t, int32(15), cursorNegativeLimit.GetOffset(), "negative limit should not change offset")

cursorOverflow := cursorStep2.GetNextCursor(math.MaxInt32)
assert.Equal(t, int32(15), cursorOverflow.GetOffset(), "overflow paging should not change offset")
}
54 changes: 39 additions & 15 deletions internal/toolsets/vulnerability/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import (
v1 "github.com/stackrox/rox/generated/api/v1"
"github.com/stackrox/stackrox-mcp/internal/client"
"github.com/stackrox/stackrox-mcp/internal/client/auth"
"github.com/stackrox/stackrox-mcp/internal/cursor"
"github.com/stackrox/stackrox-mcp/internal/logging"
"github.com/stackrox/stackrox-mcp/internal/toolsets"
)

const (
defaultLimit = 50
maximumLimit = 200.0
defaultLimit = 100
)

type filterPlatformType string
Expand All @@ -34,8 +34,7 @@ type getDeploymentsForCVEInput struct {
FilterClusterID string `json:"filterClusterId,omitempty"`
FilterNamespace string `json:"filterNamespace,omitempty"`
FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"`
Offset int32 `json:"offset,omitempty"`
Limit int32 `json:"limit,omitempty"`
Cursor string `json:"cursor,omitempty"`
}

func (input *getDeploymentsForCVEInput) validate() error {
Expand All @@ -57,6 +56,7 @@ type DeploymentResult struct {
// getDeploymentsForCVEOutput defines the output structure for get_deployments_for_cve tool.
type getDeploymentsForCVEOutput struct {
Deployments []DeploymentResult `json:"deployments"`
NextCursor string `json:"nextCursor"`
}

// getDeploymentsForCVETool implements the get_deployments_for_cve tool.
Expand Down Expand Up @@ -118,14 +118,7 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema {
filterPlatformPlatform,
}

schema.Properties["offset"].Description = "Pagination offset (default: 0)"
schema.Properties["offset"].Default = toolsets.MustJSONMarshal(0)
schema.Properties["limit"].Minimum = jsonschema.Ptr(0.0)

schema.Properties["limit"].Description = "Pagination limit: minimum: 1, maximum: 200 (default: 50)"
schema.Properties["limit"].Default = toolsets.MustJSONMarshal(defaultLimit)
schema.Properties["limit"].Minimum = jsonschema.Ptr(1.0)
schema.Properties["limit"].Maximum = jsonschema.Ptr(maximumLimit)
schema.Properties["cursor"].Description = "Cursor for next page provided by server"

return schema
}
Expand Down Expand Up @@ -160,6 +153,21 @@ func buildQuery(input getDeploymentsForCVEInput) string {
return strings.Join(queryParts, "+")
}

func getCursor(input *getDeploymentsForCVEInput) (*cursor.Cursor, error) {
if input.Cursor == "" {
startCursor, err := cursor.New(0)

return startCursor, errors.Wrap(err, "error creating starting cursor")
}

currCursor, err := cursor.Decode(input.Cursor)
if err != nil {
return nil, errors.Wrap(err, "error decoding cursor")
}

return currCursor, nil
}

// handle is the handler for get_deployments_for_cve tool.
func (t *getDeploymentsForCVETool) handle(
ctx context.Context,
Expand All @@ -171,6 +179,11 @@ func (t *getDeploymentsForCVETool) handle(
return nil, nil, err
}

currCursor, err := getCursor(&input)
if err != nil {
return nil, nil, err
}

conn, err := t.client.ReadyConn(ctx)
if err != nil {
return nil, nil, errors.Wrap(err, "unable to connect to server")
Expand All @@ -182,8 +195,8 @@ func (t *getDeploymentsForCVETool) handle(
listReq := &v1.RawQuery{
Query: buildQuery(input),
Pagination: &v1.Pagination{
Offset: input.Offset,
Limit: input.Limit,
Offset: currCursor.GetOffset(),
Limit: defaultLimit + 1,
},
}

Expand All @@ -202,8 +215,19 @@ func (t *getDeploymentsForCVETool) handle(
})
}

// We always fetch limit+1 - if we do not have one additional element we can end paging.
if len(deployments) <= defaultLimit {
return nil, &getDeploymentsForCVEOutput{Deployments: deployments}, nil
}

nextCursorStr, err := currCursor.GetNextCursor(defaultLimit).Encode()
if err != nil {
return nil, nil, errors.Wrap(err, "unable to create next cursor")
}

output := &getDeploymentsForCVEOutput{
Deployments: deployments,
Deployments: deployments[:len(deployments)-1],
NextCursor: nextCursorStr,
}

return nil, output, nil
Expand Down
Loading
Loading