Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 42 additions & 136 deletions gum/batcher.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,35 @@
import asyncio
import json
import logging
import os
from datetime import datetime, timezone, timedelta
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
from pathlib import Path

@dataclass
class BatchedObservation:
"""Represents a batched observation waiting for processing."""
id: str
observer_name: str
content: str
content_type: str
timestamp: datetime
processed: bool = False
import uuid
from persistqueue import Queue

class ObservationBatcher:
"""Handles batching of observations to reduce API calls."""
"""A persistent queue for batching observations to reduce API calls."""

def __init__(self, data_directory: str, batch_interval_hours: float = 1, max_batch_size: int = 50):
def __init__(self, data_directory: str, batch_interval_minutes: float = 2, max_batch_size: int = 50):
self.data_directory = Path(data_directory)
self.batch_interval_hours = batch_interval_hours
self.batch_interval_minutes = batch_interval_minutes
self.max_batch_size = max_batch_size
self.batch_file = self.data_directory / "batches" / "pending_observations.json"
self.batch_file.parent.mkdir(exist_ok=True)

# Create persistent queue backed by SQLite
queue_dir = self.data_directory / "batches"
queue_dir.mkdir(parents=True, exist_ok=True)
self._queue = Queue(path=str(queue_dir / "queue"))

self.logger = logging.getLogger("gum.batcher")
self._pending_observations: List[BatchedObservation] = []
self._batch_task: Optional[asyncio.Task] = None

async def start(self):
"""Start the batching system."""
self._load_pending_observations()
self._batch_task = asyncio.create_task(self._batch_loop())
self.logger.info(f"Started batcher with {len(self._pending_observations)} pending observations")
self.logger.info(f"Started batcher with {self._queue.qsize()} items in queue")

async def stop(self):
"""Stop the batching system."""
if self._batch_task:
self._batch_task.cancel()
try:
await self._batch_task
except asyncio.CancelledError:
pass
self._save_pending_observations()
self.logger.info("Stopped batcher")

def add_observation(self, observer_name: str, content: str, content_type: str) -> str:
"""Add an observation to the batch queue.
def push(self, observer_name: str, content: str, content_type: str) -> str:
"""Push an observation onto the queue.

Args:
observer_name: Name of the observer
Expand All @@ -59,116 +39,42 @@ def add_observation(self, observer_name: str, content: str, content_type: str) -
Returns:
str: Observation ID
"""
import uuid

observation = BatchedObservation(
id=str(uuid.uuid4()),
observer_name=observer_name,
content=content,
content_type=content_type,
timestamp=datetime.now(timezone.utc)
)

self._pending_observations.append(observation)
self.logger.debug(f"Added observation {observation.id} to batch (total: {len(self._pending_observations)})")
observation_id = str(uuid.uuid4())
observation_dict = {
'id': observation_id,
'observer_name': observer_name,
'content': content,
'content_type': content_type,
'timestamp': datetime.now(timezone.utc).isoformat()
}

# Save immediately to prevent data loss
self._save_pending_observations()
# Add to queue - automatically persisted by persist-queue
self._queue.put(observation_dict)
self.logger.debug(f"Pushed observation {observation_id} to queue (size: {self._queue.qsize()})")

return observation.id
return observation_id

def get_pending_count(self) -> int:
"""Get the number of pending observations."""
return len([obs for obs in self._pending_observations if not obs.processed])
def size(self) -> int:
"""Get the current size of the queue."""
return self._queue.qsize()

def get_batch(self, max_size: Optional[int] = None) -> List[BatchedObservation]:
"""Get a batch of unprocessed observations.
def pop_batch(self, batch_size: Optional[int] = None) -> List[Dict[str, Any]]:
"""Pop a batch of observations from the front of the queue (FIFO).

Args:
max_size: Maximum number of observations to return
batch_size: Number of items to pop. Defaults to max_batch_size

Returns:
List of batched observations
List of observation dictionaries popped from queue
"""
unprocessed = [obs for obs in self._pending_observations if not obs.processed]
max_size = max_size or self.max_batch_size
return unprocessed[:max_size]
batch_size = batch_size or self.max_batch_size

def mark_processed(self, observation_ids: List[str]):
"""Mark observations as processed.
batch = []
for _ in range(min(batch_size, self._queue.qsize())):
batch.append(self._queue.get_nowait())

Args:
observation_ids: List of observation IDs to mark as processed
"""
for obs in self._pending_observations:
if obs.id in observation_ids:
obs.processed = True

# Remove processed observations older than 24 hours
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=24)
self._pending_observations = [
obs for obs in self._pending_observations
if not obs.processed or obs.timestamp > cutoff_time
]
if batch:
self.logger.debug(f"Popped batch of {len(batch)} observations (queue size: {self._queue.qsize()})")

self._save_pending_observations()
self.logger.debug(f"Marked {len(observation_ids)} observations as processed")

async def _batch_loop(self):
"""Main batching loop that processes observations periodically."""
while True:
try:
# Wait for the batch interval
await asyncio.sleep(self.batch_interval_hours * 3600)

# Get pending observations
batch = self.get_batch()
if batch:
self.logger.info(f"Processing batch of {len(batch)} observations")
# Signal that we have a batch ready
# The main GUM class will handle the actual processing
# For now, just log that we have a batch
self.logger.info(f"Batch ready with {len(batch)} observations")
else:
self.logger.debug("No observations to process in this batch")

except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Error in batch loop: {e}")
await asyncio.sleep(60) # Wait a minute before retrying

def _load_pending_observations(self):
"""Load pending observations from disk."""
if self.batch_file.exists():
try:
with open(self.batch_file, 'r') as f:
data = json.load(f)
self._pending_observations = [
BatchedObservation(**obs_data)
for obs_data in data
]
# Convert timestamp strings back to datetime objects
for obs in self._pending_observations:
if isinstance(obs.timestamp, str):
obs.timestamp = datetime.fromisoformat(obs.timestamp.replace('Z', '+00:00'))
except Exception as e:
self.logger.error(f"Error loading pending observations: {e}")
self._pending_observations = []
else:
self._pending_observations = []

def _save_pending_observations(self):
"""Save pending observations to disk."""
try:
# Convert datetime objects to ISO format strings
data = []
for obs in self._pending_observations:
obs_dict = asdict(obs)
obs_dict['timestamp'] = obs.timestamp.isoformat()
data.append(obs_dict)

with open(self.batch_file, 'w') as f:
json.dump(data, f, indent=2)
except Exception as e:
self.logger.error(f"Error saving pending observations: {e}")
return batch

16 changes: 4 additions & 12 deletions gum/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def parse_args():
parser.add_argument('--reset-cache', action='store_true', help='Reset the GUM cache and exit') # Add this line

# Batching configuration arguments
parser.add_argument('--use-batched-client', action='store_true', help='Enable batched client processing')
parser.add_argument('--batch-interval-hours', type=float, help='Hours between batch processing')
parser.add_argument('--batch-interval-minutes', type=float, help='Minutes between batch processing')
parser.add_argument('--max-batch-size', type=int, help='Maximum number of observations per batch')

args = parser.parse_args()
Expand All @@ -58,10 +57,8 @@ async def main():
model = args.model or os.getenv('MODEL_NAME') or 'gpt-4o-mini'
user_name = args.user_name or os.getenv('USER_NAME')

# Batching configuration - follow same pattern as other args
use_batched_client = args.use_batched_client or os.getenv('USE_BATCHED_CLIENT', 'false').lower() == 'true'

batch_interval_hours = args.batch_interval_hours or float(os.getenv('BATCH_INTERVAL_HOURS', '1'))
# Batching configuration - follow same pattern as other args
batch_interval_minutes = args.batch_interval_minutes or float(os.getenv('BATCH_INTERVAL_MINUTES', '2'))
max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '50'))

# you need one or the other
Expand All @@ -86,17 +83,12 @@ async def main():
print("-" * 80)
else:
print(f"Listening to {user_name} with model {model}")
if use_batched_client:
print(f"Batching enabled: processing every {batch_interval_hours} hours (max {max_batch_size} observations per batch)")
else:
print("Batching disabled: processing observations immediately")

async with gum(
user_name,
model,
Screen(model),
use_batched_client=use_batched_client,
batch_interval_hours=batch_interval_hours,
batch_interval_minutes=batch_interval_minutes,
max_batch_size=max_batch_size
) as gum_instance:
await asyncio.Future() # run forever (Ctrl-C to stop)
Expand Down
12 changes: 1 addition & 11 deletions gum/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from .models import (
Observation,
Proposition,
proposition_parent,
observation_proposition,
)

Expand All @@ -45,13 +44,7 @@ def build_fts_query(raw: str, mode: str = "OR") -> str:
else: # implicit AND
return " ".join(tokens)

def _has_child_subquery() -> select:
return (
select(literal_column("1"))
.select_from(proposition_parent)
.where(proposition_parent.c.parent_id == Proposition.id)
.exists()
)



async def search_propositions_bm25(
Expand All @@ -74,7 +67,6 @@ async def search_propositions_bm25(
# 1 Build candidate list
# --------------------------------------------------------
candidate_pool = limit * 10 if enable_mmr else limit
has_child = _has_child_subquery()

if has_query:
fts_prop = Table("propositions_fts", MetaData())
Expand Down Expand Up @@ -143,14 +135,12 @@ async def search_propositions_bm25(
stmt = (
select(Proposition, best_scores.c.bm25)
.join(best_scores, best_scores.c.pid == Proposition.id)
.where(~has_child)
.order_by(best_scores.c.bm25.asc()) # smallest→best
)
else:
# --- 1-b No user query ------------------------------
stmt = (
select(Proposition, literal_column("0.0").label("bm25"))
.where(~has_child)
.order_by(Proposition.created_at.desc())
)

Expand Down
Loading