diff --git a/gum/cli.py b/gum/cli.py index 2dbc2f9..8f1b301 100644 --- a/gum/cli.py +++ b/gum/cli.py @@ -25,6 +25,11 @@ def parse_args(): action=QueryAction, help='Query the GUM with an optional query string', ) + parser.add_argument( + '--recent', '-r', + action='store_true', + help='List the most recent propositions instead of running BM25 search', + ) parser.add_argument('--limit', '-l', type=int, help='Limit the number of results', default=10) parser.add_argument('--model', '-m', type=str, help='Model to use') @@ -61,12 +66,25 @@ async def main(): min_batch_size = args.min_batch_size or int(os.getenv('MIN_BATCH_SIZE', '5')) max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '15')) - # you need one or the other - if user_name is None and args.query is None: - print("Please provide a user name (as an argument, -u, or as an env variable) or a query (as an argument, -q)") + # you need one of: user_name for listening mode, --query, or --recent + if user_name is None and args.query is None and not getattr(args, 'recent', False): + print("Please provide a user name (-u), a query (-q), or use --recent to list latest propositions") return - if args.query is not None: + if getattr(args, 'recent', False): + gum_instance = gum(user_name or os.getenv('USER_NAME') or 'default', model) + await gum_instance.connect_db() + props = await gum_instance.recent(limit=args.limit) + print(f"\nRecent {len(props)} propositions:") + for p in props: + print(f"\nProposition: {p.text}") + if p.reasoning: + print(f"Reasoning: {p.reasoning}") + if p.confidence is not None: + print(f"Confidence: {p.confidence:.2f}") + print(f"Created At: {p.created_at}") + print("-" * 80) + elif args.query is not None: gum_instance = gum(user_name, model) await gum_instance.connect_db() result = await gum_instance.query(args.query, limit=args.limit) diff --git a/gum/db_utils.py b/gum/db_utils.py index 7d43a02..168cd3b 100644 --- a/gum/db_utils.py +++ b/gum/db_utils.py @@ -243,5 +243,76 @@ async def get_related_observations( .order_by(Observation.created_at.desc()) .limit(limit) ) + result = await session.execute(stmt) + return result.scalars().all() + + +async def get_recent_propositions( + session: AsyncSession, + *, + limit: int = 10, + start_time: datetime | None = None, + end_time: datetime | None = None, + include_observations: bool = False, +) -> List[Proposition]: + """Fetch the most recent propositions ordered by created_at desc. + + Args: + session: Active async DB session + limit: Max number of propositions to return + start_time: Optional lower bound for created_at + end_time: Optional upper bound for created_at (defaults to now) + include_observations: Whether to eager-load related observations + + Returns: + List[Proposition]: Most recent propositions + """ + + if end_time is None: + end_time = datetime.now(timezone.utc) + if start_time is not None and start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + stmt = ( + select(Proposition) + .where(Proposition.created_at <= end_time) + .order_by(Proposition.created_at.desc()) + .limit(limit) + ) + if start_time is not None: + stmt = stmt.where(Proposition.created_at >= start_time) + if include_observations: + stmt = stmt.options(selectinload(Proposition.observations)) + + result = await session.execute(stmt) + return result.scalars().all() + + +async def get_recent_observations( + session: AsyncSession, + *, + limit: int = 10, + start_time: datetime | None = None, + end_time: datetime | None = None, +) -> List[Observation]: + """Fetch the most recent observations ordered by created_at desc.""" + if end_time is None: + end_time = datetime.now(timezone.utc) + if start_time is not None and start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + stmt = ( + select(Observation) + .where(Observation.created_at <= end_time) + .order_by(Observation.created_at.desc()) + .limit(limit) + ) + if start_time is not None: + stmt = stmt.where(Observation.created_at >= start_time) + result = await session.execute(stmt) return result.scalars().all() \ No newline at end of file diff --git a/gum/gum.py b/gum/gum.py index 74345bc..af59c2f 100644 --- a/gum/gum.py +++ b/gum/gum.py @@ -20,6 +20,8 @@ from .db_utils import ( get_related_observations, search_propositions_bm25, + get_recent_propositions, + get_recent_observations, ) from .models import Observation, Proposition, init_db from .observers import Observer @@ -650,3 +652,37 @@ async def query( start_time=start_time, end_time=end_time, ) + + async def recent( + self, + *, + limit: int = 10, + start_time: datetime | None = None, + end_time: datetime | None = None, + include_observations: bool = False, + ) -> list[Proposition]: + """Return the most recent propositions ordered by created_at descending.""" + async with self._session() as session: + return await get_recent_propositions( + session, + limit=limit, + start_time=start_time, + end_time=end_time, + include_observations=include_observations, + ) + + async def recent_observations( + self, + *, + limit: int = 10, + start_time: datetime | None = None, + end_time: datetime | None = None, + ) -> list[Observation]: + """Return the most recent observations ordered by created_at descending.""" + async with self._session() as session: + return await get_recent_observations( + session, + limit=limit, + start_time=start_time, + end_time=end_time, + )