Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions gum/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ def parse_args():
action=QueryAction,
help='Query the GUM with an optional query string',
)
parser.add_argument(
'--recent', '-r',
action='store_true',
help='List the most recent propositions instead of running BM25 search',
)

parser.add_argument('--limit', '-l', type=int, help='Limit the number of results', default=10)
parser.add_argument('--model', '-m', type=str, help='Model to use')
Expand Down Expand Up @@ -61,12 +66,25 @@ async def main():
min_batch_size = args.min_batch_size or int(os.getenv('MIN_BATCH_SIZE', '5'))
max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '15'))

# you need one or the other
if user_name is None and args.query is None:
print("Please provide a user name (as an argument, -u, or as an env variable) or a query (as an argument, -q)")
# you need one of: user_name for listening mode, --query, or --recent
if user_name is None and args.query is None and not getattr(args, 'recent', False):
print("Please provide a user name (-u), a query (-q), or use --recent to list latest propositions")
return

if args.query is not None:
if getattr(args, 'recent', False):
gum_instance = gum(user_name or os.getenv('USER_NAME') or 'default', model)
await gum_instance.connect_db()
props = await gum_instance.recent(limit=args.limit)
print(f"\nRecent {len(props)} propositions:")
for p in props:
print(f"\nProposition: {p.text}")
if p.reasoning:
print(f"Reasoning: {p.reasoning}")
if p.confidence is not None:
print(f"Confidence: {p.confidence:.2f}")
print(f"Created At: {p.created_at}")
print("-" * 80)
elif args.query is not None:
gum_instance = gum(user_name, model)
await gum_instance.connect_db()
result = await gum_instance.query(args.query, limit=args.limit)
Expand Down
71 changes: 71 additions & 0 deletions gum/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,76 @@ async def get_related_observations(
.order_by(Observation.created_at.desc())
.limit(limit)
)
result = await session.execute(stmt)
return result.scalars().all()


async def get_recent_propositions(
session: AsyncSession,
*,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
include_observations: bool = False,
) -> List[Proposition]:
"""Fetch the most recent propositions ordered by created_at desc.

Args:
session: Active async DB session
limit: Max number of propositions to return
start_time: Optional lower bound for created_at
end_time: Optional upper bound for created_at (defaults to now)
include_observations: Whether to eager-load related observations

Returns:
List[Proposition]: Most recent propositions
"""

if end_time is None:
end_time = datetime.now(timezone.utc)
if start_time is not None and start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
if end_time.tzinfo is None:
end_time = end_time.replace(tzinfo=timezone.utc)

stmt = (
select(Proposition)
.where(Proposition.created_at <= end_time)
.order_by(Proposition.created_at.desc())
.limit(limit)
)
if start_time is not None:
stmt = stmt.where(Proposition.created_at >= start_time)
if include_observations:
stmt = stmt.options(selectinload(Proposition.observations))

result = await session.execute(stmt)
return result.scalars().all()


async def get_recent_observations(
session: AsyncSession,
*,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> List[Observation]:
"""Fetch the most recent observations ordered by created_at desc."""
if end_time is None:
end_time = datetime.now(timezone.utc)
if start_time is not None and start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
if end_time.tzinfo is None:
end_time = end_time.replace(tzinfo=timezone.utc)

stmt = (
select(Observation)
.where(Observation.created_at <= end_time)
.order_by(Observation.created_at.desc())
.limit(limit)
)
if start_time is not None:
stmt = stmt.where(Observation.created_at >= start_time)

result = await session.execute(stmt)
return result.scalars().all()
36 changes: 36 additions & 0 deletions gum/gum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from .db_utils import (
get_related_observations,
search_propositions_bm25,
get_recent_propositions,
get_recent_observations,
)
from .models import Observation, Proposition, init_db
from .observers import Observer
Expand Down Expand Up @@ -650,3 +652,37 @@ async def query(
start_time=start_time,
end_time=end_time,
)

async def recent(
self,
*,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
include_observations: bool = False,
) -> list[Proposition]:
"""Return the most recent propositions ordered by created_at descending."""
async with self._session() as session:
return await get_recent_propositions(
session,
limit=limit,
start_time=start_time,
end_time=end_time,
include_observations=include_observations,
)

async def recent_observations(
self,
*,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[Observation]:
"""Return the most recent observations ordered by created_at descending."""
async with self._session() as session:
return await get_recent_observations(
session,
limit=limit,
start_time=start_time,
end_time=end_time,
)