Skip to content

Commit 3354dc6

Browse files
Api Contract Test + Fixes (#32)
* Implemented API Contract test * Enhanced contract test with content checking * Fixed API-level bugs: - Using default index config (None) would fail - `contents` was not supported in upsert - `include` was ignored in query when limiting fields to return * Implemented API contract testing workflows * Added query_contents to API contract test * Potential fix for code scanning alert no. 5: Workflow does not contain permissions Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> * Linting & formatting * Updated API Contract check workflow --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
1 parent 0be0cda commit 3354dc6

File tree

5 files changed

+1165
-9
lines changed

5 files changed

+1165
-9
lines changed
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
name: API Contract Test
2+
3+
permissions:
4+
contents: read
5+
pull-requests: write
6+
7+
on:
8+
pull_request:
9+
types: [opened, synchronize, reopened]
10+
push:
11+
branches: [main]
12+
13+
jobs:
14+
check-contract:
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0 # Fetch all history for git diff
21+
22+
- name: Check for Contract Changes
23+
if: github.event_name == 'pull_request'
24+
id: check_changes
25+
run: |
26+
# Check if the API contract test file was modified
27+
git diff --name-only ${{ github.event.pull_request.base.sha }}..HEAD | grep -q "tests/test_api_contract.py"
28+
29+
if [ $? -eq 0 ]; then
30+
echo "contract_changed=true" >> $GITHUB_OUTPUT
31+
echo "API contract test file was modified in this PR"
32+
else
33+
echo "contract_changed=false" >> $GITHUB_OUTPUT
34+
echo "API contract test file was not modified"
35+
fi
36+
37+
- name: Request Changes if Contract Changed
38+
if: github.event_name == 'pull_request' && steps.check_changes.outputs.contract_changed == 'true'
39+
uses: actions/github-script@v6
40+
with:
41+
github-token: ${{ secrets.GITHUB_TOKEN }}
42+
script: |
43+
const body = `## ⚠️ API Contract Change Detected
44+
45+
This PR modifies the public API contract of the CyborgDB Python SDK.
46+
47+
**Please provide an explanation for this change:**
48+
- Why is this change necessary?
49+
- Is this a breaking change or backward compatible?
50+
- Have you updated the documentation?
51+
52+
**This review must be dismissed or addressed before the PR can be merged.**`;
53+
54+
// Check if we already have a review requesting changes
55+
const { data: reviews } = await github.rest.pulls.listReviews({
56+
owner: context.repo.owner,
57+
repo: context.repo.repo,
58+
pull_number: context.issue.number,
59+
});
60+
61+
const existingReview = reviews.find(review =>
62+
review.user.type === 'Bot' &&
63+
review.body.includes('API Contract Change Detected')
64+
);
65+
66+
if (!existingReview) {
67+
// Create a review requesting changes
68+
await github.rest.pulls.createReview({
69+
owner: context.repo.owner,
70+
repo: context.repo.repo,
71+
pull_number: context.issue.number,
72+
body: body,
73+
event: 'REQUEST_CHANGES'
74+
});
75+
}

.github/workflows/test.yml

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,25 @@ jobs:
152152
fi
153153
154154
# Run all tests (full version tests will run, lite tests will also work)
155-
pytest tests/ -v --cov=cyborgdb --cov-report=term-missing --cov-append
155+
pytest tests/ -v --cov=cyborgdb --cov-report=term-missing --cov-append --ignore=tests/test_api_contract.py
156156
157-
# Stop standard server
158-
kill $(cat server-standard.pid) || true
157+
# Store the server pid for later cleanup
158+
echo "Standard server PID: $(cat server-standard.pid)"
159+
160+
- name: Check API Contract
161+
run: |
162+
echo "=== Checking API Contract ==="
163+
# Run the API contract test separately to avoid affecting coverage of other tests
164+
pytest tests/test_api_contract.py -v
165+
166+
- name: Kill Server
167+
run: |
168+
if [ -f server-lite.pid ]; then
169+
kill $(cat server-lite.pid) || true
170+
fi
171+
if [ -f server-standard.pid ]; then
172+
kill $(cat server-standard.pid) || true
173+
fi
159174
160175
- name: Upload coverage reports
161176
uses: codecov/codecov-action@v3

cyborgdb/client/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def create_index(
162162
# Convert binary key to hex string
163163
key_hex = binascii.hexlify(index_key).decode("ascii")
164164

165+
if index_config is None:
166+
index_config = IndexIVFFlatModel() # Default config
167+
165168
# Create an IndexConfig instance with the appropriate model
166169
index_config_obj = IndexConfig(index_config)
167170

cyborgdb/client/encrypted_index.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
This module provides the EncryptedIndex class for interacting with encrypted vector indexes in CyborgDB.
55
"""
66

7+
import base64
8+
import binascii
9+
import logging
710
from typing import Dict, List, Optional, Union, Any
811
import json
912
import numpy as np
10-
import logging
11-
import binascii
1213

1314
# Import the OpenAPI generated client
1415
try:
@@ -24,6 +25,7 @@
2425
from cyborgdb.openapi_client.models.query_request import QueryRequest
2526
from cyborgdb.openapi_client.models.list_ids_request import ListIDsRequest
2627
from cyborgdb.openapi_client.models.request import Request
28+
from cyborgdb.openapi_client.models import Contents
2729
except ImportError:
2830
raise ImportError(
2931
"Failed to import openapi_client. Make sure the OpenAPI client library is properly installed."
@@ -312,7 +314,23 @@ def upsert(
312314
item["vector"] = item_dict["vector"]
313315

314316
if "contents" in item_dict:
315-
item["contents"] = item_dict["contents"]
317+
contents_value = item_dict["contents"]
318+
319+
# Convert bytes to base64 string for JSON serialization
320+
if isinstance(contents_value, bytes):
321+
# Convert bytes to base64 string
322+
contents_value = base64.b64encode(contents_value).decode(
323+
"utf-8"
324+
)
325+
elif isinstance(contents_value, bytearray):
326+
# Convert bytearray to base64 string
327+
contents_value = base64.b64encode(
328+
bytes(contents_value)
329+
).decode("utf-8")
330+
# If it's already a string, use as-is
331+
332+
# Contents model accepts string or bytearray
333+
item["contents"] = Contents(contents_value)
316334

317335
if "metadata" in item_dict:
318336
# Convert dict metadata to JSON string if needed
@@ -494,6 +512,12 @@ def query(
494512
response_text = raw_response.data.decode("utf-8")
495513
response_json = json.loads(response_text)
496514

515+
# Determine include filtering strategy
516+
include_all = (
517+
include is None
518+
) # None means include everything server returns
519+
include_set = set(include) if include else set()
520+
497521
# Process the results as plain dictionaries
498522
results = []
499523
if "results" in response_json:
@@ -506,23 +530,37 @@ def query(
506530
query_items = []
507531
for item in query_result:
508532
result_item = {"id": item["id"]}
533+
534+
# Always include distance if present (core part of query results)
509535
if "distance" in item:
510536
result_item["distance"] = item["distance"]
511-
if "metadata" in item:
537+
538+
# Check metadata against include list
539+
if "metadata" in item and (
540+
include_all or "metadata" in include_set
541+
):
512542
result_item["metadata"] = item["metadata"]
543+
513544
query_items.append(result_item)
514545
results.append(query_items)
515546
else:
516547
# It's a flat list (single query results)
517548
query_items = []
518549
for item in response_json["results"]:
519550
result_item = {"id": item["id"]}
551+
552+
# Always include distance if present (core part of query results)
520553
if "distance" in item:
521554
result_item["distance"] = item["distance"]
522-
if "metadata" in item:
555+
556+
# Check metadata against include list
557+
if "metadata" in item and (
558+
include_all or "metadata" in include_set
559+
):
523560
result_item["metadata"] = item["metadata"]
561+
524562
query_items.append(result_item)
525-
results.append(query_items)
563+
results = query_items
526564

527565
return results
528566
except Exception as e:

0 commit comments

Comments
 (0)