diff --git a/gum/batcher.py b/gum/batcher.py index af7add3..644d4e8 100644 --- a/gum/batcher.py +++ b/gum/batcher.py @@ -1,55 +1,35 @@ -import asyncio -import json import logging -import os -from datetime import datetime, timezone, timedelta -from typing import List, Dict, Any, Optional -from dataclasses import dataclass, asdict +from datetime import datetime, timezone +from typing import List, Optional, Dict, Any from pathlib import Path - -@dataclass -class BatchedObservation: - """Represents a batched observation waiting for processing.""" - id: str - observer_name: str - content: str - content_type: str - timestamp: datetime - processed: bool = False +import uuid +from persistqueue import Queue class ObservationBatcher: - """Handles batching of observations to reduce API calls.""" + """A persistent queue for batching observations to reduce API calls.""" - def __init__(self, data_directory: str, batch_interval_hours: float = 1, max_batch_size: int = 50): + def __init__(self, data_directory: str, batch_interval_minutes: float = 2, max_batch_size: int = 50): self.data_directory = Path(data_directory) - self.batch_interval_hours = batch_interval_hours + self.batch_interval_minutes = batch_interval_minutes self.max_batch_size = max_batch_size - self.batch_file = self.data_directory / "batches" / "pending_observations.json" - self.batch_file.parent.mkdir(exist_ok=True) + + # Create persistent queue backed by SQLite + queue_dir = self.data_directory / "batches" + queue_dir.mkdir(parents=True, exist_ok=True) + self._queue = Queue(path=str(queue_dir / "queue")) self.logger = logging.getLogger("gum.batcher") - self._pending_observations: List[BatchedObservation] = [] - self._batch_task: Optional[asyncio.Task] = None async def start(self): """Start the batching system.""" - self._load_pending_observations() - self._batch_task = asyncio.create_task(self._batch_loop()) - self.logger.info(f"Started batcher with {len(self._pending_observations)} pending observations") + self.logger.info(f"Started batcher with {self._queue.qsize()} items in queue") async def stop(self): """Stop the batching system.""" - if self._batch_task: - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass - self._save_pending_observations() self.logger.info("Stopped batcher") - def add_observation(self, observer_name: str, content: str, content_type: str) -> str: - """Add an observation to the batch queue. + def push(self, observer_name: str, content: str, content_type: str) -> str: + """Push an observation onto the queue. Args: observer_name: Name of the observer @@ -59,116 +39,42 @@ def add_observation(self, observer_name: str, content: str, content_type: str) - Returns: str: Observation ID """ - import uuid - - observation = BatchedObservation( - id=str(uuid.uuid4()), - observer_name=observer_name, - content=content, - content_type=content_type, - timestamp=datetime.now(timezone.utc) - ) - - self._pending_observations.append(observation) - self.logger.debug(f"Added observation {observation.id} to batch (total: {len(self._pending_observations)})") + observation_id = str(uuid.uuid4()) + observation_dict = { + 'id': observation_id, + 'observer_name': observer_name, + 'content': content, + 'content_type': content_type, + 'timestamp': datetime.now(timezone.utc).isoformat() + } - # Save immediately to prevent data loss - self._save_pending_observations() + # Add to queue - automatically persisted by persist-queue + self._queue.put(observation_dict) + self.logger.debug(f"Pushed observation {observation_id} to queue (size: {self._queue.qsize()})") - return observation.id + return observation_id - def get_pending_count(self) -> int: - """Get the number of pending observations.""" - return len([obs for obs in self._pending_observations if not obs.processed]) + def size(self) -> int: + """Get the current size of the queue.""" + return self._queue.qsize() - def get_batch(self, max_size: Optional[int] = None) -> List[BatchedObservation]: - """Get a batch of unprocessed observations. + def pop_batch(self, batch_size: Optional[int] = None) -> List[Dict[str, Any]]: + """Pop a batch of observations from the front of the queue (FIFO). Args: - max_size: Maximum number of observations to return + batch_size: Number of items to pop. Defaults to max_batch_size Returns: - List of batched observations + List of observation dictionaries popped from queue """ - unprocessed = [obs for obs in self._pending_observations if not obs.processed] - max_size = max_size or self.max_batch_size - return unprocessed[:max_size] + batch_size = batch_size or self.max_batch_size - def mark_processed(self, observation_ids: List[str]): - """Mark observations as processed. + batch = [] + for _ in range(min(batch_size, self._queue.qsize())): + batch.append(self._queue.get_nowait()) - Args: - observation_ids: List of observation IDs to mark as processed - """ - for obs in self._pending_observations: - if obs.id in observation_ids: - obs.processed = True - - # Remove processed observations older than 24 hours - cutoff_time = datetime.now(timezone.utc) - timedelta(hours=24) - self._pending_observations = [ - obs for obs in self._pending_observations - if not obs.processed or obs.timestamp > cutoff_time - ] + if batch: + self.logger.debug(f"Popped batch of {len(batch)} observations (queue size: {self._queue.qsize()})") - self._save_pending_observations() - self.logger.debug(f"Marked {len(observation_ids)} observations as processed") - - async def _batch_loop(self): - """Main batching loop that processes observations periodically.""" - while True: - try: - # Wait for the batch interval - await asyncio.sleep(self.batch_interval_hours * 3600) - - # Get pending observations - batch = self.get_batch() - if batch: - self.logger.info(f"Processing batch of {len(batch)} observations") - # Signal that we have a batch ready - # The main GUM class will handle the actual processing - # For now, just log that we have a batch - self.logger.info(f"Batch ready with {len(batch)} observations") - else: - self.logger.debug("No observations to process in this batch") - - except asyncio.CancelledError: - break - except Exception as e: - self.logger.error(f"Error in batch loop: {e}") - await asyncio.sleep(60) # Wait a minute before retrying - - def _load_pending_observations(self): - """Load pending observations from disk.""" - if self.batch_file.exists(): - try: - with open(self.batch_file, 'r') as f: - data = json.load(f) - self._pending_observations = [ - BatchedObservation(**obs_data) - for obs_data in data - ] - # Convert timestamp strings back to datetime objects - for obs in self._pending_observations: - if isinstance(obs.timestamp, str): - obs.timestamp = datetime.fromisoformat(obs.timestamp.replace('Z', '+00:00')) - except Exception as e: - self.logger.error(f"Error loading pending observations: {e}") - self._pending_observations = [] - else: - self._pending_observations = [] - - def _save_pending_observations(self): - """Save pending observations to disk.""" - try: - # Convert datetime objects to ISO format strings - data = [] - for obs in self._pending_observations: - obs_dict = asdict(obs) - obs_dict['timestamp'] = obs.timestamp.isoformat() - data.append(obs_dict) - - with open(self.batch_file, 'w') as f: - json.dump(data, f, indent=2) - except Exception as e: - self.logger.error(f"Error saving pending observations: {e}") \ No newline at end of file + return batch + \ No newline at end of file diff --git a/gum/cli.py b/gum/cli.py index 5955598..b637326 100644 --- a/gum/cli.py +++ b/gum/cli.py @@ -31,8 +31,7 @@ def parse_args(): parser.add_argument('--reset-cache', action='store_true', help='Reset the GUM cache and exit') # Add this line # Batching configuration arguments - parser.add_argument('--use-batched-client', action='store_true', help='Enable batched client processing') - parser.add_argument('--batch-interval-hours', type=float, help='Hours between batch processing') + parser.add_argument('--batch-interval-minutes', type=float, help='Minutes between batch processing') parser.add_argument('--max-batch-size', type=int, help='Maximum number of observations per batch') args = parser.parse_args() @@ -58,10 +57,8 @@ async def main(): model = args.model or os.getenv('MODEL_NAME') or 'gpt-4o-mini' user_name = args.user_name or os.getenv('USER_NAME') - # Batching configuration - follow same pattern as other args - use_batched_client = args.use_batched_client or os.getenv('USE_BATCHED_CLIENT', 'false').lower() == 'true' - - batch_interval_hours = args.batch_interval_hours or float(os.getenv('BATCH_INTERVAL_HOURS', '1')) + # Batching configuration - follow same pattern as other args + batch_interval_minutes = args.batch_interval_minutes or float(os.getenv('BATCH_INTERVAL_MINUTES', '2')) max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '50')) # you need one or the other @@ -86,17 +83,12 @@ async def main(): print("-" * 80) else: print(f"Listening to {user_name} with model {model}") - if use_batched_client: - print(f"Batching enabled: processing every {batch_interval_hours} hours (max {max_batch_size} observations per batch)") - else: - print("Batching disabled: processing observations immediately") async with gum( user_name, model, Screen(model), - use_batched_client=use_batched_client, - batch_interval_hours=batch_interval_hours, + batch_interval_minutes=batch_interval_minutes, max_batch_size=max_batch_size ) as gum_instance: await asyncio.Future() # run forever (Ctrl-C to stop) diff --git a/gum/db_utils.py b/gum/db_utils.py index 026ff7c..7d43a02 100644 --- a/gum/db_utils.py +++ b/gum/db_utils.py @@ -26,7 +26,6 @@ from .models import ( Observation, Proposition, - proposition_parent, observation_proposition, ) @@ -45,13 +44,7 @@ def build_fts_query(raw: str, mode: str = "OR") -> str: else: # implicit AND return " ".join(tokens) -def _has_child_subquery() -> select: - return ( - select(literal_column("1")) - .select_from(proposition_parent) - .where(proposition_parent.c.parent_id == Proposition.id) - .exists() - ) + async def search_propositions_bm25( @@ -74,7 +67,6 @@ async def search_propositions_bm25( # 1 Build candidate list # -------------------------------------------------------- candidate_pool = limit * 10 if enable_mmr else limit - has_child = _has_child_subquery() if has_query: fts_prop = Table("propositions_fts", MetaData()) @@ -143,14 +135,12 @@ async def search_propositions_bm25( stmt = ( select(Proposition, best_scores.c.bm25) .join(best_scores, best_scores.c.pid == Proposition.id) - .where(~has_child) .order_by(best_scores.c.bm25.asc()) # smallest→best ) else: # --- 1-b No user query ------------------------------ stmt = ( select(Proposition, literal_column("0.0").label("bm25")) - .where(~has_child) .order_by(Proposition.created_at.desc()) ) diff --git a/gum/gum.py b/gum/gum.py index 051aa5d..c81198a 100644 --- a/gum/gum.py +++ b/gum/gum.py @@ -11,6 +11,7 @@ from datetime import datetime, timezone from typing import Callable, List from .models import observation_proposition +import traceback from openai import AsyncOpenAI from sqlalchemy.ext.asyncio import AsyncSession @@ -49,7 +50,7 @@ class gum: audit_prompt (str, optional): Custom prompt for auditing. data_directory (str, optional): Directory for storing data. Defaults to "~/.cache/gum". db_name (str, optional): Name of the database file. Defaults to "gum.db". - max_concurrent_updates (int, optional): Maximum number of concurrent updates. Defaults to 4. + verbosity (int, optional): Logging verbosity level. Defaults to logging.INFO. audit_enabled (bool, optional): Whether to enable auditing. Defaults to False. """ @@ -65,13 +66,11 @@ def __init__( audit_prompt: str | None = None, data_directory: str = "~/.cache/gum", db_name: str = "gum.db", - max_concurrent_updates: int = 4, verbosity: int = logging.INFO, audit_enabled: bool = False, api_base: str | None = None, api_key: str | None = None, - use_batched_client: bool = True, - batch_interval_hours: float = 1, + batch_interval_minutes: float = 2, max_batch_size: int = 50, ): # basic paths @@ -85,8 +84,7 @@ def __init__( self.audit_enabled = audit_enabled # batching configuration - self.use_batched_client = use_batched_client - self.batch_interval_hours = batch_interval_hours + self.batch_interval_minutes = batch_interval_minutes self.max_batch_size = max_batch_size # logging @@ -114,19 +112,15 @@ def __init__( self._data_directory = data_directory # Initialize batcher if enabled - if self.use_batched_client: - self.batcher = ObservationBatcher( - data_directory=data_directory, - batch_interval_hours=batch_interval_hours, - max_batch_size=max_batch_size - ) - else: - self.batcher = None + self.batcher = ObservationBatcher( + data_directory=data_directory, + batch_interval_minutes=batch_interval_minutes, + max_batch_size=max_batch_size + ) - self._update_sem = asyncio.Semaphore(max_concurrent_updates) - self._tasks: set[asyncio.Task] = set() self._loop_task: asyncio.Task | None = None self._batch_task: asyncio.Task | None = None + self._batch_processing_lock = asyncio.Lock() self.update_handlers: list[Callable[[Observer, Update], None]] = [] def start_update_loop(self): @@ -135,7 +129,7 @@ def start_update_loop(self): self._loop_task = asyncio.create_task(self._update_loop()) # Start batch processing if enabled - if self.use_batched_client and self.batcher and self._batch_task is None: + if self._batch_task is None: self._batch_task = asyncio.create_task(self._batch_processing_loop()) async def stop_update_loop(self): @@ -192,10 +186,6 @@ async def __aexit__(self, exc_type, exc, tb): """ await self.stop_update_loop() - # wait for any in-flight handlers - if self._tasks: - await asyncio.gather(*self._tasks, return_exceptions=True) - # stop observers for obs in self.observers: await obs.stop() @@ -220,21 +210,22 @@ async def _update_loop(self): upd: Update = fut.result() obs = gets[fut] - t = asyncio.create_task(self._run_with_gate(obs, upd)) - self._tasks.add(t) + asyncio.create_task(self._default_handler(obs, upd)) async def _batch_processing_loop(self): """Process batched observations periodically to reduce API calls.""" while True: try: # Wait for the batch interval - await asyncio.sleep(self.batch_interval_hours * 3600) + await asyncio.sleep(self.batch_interval_minutes * 60) # Get pending observations - batch = self.batcher.get_batch() + batch = self.batcher.pop_batch() if batch: self.logger.info(f"Processing batch of {len(batch)} observations") - await self._process_batch(batch) + # Use lock to ensure batch processing runs synchronously + async with self._batch_processing_lock: + await self._process_batch(batch) else: self.logger.debug("No observations to process in this batch") @@ -256,15 +247,15 @@ async def _process_batch(self, batched_observations): observation_ids = [] for obs in batched_observations: - combined_content.append(f"[{obs.observer_name}] {obs.content}") - observation_ids.append(obs.id) + combined_content.append(f"[{obs['observer_name']}] {obs['content']}") + observation_ids.append(obs['id']) combined_text = "\n\n".join(combined_content) # Create a combined update combined_update = Update( content=combined_text, - content_type="text" + content_type="input_text" ) try: @@ -273,9 +264,9 @@ async def _process_batch(self, batched_observations): observations = [] for obs in batched_observations: observation = Observation( - observer_name=obs.observer_name, - content=obs.content, - content_type=obs.content_type, + observer_name=obs['observer_name'], + content=obs['content'], + content_type=obs['content_type'], ) session.add(observation) observations.append(observation) @@ -283,43 +274,27 @@ async def _process_batch(self, batched_observations): await session.flush() # Process the combined content - pool = await self._generate_and_search(session, combined_update, observations[0]) - - if pool: - self.logger.info(f"Linking batch observations to {len(pool)} candidate propositions.") - for prop in pool: - for obs in observations: - await self._attach_obs_if_missing(prop, obs, session) - await session.flush() - + pool = await self._generate_and_search(session, combined_update) identical, similar, different = await self._filter_propositions(pool) self.logger.info("Applying proposition updates for batch...") - await self._handle_identical(session, identical, observations[0]) - await self._handle_similar(session, similar, observations[0]) - await self._handle_different(session, different, observations[0]) - - # Mark observations as processed - self.batcher.mark_processed(observation_ids) + await self._handle_identical(session, identical, observations) + await self._handle_similar(session, similar, observations) + await self._handle_different(session, different, observations) + # Observations are already removed from queue by pop_batch() self.logger.info(f"Completed processing batch of {len(batched_observations)} observations") except Exception as e: self.logger.error(f"Error processing batch: {e}") - # Don't mark as processed if there was an error - - async def _run_with_gate(self, observer: Observer, update: Update): - """Wrapper that enforces max_concurrent_updates. - - Args: - observer (Observer): The observer that generated the update. - update (Update): The update to process. - """ - async with self._update_sem: - try: - await self._default_handler(observer, update) - finally: - self._tasks.discard(asyncio.current_task()) + self.logger.error(f"Traceback: {traceback.format_exc()}") + self.logger.error(f"Batch size: {len(batched_observations)}") + if batched_observations: + self.logger.error(f"First observation type: {type(batched_observations[0])}") + self.logger.error(f"First observation: {batched_observations[0]}") + # Put failed items back in queue for retry + for obs in batched_observations: + self.batcher.push(obs['observer_name'], obs['content'], obs['content_type']) async def _construct_propositions(self, update: Update) -> list[PropositionItem]: """Generate propositions from an update. @@ -459,7 +434,7 @@ async def _revise_propositions( return json.loads(rsp.choices[0].message.content)["propositions"] async def _generate_and_search( - self, session: AsyncSession, update: Update, obs: Observation + self, session: AsyncSession, update: Update ) -> list[Proposition]: drafts_raw = await self._construct_propositions(update) @@ -482,7 +457,7 @@ async def _generate_and_search( hits = await search_propositions_bm25( session, f"{draft.text}\n{draft.reasoning}", mode="OR", include_observations=False, - enable_mmr=True, + enable_mmr=False, enable_decay=True ) @@ -498,58 +473,60 @@ async def _generate_and_search( return list(pool.values()) async def _handle_identical( - self, session, identical: list[Proposition], obs: Observation + self, session, identical: list[Proposition], observations: list[Observation] ) -> None: for p in identical: - await self._attach_obs_if_missing(p, obs, session) + for obs in observations: + await self._attach_obs_if_missing(p, obs, session) async def _handle_similar( self, session: AsyncSession, similar: list[Proposition], - obs: Observation, + observations: list[Observation], ) -> None: if not similar: return + # Collect all observations from similar propositions rel_obs = { o for p in similar for o in await get_related_observations(session, p.id) } - rel_obs.add(obs) + # Add all the batched observations + rel_obs.update(observations) + # Generate revised propositions revised_items = await self._revise_propositions(list(rel_obs), similar) - newest_version = max(p.version for p in similar) - parent_groups = {p.revision_group for p in similar} - if len(parent_groups) == 1: - revision_group = parent_groups.pop() - else: - revision_group = uuid4().hex - - new_children: list[Proposition] = [] + + # Delete all old similar propositions + for prop in similar: + await session.delete(prop) + + # Create new propositions to replace them + revision_group = str(uuid4()) for item in revised_items: - child = Proposition( + new_prop = Proposition( text=item["proposition"], reasoning=item["reasoning"], confidence=item.get("confidence"), decay=item.get("decay"), - version=newest_version + 1, + version=1, # Start fresh with version 1 revision_group=revision_group, observations=rel_obs, - parents=set(similar), ) - session.add(child) - new_children.append(child) + session.add(new_prop) await session.flush() async def _handle_different( - self, session, different: list[Proposition], obs: Observation + self, session, different: list[Proposition], observations: list[Observation] ) -> None: for p in different: - await self._attach_obs_if_missing(p, obs, session) + for obs in observations: + await self._attach_obs_if_missing(p, obs, session) async def _handle_audit(self, obs: Observation) -> bool: if not self.audit_enabled: @@ -609,45 +586,13 @@ async def _handle_audit(self, obs: Observation) -> bool: async def _default_handler(self, observer: Observer, update: Update) -> None: self.logger.info(f"Processing update from {observer.name}") - # If batching is enabled, add to batch instead of processing immediately - if self.use_batched_client and self.batcher: - observation_id = self.batcher.add_observation( - observer_name=observer.name, - content=update.content, - content_type=update.content_type - ) - self.logger.info(f"Added observation {observation_id} to batch (pending: {self.batcher.get_pending_count()})") - return - - # Original processing logic for non-batched mode - async with self._session() as session: - observation = Observation( - observer_name=observer.name, - content=update.content, - content_type=update.content_type, - ) - - if await self._handle_audit(observation): - return - - session.add(observation) - await session.flush() # Observation gets its ID - - pool = await self._generate_and_search(session, update, observation) - - if pool: - self.logger.info(f"Linking observation to {len(pool)} candidate propositions.") - for prop in pool: - await self._attach_obs_if_missing(prop, observation, session) - await session.flush() - - identical, similar, different = await self._filter_propositions(pool) - - self.logger.info("Applying proposition updates...") - await self._handle_identical(session, identical, observation) - await self._handle_similar(session, similar, observation) - await self._handle_different(session, different, observation) - self.logger.info("Completed processing update") + # add to batch + observation_id = self.batcher.push( + observer_name=observer.name, + content=update.content, + content_type=update.content_type + ) + self.logger.info(f"Added observation {observation_id} to queue (size: {self.batcher.size()})") @asynccontextmanager async def _session(self): diff --git a/gum/models.py b/gum/models.py index 73d8907..df903ae 100644 --- a/gum/models.py +++ b/gum/models.py @@ -54,22 +54,7 @@ class Base(AsyncAttrs, DeclarativeBase): ), ) -proposition_parent = Table( - "proposition_parent", - Base.metadata, - Column( - "child_id", - Integer, - ForeignKey("propositions.id", ondelete="CASCADE"), - primary_key=True, - ), - Column( - "parent_id", - Integer, - ForeignKey("propositions.id", ondelete="CASCADE"), - primary_key=True, - ), -) + class Observation(Base): @@ -138,7 +123,7 @@ class Proposition(Base): updated_at (datetime): When the proposition was last updated. revision_group (str): Group identifier for related proposition revisions. version (int): Version number of this proposition. - parents (set[Proposition]): Set of parent propositions. + observations (set[Observation]): Set of observations related to this proposition. """ __tablename__ = "propositions" @@ -162,15 +147,7 @@ class Proposition(Base): revision_group: Mapped[str] = mapped_column(String(36), nullable=False, index=True) version: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False) - parents: Mapped[set["Proposition"]] = relationship( - "Proposition", - secondary=proposition_parent, - primaryjoin=id == proposition_parent.c.child_id, - secondaryjoin=id == proposition_parent.c.parent_id, - backref="children", - collection_class=set, - lazy="selectin", - ) + observations: Mapped[set[Observation]] = relationship( "Observation", diff --git a/gum/prompts/gum.py b/gum/prompts/gum.py index 08f8593..0704a78 100644 --- a/gum/prompts/gum.py +++ b/gum/prompts/gum.py @@ -83,7 +83,7 @@ # Task -Generate **5 distinct, well-supported propositions** about {user_name}, each grounded in the transcript. +Generate **at least 5 distinct, well-supported propositions** about {user_name}, each grounded in the transcript. Be conservative in your confidence estimates. Just because an application appears on {user_name}'s screen does not mean they have deeply engaged with it. They may have only glanced at it for a second, making it difficult to draw strong conclusions. diff --git a/gum/schemas.py b/gum/schemas.py index abaf4e6..a8cb54b 100644 --- a/gum/schemas.py +++ b/gum/schemas.py @@ -33,7 +33,7 @@ class PropositionItem(BaseModel): class PropositionSchema(BaseModel): propositions: List[PropositionItem] = Field( ..., - description="Up to five propositions" + description="Up to K propositions" ) model_config = ConfigDict(extra="forbid") diff --git a/pyproject.toml b/pyproject.toml index de96629..32de5e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "gum-ai" -version = "0.1.10" +version = "0.1.11" readme = "README.md" authors = [{ name = "Omar Shaikh", email = "oshaikh13@gmail.com" }] license = {text = "MIT"} @@ -22,6 +22,7 @@ dependencies = [ "scikit-learn", "aiosqlite", "greenlet", + "persist-queue", "mkdocs>=1.5.0", "mkdocs-material>=9.0.0", "mkdocstrings>=0.24.0", diff --git a/setup.py b/setup.py index 1a37865..d9dfb94 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="gum", - version="0.1.10", + version="0.1.11", packages=find_packages(), install_requires=[ # Core dependencies