diff --git a/.gitignore b/.gitignore index c15e9c2..048d87d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,24 +1,26 @@ -# Environment +# Environment and local data .env -data -examples +.env.* +.python-version skypilot.yaml +data/ +examples/ test_gum.py -# Python -__pycache__ +# Python bytecode and caches +__pycache__/ *.py[cod] *$py.class *.so -.Python +pip-wheel-metadata/ +__pypackages__/ + +# Build and distribution artifacts build/ -develop-eggs/ dist/ downloads/ eggs/ .eggs/ -lib/ -lib64/ parts/ sdist/ var/ @@ -27,90 +29,68 @@ wheels/ .installed.cfg *.egg -# Virtual Environment +# Virtual environments +.venv/ venv/ +venv.bak/ env/ -ENV/ -.env/ -.venv/ env.bak/ -venv.bak/ +ENV/ -# IDE +# IDE and OS noise .idea/ .vscode/ *.swp *.swo .DS_Store +Thumbs.db -# Testing +# Testing and coverage .coverage -htmlcov/ -.pytest_cache/ -.tox/ -.nox/ .coverage.* .cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ - -# Jupyter Notebook -.ipynb_checkpoints - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# Unit test / coverage reports htmlcov/ -.tox/ .nox/ -.coverage -.coverage.* -.cache +.pytest_cache/ +.tox/ nosetests.xml coverage.xml +pytest-debug.log *.cover *.py,cover .hypothesis/ -.pytest_cache/ -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# mypy +# Type checker caches .mypy_cache/ .dmypy.json dmypy.json - -# Pyre type checker .pyre/ - -# pytype static type analyzer .pytype/ +# Jupyter notebooks +.ipynb_checkpoints/ + # Cython debug symbols -cython_debug/ \ No newline at end of file +cython_debug/ + +# Recorder artifacts +records/ +**/records/ +screenshots/ +**/screenshots/ +workflow.json +workflow.txt +*.db +*.db-* +*.sqlite +*.sqlite3 +*.log +logs/ +.tmp/ +tmp/ + +# Bundled apps and installers +*.app +*.dmg +*.whl + diff --git a/README.md b/README.md index 30fd0e6..6ed8f51 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,50 @@ General User Models learn about you by observing any interaction you have with y ## Documentation -**Please go here for documentation on setting up and using GUMs: [https://generalusermodels.github.io/gum/](https://generalusermodels.github.io/gum/)** +**Full setup and usage docs live here: [https://generalusermodels.github.io/gum/](https://generalusermodels.github.io/gum/)** + +## Record and Induce Human Workflows + +This repository also contains a macOS recorder (`record/`) and induction utilities (`induce/`) for capturing and processing human-computer interaction traces. + +### Record Human Computer-Use Activities + +Install the recording tool: + +```bash +cd record +pip install -e . +``` + +Follow the [instructions](record/instructions.pdf) to configure the required system settings. + +Run the recorder CLI directly from the repo (no install required): + +```bash +python -m record.gum +``` + +If you've installed the package with `pip install -e .`, you can instead invoke the console script: + +```bash +gum +``` + +Both commands launch the macOS recorder and begin logging activities. + +### Induce Human Workflows + +Install dependencies and run the induction pipeline against a directory that contains recorded sessions (defaults to `~/Downloads/records`): + +```bash +cd ../induce +pip install -r requirements.txt +python get_human_trajectory.py --data_dir +python segment.py --data_dir +python induce.py --data_dir --auto +``` + +`get_human_trajectory.py` merges duplicate actions, `segment.py` detects state transitions, and `induce.py` performs semantic-based segment merging. The resulting workflow artifacts are saved to `{data_dir}/workflow.json` and `{data_dir}/workflow.txt`. ## Contributing diff --git a/gum/batcher.py b/gum/batcher.py deleted file mode 100644 index 6a941e6..0000000 --- a/gum/batcher.py +++ /dev/null @@ -1,99 +0,0 @@ -import asyncio -import logging -from datetime import datetime, timezone -from typing import List, Optional, Dict, Any -from pathlib import Path -import uuid -from persistqueue import Queue - -class ObservationBatcher: - """A persistent queue for batching observations to reduce API calls.""" - - def __init__(self, data_directory: str, min_batch_size: int = 5, max_batch_size: int = 50): - self.data_directory = Path(data_directory) - self.min_batch_size = min_batch_size - self.max_batch_size = max_batch_size - - # Create persistent queue backed by SQLite - queue_dir = self.data_directory / "batches" - queue_dir.mkdir(parents=True, exist_ok=True) - self._queue = Queue(path=str(queue_dir / "queue")) - - self._batch_ready_event = asyncio.Event() - self.logger = logging.getLogger("gum.batcher") - - async def start(self): - """Start the batching system.""" - self.logger.info(f"Started batcher with {self._queue.qsize()} items in queue") - - if self.should_process_batch(): - self._batch_ready_event.set() - - async def stop(self): - """Stop the batching system.""" - self.logger.info("Stopped batcher") - - def push(self, observer_name: str, content: str, content_type: str) -> str: - """Push an observation onto the queue. - - Args: - observer_name: Name of the observer - content: Observation content - content_type: Type of content - - Returns: - str: Observation ID - """ - observation_id = str(uuid.uuid4()) - observation_dict = { - 'id': observation_id, - 'observer_name': observer_name, - 'content': content, - 'content_type': content_type, - 'timestamp': datetime.now(timezone.utc).isoformat() - } - - # Add to queue - automatically persisted by persist-queue - self._queue.put(observation_dict) - self.logger.debug(f"Pushed observation {observation_id} to queue (size: {self._queue.qsize()})") - - # Signal that a batch is ready if we've reached minimum size - if self.should_process_batch(): - self._batch_ready_event.set() - - return observation_id - - def size(self) -> int: - """Get the current size of the queue.""" - return self._queue.qsize() - - def should_process_batch(self) -> bool: - """Check if the batch should be processed based on minimum batch size.""" - return self._queue.qsize() >= self.min_batch_size - - def pop_batch(self, batch_size: Optional[int] = None) -> List[Dict[str, Any]]: - """Pop a batch of observations from the front of the queue (FIFO). - - Args: - batch_size: Number of items to pop. Defaults to max_batch_size - - Returns: - List of observation dictionaries popped from queue - """ - batch_size = batch_size or self.max_batch_size - - batch = [] - for _ in range(min(batch_size, self._queue.qsize())): - batch.append(self._queue.get_nowait()) - - if batch: - self.logger.debug(f"Popped batch of {len(batch)} observations (queue size: {self._queue.qsize()})") - - if not self.should_process_batch(): - self._batch_ready_event.clear() - - return batch - - async def wait_for_batch_ready(self): - """Wait for a batch to be ready for processing.""" - await self._batch_ready_event.wait() \ No newline at end of file diff --git a/gum/cli.py b/gum/cli.py deleted file mode 100644 index 2dbc2f9..0000000 --- a/gum/cli.py +++ /dev/null @@ -1,100 +0,0 @@ -from dotenv import load_dotenv, find_dotenv -load_dotenv(find_dotenv(usecwd=True)) - -import os -import argparse -import asyncio -import shutil -from gum import gum -from gum.observers import Screen - -class QueryAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - if values is None: - setattr(namespace, self.dest, '') - else: - setattr(namespace, self.dest, values) - -def parse_args(): - parser = argparse.ArgumentParser(description='GUM - A Python package with command-line interface') - parser.add_argument('--user-name', '-u', type=str, help='The user name to use') - - parser.add_argument( - '--query', '-q', - nargs='?', - action=QueryAction, - help='Query the GUM with an optional query string', - ) - - parser.add_argument('--limit', '-l', type=int, help='Limit the number of results', default=10) - parser.add_argument('--model', '-m', type=str, help='Model to use') - parser.add_argument('--reset-cache', action='store_true', help='Reset the GUM cache and exit') # Add this line - - # Batching configuration arguments - parser.add_argument('--min-batch-size', type=int, help='Minimum number of observations to trigger batch processing') - parser.add_argument('--max-batch-size', type=int, help='Maximum number of observations per batch') - - args = parser.parse_args() - - if not hasattr(args, 'query'): - args.query = None - - return args - -async def main(): - args = parse_args() - - # Handle --reset-cache before anything else - if getattr(args, 'reset_cache', False): - cache_dir = os.path.expanduser('~/.cache/gum/') - if os.path.exists(cache_dir): - shutil.rmtree(cache_dir) - print(f"Deleted cache directory: {cache_dir}") - else: - print(f"Cache directory does not exist: {cache_dir}") - return - - model = args.model or os.getenv('MODEL_NAME') or 'gpt-4o-mini' - user_name = args.user_name or os.getenv('USER_NAME') - - # Batching configuration - follow same pattern as other args - min_batch_size = args.min_batch_size or int(os.getenv('MIN_BATCH_SIZE', '5')) - max_batch_size = args.max_batch_size or int(os.getenv('MAX_BATCH_SIZE', '15')) - - # you need one or the other - if user_name is None and args.query is None: - print("Please provide a user name (as an argument, -u, or as an env variable) or a query (as an argument, -q)") - return - - if args.query is not None: - gum_instance = gum(user_name, model) - await gum_instance.connect_db() - result = await gum_instance.query(args.query, limit=args.limit) - - # confidences / propositions / number of items returned - print(f"\nFound {len(result)} results:") - for prop, score in result: - print(f"\nProposition: {prop.text}") - if prop.reasoning: - print(f"Reasoning: {prop.reasoning}") - if prop.confidence is not None: - print(f"Confidence: {prop.confidence:.2f}") - print(f"Relevance Score: {score:.2f}") - print("-" * 80) - else: - print(f"Listening to {user_name} with model {model}") - - async with gum( - user_name, - model, - Screen(model), - min_batch_size=min_batch_size, - max_batch_size=max_batch_size - ) as gum_instance: - await asyncio.Future() # run forever (Ctrl-C to stop) - -def cli(): - asyncio.run(main()) - -if __name__ == '__main__': - cli() \ No newline at end of file diff --git a/gum/db_utils.py b/gum/db_utils.py deleted file mode 100644 index 7d43a02..0000000 --- a/gum/db_utils.py +++ /dev/null @@ -1,247 +0,0 @@ -# db_utils.py - -from __future__ import annotations - -import math -import re -from datetime import datetime, timezone -from typing import List - -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity - -from sqlalchemy import ( - MetaData, - Table, - select, - literal_column, - text, - func, -) - -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from .models import ( - Observation, - Proposition, - observation_proposition, -) - -# Constants -K_DECAY = 2.0 # decay rate for recency adjustment -LAMBDA = 0.5 # trade-off for MMR - -def build_fts_query(raw: str, mode: str = "OR") -> str: - tokens = re.findall(r"\w+", raw.lower()) - if not tokens: - return "" - if mode == "PHRASE": - return f'"{" ".join(tokens)}"' - elif mode == "OR": - return " OR ".join(tokens) - else: # implicit AND - return " ".join(tokens) - - - - -async def search_propositions_bm25( - session: AsyncSession, - user_query: str, - *, - limit: int = 3, - mode: str = "OR", - start_time: datetime | None = None, - end_time: datetime | None = None, - include_observations: bool = True, - enable_decay: bool = True, - enable_mmr: bool = True, -) -> list[tuple["Proposition", float]]: - - q = build_fts_query(user_query, mode) - has_query = bool(q) - - # -------------------------------------------------------- - # 1 Build candidate list - # -------------------------------------------------------- - candidate_pool = limit * 10 if enable_mmr else limit - - if has_query: - fts_prop = Table("propositions_fts", MetaData()) - - if include_observations: - # --- 1-a-1 WITH observations -------------------- - fts_obs = Table("observations_fts", MetaData()) - - bm25_p = literal_column("bm25(propositions_fts)").label("score") - bm25_o = literal_column("bm25(observations_fts)").label("score") - - sub_p = ( - select(Proposition.id.label("pid"), bm25_p) - .select_from( - fts_prop.join( - Proposition, - literal_column("propositions_fts.rowid") == Proposition.id, - ) - ) - .where(text("propositions_fts MATCH :q")) - ) - - sub_o = ( - select(observation_proposition.c.proposition_id.label("pid"), bm25_o) - .select_from( - fts_obs - .join( - Observation, - literal_column("observations_fts.rowid") == Observation.id, - ) - .join( - observation_proposition, - observation_proposition.c.observation_id == Observation.id, - ) - ) - .where(text("observations_fts MATCH :q")) - ) - - union_sub = sub_p.union_all(sub_o).subquery() - - best_scores = ( - select( - union_sub.c.pid, - func.min(union_sub.c.score).label("bm25"), - ) - .group_by(union_sub.c.pid) - .subquery() - ) - else: - # --- 1-a-2 WITHOUT observations ----------------- - best_scores = ( - select( - Proposition.id.label("pid"), - literal_column("bm25(propositions_fts)").label("bm25"), - ) - .select_from( - fts_prop.join( - Proposition, - literal_column("propositions_fts.rowid") == Proposition.id, - ) - ) - .where(text("propositions_fts MATCH :q")) - .subquery() - ) - - stmt = ( - select(Proposition, best_scores.c.bm25) - .join(best_scores, best_scores.c.pid == Proposition.id) - .order_by(best_scores.c.bm25.asc()) # smallest→best - ) - else: - # --- 1-b No user query ------------------------------ - stmt = ( - select(Proposition, literal_column("0.0").label("bm25")) - .order_by(Proposition.created_at.desc()) - ) - - # -------------------------------------------------------- - # 2 Time filtering & eager-load - # -------------------------------------------------------- - if end_time is None: - end_time = datetime.now(timezone.utc) - if start_time is not None and start_time.tzinfo is None: - start_time = start_time.replace(tzinfo=timezone.utc) - if end_time.tzinfo is None: - end_time = end_time.replace(tzinfo=timezone.utc) - - if start_time is not None: - stmt = stmt.where(Proposition.created_at >= start_time) - stmt = stmt.where(Proposition.created_at <= end_time) - - if include_observations: - stmt = stmt.options(selectinload(Proposition.observations)) - - stmt = stmt.limit(candidate_pool) - - # -------------------------------------------------------- - # 3 Execute & score - # -------------------------------------------------------- - bind = {"q": q} if has_query else {} - rows = (await session.execute(stmt, bind)).all() - if not rows: - return [] - - # --- 3-a. Calculate initial scores --- - initial_scores: list[float] = [] - now = datetime.now(timezone.utc) - for prop, raw_score in rows: - relevance_score = -raw_score if has_query else 0.0 - gamma = 0.0 - if enable_decay: - dt = prop.created_at.replace(tzinfo=timezone.utc) - age_days = max((now - dt).total_seconds() / 86_400, 0.0) - alpha = prop.decay if prop.decay is not None else 0.0 - gamma = -alpha * K_DECAY * age_days - - score = relevance_score * math.exp(gamma) - initial_scores.append(score) - - final_scores_np = np.array(initial_scores) - min_score = np.min(final_scores_np) - max_score = np.max(final_scores_np) - - if max_score > min_score: - final_scores_np = (final_scores_np - min_score) / (max_score - min_score) - else: - final_scores_np = np.full_like(final_scores_np, 0.5) - - final_scores = final_scores_np.tolist() - - if enable_mmr and len(rows) > 1: - docs: list[str] = [] - for p, _ in rows: - doc_parts = [p.text, p.reasoning] - if include_observations and p.observations: - obs_concat = " ".join(o.content for o in list(p.observations)[:10]) - doc_parts.append(obs_concat) - docs.append(" ".join(doc_parts)) - - vecs = TfidfVectorizer().fit_transform(docs) - - selected_idxs = [] - mmr_scores = np.array(final_scores) - - while len(selected_idxs) < min(limit, len(rows)): - if not selected_idxs: - idx = int(np.argmax(mmr_scores)) - else: - sims = cosine_similarity(vecs, vecs[selected_idxs]).max(axis=1) - mmr = LAMBDA * mmr_scores - (1 - LAMBDA) * sims - mmr[selected_idxs] = -np.inf - idx = int(np.argmax(mmr)) - - selected_idxs.append(idx) - else: - idxs = np.argsort(final_scores)[::-1][:limit] - selected_idxs = idxs.tolist() - - result = [(rows[i][0], final_scores[i]) for i in selected_idxs] - return result - -async def get_related_observations( - session: AsyncSession, - proposition_id: int, - *, # Force keyword arguments for optional parameters - limit: int = 5, -) -> List[Observation]: - - stmt = ( - select(Observation) - .join(observation_proposition) - .join(Proposition) - .where(Proposition.id == proposition_id) - .order_by(Observation.created_at.desc()) - .limit(limit) - ) - result = await session.execute(stmt) - return result.scalars().all() \ No newline at end of file diff --git a/gum/gum.py b/gum/gum.py deleted file mode 100644 index b4ef53a..0000000 --- a/gum/gum.py +++ /dev/null @@ -1,651 +0,0 @@ -# gum.py - -from __future__ import annotations - -import asyncio -import json -import logging -import os -from uuid import uuid4 -from contextlib import asynccontextmanager -from datetime import datetime, timezone -from typing import Callable, List -from .models import observation_proposition -import traceback - -from openai import AsyncOpenAI -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import insert - -from .db_utils import ( - get_related_observations, - search_propositions_bm25, -) -from .models import Observation, Proposition, init_db -from .observers import Observer -from .schemas import ( - PropositionItem, - PropositionSchema, - RelationSchema, - Update, - get_schema, - AuditSchema -) -from gum.prompts.gum import AUDIT_PROMPT, PROPOSE_PROMPT, REVISE_PROMPT, SIMILAR_PROMPT -from .batcher import ObservationBatcher - -class gum: - """A class for managing general user models. - - This class provides functionality for observing user behavior, generating and managing - propositions about user behavior, and maintaining relationships between observations - and propositions. - - Args: - user_name (str): The name of the user being modeled. - *observers (Observer): Variable number of observer instances to track user behavior. - propose_prompt (str, optional): Custom prompt for proposition generation. - similar_prompt (str, optional): Custom prompt for similarity analysis. - revise_prompt (str, optional): Custom prompt for proposition revision. - audit_prompt (str, optional): Custom prompt for auditing. - data_directory (str, optional): Directory for storing data. Defaults to "~/.cache/gum". - db_name (str, optional): Name of the database file. Defaults to "gum.db". - - verbosity (int, optional): Logging verbosity level. Defaults to logging.INFO. - audit_enabled (bool, optional): Whether to enable auditing. Defaults to False. - """ - - def __init__( - self, - user_name: str, - model: str, - *observers: Observer, - propose_prompt: str | None = None, - similar_prompt: str | None = None, - revise_prompt: str | None = None, - audit_prompt: str | None = None, - data_directory: str = "~/.cache/gum", - db_name: str = "gum.db", - verbosity: int = logging.INFO, - audit_enabled: bool = False, - api_base: str | None = None, - api_key: str | None = None, - min_batch_size: int = 5, - max_batch_size: int = 50, - ): - # basic paths - data_directory = os.path.expanduser(data_directory) - os.makedirs(data_directory, exist_ok=True) - - # runtime - self.user_name = user_name - self.observers: list[Observer] = list(observers) - self.model = model - self.audit_enabled = audit_enabled - - # batching configuration - self.min_batch_size = min_batch_size - self.max_batch_size = max_batch_size - - # logging - self.logger = logging.getLogger("gum") - self.logger.setLevel(verbosity) - if not self.logger.handlers: - h = logging.StreamHandler() - h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) - self.logger.addHandler(h) - - # prompts - self.propose_prompt = propose_prompt or PROPOSE_PROMPT - self.similar_prompt = similar_prompt or SIMILAR_PROMPT - self.revise_prompt = revise_prompt or REVISE_PROMPT - self.audit_prompt = audit_prompt or AUDIT_PROMPT - - self.client = AsyncOpenAI( - base_url=api_base or os.getenv("GUM_LM_API_BASE"), - api_key=api_key or os.getenv("GUM_LM_API_KEY") or os.getenv("OPENAI_API_KEY") or "None" - ) - - self.engine = None - self.Session = None - self._db_name = db_name - self._data_directory = data_directory - - # Initialize batcher if enabled - self.batcher = ObservationBatcher( - data_directory=data_directory, - min_batch_size=min_batch_size, - max_batch_size=max_batch_size - ) - - self._loop_task: asyncio.Task | None = None - self._batch_task: asyncio.Task | None = None - self._batch_processing_lock = asyncio.Lock() - self.update_handlers: list[Callable[[Observer, Update], None]] = [] - - def start_update_loop(self): - """Start the asynchronous update loop for processing observer updates.""" - if self._loop_task is None: - self._loop_task = asyncio.create_task(self._update_loop()) - - # Start batch processing if enabled - if self._batch_task is None: - self._batch_task = asyncio.create_task(self._batch_processing_loop()) - - async def stop_update_loop(self): - """Stop the asynchronous update loop and clean up resources.""" - if self._loop_task: - self._loop_task.cancel() - try: - await self._loop_task - except asyncio.CancelledError: - pass - self._loop_task = None - - # Stop batch processing if enabled - if self._batch_task: - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass - self._batch_task = None - - if self.batcher: - await self.batcher.stop() - - async def connect_db(self): - """Initialize the database connection if not already connected.""" - if self.engine is None: - self.engine, self.Session = await init_db( - self._db_name, self._data_directory - ) - - async def __aenter__(self): - """Async context manager entry point. - - Returns: - gum: The instance of the gum class. - """ - await self.connect_db() - self.start_update_loop() - - # Start batcher if enabled - if self.batcher: - await self.batcher.start() - - return self - - async def __aexit__(self, exc_type, exc, tb): - """Async context manager exit point. - - Args: - exc_type: The type of exception if any. - exc: The exception instance if any. - tb: The traceback if any. - """ - await self.stop_update_loop() - - # stop observers - for obs in self.observers: - await obs.stop() - - async def _update_loop(self): - """Efficiently wait for any observer to produce an Update and dispatch it. - - This method continuously monitors all observers for updates and processes them - through the semaphore-guarded handler. - """ - while True: - gets = { - asyncio.create_task(obs.update_queue.get()): obs - for obs in self.observers - } - - done, _ = await asyncio.wait( - gets.keys(), return_when=asyncio.FIRST_COMPLETED - ) - - for fut in done: - upd: Update = fut.result() - obs = gets[fut] - - asyncio.create_task(self._default_handler(obs, upd)) - - async def _batch_processing_loop(self): - """Process batched observations when minimum batch size is reached.""" - while True: - # Wait for batch to be ready (event-driven, no polling!) - await self.batcher.wait_for_batch_ready() - - # Use lock to ensure batch processing runs synchronously - async with self._batch_processing_lock: - batch = self.batcher.pop_batch() - self.logger.info(f"Processing batch of {len(batch)} observations") - await self._process_batch(batch) - - async def _process_batch(self, batched_observations): - """Process a batch of observations together to reduce API calls.""" - - # Combine all observations into a single content for analysis - combined_content = [] - observation_ids = [] - - for obs in batched_observations: - combined_content.append(f"[{obs['observer_name']}] {obs['content']}") - observation_ids.append(obs['id']) - - combined_text = "\n\n".join(combined_content) - - # Create a combined update - combined_update = Update( - content=combined_text, - content_type="input_text" - ) - - try: - async with self._session() as session: - # Create observations in database - observations = [] - for obs in batched_observations: - observation = Observation( - observer_name=obs['observer_name'], - content=obs['content'], - content_type=obs['content_type'], - ) - session.add(observation) - observations.append(observation) - - await session.flush() - - # Process the combined content - pool = await self._generate_and_search(session, combined_update) - identical, similar, different = await self._filter_propositions(pool) - - self.logger.info("Applying proposition updates for batch...") - await self._handle_identical(session, identical, observations) - await self._handle_similar(session, similar, observations) - await self._handle_different(session, different, observations) - - # Observations are already removed from queue by pop_batch() - self.logger.info(f"Completed processing batch of {len(batched_observations)} observations") - - except Exception as e: - self.logger.error(f"Error processing batch: {e}") - self.logger.error(f"Traceback: {traceback.format_exc()}") - self.logger.error(f"Batch size: {len(batched_observations)}") - if batched_observations: - self.logger.error(f"First observation type: {type(batched_observations[0])}") - self.logger.error(f"First observation: {batched_observations[0]}") - # Put failed items back in queue for retry - for obs in batched_observations: - self.batcher.push(obs['observer_name'], obs['content'], obs['content_type']) - - async def _construct_propositions(self, update: Update) -> list[PropositionItem]: - """Generate propositions from an update. - - Args: - update (Update): The update to generate propositions from. - - Returns: - list[PropositionItem]: List of generated propositions. - """ - prompt = ( - self.propose_prompt.replace("{user_name}", self.user_name) - .replace("{inputs}", update.content) - ) - - schema = PropositionSchema.model_json_schema() - rsp = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - response_format=get_schema(schema), - ) - - return json.loads(rsp.choices[0].message.content)["propositions"] - - async def _build_relation_prompt(self, all_props) -> str: - """Build a prompt for analyzing relationships between propositions. - - Args: - all_props: List of propositions to analyze. - - Returns: - str: The formatted prompt for relationship analysis. - """ - blocks = [ - f"[id={p['id']}] {p['proposition']}\n Reasoning: {p['reasoning']}" - for p in all_props - ] - body = "\n\n".join(blocks) - return self.similar_prompt.replace("{body}", body) - - async def _filter_propositions( - self, rel_props: list[Proposition] - ) -> tuple[list[Proposition], list[Proposition], list[Proposition]]: - """Filter propositions into identical, similar, and unrelated groups. - - Args: - rel_props (list[Proposition]): List of propositions to filter. - - Returns: - tuple[list[Proposition], list[Proposition], list[Proposition]]: Three lists containing - identical, similar, and unrelated propositions respectively. - """ - if not rel_props: - return [], [], [] - - payload = [ - {"id": p.id, "proposition": p.text, "reasoning": p.reasoning or ""} - for p in rel_props - ] - prompt_text = await self._build_relation_prompt(payload) - - rsp = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt_text}], - response_format=get_schema(RelationSchema.model_json_schema()), - ) - - data = RelationSchema.model_validate_json(rsp.choices[0].message.content) - - id_to_prop = {p.id: p for p in rel_props} - ident, sim, unrel = set(), set(), set() - - for r in data.relations: - if r.label == "IDENTICAL": - ident.add(r.source) - ident.update(r.target or []) - elif r.label == "SIMILAR": - sim.add(r.source) - sim.update(r.target or []) - else: - unrel.add(r.source) - - # only keep IDs we actually know about - valid_ids = set(id_to_prop.keys()) - ident &= valid_ids - sim &= valid_ids - unrel &= valid_ids - - return ( - [id_to_prop[i] for i in ident], - [id_to_prop[i] for i in sim - ident], - [id_to_prop[i] for i in unrel - ident - sim], - ) - - async def _build_revision_body( - self, similar: List[Proposition], related_obs: List[Observation] - ) -> str: - """Build the body text for proposition revision. - - Args: - similar (List[Proposition]): List of similar propositions. - related_obs (List[Observation]): List of related observations. - - Returns: - str: The formatted body text for revision. - """ - blocks = [ - f"Proposition {idx}: {p.text}\nReasoning: {p.reasoning}" - for idx, p in enumerate(similar, 1) - ] - if related_obs: - blocks.append("\nSupporting observations:") - blocks.extend(f"- {o.content}" for o in related_obs[:10]) - return "\n".join(blocks) - - async def _revise_propositions( - self, - related_obs: list[Observation], - similar_cluster: list[Proposition], - ) -> list[dict]: - """Revise propositions based on related observations and similar propositions. - - Args: - related_obs (list[Observation]): List of related observations. - similar_cluster (list[Proposition]): List of similar propositions. - - Returns: - list[dict]: List of revised propositions. - """ - body = await self._build_revision_body(similar_cluster, related_obs) - prompt = self.revise_prompt.replace("{body}", body) - rsp = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - response_format=get_schema(PropositionSchema.model_json_schema()), - ) - return json.loads(rsp.choices[0].message.content)["propositions"] - - async def _generate_and_search( - self, session: AsyncSession, update: Update - ) -> list[Proposition]: - - drafts_raw = await self._construct_propositions(update) - drafts: list[Proposition] = [] - pool: dict[int, Proposition] = {} - - for itm in drafts_raw: - draft = Proposition( - text=itm["proposition"], - reasoning=itm["reasoning"], - confidence=itm.get("confidence"), - decay=itm.get("decay"), - revision_group=str(uuid4()), - version=1, - ) - drafts.append(draft) - - # search existing persisted props - with session.no_autoflush: - hits = await search_propositions_bm25( - session, f"{draft.text}\n{draft.reasoning}", mode="OR", - include_observations=False, - enable_mmr=False, - enable_decay=True - ) - - for prop, _score in hits: - pool[prop.id] = prop - - session.add_all(drafts) - await session.flush() - - for draft in drafts: - pool[draft.id] = draft - - return list(pool.values()) - - async def _handle_identical( - self, session, identical: list[Proposition], observations: list[Observation] - ) -> None: - for p in identical: - for obs in observations: - await self._attach_obs_if_missing(p, obs, session) - - async def _handle_similar( - self, - session: AsyncSession, - similar: list[Proposition], - observations: list[Observation], - ) -> None: - - if not similar: - return - - # Collect all observations from similar propositions - rel_obs = { - o - for p in similar - for o in await get_related_observations(session, p.id) - } - # Add all the batched observations - rel_obs.update(observations) - - # Generate revised propositions - revised_items = await self._revise_propositions(list(rel_obs), similar) - - # Delete all old similar propositions - for prop in similar: - await session.delete(prop) - - # Create new propositions to replace them - revision_group = str(uuid4()) - for item in revised_items: - new_prop = Proposition( - text=item["proposition"], - reasoning=item["reasoning"], - confidence=item.get("confidence"), - decay=item.get("decay"), - version=1, # Start fresh with version 1 - revision_group=revision_group, - observations=rel_obs, - ) - session.add(new_prop) - - await session.flush() - - async def _handle_different( - self, session, different: list[Proposition], observations: list[Observation] - ) -> None: - for p in different: - for obs in observations: - await self._attach_obs_if_missing(p, obs, session) - - async def _handle_audit(self, obs: Observation) -> bool: - if not self.audit_enabled: - return False - - hits = await self.query(obs.content, limit=10, mode="OR") - - if not hits: - past_interaction = "*None*" - else: - ctx_chunks: list[str] = [] - async with self._session() as session: - for prop, score in hits: - chunk = [f"• {prop.text}"] - if prop.reasoning: - chunk.append(f" Reasoning: {prop.reasoning}") - if prop.confidence is not None: - chunk.append(f" Confidence: {prop.confidence}") - chunk.append(f" Relevance Score: {score:.2f}") - - obs_list = await get_related_observations(session, prop.id) - if obs_list: - chunk.append(" Supporting Observations:") - for rel_obs in obs_list: - preview = rel_obs.content.replace("\n", " ")[:120] - chunk.append(f" - [{rel_obs.observer_name}] {preview}") - - ctx_chunks.append("\n".join(chunk)) - - past_interaction = "\n\n".join(ctx_chunks) - - prompt = ( - self.audit_prompt - .replace("{past_interaction}", past_interaction) - .replace("{user_input}", obs.content) - .replace("{user_name}", self.user_name) - ) - - rsp = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - response_format=get_schema(AuditSchema.model_json_schema()), - temperature=0.0, - ) - decision = json.loads(rsp.choices[0].message.content) - - if not decision["transmit_data"]: - self.logger.warning( - "Audit blocked transmission (data_type=%s, subject=%s)", - decision["data_type"], - decision["subject"], - ) - return True - - return False - - async def _default_handler(self, observer: Observer, update: Update) -> None: - self.logger.info(f"Processing update from {observer.name}") - - # add to batch - observation_id = self.batcher.push( - observer_name=observer.name, - content=update.content, - content_type=update.content_type - ) - self.logger.info(f"Added observation {observation_id} to queue (size: {self.batcher.size()})") - - @asynccontextmanager - async def _session(self): - async with self.Session() as s: - async with s.begin(): - yield s - - @staticmethod - async def _attach_obs_if_missing(prop: Proposition, obs: Observation, session): - await session.execute( - insert(observation_proposition) - .prefix_with("OR IGNORE") - .values(observation_id=obs.id, proposition_id=prop.id) - ) - prop.updated_at = datetime.now(timezone.utc) - - def add_observer(self, observer: Observer): - """Add an observer to track user behavior. - - Args: - observer (Observer): The observer to add. - """ - self.observers.append(observer) - - def remove_observer(self, observer: Observer): - """Remove an observer from tracking. - - Args: - observer (Observer): The observer to remove. - """ - if observer in self.observers: - self.observers.remove(observer) - - def register_update_handler(self, fn: Callable[[Observer, Update], None]): - """Register a custom update handler function. - - Args: - fn (Callable[[Observer, Update], None]): The handler function to register. - """ - self.update_handlers.append(fn) - - async def query( - self, - user_query: str, - *, - limit: int = 3, - mode: str = "OR", - start_time: datetime | None = None, - end_time: datetime | None = None, - ) -> list[tuple[Proposition, float]]: - """Query the database for propositions matching the user query. - - Args: - user_query (str): The query string to search for. - limit (int, optional): Maximum number of results to return. Defaults to 3. - mode (str, optional): Search mode ("OR" or "AND"). Defaults to "OR". - start_time (datetime, optional): Start time for filtering results. Defaults to None. - end_time (datetime, optional): End time for filtering results. Defaults to None. - - Returns: - list[tuple[Proposition, float]]: List of tuples containing propositions and their relevance scores. - """ - async with self._session() as session: - return await search_propositions_bm25( - session, - user_query, - limit=limit, - mode=mode, - start_time=start_time, - end_time=end_time, - ) diff --git a/gum/observers/__init__.py b/gum/observers/__init__.py deleted file mode 100644 index cbb8926..0000000 --- a/gum/observers/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Observer module for GUM - General User Models. - -This module provides observer classes for different types of user interactions. -""" - -from .observer import Observer -from .screen import Screen - -__all__ = ["Observer", "Screen"] \ No newline at end of file diff --git a/gum/observers/observer.py b/gum/observers/observer.py deleted file mode 100644 index dab8022..0000000 --- a/gum/observers/observer.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations -from abc import ABC, abstractmethod -from typing import Optional -import asyncio - -class Observer(ABC): - """Base class for all observers in the GUM system. - - This abstract base class defines the interface for all observers that monitor user behavior. - Observers are responsible for collecting data about user interactions and sending updates - through an asynchronous queue. - - Args: - name (Optional[str]): A custom name for the observer. If not provided, the class name will be used. - - Attributes: - update_queue (asyncio.Queue): Queue for sending updates to the main GUM system. - _name (str): The name of the observer. - _running (bool): Flag indicating if the observer is currently running. - _task (Optional[asyncio.Task]): Background task handle for the observer's worker. - """ - - def __init__(self, name: Optional[str] = None) -> None: - self.update_queue = asyncio.Queue() - self._name = name or self.__class__.__name__ - - # running flag + background task handle - self._running = True - self._task: asyncio.Task | None = asyncio.create_task(self._worker_wrapper()) - - # ─────────────────────────────── abstract worker - @abstractmethod - async def _worker(self) -> None: # subclasses override - """Main worker method that must be implemented by subclasses. - - This method should contain the main logic for the observer, such as monitoring - user interactions or collecting data. It runs in a background task and should - continue running until the observer is stopped. - """ - pass - - # wrapper plugs running flag + exception handling - async def _worker_wrapper(self) -> None: - """Wrapper for the worker method that handles exceptions and cleanup. - - This method ensures proper cleanup of resources when the worker stops, - whether due to normal termination or an exception. - """ - try: - await self._worker() - except asyncio.CancelledError: - pass - except Exception as exc: - raise - finally: - self._running = False - - # ─────────────────────────────── public API - @property - def name(self) -> str: - """Get the name of the observer. - - Returns: - str: The observer's name. - """ - return self._name - - async def get_update(self): - """Get the next update from the queue if available. - - Returns: - Optional[Update]: The next update from the queue, or None if the queue is empty. - """ - try: - return self.update_queue.get_nowait() - except asyncio.QueueEmpty: - return None - - async def stop(self) -> None: - """Stop the observer and clean up resources. - - This method cancels the worker task and drains the update queue. - """ - if self._task and not self._task.done(): - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - # unblock any awaiters - while not self.update_queue.empty(): - self.update_queue.get_nowait() diff --git a/gum/observers/screen.py b/gum/observers/screen.py deleted file mode 100644 index 726a449..0000000 --- a/gum/observers/screen.py +++ /dev/null @@ -1,432 +0,0 @@ -from __future__ import annotations -############################################################################### -# Imports # -############################################################################### - -# — Standard library — -import base64 -import logging -import os -import time -from collections import deque -from typing import Any, Dict, Iterable, List, Optional - -import asyncio - -# — Third-party — -import mss -import Quartz -from PIL import Image -from pynput import mouse # still synchronous -from shapely.geometry import box -from shapely.ops import unary_union - -# — Local — -from .observer import Observer -from ..schemas import Update - -# — OpenAI async client — -from openai import AsyncOpenAI - -# — Local — -from gum.prompts.screen import TRANSCRIPTION_PROMPT, SUMMARY_PROMPT - -############################################################################### -# Window‑geometry helpers # -############################################################################### - - -def _get_global_bounds() -> tuple[float, float, float, float]: - """Return a bounding box enclosing **all** physical displays. - - Returns - ------- - (min_x, min_y, max_x, max_y) tuple in Quartz global coordinates. - """ - err, ids, cnt = Quartz.CGGetActiveDisplayList(16, None, None) - if err != Quartz.kCGErrorSuccess: # pragma: no cover (defensive) - raise OSError(f"CGGetActiveDisplayList failed: {err}") - - min_x = min_y = float("inf") - max_x = max_y = -float("inf") - for did in ids[:cnt]: - r = Quartz.CGDisplayBounds(did) - x0, y0 = r.origin.x, r.origin.y - x1, y1 = x0 + r.size.width, y0 + r.size.height - min_x, min_y = min(min_x, x0), min(min_y, y0) - max_x, max_y = max(max_x, x1), max(max_y, y1) - return min_x, min_y, max_x, max_y - - -def _get_visible_windows() -> List[tuple[dict, float]]: - """List *onscreen* windows with their visible‑area ratio. - - Each tuple is ``(window_info_dict, visible_ratio)`` where *visible_ratio* - is in ``[0.0, 1.0]``. Internal system windows (Dock, WindowServer, …) are - ignored. - """ - _, _, _, gmax_y = _get_global_bounds() - - opts = ( - Quartz.kCGWindowListOptionOnScreenOnly - | Quartz.kCGWindowListOptionIncludingWindow - ) - wins = Quartz.CGWindowListCopyWindowInfo(opts, Quartz.kCGNullWindowID) - - occupied = None # running union of opaque regions above the current window - result: list[tuple[dict, float]] = [] - - for info in wins: - owner = info.get("kCGWindowOwnerName", "") - if owner in ("Dock", "WindowServer", "Window Server"): - continue - - bounds = info.get("kCGWindowBounds", {}) - x, y, w, h = ( - bounds.get("X", 0), - bounds.get("Y", 0), - bounds.get("Width", 0), - bounds.get("Height", 0), - ) - if w <= 0 or h <= 0: - continue # hidden or minimised - - inv_y = gmax_y - y - h # Quartz→Shapely Y‑flip - poly = box(x, inv_y, x + w, inv_y + h) - if poly.is_empty: - continue - - visible = poly if occupied is None else poly.difference(occupied) - if not visible.is_empty: - ratio = visible.area / poly.area - result.append((info, ratio)) - occupied = poly if occupied is None else unary_union([occupied, poly]) - - return result - - -def _is_app_visible(names: Iterable[str]) -> bool: - """Return *True* if **any** window from *names* is at least partially visible.""" - targets = set(names) - return any( - info.get("kCGWindowOwnerName", "") in targets and ratio > 0 - for info, ratio in _get_visible_windows() - ) - -############################################################################### -# Screen observer # -############################################################################### - -class Screen(Observer): - """Observer that captures and analyzes screen content around user interactions. - - This observer captures screenshots before and after user interactions (mouse movements, - clicks, and scrolls) and uses GPT-4 Vision to analyze the content. It can also take - periodic screenshots and skip captures when certain applications are visible. - - Args: - screenshots_dir (str, optional): Directory to store screenshots. Defaults to "~/.cache/gum/screenshots". - skip_when_visible (Optional[str | list[str]], optional): Application names to skip when visible. - Defaults to None. - transcription_prompt (Optional[str], optional): Custom prompt for transcribing screenshots. - Defaults to None. - summary_prompt (Optional[str], optional): Custom prompt for summarizing screenshots. - Defaults to None. - model_name (str, optional): GPT model to use for vision analysis. Defaults to "gpt-4o-mini". - history_k (int, optional): Number of recent screenshots to keep in history. Defaults to 10. - debug (bool, optional): Enable debug logging. Defaults to False. - - Attributes: - _CAPTURE_FPS (int): Frames per second for screen capture. - _DEBOUNCE_SEC (int): Seconds to wait before processing an interaction. - _MON_START (int): Index of first real display in mss. - """ - - _CAPTURE_FPS: int = 10 - _DEBOUNCE_SEC: int = 2 - _MON_START: int = 1 # first real display in mss - - # ─────────────────────────────── construction - def __init__( - self, - model_name: str = "gpt-4o-mini", - screenshots_dir: str = "~/.cache/gum/screenshots", - skip_when_visible: Optional[str | list[str]] = None, - transcription_prompt: Optional[str] = None, - summary_prompt: Optional[str] = None, - history_k: int = 10, - debug: bool = False, - api_key: str | None = None, - api_base: str | None = None, - ) -> None: - """Initialize the Screen observer. - - Args: - screenshots_dir (str, optional): Directory to store screenshots. Defaults to "~/.cache/gum/screenshots". - skip_when_visible (Optional[str | list[str]], optional): Application names to skip when visible. - Defaults to None. - transcription_prompt (Optional[str], optional): Custom prompt for transcribing screenshots. - Defaults to None. - summary_prompt (Optional[str], optional): Custom prompt for summarizing screenshots. - Defaults to None. - model_name (str, optional): GPT model to use for vision analysis. Defaults to "gpt-4o-mini". - history_k (int, optional): Number of recent screenshots to keep in history. Defaults to 10. - debug (bool, optional): Enable debug logging. Defaults to False. - """ - self.screens_dir = os.path.abspath(os.path.expanduser(screenshots_dir)) - os.makedirs(self.screens_dir, exist_ok=True) - - self._guard = {skip_when_visible} if isinstance(skip_when_visible, str) else set(skip_when_visible or []) - - self.transcription_prompt = transcription_prompt or TRANSCRIPTION_PROMPT - self.summary_prompt = summary_prompt or SUMMARY_PROMPT - self.model_name = model_name - - self.debug = debug - - # state shared with worker - self._frames: Dict[int, Any] = {} - self._frame_lock = asyncio.Lock() - - self._history: deque[str] = deque(maxlen=max(0, history_k)) - self._pending_event: Optional[dict] = None - self._debounce_handle: Optional[asyncio.TimerHandle] = None - self.client = AsyncOpenAI( - # try the class, then the env for screen, then the env for gum - base_url=api_base or os.getenv("SCREEN_LM_API_BASE") or os.getenv("GUM_LM_API_BASE"), - - # try the class, then the env for screen, then the env for GUM, then none - api_key=api_key or os.getenv("SCREEN_LM_API_KEY") or os.getenv("GUM_LM_API_KEY") or os.getenv("OPENAI_API_KEY") or "None" - ) - - # call parent - super().__init__() - - # ─────────────────────────────── tiny sync helpers - @staticmethod - def _mon_for(x: float, y: float, mons: list[dict]) -> Optional[int]: - """Find which monitor contains the given coordinates. - - Args: - x (float): X coordinate. - y (float): Y coordinate. - mons (list[dict]): List of monitor information dictionaries. - - Returns: - Optional[int]: Monitor index if found, None otherwise. - """ - for idx, m in enumerate(mons, 1): - if m["left"] <= x < m["left"] + m["width"] and m["top"] <= y < m["top"] + m["height"]: - return idx - return None - - @staticmethod - def _encode_image(img_path: str) -> str: - """Encode an image file as base64. - - Args: - img_path (str): Path to the image file. - - Returns: - str: Base64 encoded image data. - """ - with open(img_path, "rb") as fh: - return base64.b64encode(fh.read()).decode() - - # ─────────────────────────────── OpenAI Vision (async) - async def _call_gpt_vision(self, prompt: str, img_paths: list[str]) -> str: - """Call GPT Vision API to analyze images. - - Args: - prompt (str): Prompt to guide the analysis. - img_paths (list[str]): List of image paths to analyze. - - Returns: - str: GPT's analysis of the images. - """ - content = [ - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded}"}, - } - for encoded in (await asyncio.gather( - *[asyncio.to_thread(self._encode_image, p) for p in img_paths] - )) - ] - content.append({"type": "text", "text": prompt}) - - rsp = await self.client.chat.completions.create( - model=self.model_name, - messages=[{"role": "user", "content": content}], - response_format={"type": "text"}, - ) - return rsp.choices[0].message.content - - # ─────────────────────────────── I/O helpers - async def _save_frame(self, frame, tag: str) -> str: - """Save a frame as a JPEG image. - - Args: - frame: Frame data to save. - tag (str): Tag to include in the filename. - - Returns: - str: Path to the saved image. - """ - ts = f"{time.time():.5f}" - path = os.path.join(self.screens_dir, f"{ts}_{tag}.jpg") - await asyncio.to_thread( - Image.frombytes("RGB", (frame.width, frame.height), frame.rgb).save, - path, - "JPEG", - quality=70, - ) - return path - - async def _process_and_emit(self, before_path: str, after_path: str) -> None: - """Process screenshots and emit an update. - - Args: - before_path (str): Path to the "before" screenshot. - after_path (str | None): Path to the "after" screenshot, if any. - """ - # chronology: append 'before' first (history order == real order) - self._history.append(before_path) - prev_paths = list(self._history) - - # async OpenAI calls - try: - transcription = await self._call_gpt_vision(self.transcription_prompt, [before_path, after_path]) - except Exception as exc: # pragma: no cover - transcription = f"[transcription failed: {exc}]" - - prev_paths.append(before_path) - prev_paths.append(after_path) - try: - summary = await self._call_gpt_vision(self.summary_prompt, prev_paths) - except Exception as exc: # pragma: no cover - summary = f"[summary failed: {exc}]" - - txt = (transcription + summary).strip() - await self.update_queue.put(Update(content=txt, content_type="input_text")) - - # ─────────────────────────────── skip guard - def _skip(self) -> bool: - """Check if capture should be skipped based on visible applications. - - Returns: - bool: True if capture should be skipped, False otherwise. - """ - return _is_app_visible(self._guard) if self._guard else False - - # ─────────────────────────────── main async worker - async def _worker(self) -> None: # overrides base class - """Main worker method that captures and processes screenshots. - - This method runs in a background task and handles: - - Mouse event monitoring - - Screen capture - - Periodic screenshots - - Image processing and analysis - """ - log = logging.getLogger("Screen") - if self.debug: - logging.basicConfig(level=logging.INFO, format="%(asctime)s [Screen] %(message)s", datefmt="%H:%M:%S") - else: - log.addHandler(logging.NullHandler()) - log.propagate = False - - CAP_FPS = self._CAPTURE_FPS - DEBOUNCE = self._DEBOUNCE_SEC - - loop = asyncio.get_running_loop() - - # ------------------------------------------------------------------ - # All calls to mss / Quartz are wrapped in `to_thread` - # ------------------------------------------------------------------ - with mss.mss() as sct: - mons = sct.monitors[self._MON_START:] - - # ---- mouse callbacks (pynput is sync → schedule into loop) ---- - def schedule_event(x: float, y: float, typ: str): - asyncio.run_coroutine_threadsafe(mouse_event(x, y, typ), loop) - - listener = mouse.Listener( - on_move=lambda x, y: schedule_event(x, y, "move"), - on_click=lambda x, y, btn, prs: schedule_event(x, y, "click") if prs else None, - on_scroll=lambda x, y, dx, dy: schedule_event(x, y, "scroll"), - ) - listener.start() - - # ---- nested helper inside the async context ---- - async def flush(): - """Process pending event and emit update.""" - if self._pending_event is None: - return - if self._skip(): - self._pending_event = None - return - - ev = self._pending_event - aft = await asyncio.to_thread(sct.grab, mons[ev["mon"] - 1]) - - bef_path = await self._save_frame(ev["before"], "before") - aft_path = await self._save_frame(aft, "after") - await self._process_and_emit(bef_path, aft_path) - - log.info(f"{ev['type']} captured on monitor {ev['mon']}") - self._pending_event = None - - def debounce_flush(): - """Schedule flush as a task.""" - asyncio.create_task(flush()) - - # ---- mouse event reception ---- - async def mouse_event(x: float, y: float, typ: str): - """Handle mouse events. - - Args: - x (float): X coordinate. - y (float): Y coordinate. - typ (str): Event type ("move", "click", or "scroll"). - """ - idx = self._mon_for(x, y, mons) - log.info( - f"{typ:<6} @({x:7.1f},{y:7.1f}) → mon={idx} {'(guarded)' if self._skip() else ''}" - ) - if self._skip() or idx is None: - return - - # lazily grab before-frame - if self._pending_event is None: - async with self._frame_lock: - bf = self._frames.get(idx) - if bf is None: - return - self._pending_event = {"type": typ, "mon": idx, "before": bf} - - # reset debounce timer - if self._debounce_handle: - self._debounce_handle.cancel() - self._debounce_handle = loop.call_later(DEBOUNCE, debounce_flush) - - # ---- main capture loop ---- - log.info(f"Screen observer started — guarding {self._guard or '∅'}") - - while self._running: # flag from base class - t0 = time.time() - - # refresh 'before' buffers - for idx, m in enumerate(mons, 1): - frame = await asyncio.to_thread(sct.grab, m) - async with self._frame_lock: - self._frames[idx] = frame - - # fps throttle - dt = time.time() - t0 - await asyncio.sleep(max(0, (1 / CAP_FPS) - dt)) - - # shutdown - listener.stop() - if self._debounce_handle: - self._debounce_handle.cancel() diff --git a/gum/prompts/__init__.py b/gum/prompts/__init__.py deleted file mode 100644 index 0519ecb..0000000 --- a/gum/prompts/__init__.py +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/gum/prompts/gum.py b/gum/prompts/gum.py deleted file mode 100644 index 0704a78..0000000 --- a/gum/prompts/gum.py +++ /dev/null @@ -1,213 +0,0 @@ -AUDIT_PROMPT = """You are a data privacy compliance assistant for a large language model (LLM). - -Here are some past interactions {user_name} had with an LLM - -## Past Interactions - -{past_interaction} - -## Task - -{user_name} currently is looking at the following: - -User Input ---- -{user_input} ---- - -Given {user_name}'s input, analyze and respond in structured JSON format with the following fields: - -1. `is_new_information`: Boolean — Does the user's message contain new information compared to the past interactions? -2. `data_type`: String — What type of data is being disclosed (e.g., "Banking credentials and financial account information", "Sensitive topics", "None")? -3. `subject`: String — Who is the primary subject of the disclosed data? -4. `recipient`: String — Who or what is the recipient of the information (e.g., "An AI model that provides conversational assistance")? -5. `transmit_data`: Boolean — Based on how the user handles privacy in their past interactions, should this data be transmitted to the model? - -Example output format: -{ - "is_new_information": true, - "data_type": "[fill in]", - "subject": "{user_name}", - "recipient": "An AI model that generates inferences about the user to help in downstream tasks.", - "transmit_data": true -}""" - - -PROPOSE_PROMPT = """You are a helpful assistant tasked with analyzing user behavior based on transcribed activity. - -# Analysis - -Using a transcription of {user_name}'s activity, analyze {user_name}'s current activities, behavior, and preferences. Draw insightful, concrete conclusions. - -To support effective information retrieval (e.g., using BM25), your analysis must **explicitly identify and refer to specific named entities** mentioned in the transcript. This includes applications, websites, documents, people, organizations, tools, and any other proper nouns. Avoid general summaries—**use exact names** wherever possible, even if only briefly referenced. - -Consider these points in your analysis: - -- What specific tasks or goals is {user_name} actively working towards, as evidenced by named files, apps, platforms, or individuals? -- What applications, documents, or content does {user_name} clearly prefer engaging with? Identify them by name. -- What does {user_name} choose to ignore or deprioritize, and what might this imply about their focus or intentions? -- What are the strengths or weaknesses in {user_name}’s behavior or tools? Cite relevant named entities or resources. - -Provide detailed, concrete explanations for each inference. **Support every claim with specific references to named entities in the transcript.** - -## Evaluation Criteria - -For each proposition you generate, evaluate its strength using two scales: - -### 1. Confidence Scale - -Rate your confidence based on how clearly the evidence supports your claim. Consider: - -- **Direct Evidence**: Is there direct interaction with a specific, named entity (e.g., opened “Notion,” responded to “Slack” from “Alex”)? -- **Relevance**: Is the evidence clearly tied to the proposition? -- **Engagement Level**: Was the interaction meaningful or sustained? - -Score: **1 (weak support)** to **10 (explicit, strong support)**. High scores require specific named references. - -### 2. Decay Scale - -Rate how long the proposition is likely to stay relevant. Consider: - -- **Urgency**: Does the task or interest have clear time pressure? -- **Durability**: Will this matter 24 hours later or more? - -Score: **1 (short-lived)** to **10 (long-lasting insight or pattern)**. - -# Input - -Below is a set of transcribed actions and interactions that {user_name} has performed: - -## User Activity Transcriptions - -{inputs} - -# Task - -Generate **at least 5 distinct, well-supported propositions** about {user_name}, each grounded in the transcript. - -Be conservative in your confidence estimates. Just because an application appears on {user_name}'s screen does not mean they have deeply engaged with it. They may have only glanced at it for a second, making it difficult to draw strong conclusions. - -Assign high confidence scores (e.g., 8-10) only when the transcriptions provide explicit, direct evidence that {user_name} is actively engaging with the content in a meaningful way. Keep in mind that that the content on the screen is what the user is viewing. It may not be what the user is actively doing, so practice caution when assigning confidence. - -Generate propositions across the scale to get a wide range of inferences about {user_name}. - -Return your results in this exact JSON format: - -{ - "propositions": [ - { - "proposition": "[Insert your proposition here]", - "reasoning": "[Provide detailed evidence from specific parts of the transcriptions to clearly justify this proposition. Refer explicitly to named entities where applicable.]", - "confidence": "[Confidence score (1–10)]", - "decay": "[Decay score (1–10)]" - }, - ... - ] -}""" - -REVISE_PROMPT = """You are an expert analyst. A cluster of similar propositions are shown below, followed by their supporting observations. - -Your job is to produce a **final set** of propositions that is clear, non-redundant, and captures everything about the user, {user_name}. - -To support information retrieval (e.g., with BM25), you must **explicitly identify and preserve all named entities** from the input wherever possible. These may include applications, websites, documents, people, organizations, tools, or any other specific proper nouns mentioned in the original propositions or their evidence. - -You MAY: - -- **Edit** a proposition for clarity, precision, or brevity. -- **Merge** propositions that convey the same meaning. -- **Split** a proposition that contains multiple distinct claims. -- **Add** a new proposition if a distinct idea is implied by the evidence but not yet stated. -- **Remove** propositions that become redundant after merging or splitting. - -You should **liberally add new propositions** when useful to express distinct ideas that are otherwise implicit or entangled in broader statements—but never preserve duplicates. - -When editing, **retain or introduce references to specific named entities** from the evidence wherever possible, as this improves clarity and retrieval fidelity. - -Edge cases to handle: - -- **Contradictions** – If two propositions conflict, keep the one with stronger supporting evidence, or merge them into a conditional statement. Lower the confidence score of weaker or uncertain claims. -- **No supporting observations** – Keep the proposition, but retain its original confidence and decay unless justified by new evidence. -- **Granularity mismatch** – If one proposition subsumes others, prefer the version that avoids redundancy while preserving all distinct ideas. -- **Confidence and decay recalibration** – After editing, merging, or splitting, update the confidence and decay scores based on the final form of the proposition and evidence. - -General guidelines: - -- Keep each proposition clear and concise (typically 1–2 sentences). -- Maintain all meaningful content from the originals. -- Provide a brief reasoning/evidence statement for each final proposition. -- Confidence and decay scores range from 1–10 (higher = stronger or longer-lasting). - -## Evaluation Criteria - -For each proposition you revise, evaluate its strength using two scales: - -### 1. Confidence Scale - -Rate your confidence in the proposition based on how directly and clearly it is supported by the evidence. Consider: - -- **Direct Evidence**: Is the claim directly supported by clear, named interactions in the observations? -- **Relevance**: Is the evidence closely tied to the proposition? -- **Completeness**: Are key details present and unambiguous? -- **Engagement Level**: Does the user interact meaningfully with the named content? - -Score: **1 (weak/assumed)** to **10 (explicitly demonstrated)**. High scores require direct and strong evidence from the observations. - -### 2. Decay Scale - -Rate how long the insight is likely to remain relevant. Consider: - -- **Immediacy**: Is the activity time-sensitive? -- **Durability**: Will the proposition remain true over time? - -Score: **1 (short-lived)** to **10 (long-term relevance or behavioral pattern)**. - -# Input - -{body} - -# Output - -Assign high confidence scores (e.g., 8-10) only when the transcriptions provide explicit, direct evidence that {user_name} is actively engaging with the content in a meaningful way. Keep in mind that that the input is what the {user_name} is viewing. It may not be what the {user_name} is actively doing, so practice caution when assigning confidence. - -Return **only** JSON in the following format: - -{ - "propositions": [ - { - "proposition": "", - "reasoning": "", - "confidence": , - "decay": - }, - ... - ] -}""" - -SIMILAR_PROMPT = """You will label sets of propositions based on how similar they are to eachother. - -# Propositions - -{body} - -# Task - -Use exactly these labels: - -(A) IDENTICAL – The propositions say practically the same thing. -(B) SIMILAR – The propositions relate to a similar idea or topic. -(C) UNRELATED – The propositions are fundamentally different. - -Always refer to propositions by their numeric IDs. - -Return **only** JSON in the following format: - -{ - "relations": [ - { - "source": , - "label": "IDENTICAL" | "SIMILAR" | "UNRELATED", - "target": [, ...] // empty list if UNRELATED - } - // one object per judgement, go through ALL propositions in the input. - ] -}""" \ No newline at end of file diff --git a/gum/prompts/screen.py b/gum/prompts/screen.py deleted file mode 100644 index effcaa5..0000000 --- a/gum/prompts/screen.py +++ /dev/null @@ -1,15 +0,0 @@ -TRANSCRIPTION_PROMPT = """Transcribe in markdown ALL the content from the screenshots of the user's screen. - -NEVER SUMMARIZE ANYTHING. You must transcribe everything EXACTLY, word for word, but don't repeat yourself. - -ALWAYS include all the application names, file paths, and website URLs in your transcript. - -Create a FINAL structured markdown transcription.""" - -SUMMARY_PROMPT = """Provide a detailed description of the actions occuring across the provided images. The images are in the order they were taken. - -Include as much relevant detail as possible, but remain concise. - -Generate a handful of bullet points and reference *specific* actions the user is taking. - -Keep in mind that that the content on the screen is what the user is viewing. It may not be what the user is actively doing or what they believe, so practice caution when making assumptions.""" \ No newline at end of file diff --git a/induce/get_human_trajectory.py b/induce/get_human_trajectory.py new file mode 100644 index 0000000..7146b51 --- /dev/null +++ b/induce/get_human_trajectory.py @@ -0,0 +1,356 @@ +"""Load the process human trajectory from the database.""" + +import os +import shutil +import argparse + +import pandas as pd +from sqlalchemy import create_engine, text +from utils import ( + is_click_action, is_keyboard_action, is_scroll_action, + get_key_input, compose_key_input +) +from language import ActionNode, SequenceNode + +# %% Action Processing + +def load_actions_from_db(log_dir: str, db_path: str) -> list[str]: + """Load the actions from the database. """ + db_path = os.path.expanduser(os.path.join(log_dir, db_path)) + engine = create_engine(f"sqlite:///{db_path}") + + with engine.connect() as connection: + query = text("SELECT * from observations") + df = pd.read_sql_query(query, connection) + return df["content"].to_list() + +def hotkey_in_action(action: str) -> bool: + """Check if the action contains a hotkey.""" + return any(hotkey in action for hotkey in [".cmd", ".enter", ".tab", ".up", ".down"]) + + +def trigger_close_buffer(action: str, buffer_actions: list[str], enable_hotkey: bool = False) -> bool: + """Time to close the buffer: + - Current buffer is non-empty + - Next new key/scroll action is different from the last action in the buffer. + """ + if len(buffer_actions) == 0: + return False + if is_keyboard_action(buffer_actions[-1]) and (not is_keyboard_action(action)): + return True + if is_scroll_action(buffer_actions[-1]) and (not is_scroll_action(action)): + return True + if enable_hotkey and is_keyboard_action(action) and hotkey_in_action(action): + return True + return False + + +def trigger_add_buffer(action: str, buffer_actions: list[str]) -> bool: + """Should add the new action to the buffer. + - Is keyboard or scroll action + - (i) buffer is empty; (ii) last action in buffer is the same type as the new action. + """ + if not (is_keyboard_action(action) or is_scroll_action(action)): + return False + if len(buffer_actions) == 0: + return True + if is_keyboard_action(action) and is_keyboard_action(buffer_actions[-1]): + # print(f"Event 1: {action} | {buffer_actions[-1]}") + return True + if is_scroll_action(action) and is_scroll_action(buffer_actions[-1]): + return True + return False + + +def merge_actions(actions: list[str], enable_hotkey: bool = False) -> list[str]: + """Merge adjacent keyboard and scrolling actions into a single action.""" + original_actions, merged_actions = [], [] + buffer_actions = [] + for action in actions: + close_buffer_flag = trigger_close_buffer(action, buffer_actions, enable_hotkey=enable_hotkey) + if close_buffer_flag: + if buffer_actions and is_keyboard_action(buffer_actions[0]): # keypress buffer + assert all([is_keyboard_action(action) for action in buffer_actions]) + original_actions.append({"before": buffer_actions[0], "after": buffer_actions[-1]}) + + buffer_values = [get_key_input(action) for action in buffer_actions] + keyboard_input = compose_key_input(buffer_values) + merged_actions.append(f"key_press('{keyboard_input}')") + # print("[KeyPress] :", merged_actions[-1]) + elif buffer_actions and is_scroll_action(buffer_actions[0]): # scroll buffer + assert all([is_scroll_action(action) for action in buffer_actions]) + for ba in buffer_actions: + if len(merged_actions) == 0 or ba != merged_actions[-1]: + original_actions.append({"before": ba, "after": ba}) + merged_actions.append(ba) + # print("[Scroll] :", merged_actions[-1]) + buffer_actions = [] + + add_buffer_flag = trigger_add_buffer(action, buffer_actions) + if add_buffer_flag: + buffer_actions.append(action) + else: + merged_actions.append(action) + original_actions.append({"before": action, "after": action}) + + return original_actions, merged_actions + + +# %% State + +def find_screenshot(screenshot_paths: list[str], action: str, suffix: str) -> tuple[str, list[str]]: + """Find the screenshot path for the given action and suffix. + Return the screenshot path and the remaining screenshot paths.""" + for i, sp in enumerate(screenshot_paths): + if action in sp and sp.endswith(suffix): + return screenshot_paths[i], screenshot_paths[: i] + screenshot_paths[i+1:] + return None, screenshot_paths + + +def get_states(actions: list[str], screenshot_dir: str, is_windows: bool = False) -> list[dict[str, str]]: + """Get before/after states (screenshots) associate with each action. + """ + screenshot_paths = sorted(os.listdir(screenshot_dir), key=lambda x: x.split('_')[0]) # sort by timestamp + screenshot_paths = [os.path.join(screenshot_dir, p) for p in screenshot_paths] + + states = [] + for action_dict in actions: + # print(action_dict) + suffix_before = "_first.jpg" if is_keyboard_action(action_dict["before"]) else "_before.jpg" + before_path, screenshot_paths = find_screenshot(screenshot_paths, action_dict["before"], suffix_before) + + suffix_after = "_final.jpg" if is_keyboard_action(action_dict["after"]) else "_after.jpg" + after_path, screenshot_paths = find_screenshot(screenshot_paths, action_dict["after"], suffix_after) + state = {"before": before_path, "after": after_path} + states.append(state) + print(state) + print('-'*20) + + return states + + +def adjust_states(actions: list[str], states: list[dict]) -> list[dict]: + """Adjust the states to reflect more accurate changes.""" + adjusted_states = [] + for i, (action, state) in enumerate(zip(actions, states)): + if (i == 0) or is_keyboard_action(action): + before_state = state["before"] + else: + before_state = states[i-1]["after"] + + if is_keyboard_action(action) and (i < len(actions) - 1): + after_state = states[i+1]["before"] + else: + after_state = state.get("after", state["before"]) + + adjusted_states.append({"before": before_state, "after": after_state}) + + return adjusted_states + + +# %% Time + +def parse_screenshot_path(path: str) -> tuple[str, str]: + """Parse the screenshot path into action and timestamp.""" + parts = path.split('/')[-1].split('_') + timestamp = parts[0] + if "key" in parts: + action = '_'.join(parts[1:]).rstrip(".jpg") + tag = "before" + else: + action = '_'.join(parts[1:-1]) + tag = parts[-1].split('.')[0] + return {"timestamp": timestamp, "action": action, "tag": tag} + + +# %% Merge Click Actions + +def parse_click_coords(action: str) -> tuple[float, float]: + """Parse the coordinates from the action.""" + x, y = action.split('(')[1].split(')')[0].split(',') + return float(x), float(y) + +def is_double_click(step_1: ActionNode, step_2: ActionNode, time_threshold: float = 0.5, distance_threshold: float = 10) -> bool: + """Check if the two click actions constitute a double click.""" + if not (is_click_action(step_1.action) and is_click_action(step_2.action)): + return False + if step_2.time.diff > time_threshold: + return False + + x1, y1 = parse_click_coords(step_1.action) + x2, y2 = parse_click_coords(step_2.action) + dx, dy = x2 - x1, y2 - y1 + distance = (dx * dx + dy * dy) ** 0.5 + return distance < distance_threshold + +def merge_double_clicks(node_list: list[ActionNode]) -> list[ActionNode]: + """Merge double clicks into a single click.""" + merged_node_list = [] + i, N = 0, len(node_list) - 1 + while i < N: + step, next_step = node_list[i], node_list[i+1] + if is_double_click(step, next_step): + coords_str = '(' + step.action.split('(')[1] + merged_action = "double_click" + coords_str + + data = { + "action": merged_action, + "state": { + "before": step.state.before, + "after": next_step.state.after, + }, + "time": { + "before": step.time.before, + "after": next_step.time.after, + "range": step.time.range + next_step.time.range, + "diff": step.time.diff + next_step.time.diff + } + } + merged_node_list.append(ActionNode.from_json(data=data)) + i += 2 + else: + merged_node_list.append(step) + i += 1 + print(f"Double clicks merged: #{len(node_list)} -> #{len(merged_node_list)} steps.") + return merged_node_list + + +# %% Time + +def parse_time_from_path(path: str) -> float: + """Parse the time from the path.""" + return float(path.split('/')[-1].split('_')[0]) + + +def measure_time_from_states(states: list[dict]) -> list[dict]: + """Measure the time from the states.""" + time_list = [] + for i, state in enumerate(states): + # calculate time range + try: + before_time = parse_time_from_path(state["before"]) + after_time = parse_time_from_path(state.get("after", state["before"])) + except: + print(f"Error parsing time from path: {state}") + before_time = 0 + after_time = 0 + time_range = after_time - before_time + + # calculate time diff + if i == 0: + time_diff = 0 + else: + try: + last_time = parse_time_from_path(states[i-1].get("after", states[i-1]["before"])) + time_diff = before_time - last_time + except: + print(f"Error parsing time from path: {states[i-1]}") + time_diff = 0 + + time_list.append({ + "before": before_time, "after": after_time, + "range": time_range, "diff": time_diff, + }) + return time_list + + +def transfer_valid_states(node_list: list[ActionNode], src_suffix: str, dst_suffix: str): + """Trasfer valid states to the new directory.""" + for i, node in enumerate(node_list): + before_path, after_path = node.state.before, node.state.after + if before_path is not None: + dst_before_path = before_path.replace(src_suffix, dst_suffix) + if os.path.exists(before_path): + shutil.move(before_path, dst_before_path) + node_list[i].state.before = dst_before_path + + if after_path is not None: + dst_after_path = after_path.replace(src_suffix, dst_suffix) + if os.path.exists(after_path): + shutil.move(after_path, dst_after_path) + node_list[i].state.after = dst_after_path + + return node_list + + +# %% Main +def main(): + actions = load_actions_from_db(args.data_dir, args.db_path) + print(f"Loaded {len(actions)} actions from the database.") + original_actions, actions = merge_actions(actions, enable_hotkey=args.enable_hotkey) + states = get_states(original_actions, args.screenshot_dir) + time_list = measure_time_from_states(states) + assert len(actions) == len(states) == len(time_list) + # trajectory = [{"action": a, "time": t, "state": s} for a,s,t in zip(actions, states, time_list)] + print(f"Original trajectory: #{len(actions)} steps.") + + # prune the consecutive actions without before+after states, in the beginning and end of the trajectory + # find the first action with not-None before+after states + first_idx = 0 + for i, (a, s, t) in enumerate(zip(actions, states, time_list)): + if s["before"] is not None and s["after"] is not None: + if states[i+1]["before"] is None or states[i+1]["after"] is None: + first_idx = i + break + actions = actions[first_idx:] + states = states[first_idx:] + time_list = time_list[first_idx:] + # find the last action with not-None before+after states + last_idx = len(actions) - 1 + for i in range(len(actions)-1, -1, -1): + if states[i]["before"] is not None and states[i]["after"] is not None: + if states[i-1]["before"] is None or states[i-1]["after"] is None: + last_idx = i + break + actions = actions[:last_idx+1] + states = states[:last_idx+1] + time_list = time_list[:last_idx+1] + print(f"Pruned trajectory: #{last_idx+1-first_idx} steps.") + + if args.adjust_states: + states = adjust_states(actions, states) + adjusted_time_list = measure_time_from_states(states) + for time_dict, adjusted in zip(time_list, adjusted_time_list): + time_dict["range"] = adjusted["range"] + assert len(actions) == len(states) == len(time_list) + + node_list = [ActionNode(action=a, state=s, time=t) for a, s, t in zip(actions, states, time_list)] + + # organize trajectory + if args.merge_double_clicks: + node_list = merge_double_clicks(node_list) + + if args.transfer_valid_states: + src_suffix = args.screenshot_dir.split('/')[-1] + dst_suffix = args.state_dir.split('/')[-1] + node_list = transfer_valid_states(node_list, src_suffix, dst_suffix) + + print(f"Saving trajectory of #{len(node_list)} steps to {args.data_dir}...") + traj_dir = args.data_dir.replace("/records", "") + traj_path = os.path.join(traj_dir, "processed_trajectory.json") + root = SequenceNode(nodes=node_list) + root.to_json(traj_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, required=True, + help="The directory containing the raw trajectory data.") + parser.add_argument("--db_path", type=str, default="actions.db") + parser.add_argument("--screenshot_dir", type=str, default="screenshots") + + parser.add_argument("--enable_hotkey", action="store_true", help="Enable hotkey in the action.") + parser.add_argument("--adjust_states", action="store_true", help="Adjust the states to reflect more accurate changes.") + parser.add_argument("--merge_double_clicks", action="store_true", help="Identify double clicks and merge them into a single action.") + parser.add_argument("--transfer_valid_states", action="store_true", help="Transfer valid states to the new directory.") + parser.add_argument("--state_dir", type=str, default="states") + parser.add_argument("--verbose", action="store_true") + + args = parser.parse_args() + + args.data_dir = os.path.join(args.data_dir, "records") + args.screenshot_dir = os.path.join(args.data_dir, args.screenshot_dir) + + args.adjust_states, args.merge_double_clicks = True, True + + main() diff --git a/induce/induce.py b/induce/induce.py new file mode 100644 index 0000000..bef0eaa --- /dev/null +++ b/induce/induce.py @@ -0,0 +1,188 @@ +import os +import json +import argparse +from utils import call_openai +from language import ActionNode, SequenceNode + + +def get_node_list(data_dir: str) -> list[ActionNode | SequenceNode]: + node_paths = [f for f in os.listdir(data_dir) if f.endswith(".json")] + node_paths.sort(key=lambda x: int(x.split('.')[0])) + node_paths = [os.path.join(data_dir, f) for f in node_paths] + assert len(node_paths) > 0, f"No node files found in {data_dir}" + node_list = [] + for np in node_paths: + np_data = json.load(open(np)) + if np_data["node_type"] == "action": + node_list.append(ActionNode.from_json(data=np_data)) + elif np_data["node_type"] == "sequence": + node_list.append(SequenceNode.from_json(data=np_data)) + else: + raise ValueError(f"Unknown node type: {np_data['node_type']}") + return node_list + +def get_step_goals(node_list: list[ActionNode | SequenceNode]) -> str: + goals = [] + for i, node in enumerate(node_list): + if node.status.value == "failure": + goals.append(f"[{i}] (Attempted Failure) {node.goal}") + elif node.goal is None: + continue + else: + goals.append(f"[{i}] {node.goal}") + return '\n'.join(goals) + + +# %% Induce and Parse Workflow + +def get_workflow(text: str, verbose: bool = True) -> list[str]: + """Induce the workflow from the step goals, by adopting or merging the steps. + Args: + text: The step goals. + verbose: Whether to print the workflow. + Returns: + The workflow steps. + """ + prompt = open(os.path.join(args.prompt_dir, "induce.txt")).read() + workflow = call_openai(prompt=prompt, content=text) + workflow = workflow.strip('```').strip('\n').strip() + if verbose: print(workflow) + workflow_steps = workflow.split('\n') + workflow_steps = [ws for ws in workflow_steps if ws.startswith('[')] + return workflow_steps + + +def parse_step(step: str) -> dict: + index_text, desc_text = step.split(']') + index_text = index_text.strip().lstrip('[').rstrip(']') + if '-' in index_text: + s, e = index_text.split('-') + s = int(s.strip()) + e = int(e.strip()) + else: + s = e = int(index_text.strip()) + + desc_text = desc_text.strip() + return {"index": (s, e), "goal": desc_text} + + +def parse_workflow(workflow_steps: list[str], node_list: list[ActionNode | SequenceNode], verbose: bool = True) -> SequenceNode: + workflow_root = SequenceNode(nodes=[]) + if verbose: print("Parsing workflow steps...") + for step in workflow_steps: + wdict = parse_step(step) + s, e = wdict["index"] # inclusive at both ends + if s == e: # a single action/sequence node + w_node = node_list[s] + if w_node.node_type.value == "sequence": + w_node.get_status() + else: # sequence node + w_node = SequenceNode(nodes=node_list[s:e+1]) + w_node.goal = wdict["goal"] + w_node.get_status() + if w_node.status.value == "failure": + w_node.goal = w_node.goal + workflow_root.nodes.append(w_node) + if verbose: + print(f"{w_node.node_type} | {w_node.goal} | {w_node.status}") + + return workflow_root + + +# %% Main + +def one_pass(input_dir: str, output_dir: str): + input_dir = os.path.join(args.data_dir, input_dir) + output_dir = os.path.join(args.data_dir, output_dir) + os.makedirs(output_dir, exist_ok=True) + + node_list = get_node_list(input_dir) + step_goals = get_step_goals(node_list) + if args.verbose: print("Original step goals:\n", step_goals) + + workflow_steps = get_workflow(step_goals) + + # save workflow plain text + output_path = os.path.join(output_dir, f"workflow.txt") + with open(output_path, 'w') as fw: + fw.write('\n'.join(workflow_steps)) + + # save workflow json + workflow_root = None + while workflow_root is None: + try: + workflow_root = parse_workflow(workflow_steps, node_list, args.verbose) + print(f"Parsed workflow: {len(workflow_root.nodes)}") + except Exception as e: + print(f"Error parsing workflow: {e}") + print("Please try again.") + workflow_steps = get_workflow(step_goals) + + for i, node in enumerate(workflow_root.nodes): + output_path = os.path.join(output_dir, f"{i}.json") + node.to_json(output_path) + + final = input("Output the final workflow? (Y/n)") + if final.lower() != "n": + output_path = os.path.join(args.data_dir, f"workflow.json") + print(f"Outputting the final workflow to {output_path} ...") + workflow_root.to_json(output_path) + + return workflow_root + + +def decide_next_dir(input_dir: str, output_dir: str) -> tuple[str, str]: + input_index = input_dir.lstrip('nodes').rstrip('/') + if input_index == '': input_index = 0 + else: input_index = int(input_index) + output_index = output_dir.lstrip('nodes').rstrip('/') + output_index = int(output_index) + print(f"Input index: {input_index}, Output index: {output_index}") + return f"nodes{input_index + 1}", f"nodes{output_index + 1}" + +def is_close_enough(last_root: SequenceNode, curr_root: SequenceNode) -> bool: + return len(last_root.nodes) <= (len(curr_root.nodes) + 2) + + +def auto_iterate(): + node_list = get_node_list(os.path.join(args.data_dir, args.input_dir)) + last_root = SequenceNode(nodes=node_list) + close_enough = False + i_iter, max_iter = 0, 5 + while i_iter < max_iter: + curr_root = one_pass(args.input_dir, args.output_dir) + close_enough = is_close_enough(last_root, curr_root) + if close_enough: break + last_root = curr_root + i_iter += 1 + args.input_dir, args.output_dir = decide_next_dir(args.input_dir, args.output_dir) + + output_path = os.path.join(args.data_dir, f"workflow.json") + curr_root.to_json(output_path) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, required=True, + help="The directory containing the nodes folder.") + parser.add_argument("--input_dir", type=str, default="nodes", + help="The directory containing the input nodes.") + parser.add_argument("--output_dir", type=str, default="nodes1", + help="The directory to save the merged nodes.") + + parser.add_argument("--model_name", type=str, + default="litellm/neulab/claude-3-5-sonnet-20241022", + help="The model name to use for the LLM.") + parser.add_argument("--prompt_dir", type=str, default="prompts", + help="The directory containing the prompts.") + + parser.add_argument("--auto", action="store_true", help="If automatically iterate and terminate workflow induction.") + parser.add_argument("--verbose", action="store_true", help="Print details.") + + args = parser.parse_args() + + if args.auto: + auto_iterate() + else: + one_pass(args.input_dir, args.output_dir) diff --git a/induce/language.py b/induce/language.py new file mode 100644 index 0000000..db57244 --- /dev/null +++ b/induce/language.py @@ -0,0 +1,419 @@ +import os +import enum +import json +from utils import is_keyboard_action, encode_image, call_openai + +from openai import OpenAI +client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + +MAX_DIFF = 100000.0 + +class NodeType(enum.Enum): + ACTION = "action" + SEQUENCE = "sequence" + +# %% Action Node + +class Time: + def __init__(self, before: str, after: str, range: float = None, diff: float = None): + """Initialize the time object associated with an action. + """ + self.before = before + self.after = after + self.range = range + self.diff = diff + + @classmethod + def from_json(cls, path: str = None, data: dict = None): + if path is not None: + data = json.load(open(path)) + elif data is None: + return None + return cls(**data) + + def to_json(self, path: str = None): + data = { + "before": self.before, "after": self.after, + "range": self.range, "diff": self.diff + } + if path is not None: + json.dump(data, open(path, 'w')) + return data + + def get_time(self, reverse: bool = False) -> str: + if reverse: return self.after if (self.after is not None) else self.before + else: return self.before if (self.before is not None) else self.after + + +class State: + def __init__(self, before: str, after: str, diff_score: float = None): + """Initialize the state object associated with an action. + Args: + before: The screenshot path of the state before the action. + after: The screenshot path of the state after the action. + diff_score: The MSE difference score between the before and after states. + """ + self.before = before + self.after = after + self.diff_score = diff_score + + @classmethod + def from_json(cls, path: str = None, data: dict = None): + if path is not None: + data = json.load(open(path)) + elif data is None: + raise ValueError("Either path or data must be provided.") + return cls(before=data["before"], after=data["after"], diff_score=data.get("diff_score", None)) + + def to_json(self, path: str = None): + data = {"before": self.before, "after": self.after, "diff_score": self.diff_score} + if path is not None: + json.dump(data, open(path, 'w')) + return data + + def get_state(self, reverse: bool = False) -> str: + if reverse: return self.after if (self.after is not None) else self.before + else: return self.before if (self.before is not None) else self.after + + +class ActionNode: + def __init__(self, action: str, state: dict | State, goal: str = None, time: dict | Time = None): + self.node_type = NodeType.ACTION + self.length = 1 + self.action = action + self.state = state if isinstance(state, State) else State(**state) + self.goal = goal + if time is None: + self.time = None + else: + self.time = time if isinstance(time, Time) else Time(**time) + self.status = SequenceStatus.SUCCESS + + def __str__(self): + return f"ActionNode(action={self.action}, state={self.state}, description={self.description})" + + def get_semantic_repr(self): + if self.goal is not None: return self.goal + else: self.action + + def get_num_actions(self): + return 1 + + def get_goal(self, model_name: str = None): + """Verbalize the `goal` of the `action`.""" + content = get_action_content(self, add_state=False) + prompt = "Your task is to summarize the goal in a short sentence, given the action and the state." + \ + "Do not include prefix like 'the goal is'. Do not include action-specific details like the coordinates." + goal = call_openai(prompt=prompt, content=content) + self.goal = goal + + @classmethod + def from_json(cls, path: str = None, data: dict = None): + if path is not None: + data = json.load(open(path)) + elif data is None: + raise ValueError("Either path or data must be provided.") + state = State.from_json(data=data["state"]) + time = Time.from_json(data=data.get("time", None)) + return cls(action=data["action"], state=state, goal=data.get("goal", None), time=time) + + def to_json(self, path: str = None): + """Save the action node to a JSON file. + Args: + path: The path to save the action node. + """ + data = { + "node_type": self.node_type.value, + "action": self.action, + "state": self.state.to_json(), + "goal": self.goal, + "time": self.time.to_json() if self.time is not None else None + } + if path is not None: + json.dump(data, open(path, 'w')) + return data + +# %% Sequence Node + +class SequenceStatus(enum.Enum): + SUCCESS = "success" + FAILURE = "failure" + UNKNOWN = "unknown" + + +class SequenceNode: + def __init__( + self, + nodes: list, + goal: str = None, + status: SequenceStatus = None, + ): + self.node_type = NodeType.SEQUENCE + self.nodes = nodes + self.length = len(self.nodes) + + self.goal = goal + if status is None: + self.status = SequenceStatus.UNKNOWN + else: + for v in SequenceStatus: + if v.value == status: + self.status = v + else: + self.status = SequenceStatus.UNKNOWN + + def __str__(self): + return f"SequenceNode(goal={self.goal}, nnodes={len(self.nodes)}, status={self.status})" + + def get_semantic_repr(self): + if self.goal is not None: return self.goal + else: + subgoals = [n.get_semantic_repr() for n in self.nodes] + return '\n'.join([sg for sg in subgoals if sg is not None]) + + def get_num_actions(self): + num_actions = 0 + for n in self.nodes: + num_actions += n.get_num_actions() + return num_actions + + def annotate( + self, + prompt_path: str = "prompts/annotate_node.txt", + model_name: str = "litellm/neulab/claude-3-5-sonnet-20241022", + bucket_size: int = 20, + verbose: bool = True, + ): + """Group adjacent `nodes` into a `SequenceNode`.""" + total_nodes = len(self.nodes) + num_buckets = (total_nodes + bucket_size - 1) // bucket_size + if verbose: print(f"Total nodes: {total_nodes} | Num buckets: {num_buckets}") + + prompt = open(prompt_path).read() + + new_nodes = [] + for i in range(num_buckets): + print(f"Bucket {i}...") + nodes = self.nodes[i*bucket_size:(i+1)*bucket_size] + content = get_nodes_content(nodes, add_state=True) + response = call_openai(prompt=prompt, content=content) + chunks = parse_annotation(response, nodes) + while chunks is None: + print("Failed to parse annotation. Please try again.") + response = call_openai(prompt=prompt, content=content) + chunks = parse_annotation(response, nodes) + # print("Chunks: ", chunks) + chunks = validate_chunks(chunks, nodes) + + for c in chunks: + new_node = get_new_node(nodes[c["start"]:c["end"]+1]) + new_node.goal = c["goal"] + new_nodes.append(new_node) + if verbose: print(f"Bucket {i} done: {len(new_nodes)} new nodes..") + + self.nodes = new_nodes + + # update the goal and status + self.get_goal(model_name=model_name) + self.get_status(model_name=model_name) + + def get_goal(self, prompt_path: str = "prompts/get_node_goal.txt", model_name: str = "litellm/neulab/claude-3-5-sonnet-20241022"): + """Summarize the `goal` from the `nodes`.""" + subgoals = [f"[{i}] {n.get_semantic_repr()}" for i, n in enumerate(self.nodes)] + content = [{"type": "text", "text": '\n'.join(subgoals)}] + prompt = open(prompt_path).read() + goal = call_openai(prompt=prompt, content=content) + self.goal = goal + + def get_status(self, prompt_path: str = "prompts/get_node_status.txt", model_name: str = "litellm/neulab/claude-3-5-sonnet-20241022"): + """Get the success/failure status of the `nodes` in achieving the `goal`.""" + prompt = open(prompt_path).read() + content = get_nodes_content(self.nodes, add_state=True) + response = call_openai(prompt=prompt, content=content) + self.status = SequenceStatus.SUCCESS if "YES" in response else SequenceStatus.FAILURE + + @classmethod + def from_json(cls, path: str = None, data: dict = None): + if path is not None: + # print(f"Loading sequence node from {type(path)} | {path}...") + data = json.load(open(path, 'r')) + elif data is None: + raise ValueError("Either path or data must be provided.") + + nodes = [] + for node in data["nodes"]: + if node["node_type"] == NodeType.ACTION.value: + nodes.append(ActionNode.from_json(data=node)) + elif node["node_type"] == NodeType.SEQUENCE.value: + nodes.append(cls.from_json(data=node)) + return cls(nodes=nodes, goal=data.get("goal", None), status=data.get("status", None)) + + def to_json(self, path: str = None): + """Save the sequence node to a JSON file. + Args: + path: The path to save the sequence node. + """ + data = { + "node_type": self.node_type.value, + "nodes": [node.to_json() for node in self.nodes], + "goal": self.goal, + "status": self.status.value + } + if path is not None: + json.dump(data, open(path, 'w')) + return data + +# %% Utility functions + +def get_new_node(action_node_list: list[ActionNode]) -> ActionNode | SequenceNode: + """Get a new node from a segment.""" + if len(action_node_list) == 1: + return action_node_list[0] + else: + assert len(action_node_list) > 1, f"Length is {len(action_node_list)}" + return SequenceNode(nodes=action_node_list) + +def merge_nodes(node_list: list[ActionNode | SequenceNode]) -> SequenceNode: + merged_nodes = [] + for i, node in enumerate(node_list): + if isinstance(node, ActionNode): + merged_nodes.append(node) + elif isinstance(node, SequenceNode): + merged_nodes.extend(node.nodes) + return SequenceNode(nodes=merged_nodes) + + +# %% Annotate Sequence Node +def parse_chunk(chunk: str) -> dict: + """Parse chunk info dict from string.""" + s = chunk.index('[') + e = chunk.index(']', s+1) + text = chunk[e+1:].strip() + print(chunk[s+1:e]) + if '-' not in chunk[s+1:e]: + s, e = chunk[s+1:e], chunk[s+1:e] + else: + s, e = chunk[s+1:e].split('-') + s, e = int(s.strip()), int(e.strip()) + return {"start": s, "end": e, "length": e - s + 1, "goal": text} + +def parse_annotation(annotation: str, node_list: list[ActionNode]) -> list[ActionNode | SequenceNode]: + print("Annotation: ", annotation) + index = annotation.find('[') + chunks = [s.strip() for s in annotation[index:].split('\n') if s.strip()] + chunks = [c for c in chunks if c.startswith('[')] + parsed_chunks = [] + for c in chunks: + try: + cp = parse_chunk(c) + parsed_chunks.append(cp) + except: + print(f"Failed to parse chunk: {c}") + return parsed_chunks + + +# %% Validate Chunks + +def remove_empty_chunks(chunks: list[dict], node_list: list[ActionNode]) -> list[dict]: + """Remove chunks with no steps.""" + chunk_index = None + for i, c in enumerate(chunks): + actions = node_list[c["start"]:c["end"]+1] + if len(actions) == 0: + chunk_index = i + break + if len(actions) < c["length"]: + chunks[i]["end"] = c["start"] + len(actions) - 1 + chunks[i]["length"] = len(actions) + if chunk_index is not None: + chunks = chunks[:chunk_index] + print(f"Removed chunk {chunk_index} because it has no steps.") + return chunks + +def validate_chunks(chunks: list[dict], node_list: list[ActionNode]) -> list[dict]: + """Validate the chunks contain the same number of ActionNodes.""" + chunks = remove_empty_chunks(chunks, node_list) + print("Non-Empty Chunks: ", chunks) + total_steps = sum([c["length"] for c in chunks]) + print(f"Total Chunk Steps: {total_steps} vs Trajectory Length:{len(node_list)}") + if total_steps < len(node_list): # add the remaining steps to the last chunk + if len(chunks) == 0: + chunks.append({"start": 0, "end": len(node_list) - 1, "length": len(node_list), "goal": None}) + else: + chunks[-1]["end"] = len(node_list) - 1 + chunks[-1]["length"] = len(node_list) - chunks[-1]["start"] + total_steps = sum([c["length"] for c in chunks]) + if total_steps != len(node_list): + print(f"[WARNING] Total Chunk Steps: {chunks} vs Trajectory Length:{len(node_list)}") + return chunks + +# %% Input Node to LLM + +def get_action_content(action_node: ActionNode, add_state: bool = False) -> list[dict]: + """Get the content of the step.""" + content = [] + if add_state: + if is_keyboard_action(action_node.action): + image_path = action_node.state.get_state() + if image_path is not None: + image_url = encode_image(image_path, return_url=True) + content.append({"type": "image_url", "image_url": {"url": image_url}}) + else: + image_path = action_node.state.get_state() + if image_path is not None: + image_url = encode_image(image_path, return_url=True) + content.append({"type": "image_url", "image_url": {"url": image_url}}) + + text = action_node.action + if action_node.goal is not None: + text += f" ({action_node.goal})" + content.append({"type": "text", "text": text}) + return content + +def get_nodes_content(node_list: list[ActionNode | SequenceNode], add_state: bool = False) -> list[dict]: + """Get the content of the sequence. Only add state for the last action, if set up.""" + content = [] + for n in node_list[:-1]: + if isinstance(n, ActionNode): + content.extend(get_action_content(n, add_state=False)) + else: + content.extend(get_nodes_content(n.nodes, add_state=False)) + + if isinstance(node_list[-1], ActionNode): + content.extend(get_action_content(node_list[-1], add_state=add_state)) + else: + content.extend(get_nodes_content(node_list[-1].nodes, add_state=add_state)) + return content + + +# %% Get First/Last Action + +def get_first_action(node: ActionNode | SequenceNode) -> ActionNode: + if isinstance(node, ActionNode): + return node + else: + return get_first_action(node.nodes[0]) + +def get_last_action(node: ActionNode | SequenceNode) -> ActionNode: + if isinstance(node, ActionNode): + return node + else: + return get_last_action(node.nodes[-1]) + + + +def viz_node(node, indent_level: int = 0): + content = f"[{indent_level}]" + " " * indent_level + "Type: " + node.node_type.value + " | " + if isinstance(node, ActionNode) and node.action is not None: + content += "Action: " + node.action + " | " + else: + content += "Action: None | " + if node.goal is not None: + content += "Goal: " + node.goal.split("\n")[0] + else: + content += "Goal: None" + print(content) + if node.node_type == NodeType.SEQUENCE: + for child in node.nodes: + viz_node(child, indent_level + 1) \ No newline at end of file diff --git a/induce/prompts/annotate_node.txt b/induce/prompts/annotate_node.txt new file mode 100644 index 0000000..3d8107e --- /dev/null +++ b/induce/prompts/annotate_node.txt @@ -0,0 +1,18 @@ +You task is to decompose the computer-use activities into chunks. Each chunk should be driven by a meaningful subtask objective (e.g., "enter document title", "find the expense data in folders"). Describe the subtask objective in a natural language sentence. +You want to ensure you output as few chunks as possible. + +For each output chunk, note the start and end index in the input, as follows: +## Example input +[0] click_left(34.1, 329.1) +[1] ... ... + +## Example Output +[0-3] open the finance expense validation instruction in Google Docs +[4-5] Navigating into the Documents folder + + +If multiple consecutive chunks have similar objectives, you should merge them. Make sure the output chunks do not have overlapping steps. +Strictly follow the format of the example output. +The computer screenshots are only about the task content, no personal or sensitive information is included. +Start indexing from 0. +Mention the specific software or tool used in the action. \ No newline at end of file diff --git a/induce/prompts/get_node_goal.txt b/induce/prompts/get_node_goal.txt new file mode 100644 index 0000000..b743d70 --- /dev/null +++ b/induce/prompts/get_node_goal.txt @@ -0,0 +1,9 @@ +Your task is to summarize the goal in a short sentence, given a sequence of subgoals. + +## Example input +[1] Navigating into the Documents folder +[2] Navigating into the Finance folder +[3] Find the `expenses.csv` file. + +## Example output +Navigate through the Documents then Finance folders to find the `expenses.csv` file. diff --git a/induce/prompts/get_node_status.txt b/induce/prompts/get_node_status.txt new file mode 100644 index 0000000..568cf2a --- /dev/null +++ b/induce/prompts/get_node_status.txt @@ -0,0 +1,2 @@ +Your task is to determine if the computer-use action sequence roughly achieves the goal. +Output "YES" if the goal is achieved, otherwise output "NO". \ No newline at end of file diff --git a/induce/prompts/induce.txt b/induce/prompts/induce.txt new file mode 100644 index 0000000..b0709cb --- /dev/null +++ b/induce/prompts/induce.txt @@ -0,0 +1,24 @@ +Your task is to summarize the general workflow from the provided task-solving steps. + +For example, if the steps are: + +``` +[1] Creates a new empty Google Sheet named "april-attendance-data" in chrome browser. +[2] Scrolls down the google sheet to view more rows. +[3] Scrolls down and then clicks on the top row. +[4] Enters the text "Editor" into the select cell in Google Sheet. +[5] Copies the text "Editor" into the select cell in Google Sheet. +[6] Downloads the Google Sheet as a csv file. +``` + +The workflow should be: +``` +[1-1] Create a new empty Google Sheet named "april-attendance-data" in chrome browser. +[2-5] View and edit new data in the Google Sheet. +[6-6] Downloads the Google Sheet as a csv file. +``` +Each step in the workflow should be (i) a previous step, or (ii) a summary of multiple consecutive steps. +Especially if there are repetitive patterns across steps, for example, continuing editing a document for multiple times, you should combine them into a single step. Summarize the actions in a concise manner, but do maintain necessary details such as the file name or function tested. Summarize the shared final goal of the steps, instead of simply concatenating the steps. +Try segmenting the steps into semantically relevant chunks, and then summarize them into a single step. Pay attention to the order and index of the steps, they need to exactly match the input step indices. +Do not reorder steps in the workflow. +Merging document navigation steps to a higher level. Maintain coding-involved steps as separate steps. diff --git a/induce/requirements.txt b/induce/requirements.txt new file mode 100644 index 0000000..7c9a0ba --- /dev/null +++ b/induce/requirements.txt @@ -0,0 +1,2 @@ +pandas +opencv-python \ No newline at end of file diff --git a/induce/segment.py b/induce/segment.py new file mode 100644 index 0000000..84ba240 --- /dev/null +++ b/induce/segment.py @@ -0,0 +1,476 @@ +"""Segment trajectory into (raw) nodes based on state similarity. + - Use MSE to measure state similarity, split if beyond a threshold. + - If specified, use neural similarity to re-merge nodes if the software is the same. +""" + +import os +import cv2 +import argparse +import numpy as np +from utils import encode_image, call_openai +from language import ( + ActionNode, SequenceNode, get_new_node, merge_nodes, + get_first_action, get_last_action, viz_node, +) + +# %% State Similarity +MAX_DIFF = 100000.0 + +def mse(image_path1: str | None, image_path2: str | None) -> float: + """Calculate the mean squared error between two images.""" + # print(f"Calculating MSE between {image_path1} and {image_path2}") + if image_path1 is None or image_path2 is None: + return MAX_DIFF + image1 = cv2.imread(image_path1) + image2 = cv2.imread(image_path2) + if image1.shape != image2.shape: + # print(f"Image shapes do not match: {image1.shape} != {image2.shape}") + return MAX_DIFF + err = np.sum((image1.astype("float") - image2.astype("float")) ** 2) + err /= float(image1.shape[0] * image1.shape[1]) + return err + + +PROMPT = """Your task is to determine if the two computer screens focus on the same software. +For each screen, first identify the software it is focused on in the front, e.g., Google Chrome, VSCode, Finder, etc. +Then, compare the software on the two screens. If they are the same, output 'YES'. Otherwise, output 'NO'.""" + +def neural(image1: str | None, image2: str | None) -> float: + """Calculate the similarity between two images using LLM.""" + if image1 is None or image2 is None: + return MAX_DIFF + image1_url = encode_image(image1, return_url=True) + image2_url = encode_image(image2, return_url=True) + content = [ + {"type": "image_url", "image_url": {"url": image1_url},}, + {"type": "image_url", "image_url": {"url": image2_url},} + ] + response = call_openai(PROMPT, content) + return 0.0 if "YES" in response else 1.0 + + +SIM_FUNC = {"mse": mse, "neural": neural} + +def get_state_similarity( + curr_node: ActionNode | SequenceNode, + last_node: ActionNode | SequenceNode, + method: str = "mse", +) -> float: + """Calculate the similarity between the current and last state of the trajectory.""" + curr_action = get_first_action(curr_node) + curr_state_path = curr_action.state.get_state(reverse=True) + last_action = get_last_action(last_node) + last_state_path = last_action.state.get_state(reverse=True) + diff_score = SIM_FUNC[method](curr_state_path, last_state_path) + if diff_score is None: + print(f"Diff score is None for {curr_state_path} and {last_state_path}") + return diff_score + + +def measure_state_diffs(path: str, verbose: bool = False) -> SequenceNode: + """Measure the state similarity between consecutive action nodes in the trajectory.""" + output_path = path.replace(".json", "_mse.json") + if os.path.exists(output_path): + root_node = SequenceNode.from_json(output_path) + if verbose: + print(f"Loaded trajectory with mse diff scores from {output_path}") + return root_node + + # measure similarity scores + print(f"Measuring state diffs for {path}...") + root_node = SequenceNode.from_json(path) + root_node.nodes[0].state.diff_score = 0.0 + for i, action_node in enumerate(root_node.nodes[1:]): + diff_score = get_state_similarity( + curr_node=action_node, + last_node=root_node.nodes[i-1], + method="mse", + ) + root_node.nodes[i+1].state.diff_score = diff_score + + # save the trajectory with diff scores + root_node.to_json(output_path) + if verbose: + print(f"Saved trajectory with mse diff scores to {output_path}") + return root_node + + +# %% Segmentation + +def segment_per_step( + root_node: SequenceNode, + threshold: float = 10000.0, + verbose: bool = False +) -> list[SequenceNode | ActionNode]: + """Segment the trajectory at actions with above-threshold state differences.""" + segments, curr_segment = [], [] + for i, action_node in enumerate(root_node.nodes): + # clean the current segment if a new high-diff step is found + if (action_node.state.diff_score > threshold) and len(curr_segment) > 0: + segments.append(get_new_node(curr_segment)) + curr_segment = [] + # otherwise, add the step to the current segment + curr_segment.append(action_node) + + # add the last segment if it exists + if curr_segment: segments.append(get_new_node(curr_segment)) + if verbose: + print(f"Found {len(segments)} segments via mse diff threshold {threshold}.") + return segments + + +def find_below_threshold_ranges(scores: list[float], threshold: float = 8000.0, min_steps: int = 3) -> list[tuple[int, int]]: + """ + Find all ranges (start_index, end_index) where all scores are below threshold. + + Args: + scores: List of numeric scores + threshold: Threshold value to compare against + + Returns: + List of tuples (start_index, end_index) where all scores in the range + [start_index, end_index] are below threshold + """ + ranges = [] + start = None + + for i, score in enumerate(scores): + if score < threshold: + if start is None: + start = i + else: + if start is not None: + ranges.append((start, i - 1)) + start = None + + # Handle case where the last range extends to the end + if start is not None: + ranges.append((start, len(scores) - 1)) + + ranges = [r for r in ranges if (r[1] - r[0] + 1)>= min_steps] + return ranges + +def get_intervals( + ranges: list[tuple[int, int]], + root_node: SequenceNode, + threshold: float = 10000.0, + min_steps: int = 5, + verbose: bool = False +) -> list[tuple[int, int]]: + intervals = [] + + # add the first range + s, e = ranges[0] + if s <= min_steps: + ranges = [(0, e)] + ranges[1: ] + else: + intervals.append((0, s-1)) + if verbose: + print(f"Found {len(ranges)} ranges: ", ranges) + # cont = input("Continue? (Y/n)") + + # add the rest of the ranges + i, L = 0, len(ranges) + while i < (L - 1): + # calculate gap with next range + curr_range = ranges[i] + gap = (ranges[i][1]+1, ranges[i+1][0]-1) # both sides inclusive + step_diff = gap[1] - curr_range[1] + if step_diff >= min_steps: # add as two ranges + intervals.append(curr_range) + intervals.append(gap) + if verbose: + print(f"Added range {curr_range} and gap {gap} (step_diff: {step_diff})") + # cont = input("Continue? (Y/n)") + else: # merge as one range + intervals.append((curr_range[0], gap[1])) + if verbose: + print(f"Added merged range {intervals[-1]} (curr_range: {curr_range}) (gap: {gap}) (step_diff: {step_diff})") + # cont = input("Continue? (Y/n)") + i += 1 + + # after the loop + if i != L - 1: + print("Suggestion: use `--default_segment` to use the default segmentation method. Do you want to run it now? (y/n)") + cont = input() + if cont == "y": + segments = segment_per_step(root_node, threshold, verbose) + return segments + else: + raise ValueError("i != L - 1" + f"(i: {i}, L-1: {L-1})") + + if ranges[i][1] < root_node.length-1: + curr_range = ranges[i] + gap = (ranges[i][1]+1, root_node.length-1) + step_diff = gap[0] - curr_range[1] + if step_diff >= min_steps: + intervals.append(curr_range) + intervals.append(gap) + else: + intervals.append((curr_range[0], gap[1])) + else: + assert ranges[i][1] == root_node.length-1 + intervals.append(ranges[i]) + return intervals + + +def segment_by_ranges( + root_node: SequenceNode, + threshold: float = 10000.0, + min_steps: int = 5, + verbose: bool = False, +) -> list[ActionNode | SequenceNode]: + scores = [a.state.diff_score for a in root_node.nodes] + ranges = find_below_threshold_ranges(scores, threshold, min_steps) # inclusive at both ends + intervals = get_intervals(ranges, root_node, threshold, min_steps, verbose) + if verbose: + print(f"Found {len(intervals)} segments via mse diff threshold {threshold}: {intervals}") + if min_steps == 1: + split_intervals = [] + for (s, e) in intervals: + if all([root_node.nodes[i].state.diff_score == MAX_DIFF for i in range(s, e+1)]): + split_intervals.extend([(m, m) for m in range(s, e+1)]) + else: + split_intervals.append((s, e)) + if verbose: + print(f"Split {len(intervals)} segments via MaxDiff into: {len(split_intervals)}") + intervals = split_intervals + + segments = [ + get_new_node(root_node.nodes[i1:i2+1]) + for (i1, i2) in intervals + ] + return segments + + +def get_ipython_segments_0(nodes: list[ActionNode]) -> list[ActionNode | SequenceNode]: + """Get the ipython segments from the nodes.""" + ipython_segments = [] + curr_segment = [] + for node in nodes: + assert isinstance(node, ActionNode) + if ("run_ipython" in node.action) and (len(curr_segment) > 0): + ipython_segments.append(get_new_node(curr_segment)) + curr_segment = [] + curr_segment.append(node) + if len(curr_segment) > 0: + ipython_segments.append(get_new_node(curr_segment)) + return ipython_segments + +def get_ipython_segments(nodes: list[ActionNode]) -> list[ActionNode | SequenceNode]: + """Get the ipython segments from the nodes.""" + ipython_indices = [i for i, node in enumerate(nodes) if ("run_ipython" in node.action)] + if len(ipython_indices) == 0: + return [get_new_node(nodes)] + ipython_segments = [] + if ipython_indices[0] > 0: + ipython_segments.append(get_new_node(nodes[:ipython_indices[0]])) + for i in range(len(ipython_indices)-1): + ipython_segments.append(nodes[ipython_indices[i]]) + if ipython_indices[i]+1 < ipython_indices[i+1]: + ipython_segments.append(get_new_node(nodes[ipython_indices[i]+1:ipython_indices[i+1]])) + ipython_segments.append(nodes[ipython_indices[-1]]) + if ipython_indices[-1] < len(nodes)-1: + ipython_segments.append(get_new_node(nodes[ipython_indices[-1]+1:])) + return ipython_segments + +def segment_at_ipython(segments: list[ActionNode | SequenceNode], verbose: bool = False) -> list[ActionNode | SequenceNode]: + """Segment the trajectory at actions with above-threshold state differences.""" + ipython_segments = [] + for seg in segments: + if isinstance(seg, ActionNode): + ipython_segments.append(seg) + elif isinstance(seg, SequenceNode): + ipython_segments.extend(get_ipython_segments(seg.nodes)) + else: + raise ValueError(f"Unknown segment type: {type(seg)}") + if verbose: + print(f"Found {len(ipython_segments)} ipython segments from {len(segments)} segments.") + return ipython_segments + +# %% Remerge with LLM + +def remerge_segments( + segments: list[ActionNode | SequenceNode], + threshold: float = 3.0, + verbose: bool = False +) -> list[ActionNode | SequenceNode]: + merged_segments = [] + i, L = 0, len(segments) + while i < (L-1): + seg, next_seg = segments[i], segments[i+1] + if min(seg.length, next_seg.length) < threshold: + diff_score = get_state_similarity( + curr_node=next_seg, + last_node=seg, + method="neural", + ) + if diff_score == 0.0: # same software + merged_segments.append(merge_nodes([seg, next_seg])) + i += 2 + if verbose: + print(f"Merged segment #{i-1} and #{i} with {len(seg) + len(next_seg)} steps.") + continue + merged_segments.append(seg) + i += 1 + if i == L-1: + merged_segments.append(segments[i]) + return merged_segments + + +def remerge_segments_iterative( + segments: list[ActionNode | SequenceNode], + threshold: float = 3.0, + verbose: bool = False, +) -> list[ActionNode | SequenceNode]: + segment_lengths = [seg.length for seg in segments] + while min(segment_lengths) < threshold: + segments = remerge_segments(segments, threshold, verbose) + if segment_lengths == [seg.length for seg in segments]: + break + segment_lengths = [seg.length for seg in segments] + return segments + + +# %% Split Sequence Nodes + +def split_sequence_node(node: SequenceNode) -> list[ActionNode | SequenceNode]: + """Split the sequence node into multiple nodes.""" + subgoals = [n.goal for n in node.nodes] + if len(subgoals) == 1: + # assert node.goal == subgoals[0], f"Node goal: {node.goal} != Subgoal: {subgoals[0]}" + return node.nodes + if len(subgoals) > 20: + return [node] + response = call_openai( + prompt="Is the goal a simple composition of subgoals? Only output 'YES' or 'NO'.", + content=f"Goal: {node.goal}\n\nSubgoals:\n" + "\n".join(subgoals) + ) + if "YES" in response: + print(f"Split sequence node {node.goal} into {len(node.nodes)} nodes:", subgoals) + # cont = input("Continue? (Y/n)") + return node.nodes + else: + print(f"Not splitting sequence node {node.goal} into {len(node.nodes)} nodes:", subgoals) + # cont = input("Continue? (y/N)") + return [node] + if cont.strip().lower() == "y": + return node.nodes + else: + return [node] + + +# %% Main + +def main(): + # measure state diffs + root_node = measure_state_diffs(args.trajectory_path, args.verbose) + # segment the trajectory + if args.default_segment: + segments = segment_per_step(root_node, args.threshold, args.verbose) + else: + segments = segment_by_ranges(root_node, args.threshold, min_steps=args.min_steps, verbose=args.verbose) + + segments = segment_at_ipython(segments, args.verbose) + if len(segments) == 0: segments = [root_node] + + # semantically re-merge the segments if specified (based on semantic adjacent state similarity) + if args.do_remerge: # list[ActionNode | SequenceNode] + segments = remerge_segments_iterative(segments, args.remerge_threshold, args.verbose) + + segments = segments[1:] + print(f"Found {len(segments)} segments: ", [s.get_num_actions() for s in segments]) # each segment is a ActionNode or list[ActionNode] + # cont = input("Continue? (Y/n)") + # segment sequence nodes with compositional goals (based on LLM-annotated goals) + def process_and_save_node(node: ActionNode | SequenceNode, i: int, save: bool = True) -> tuple[list[ActionNode | SequenceNode], int]: + if isinstance(node, ActionNode): + node.get_goal() + print(f"[{node.node_type.value}] Goal: {node.goal}") + if save: + node_path = os.path.join(args.output_dir, f"{i}.json") + node.to_json(node_path) + return [node], i+1 + elif isinstance(node, SequenceNode): + node.annotate(model_name="gpt-4o", verbose=args.verbose) + split_nodes = split_sequence_node(node) + for n in split_nodes: + print(f"[{n.node_type.value}] Goal: {n.goal}") + if save: + node_path = os.path.join(args.output_dir, f"{i}.json") + n.to_json(node_path) + i += 1 + return split_nodes, i + + + node_list = [] + for node in segments[: 2]: + nodes, i = process_and_save_node(node, 0, save=False) + node_list.extend(nodes) + + def merge_nodes_keep_goal(node1: ActionNode | SequenceNode, node2: ActionNode | SequenceNode) -> SequenceNode: + if node1.node_type.value == "action": + if node2.node_type.value == "action": + node = get_new_node([node1, node2]) + return [node] + else: + node2.nodes = [node1] + node2.nodes + return [node2] + elif node1.node_type.value == "sequence": + node2.nodes = node1.nodes + node2.nodes + return [node2] + else: + raise ValueError(f"Unknown node type: {node1.node_type.value}") + + if node_list[0].get_num_actions() == 1: + node_list = merge_nodes_keep_goal(node_list[0], node_list[1]) + for i, n in enumerate(node_list): + print(f"[{n.node_type.value}] Goal: {n.goal}") + node_path = os.path.join(args.output_dir, f"{i}.json") + n.to_json(node_path) + + i += 1 + print("Resume node processing from index", i) + for node in segments[2:]: + nodes, i = process_and_save_node(node, i, save=True) + node_list.extend(nodes) + + # examine the segments + if args.verbose: + viz_idx = int(input(f"Enter the index of the node to visualize (0-{len(node_list)-1}): ")) + while 0 <= viz_idx < len(node_list): + viz_node(node_list[viz_idx]) + viz_idx = int(input(f"Enter the index of the node to visualize (0-{len(node_list)-1}): ")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, required=True, + help="The directory containing the trajectory data.") + parser.add_argument("--trajectory_name", type=str, default="processed_trajectory.json", + help="The name of the trajectory data file.") + parser.add_argument("--output_dir", type=str, default="nodes", + help="The directory to save the segmented nodes.") + + # mse + parser.add_argument("--threshold", type=float, default=8000.0, + help="State MSE difference threshold for segmentation.") + parser.add_argument("--min_steps", type=int, default=5, help="Minimum number of steps for a segment.") + parser.add_argument("--default_segment", action="store_true", + help="If use default segmentation method: split at high-diff steps; otherwise, identify low-diff ranges.") + + # neural + parser.add_argument("--do_remerge", action="store_true", + help="If re-merge segments via neural similarity.") + parser.add_argument("--remerge_threshold", type=float, default=3.0, + help="Maximum number of steps for triggering re-merging via neural similarity.") + + # debug + parser.add_argument("--verbose", action="store_true", help="Print details.") + + args = parser.parse_args() + + args.trajectory_path = os.path.join(args.data_dir, args.trajectory_name) + args.output_dir = os.path.join(args.data_dir, args.output_dir) + os.makedirs(args.output_dir, exist_ok=True) + + main() diff --git a/induce/utils.py b/induce/utils.py new file mode 100644 index 0000000..4c59887 --- /dev/null +++ b/induce/utils.py @@ -0,0 +1,73 @@ +import time +import base64 + +def encode_image(img_path: str, return_url: bool = False) -> str: + """Encode the image to base64.""" + with open(img_path, "rb") as fh: + img = base64.b64encode(fh.read()).decode() + if return_url: + return f"data:image/jpeg;base64,{img}" + return img + + +def is_keyboard_action(action: str) -> bool: + """Check if the action is a keyboard action.""" + return "press" in action + +def is_click_action(action: str) -> bool: + return "click" in action + +def is_scroll_action(action: str) -> bool: + return "scroll" in action + +# %% Keyboard Input +def get_key_input(action: str) -> str: + """Parse the key input from the action.""" + if '(' in action and ')' in action: + kin = action.split('(')[1].split(')')[0] + else: + kin = action + kin = kin.replace("'", "").strip() + + if kin == "Key.space": + return " " + elif kin == "Key.shift": + return "" # upper/lower case already applied to characters + elif kin == "Key.backspace": + return kin + elif kin.startswith("Key."): # shift/ctrl/alt/cmd + return kin + '+' + else: + return kin + +def compose_key_input(input_list: list[str]) -> str: + """Compose the key input from the actions.""" + composed_input_list = [] + for il in input_list: + if il == "Key.backspace" and len(composed_input_list) > 0: + composed_input_list[-1] = composed_input_list[-1][:-1] + else: + composed_input_list.append(il) + return "".join(composed_input_list) + + +# %% LLM +import os +import openai +from openai import OpenAI + +def call_openai(prompt: str, content = None, model_name: str = "gpt-4o") -> str: + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + try: + response = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "system", "content": prompt}, + {"role": "user", "content": content}, + ], + temperature=0.0, + ) + return response.choices[0].message.content + except Exception as e: + print(f"Error calling {model_name}: {type(e)}") + return "" diff --git a/record/MANIFEST.in b/record/MANIFEST.in new file mode 100644 index 0000000..83c6386 --- /dev/null +++ b/record/MANIFEST.in @@ -0,0 +1,5 @@ +graft gum/prompts/gum +include gum/app/macos/build-macos.sh +include gum/app/macos/apple_events.entitlements.plist +include gum/app/macos/VERSION +include gum/app/macos/Gum Recorder.spec diff --git a/record/README.md b/record/README.md new file mode 100644 index 0000000..f72039d --- /dev/null +++ b/record/README.md @@ -0,0 +1,149 @@ +# Human Activity Recording Tool + +## Installation + +Install from source for now. As of now, we've only tested MacOS: + +```bash +pip install -e . +``` + +For memory monitoring capabilities: +```bash +pip install -e .[monitoring] +``` + +Make sure to enable recording on your Mac: go to System Preferences -> Privacy & Security -> Accessibility, allow recording for the app that you use to edit the code, e.g., vscode. + +## Standalone App (currently only on MacOS) + +We provide a build script that creates a `.app` you can double‑click. + +1) Build (developer step): + +```bash +# From record/ directory +./gum/app/macos/build-macos.sh record +``` + +This produces `gum/app/macos/dist/Gum Recorder.app`. + +2) First run permissions (user step): +- Open `gum/app/macos/dist/Gum Recorder.app` (right‑click → Open the first time if Gatekeeper warns) +- Grant permissions when prompted: + - Privacy & Security → Screen Recording → enable “Gum Recorder” + - Privacy & Security → Accessibility → enable “Gum Recorder” + - Privacy & Security → Input Monitoring → enable “Gum Recorder” + +The app saves data under `~/Downloads/records`. + +## Usage + +### CLI (Terminal) + +1. Grant your terminal application (Terminal, iTerm2, etc.) Accessibility and Input Monitoring permissions in **System Settings → Privacy & Security**. +2. Activate the environment where you installed the package and run: + +```bash +gum --user-name "your-name" +``` + + The CLI defaults to the pynput keyboard backend so it works inside a terminal window without a Cocoa run loop. Data and screenshots go to `~/Downloads/records` unless you override `--data-directory` / `--screenshots-dir`. + +3. You can also invoke the entrypoint directly: + +```bash +python -m gum.cli.main --debug +``` + + Use `GUM_DISABLE_KEYBOARD=1` if you need to launch without keyboard logging (for example while debugging permissions). + +### Scroll Filtering Options +### macOS note: GUM_DISABLE_KEYBOARD + +On macOS, AppKit requires keyboard event monitors to be registered from the main thread. Our GUI app (Tk-based) sets up a main-thread shim that captures AppKit key events and forwards them into the background recorder thread. To avoid registering a second keyboard listener in the background thread (which can fail due to macOS security constraints), the GUI sets the environment variable `GUM_DISABLE_KEYBOARD=1` before starting the background recorder. The observer (`Screen`) checks this variable and skips starting its keyboard backend when the shim is active. + +- When running the GUI app: the main-thread shim is used; `GUM_DISABLE_KEYBOARD=1` is set automatically. +- When running the CLI: no shim is present; leave the variable unset so the keyboard backend runs normally. + +Only set `GUM_DISABLE_KEYBOARD=1` if you explicitly want to disable the keyboard backend (e.g., to debug permissions) and rely on the GUI shim, or run without keyboard capture. + + +To reduce unnecessary scroll logging, you can configure scroll filtering parameters: + +```bash +# More aggressive filtering (fewer scroll events logged) +gum --scroll-debounce 1.0 \ + --scroll-min-distance 10.0 \ + --scroll-max-frequency 5 \ + --scroll-session-timeout 3.0 + +# Less filtering (more scroll events logged) +gum --scroll-debounce 0.2 \ + --scroll-min-distance 2.0 \ + --scroll-max-frequency 20 \ + --scroll-session-timeout 1.0 +``` + +**Scroll filtering parameters:** +- `--scroll-debounce`: Minimum time between scroll events (default: 0.5 seconds) +- `--scroll-min-distance`: Minimum scroll distance to log (default: 5.0 pixels) +- `--scroll-max-frequency`: Maximum scroll events per second (default: 10) +- `--scroll-session-timeout`: Scroll session timeout (default: 2.0 seconds) + +## Troubleshooting + +### Process Killed After 30 Screenshots (Mac M3) + +If the process gets killed after approximately 30 screenshots on Mac M3, this is likely due to memory pressure. The tool has been optimized to address this issue: + +**Recent fixes include:** +- Reduced capture frequency from 10 FPS to 5 FPS (3 FPS on high-DPI displays) +- Lower JPEG quality (70% instead of 90%) to reduce file sizes +- Explicit memory cleanup every 30 frames (20 frames on high-DPI displays) +- Proper disposal of old frame objects +- Custom thread pool to prevent thread pool exhaustion +- Better error handling for MSS operations +- Automatic detection of high-DPI displays with conservative settings +- **Scroll filtering**: Reduces unnecessary scroll event logging with configurable debouncing, distance thresholds, and frequency limits + +**Additional issues addressed:** +- **Thread pool exhaustion**: Limited thread pool size to 4 workers +- **MSS memory leaks**: Added proper resource cleanup and error handling +- **High-DPI display pressure**: Automatic detection and reduced capture frequency +- **Concurrent file I/O**: Better coordination of file operations +- **Apple Silicon optimization**: Specific handling for ARM64 architecture + +**To diagnose system issues:** +```bash +# Run diagnostic tool before starting gum +python diagnose_memory.py +``` + +**To monitor memory usage:** +```bash +# In a separate terminal +python memory_monitor.py +``` + +**Additional recommendations:** +1. Close unnecessary applications while recording +2. Ensure you have at least 4GB of free RAM +3. If issues persist, try running with debug mode: + ```bash + gum --debug + ``` +4. For high-DPI displays, the tool automatically uses more conservative settings +5. Consider running on a single monitor if using multiple high-resolution displays + +### Memory Monitoring + +To track memory usage during recording, install the monitoring dependencies and run the memory monitor in a separate terminal: + +```bash +# Terminal 1: Run the recording tool +gum + +# Terminal 2: Monitor memory usage +python memory_monitor.py +``` diff --git a/gum/__init__.py b/record/gum/__init__.py similarity index 86% rename from gum/__init__.py rename to record/gum/__init__.py index 18340b1..c7b24d1 100644 --- a/gum/__init__.py +++ b/record/gum/__init__.py @@ -4,7 +4,7 @@ A Python package for managing user feedback and interactions. """ -__version__ = "0.1.2" +__version__ = "0.1.0" from .gum import gum diff --git a/record/gum/__main__.py b/record/gum/__main__.py new file mode 100644 index 0000000..9c0d3ae --- /dev/null +++ b/record/gum/__main__.py @@ -0,0 +1,7 @@ +from dotenv import load_dotenv +load_dotenv() + +from .cli import main + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/record/gum/app/__init__.py b/record/gum/app/__init__.py new file mode 100644 index 0000000..cbc1b2c --- /dev/null +++ b/record/gum/app/__init__.py @@ -0,0 +1,8 @@ +"""Platform-specific application support for Gum.""" + +from .macos import AppleUIInspector, check_automation_permission_granted + +__all__ = [ + "AppleUIInspector", + "check_automation_permission_granted", +] diff --git a/record/gum/app/macos/Gum Recorder.spec b/record/gum/app/macos/Gum Recorder.spec new file mode 100644 index 0000000..352f3ed --- /dev/null +++ b/record/gum/app/macos/Gum Recorder.spec @@ -0,0 +1,77 @@ +# -*- mode: python ; coding: utf-8 -*- +import os +import sys +from pathlib import Path +from PyInstaller.utils.hooks import collect_data_files +from PyInstaller.utils.hooks import collect_submodules + +SPEC_PATH = Path(globals().get('__file__', sys.argv[0])).resolve() +APP_DIR = SPEC_PATH.parent +PROJECT_ROOT = APP_DIR.parent.parent + +APP_NAME = os.environ.get("PYINSTALLER_APP_NAME", "Gum Recorder") +BUNDLE_IDENTIFIER = os.environ.get("PYINSTALLER_BUNDLE_IDENTIFIER", "com.local.gumrecorder") + +datas = [] +hiddenimports = ['Quartz', 'AppKit', 'Foundation', 'dotenv', 'gum.observers.base.observer', 'gum.observers.base.screen', 'gum.observers.base.keyboard', 'gum.observers.base.mouse', 'gum.observers.base.screenshots', 'gum.observers.macos.keyboard', 'gum.observers.macos.mouse', 'gum.observers.macos.screenshots', 'gum.observers.macos.app_and_browser_logging', 'gum.observers.fallback.keyboard', 'gum.observers.fallback.mouse', 'gum.observers.fallback.screenshots'] +datas += collect_data_files('shapely') +hiddenimports += collect_submodules('sqlalchemy') +hiddenimports += collect_submodules('gum.observers') +hiddenimports += collect_submodules('gum.cli') +hiddenimports += collect_submodules('sqlalchemy_utils') +hiddenimports += collect_submodules('pydantic') +hiddenimports += collect_submodules('aiosqlite') +hiddenimports += collect_submodules('shapely') +hiddenimports += collect_submodules('pynput') +hiddenimports += collect_submodules('mss') +# Include tkinter since the macOS app uses Tk for the UI +hiddenimports += collect_submodules('tkinter') + + +a = Analysis( + [str(APP_DIR / 'app_entry.py')], + pathex=[str(PROJECT_ROOT)], + binaries=[], + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, + optimize=0, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name=APP_NAME, + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=True, + upx_exclude=[], + name=APP_NAME, +) +app = BUNDLE( + coll, + name=f'{APP_NAME}.app', + icon=None, + bundle_identifier=BUNDLE_IDENTIFIER, +) diff --git a/record/gum/app/macos/VERSION b/record/gum/app/macos/VERSION new file mode 100644 index 0000000..b4ae2bd --- /dev/null +++ b/record/gum/app/macos/VERSION @@ -0,0 +1 @@ +0.0.60 diff --git a/record/gum/app/macos/__init__.py b/record/gum/app/macos/__init__.py new file mode 100644 index 0000000..95816dd --- /dev/null +++ b/record/gum/app/macos/__init__.py @@ -0,0 +1,8 @@ +"""macOS-specific application entrypoints and utilities for Gum.""" + +from ...observers.macos import AppleUIInspector, check_automation_permission_granted + +__all__ = [ + "AppleUIInspector", + "check_automation_permission_granted", +] diff --git a/record/gum/app/macos/app_entry.py b/record/gum/app/macos/app_entry.py new file mode 100644 index 0000000..bbc46ac --- /dev/null +++ b/record/gum/app/macos/app_entry.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import os +import sys +import plistlib +import threading +import logging + +# Ensure package is importable when frozen +if getattr(sys, "frozen", False): + sys.path.insert(0, os.path.dirname(sys.executable)) + +if __package__ in (None, ""): + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + if project_root not in sys.path: + sys.path.insert(0, project_root) + +try: + from ...cli import main + from ...cli import BackgroundRecorder +except ImportError: + from gum.cli import main + from gum.cli import BackgroundRecorder + +def _detect_app_version() -> str: + # Priority: env override, Info.plist in bundled app, fallback + env_v = os.environ.get("GUM_APP_VERSION") + if env_v: + return env_v + try: + if getattr(sys, "frozen", False): + plist_path = os.path.abspath(os.path.join(os.path.dirname(sys.executable), "..", "Info.plist")) + if os.path.exists(plist_path): + with open(plist_path, "rb") as f: + data = plistlib.load(f) + v = data.get("CFBundleShortVersionString") or data.get("CFBundleVersion") + if isinstance(v, str) and v.strip(): + return v.strip() + except Exception: # Catches any error during version detection from Info.plist + pass + return "0.0" + +APP_VERSION = _detect_app_version() + + +def run(): + try: + import tkinter as tk + from tkinter import messagebox + from tkinter import filedialog + except Exception: # Catches tkinter import failures (fallback to CLI mode) + os.environ.setdefault("PYTHONASYNCIODEBUG", "0") + try: + main() + except KeyboardInterrupt: # Catches Ctrl+C interruption during CLI fallback + pass + return + + try: + from . import app_entry_utils as ui # type: ignore + except ImportError: # Catches relative import failures (fallback to absolute import) + from gum.app.macos import app_entry_utils as ui # type: ignore + + # ------- Settings and defaults ------- + app_support_root = os.path.expanduser("~/Library/Application Support") + default_output_dir = os.path.join(app_support_root, "Gum Recorder") + + settings_path = os.path.join(default_output_dir, "settings.json") + + settings = ui.load_settings(settings_path, default_output_dir) + output_dir = settings.get("output_dir", default_output_dir) + screenshots_dir = os.path.join(output_dir, "screenshots") + kb_recorder = ui.MainThreadKeyboardRecorderShim() + output_dir_display_var = None + output_path_tooltip = None + + def open_output_folder() -> None: + os.makedirs(output_dir, exist_ok=True) + ui.open_uri(output_dir) + + def choose_output_folder() -> None: + nonlocal output_dir, screenshots_dir, output_dir_display_var, output_path_tooltip + if recording_active["value"]: + messagebox.showinfo("Recording active", "Stop recording before changing the output folder.") + return + selected = filedialog.askdirectory(title="Choose Output Folder", initialdir=output_dir) + if not selected: + return + new_dir = os.path.abspath(os.path.expanduser(selected)) + try: + os.makedirs(new_dir, exist_ok=True) + # Preflight write + test_path = os.path.join(new_dir, ".write_test") + with open(test_path, "w") as f: + f.write("ok") + os.remove(test_path) + except Exception as e: # Catches directory creation/write permission failures + messagebox.showerror("Cannot use folder", f"{new_dir}\nError: {e}") + return + output_dir = new_dir + screenshots_dir = os.path.join(output_dir, "screenshots") + settings["output_dir"] = output_dir + ui.save_settings(settings_path, settings) + status_var.set(f"Output folder set to: {output_dir}") + if output_dir_display_var is not None: + output_dir_display_var.set(ui.format_output_dir(output_dir)) + if output_path_tooltip is not None: + output_path_tooltip.text = os.path.abspath(output_dir) + + start_after_ui = {"value": False} + recording_active = {"value": False} + keyboard_enabled = {"value": False} + + # ------- Permission checks (onboarding) ------- + def show_onboarding_if_needed(root_window) -> None: + if settings.get("onboarding_done", False): + return + + import tkinter as tk + + top = tk.Toplevel(root_window) + top.title("Setup Permissions") + top.grab_set() + top.transient(root_window) + + def status_to_text(flag: bool | None) -> str: + if flag is True: + return "✅ Granted" + if flag is False: + return "❌ Not granted" + return "ℹ️ Unknown" + + sr_status = tk.StringVar(value=status_to_text(ui.check_screen_recording_granted())) + ax_status = tk.StringVar(value=status_to_text(ui.check_accessibility_granted())) + im_status = tk.StringVar(value=status_to_text(ui.check_input_monitoring_granted())) + au_status = tk.StringVar(value=status_to_text(ui.check_automation_granted())) + + row = 0 + tk.Label(top, text="Gum Recorder needs these permissions:", font=("Helvetica", 12, "bold")).grid(row=row, column=0, columnspan=3, sticky="w", padx=12, pady=(12, 8)) + row += 1 + + tk.Label(top, text="Screen Recording:").grid(row=row, column=0, sticky="w", padx=12, pady=6) + tk.Label(top, textvariable=sr_status).grid(row=row, column=1, sticky="w", padx=8) + tk.Button(top, text="Open", command=lambda: (ui.request_screen_recording_access(), sr_status.set(status_to_text(ui.check_screen_recording_granted())))).grid(row=row, column=2, padx=12) + row += 1 + + tk.Label(top, text="Accessibility:").grid(row=row, column=0, sticky="w", padx=12, pady=6) + tk.Label(top, textvariable=ax_status).grid(row=row, column=1, sticky="w", padx=8) + tk.Button(top, text="Open", command=lambda: (ui.prompt_accessibility_access(), ax_status.set(status_to_text(ui.check_accessibility_granted())))).grid(row=row, column=2, padx=12) + row += 1 + + tk.Label(top, text="Input Monitoring (Keyboard):").grid(row=row, column=0, sticky="w", padx=12, pady=6) + tk.Label(top, textvariable=im_status).grid(row=row, column=1, sticky="w", padx=8) + tk.Button(top, text="Open", command=lambda: (ui.prompt_input_monitoring_access(), im_status.set(status_to_text(ui.check_input_monitoring_granted())))).grid(row=row, column=2, padx=12) + row += 1 + + tk.Label(top, text="Browser URL Capture:").grid(row=row, column=0, sticky="w", padx=12, pady=6) + tk.Label(top, textvariable=au_status).grid(row=row, column=1, sticky="w", padx=8) + tk.Button( + top, + text="Enable Browser URLs", + command=lambda: (ui.prompt_automation_access(), au_status.set(status_to_text(ui.check_automation_granted()))), + ).grid(row=row, column=2, padx=12) + row += 1 + + tk.Label( + top, + text=( + "Click Enable Browser URLs and approve the prompts for Safari, Chrome, Brave, Edge, or Arc (if installed). " + "macOS will then open Privacy & Security → Automation so you can confirm Gum Recorder is allowed. " + "If you're running from Terminal, the prompt may mention 'osascript' instead of Gum Recorder." + ), + wraplength=360, + justify="left", + ).grid(row=row, column=0, columnspan=3, sticky="w", padx=12, pady=(0, 10)) + row += 1 + + def finish_onboarding(): + settings["onboarding_done"] = True + ui.save_settings(settings_path, settings) + top.destroy() + + tk.Button(top, text="Continue", command=finish_onboarding).grid(row=row, column=2, sticky="e", padx=12, pady=(8, 12)) + + def ensure_dirs() -> tuple[bool, str | None]: + try: + os.makedirs(screenshots_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + # Preflight writability + test_path = os.path.join(output_dir, ".write_test") + with open(test_path, "w") as f: + f.write("ok") + os.remove(test_path) + return True, None + except Exception as e: # Catches directory creation/write test failures + return False, str(e) + + def permissions_blocking_message() -> str | None: + sr = ui.check_screen_recording_granted() + ax = ui.check_accessibility_granted() + if sr is False or ax is False: + return "Screen Recording and Accessibility must be granted before starting. Open the permissions via the buttons, grant access, then relaunch the app." + return None + + def start_recording() -> None: + if recording_active["value"]: + return + + block_msg = permissions_blocking_message() + if block_msg: + messagebox.showwarning("Permissions required", block_msg) + return + + ok, err = ensure_dirs() + if not ok: + messagebox.showerror("Cannot write to output folder", f"Directory: {output_dir}\nError: {err}\n\nChoose a different folder or fix permissions.") + return + + # Allow the main-thread shim to handle keyboard events when available; + # fall back to the screen observer's internal backends otherwise. + shim_running = False + try: + shim_running = kb_recorder.start() + except Exception: # Catches keyboard shim startup failures + logging.getLogger("GumUI").exception("Keyboard shim failed to start") + shim_running = False + + input_monitoring_granted = ui.check_input_monitoring_granted() is True + if shim_running and input_monitoring_granted: + os.environ["GUM_DISABLE_KEYBOARD"] = "1" + else: + os.environ.pop("GUM_DISABLE_KEYBOARD", None) + if shim_running and not input_monitoring_granted: + logging.getLogger("GumUI").info( + "Input Monitoring permission missing; keeping background keyboard listener active as fallback" + ) + + try: + BackgroundRecorder.start( + user_name="anonymous", + data_directory=output_dir, + screenshots_dir=screenshots_dir, + debug=False, + ) + recording_active["value"] = True + keyboard_enabled["value"] = shim_running + if shim_running: + status_var.set("Recording… (Click Stop to pause)") + else: + status_var.set("Recording… (Keyboard fallback active; click Stop to pause)") + start_btn.config(state=tk.DISABLED) + stop_btn.config(state=tk.NORMAL) + if not shim_running: + logging.getLogger("GumUI").info("Keyboard shim unavailable; relying on background listener") + except Exception as e: # Catches BackgroundRecorder startup failures + if shim_running: + try: + kb_recorder.stop() + except Exception: # Catches keyboard shim cleanup failures + logging.getLogger("GumUI").exception("Failed to stop keyboard shim after startup error") + os.environ.pop("GUM_DISABLE_KEYBOARD", None) + messagebox.showerror("Failed to start", str(e)) + + def stop_recording() -> None: + if not recording_active["value"]: + return + try: + BackgroundRecorder.stop() + kb_recorder.stop() + os.environ.pop("GUM_DISABLE_KEYBOARD", None) + recording_active["value"] = False + keyboard_enabled["value"] = False + status_var.set("Stopped. Press Start to record again.") + start_btn.config(state=tk.NORMAL) + stop_btn.config(state=tk.DISABLED) + # No-op; no background keyboard recorder + except Exception as e: # Catches BackgroundRecorder stop failures + messagebox.showerror("Failed to stop", str(e)) + + def quit_app() -> None: + os._exit(0) + + root = tk.Tk() + root.title(f"Gum Recorder v{APP_VERSION}") + root.geometry("720x520") + root.minsize(560, 420) + root.resizable(True, True) + + frame = tk.Frame(root, padx=16, pady=16) + frame.pack(fill=tk.BOTH, expand=True) + + title = tk.Label(frame, text=f"Gum Recorder v{APP_VERSION}", font=("Helvetica", 16, "bold")) + title.pack(anchor="w") + + status_var = tk.StringVar(value="Ready. Grant permissions below, then Start Recording.") + status = tk.Label(frame, textvariable=status_var) + status.pack(anchor="w", pady=(8, 12)) + + controls = tk.Frame(frame) + controls.pack(anchor="w") + start_btn = tk.Button(controls, text="Start Recording", width=18, command=start_recording) + start_btn.grid(row=0, column=0, padx=(0, 8)) + stop_btn = tk.Button(controls, text="Stop Recording", width=18, command=stop_recording, state=tk.DISABLED) + stop_btn.grid(row=0, column=1) + + # Permissions block + btns = tk.Frame(frame) + btns.pack(anchor="w", pady=(12, 8)) + + # Helper for status text + def _status_to_text(flag: bool | None) -> str: + if flag is True: + return "✅ Granted" + if flag is False: + return "❌ Not granted" + return "ℹ️ Unknown" + + sr_status_var = tk.StringVar(value=_status_to_text(ui.check_screen_recording_granted())) + ax_status_var = tk.StringVar(value=_status_to_text(ui.check_accessibility_granted())) + im_status_var = tk.StringVar(value=_status_to_text(ui.check_input_monitoring_granted())) + au_status_var = tk.StringVar(value=_status_to_text(ui.check_automation_granted())) + + tk.Label(btns, text="Screen Recording:").grid(row=0, column=0, sticky="w", padx=(0, 8), pady=(0, 8)) + tk.Label(btns, textvariable=sr_status_var).grid(row=0, column=1, sticky="w", padx=(0, 12)) + sr_open_btn = tk.Button(btns, text="Open", command=ui.open_screen_recording_settings) + sr_open_btn.grid(row=0, column=2, padx=(0, 8), pady=(0, 8)) + sr_help_btn = tk.Button(btns, text="How to grant", command=ui.open_screen_recording_help) + sr_help_btn.grid(row=0, column=3, padx=(0, 8), pady=(0, 8)) + + tk.Label(btns, text="Accessibility:").grid(row=1, column=0, sticky="w", padx=(0, 8), pady=(0, 8)) + tk.Label(btns, textvariable=ax_status_var).grid(row=1, column=1, sticky="w", padx=(0, 12)) + ax_open_btn = tk.Button(btns, text="Open", command=ui.open_accessibility_settings) + ax_open_btn.grid(row=1, column=2, padx=(0, 8), pady=(0, 8)) + ax_help_btn = tk.Button(btns, text="How to grant", command=ui.open_accessibility_help) + ax_help_btn.grid(row=1, column=3, padx=(0, 8), pady=(0, 8)) + + tk.Label(btns, text="Input Monitoring:").grid(row=2, column=0, sticky="w", padx=(0, 8), pady=(0, 8)) + tk.Label(btns, textvariable=im_status_var).grid(row=2, column=1, sticky="w", padx=(0, 12)) + im_open_btn = tk.Button(btns, text="Open", command=ui.prompt_input_monitoring_access) + im_open_btn.grid(row=2, column=2, padx=(0, 8), pady=(0, 8)) + im_help_btn = tk.Button(btns, text="How to grant", command=ui.open_input_monitoring_help) + im_help_btn.grid(row=2, column=3, padx=(0, 8), pady=(0, 8)) + + tk.Label(btns, text="Browser URL Capture:").grid(row=3, column=0, sticky="w", padx=(0, 8), pady=(0, 8)) + tk.Label(btns, textvariable=au_status_var).grid(row=3, column=1, sticky="w", padx=(0, 12)) + au_open_btn = tk.Button(btns, text="Enable Browser URLs", command=ui.prompt_automation_access) + au_open_btn.grid(row=3, column=2, padx=(0, 8), pady=(0, 8)) + au_help_btn = tk.Button(btns, text="How to grant", command=ui.open_automation_help) + au_help_btn.grid(row=3, column=3, padx=(0, 8), pady=(0, 8)) + + def refresh_permission_status(): + sr_status_var.set(_status_to_text(ui.check_screen_recording_granted())) + ax_status_var.set(_status_to_text(ui.check_accessibility_granted())) + im_status_var.set(_status_to_text(ui.check_input_monitoring_granted())) + au_status_var.set(_status_to_text(ui.check_automation_granted())) + + tk.Button(btns, text="Refresh", command=refresh_permission_status).grid(row=0, column=4, rowspan=4, padx=(12, 0)) + + # Attach tooltips to all permission helper buttons with no delay for quicker guidance + ui.Tooltip(sr_help_btn, ui.PERMISSION_TOOLTIPS["screen"], delay=0) + ui.Tooltip(ax_help_btn, ui.PERMISSION_TOOLTIPS["accessibility"], delay=0) + ui.Tooltip(im_help_btn, ui.PERMISSION_TOOLTIPS["input"], delay=0) + ui.Tooltip(au_help_btn, ui.PERMISSION_TOOLTIPS["automation"], delay=0) + + folder_row = tk.Frame(frame) + folder_row.pack(fill=tk.X, pady=(12, 8)) + folder_row.columnconfigure(1, weight=1) + + tk.Label(folder_row, text="Current Output Path:").grid(row=0, column=0, sticky="w", padx=(0, 8), pady=(0, 8)) + + output_dir_display_var = tk.StringVar(value=ui.format_output_dir(output_dir)) + output_dir_dropdown = tk.Menubutton( + folder_row, + textvariable=output_dir_display_var, + relief=tk.RAISED, + indicatoron=True, + anchor="w", + ) + output_dir_dropdown.grid(row=0, column=1, sticky="we", pady=(0, 8)) + + output_dropdown_menu = tk.Menu(output_dir_dropdown, tearoff=0) + output_dropdown_menu.add_command(label="Open Output Folder", command=open_output_folder) + output_dropdown_menu.add_separator() + output_dropdown_menu.add_command(label="Change Output Folder…", command=choose_output_folder) + output_dir_dropdown.configure(menu=output_dropdown_menu) + + output_path_tooltip = ui.Tooltip(output_dir_dropdown, os.path.abspath(output_dir), delay=0) + + # Main-thread pump for any queued tasks from keyboard recorder + # Main-thread pump for keyboard events + def _pump_keyboard() -> None: + kb_recorder.pump_main_thread_tasks() + root.after(50, _pump_keyboard) + _pump_keyboard() + + # Periodically refresh permission statuses + def periodic_refresh(): + try: + refresh_permission_status() + except Exception: # Catches permission status refresh failures + pass + finally: + root.after(3000, periodic_refresh) + periodic_refresh() + + def schedule_screen_preflight() -> None: + if ui.check_screen_recording_granted() is True: + return + + def _prime() -> None: + try: + ui.prime_screen_recording_permission() + except Exception: # Catches screen recording permission preflight failures + pass + + threading.Thread(name="ScreenRecordingPreflight", target=_prime, daemon=True).start() + + def schedule_accessibility_preflight() -> None: + if ui.check_accessibility_granted() is True: + return + + def _prime() -> None: + try: + ui.prime_accessibility_permission() + except Exception: # Catches accessibility permission preflight failures + pass + + threading.Thread(name="AccessibilityPreflight", target=_prime, daemon=True).start() + + def schedule_input_preflight() -> None: + if ui.check_input_monitoring_granted() is True: + return + + def _prime() -> None: + try: + ui.prime_input_monitoring_permission() + except Exception: # Catches input monitoring permission preflight failures + pass + + threading.Thread(name="InputMonitoringPreflight", target=_prime, daemon=True).start() + + def schedule_automation_preflight() -> None: + if ui.check_automation_granted() is True: + return + + def _prime() -> None: + try: + ui.prime_automation_permissions() + except Exception: # Catches automation permission preflight failures + pass + + threading.Thread(name="AutomationPreflight", target=_prime, daemon=True).start() + + root.after(200, schedule_screen_preflight) + root.after(300, schedule_accessibility_preflight) + root.after(400, schedule_input_preflight) + root.after(500, schedule_automation_preflight) + + # Show onboarding if first run + show_onboarding_if_needed(root) + + tk.Button(frame, text="Quit", command=quit_app).pack(anchor="e") + + try: + root.mainloop() + finally: + pass + + +if __name__ == "__main__": + run() diff --git a/record/gum/app/macos/app_entry_utils.py b/record/gum/app/macos/app_entry_utils.py new file mode 100644 index 0000000..0be83fc --- /dev/null +++ b/record/gum/app/macos/app_entry_utils.py @@ -0,0 +1,713 @@ +from __future__ import annotations + +import json +import logging +import os +import plistlib +import subprocess +import sys +import tkinter as tk +from pathlib import Path +from typing import Any + +import threading +from collections import deque + +from ...observers.macos import AppleUIInspector, check_automation_permission_granted +from ...observers.macos.keyboard import event_token_from_nsevent, register_appkit_key_monitors, remove_appkit_monitors +from ...cli.background import BackgroundRecorder +from ...observers.constants import KEYBOARD_PUMP_INTERVAL_MS, PERMISSION_REFRESH_INTERVAL_MS + +try: # AppKit is only available on macOS + from AppKit import NSEvent, NSEventMaskKeyDown, NSEventMaskKeyUp +except Exception: # pragma: no cover - AppKit unavailable outside macOS + NSEvent = None + NSEventMaskKeyDown = 0 + NSEventMaskKeyUp = 0 + + +DEFAULT_SETTINGS = { + "output_dir": os.path.join( + os.path.expanduser("~/Library/Application Support"), + "Gum Recorder", + ), + "onboarding_done": False, +} + + +class MainThreadKeyboardRecorderShim: + """We need a small shim running in the main thread to forward keyboard events to the recorder + because on MacOS there is a security sandbox that prevents background threads from + listening to keyboard events. + """ + + def __init__(self) -> None: + self._monitors: list[Any] = [] + self._queue: deque[tuple[str, str]] = deque() + self._lock = threading.Lock() + self._running = False + + @staticmethod + def _event_token(ev) -> str: + # Delegate to shared backend helper to avoid duplication + try: + return event_token_from_nsevent(ev) + except Exception: + return "KEY:unknown" + + def start(self) -> bool: + if self._running: + return True + if NSEvent is None: + logging.getLogger("GumUI").warning("AppKit unavailable; falling back to background keyboard monitor") + return False + + def _enqueue(token: str, kind: str) -> None: + if not token: + return + with self._lock: + self._queue.append((token, kind)) + + def on_down(ev): + _enqueue(self._event_token(ev), "press") + + def on_up(ev): + _enqueue(self._event_token(ev), "release") + + try: + self._monitors = register_appkit_key_monitors(on_down, on_up) + self._running = True + except Exception: + self._monitors = [] + self._running = False + logging.getLogger("GumUI").warning( + "Failed to start main-thread keyboard monitor; falling back to background listener", + exc_info=True, + ) + + return self._running + + def stop(self) -> None: + if not self._running: + return + try: + remove_appkit_monitors(self._monitors) + finally: + self._monitors = [] + self._running = False + with self._lock: + self._queue.clear() + + def pump_main_thread_tasks(self) -> None: + if not self._running: + return + events: list[tuple[str, str]] + with self._lock: + if not self._queue: + return + events = list(self._queue) + self._queue.clear() + for token, kind in events: + BackgroundRecorder.post_key_event(token, "press" if kind == "press" else "release") + + +class Tooltip: + """Simple tooltip helper for Tk widgets.""" + + def __init__(self, widget: tk.Widget, text: str, delay: int = 400) -> None: + self.widget = widget + self.text = text + self.delay = delay + self._after_id: int | None = None + self._tip: tk.Toplevel | None = None + widget.bind("", self._enter) + widget.bind("", self._leave) + widget.bind("", self._leave) + + def _enter(self, _event=None) -> None: + self._schedule() + + def _leave(self, _event=None) -> None: + self._unschedule() + self._hide() + + def _schedule(self) -> None: + self._unschedule() + self._after_id = self.widget.after(self.delay, self._show) + + def _unschedule(self) -> None: + if self._after_id: + try: + self.widget.after_cancel(self._after_id) + except Exception: # pragma: no cover - best effort cleanup + pass + self._after_id = None + + def _show(self) -> None: + if self._tip or not self.text: + return + try: + x, y, _, _ = self.widget.bbox("insert") if hasattr(self.widget, "bbox") else (0, 0, 0, 0) + except Exception: + x, y = 0, 0 + x += self.widget.winfo_rootx() + 20 + y += self.widget.winfo_rooty() + 20 + tip = tk.Toplevel(self.widget) + tip.wm_overrideredirect(True) + tip.wm_geometry(f"+{x}+{y}") + label = tk.Label( + tip, + text=self.text, + justify="left", + background="#ffffe0", + relief="solid", + borderwidth=1, + wraplength=320, + ) + label.pack(ipadx=6, ipady=3) + self._tip = tip + + def _hide(self) -> None: + if self._tip: + try: + self._tip.destroy() + except Exception: # pragma: no cover - best effort cleanup + pass + self._tip = None + + +PERMISSION_TOOLTIPS: dict[str, str] = { + "screen": ( + "macOS requires the Screen Recording permission so Gum Recorder can capture " + "the pixels on your display. Click 'Open' to open the System dialog and click the '+' button. " + "This will give you a file dropdown: find and select Gum Recorder to add these permissions. " + "if it is not already listed. If it is already listed, and you still don't see permissions, click 'Refresh' to update the list or file a ticket." + ), + "accessibility": ( + "macOS requires the Accessibility permission so Gum Recorder can observe UI events, including keystrokes. " + "Click 'Open' to open the System dialog and click the '+' button. " + "This will give you a file dropdown: find and select Gum Recorder to add these permissions. " + "if it is not already listed. If it is already listed, and you still don't see permissions, click 'Refresh' to update the list or file a ticket." + ), + "input": ( + "Input Monitoring permission is necessary for listening to keyboard events while recording. " + "Click 'Open' to open the System dialog and click the '+' button. " + "This will give you a file dropdown: find and select Gum Recorder to add these permissions. " + "if it is not already listed. If it is already listed, and you still don't see permissions, click 'Refresh' to update the list or file a ticket." + ), + "automation": ( + "Browser URL capture relies on macOS Automation (Apple Events). After you click Enable Browser URLs, approve the prompts " + "and make sure Gum Recorder — or 'osascript' if you launched from Terminal — stays enabled under Privacy & Security → Automation." + ), +} + + +BROWSER_AUTOMATION_SCRIPTS: list[tuple[str, str, str]] = [ + ("Safari", "Safari.app", 'tell application "Safari" to return name'), + ("Google Chrome", "Google Chrome.app", 'tell application "Google Chrome" to return name'), + ("Chromium", "Chromium.app", 'tell application "Chromium" to return name'), + ("Brave Browser", "Brave Browser.app", 'tell application "Brave Browser" to return name'), + ("Microsoft Edge", "Microsoft Edge.app", 'tell application "Microsoft Edge" to return name'), + ("Arc", "Arc.app", 'tell application "Arc" to return name'), +] + + +AUTOMATION_APP_LOCATIONS = ["/Applications", os.path.expanduser("~/Applications")] + + +def load_settings(settings_path: str, default_output_dir: str) -> dict[str, Any]: + try: + with open(settings_path, "r") as fh: + data = json.load(fh) + if "output_dir" not in data: + data["output_dir"] = default_output_dir + if "onboarding_done" not in data: + data["onboarding_done"] = False + return data + except Exception: + return { + "output_dir": default_output_dir, + "onboarding_done": False, + } + + +def save_settings(settings_path: str, data: dict[str, Any]) -> None: + try: + os.makedirs(os.path.dirname(settings_path), exist_ok=True) + with open(settings_path, "w") as fh: + json.dump(data, fh, indent=2) + except Exception: # pragma: no cover - best effort persistence + pass + + +def format_output_dir(path: str, max_len: int = 64) -> str: + normalized = os.path.abspath(os.path.expanduser(path)) + if len(normalized) <= max_len: + return normalized + head = max_len // 2 - 1 + tail = max_len - head - 1 + return normalized[:head] + "…" + normalized[-tail:] + + +# ─────────────────────────────── macOS helpers +def open_uri(uri: str) -> None: + try: + subprocess.Popen(["open", uri]) + except Exception: # pragma: no cover - best effort + pass + + +def open_screen_recording_settings() -> None: + open_uri("x-apple.systempreferences:com.apple.preference.security?Privacy_ScreenCapture") + + +def open_accessibility_settings() -> None: + open_uri("x-apple.systempreferences:com.apple.preference.security?Privacy_Accessibility") + + +def open_keyboard_monitoring_settings() -> None: + open_uri("x-apple.systempreferences:com.apple.preference.security?Privacy_ListenEvent") + + +def open_automation_settings() -> None: + open_uri("x-apple.systempreferences:com.apple.preference.security?Privacy_Automation") + + +def open_screen_recording_help() -> None: + open_uri("https://support.apple.com/guide/mac-help/allow-apps-to-use-screen-and-audio-recording-mchl592e5686/26/mac/26") + + +def open_accessibility_help() -> None: + open_uri("https://support.apple.com/guide/mac-help/allow-accessibility-apps-to-access-your-mac-mh43185/26/mac/26") + + +def open_input_monitoring_help() -> None: + open_uri("https://support.apple.com/guide/mac-help/control-access-to-input-monitoring-on-mac-mchl4cedafb6/26/mac/26") + + +def open_automation_help() -> None: + open_uri("https://support.apple.com/guide/mac-help/allow-apps-to-control-your-mac-mchl30d1931e/mac") + + +def check_screen_recording_granted() -> bool | None: + try: + import Quartz + + if hasattr(Quartz, "CGPreflightScreenCaptureAccess"): + return bool(Quartz.CGPreflightScreenCaptureAccess()) + except Exception: + return None + return None + + +def request_screen_recording_access() -> None: + _request_screen_recording_access(open_settings=True) + + +def _request_screen_recording_access(open_settings: bool) -> bool: + triggered = False + try: + import Quartz + + request = getattr(Quartz, "CGRequestScreenCaptureAccess", None) + if callable(request): + try: + request() + except TypeError: + request(None) + triggered = True + except Exception: # Catches screen recording permission failures + pass + + if open_settings: + open_screen_recording_settings() + + return triggered + + +def prime_screen_recording_permission(logger: logging.Logger | None = None) -> bool: + logger = logger or logging.getLogger("gum.ui.screen_preflight") + status = check_screen_recording_granted() + if status is True: + logger.debug("Screen Recording permission already granted; skipping preflight") + return False + + triggered = _request_screen_recording_access(open_settings=False) + if triggered: + logger.debug("Triggered Screen Recording permission prompt via preflight") + else: + logger.debug("Unable to trigger Screen Recording preflight (API unavailable or call failed)") + return triggered + + +def check_accessibility_granted() -> bool | None: + try: + import Quartz + + if hasattr(Quartz, "AXIsProcessTrusted"): + return bool(Quartz.AXIsProcessTrusted()) + if hasattr(Quartz, "AXIsProcessTrustedWithOptions"): + return bool(Quartz.AXIsProcessTrustedWithOptions(None)) + except Exception: + return None + return None + + +def prompt_accessibility_access() -> None: + _prompt_accessibility_access(open_settings=True) + + +def _prompt_accessibility_access(open_settings: bool) -> bool: + triggered = False + try: + import Quartz + + if hasattr(Quartz, "AXIsProcessTrustedWithOptions"): + from Foundation import NSDictionary + + options = {"kAXTrustedCheckOptionPrompt": True} + Quartz.AXIsProcessTrustedWithOptions( + NSDictionary.dictionaryWithDictionary_(options) + ) + triggered = True + except Exception: # Catches accessibility permission failures + pass + + if open_settings: + open_accessibility_settings() + + return triggered + + +def prime_accessibility_permission(logger: logging.Logger | None = None) -> bool: + logger = logger or logging.getLogger("gum.ui.accessibility_preflight") + status = check_accessibility_granted() + if status is True: + logger.debug("Accessibility permission already granted; skipping preflight") + return False + + triggered = _prompt_accessibility_access(open_settings=False) + if triggered: + logger.debug("Triggered Accessibility permission prompt via preflight") + else: + logger.debug("Unable to trigger Accessibility preflight (API unavailable or call failed)") + return triggered + + +def check_input_monitoring_granted() -> bool | None: + try: + import Quartz + + if hasattr(Quartz, "CGPreflightListenEventAccess"): + return bool(Quartz.CGPreflightListenEventAccess()) + + mask = (1 << Quartz.kCGEventKeyDown) | (1 << Quartz.kCGEventKeyUp) + tap = Quartz.CGEventTapCreate( + Quartz.kCGSessionEventTap, + Quartz.kCGHeadInsertEventTap, + Quartz.kCGEventTapOptionListenOnly, + mask, + None, + None, + ) + if tap is None: + return False + try: + Quartz.CFRelease(tap) + except Exception: # Catches input monitoring permission preflight failures + pass + return True + except Exception: + return None + + +def prompt_input_monitoring_access() -> None: + _request_input_monitoring_access(open_settings=True) + + +def _request_input_monitoring_access(open_settings: bool) -> bool: + triggered = False + try: + import Quartz + + request = getattr(Quartz, "CGRequestListenEventAccess", None) + if callable(request): + try: + request() + except TypeError: + request(None) + triggered = True + except Exception: # Catches input monitoring permission preflight failures + pass + + if open_settings: + open_keyboard_monitoring_settings() + + return triggered + + +def prime_input_monitoring_permission(logger: logging.Logger | None = None) -> bool: + logger = logger or logging.getLogger("gum.ui.input_preflight") + status = check_input_monitoring_granted() + if status is True: + logger.debug("Input Monitoring permission already granted; skipping preflight") + return False + + triggered = _request_input_monitoring_access(open_settings=False) + if triggered: + logger.debug("Triggered Input Monitoring permission prompt via preflight") + else: + logger.debug("Unable to trigger Input Monitoring preflight (API unavailable or call failed)") + return triggered + + +def check_automation_granted() -> bool | None: + return check_automation_permission_granted() + + +def prompt_automation_access() -> None: + inspector = None + running_snapshot: list[tuple[str, str, str]] = [] + status_before = check_automation_permission_granted(force_refresh=True) + + if AppleUIInspector is not None: + try: + inspector = AppleUIInspector(logging.getLogger("gum.ui.automation_prompt")) + running_snapshot = list(inspector.running_browser_applications()) + except Exception: # Catches automation permission preflight failures + inspector = None + running_snapshot = [] + + used_new_path = False + if inspector is not None: + try: + used_new_path = inspector.prime_automation_for_running_browsers() + if used_new_path: + status_check = check_automation_permission_granted(force_refresh=True) + if status_check is not True: + used_new_path = False + except Exception: # Catches automation permission preflight failures + used_new_path = False + + if not used_new_path: + _legacy_trigger_automation_scripts() + + status_after = check_automation_permission_granted(force_refresh=True) + notes = _collect_automation_guidance(status_before, status_after, inspector, running_snapshot) + if status_after is not True and notes: + _show_automation_guidance(notes) + + open_automation_settings() + + +def _legacy_trigger_automation_scripts() -> None: + def _app_installed(bundle_name: str) -> bool: + for base in AUTOMATION_APP_LOCATIONS: + if os.path.exists(os.path.join(base, bundle_name)): + return True + return False + + any_triggered = False + for _label, bundle, script in BROWSER_AUTOMATION_SCRIPTS: + if not _app_installed(bundle): + continue + try: + subprocess.run( + ["osascript", "-e", script], + capture_output=True, + text=True, + timeout=1.5, + ) + any_triggered = True + except Exception: # Catches automation permission preflight failures + continue + + if not any_triggered: + try: + subprocess.run( + ["osascript", "-e", 'tell application "System Events" to return 1'], + capture_output=True, + text=True, + timeout=1.5, + ) + except Exception: # Catches automation permission preflight failures + pass + + +def prime_automation_permissions(logger: logging.Logger | None = None) -> None: + if AppleUIInspector is None: + _legacy_trigger_automation_scripts() + return + try: + inspector = AppleUIInspector(logger or logging.getLogger("gum.ui.automation_preflight")) + attempted = inspector.prime_automation_for_running_browsers() + if not attempted and check_automation_permission_granted(force_refresh=True) is not True: + _legacy_trigger_automation_scripts() + except Exception: # Catches automation permission preflight failures + _legacy_trigger_automation_scripts() + + +def _collect_automation_guidance(status_before: bool | None, status_after: bool | None, inspector: AppleUIInspector | None, running_snapshot: list[tuple[str, str, str]]) -> list[str]: + notes: list[str] = [] + + sig_info = _analyze_code_signature() + binary_path = sig_info.get("binary") + bundle_path = sig_info.get("bundle") + signed = sig_info.get("signed") + has_entitlement = sig_info.get("has_automation_entitlement") + codesign_error = sig_info.get("codesign_error") + + if getattr(sys, "frozen", False): + target_path = bundle_path or binary_path + if signed is False: + if target_path: + notes.append( + f"Code-sign the app at {target_path} (ad-hoc signing works) so macOS can request Automation. Example: `codesign --force --deep --options runtime --sign - '{target_path}'`." + ) + else: + notes.append("Code-sign the running app with the Hardened Runtime and Apple Events entitlement so macOS can show the Automation prompt.") + elif signed and not has_entitlement: + notes.append( + "The current build is signed without the com.apple.security.automation.apple-events entitlement. Add that entitlement and rebuild/sign before retrying." + ) + elif codesign_error: + notes.append(codesign_error) + else: + notes.append( + "You're running the recorder from source. macOS will attribute the Automation request to `osascript`. Approve 'osascript' under Privacy & Security → Automation or build a signed app when you're ready." + ) + + quarantine_path = bundle_path or (binary_path if getattr(sys, "frozen", False) else None) + if quarantine_path and _has_quarantine_attribute(quarantine_path): + notes.append(f"Remove the quarantine attribute with `xattr -dr com.apple.quarantine '{quarantine_path}'` and relaunch.") + + running = list(running_snapshot) + if not running and inspector is not None: + try: + running = list(inspector.running_browser_applications()) + except Exception: # Catches automation permission preflight failures + running = [] + if not running: + notes.append("Open Safari, Chrome, Brave, Edge, or Arc and leave a window active before enabling browser URLs.") + + if status_after is False or status_before is False: + bundle_id = _detect_bundle_identifier(bundle_path) + if not getattr(sys, "frozen", False): + target = "com.apple.osascript" + else: + target = bundle_id or "Gum Recorder" + notes.append(f"macOS has a previous Automation denial recorded. Run `tccutil reset AppleEvents {target}` and relaunch the app.") + + return _dedupe_preserve_order(notes) + + +def _show_automation_guidance(notes: list[str]) -> None: + if not notes: + return + try: + from tkinter import messagebox + except Exception: + logging.getLogger("gum.ui.automation_prompt").info("Automation guidance: %s", " | ".join(notes)) + return + + body = "\n\n".join(notes) + message = "Automation permission is still unavailable. Try the steps below:\n\n" + body + messagebox.showinfo("Enable Browser URLs", message) + + +def _analyze_code_signature() -> dict[str, Any]: + binary_path, bundle_path = _current_binary_paths() + info: dict[str, Any] = { + "binary": binary_path, + "bundle": bundle_path, + "signed": None, + "has_automation_entitlement": None, + "codesign_error": None, + } + + if not getattr(sys, "frozen", False) or binary_path is None: + return info + + try: + proc = subprocess.run( + ["codesign", "--display", "--entitlements", "-", str(binary_path)], + capture_output=True, + text=True, + timeout=3, + ) + except FileNotFoundError: + info["codesign_error"] = "The `codesign` tool is not available in PATH; unable to verify entitlements." + return info + except Exception as exc: + info["codesign_error"] = f"Failed to inspect code signature: {exc}" + return info + + if proc.returncode != 0: + info["signed"] = False + info["codesign_error"] = (proc.stderr or proc.stdout or "Code signature check failed").strip() + return info + + info["signed"] = True + ent_blob = (proc.stdout or proc.stderr or "").lower() + info["has_automation_entitlement"] = "com.apple.security.automation.apple-events" in ent_blob + return info + + +def _current_binary_paths() -> tuple[Path | None, Path | None]: + try: + if getattr(sys, "frozen", False): + binary = Path(sys.executable).resolve() + else: + binary = Path(__file__).resolve() + except Exception: + binary = None + + bundle = None + if binary is not None: + for parent in binary.parents: + if parent.suffix == ".app": + bundle = parent + break + return binary, bundle + + +def _has_quarantine_attribute(path: Path) -> bool: + try: + proc = subprocess.run( + ["xattr", "-p", "com.apple.quarantine", str(path)], + capture_output=True, + text=True, + timeout=1.0, + ) + except FileNotFoundError: + return False + except Exception: + return False + return proc.returncode == 0 + + +def _detect_bundle_identifier(bundle_path: Path | None) -> str | None: + if bundle_path: + plist_path = bundle_path / "Contents" / "Info.plist" + try: + with open(plist_path, "rb") as fh: + data = plistlib.load(fh) + bundle_id = data.get("CFBundleIdentifier") + if isinstance(bundle_id, str) and bundle_id.strip(): + return bundle_id.strip() + except Exception: + pass + env_bundle = os.environ.get("GUM_BUNDLE_ID") + if env_bundle: + return env_bundle + return None + + +def _dedupe_preserve_order(items: list[str]) -> list[str]: + seen: set[str] = set() + result: list[str] = [] + for item in items: + if item and item not in seen: + seen.add(item) + result.append(item) + return result diff --git a/record/gum/app/macos/apple_events.entitlements.plist b/record/gum/app/macos/apple_events.entitlements.plist new file mode 100644 index 0000000..5062340 --- /dev/null +++ b/record/gum/app/macos/apple_events.entitlements.plist @@ -0,0 +1,8 @@ + + + + + com.apple.security.automation.apple-events + + + diff --git a/record/gum/app/macos/build-macos.sh b/record/gum/app/macos/build-macos.sh new file mode 100755 index 0000000..2026f04 --- /dev/null +++ b/record/gum/app/macos/build-macos.sh @@ -0,0 +1,200 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Resolve project root (directory of this script) +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +cd "$SCRIPT_DIR" +# Find the nearest parent directory containing setup.py or pyproject.toml +PROJECT_ROOT="$SCRIPT_DIR" +while true; do + if [ -f "$PROJECT_ROOT/pyproject.toml" ] || [ -f "$PROJECT_ROOT/setup.py" ]; then + break + fi + PARENT="$(dirname "$PROJECT_ROOT")" + if [ "$PARENT" = "$PROJECT_ROOT" ]; then + echo "Could not locate a Python project root (setup.py or pyproject.toml)." >&2 + exit 1 + fi + PROJECT_ROOT="$PARENT" +done + +# Args: +# $1: version (optional) OR env name if not a version +# $2: env name (optional) + +# Load previous version if present +VERSION_FILE="${SCRIPT_DIR}/VERSION" +PREV_VERSION="" +if [ -f "$VERSION_FILE" ]; then + PREV_VERSION="$(cat "$VERSION_FILE" | tr -d '\n' | tr -d '\r')" +fi + +is_version() { + [[ "$1" =~ ^[0-9]+(\.[0-9]+)*$ ]] +} + +increment_version() { + local v="$1" + IFS='.' read -r -a parts <<< "$v" + local n=${#parts[@]} + if [ $n -eq 0 ]; then + echo "0.1" + return + fi + local last_index=$((n-1)) + local last=${parts[$last_index]} + if [[ "$last" =~ ^[0-9]+$ ]]; then + parts[$last_index]=$((last+1)) + else + # Fallback: append .1 + parts+=("1") + fi + (IFS='.'; echo "${parts[*]}") +} + +RAW1="${1:-}" +RAW2="${2:-}" + +VERSION="" +CONDA_ENV_NAME="record" + +if [ -n "$RAW1" ]; then + if is_version "$RAW1"; then + VERSION="$RAW1" + if [ -n "$RAW2" ]; then + CONDA_ENV_NAME="$RAW2" + fi + else + CONDA_ENV_NAME="$RAW1" + if [ -n "$RAW2" ] && is_version "$RAW2"; then + VERSION="$RAW2" + fi + fi +fi + +if [ -z "$VERSION" ]; then + if [ -n "$PREV_VERSION" ] && is_version "$PREV_VERSION"; then + VERSION="$(increment_version "$PREV_VERSION")" + else + VERSION="0.1" + fi +fi + +echo "$VERSION" > "$VERSION_FILE" + +APP_BASENAME="Gum Recorder" +APP_NAME="${APP_BASENAME} v${VERSION}" +IDENTIFIER="com.local.gumrecorder.v${VERSION}" + +echo "==> Building ${APP_NAME} (${IDENTIFIER}) using env: ${CONDA_ENV_NAME}" + +echo "==> Initializing conda and activating environment: ${CONDA_ENV_NAME}" +if command -v conda >/dev/null 2>&1; then + # shellcheck disable=SC1091 + source "$(conda info --base)/etc/profile.d/conda.sh" + if conda activate "${CONDA_ENV_NAME}" 2>/dev/null; then + echo "Activated env by name: ${CONDA_ENV_NAME}" + else + # Try to find env path from 'conda env list' + ENV_PATH="$(conda env list | awk -v tgt="${CONDA_ENV_NAME}" '$0 ~ "/envs/"tgt"$" {print $NF}')" + if [ -z "${ENV_PATH}" ]; then + # Try matching any path that ends with the target name + ENV_PATH="$(conda env list | awk -v tgt="/${CONDA_ENV_NAME}$" '$NF ~ tgt {print $NF}')" + fi + if [ -n "${ENV_PATH}" ] && conda activate "${ENV_PATH}" 2>/dev/null; then + echo "Activated env by path: ${ENV_PATH}" + else + echo "Conda environment '${CONDA_ENV_NAME}' not found. Create it or pass a different name/version as arguments." >&2 + conda info --envs || true + exit 1 + fi + fi +else + echo "Conda not found on PATH. Please install Miniconda/Anaconda and create 'record' env." >&2 + exit 1 +fi + +echo "==> Installing/Updating build dependencies" +python -m pip install --upgrade pip wheel +python -m pip install --upgrade pyinstaller + +echo "==> Ensuring package is installable" +pushd "$PROJECT_ROOT" >/dev/null +python -m pip install -e . +popd >/dev/null + +echo "==> Building macOS .app with PyInstaller" +export PYINSTALLER_APP_NAME="${APP_NAME}" +export PYINSTALLER_BUNDLE_IDENTIFIER="${IDENTIFIER}" +pyinstaller --noconfirm "${SCRIPT_DIR}/Gum Recorder.spec" +unset PYINSTALLER_APP_NAME +unset PYINSTALLER_BUNDLE_IDENTIFIER + +APP_PATH="${SCRIPT_DIR}/dist/${APP_NAME}.app" +INFO_PLIST="${APP_PATH}/Contents/Info.plist" + +# Post-process Info.plist with version keys +if [ -f "$INFO_PLIST" ]; then + if /usr/libexec/PlistBuddy -c 'Print :CFBundleShortVersionString' "$INFO_PLIST" >/dev/null 2>&1; then + /usr/libexec/PlistBuddy -c "Set :CFBundleShortVersionString ${VERSION}" "$INFO_PLIST" || true + else + /usr/libexec/PlistBuddy -c "Add :CFBundleShortVersionString string ${VERSION}" "$INFO_PLIST" || true + fi + if /usr/libexec/PlistBuddy -c 'Print :CFBundleVersion' "$INFO_PLIST" >/dev/null 2>&1; then + /usr/libexec/PlistBuddy -c "Set :CFBundleVersion ${VERSION}" "$INFO_PLIST" || true + else + /usr/libexec/PlistBuddy -c "Add :CFBundleVersion string ${VERSION}" "$INFO_PLIST" || true + fi + # Ensure Automation usage description for Apple Events (required on macOS for URL retrieval via AppleScript) + if /usr/libexec/PlistBuddy -c 'Print :NSAppleEventsUsageDescription' "$INFO_PLIST" >/dev/null 2>&1; then + /usr/libexec/PlistBuddy -c "Set :NSAppleEventsUsageDescription Uses Apple Events to read the active browser tab URL you are viewing." "$INFO_PLIST" || true + else + /usr/libexec/PlistBuddy -c "Add :NSAppleEventsUsageDescription string Uses Apple Events to read the active browser tab URL you are viewing." "$INFO_PLIST" || true + fi +fi + +# Reset TCC permissions for this app bundle identifier (best effort, non-fatal) +echo "==> Resetting TCC permissions for ${IDENTIFIER} (Accessibility, InputMonitoring, ScreenCapture, AppleEvents)" +if command -v tccutil >/dev/null 2>&1; then + tccutil reset Accessibility "${IDENTIFIER}" || true + tccutil reset InputMonitoring "${IDENTIFIER}" || true + tccutil reset ScreenCapture "${IDENTIFIER}" || true + tccutil reset AppleEvents "${IDENTIFIER}" || true +else + echo "tccutil not found; skipping TCC reset" +fi + +# Prepare entitlements for Apple Events automation (always generated so we can ad-hoc sign if needed) +ENTITLEMENTS_PLIST="${SCRIPT_DIR}/apple_events.entitlements.plist" +cat > "${ENTITLEMENTS_PLIST}" <<'EOF' + + + + + com.apple.security.automation.apple-events + + + +EOF + +SIGN_IDENTITY_DEFAULT="Mac Developer" +SIGN_IDENTITY="${CODESIGN_IDENTITY:-${SIGN_IDENTITY_DEFAULT}}" +SELECTED_IDENTITY="${SIGN_IDENTITY}" +if [ "${SIGN_IDENTITY}" != "-" ] && ! security find-identity -v -p codesigning 2>/dev/null | grep -q "${SIGN_IDENTITY}"; then + echo "No code signing identity '${SIGN_IDENTITY}' found; defaulting to ad-hoc signing." + SELECTED_IDENTITY="-" +fi + +echo "==> Code signing app with identity: ${SELECTED_IDENTITY} (Apple Events entitlement)" +if [ "${SELECTED_IDENTITY}" = "-" ]; then + codesign --deep --force --entitlements "${ENTITLEMENTS_PLIST}" --sign "${SELECTED_IDENTITY}" "${APP_PATH}" +else + codesign --deep --force --options runtime --entitlements "${ENTITLEMENTS_PLIST}" --sign "${SELECTED_IDENTITY}" "${APP_PATH}" +fi +echo "==> Verifying code signature" +codesign --verify --deep --strict --verbose=2 "${APP_PATH}" || true + +echo "==> Build complete" +echo "Open the app at: ${APP_PATH}" +echo "Version: ${VERSION}" +echo "Note: On first run, macOS will ask for Screen Recording, Accessibility, Input Monitoring, and Automation permissions." diff --git a/record/gum/cli/__init__.py b/record/gum/cli/__init__.py new file mode 100644 index 0000000..cb1c8a8 --- /dev/null +++ b/record/gum/cli/__init__.py @@ -0,0 +1,12 @@ +"""Command-line interface utilities for the Gum recorder.""" + +from .background import BackgroundRecorder +from .main import main, _run_cli + +__all__ = ["BackgroundRecorder", "main", "cli_main"] + + + +def cli_main(): + import asyncio + asyncio.run(_run_cli()) diff --git a/record/gum/cli/background.py b/record/gum/cli/background.py new file mode 100644 index 0000000..08e82ec --- /dev/null +++ b/record/gum/cli/background.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import asyncio +import logging +import threading +from typing import Optional + +from ..gum import gum as GumApp +from ..observers import AppAndBrowserInspector, Screen + + +class BackgroundRecorder: + """Controller to run/stop the recorder in a background thread.""" + + _thread: Optional[threading.Thread] = None + _loop: Optional[asyncio.AbstractEventLoop] = None + _stop_event: Optional[asyncio.Event] = None + _running: bool = False + _screen: Optional[Screen] = None + + @classmethod + def is_running(cls) -> bool: + return cls._running + + @classmethod + def start( + cls, + user_name: str = "anonymous", + data_directory: str = "~/Downloads/records", + screenshots_dir: str = "~/Downloads/records/screenshots", + debug: bool = False, + scroll_debounce: float = 0.5, + scroll_min_distance: float = 5.0, + scroll_max_frequency: int = 10, + scroll_session_timeout: float = 2.0, + ) -> None: + if cls._running: + return + + cls._stop_event = asyncio.Event() + cls._loop = asyncio.new_event_loop() + + async def _run(): + inspector_logger = logging.getLogger("gum.inspector.background") + app_inspector = AppAndBrowserInspector(inspector_logger) + screen_observer = Screen( + screenshots_dir=screenshots_dir, + debug=debug, + keystroke_log_path=f"{data_directory}/keystrokes.log", + keyboard_timeout=2.0, + keyboard_sample_interval_sec=0.25, + scroll_debounce_sec=scroll_debounce, + scroll_min_distance=scroll_min_distance, + scroll_max_frequency=scroll_max_frequency, + scroll_session_timeout=scroll_session_timeout, + app_inspector=app_inspector, + ) + BackgroundRecorder._screen = screen_observer + async with GumApp( + user_name, + screen_observer, + data_directory=data_directory, + app_and_browser_inspector=app_inspector, + ): + await cls._stop_event.wait() + + def _thread_target(): + assert cls._loop is not None + asyncio.set_event_loop(cls._loop) + try: + cls._loop.run_until_complete(_run()) + finally: + try: + pending = asyncio.all_tasks(loop=cls._loop) + for task in pending: + task.cancel() + cls._loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + except Exception: + pass + cls._loop.stop() + cls._loop.close() + cls._loop = None + cls._stop_event = None + cls._running = False + cls._screen = None + + cls._thread = threading.Thread(target=_thread_target, daemon=True) + cls._running = True + cls._thread.start() + + @classmethod + def stop(cls) -> None: + if not cls._running: + return + if cls._loop and cls._stop_event: + def _set_event(): + assert cls._stop_event is not None + cls._stop_event.set() + try: + cls._loop.call_soon_threadsafe(_set_event) + except RuntimeError: + pass + if cls._thread and cls._thread.is_alive(): + cls._thread.join(timeout=5) + cls._thread = None + cls._running = False + + @classmethod + def post_key_event(cls, token: str, event_type: str) -> None: + """Forward a key token event from the main thread to the Screen observer.""" + if not cls._running or cls._loop is None or cls._screen is None: + return + + async def _dispatch(): + try: + await cls._screen.handle_key_token_event(token, event_type) + except Exception: + pass + + try: + asyncio.run_coroutine_threadsafe(_dispatch(), cls._loop) + except Exception: + pass diff --git a/record/gum/cli/main.py b/record/gum/cli/main.py new file mode 100644 index 0000000..4d8e025 --- /dev/null +++ b/record/gum/cli/main.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import argparse +import asyncio +import logging + +from dotenv import load_dotenv + +from ..gum import gum as GumApp +from ..observers import AppAndBrowserInspector, Screen + +load_dotenv() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="GUM - A Python package with command-line interface") + parser.add_argument("--user-name", "-u", type=str, default="anonymous", help="The user name to use") + parser.add_argument("--debug", "-d", action="store_true", help="Enable debug mode") + + # Output directories + parser.add_argument( + "--data-directory", + type=str, + default="~/Downloads/records", + help="Directory for database and logs (default: ~/Downloads/records)", + ) + parser.add_argument( + "--screenshots-dir", + type=str, + default="~/Downloads/records/screenshots", + help="Directory to save screenshots (default: ~/Downloads/records/screenshots)", + ) + + # Scroll filtering options + parser.add_argument( + "--scroll-debounce", + type=float, + default=0.5, + help="Minimum time between scroll events (seconds, default: 0.5)", + ) + parser.add_argument( + "--scroll-min-distance", + type=float, + default=5.0, + help="Minimum scroll distance to log (pixels, default: 5.0)", + ) + parser.add_argument( + "--scroll-max-frequency", + type=int, + default=10, + help="Maximum scroll events per second (default: 10)", + ) + parser.add_argument( + "--scroll-session-timeout", + type=float, + default=2.0, + help="Scroll session timeout (seconds, default: 2.0)", + ) + + return parser.parse_args() + + +async def _run_cli() -> None: + args = parse_args() + print(f"User Name: {args.user_name}") + + inspector_logger = logging.getLogger("gum.inspector") + app_inspector = AppAndBrowserInspector(inspector_logger) + + screen_observer = Screen( + screenshots_dir=args.screenshots_dir, + debug=args.debug, + scroll_debounce_sec=args.scroll_debounce, + scroll_min_distance=args.scroll_min_distance, + scroll_max_frequency=args.scroll_max_frequency, + scroll_session_timeout=args.scroll_session_timeout, + app_inspector=app_inspector, + ) + + async with GumApp( + args.user_name, + screen_observer, + data_directory=args.data_directory, + app_and_browser_inspector=app_inspector, + ): + await asyncio.Future() + + +def main() -> None: + asyncio.run(_run_cli()) + + +if __name__ == "__main__": + main() diff --git a/record/gum/db_utils.py b/record/gum/db_utils.py new file mode 100644 index 0000000..1f80eb6 --- /dev/null +++ b/record/gum/db_utils.py @@ -0,0 +1,164 @@ +# db_utils.py + +from __future__ import annotations + +import math +from datetime import datetime, timezone +import re +from typing import List + +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity + +from sqlalchemy import MetaData, Table, literal_column, select, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from .models import Observation, Proposition, proposition_parent + +def build_fts_query(raw: str, mode: str = "OR") -> str: + tokens = re.findall(r"\w+", raw.lower()) + if not tokens: + return "" + if mode == "PHRASE": + return f'"{" ".join(tokens)}"' + elif mode == "OR": + return " OR ".join(tokens) + else: # implicit AND + return " ".join(tokens) + +def _has_child_subquery() -> select: + return ( + select(literal_column("1")) + .select_from(proposition_parent) + .where(proposition_parent.c.parent_id == Proposition.id) + .exists() + ) + +# constants +K_DECAY = 2.0 # decay rate for recency adjustment +LAMBDA = 0.5 # trade-off for MMR + +async def search_propositions_bm25( + session: AsyncSession, + user_query: str, + *, + limit: int = 3, + mode: str = "OR", + start_time: datetime | None = None, + end_time: datetime | None = None, +) -> list[tuple[Proposition, float]]: + """ + Args: + session: AsyncSession for database operations + user_query: Search query string + limit: Maximum number of results to return + mode: Search mode ("AND", "OR", or "PHRASE") + start_time: Start of time range (UTC, inclusive) + end_time: End of time range (UTC, inclusive, defaults to now) + """ + q = build_fts_query(user_query, mode) + if not q: + return [] + + candidate_pool = max(limit * 10, limit) + + fts = Table("propositions_fts", MetaData()) + bm25_col = literal_column("bm25(propositions_fts)").label("bm25") + join_cond = literal_column("propositions_fts.rowid") == Proposition.id + has_child = _has_child_subquery() + + # Set default end_time to now if not provided + if end_time is None: + end_time = datetime.now(timezone.utc) + + # Ensure both times are timezone-aware + if start_time is not None and start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + stmt = ( + select(Proposition, bm25_col) + .select_from(fts.join(Proposition, join_cond)) + .where(text("propositions_fts MATCH :q")) + .where(~has_child) + ) + + # Add time range filtering + if start_time is not None: + stmt = stmt.where(Proposition.created_at >= start_time) + stmt = stmt.where(Proposition.created_at <= end_time) + + stmt = ( + stmt.order_by(bm25_col) + .options(selectinload(Proposition.observations)) + .limit(candidate_pool) + ) + + raw = await session.execute(stmt, {"q": q}) + rows = raw.all() + if not rows: + return [] + + now = datetime.now(timezone.utc) + rel_scores: List[float] = [] + for prop, raw_score in rows: + + dt = prop.created_at + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + + age_days = max((now - dt).total_seconds() / 86400, 0.0) + + alpha = prop.decay if prop.decay is not None else 0.0 + gamma = math.exp(-alpha * K_DECAY * age_days) + + r_eff = -raw_score * gamma + rel_scores.append(r_eff) + + docs = [f"{p.text} {p.reasoning}" for p, _ in rows] + vecs = TfidfVectorizer().fit_transform(docs) + + # MMR selection + selected_idxs: List[int] = [] + final_scores: List[float] = [] + + while len(selected_idxs) < min(limit, len(rows)): + if not selected_idxs: + idx = int(np.argmax(rel_scores)) + selected_idxs.append(idx) + final_scores.append(rel_scores[idx]) + continue + + sims = cosine_similarity(vecs, vecs[selected_idxs]).max(axis=1) + mmr_scores = (LAMBDA * np.array(rel_scores) + - (1 - LAMBDA) * sims) + + # never pick twice + mmr_scores[selected_idxs] = -np.inf + + idx = int(np.argmax(mmr_scores)) + selected_idxs.append(idx) + final_scores.append(float(mmr_scores[idx])) + + return [(rows[i][0], final_scores[pos]) + for pos, i in enumerate(selected_idxs)] + +async def get_related_observations( + session: AsyncSession, + proposition_id: int, + *, + limit: int = 5, +) -> List[Observation]: + + stmt = ( + select(Observation) + .join(Observation.propositions) + .where(Proposition.id == proposition_id) + .order_by(Observation.created_at.desc()) + .limit(limit) # Use the limit parameter + ) + result = await session.execute(stmt) + return result.scalars().all() \ No newline at end of file diff --git a/record/gum/gum.py b/record/gum/gum.py new file mode 100644 index 0000000..8c981d9 --- /dev/null +++ b/record/gum/gum.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import asyncio +import logging +import os +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Callable +from logging.handlers import RotatingFileHandler + +from .models import Observation, init_db +from .observers import Observer +from .schemas import Update +from .observers import AppAndBrowserInspector + +class gum: + def __init__( + self, + user_name: str, + *observers: Observer, + data_directory: str = "~/Downloads/records", + db_name: str = "actions.db", + max_concurrent_updates: int = 4, + verbosity: int = logging.INFO, + app_and_browser_inspector: AppAndBrowserInspector | None = None, + ): + # basic paths + data_directory = os.path.expanduser(data_directory) + os.makedirs(data_directory, exist_ok=True) + + # runtime + self.user_name = user_name + self.observers: list[Observer] = list(observers) + + # logging + self.logger = logging.getLogger("gum") + self.logger.setLevel(verbosity) + if not self.logger.handlers: + h = logging.StreamHandler() + h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + self.logger.addHandler(h) + try: + file_handler = RotatingFileHandler( + os.path.join(data_directory, "gum.log"), maxBytes=2_000_000, backupCount=3 + ) + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + self.logger.addHandler(file_handler) + except Exception: # if fails, keep console logging only + pass + + self.engine = None + self.Session = None + self._db_name = db_name + self._data_directory = data_directory + + self._update_sem = asyncio.Semaphore(max_concurrent_updates) + self._tasks: set[asyncio.Task] = set() + self._loop_task: asyncio.Task | None = None + self.update_handlers: list[Callable[[Observer, Update], None]] = [] + self._app_and_browser_inspector = app_and_browser_inspector or AppAndBrowserInspector(self.logger) + + def start_update_loop(self): + if self._loop_task is None: + self._loop_task = asyncio.create_task(self._update_loop()) + + async def stop_update_loop(self): + if self._loop_task: + self._loop_task.cancel() + try: + await self._loop_task + except asyncio.CancelledError: + pass + self._loop_task = None + + async def connect_db(self): + if self.engine is None: + self.engine, self.Session = await init_db( + self._db_name, self._data_directory + ) + + async def __aenter__(self): + await self.connect_db() + try: + await asyncio.to_thread(self._app_and_browser_inspector.prime_automation_for_running_browsers) + except AttributeError: + await asyncio.get_running_loop().run_in_executor( + None, self._app_and_browser_inspector.prime_automation_for_running_browsers + ) + self.start_update_loop() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.stop_update_loop() + if self._tasks: + await asyncio.gather(*self._tasks, return_exceptions=True) + for obs in self.observers: + await obs.stop() + + async def _update_loop(self): + """ + Wait for *any* observer to produce an Update and + dispatch it through the semaphore-guarded handler. + """ + while True: + gets = { + asyncio.create_task(obs.update_queue.get()): obs + for obs in self.observers + } + + done, _ = await asyncio.wait( + gets.keys(), return_when=asyncio.FIRST_COMPLETED + ) + + for fut in done: + upd: Update = fut.result() + obs = gets[fut] + + t = asyncio.create_task(self._run_with_gate(obs, upd)) + self._tasks.add(t) + + async def _run_with_gate(self, observer: Observer, update: Update): + """Wrapper that enforces max_concurrent_updates.""" + async with self._update_sem: + try: + await self._default_handler(observer, update) + finally: + self._tasks.discard(asyncio.current_task()) + + async def _handle_audit(self, obs: Observation) -> bool: + return False + + async def _default_handler(self, observer: Observer, update: Update) -> None: + self.logger.info(f"Processing update from {observer.name}") + self.logger.info(f"Content ({update.content_type}): {update.content[:10]}") + + async with self._session() as session: + app_name = update.app_name or self._app_and_browser_inspector.get_frontmost_app_name() + browser_url = update.browser_url + if app_name and browser_url is None: + browser_url = self._app_and_browser_inspector.get_browser_url(app_name) + if app_name: + self.logger.debug("Active app resolved to '%s'%s", app_name, f" (url={browser_url})" if browser_url else "") + + now = datetime.now().astimezone() + + observation = Observation( + observer_name=observer.name, + content=update.content, + content_type=update.content_type, + app_name=app_name, + browser_url=browser_url, + created_at=now, + updated_at=now, + ) + + if await self._handle_audit(observation): + return + + session.add(observation) + await session.flush() + + @asynccontextmanager + async def _session(self): + async with self.Session() as s: + async with s.begin(): + yield s + + def add_observer(self, observer: Observer): + self.observers.append(observer) + + def remove_observer(self, observer: Observer): + if observer in self.observers: + self.observers.remove(observer) + + def register_update_handler(self, fn: Callable[[Observer, Update], None]): + self.update_handlers.append(fn) diff --git a/gum/models.py b/record/gum/models.py similarity index 58% rename from gum/models.py rename to record/gum/models.py index df903ae..f6eed3d 100644 --- a/gum/models.py +++ b/record/gum/models.py @@ -30,11 +30,6 @@ from sqlalchemy.sql import func class Base(AsyncAttrs, DeclarativeBase): - """Base class for all database models. - - This class provides the foundation for all SQLAlchemy models in the application, - including async support and declarative base functionality. - """ pass observation_proposition = Table( @@ -54,30 +49,34 @@ class Base(AsyncAttrs, DeclarativeBase): ), ) - +proposition_parent = Table( + "proposition_parent", + Base.metadata, + Column( + "child_id", + Integer, + ForeignKey("propositions.id", ondelete="CASCADE"), + primary_key=True, + ), + Column( + "parent_id", + Integer, + ForeignKey("propositions.id", ondelete="CASCADE"), + primary_key=True, + ), +) class Observation(Base): - """Represents an observation of user behavior. - - This model stores observations made by various observers about user behavior, - including the content of the observation and metadata about when and how it was made. - - Attributes: - id (int): Primary key for the observation. - observer_name (str): Name of the observer that made this observation. - content (str): The actual content of the observation. - content_type (str): Type of content (e.g., 'text', 'image', etc.). - created_at (datetime): When the observation was created. - updated_at (datetime): When the observation was last updated. - propositions (set[Proposition]): Set of propositions related to this observation. - """ __tablename__ = "observations" id: Mapped[int] = mapped_column(primary_key=True) observer_name: Mapped[str] = mapped_column(String(100), nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False) content_type: Mapped[str] = mapped_column(String(50), nullable=False) + # New optional metadata fields for each action/observation + app_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + browser_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) created_at: Mapped[str] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False @@ -99,33 +98,10 @@ class Observation(Base): ) def __repr__(self) -> str: - """String representation of the observation. - - Returns: - str: A string representation showing the observation ID and observer name. - """ return f"" class Proposition(Base): - """Represents a proposition about user behavior. - - This model stores propositions generated from observations, including the proposition - text, reasoning behind it, and metadata about its creation and relationships. - - Attributes: - id (int): Primary key for the proposition. - text (str): The actual proposition text. - reasoning (str): The reasoning behind this proposition. - confidence (Optional[int]): Confidence level in this proposition. - decay (Optional[int]): Decay factor for this proposition. - created_at (datetime): When the proposition was created. - updated_at (datetime): When the proposition was last updated. - revision_group (str): Group identifier for related proposition revisions. - version (int): Version number of this proposition. - - observations (set[Observation]): Set of observations related to this proposition. - """ __tablename__ = "propositions" id: Mapped[int] = mapped_column(primary_key=True) @@ -147,7 +123,15 @@ class Proposition(Base): revision_group: Mapped[str] = mapped_column(String(36), nullable=False, index=True) version: Mapped[int] = mapped_column(Integer, server_default="1", nullable=False) - + parents: Mapped[set["Proposition"]] = relationship( + "Proposition", + secondary=proposition_parent, + primaryjoin=id == proposition_parent.c.child_id, + secondaryjoin=id == proposition_parent.c.parent_id, + backref="children", + collection_class=set, + lazy="selectin", + ) observations: Mapped[set[Observation]] = relationship( "Observation", @@ -159,11 +143,6 @@ class Proposition(Base): ) def __repr__(self) -> str: - """String representation of the proposition. - - Returns: - str: A string representation showing the proposition ID and a preview of its text. - """ preview = (self.text[:27] + "…") if len(self.text) > 30 else self.text return f"" @@ -171,14 +150,7 @@ def __repr__(self) -> str: FTS_TOKENIZER = "porter ascii" def create_fts_table(conn) -> None: - """Create FTS5 virtual table and triggers for proposition search. - - This function creates a full-text search table for propositions and sets up - triggers to maintain the search index as propositions are modified. - - Args: - conn: SQLite database connection. - """ + """Create FTS5 virtual table + triggers on first run.""" exists = conn.execute( sql_text( "SELECT 1 FROM sqlite_master WHERE type='table' AND name='propositions_fts'" @@ -245,61 +217,6 @@ def create_fts_table(conn) -> None: ) ) -def create_observations_fts(conn) -> None: - """Create FTS5 virtual table and triggers for observation search. - - This function creates a full-text search table for observations and sets up - triggers to maintain the search index as observations are modified. - - Args: - conn: SQLite database connection. - """ - exists = conn.execute(sql_text( - "SELECT 1 FROM sqlite_master " - "WHERE type='table' AND name='observations_fts'" - )).fetchone() - if exists: - return # already present - - conn.execute(sql_text(f""" - CREATE VIRTUAL TABLE observations_fts - USING fts5( - content, - content='observations', - content_rowid='id', - tokenize='{FTS_TOKENIZER}' - ); - """)) - conn.execute(sql_text(""" - CREATE TRIGGER observations_ai - AFTER INSERT ON observations BEGIN - INSERT INTO observations_fts(rowid, content) - VALUES (new.id, new.content); - END; - """)) - conn.execute(sql_text(""" - CREATE TRIGGER observations_ad - AFTER DELETE ON observations BEGIN - INSERT INTO observations_fts(observations_fts, rowid, content) - VALUES ('delete', old.id, old.content); - END; - """)) - conn.execute(sql_text(""" - CREATE TRIGGER observations_au - AFTER UPDATE ON observations BEGIN - INSERT INTO observations_fts(observations_fts, rowid, content) - VALUES ('delete', old.id, old.content); - INSERT INTO observations_fts(rowid, content) - VALUES (new.id, new.content); - END; - """)) - # back-fill the index - conn.execute(sql_text(""" - INSERT INTO observations_fts(rowid, content) - SELECT id, content FROM observations; - """)) - - async def init_db( db_path: str = "gum.db", db_directory: Optional[str] = None, @@ -326,7 +243,15 @@ async def init_db( await conn.run_sync(Base.metadata.create_all) await conn.run_sync(create_fts_table) - await conn.run_sync(create_observations_fts) + # Ensure new columns exist when upgrading an existing DB without migrations + def _ensure_observation_columns(sync_conn): + info = sync_conn.execute(sql_text("PRAGMA table_info(observations)")).fetchall() + cols = {row[1] for row in info} + if "app_name" not in cols: + sync_conn.execute(sql_text("ALTER TABLE observations ADD COLUMN app_name VARCHAR(100)")) + if "browser_url" not in cols: + sync_conn.execute(sql_text("ALTER TABLE observations ADD COLUMN browser_url TEXT")) + await conn.run_sync(_ensure_observation_columns) Session = async_sessionmaker( engine, diff --git a/record/gum/observers/__init__.py b/record/gum/observers/__init__.py new file mode 100644 index 0000000..332baec --- /dev/null +++ b/record/gum/observers/__init__.py @@ -0,0 +1,80 @@ +"""Observer package orchestrating platform-specific implementations.""" + +from __future__ import annotations + +import sys +from typing import Callable + +from .base import Observer +from .base.screen import Screen as BaseScreen +from .base.keyboard import KeyboardBackend +from .base.mouse import MouseBackend +from .base.screenshots import ScreenshotBackend +from .fallback import ( + FallbackAppAndBrowserInspector, + FallbackScreenshotBackend, + PynputKeyboardBackend, + PynputMouseBackend, + fallback_check_automation_permission_granted, +) + +AppAndBrowserInspector = FallbackAppAndBrowserInspector +check_automation_permission_granted = fallback_check_automation_permission_granted +keyboard_factory: Callable[[], KeyboardBackend] = lambda: PynputKeyboardBackend() +mouse_factory: Callable[[], MouseBackend] = lambda: PynputMouseBackend() +screenshot_factory: Callable[[], ScreenshotBackend] = FallbackScreenshotBackend +visibility_guard = None + +if sys.platform == "darwin": # macOS + try: + from .macos.keyboard import MacKeyboardBackend + from .macos.mouse import MacMouseBackend + from .macos.screenshots import MacScreenshotBackend, is_app_visible + from .macos.app_and_browser_logging import MacOSAppAndBrowserInspector + from .macos.app_and_browser_logging import check_automation_permission_granted as mac_check_automation_permission_granted + except Exception: + AppAndBrowserInspector = FallbackAppAndBrowserInspector + else: + AppAndBrowserInspector = MacOSAppAndBrowserInspector + check_automation_permission_granted = mac_check_automation_permission_granted + + keyboard_factory = lambda: MacKeyboardBackend() + mouse_factory = lambda: MacMouseBackend() + screenshot_factory = MacScreenshotBackend + visibility_guard = is_app_visible +elif sys.platform.startswith("win"): + try: + from .windows.keyboard import WindowsKeyboardBackend + from .windows.mouse import WindowsMouseBackend + from .windows.screenshots import WindowsScreenshotBackend + from .windows.app_and_browser_logging import WindowsAppAndBrowserInspector + from .windows.app_and_browser_logging import ( + check_automation_permission_granted as windows_check_automation_permission_granted, + ) + except Exception: + check_automation_permission_granted = fallback_check_automation_permission_granted + else: + keyboard_factory = lambda: WindowsKeyboardBackend() + mouse_factory = lambda: WindowsMouseBackend() + screenshot_factory = WindowsScreenshotBackend + AppAndBrowserInspector = WindowsAppAndBrowserInspector + check_automation_permission_granted = windows_check_automation_permission_granted +elif "check_automation_permission_granted" not in globals(): + check_automation_permission_granted = fallback_check_automation_permission_granted + + +class Screen(BaseScreen): + """Concrete screen observer wired with the detected platform backends.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__( + keyboard_backend=keyboard_factory(), + mouse_backend=mouse_factory(), + screenshot_backend_factory=screenshot_factory, + visibility_guard=visibility_guard, + *args, + **kwargs, + ) + + +__all__ = ["Observer", "Screen", "AppAndBrowserInspector", "check_automation_permission_granted"] diff --git a/record/gum/observers/base/__init__.py b/record/gum/observers/base/__init__.py new file mode 100644 index 0000000..303eeb2 --- /dev/null +++ b/record/gum/observers/base/__init__.py @@ -0,0 +1,19 @@ +"""Shared observer building blocks.""" + +from .observer import Observer +from .keyboard import KeyboardBackend, KeyDispatch, NullKeyboardBackend +from .mouse import MouseBackend, NullMouseBackend +from .screenshots import MssScreenshotBackend, ScreenshotBackend +from .screen import Screen + +__all__ = [ + "Observer", + "KeyboardBackend", + "KeyDispatch", + "NullKeyboardBackend", + "MouseBackend", + "NullMouseBackend", + "ScreenshotBackend", + "MssScreenshotBackend", + "Screen", +] diff --git a/record/gum/observers/base/keyboard.py b/record/gum/observers/base/keyboard.py new file mode 100644 index 0000000..441eca0 --- /dev/null +++ b/record/gum/observers/base/keyboard.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Callable, Protocol + +KeyDispatch = Callable[[str, str], None] + + +class KeyboardBackend(Protocol): + """Abstract interface for keyboard event backends.""" + + def start(self, dispatch: KeyDispatch) -> None: + """Begin emitting keyboard events via *dispatch* (token, action).""" + + def stop(self) -> None: + """Stop emitting keyboard events and release any resources.""" + + +class NullKeyboardBackend: + """Keyboard backend stub that never emits events.""" + + def start(self, dispatch: KeyDispatch) -> None: # pragma: no cover - no behaviour to test + return + + def stop(self) -> None: # pragma: no cover - no behaviour to test + return diff --git a/record/gum/observers/base/mouse.py b/record/gum/observers/base/mouse.py new file mode 100644 index 0000000..9a0f2c9 --- /dev/null +++ b/record/gum/observers/base/mouse.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Callable, Protocol + +MouseClickHandler = Callable[[float, float, str], None] +MouseScrollHandler = Callable[[float, float, float, float], None] + + +class MouseBackend(Protocol): + """Abstract interface for mouse event sources.""" + + def start(self, on_click: MouseClickHandler, on_scroll: MouseScrollHandler) -> None: + """Start delivering events towards the provided callbacks.""" + + def stop(self) -> None: + """Stop delivering mouse events and release any resources.""" + + +class NullMouseBackend: + """Mouse backend stub that never emits events.""" + + def start(self, on_click: MouseClickHandler, on_scroll: MouseScrollHandler) -> None: # pragma: no cover + return + + def stop(self) -> None: # pragma: no cover + return diff --git a/record/gum/observers/base/observer.py b/record/gum/observers/base/observer.py new file mode 100644 index 0000000..583ecb9 --- /dev/null +++ b/record/gum/observers/base/observer.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import Optional +import asyncio + +class Observer(ABC): + + def __init__(self, name: Optional[str] = None) -> None: + self.update_queue = asyncio.Queue() + self._name = name or self.__class__.__name__ + + self._running = True + self._task: asyncio.Task | None = asyncio.create_task(self._worker_wrapper()) + + @abstractmethod + async def _worker(self) -> None: + pass + + async def _worker_wrapper(self) -> None: + try: + await self._worker() + except asyncio.CancelledError: + pass + except Exception as exc: + raise + finally: + self._running = False + + # ─────────────────────────────── public API + @property + def name(self) -> str: + return self._name + + async def get_update(self): + """Return an Update if immediately available, else None (non-blocking).""" + try: + return self.update_queue.get_nowait() + except asyncio.QueueEmpty: + return None + + async def stop(self) -> None: + """Cancel the worker task and drain the queue.""" + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + # unblock any awaiters + while not self.update_queue.empty(): + self.update_queue.get_nowait() diff --git a/record/gum/observers/base/screen.py b/record/gum/observers/base/screen.py new file mode 100644 index 0000000..5362edc --- /dev/null +++ b/record/gum/observers/base/screen.py @@ -0,0 +1,632 @@ +from __future__ import annotations + +import asyncio +import gc +import logging +import os +import time +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence + +from PIL import Image, ImageDraw +from pynput import mouse + +from .observer import Observer +from ..constants import ( + CAPTURE_FPS_DEFAULT, + MEMORY_CLEANUP_INTERVAL_DEFAULT, + MON_START_INDEX, + KEYBOARD_TIMEOUT_DEFAULT, + KEYBOARD_SAMPLE_INTERVAL_DEFAULT, + SCROLL_DEBOUNCE_SEC_DEFAULT, + SCROLL_MIN_DISTANCE_DEFAULT, + SCROLL_MAX_FREQUENCY_DEFAULT, + SCROLL_SESSION_TIMEOUT_DEFAULT, +) +from .keyboard import KeyDispatch, KeyboardBackend, NullKeyboardBackend +from .mouse import MouseBackend, NullMouseBackend +from .screenshots import ScreenshotBackend +from ...schemas import Update + +VisibilityGuard = Callable[[Iterable[str]], bool] + + +class Screen(Observer): + """Platform-agnostic core for observing screen, keyboard, and mouse activity.""" + + _CAPTURE_FPS: int = CAPTURE_FPS_DEFAULT + _PERIODIC_SEC: int = 30 + _DEBOUNCE_SEC: int = 1 + _MON_START: int = MON_START_INDEX + _MEMORY_CLEANUP_INTERVAL: int = MEMORY_CLEANUP_INTERVAL_DEFAULT + _MAX_WORKERS: int = 4 + + _SCROLL_DEBOUNCE_SEC: float = SCROLL_DEBOUNCE_SEC_DEFAULT + _SCROLL_MIN_DISTANCE: float = SCROLL_MIN_DISTANCE_DEFAULT + _SCROLL_MAX_FREQUENCY: int = SCROLL_MAX_FREQUENCY_DEFAULT + _SCROLL_SESSION_TIMEOUT: float = SCROLL_SESSION_TIMEOUT_DEFAULT + + def __init__( + self, + keyboard_backend: Optional[KeyboardBackend], + mouse_backend: Optional[MouseBackend], + screenshot_backend_factory: Callable[[], ScreenshotBackend], + app_inspector: Optional[Any] = None, + visibility_guard: Optional[VisibilityGuard] = None, + *, + screenshots_dir: str = "~/Downloads/records/screenshots", + skip_when_visible: Optional[str | Sequence[str]] = None, + history_k: int = 10, + debug: bool = False, + keyboard_timeout: float = KEYBOARD_TIMEOUT_DEFAULT, + keystroke_log_path: Optional[str] = None, + keyboard_sample_interval_sec: float = KEYBOARD_SAMPLE_INTERVAL_DEFAULT, + scroll_debounce_sec: float = SCROLL_DEBOUNCE_SEC_DEFAULT, + scroll_min_distance: float = SCROLL_MIN_DISTANCE_DEFAULT, + scroll_max_frequency: int = SCROLL_MAX_FREQUENCY_DEFAULT, + scroll_session_timeout: float = SCROLL_SESSION_TIMEOUT_DEFAULT, + ) -> None: + self._keyboard_backend = keyboard_backend or NullKeyboardBackend() + self._mouse_backend = mouse_backend or NullMouseBackend() + self._screenshot_backend_factory = screenshot_backend_factory + self._app_inspector = app_inspector + self._visibility_guard = visibility_guard + + self.screens_dir = os.path.abspath(os.path.expanduser(screenshots_dir)) + os.makedirs(self.screens_dir, exist_ok=True) + + if isinstance(skip_when_visible, str): + self._guard = {skip_when_visible} + else: + self._guard = set(skip_when_visible or []) + + self.debug = debug + self._thread_pool = ThreadPoolExecutor(max_workers=self._MAX_WORKERS) + + # Scroll filtering configuration + self._scroll_debounce_sec = scroll_debounce_sec + self._scroll_min_distance = scroll_min_distance + self._scroll_max_frequency = scroll_max_frequency + self._scroll_session_timeout = scroll_session_timeout + + # Frame buffers shared with worker + self._frames: Dict[int, Any] = {} + self._frame_lock = asyncio.Lock() + + self._history: deque[str] = deque(maxlen=max(0, history_k)) + self._pending_event: Optional[dict] = None + self._debounce_handle: Optional[asyncio.TimerHandle] = None + + # keyboard activity tracking + self._key_activity_start: Optional[float] = None + self._key_activity_timeout: float = keyboard_timeout + self._key_screenshots: List[str] = [] + self._key_activity_lock = asyncio.Lock() + self._last_key_screenshot_time: Optional[float] = None + self._last_key_position: Optional[tuple[float, float, int]] = None + self._keystroke_log_path: Optional[str] = ( + os.path.abspath(os.path.expanduser(keystroke_log_path)) if keystroke_log_path else None + ) + self._keyboard_sample_interval_sec: float = max(0.0, float(keyboard_sample_interval_sec)) + + # scroll activity tracking + self._scroll_last_time: Optional[float] = None + self._scroll_last_position: Optional[tuple[float, float]] = None + self._scroll_session_start: Optional[float] = None + self._scroll_event_count: int = 0 + self._scroll_lock = asyncio.Lock() + + # Keyboard backend bookkeeping + self._keyboard_dispatch: Optional[KeyDispatch] = None + + # Mouse backend bookkeeping + self._mouse_started: bool = False + + super().__init__() + + if self._detect_high_dpi(): + self._CAPTURE_FPS = 3 + self._MEMORY_CLEANUP_INTERVAL = 20 + if self.debug: + logging.getLogger("Screen").info("High-DPI display detected, using conservative settings") + + @staticmethod + def _mon_for(x: float, y: float, mons: Sequence[dict]) -> Optional[int]: + for idx, m in enumerate(mons, 1): + if m["left"] <= x < m["left"] + m["width"] and m["top"] <= y < m["top"] + m["height"]: + return idx + return 1 + + async def _run_in_thread(self, func, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self._thread_pool, lambda: func(*args, **kwargs)) + + def _detect_high_dpi(self) -> bool: + try: + backend = self._screenshot_backend_factory() + with backend as sct: + for monitor in sct.monitors[self._MON_START:]: + if monitor.get("width", 0) > 2560 or monitor.get("height", 0) > 1600: + return True + except Exception: + return False + return False + + def _should_log_scroll(self, x: float, y: float, dx: float, dy: float) -> bool: + current_time = time.time() + + if ( + self._scroll_session_start is None + or current_time - self._scroll_session_start > self._scroll_session_timeout + ): + self._scroll_session_start = current_time + self._scroll_event_count = 0 + self._scroll_last_position = (x, y) + self._scroll_last_time = current_time + return True + + if self._scroll_last_time is not None and current_time - self._scroll_last_time < self._scroll_debounce_sec: + return False + + if self._scroll_last_position is not None: + distance = ((x - self._scroll_last_position[0]) ** 2 + (y - self._scroll_last_position[1]) ** 2) ** 0.5 + if distance < self._scroll_min_distance: + return False + + self._scroll_event_count += 1 + session_duration = current_time - self._scroll_session_start + if session_duration > 0: + frequency = self._scroll_event_count / session_duration + if frequency > self._scroll_max_frequency: + return False + + self._scroll_last_position = (x, y) + self._scroll_last_time = current_time + return True + + def _snapshot_app_context(self, cursor: Optional[tuple[float, float]] = None) -> tuple[Optional[str], Optional[str]]: + inspector = getattr(self, "_app_inspector", None) + if inspector is None: + return None, None + app_name: Optional[str] = None + browser_url: Optional[str] = None + + try: + app_name = inspector.get_frontmost_app_name() + except Exception: + app_name = None + + if not app_name and cursor and hasattr(inspector, "app_at_point"): + try: + name_at_point, _ = inspector.app_at_point(*cursor) + except Exception: + name_at_point = None + if name_at_point: + app_name = name_at_point + + if app_name: + try: + browser_url = inspector.get_browser_url(app_name) + except Exception: + browser_url = None + + return app_name, browser_url + + async def _cleanup_key_screenshots(self) -> None: + if len(self._key_screenshots) <= 2: + return + to_delete = self._key_screenshots[1:-1] + self._key_screenshots = [self._key_screenshots[0], self._key_screenshots[-1]] + + for path in to_delete: + try: + await self._run_in_thread(os.remove, path) + if self.debug: + logging.getLogger("Screen").info(f"Deleted intermediate screenshot: {path}") + except OSError: + pass + + async def _save_frame( + self, + frame, + x: float, + y: float, + tag: str, + box_color: str = "red", + box_width: int = 10, + scale: float = 1.0, + ) -> str: + ts = f"{time.time():.5f}" + path = os.path.join(self.screens_dir, f"{ts}_{tag}.jpg") + image = Image.frombytes("RGB", (frame.width, frame.height), frame.rgb) + draw = ImageDraw.Draw(image) + sx = max(1.0, float(scale)) + x = int(x * sx) + y = int(y * sx) + x1, x2 = max(0, x - 30), min(frame.width, x + 30) + y1, y2 = max(0, y - 20), min(frame.height, y + 20) + draw.rectangle([x1, y1, x2, y2], outline=box_color, width=box_width) + + await self._run_in_thread( + image.save, + path, + "JPEG", + quality=70, + optimize=True, + ) + + del draw + del image + return path + + async def _process_and_emit( + self, + before_path: str, + after_path: str | None, + action: Optional[str], + ev: Optional[dict], + ) -> None: + if ev is None: + return + cursor = ev.get("cursor") if isinstance(ev, dict) else None + app_name, browser_url = self._snapshot_app_context(cursor) + mon = ev.get("mon") if isinstance(ev, dict) else None + mon_sfx = f"@mon{mon}" if mon is not None else "" + if action and "scroll" in action: + scroll_info = ev.get("scroll", (0, 0)) + step = f"scroll{mon_sfx}({ev['position'][0]:.1f}, {ev['position'][1]:.1f}, dx={scroll_info[0]:.2f}, dy={scroll_info[1]:.2f})" + await self.update_queue.put( + Update(content=step, content_type="input_text", app_name=app_name, browser_url=browser_url) + ) + elif action and "click" in action: + step = f"{action}{mon_sfx}({ev['position'][0]:.1f}, {ev['position'][1]:.1f})" + await self.update_queue.put( + Update(content=step, content_type="input_text", app_name=app_name, browser_url=browser_url) + ) + elif action: + step = f"{action}{mon_sfx}({ev.get('text', '')})" + await self.update_queue.put( + Update(content=step, content_type="input_text", app_name=app_name, browser_url=browser_url) + ) + + async def stop(self) -> None: + await super().stop() + try: + self._keyboard_backend.stop() + except Exception: + pass + try: + self._mouse_backend.stop() + except Exception: + pass + + async with self._frame_lock: + for frame in self._frames.values(): + if frame is not None: + del frame + self._frames.clear() + await self._run_in_thread(gc.collect) + self._thread_pool.shutdown(wait=True) + + def _skip(self) -> bool: + if not self._guard or self._visibility_guard is None: + return False + try: + return self._visibility_guard(self._guard) + except Exception: + return False + + async def _worker(self) -> None: + log = logging.getLogger("Screen") + if self.debug: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [Screen] %(message)s", + datefmt="%H:%M:%S", + ) + else: + log.addHandler(logging.NullHandler()) + log.propagate = False + + cap_fps = self._CAPTURE_FPS + loop = asyncio.get_running_loop() + + with self._screenshot_backend_factory() as capture: + mons = capture.monitors[self._MON_START:] + + async def flush_pending() -> None: + if self._pending_event is None: + return + if self._skip(): + self._pending_event = None + return + + ev = self._pending_event + mon = mons[ev["mon"] - 1] + try: + aft = await self._run_in_thread(capture.grab, mon) + except Exception as e: + if self.debug: + log.error(f"Failed to capture after frame: {e}") + self._pending_event = None + return + + try: + scale = float(aft.width) / float(mon.get("width", aft.width) or aft.width) + if scale <= 0: + scale = 1.0 + except Exception: + scale = 1.0 + + bef_path = await self._save_frame( + ev["before"], + ev["position"][0], + ev["position"][1], + f"{ev['type']}_before@mon{ev['mon']}", + scale=scale, + ) + aft_path = await self._save_frame( + aft, + ev["position"][0], + ev["position"][1], + f"{ev['type']}_after@mon{ev['mon']}", + scale=scale, + ) + await self._process_and_emit(bef_path, aft_path, ev["type"], ev) + log.info(f"{ev['type']} captured on monitor {ev['mon']}") + self._pending_event = None + + async def scroll_event(x: float, y: float, dx: float, dy: float): + async with self._scroll_lock: + if not self._should_log_scroll(x, y, dx, dy): + if self.debug: + log.info(f"Scroll filtered out: dx={dx:.2f}, dy={dy:.2f}") + return + idx = self._mon_for(x, y, mons) + if idx is None: + return + mon = mons[idx - 1] + x_local = x - mon["left"] + y_local = y - mon["top"] + log.info( + f"scroll @({x_local:7.1f},{y_local:7.1f}) dx={dx:.2f} dy={dy:.2f} → mon={idx}" + ) + scroll_magnitude = (dx ** 2 + dy ** 2) ** 0.5 + if scroll_magnitude < 1.0: + if self.debug: + log.info(f"Scroll too small: magnitude={scroll_magnitude:.2f}") + return + if self._skip(): + return + async with self._frame_lock: + fr = self._frames.get(idx) + if fr is None: + return + try: + scale = float(fr.width) / float(mon.get("width", fr.width) or fr.width) + if scale <= 0: + scale = 1.0 + except Exception: + scale = 1.0 + + self._pending_event = { + "type": "scroll", + "position": (x_local, y_local), + "mon": idx, + "before": fr, + "scale": scale, + "scroll": (dx, dy), + "cursor": (x, y), + } + await flush_pending() + + async def mouse_event(x: float, y: float, typ: str): + idx = self._mon_for(x, y, mons) + if idx is None: + return + mon = mons[idx - 1] + x_local = x - mon["left"] + y_local = y - mon["top"] + log.info( + f"{typ:<6} @({x_local:7.1f},{y_local:7.1f}) → mon={idx} {'(guarded)' if self._skip() else ''}" + ) + if self._skip(): + return + async with self._frame_lock: + fr = self._frames.get(idx) + if fr is None: + return + try: + scale = float(fr.width) / float(mon.get("width", fr.width) or fr.width) + if scale <= 0: + scale = 1.0 + except Exception: + scale = 1.0 + + self._pending_event = { + "type": typ, + "position": (x_local, y_local), + "mon": idx, + "before": fr, + "scale": scale, + "cursor": (x, y), + } + await flush_pending() + + def schedule_mouse_click(x: float, y: float, typ: str) -> None: + asyncio.run_coroutine_threadsafe(mouse_event(x, y, typ), loop) + + def schedule_mouse_scroll(x: float, y: float, dx: float, dy: float) -> None: + asyncio.run_coroutine_threadsafe(scroll_event(x, y, dx, dy), loop) + + try: + self._mouse_backend.start(schedule_mouse_click, schedule_mouse_scroll) + self._mouse_started = True + except Exception as exc: + self._mouse_started = False + log.warning(f"Failed to start mouse backend: {exc}") + + async def key_token_event(tok: str, typ: str): + controller = mouse.Controller() + x, y = controller.position + idx = self._mon_for(x, y, mons) + if idx is None: + return + mon = mons[idx - self._MON_START] + x_local = x - mon["left"] + y_local = y - mon["top"] + if self._keystroke_log_path: + try: + os.makedirs(os.path.dirname(self._keystroke_log_path), exist_ok=True) + with open(self._keystroke_log_path, "a") as f: + f.write(f"{datetime.now().isoformat()}\t{typ}\t{tok}\n") + except Exception: + pass + + step = f"key_{typ}@mon{idx}({tok})" + app_name, browser_url = self._snapshot_app_context((x, y)) + await self.update_queue.put( + Update( + content=step, + content_type="input_text", + app_name=app_name, + browser_url=browser_url, + ) + ) + + async with self._key_activity_lock: + current_time = time.time() + try: + async with self._frame_lock: + fr = self._frames.get(idx) + scale = float(fr.width) / float(mon.get("width", fr.width) or fr.width) if fr else 1.0 + except Exception: + scale = 1.0 + + if ( + self._key_activity_start is None + or current_time - self._key_activity_start > self._key_activity_timeout + ): + self._key_activity_start = current_time + self._key_screenshots = [] + if fr is not None: + screenshot_path = await self._save_frame(fr, x_local, y_local, f"{step}_first", scale=scale) + self._key_screenshots.append(screenshot_path) + self._last_key_screenshot_time = current_time + self._last_key_position = (x_local, y_local, idx) + else: + should_sample = ( + self._last_key_screenshot_time is None + or (current_time - self._last_key_screenshot_time) >= self._keyboard_sample_interval_sec + ) + if should_sample and fr is not None: + screenshot_path = await self._save_frame(fr, x_local, y_local, f"{step}_intermediate", scale=scale) + self._key_screenshots.append(screenshot_path) + self._last_key_screenshot_time = current_time + self._last_key_position = (x_local, y_local, idx) + + if len(self._key_screenshots) > 2: + asyncio.create_task(self._cleanup_key_screenshots()) + + async def _handle_key_token_event(tok: str, typ: str): + await key_token_event(tok, typ) + + self.handle_key_token_event = _handle_key_token_event # type: ignore[attr-defined] + + def schedule_key_token_event(tok: str, typ: str) -> None: + asyncio.run_coroutine_threadsafe(key_token_event(tok, typ), loop) + + # a bit hacky, but this helps us run the keyboard in the main thread when running the app + # (otherwise, on macOS running in the background will crash...) + disable_keyboard_env = os.environ.get("GUM_DISABLE_KEYBOARD", "").strip() == "1" + if not disable_keyboard_env: + try: + self._keyboard_backend.start(schedule_key_token_event) + except Exception as exc: + log.warning(f"Keyboard backend unavailable: {exc}") + else: + log.info("Keyboard backend disabled via GUM_DISABLE_KEYBOARD=1; using main-thread shim events only") + + log.info(f"Screen observer started — guarding {self._guard or '∅'}") + frame_count = 0 + + while self._running: + t0 = time.time() + + for idx, mon in enumerate(mons, 1): + old_frame = None + async with self._frame_lock: + old_frame = self._frames.get(idx) + try: + frame = await self._run_in_thread(capture.grab, mon) + except Exception as e: + if self.debug: + log.error(f"Failed to capture frame: {e}") + continue + + async with self._frame_lock: + self._frames[idx] = frame + + if old_frame is not None: + del old_frame + + frame_count += 1 + if frame_count % self._MEMORY_CLEANUP_INTERVAL == 0: + await self._run_in_thread(gc.collect) + + current_time = time.time() + if ( + self._key_activity_start is not None + and current_time - self._key_activity_start > self._key_activity_timeout + and len(self._key_screenshots) >= 1 + ): + async with self._key_activity_lock: + try: + if self._last_key_position is not None: + lx, ly, lidx = self._last_key_position + mon = mons[lidx - self._MON_START] + async with self._frame_lock: + fr = self._frames.get(lidx) + scale = ( + float(fr.width) / float(mon.get("width", fr.width) or fr.width) + if fr + else 1.0 + ) + if fr is not None: + final_path = await self._save_frame(fr, lx, ly, f"key_final@mon{lidx}", scale=scale) + self._key_screenshots.append(final_path) + except Exception: + pass + await self._cleanup_key_screenshots() + self._key_activity_start = None + self._key_screenshots = [] + self._last_key_screenshot_time = None + self._last_key_position = None + + dt = time.time() - t0 + await asyncio.sleep(max(0, (1 / cap_fps) - dt)) + + try: + await flush_pending() + finally: + try: + self._keyboard_backend.stop() + except Exception: + pass + try: + self._mouse_backend.stop() + except Exception: + pass + + if self._key_activity_start is not None and len(self._key_screenshots) > 1: + async with self._key_activity_lock: + last_path = self._key_screenshots[-1] + final_path = last_path.replace("_intermediate", "_final") + try: + await self._run_in_thread(os.rename, last_path, final_path) + log.info(f"Final keyboard session cleanup, renamed: {final_path}") + except OSError: + pass + await self._cleanup_key_screenshots() diff --git a/record/gum/observers/base/screenshots.py b/record/gum/observers/base/screenshots.py new file mode 100644 index 0000000..a57cb2d --- /dev/null +++ b/record/gum/observers/base/screenshots.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Protocol, Sequence + +import mss + + +class ScreenshotBackend(Protocol): + """Abstract interface for grabbing screen frames.""" + + def __enter__(self) -> "ScreenshotBackend": + ... + + def __exit__(self, exc_type, exc, tb) -> None: + ... + + @property + def monitors(self) -> Sequence[dict]: + ... + + def grab(self, monitor: dict): + ... + + +class MssScreenshotBackend: + """Minimal wrapper around ``mss.mss`` exposing the backend protocol.""" + + def __init__(self, **kwargs) -> None: + self._kwargs = kwargs + self._mss: mss.mss | None = None + + def __enter__(self) -> "MssScreenshotBackend": + self._mss = mss.mss(**self._kwargs) + return self + + def __exit__(self, exc_type, exc, tb) -> None: + if self._mss is not None: + close = getattr(self._mss, "close", None) + if callable(close): + close() + self._mss = None + + @property + def monitors(self) -> Sequence[dict]: + if self._mss is None: + raise RuntimeError("Screenshot backend not initialised; use as a context manager") + return self._mss.monitors + + def grab(self, monitor: dict): + if self._mss is None: + raise RuntimeError("Screenshot backend not initialised; use as a context manager") + return self._mss.grab(monitor) diff --git a/record/gum/observers/constants.py b/record/gum/observers/constants.py new file mode 100644 index 0000000..af7d906 --- /dev/null +++ b/record/gum/observers/constants.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +# Shared defaults used by observer backends and platform integrations + +# Screen capture and housekeeping +CAPTURE_FPS_DEFAULT: int = 5 +MEMORY_CLEANUP_INTERVAL_DEFAULT: int = 30 +MON_START_INDEX: int = 1 + +# Keyboard activity sampling +KEYBOARD_TIMEOUT_DEFAULT: float = 2.0 +KEYBOARD_SAMPLE_INTERVAL_DEFAULT: float = 0.25 + +# Scroll filtering +SCROLL_DEBOUNCE_SEC_DEFAULT: float = 0.5 +SCROLL_MIN_DISTANCE_DEFAULT: float = 5.0 +SCROLL_MAX_FREQUENCY_DEFAULT: int = 10 +SCROLL_SESSION_TIMEOUT_DEFAULT: float = 2.0 + +# macOS app UI intervals (milliseconds) +KEYBOARD_PUMP_INTERVAL_MS: int = 50 +PERMISSION_REFRESH_INTERVAL_MS: int = 3000 + + diff --git a/record/gum/observers/fallback/__init__.py b/record/gum/observers/fallback/__init__.py new file mode 100644 index 0000000..63c0327 --- /dev/null +++ b/record/gum/observers/fallback/__init__.py @@ -0,0 +1,17 @@ +"""Platform-agnostic observer helpers and fallbacks.""" + +from .keyboard import PynputKeyboardBackend +from .mouse import PynputMouseBackend +from .screenshots import FallbackScreenshotBackend +from .app_and_browser_logging import ( + FallbackAppAndBrowserInspector, + check_automation_permission_granted as fallback_check_automation_permission_granted, +) + +__all__ = [ + "PynputKeyboardBackend", + "PynputMouseBackend", + "FallbackScreenshotBackend", + "FallbackAppAndBrowserInspector", + "fallback_check_automation_permission_granted", +] diff --git a/record/gum/observers/fallback/app_and_browser_logging.py b/record/gum/observers/fallback/app_and_browser_logging.py new file mode 100644 index 0000000..32f615b --- /dev/null +++ b/record/gum/observers/fallback/app_and_browser_logging.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Optional + + +class FallbackAppAndBrowserInspector: # pragma: no cover - placeholder behaviour + """Fallback stub used on platforms without UI automation support.""" + + def __init__(self, logger) -> None: + self.logger = logger + self.last_frontmost_bundle_id: Optional[str] = None + + def get_frontmost_app_name(self) -> Optional[str]: + return None + + def get_browser_url(self, app_name: Optional[str]) -> Optional[str]: + return None + + def snapshot_running_browsers(self): # type: ignore[override] + return [] + + def prime_automation_for_running_browsers(self) -> bool: + return False + + +def check_automation_permission_granted(force_refresh: bool = False) -> Optional[bool]: + """Always returns *None* to indicate unknown automation status.""" + return None + + +__all__ = ["FallbackAppAndBrowserInspector", "check_automation_permission_granted"] diff --git a/record/gum/observers/fallback/keyboard.py b/record/gum/observers/fallback/keyboard.py new file mode 100644 index 0000000..b32f438 --- /dev/null +++ b/record/gum/observers/fallback/keyboard.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from typing import Optional + +from pynput import keyboard + +from ..base.keyboard import KeyDispatch, KeyboardBackend + + +def _token_from_key(key: keyboard.Key | keyboard.KeyCode) -> Optional[str]: + try: + if isinstance(key, keyboard.KeyCode): + if key.char: + return f"TEXT:{key.char}" + if key.vk is not None: + return f"VK:{key.vk}" + elif isinstance(key, keyboard.Key): + name = getattr(key, "name", None) + if name: + return f"KEY:{name}" + return f"KEY:{str(key)}" + except Exception: + return None + return None + + +class PynputKeyboardBackend(KeyboardBackend): + """Simple keyboard listener that reports events via a callback.""" + + def __init__(self) -> None: + self._listener: Optional[keyboard.Listener] = None + self._dispatch: Optional[KeyDispatch] = None + + def start(self, dispatch: KeyDispatch) -> None: + self._dispatch = dispatch + if self._listener is not None: + return + + def _emit(event_key, event_type: str) -> None: + token = _token_from_key(event_key) + if not token or self._dispatch is None: + return + try: + self._dispatch(token, event_type) + except Exception: + pass + + self._listener = keyboard.Listener( + on_press=lambda key: _emit(key, "press"), + on_release=lambda key: _emit(key, "release"), + ) + self._listener.start() + + def stop(self) -> None: + if self._listener is None: + return + try: + self._listener.stop() + finally: + self._listener = None + self._dispatch = None diff --git a/record/gum/observers/fallback/mouse.py b/record/gum/observers/fallback/mouse.py new file mode 100644 index 0000000..f5a6569 --- /dev/null +++ b/record/gum/observers/fallback/mouse.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Optional + +from pynput import mouse as pynput_mouse + +from ..base.mouse import MouseBackend, MouseClickHandler, MouseScrollHandler + + +class PynputMouseBackend(MouseBackend): + """Mouse backend built on ``pynput``'s global listener.""" + + def __init__(self) -> None: + self._listener: Optional[pynput_mouse.Listener] = None + + def start(self, on_click: MouseClickHandler, on_scroll: MouseScrollHandler) -> None: + if self._listener is not None: + return + + def handle_click(x, y, button, pressed): + if not pressed: + return + name = getattr(button, "name", str(button)) + on_click(x, y, f"click_{name}") + + self._listener = pynput_mouse.Listener( + on_click=handle_click, + on_scroll=lambda x, y, dx, dy: on_scroll(x, y, dx, dy), + ) + self._listener.start() + + def stop(self) -> None: + if self._listener is None: + return + try: + self._listener.stop() + finally: + self._listener = None diff --git a/record/gum/observers/fallback/screenshots.py b/record/gum/observers/fallback/screenshots.py new file mode 100644 index 0000000..dafc010 --- /dev/null +++ b/record/gum/observers/fallback/screenshots.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from ..base.screenshots import MssScreenshotBackend + + +class FallbackScreenshotBackend(MssScreenshotBackend): + """Default cross-platform screenshot backend using ``mss``.""" + + def __init__(self) -> None: + super().__init__() diff --git a/record/gum/observers/macos/__init__.py b/record/gum/observers/macos/__init__.py new file mode 100644 index 0000000..f940463 --- /dev/null +++ b/record/gum/observers/macos/__init__.py @@ -0,0 +1,19 @@ +"""macOS-specific observer implementations.""" + +from .keyboard import AppKitKeyboardBackend, MacKeyboardBackend +from .mouse import MacMouseBackend +from .screenshots import MacScreenshotBackend, is_app_visible +from .app_and_browser_logging import MacOSAppAndBrowserInspector, check_automation_permission_granted + +AppleUIInspector = MacOSAppAndBrowserInspector + +__all__ = [ + "AppKitKeyboardBackend", + "MacKeyboardBackend", + "MacMouseBackend", + "MacScreenshotBackend", + "is_app_visible", + "MacOSAppAndBrowserInspector", + "AppleUIInspector", + "check_automation_permission_granted", +] diff --git a/record/gum/observers/macos/app_and_browser_logging.py b/record/gum/observers/macos/app_and_browser_logging.py new file mode 100644 index 0000000..aa15e1f --- /dev/null +++ b/record/gum/observers/macos/app_and_browser_logging.py @@ -0,0 +1,1056 @@ +from __future__ import annotations + +import logging +import subprocess +import time +from dataclasses import dataclass, field +from typing import Optional + +try: + import Quartz +except Exception: + Quartz = None + +try: + from AppKit import NSWorkspace, NSRunningApplication +except Exception: + NSWorkspace = None + NSRunningApplication = None + + +__all__ = [ + "MacOSAppAndBrowserInspector", + "check_automation_permission_granted", +] + + +SYSTEM_WINDOW_OWNERS = {"dock", "windowserver", "window server"} +OVERLAY_WINDOW_OWNERS = { + "control center", + "notification center", + "notificationcentre", + "notificationcentreui", + "screen recording", + "screenrecording", + "screenrecordingindicator", + "screenrecordingui", + "screencapture", + "screenshotui", + "siri", +} + + +def _visible_tab_jxa_script(app_title: str, tab_accessor: str) -> str: + """Generate a JavaScript for Automation (JXA) script to get the URL of the foremost visible tab. + + This function creates a robust JXA script that finds the most prominent browser window + and extracts the URL from its active tab. It's used for browsers that support JXA + automation and provides more reliable tab detection than simple AppleScript. + + The script handles various edge cases: + - Multiple windows with different visibility states + - Minimized or hidden windows + - Different tab accessor methods (activeTab vs currentTab) + + Args: + app_title: The application name (e.g., "Safari", "Google Chrome") + tab_accessor: The method to access the active tab (e.g., "activeTab", "currentTab") + + Returns: + A complete JXA script as a string + """ + return ( + "jxa:(() => {\n" + f" const app = Application(\"{app_title}\");\n" + " if (!app.running()) { return null; }\n" + " const wins = app.windows();\n" + " if (!wins.length) { return null; }\n" + " const ordered = wins\n" + " .map(win => ({\n" + " win,\n" + " order: (() => {\n" + " try { return win.index(); } catch (err) { return Number.POSITIVE_INFINITY; }\n" + " })(),\n" + " }))\n" + " .sort((a, b) => a.order - b.order);\n" + " for (const entry of ordered) {\n" + " const win = entry.win;\n" + " try { if (win.miniaturized && win.miniaturized()) { continue; } } catch (err) {}\n" + " try { if (win.visible && !win.visible()) { continue; } } catch (err) {}\n" + f" if (typeof win.{tab_accessor} !== 'function') {{ continue; }}\n" + " let tab = null;\n" + f" try {{ tab = win.{tab_accessor}(); }} catch (err) {{ tab = null; }}\n" + " if (!tab) { continue; }\n" + " try {\n" + " if (typeof tab.url === 'function') { return tab.url(); }\n" + " if (typeof tab.URL === 'function') { return tab.URL(); }\n" + " } catch (err) {}\n" + " }\n" + " return null;\n" + "})();" + ) + + +def _chromium_jxa_script(app_title: str) -> str: + """Generate a JXA script for Chromium-based browsers (Chrome, Edge, Brave, etc.). + + Chromium-based browsers use 'activeTab' to access the current tab, which is different + from Safari's 'currentTab' method. This function creates the appropriate script + for this browser family. + + Args: + app_title: The application name (e.g., "Google Chrome", "Microsoft Edge") + + Returns: + A JXA script configured for Chromium-based browsers + """ + return _visible_tab_jxa_script(app_title, "activeTab") + + +def _safari_jxa_script(app_title: str) -> str: + """Generate a JXA script for Safari-based browsers. + + Safari uses 'currentTab' to access the current tab, which is different from + Chromium-based browsers. This function creates the appropriate script for + Safari and Safari Technology Preview. + + Args: + app_title: The application name (e.g., "Safari", "Safari Technology Preview") + + Returns: + A JXA script configured for Safari-based browsers + """ + return _visible_tab_jxa_script(app_title, "currentTab") + + +def _focused_app_via_accessibility() -> tuple[Optional[str], Optional[str], Optional[int]]: + """Attempt to resolve the focused app using the accessibility API + (i.e. the app that is currently in the user's foreground, on the monitor they are currently using... this is tricky...). + + This function is the primary method for detecting which application currently has + keyboard focus on macOS. It uses the accessibility API to get the most accurate + information about the focused application, which is crucial for understanding + user context and activity tracking. + + The function implements a multi-layered approach with extensive error handling: + + 1. **System-wide Element Creation**: Creates a system-wide accessibility element + that can query the focused application. This can fail if accessibility + permissions are not granted. + + 2. **Focused Application Query**: Uses AXUIElementCopyAttributeValue to get the + currently focused application reference. Handles different API signatures + across macOS versions (3-parameter vs 2-parameter versions). + + 3. **Process ID Extraction**: Gets the process ID from the application reference, + which is needed for further app information retrieval. + + 4. **App Information via NSRunningApplication**: Uses AppKit to get detailed + app information (name, bundle ID) from the process ID. This is the most + reliable method when available. + + 5. **Window List Fallback**: If AppKit fails, falls back to querying the window + list to find the app name by matching the process ID to window owners. + + The extensive exception handling ensures the function never crashes, even when: + - Accessibility permissions are denied + - Apps quit or change focus during execution + - Different macOS versions have API variations + - System is in various states (sleep, locked, etc.) + + Returns: + Tuple of (app_name, bundle_id, pid): + - app_name: Human-readable application name (e.g., "Safari", "Google Chrome") + - bundle_id: Application bundle identifier (e.g., "com.apple.safari") + - pid: Process ID of the application + + All values may be None if detection fails at any stage. + """ + + if Quartz is None: + return None, None, None + + required = ( + "AXUIElementCreateSystemWide", + "AXUIElementCopyAttributeValue", + "AXUIElementGetPid", + "kAXFocusedApplicationAttribute", + "kAXErrorSuccess", + ) + if any(not hasattr(Quartz, attr) for attr in required): + return None, None, None + + try: + system = Quartz.AXUIElementCreateSystemWide() + except Exception: # Catches accessibility permission denied, system-level failures + return None, None, None + + if system is None: + return None, None, None + + try: + result = Quartz.AXUIElementCopyAttributeValue( + system, + Quartz.kAXFocusedApplicationAttribute, + None, + ) + except TypeError: # Catches API signature mismatch (3-param vs 2-param versions) + try: + result = Quartz.AXUIElementCopyAttributeValue( + system, + Quartz.kAXFocusedApplicationAttribute, + ) + except Exception: # Catches any failure in the 2-parameter fallback + return None, None, None + except Exception: # Catches accessibility API failures, permission issues + return None, None, None + + if isinstance(result, tuple): + app_ref, error = result + if error not in (None, Quartz.kAXErrorSuccess): + return None, None, None + else: + app_ref = result + + if app_ref is None: + return None, None, None + + try: + pid = Quartz.AXUIElementGetPid(app_ref) + except Exception: # Catches invalid app reference, terminated processes + pid = None + + if not pid: + return None, None, None + + name: Optional[str] = None + bundle: Optional[str] = None + + if NSRunningApplication is not None: + try: + app = NSRunningApplication.runningApplicationWithProcessIdentifier_(pid) + except Exception: # Catches invalid PID, app quit, AppKit failures + app = None + if app is not None: + bundle_candidate = (app.bundleIdentifier() or "").strip() or None + name_candidate = (app.localizedName() or "").strip() or None + bundle = bundle_candidate or bundle + name = name_candidate or bundle_candidate or name + if name: + return name, bundle, pid + + try: + opts = ( + Quartz.kCGWindowListOptionOnScreenOnly + | Quartz.kCGWindowListOptionIncludingWindow + ) + wins = Quartz.CGWindowListCopyWindowInfo(opts, Quartz.kCGNullWindowID) or [] + for info in wins: + if info.get("kCGWindowOwnerPID") == pid: + owner = (info.get("kCGWindowOwnerName") or "").strip() + bundle_candidate = (info.get("kCGWindowOwnerBundleIdentifier") or "").strip() or None + if owner: + return owner, (bundle_candidate or bundle), pid + except Exception: # Catches window list API failures, system state issues + pass + + return None, bundle, pid + + +BROWSER_OSA_SCRIPTS: dict[str, tuple[str, ...]] = { + "safari": ( + _safari_jxa_script("Safari"), + 'tell application "Safari" to if (exists front document) then get URL of front document', + 'tell application "Safari" to if (count of documents) > 0 then get URL of front document', + 'jxa:(() => {\n const safari = Application("Safari");\n if (!safari.running()) { return null; }\n const docs = safari.documents();\n return docs.length ? docs[0].url() : null;\n})();', + ), + "safari technology preview": ( + _safari_jxa_script("Safari Technology Preview"), + 'tell application "Safari Technology Preview" to if (exists front document) then get URL of front document', + 'jxa:(() => {\n const safari = Application("Safari Technology Preview");\n if (!safari.running()) { return null; }\n const docs = safari.documents();\n return docs.length ? docs[0].url() : null;\n})();', + ), + "google chrome": ( + _chromium_jxa_script("Google Chrome"), + 'tell application "Google Chrome" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "google chrome beta": ( + _chromium_jxa_script("Google Chrome Beta"), + 'tell application "Google Chrome Beta" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "google chrome dev": ( + _chromium_jxa_script("Google Chrome Dev"), + 'tell application "Google Chrome Dev" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "google chrome canary": ( + _chromium_jxa_script("Google Chrome Canary"), + 'tell application "Google Chrome Canary" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "chromium": ( + _chromium_jxa_script("Chromium"), + 'tell application "Chromium" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "brave browser": ( + _chromium_jxa_script("Brave Browser"), + 'tell application "Brave Browser" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "brave browser beta": ( + _chromium_jxa_script("Brave Browser Beta"), + 'tell application "Brave Browser Beta" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "brave browser nightly": ( + _chromium_jxa_script("Brave Browser Nightly"), + 'tell application "Brave Browser Nightly" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "microsoft edge": ( + _chromium_jxa_script("Microsoft Edge"), + 'tell application "Microsoft Edge" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "microsoft edge beta": ( + _chromium_jxa_script("Microsoft Edge Beta"), + 'tell application "Microsoft Edge Beta" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "microsoft edge dev": ( + _chromium_jxa_script("Microsoft Edge Dev"), + 'tell application "Microsoft Edge Dev" to if (count of windows) > 0 then get URL of active tab of front window', + ), + "arc": ( + _chromium_jxa_script("Arc"), + 'tell application "Arc"\nif (count of windows) = 0 then return missing value\nreturn URL of active tab of front window\nend tell', + 'tell application "Arc"\nif (count of windows) = 0 then return missing value\ntell front window to tell active tab to return URL\nend tell', + 'jxa:(() => {\n const arc = Application("Arc");\n const wins = arc.windows();\n if (!wins.length) { return null; }\n const front = arc.frontWindow();\n if (!front) { return null; }\n const tab = front.activeTab();\n return tab ? tab.url() : null;\n})();', + ), +} + +BROWSER_BUNDLE_IDS: dict[str, str] = { + "com.apple.safari": "safari", + "com.apple.safaritechnologypreview": "safari technology preview", + "com.google.chrome": "google chrome", + "com.google.chrome.beta": "google chrome beta", + "com.google.chrome.dev": "google chrome dev", + "com.google.chrome.canary": "google chrome canary", + "org.chromium.chromium": "chromium", + "com.brave.browser": "brave browser", + "com.brave.browser.beta": "brave browser beta", + "com.brave.browser.nightly": "brave browser nightly", + "com.microsoft.edgemac": "microsoft edge", + "com.microsoft.edgemac.beta": "microsoft edge beta", + "com.microsoft.edgemac.dev": "microsoft edge dev", + "company.thebrowser.browser": "arc", +} + + + +AUTOMATION_PERMISSION_CACHE_TTL = 3.0 +_AUTOMATION_PERMISSION_CACHE: tuple[float, bool | None] | None = None +_AUTOMATION_DENIED_SUBSTRINGS = ( + "-1743", + "not authorized to send apple events", + "not authorised to send apple events", + "errae eventnotpermitted", +) + +_NULL_URL_VALUES = { + "", + "null", + "missing value", + "undefined", +} + + +def _ax_copy_attribute(element, attribute): + """Safely copy an attribute value from an accessibility element. + + This helper function provides robust access to accessibility element attributes + with proper error handling for different API versions and failure modes. + It's used throughout the accessibility-based app detection system. + + Args: + element: The accessibility element to query + attribute: The attribute name to retrieve + + Returns: + The attribute value, or None if retrieval fails + """ + if Quartz is None: + return None + try: + value = Quartz.AXUIElementCopyAttributeValue(element, attribute, None) + except TypeError: + try: + value = Quartz.AXUIElementCopyAttributeValue(element, attribute) + except Exception: + return None + except Exception: + return None + if isinstance(value, tuple): + candidate, error = value + if error not in (None, getattr(Quartz, "kAXErrorSuccess", None)): + return None + value = candidate + return value + + +def _browser_url_via_accessibility(pid: Optional[int]) -> Optional[str]: + """Extract browser URL using accessibility API as a fallback method. + + This function provides an alternative method to get browser URLs when AppleScript + automation fails or is not available. It uses the accessibility API to find the + focused window and extract URL information from browser applications. + + This is particularly useful for: + - Browsers that don't support AppleScript automation + - Systems where automation permissions are denied + - Fallback when primary URL extraction methods fail + + Args: + pid: Process ID of the browser application + + Returns: + The current URL from the browser, or None if extraction fails + """ + if Quartz is None or not pid: + return None + try: + app_elem = Quartz.AXUIElementCreateApplication(pid) + except Exception: + return None + if app_elem is None: + return None + + window = _ax_copy_attribute(app_elem, "AXFocusedWindow") + if window is None: + return None + + for attr in ("AXURL", "AXDocument"): + value = _ax_copy_attribute(window, attr) + if value is None: + continue + if isinstance(value, str): + if value: + return value + continue + try: + string_value = str(value) + except Exception: + continue + if string_value: + return string_value + return None + + +@dataclass +class MacOSAppAndBrowserInspector: + """macOS-specific implementation for detecting applications and extracting browser URLs. + + This class provides comprehensive app and browser detection capabilities for macOS, + enabling the recording system to understand what applications users are interacting + with and what web pages they're viewing. It's a critical component of the user + activity tracking system. + + Key capabilities: + - Detect the currently focused/active application + - Extract URLs from browser applications using AppleScript automation + - Fallback to accessibility API when automation is not available + - Handle automation permission management and error reporting + - Support for major browsers (Safari, Chrome, Edge, Brave, Arc, etc.) + + The class maintains state to optimize performance and provide consistent results + across multiple queries, including caching of browser URLs and tracking of + permission issues. + """ + logger: logging.Logger + last_frontmost_bundle_id: Optional[str] = None + last_frontmost_pid: Optional[int] = None + last_browser_urls: dict[str, str] = field(default_factory=dict) + unknown_browser_apps: set[str] = field(default_factory=set) + browser_script_failures: set[tuple[str, str, int]] = field(default_factory=set) + automation_denied_for: set[str] = field(default_factory=set) + + def app_at_point(self, x: float, y: float) -> tuple[Optional[str], Optional[str]]: + """Determine which application owns the window at the given screen coordinates. + + This method is used to identify which app the user is interacting with when they + click or interact at specific screen coordinates. It's part of the app detection + system that helps track user activity across different applications. + + Args: + x: Screen X coordinate + y: Screen Y coordinate + + Returns: + Tuple of (app_name, bundle_id) for the app at the given point, or (None, None) if not found + """ + if Quartz is None: + return None, None + + try: + wins = Quartz.CGWindowListCopyWindowInfo( + Quartz.kCGWindowListOptionOnScreenOnly | Quartz.kCGWindowListOptionIncludingWindow, + Quartz.kCGNullWindowID, + ) or [] + except Exception: + return None, None + + for info in wins: + owner = info.get("kCGWindowOwnerName") or "" + owner_lower = owner.lower() + if not owner or owner_lower in SYSTEM_WINDOW_OWNERS or owner_lower in OVERLAY_WINDOW_OWNERS: + continue + + bounds = info.get("kCGWindowBounds") or {} + left = int(bounds.get("X", 0) or 0) + top = int(bounds.get("Y", 0) or 0) + width = int(bounds.get("Width", 0) or 0) + height = int(bounds.get("Height", 0) or 0) + if width <= 0 or height <= 0: + continue + + if not (left <= int(x) < left + width and top <= int(y) < top + height): + continue + + bundle = (info.get("kCGWindowOwnerBundleIdentifier") or "").strip() or None + if bundle: + self.last_frontmost_bundle_id = bundle.lower() + pid_candidate = info.get("kCGWindowOwnerPID") + if pid_candidate is not None: + try: + self.last_frontmost_pid = int(pid_candidate) + except Exception: + self.last_frontmost_pid = None + return owner, bundle + + return None, None + + def get_frontmost_app_name(self) -> Optional[str]: + """Get the name of the currently focused/active application. + + This is the primary method for determining which app the user is currently using. + It tries multiple approaches in order of reliability: + 1. Accessibility API (most accurate for focused apps) + 2. NSWorkspace frontmostApplication (AppKit fallback) + 3. Window list analysis (finds largest visible window) + + The result is used throughout the system to track which application context + user interactions are happening in, enabling proper app-specific logging. + + Returns: + The name of the frontmost application, or None if detection fails + """ + name, bundle, pid = _focused_app_via_accessibility() + if bundle: + self.last_frontmost_bundle_id = bundle.lower() + if pid: + self.last_frontmost_pid = pid + if name: + return name + + if NSWorkspace is not None: + try: + app = NSWorkspace.sharedWorkspace().frontmostApplication() + if app is not None: + bundle_id = (app.bundleIdentifier() or "").strip() or None + self.last_frontmost_bundle_id = bundle_id.lower() if bundle_id else None + name = (app.localizedName() or bundle_id or "").strip() + if name: + return name + except Exception: + self.last_frontmost_bundle_id = None + + if Quartz is None: + self.last_frontmost_bundle_id = None + return None + + try: + opts = ( + Quartz.kCGWindowListOptionOnScreenOnly + | Quartz.kCGWindowListOptionIncludingWindow + ) + wins = Quartz.CGWindowListCopyWindowInfo(opts, Quartz.kCGNullWindowID) or [] + except Exception: + self.last_frontmost_bundle_id = None + return None + + self.last_frontmost_bundle_id = None + self.last_frontmost_pid = None + + def record_owner(owner: str, info: dict) -> str: + bundle_candidate = (info.get("kCGWindowOwnerBundleIdentifier") or "").strip() or None + pid_candidate = info.get("kCGWindowOwnerPID") + if bundle_candidate: + self.last_frontmost_bundle_id = bundle_candidate.lower() + if pid_candidate is not None: + try: + self.last_frontmost_pid = int(pid_candidate) + except Exception: + self.last_frontmost_pid = None + return owner + + for info in wins: + owner = info.get("kCGWindowOwnerName") or "" + owner_lower = owner.lower() + if not owner or owner_lower in SYSTEM_WINDOW_OWNERS or owner_lower in OVERLAY_WINDOW_OWNERS: + continue + + layer = info.get("kCGWindowLayer", 0) + if layer != 0: + continue + + alpha = info.get("kCGWindowAlpha", 1) + if alpha == 0: + continue + + bounds = info.get("kCGWindowBounds") or {} + width = max(int(bounds.get("Width", 0) or 0), 0) + height = max(int(bounds.get("Height", 0) or 0), 0) + if width < 32 or height < 32: + continue + + return record_owner(owner, info) + + candidate_name: Optional[str] = None + candidate_info: Optional[dict] = None + candidate_area = 0 + + for info in wins: + owner = info.get("kCGWindowOwnerName") or "" + owner_lower = owner.lower() + if not owner or owner_lower in SYSTEM_WINDOW_OWNERS or owner_lower in OVERLAY_WINDOW_OWNERS: + continue + + bounds = info.get("kCGWindowBounds") or {} + width = max(int(bounds.get("Width", 0) or 0), 0) + height = max(int(bounds.get("Height", 0) or 0), 0) + area = width * height + if area <= 0: + continue + + if area > candidate_area: + candidate_name = owner + candidate_info = info + candidate_area = area + + if candidate_name and candidate_info: + return record_owner(candidate_name, candidate_info) + + for info in wins: + owner = info.get("kCGWindowOwnerName") or "" + owner_lower = owner.lower() + if owner and owner_lower not in SYSTEM_WINDOW_OWNERS and owner_lower not in OVERLAY_WINDOW_OWNERS: + return record_owner(owner, info) + + return None + + def get_browser_url(self, app_name: Optional[str]) -> Optional[str]: + """Get the current URL from a browser application. + + This method is the core of browser URL tracking functionality. It determines + what webpage the user is currently viewing in their browser, which is crucial + for understanding the context of user interactions (e.g., what site they're + browsing when they click or type). + + The method tries multiple approaches: + 1. AppleScript automation (primary method for most browsers) + 2. Accessibility API fallback (for browsers that don't support AppleScript) + 3. Cached results (for performance) + + Args: + app_name: Name of the browser application to query + + Returns: + The current URL being viewed, or None if not available or not a browser + """ + if not app_name: + return None + + log = self.logger + key = self._resolve_browser_key(app_name) + if key is None: + signature = f"{app_name}|{self.last_frontmost_bundle_id or ''}" + if signature not in self.unknown_browser_apps: + self.unknown_browser_apps.add(signature) + log.info( + "No browser URL mapping available for '%s' (bundle=%s)", + app_name, + self.last_frontmost_bundle_id, + ) + return None + + result = self._run_browser_scripts(app_name, key) + if result: + self.last_browser_urls[key] = result + return result + + fallback = _browser_url_via_accessibility(self.last_frontmost_pid) + if fallback: + self.last_browser_urls[key] = fallback + return fallback + + return self.last_browser_urls.get(key) + + def _run_browser_scripts(self, app_name: str, key: str) -> Optional[str]: + """Execute AppleScript commands to extract the current URL from a browser. + + This method handles the low-level execution of AppleScript commands that query + browser applications for their current URL. It's a critical component of the + browser URL tracking system, handling multiple script formats (AppleScript and JXA) + and managing automation permissions. + + The method includes comprehensive error handling for: + - Automation permission denials + - Script execution failures + - Timeout handling + - Different browser script formats + + Args: + app_name: Human-readable name of the browser application + key: Internal key identifying the browser type (e.g., 'safari', 'google chrome') + + Returns: + The current URL from the browser, or None if extraction fails + """ + log = self.logger + scripts = BROWSER_OSA_SCRIPTS.get(key, ()) + if not scripts: + if key not in self.unknown_browser_apps: + self.unknown_browser_apps.add(key) + log.info("No AppleScript candidates registered for '%s'", key) + return None + + try: + for script in scripts: + cmd = ["osascript"] + if script.startswith("jxa:"): + body = script[4:] + cmd.extend(["-l", "JavaScript", "-e", body]) + else: + cmd.extend(["-e", script]) + + out = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=0.75, + check=False, + ) + stdout = (out.stdout or "").strip() + if out.returncode == 0 and stdout and stdout.strip().lower() not in _NULL_URL_VALUES: + _update_automation_cache(True) + self.last_browser_urls[key] = stdout + return stdout + + failure_sig = (key, script, out.returncode) + if failure_sig not in self.browser_script_failures: + self.browser_script_failures.add(failure_sig) + stderr = (out.stderr or "").strip() + if _automation_denied_from_stderr(stderr): + _update_automation_cache(False) + if key not in self.automation_denied_for: + self.automation_denied_for.add(key) + log.warning( + "Automation permission denied when querying '%s'. " + "Enable Gum Recorder (or osascript) in System Settings → Privacy & Security → Automation.", + app_name, + ) + return None + level = logging.WARNING if out.returncode not in (0, -128) else logging.DEBUG + log.log( + level, + "Browser URL script failed for '%s'. key=%s exit=%s stderr=%s", + app_name, + key, + out.returncode, + stderr, + ) + return None + except FileNotFoundError: + log.warning("'osascript' binary not found when attempting to read browser URL") + return None + except subprocess.TimeoutExpired: + log.debug("Timed out while fetching browser URL for '%s'", app_name) + return None + except Exception as exc: + log.debug("Unexpected error fetching browser URL for '%s': %s", app_name, exc) + return None + + def running_browser_applications(self) -> list[tuple[str, str, str]]: + """Get a list of currently running browser applications. + + This method scans the system for running browser applications that are known + to the system. It's used for automation permission testing and for determining + which browsers are available for URL tracking. + + The method identifies browsers by their bundle identifiers and returns + standardized information about each running browser. + + Returns: + List of tuples containing (app_name, browser_key, bundle_id) for each running browser + """ + if NSWorkspace is None: + return [] + try: + workspace = NSWorkspace.sharedWorkspace() + running = workspace.runningApplications() + except Exception: + return [] + + seen: set[str] = set() + result: list[tuple[str, str, str]] = [] + for app in running: + try: + bundle_id = (app.bundleIdentifier() or "").strip().lower() + if not bundle_id: + continue + key = BROWSER_BUNDLE_IDS.get(bundle_id) + if not key or key in seen: + continue + name = app.localizedName() or bundle_id + seen.add(key) + result.append((name, key, bundle_id)) + except Exception: + continue + return result + + def prime_automation_for_running_browsers(self) -> bool: + """Test automation permissions for running browsers to detect permission issues early. + + This method proactively tests AppleScript automation permissions by attempting + to query each running browser for its URL. This helps identify permission issues + before the user starts interacting with browsers, providing better error messages + and user experience. + + The method is typically called during system initialization to: + 1. Detect if automation permissions are granted + 2. Cache permission status for performance + 3. Provide early warning about permission issues + + Returns: + True if automation testing was attempted, False if no browsers were running + """ + log = self.logger + granted = check_automation_permission_granted() + if granted: + log.debug("Automation permission already granted; skipping browser preflight") + return True + + attempted = False + for app_name, key, bundle_id in self.running_browser_applications(): + attempted = True + previous_bundle = self.last_frontmost_bundle_id + try: + self.last_frontmost_bundle_id = bundle_id + self._run_browser_scripts(app_name, key) + finally: + self.last_frontmost_bundle_id = previous_bundle + + if not attempted: + frontmost = self.get_frontmost_app_name() + if frontmost: + key = self._resolve_browser_key(frontmost) + if key: + self._run_browser_scripts(frontmost, key) + attempted = True + else: + log.debug("Frontmost app '%s' is not a known browser for automation preflight", frontmost) + else: + log.debug("No running browsers detected when attempting automation preflight") + + return attempted + + def _resolve_browser_key(self, app_name: str) -> Optional[str]: + """Convert an application name to a standardized browser key for script lookup. + + This method normalizes application names to match the keys used in the + BROWSER_OSA_SCRIPTS dictionary. It handles various naming variations and + also checks bundle identifiers as a fallback method for identification. + + The normalization process: + 1. Converts to lowercase and removes ".app" suffix + 2. Checks against known browser names + 3. Falls back to bundle identifier matching + 4. Handles variations like "Google Chrome Browser" -> "google chrome" + + Args: + app_name: The application name to resolve + + Returns: + The standardized browser key, or None if not a known browser + """ + normalized = " ".join(app_name.lower().replace(".app", "").split()) + if normalized in BROWSER_OSA_SCRIPTS: + return normalized + + bundle = (self.last_frontmost_bundle_id or "").strip().lower() + if bundle: + if bundle in BROWSER_BUNDLE_IDS: + return BROWSER_BUNDLE_IDS[bundle] + parts = bundle.split('.') + while len(parts) > 2: + parts = parts[:-1] + candidate = '.'.join(parts) + if candidate in BROWSER_BUNDLE_IDS: + return BROWSER_BUNDLE_IDS[candidate] + + if normalized.endswith(" browser"): + simplified = normalized[:-8] + if simplified in BROWSER_OSA_SCRIPTS: + return simplified + + return None + + + + + +#### +## four helper functions for checking permission status to help more naive users understand whether their app has permissions. + +def _update_automation_cache(value: bool | None) -> None: + """Update the global automation permission cache with a new value and timestamp. + + This function manages the caching of automation permission status to avoid + repeatedly testing permissions, which can be expensive and slow. The cache + includes both the permission status and a timestamp for TTL (time-to-live) + management. + + The cache is used to optimize performance by: + - Avoiding redundant permission checks within a short time window + - Providing quick responses for repeated permission queries + - Reducing the overhead of AppleScript execution for permission testing + + Args: + value: The automation permission status to cache: + - True: Automation permissions are granted + - False: Automation permissions are denied + - None: Permission status is unknown/undetermined + """ + global _AUTOMATION_PERMISSION_CACHE + _AUTOMATION_PERMISSION_CACHE = (time.monotonic(), value) + + +def _read_automation_cache(force_refresh: bool = False) -> tuple[bool | None, bool]: + """Read the automation permission status from the cache with TTL validation. + + This function retrieves the cached automation permission status if it's still + valid (within the TTL window). It provides a way to avoid expensive permission + checks when the status is already known and recent. + + The cache has a TTL (time-to-live) of AUTOMATION_PERMISSION_CACHE_TTL seconds + to balance performance with accuracy, since permission status can change + (e.g., user grants/revokes permissions). + + Args: + force_refresh: If True, ignore cache and return (None, False) to force + a fresh permission check + + Returns: + Tuple of (cached_value, cache_hit): + - cached_value: The cached permission status (True/False/None) or None if cache miss + - cache_hit: Boolean indicating whether a valid cached value was found + """ + if force_refresh: + return None, False + cache = _AUTOMATION_PERMISSION_CACHE + if cache is None: + return None, False + ts, value = cache + if time.monotonic() - ts <= AUTOMATION_PERMISSION_CACHE_TTL: + return value, True + return None, False + + +def _automation_denied_from_stderr(stderr: str) -> bool: + """Detect if AppleScript automation was denied based on stderr output. + + This function analyzes the stderr output from AppleScript execution to determine + if the failure was due to automation permission denial. macOS returns specific + error codes and messages when automation is not permitted, which this function + recognizes. + + The function checks for several indicators of permission denial: + - Error code -1743 (standard macOS automation denial error) + - Various text patterns indicating permission issues + - Localized error messages in different languages + + This detection is crucial for providing appropriate user feedback and avoiding + repeated failed attempts when permissions are known to be denied. + + Args: + stderr: The stderr output from AppleScript execution + + Returns: + True if the stderr indicates automation permission was denied, False otherwise + """ + if not stderr: + return False + if "-1743" in stderr: + return True + lowered = stderr.lower() + return any(token in lowered for token in _AUTOMATION_DENIED_SUBSTRINGS) + + +def check_automation_permission_granted(force_refresh: bool = False) -> bool | None: + """Check if AppleScript automation permissions are granted for browser applications. + + This function determines whether the system has been granted automation permissions + to control browser applications via AppleScript. It's a critical check for the + browser URL tracking functionality, as automation permissions are required to + extract URLs from browsers. + + The function uses a multi-step approach: + + 1. **Cache Check**: First checks if permission status is cached and still valid + to avoid expensive repeated checks. + + 2. **Browser Detection**: Scans for currently running browser applications that + support automation. + + 3. **Permission Testing**: Attempts to execute simple AppleScript commands on + each running browser to test if automation is permitted. + + 4. **Result Caching**: Caches the result to optimize future checks. + + The function handles various scenarios: + - No browsers running: Returns None (unknown status) + - All browsers deny automation: Returns False + - At least one browser allows automation: Returns True + - Mixed results: Returns None (inconclusive) + + This check is typically performed during system initialization to provide early + feedback about permission issues and optimize subsequent browser URL queries. + + Args: + force_refresh: If True, bypass cache and perform a fresh permission check + + Returns: + - True: Automation permissions are granted (at least one browser allows it) + - False: Automation permissions are denied (all browsers deny it) + - None: Permission status is unknown (no browsers running or inconclusive results) + """ + cached_value, cache_hit = _read_automation_cache(force_refresh) + if cache_hit: + return cached_value + + inspector = MacOSAppAndBrowserInspector(logging.getLogger("gum.automation_probe")) + running = inspector.running_browser_applications() + + denied_detected = False + attempted = False + + for app_name, key, bundle_id in running: + attempted = True + inspector.last_frontmost_bundle_id = bundle_id + result = inspector._run_browser_scripts(app_name, key) + if result is not None: + return True + + cached_value, cache_hit = _read_automation_cache() + if cache_hit and cached_value is False: + denied_detected = True + break + + if denied_detected: + return False + + if attempted: + _update_automation_cache(None) + return None diff --git a/record/gum/observers/macos/keyboard.py b/record/gum/observers/macos/keyboard.py new file mode 100644 index 0000000..6a60fc1 --- /dev/null +++ b/record/gum/observers/macos/keyboard.py @@ -0,0 +1,197 @@ +from __future__ import annotations + +import logging +import os +from typing import List, Optional + +try: # pragma: no cover - AppKit only available on macOS + from AppKit import ( + NSEvent, + NSEventMaskKeyDown, + NSEventMaskKeyUp, + ) +except Exception: # pragma: no cover - enables import on non-macOS platforms + NSEvent = None + NSEventMaskKeyDown = 0 + NSEventMaskKeyUp = 0 + +try: # pragma: no cover - optional accessibility check + import Quartz +except Exception: # pragma: no cover + Quartz = None + +from ..base.keyboard import KeyDispatch, KeyboardBackend +from ..fallback.keyboard import PynputKeyboardBackend + +log = logging.getLogger("Screen.MacKeyboard") + + +def _token_from_event(event) -> tuple[str, Optional[str]]: + tok: Optional[str] + try: + txt = event.characters() or event.charactersIgnoringModifiers() or "" + except Exception: + txt = "" + try: + vk: Optional[int] = int(event.keyCode()) + except Exception: + vk = None + + if txt: + tok = f"TEXT:{txt}" + elif vk is not None: + tok = f"VK:{vk}" + else: + tok = "KEY:unknown" + return tok, txt + + +def event_token_from_nsevent(event) -> str: + """Public helper to derive our key token string from an AppKit NSEvent. + + Centralizes token generation so other components (e.g., GUI main-thread shim) + can reuse identical logic without duplication. + """ + tok, _ = _token_from_event(event) + return tok + + +def register_appkit_key_monitors(on_key_down, on_key_up) -> list[object]: + """Register global AppKit key down/up monitors and return monitor handles. + + Raises RuntimeError if AppKit is unavailable. + """ + if NSEvent is None: + raise RuntimeError("AppKit is not available on this platform") + monitors: list[object] = [ + NSEvent.addGlobalMonitorForEventsMatchingMask_handler_(NSEventMaskKeyDown, on_key_down), + NSEvent.addGlobalMonitorForEventsMatchingMask_handler_(NSEventMaskKeyUp, on_key_up), + ] + return monitors + + +def remove_appkit_monitors(monitors: list[object]) -> None: + if not monitors or NSEvent is None: + return + for monitor in monitors: + try: + NSEvent.removeMonitor_(monitor) + except Exception: + pass + + +class AppKitKeyboardBackend(KeyboardBackend): + """Keyboard backend built on AppKit global event monitoring.""" + + def __init__(self) -> None: + self._monitors: List[object] = [] + self._dispatch: Optional[KeyDispatch] = None + + @staticmethod + def supported() -> bool: + return NSEvent is not None + + def start(self, dispatch: KeyDispatch) -> None: + if NSEvent is None: + raise RuntimeError("AppKit is not available on this platform") + if self._monitors: + return + + self._dispatch = dispatch + + def emit(event, kind: str) -> None: + if self._dispatch is None: + return + try: + tok, _ = _token_from_event(event) + if tok: + self._dispatch(tok, kind) + except Exception as exc: + log.debug(f"Failed to dispatch AppKit {kind} event: {exc}") + + try: + self._monitors = register_appkit_key_monitors(lambda ev: emit(ev, "press"), lambda ev: emit(ev, "release")) + except Exception as exc: + self.stop() + raise RuntimeError(f"Failed to register AppKit monitors: {exc}") from exc + + log.info("Keyboard monitoring enabled (AppKit)") + + def stop(self) -> None: + if not self._monitors or NSEvent is None: + self._monitors = [] + return + try: + remove_appkit_monitors(self._monitors) + finally: + self._monitors = [] + self._dispatch = None + + +class MacKeyboardBackend(KeyboardBackend): + """macOS keyboard backend that prefers AppKit but falls back to pynput.""" + + def __init__(self) -> None: + self._appkit = AppKitKeyboardBackend() if AppKitKeyboardBackend.supported() else None + self._pynput = PynputKeyboardBackend() + self._active: KeyboardBackend | None = None + + @staticmethod + def _has_appkit_access() -> bool: + if Quartz is None: + return False + try: + mask_fn = getattr(Quartz, "CGEventMaskBit", None) + if mask_fn is None: + return False + mask = mask_fn(Quartz.kCGEventKeyDown) | mask_fn(Quartz.kCGEventKeyUp) + return bool(Quartz.CGPreflightListenEventAccess(mask)) + except Exception: + return False + + def start(self, dispatch: KeyDispatch) -> None: + prefer_raw = os.environ.get("GUM_KEYBOARD_BACKEND", "auto") + prefer = (prefer_raw or "auto").strip().lower() + if prefer not in {"auto", "appkit", "pynput"}: + prefer = "auto" + + attempted_appkit = False + if self._appkit is not None and prefer != "pynput": + attempted_appkit = True + has_access = self._has_appkit_access() + if not has_access and prefer == "auto": + log.debug( + "AppKit keyboard backend preflight denied accessibility permission; using pynput instead" + ) + else: + if not has_access: + log.warning( + "AppKit keyboard backend requested without accessibility permission; attempting anyway" + ) + try: + self._appkit.start(dispatch) + except Exception as exc: + log.warning(f"AppKit keyboard backend unavailable, falling back to pynput: {exc}") + self._appkit.stop() + else: + self._active = self._appkit + return + + if prefer == "appkit": + log.warning("AppKit keyboard backend requested but unavailable; using pynput fallback") + elif prefer == "pynput": + log.info("Keyboard monitoring forced to pynput backend") + + self._pynput.start(dispatch) + self._active = self._pynput + if attempted_appkit and prefer != "pynput": + log.info("Keyboard monitoring enabled (pynput fallback)") + else: + log.info("Keyboard monitoring enabled (pynput)") + + def stop(self) -> None: + try: + if self._active is not None: + self._active.stop() + finally: + self._active = None diff --git a/record/gum/observers/macos/mouse.py b/record/gum/observers/macos/mouse.py new file mode 100644 index 0000000..ba534b9 --- /dev/null +++ b/record/gum/observers/macos/mouse.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from ..base.mouse import MouseBackend, MouseClickHandler, MouseScrollHandler +from ..fallback.mouse import PynputMouseBackend + + +class MacMouseBackend(MouseBackend): + """macOS mouse backend delegating to the shared pynput implementation.""" + + def __init__(self) -> None: + self._delegate = PynputMouseBackend() + + def start(self, on_click: MouseClickHandler, on_scroll: MouseScrollHandler) -> None: + self._delegate.start(on_click, on_scroll) + + def stop(self) -> None: + self._delegate.stop() diff --git a/record/gum/observers/macos/screenshots.py b/record/gum/observers/macos/screenshots.py new file mode 100644 index 0000000..c7cefb1 --- /dev/null +++ b/record/gum/observers/macos/screenshots.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Iterable, List, Tuple + +import Quartz +from shapely.geometry import box +from shapely.ops import unary_union + +from ..base.screenshots import MssScreenshotBackend + +__all__ = [ + "MacScreenshotBackend", + "is_app_visible", +] + + +class MacScreenshotBackend(MssScreenshotBackend): + """macOS-specific hook point for screenshot capture.""" + + # Currently inherits behaviour from the generic MSS backend. + # Reserved for future macOS tweaks (e.g. Retina scaling adjustments). + pass + + +def _get_global_bounds() -> Tuple[float, float, float, float]: + err, ids, cnt = Quartz.CGGetActiveDisplayList(16, None, None) + if err != Quartz.kCGErrorSuccess: + raise OSError(f"CGGetActiveDisplayList failed: {err}") + + min_x = min_y = float("inf") + max_x = max_y = -float("inf") + for did in ids[:cnt]: + r = Quartz.CGDisplayBounds(did) + x0, y0 = r.origin.x, r.origin.y + x1, y1 = x0 + r.size.width, y0 + r.size.height + min_x, min_y = min(min_x, x0), min(min_y, y0) + max_x, max_y = max(max_x, x1), max(max_y, y1) + return min_x, min_y, max_x, max_y + + +def _get_visible_windows() -> List[tuple[dict, float]]: + _, _, _, gmax_y = _get_global_bounds() + + opts = ( + Quartz.kCGWindowListOptionOnScreenOnly + | Quartz.kCGWindowListOptionIncludingWindow + ) + wins = Quartz.CGWindowListCopyWindowInfo(opts, Quartz.kCGNullWindowID) + + occupied = None + result: List[tuple[dict, float]] = [] + + for info in wins: + owner = info.get("kCGWindowOwnerName", "") + if owner in ("Dock", "WindowServer", "Window Server"): + continue + + bounds = info.get("kCGWindowBounds", {}) + x, y, w, h = ( + bounds.get("X", 0), + bounds.get("Y", 0), + bounds.get("Width", 0), + bounds.get("Height", 0), + ) + if w <= 0 or h <= 0: + continue + + inv_y = gmax_y - y - h + poly = box(x, inv_y, x + w, inv_y + h) + if poly.is_empty: + continue + + visible = poly if occupied is None else poly.difference(occupied) + if not visible.is_empty: + ratio = visible.area / poly.area + result.append((info, ratio)) + occupied = poly if occupied is None else unary_union([occupied, poly]) + + return result + + +def is_app_visible(names: Iterable[str]) -> bool: + targets = set(names) + return any( + info.get("kCGWindowOwnerName", "") in targets and ratio > 0 + for info, ratio in _get_visible_windows() + ) diff --git a/record/gum/observers/utils.py b/record/gum/observers/utils.py new file mode 100644 index 0000000..085e330 --- /dev/null +++ b/record/gum/observers/utils.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import os +import shutil +from typing import Optional + +from pydrive.auth import GoogleAuth +from pydrive.drive import GoogleDrive + + +def initialize_google_drive(client_secrets_path: Optional[str] = None) -> GoogleDrive: + """Initialise Google Drive authentication with an optional secrets path.""" + gauth = GoogleAuth() + + if client_secrets_path: + client_secrets_path = os.path.abspath(os.path.expanduser(client_secrets_path)) + if not os.path.exists(client_secrets_path): + raise FileNotFoundError(f"Client secrets file not found: {client_secrets_path}") + + temp_client_secrets = "client_secrets.json" + try: + shutil.copy2(client_secrets_path, temp_client_secrets) + gauth.LocalWebserverAuth() + finally: + try: + os.remove(temp_client_secrets) + except OSError: + pass + else: + gauth.LocalWebserverAuth() + + return GoogleDrive(gauth) + + +def list_folders(drive: GoogleDrive): + """List all folders in Google Drive to help find folder IDs.""" + folders = drive.ListFile({'q': "mimeType='application/vnd.google-apps.folder' and trashed=false"}).GetList() + for folder in folders: + print(f"Name: {folder['title']}, ID: {folder['id']}") + return folders + + +def find_folder_by_name(folder_name: str, drive: GoogleDrive): + """Find a Google Drive folder by its name and return its identifier.""" + folders = drive.ListFile({'q': f"mimeType='application/vnd.google-apps.folder' and title='{folder_name}' and trashed=false"}).GetList() + if folders: + return folders[0]['id'] + return None + + +def upload_file(path: str, drive_dir: str, drive_instance: GoogleDrive): + """Upload *path* to Google Drive folder *drive_dir* and remove the local copy.""" + upload_file = drive_instance.CreateFile({ + 'title': os.path.basename(path), + 'parents': [{'id': drive_dir}], + }) + upload_file.SetContentFile(path) + upload_file.Upload() + os.remove(path) diff --git a/record/gum/observers/windows/__init__.py b/record/gum/observers/windows/__init__.py new file mode 100644 index 0000000..a5ce3b4 --- /dev/null +++ b/record/gum/observers/windows/__init__.py @@ -0,0 +1,17 @@ +"""Windows-specific observer implementations (placeholder).""" + +from .keyboard import WindowsKeyboardBackend +from .mouse import WindowsMouseBackend +from .screenshots import WindowsScreenshotBackend +from .app_and_browser_logging import ( + WindowsAppAndBrowserInspector, + check_automation_permission_granted as windows_check_automation_permission_granted, +) + +__all__ = [ + "WindowsKeyboardBackend", + "WindowsMouseBackend", + "WindowsScreenshotBackend", + "WindowsAppAndBrowserInspector", + "windows_check_automation_permission_granted", +] diff --git a/record/gum/observers/windows/app_and_browser_logging.py b/record/gum/observers/windows/app_and_browser_logging.py new file mode 100644 index 0000000..266bc89 --- /dev/null +++ b/record/gum/observers/windows/app_and_browser_logging.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Optional + + +class WindowsAppAndBrowserInspector: + """Windows stub for UI automation hooks (not yet implemented).""" + + def __init__(self, logger) -> None: + self.logger = logger + self.last_frontmost_bundle_id: Optional[str] = None + + def get_frontmost_app_name(self) -> Optional[str]: + return None + + def snapshot_running_browsers(self): + return [] + + +def check_automation_permission_granted(force_refresh: bool = False) -> Optional[bool]: + return None + + +__all__ = ["WindowsAppAndBrowserInspector", "check_automation_permission_granted"] diff --git a/record/gum/observers/windows/keyboard.py b/record/gum/observers/windows/keyboard.py new file mode 100644 index 0000000..91f3a0a --- /dev/null +++ b/record/gum/observers/windows/keyboard.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ..fallback.keyboard import PynputKeyboardBackend as WindowsKeyboardBackend + +__all__ = ["WindowsKeyboardBackend"] diff --git a/record/gum/observers/windows/mouse.py b/record/gum/observers/windows/mouse.py new file mode 100644 index 0000000..dd7702d --- /dev/null +++ b/record/gum/observers/windows/mouse.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ..fallback.mouse import PynputMouseBackend as WindowsMouseBackend + +__all__ = ["WindowsMouseBackend"] diff --git a/record/gum/observers/windows/screenshots.py b/record/gum/observers/windows/screenshots.py new file mode 100644 index 0000000..fafb9a2 --- /dev/null +++ b/record/gum/observers/windows/screenshots.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from ..fallback.screenshots import FallbackScreenshotBackend as WindowsScreenshotBackend + +__all__ = ["WindowsScreenshotBackend"] diff --git a/gum/schemas.py b/record/gum/schemas.py similarity index 90% rename from gum/schemas.py rename to record/gum/schemas.py index a8cb54b..45c9e5f 100644 --- a/gum/schemas.py +++ b/record/gum/schemas.py @@ -33,13 +33,15 @@ class PropositionItem(BaseModel): class PropositionSchema(BaseModel): propositions: List[PropositionItem] = Field( ..., - description="Up to K propositions" + description="Up to five propositions" ) model_config = ConfigDict(extra="forbid") class Update(BaseModel): content: str = Field(..., description="The content of the update") content_type: Literal["input_text", "input_image"] = Field(..., description="The type of the update") + app_name: Optional[str] = Field(default=None, description="Active application name when the update was captured") + browser_url: Optional[str] = Field(default=None, description="Active browser URL when the update was captured") RelationLabel = Literal["IDENTICAL", "SIMILAR", "UNRELATED"] diff --git a/record/instructions.pdf b/record/instructions.pdf new file mode 100644 index 0000000..cb6b8c2 Binary files /dev/null and b/record/instructions.pdf differ diff --git a/record/pyproject.toml b/record/pyproject.toml new file mode 100644 index 0000000..16bd754 --- /dev/null +++ b/record/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/record/setup.py b/record/setup.py new file mode 100644 index 0000000..7a1bd95 --- /dev/null +++ b/record/setup.py @@ -0,0 +1,51 @@ +from setuptools import setup, find_packages + +setup( + name="gum", + version="0.1.0", + packages=find_packages(), + include_package_data=True, + install_requires=[ + # Core dependencies + "pillow", # For image processing + "mss", # For screen capture + "pynput", # For mouse/keyboard monitoring + "shapely", # For geometry operations + "pyobjc-framework-Quartz", # For macOS window management + "pyobjc-framework-Cocoa", # For AppKit (GUI, NSEvent) + "openai>=1.0.0", + "SQLAlchemy>=2.0.0", + "pydantic>=2.0.0", + "sqlalchemy-utils>=0.41.0", + "python-dotenv>=1.0.0", + "scikit-learn", + "aiosqlite", + "greenlet", + "PyYAML", # For Google Drive settings configuration + "PyDrive", # For Google Drive integration + # Google Drive API dependencies (optional, for advanced features) + "google-auth", # For Google Drive API authentication + "google-auth-oauthlib", # For OAuth flow + "google-auth-httplib2", # For HTTP requests + "google-api-python-client", # For Google Drive API + ], + extras_require={ + 'monitoring': [ + 'psutil', # For memory monitoring + ], + }, + entry_points={ + 'console_scripts': [ + 'gum=gum.cli:cli_main', + ], + }, + description="A Python package with command-line interface", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", +) \ No newline at end of file diff --git a/record/tests/test_cli.py b/record/tests/test_cli.py new file mode 100644 index 0000000..c2ae5b1 --- /dev/null +++ b/record/tests/test_cli.py @@ -0,0 +1,181 @@ +import argparse +import asyncio +import importlib +import tempfile +import unittest +from contextlib import asynccontextmanager +from unittest import mock + +from record.gum.schemas import Update +from record.gum.gum import gum as GumEngine + + +class _DummyScreen: + instance = None + + def __init__(self, *args, **kwargs): + type(self).instance = self + self.args = args + self.kwargs = kwargs + + +class _DummyGum: + instance = None + + def __init__(self, user_name, screen_observer, data_directory, app_and_browser_inspector, **kwargs): + type(self).instance = self + self.user_name = user_name + self.screen_observer = screen_observer + self.data_directory = data_directory + self.app_and_browser_inspector = app_and_browser_inspector + self.kwargs = kwargs + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class _DummyInspector: + def __init__(self, front_name=None, front_url=None): + self.front_name = front_name + self.front_url = front_url + self.calls = [] + + def get_frontmost_app_name(self): + self.calls.append("front") + return self.front_name + + def get_browser_url(self, app_name): + self.calls.append(("browser", app_name)) + return self.front_url + + def prime_automation_for_running_browsers(self): + return True + + def running_browser_applications(self): + return [] + + +class _DummyObserver: + def __init__(self, name="Screen"): + self.name = name + + +class _DummySession: + def __init__(self): + self.added = [] + self.flushed = False + + def add(self, observation): + self.added.append(observation) + + async def flush(self): + self.flushed = True + + +class GumCLITests(unittest.IsolatedAsyncioTestCase): + async def test_cli_shares_inspector_between_screen_and_gum(self): + cli_main = importlib.import_module("record.gum.cli.main") + + def resolved_future(): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_result(None) + return fut + + args = argparse.Namespace( + user_name="tester", + debug=False, + data_directory="/tmp", + screenshots_dir="/tmp/screens", + scroll_debounce=0.5, + scroll_min_distance=5.0, + scroll_max_frequency=10, + scroll_session_timeout=2.0, + ) + + with mock.patch.object(cli_main, "Screen", _DummyScreen), \ + mock.patch.object(cli_main, "GumApp", _DummyGum), \ + mock.patch.object(cli_main, "parse_args", return_value=args), \ + mock.patch.object(cli_main.asyncio, "Future", side_effect=resolved_future): + await cli_main._run_cli() + + screen_instance = _DummyScreen.instance + gum_instance = _DummyGum.instance + self.assertIsNotNone(screen_instance) + self.assertIsNotNone(gum_instance) + self.assertIs( + screen_instance.kwargs.get("app_inspector"), + gum_instance.app_and_browser_inspector, + ) + + +class GumDefaultHandlerTests(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.tempdir = tempfile.TemporaryDirectory() + + async def asyncTearDown(self): + self.tempdir.cleanup() + + async def _make_gum(self, inspector): + g = GumEngine( + "tester", + data_directory=self.tempdir.name, + app_and_browser_inspector=inspector, + ) + + def fake_session(self): + @asynccontextmanager + async def cm(): + session = _DummySession() + self._test_session = session + yield session + return cm() + + g._session = fake_session.__get__(g, type(g)) + return g + + async def test_default_handler_prefers_update_metadata(self): + inspector = _DummyInspector(front_name="Terminal", front_url="terminal://") + g = await self._make_gum(inspector) + observer = _DummyObserver() + update = Update( + content="key_press", + content_type="input_text", + app_name="Safari", + browser_url="https://example.com", + ) + + await g._default_handler(observer, update) + + session = g._test_session + self.assertEqual(len(session.added), 1) + obs = session.added[0] + self.assertEqual(obs.app_name, "Safari") + self.assertEqual(obs.browser_url, "https://example.com") + self.assertEqual(inspector.calls, []) + + async def test_default_handler_uses_inspector_when_metadata_missing(self): + inspector = _DummyInspector(front_name="Arc", front_url="https://arc.net") + g = await self._make_gum(inspector) + observer = _DummyObserver() + update = Update( + content="key_press", + content_type="input_text", + ) + + await g._default_handler(observer, update) + + session = g._test_session + self.assertEqual(len(session.added), 1) + obs = session.added[0] + self.assertEqual(obs.app_name, "Arc") + self.assertEqual(obs.browser_url, "https://arc.net") + self.assertIn("front", inspector.calls) + self.assertIn(("browser", "Arc"), inspector.calls) + + +if __name__ == "__main__": + unittest.main() diff --git a/record/tests/test_macos_inspector.py b/record/tests/test_macos_inspector.py new file mode 100644 index 0000000..d56567b --- /dev/null +++ b/record/tests/test_macos_inspector.py @@ -0,0 +1,55 @@ +import logging +import unittest +from unittest import mock + +import types + +from record.gum.observers.macos.app_and_browser_logging import MacOSAppAndBrowserInspector + + +class MacInspectorTests(unittest.TestCase): + def setUp(self): + self.logger = logging.getLogger("test.macos.inspector") + self.inspector = MacOSAppAndBrowserInspector(self.logger) + self.inspector.last_frontmost_bundle_id = "company.thebrowser.browser" + self.inspector.last_frontmost_pid = 123 + + @mock.patch("record.gum.observers.macos.app_and_browser_logging._browser_url_via_accessibility") + @mock.patch.object(MacOSAppAndBrowserInspector, "_run_browser_scripts") + def test_accessibility_fallback_when_scripts_fail(self, run_scripts, fallback): + run_scripts.return_value = None + fallback.return_value = "https://arc.net" + + url = self.inspector.get_browser_url("Arc") + + self.assertEqual(url, "https://arc.net") + fallback.assert_called_once_with(123) + + @mock.patch("record.gum.observers.macos.app_and_browser_logging._browser_url_via_accessibility") + @mock.patch.object(MacOSAppAndBrowserInspector, "_run_browser_scripts") + def test_cached_url_used_when_no_new_data(self, run_scripts, fallback): + run_scripts.return_value = None + fallback.return_value = None + self.inspector.last_browser_urls["arc"] = "https://cached.example" + + url = self.inspector.get_browser_url("Arc") + + self.assertEqual(url, "https://cached.example") + fallback.assert_called_once_with(123) + + @mock.patch("record.gum.observers.macos.app_and_browser_logging.subprocess.run") + def test_run_browser_scripts_skips_nullish_responses(self, run_proc): + run_proc.side_effect = [ + types.SimpleNamespace(returncode=0, stdout="missing value", stderr=""), + types.SimpleNamespace(returncode=0, stdout="null", stderr=""), + types.SimpleNamespace(returncode=0, stdout="https://arc.example", stderr=""), + ] + + url = self.inspector._run_browser_scripts("Arc", "arc") + + self.assertEqual(url, "https://arc.example") + self.assertEqual(run_proc.call_count, 3) + + +if __name__ == "__main__": + unittest.main()