diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 000000000..6b0f4de4f --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,3 @@ +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 000000000..2e53469ce --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,10 @@ +``` +# First time setup +cd deploy +uv venv +source .venv/bin/activate +uv pip install -e . + +# Subsequent usage +python deploy/models/omniparser/deploy.py start +``` diff --git a/deploy/deploy/models/omniparser/.dockerignore b/deploy/deploy/models/omniparser/.dockerignore new file mode 100644 index 000000000..213bee701 --- /dev/null +++ b/deploy/deploy/models/omniparser/.dockerignore @@ -0,0 +1,20 @@ +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +env +pip-log.txt +pip-delete-this-directory.txt +.tox +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.log +.pytest_cache +.env +.venv +.DS_Store diff --git a/deploy/deploy/models/omniparser/Dockerfile b/deploy/deploy/models/omniparser/Dockerfile new file mode 100644 index 000000000..f14ea7ac8 --- /dev/null +++ b/deploy/deploy/models/omniparser/Dockerfile @@ -0,0 +1,59 @@ +FROM nvidia/cuda:12.3.1-devel-ubuntu22.04 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git-lfs \ + wget \ + libgl1 \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:$PATH" + +RUN conda create -n omni python=3.12 && \ + echo "source activate omni" > ~/.bashrc +ENV CONDA_DEFAULT_ENV=omni +ENV PATH="/opt/conda/envs/omni/bin:$PATH" + +WORKDIR /app + +RUN git clone https://github.com/microsoft/OmniParser.git && \ + cd OmniParser && \ + git lfs install && \ + git lfs pull + +WORKDIR /app/OmniParser + +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + pip uninstall -y opencv-python opencv-python-headless && \ + pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \ + pip install -r requirements.txt && \ + pip install huggingface_hub fastapi uvicorn + +# Download V2 weights +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + mkdir -p /app/OmniParser/weights && \ + cd /app/OmniParser && \ + rm -rf weights/icon_detect weights/icon_caption weights/icon_caption_florence && \ + for folder in icon_caption icon_detect; do \ + huggingface-cli download microsoft/OmniParser-v2.0 --local-dir weights --repo-type model --include "$folder/*"; \ + done && \ + mv weights/icon_caption weights/icon_caption_florence + +# Pre-download OCR models during build +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + cd /app/OmniParser && \ + python3 -c "import easyocr; reader = easyocr.Reader(['en']); print('Downloaded EasyOCR model')" && \ + python3 -c "from paddleocr import PaddleOCR; ocr = PaddleOCR(lang='en', use_angle_cls=False, use_gpu=False, show_log=False); print('Downloaded PaddleOCR model')" + +CMD ["python3", "/app/OmniParser/omnitool/omniparserserver/omniparserserver.py", \ + "--som_model_path", "/app/OmniParser/weights/icon_detect/model.pt", \ + "--caption_model_path", "/app/OmniParser/weights/icon_caption_florence", \ + "--device", "cuda", \ + "--BOX_TRESHOLD", "0.05", \ + "--host", "0.0.0.0", \ + "--port", "8000"] diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py new file mode 100644 index 000000000..c0cac4f49 --- /dev/null +++ b/deploy/deploy/models/omniparser/client.py @@ -0,0 +1,128 @@ +"""Client module for interacting with the OmniParser server.""" + +import base64 +import fire +import requests + +from loguru import logger +from PIL import Image, ImageDraw + + +def image_to_base64(image_path: str) -> str: + """Convert an image file to base64 string. + + Args: + image_path: Path to the image file + + Returns: + str: Base64 encoded string of the image + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def plot_results( + original_image_path: str, + som_image_base64: str, + parsed_content_list: list[dict[str, list[float]]], +) -> None: + """Plot parsing results on the original image. + + Args: + original_image_path: Path to the original image + som_image_base64: Base64 encoded SOM image + parsed_content_list: List of parsed content with bounding boxes + """ + # Open original image + image = Image.open(original_image_path) + width, height = image.size + + # Create drawable image + draw = ImageDraw.Draw(image) + + # Draw bounding boxes and labels + for item in parsed_content_list: + # Get normalized coordinates and convert to pixel coordinates + x1, y1, x2, y2 = item["bbox"] + x1 = int(x1 * width) + y1 = int(y1 * height) + x2 = int(x2 * width) + y2 = int(y2 * height) + + label = item["content"] + + # Draw rectangle + draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2) + + # Draw label background + text_bbox = draw.textbbox((x1, y1), label) + draw.rectangle( + [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], + fill="white", + ) + + # Draw label text + draw.text((x1, y1), label, fill="red") + + # Show image + image.show() + + +def parse_image( + image_path: str, + server_url: str, +) -> None: + """Parse an image using the OmniParser server. + + Args: + image_path: Path to the image file + server_url: URL of the OmniParser server + """ + # Remove trailing slash from server_url if present + server_url = server_url.rstrip("/") + + # Convert image to base64 + base64_image = image_to_base64(image_path) + + # Prepare request + url = f"{server_url}/parse/" + payload = {"base64_image": base64_image} + + try: + # First, check if the server is available + probe_url = f"{server_url}/probe/" + probe_response = requests.get(probe_url) + probe_response.raise_for_status() + logger.info("Server is available") + + # Make request to API + response = requests.post(url, json=payload) + response.raise_for_status() + + # Parse response + result = response.json() + som_image_base64 = result["som_image_base64"] + parsed_content_list = result["parsed_content_list"] + + # Plot results + plot_results(image_path, som_image_base64, parsed_content_list) + + # Print latency + logger.info(f"API Latency: {result['latency']:.2f} seconds") + + except requests.exceptions.ConnectionError: + logger.error(f"Error: Could not connect to server at {server_url}") + logger.error("Please check if the server is running and the URL is correct") + except requests.exceptions.RequestException as e: + logger.error(f"Error making request to API: {e}") + except Exception as e: + logger.error(f"Error: {e}") + + +def main() -> None: + """Main entry point for the client application.""" + fire.Fire(parse_image) + + +if __name__ == "__main__": + main() diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py new file mode 100644 index 000000000..b951378bb --- /dev/null +++ b/deploy/deploy/models/omniparser/deploy.py @@ -0,0 +1,785 @@ +"""Deployment module for OmniParser on AWS EC2.""" + +import os +import subprocess +import time + +from botocore.exceptions import ClientError +from loguru import logger +from pydantic_settings import BaseSettings +import boto3 +import fire +import paramiko + + +CLEANUP_ON_FAILURE = False + + +class Config(BaseSettings): + """Configuration settings for deployment.""" + + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION: str + + PROJECT_NAME: str = "omniparser" + REPO_URL: str = "https://github.com/microsoft/OmniParser.git" + AWS_EC2_AMI: str = "ami-06835d15c4de57810" + AWS_EC2_DISK_SIZE: int = 128 # GB + AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64) + AWS_EC2_USER: str = "ubuntu" + PORT: int = 8000 # FastAPI port + COMMAND_TIMEOUT: int = 600 # 10 minutes + + class Config: + """Pydantic configuration class.""" + + env_file = ".env" + env_file_encoding = "utf-8" + + @property + def CONTAINER_NAME(self) -> str: + """Get the container name.""" + return f"{self.PROJECT_NAME}-container" + + @property + def AWS_EC2_KEY_NAME(self) -> str: + """Get the EC2 key pair name.""" + return f"{self.PROJECT_NAME}-key" + + @property + def AWS_EC2_KEY_PATH(self) -> str: + """Get the path to the EC2 key file.""" + return f"./{self.AWS_EC2_KEY_NAME}.pem" + + @property + def AWS_EC2_SECURITY_GROUP(self) -> str: + """Get the EC2 security group name.""" + return f"{self.PROJECT_NAME}-SecurityGroup" + + +config = Config() + + +def create_key_pair( + key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH +) -> str | None: + """Create an EC2 key pair. + + Args: + key_name: Name of the key pair + key_path: Path where to save the key file + + Returns: + str | None: Key name if successful, None otherwise + """ + ec2_client = boto3.client("ec2", region_name=config.AWS_REGION) + try: + key_pair = ec2_client.create_key_pair(KeyName=key_name) + private_key = key_pair["KeyMaterial"] + + with open(key_path, "w") as key_file: + key_file.write(private_key) + os.chmod(key_path, 0o400) # Set read-only permissions + + logger.info(f"Key pair {key_name} created and saved to {key_path}") + return key_name + except ClientError as e: + logger.error(f"Error creating key pair: {e}") + return None + + +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + """Get existing security group or create a new one. + + Args: + ports: List of ports to open in the security group + + Returns: + str | None: Security group ID if successful, None otherwise + """ + ec2 = boto3.client("ec2", region_name=config.AWS_REGION) + + ip_permissions = [ + { + "IpProtocol": "tcp", + "FromPort": port, + "ToPort": port, + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + for port in ports + ] + + try: + response = ec2.describe_security_groups( + GroupNames=[config.AWS_EC2_SECURITY_GROUP] + ) + security_group_id = response["SecurityGroups"][0]["GroupId"] + logger.info( + f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: " + f"{security_group_id}" + ) + + for ip_permission in ip_permissions: + try: + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=[ip_permission] + ) + logger.info(f"Added inbound rule for port {ip_permission['FromPort']}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidPermission.Duplicate": + logger.info( + f"Rule for port {ip_permission['FromPort']} already exists" + ) + else: + logger.error( + f"Error adding rule for port {ip_permission['FromPort']}: {e}" + ) + + return security_group_id + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + try: + response = ec2.create_security_group( + GroupName=config.AWS_EC2_SECURITY_GROUP, + Description="Security group for OmniParser deployment", + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [{"Key": "Name", "Value": config.PROJECT_NAME}], + } + ], + ) + security_group_id = response["GroupId"] + logger.info( + f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' " + f"with ID: {security_group_id}" + ) + + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) + logger.info(f"Added inbound rules for ports {ports}") + + return security_group_id + except ClientError as e: + logger.error(f"Error creating security group: {e}") + return None + else: + logger.error(f"Error describing security groups: {e}") + return None + + +def deploy_ec2_instance( + ami: str = config.AWS_EC2_AMI, + instance_type: str = config.AWS_EC2_INSTANCE_TYPE, + project_name: str = config.PROJECT_NAME, + key_name: str = config.AWS_EC2_KEY_NAME, + disk_size: int = config.AWS_EC2_DISK_SIZE, +) -> tuple[str | None, str | None]: + """Deploy a new EC2 instance or return existing one. + + Args: + ami: AMI ID to use for the instance + instance_type: EC2 instance type + project_name: Name tag for the instance + key_name: Name of the key pair to use + disk_size: Size of the root volume in GB + + Returns: + tuple[str | None, str | None]: Instance ID and public IP if successful + """ + ec2 = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Check for existing instances first + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + { + "Name": "instance-state-name", + "Values": ["running", "pending", "stopped"], + }, + ] + ) + + existing_instance = None + for instance in instances: + existing_instance = instance + if instance.state["Name"] == "running": + logger.info( + f"Instance already running: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + elif instance.state["Name"] == "stopped": + logger.info(f"Starting existing stopped instance: ID - {instance.id}") + ec2_client.start_instances(InstanceIds=[instance.id]) + instance.wait_until_running() + instance.reload() + logger.info( + f"Instance started: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + + # If we found an existing instance, ensure we have its key + if existing_instance: + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.warning( + f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance." + ) + logger.warning( + "You'll need to use the original key file to connect to this instance." + ) + logger.warning( + "Consider terminating the instance with 'deploy.py stop' and starting " + "fresh." + ) + return None, None + return existing_instance.id, existing_instance.public_ip_address + + # No existing instance found, create new one with new key pair + security_group_id = get_or_create_security_group_id() + if not security_group_id: + logger.error( + "Unable to retrieve security group ID. Instance deployment aborted." + ) + return None, None + + # Create new key pair + try: + if os.path.exists(config.AWS_EC2_KEY_PATH): + logger.info(f"Removing existing key file {config.AWS_EC2_KEY_PATH}") + os.remove(config.AWS_EC2_KEY_PATH) + + try: + ec2_client.delete_key_pair(KeyName=key_name) + logger.info(f"Deleted existing key pair {key_name}") + except ClientError: + pass # Key pair doesn't exist, which is fine + + if not create_key_pair(key_name): + logger.error("Failed to create key pair") + return None, None + except Exception as e: + logger.error(f"Error managing key pair: {e}") + return None, None + + # Create new instance + ebs_config = { + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": disk_size, + "VolumeType": "gp3", + "DeleteOnTermination": True, + }, + } + + new_instance = ec2.create_instances( + ImageId=ami, + MinCount=1, + MaxCount=1, + InstanceType=instance_type, + KeyName=key_name, + SecurityGroupIds=[security_group_id], + BlockDeviceMappings=[ebs_config], + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [{"Key": "Name", "Value": project_name}], + }, + ], + )[0] + + new_instance.wait_until_running() + new_instance.reload() + logger.info( + f"New instance created: ID - {new_instance.id}, " + f"IP - {new_instance.public_ip_address}" + ) + return new_instance.id, new_instance.public_ip_address + + +def configure_ec2_instance( + instance_id: str | None = None, + instance_ip: str | None = None, + max_ssh_retries: int = 20, + ssh_retry_delay: int = 20, + max_cmd_retries: int = 20, + cmd_retry_delay: int = 30, +) -> tuple[str | None, str | None]: + """Configure an EC2 instance with necessary dependencies and Docker setup. + + This function either configures an existing EC2 instance specified by instance_id + and instance_ip, or deploys and configures a new instance. It installs Docker and + other required dependencies, and sets up the environment for running containers. + + Args: + instance_id: Optional ID of an existing EC2 instance to configure. + If None, a new instance will be deployed. + instance_ip: Optional IP address of an existing EC2 instance. + Required if instance_id is provided. + max_ssh_retries: Maximum number of SSH connection attempts. + Defaults to 20 attempts. + ssh_retry_delay: Delay in seconds between SSH connection attempts. + Defaults to 20 seconds. + max_cmd_retries: Maximum number of command execution retries. + Defaults to 20 attempts. + cmd_retry_delay: Delay in seconds between command execution retries. + Defaults to 30 seconds. + + Returns: + tuple[str | None, str | None]: A tuple containing: + - The instance ID (str) or None if configuration failed + - The instance's public IP address (str) or None if configuration failed + + Raises: + RuntimeError: If command execution fails + paramiko.SSHException: If SSH connection fails + Exception: For other unexpected errors during configuration + """ + if not instance_id: + ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() + else: + ec2_instance_id = instance_id + ec2_instance_ip = instance_ip + + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh_retries = 0 + while ssh_retries < max_ssh_retries: + try: + ssh_client.connect( + hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key + ) + break + except Exception as e: + ssh_retries += 1 + logger.error(f"SSH connection attempt {ssh_retries} failed: {e}") + if ssh_retries < max_ssh_retries: + logger.info(f"Retrying SSH connection in {ssh_retry_delay} seconds...") + time.sleep(ssh_retry_delay) + else: + logger.error("Maximum SSH connection attempts reached. Aborting.") + return None, None + + commands = [ + "sudo apt-get update", + "sudo apt-get install -y ca-certificates curl gnupg", + "sudo install -m 0755 -d /etc/apt/keyrings", + ( + "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | " + "sudo dd of=/etc/apt/keyrings/docker.gpg" + ), + "sudo chmod a+r /etc/apt/keyrings/docker.gpg", + ( + 'echo "deb [arch="$(dpkg --print-architecture)" ' + "signed-by=/etc/apt/keyrings/docker.gpg] " + "https://download.docker.com/linux/ubuntu " + '"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | ' + "sudo tee /etc/apt/sources.list.d/docker.list > /dev/null" + ), + "sudo apt-get update", + ( + "sudo apt-get install -y docker-ce docker-ce-cli containerd.io " + "docker-buildx-plugin docker-compose-plugin" + ), + "sudo systemctl start docker", + "sudo systemctl enable docker", + "sudo usermod -a -G docker ${USER}", + "sudo docker system prune -af --volumes", + f"sudo docker rm -f {config.PROJECT_NAME}-container || true", + ] + + for command in commands: + logger.info(f"Executing command: {command}") + cmd_retries = 0 + while cmd_retries < max_cmd_retries: + stdin, stdout, stderr = ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() + + if exit_status == 0: + logger.info("Command executed successfully") + break + else: + error_message = stderr.read() + if "Could not get lock" in str(error_message): + cmd_retries += 1 + logger.warning( + f"dpkg is locked, retrying in {cmd_retry_delay} seconds... " + f"Attempt {cmd_retries}/{max_cmd_retries}" + ) + time.sleep(cmd_retry_delay) + else: + logger.error( + f"Error in command: {command}, Exit Status: {exit_status}, " + f"Error: {error_message}" + ) + break + + ssh_client.close() + return ec2_instance_id, ec2_instance_ip + + +def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode("utf-8", errors="replace") + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode("utf-8", errors="replace") + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") + + +class Deploy: + """Class handling deployment operations for OmniParser.""" + + @staticmethod + def start() -> None: + """Start a new deployment of OmniParser on EC2.""" + try: + instance_id, instance_ip = configure_ec2_instance() + assert instance_ip, f"invalid {instance_ip=}" + + # Trigger driver installation via login shell + Deploy.ssh(non_interactive=True) + + # Get the directory containing deploy.py + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Define files to copy + files_to_copy = { + "Dockerfile": os.path.join(current_dir, "Dockerfile"), + ".dockerignore": os.path.join(current_dir, ".dockerignore"), + } + + # Copy files to instance + for filename, filepath in files_to_copy.items(): + if os.path.exists(filepath): + logger.info(f"Copying {filename} to instance...") + subprocess.run( + [ + "scp", + "-i", + config.AWS_EC2_KEY_PATH, + "-o", + "StrictHostKeyChecking=no", + filepath, + f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}", + ], + check=True, + ) + else: + logger.warning(f"File not found: {filepath}") + + # Connect to instance and execute commands + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + logger.info(f"Connecting to {instance_ip}...") + ssh_client.connect( + hostname=instance_ip, + username=config.AWS_EC2_USER, + pkey=key, + timeout=30, + ) + + setup_commands = [ + "rm -rf OmniParser", # Clean up any existing repo + f"git clone {config.REPO_URL}", + "cp Dockerfile .dockerignore OmniParser/", + ] + + # Execute setup commands + for command in setup_commands: + logger.info(f"Executing setup command: {command}") + execute_command(ssh_client, command) + + # Build and run Docker container + docker_commands = [ + # Remove any existing container + "sudo docker rm -f {config.CONTAINER_NAME} || true", + # Remove any existing image + "sudo docker rmi {config.PROJECT_NAME} || true", + # Build new image + ( + "cd OmniParser && sudo docker build --progress=plain " + "-t {config.PROJECT_NAME} ." + ), + # Run new container + ( + "sudo docker run -d -p 8000:8000 --gpus all --name " + "{config.CONTAINER_NAME} {config.PROJECT_NAME}" + ), + ] + + # Execute Docker commands + for command in docker_commands: + logger.info(f"Executing Docker command: {command}") + execute_command(ssh_client, command) + + # Wait for container to start and check its logs + logger.info("Waiting for container to start...") + time.sleep(10) # Give container time to start + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + + # Wait for server to become responsive + logger.info("Waiting for server to become responsive...") + max_retries = 30 + retry_delay = 10 + server_ready = False + + for attempt in range(max_retries): + try: + # Check if server is responding + check_command = f"curl -s http://localhost:{config.PORT}/probe/" + execute_command(ssh_client, check_command) + server_ready = True + break + except Exception as e: + logger.warning( + f"Server not ready (attempt {attempt + 1}/{max_retries}): " + f"{e}" + ) + if attempt < max_retries - 1: + logger.info( + f"Waiting {retry_delay} seconds before next attempt..." + ) + time.sleep(retry_delay) + + if not server_ready: + raise RuntimeError("Server failed to start properly") + + # Final status check + execute_command(ssh_client, "docker ps | grep {config.CONTAINER_NAME}") + + server_url = f"http://{instance_ip}:{config.PORT}" + logger.info(f"Deployment complete. Server running at: {server_url}") + + # Verify server is accessible from outside + try: + import requests + + response = requests.get(f"{server_url}/probe/", timeout=10) + if response.status_code == 200: + logger.info("Server is accessible from outside!") + else: + logger.warning( + f"Server responded with status code: {response.status_code}" + ) + except Exception as e: + logger.warning(f"Could not verify external access: {e}") + + except Exception as e: + logger.error(f"Error during deployment: {e}") + # Get container logs for debugging + try: + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + except Exception as exc: + logger.warning(f"{exc=}") + pass + raise + + finally: + ssh_client.close() + + except Exception as e: + logger.error(f"Deployment failed: {e}") + if CLEANUP_ON_FAILURE: + # Attempt cleanup on failure + try: + Deploy.stop() + except Exception as cleanup_error: + logger.error(f"Cleanup after failure also failed: {cleanup_error}") + raise + + logger.info("Deployment completed successfully!") + + @staticmethod + def status() -> None: + """Check the status of deployed instances.""" + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[{"Name": "tag:Name", "Values": [config.PROJECT_NAME]}] + ) + + for instance in instances: + public_ip = instance.public_ip_address + if public_ip: + server_url = f"http://{public_ip}:{config.PORT}" + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: {server_url}" + ) + else: + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: Not available (no public IP)" + ) + + @staticmethod + def ssh(non_interactive: bool = False) -> None: + """SSH into the running instance. + + Args: + non_interactive: If True, run in non-interactive mode + """ + # Get instance IP + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + {"Name": "instance-state-name", "Values": ["running"]}, + ] + ) + + instance = next(iter(instances), None) + if not instance: + logger.error("No running instance found") + return + + ip = instance.public_ip_address + if not ip: + logger.error("Instance has no public IP") + return + + # Check if key file exists + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.error(f"Key file not found: {config.AWS_EC2_KEY_PATH}") + return + + if non_interactive: + # Simulate full login by forcing all initialization scripts + ssh_command = [ + "ssh", + "-o", + "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", + "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", + config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{ip}", + "-t", # Allocate a pseudo-terminal + "-tt", # Force pseudo-terminal allocation + "bash --login -c 'exit'", # Force full login shell and exit immediately + ] + else: + # Build and execute SSH command + ssh_command = ( + f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no " + f"{config.AWS_EC2_USER}@{ip}" + ) + logger.info(f"Connecting with: {ssh_command}") + os.system(ssh_command) + return + + # Execute the SSH command for non-interactive mode + try: + subprocess.run(ssh_command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"SSH connection failed: {e}") + + @staticmethod + def stop( + project_name: str = config.PROJECT_NAME, + security_group_name: str = config.AWS_EC2_SECURITY_GROUP, + ) -> None: + """Terminates the EC2 instance and deletes the associated security group. + + Args: + project_name (str): The project name used to tag the instance. + Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. + Defaults to config.AWS_EC2_SECURITY_GROUP. + """ + ec2_resource = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Terminate EC2 instances + instances = ec2_resource.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [project_name]}, + { + "Name": "instance-state-name", + "Values": [ + "pending", + "running", + "shutting-down", + "stopped", + "stopping", + ], + }, + ] + ) + + for instance in instances: + logger.info(f"Terminating instance: ID - {instance.id}") + instance.terminate() + instance.wait_until_terminated() + logger.info(f"Instance {instance.id} terminated successfully.") + + # Delete security group + try: + ec2_client.delete_security_group(GroupName=security_group_name) + logger.info(f"Deleted security group: {security_group_name}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + logger.info( + f"Security group {security_group_name} does not exist or already " + "deleted." + ) + else: + logger.error(f"Error deleting security group: {e}") + + +if __name__ == "__main__": + fire.Fire(Deploy) diff --git a/deploy/pyproject.toml b/deploy/pyproject.toml new file mode 100644 index 000000000..835b62424 --- /dev/null +++ b/deploy/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "deploy" +version = "0.1.0" +authors = [ + { name="Richard Abrich", email="richard@openadapt.ai" }, +] +description = "Deployment tools for OpenAdapt models" +requires-python = ">=3.10" +dependencies = [ + "boto3>=1.36.22", + "fire>=0.7.0", + "loguru>=0.7.0", + "paramiko>=3.5.1", + "pillow>=11.1.0", + "pydantic>=2.10.6", + "pydantic-settings>=2.7.1", + "requests>=2.32.3", +] diff --git a/openadapt/strategies/process_graph.py b/openadapt/strategies/process_graph.py new file mode 100644 index 000000000..9af045b30 --- /dev/null +++ b/openadapt/strategies/process_graph.py @@ -0,0 +1,1214 @@ +"""Process graph-based replay strategy using OmniParser and Gemini 2.0. + +This strategy: +1. Uses OmniParser for parsing visual state and Gemini 2.0 for state evaluation +2. Takes natural language task descriptions instead of recording IDs +3. Processes coalesced actions from events.py +4. Builds and maintains a process graph G=(V,E) where: + - V represents States + - E represents Actions + - Graph is constructed before replay based on recording + task description + - Graph is updated during replay based on observed states +""" + +import json +import math +import time +import uuid +from typing import List, Optional, Dict, Union, Literal, Any +import numpy as np + +from pydantic import BaseModel, Field +from PIL import Image +from json_repair import repair_json, loads as repair_loads + +from openadapt import adapters, common, models, utils, vision +from openadapt.custom_logger import logger +from openadapt.db import crud +from openadapt.strategies.base import BaseReplayStrategy +from openadapter.providers.omniparser import OmniParserProvider + + +# Pydantic models for structured data +class RecognitionCriterion(BaseModel): + """Criteria for recognizing a state""" + type: Literal["window_title", "ui_element_present", "visual_template"] + pattern: Optional[str] = None + threshold: Optional[float] = None + element_descriptor: Optional[str] = None + + +class ActionParameter(BaseModel): + """Parameters for an action""" + target_element: Optional[str] = None + text_input: Optional[str] = None + click_type: Optional[Literal["single", "double", "right"]] = None + coordinate_type: Optional[Literal["absolute", "relative"]] = None + + +class ActionModel(BaseModel): + """Model for an action in the process""" + name: str + description: str + parameters: ActionParameter + + +class Condition(BaseModel): + """Condition for a transition""" + type: Literal["element_state", "data_value", "previous_action"] + description: str + + +class Transition(BaseModel): + """Transition between states""" + from_state: str = Field(..., alias="from") + to_state: str = Field(..., alias="to") + action: ActionModel + condition: Optional[Condition] = None + + +class Branch(BaseModel): + """Branch in a decision point""" + condition: str + next_state: str + + +class DecisionPoint(BaseModel): + """Decision point in the process""" + state: str + description: str + branches: List[Branch] + + +class Loop(BaseModel): + """Loop in the process""" + start_state: str + end_state: str + exit_condition: str + description: str + + +class StateModel(BaseModel): + """Model for a state in the process""" + name: str + description: str + recognition_criteria: List[RecognitionCriterion] + + +class ProcessAnalysis(BaseModel): + """Complete model of a process""" + process_name: str + description: str + states: List[StateModel] + transitions: List[Transition] + loops: List[Loop] + decision_points: List[DecisionPoint] + + +class StateTrajectoryEntry(BaseModel): + """Entry in the state trajectory""" + state_name: Optional[str] = None + action_name: Optional[str] = None + timestamp: float + + +class CurrentStateMatch(BaseModel): + """Result of matching current state to graph""" + matched_state_name: str + confidence: float + reasoning: str + + +class UIElement(BaseModel): + """UI element in the visual state""" + type: str + text: Optional[str] = None + bounds: Dict[str, int] + description: str + is_interactive: bool + + +class VisualState(BaseModel): + """Visual state representation""" + window_title: str + ui_elements: List[UIElement] + screenshot_timestamp: float + + +class AbstractState: + """Represents an abstract state in the process graph with recognition logic.""" + + def __init__(self, name, description, recognition_criteria): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.recognition_criteria = recognition_criteria + self.example_screenshots = [] + + def match_rules(self, current_state, trajectory=None): + """Apply rule-based matching using recognition criteria.""" + for criterion in self.recognition_criteria: + if not self._evaluate_criterion(criterion, current_state): + return False + return True + + def _evaluate_criterion(self, criterion, state): + """Evaluate a single recognition criterion against current state.""" + criterion_type = criterion["type"] + + if criterion_type == "window_title": + if not state.window_event or not state.window_event.title: + return False + return criterion["pattern"] in state.window_event.title + + elif criterion_type == "ui_element_present": + if not state.visual_data: + return False + return any( + criterion["element_descriptor"] in element["description"] + for element in state.visual_data + ) + + elif criterion_type == "visual_template": + # Match against example screenshots + if not self.example_screenshots: + return False + return any( + vision.get_image_similarity(state.screenshot.image, example)[0] > criterion.get("threshold", 0.8) + for example in self.example_screenshots + ) + + return False + + def add_example(self, screenshot): + """Add example screenshot for visual matching.""" + self.example_screenshots.append(screenshot) + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "recognition_criteria": self.recognition_criteria + } + + +class AbstractAction: + """Represents an abstract action with parameters to be instantiated.""" + + def __init__(self, name, description, parameters): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.parameters = parameters + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "parameters": self.parameters + } + + +class ProcessGraph: + """Enhanced process graph with abstract states and conditional transitions.""" + + def __init__(self): + self.nodes = set() + self.edges = [] + self.conditions = {} # Maps (from_state, action, to_state) to condition logic + self.description = "" + + def add_node(self, node): + """Add a node to the graph.""" + self.nodes.add(node) + + def add_edge(self, from_state, action, to_state): + """Add an edge to the graph.""" + self.add_node(from_state) + self.add_node(action) + self.add_node(to_state) + self.edges.append((from_state, action, to_state)) + + def add_condition(self, from_state, action, to_state, condition): + """Add a condition to an edge.""" + key = (from_state.id, action.id, to_state.id) + self.conditions[key] = condition + + def get_abstract_states(self): + """Get all abstract states in the graph.""" + return [node for node in self.nodes if isinstance(node, AbstractState)] + + def get_state_by_name(self, name): + """Find a state by name.""" + for node in self.nodes: + if isinstance(node, AbstractState) and node.name == name: + return node + return None + + def get_possible_actions(self, state): + """Get possible actions from a state, considering conditions.""" + possible_actions = [] + + for from_state, action, to_state in self.edges: + if from_state.id == state.id: + key = (from_state.id, action.id, to_state.id) + if key in self.conditions: + # For now, we include conditional actions + # In a full implementation, would need to evaluate conditions + possible_actions.append((action, to_state)) + else: + possible_actions.append((action, to_state)) + + return possible_actions + + def set_description(self, description): + """Set the overall description of the process.""" + self.description = description + + def get_description(self): + """Get the overall description of the process.""" + return self.description + + def to_model(self) -> ProcessAnalysis: + """Convert graph to Pydantic model for serialization.""" + states = [ + StateModel( + name=state.name, + description=state.description, + recognition_criteria=[ + RecognitionCriterion(**criterion) + for criterion in state.recognition_criteria + ] + ) + for state in self.get_abstract_states() + ] + + transitions = [] + for from_state, action, to_state in self.edges: + if isinstance(from_state, AbstractState) and isinstance(to_state, AbstractState): + key = (from_state.id, action.id, to_state.id) + transition = Transition( + from_state=from_state.name, + to_state=to_state.name, + action=ActionModel( + name=action.name, + description=action.description, + parameters=ActionParameter(**action.parameters) + ) + ) + if key in self.conditions: + transition.condition = Condition(**self.conditions[key]) + transitions.append(transition) + + # Build loops and decision points using a simple algorithm + loops = self._detect_loops() + decision_points = self._detect_decision_points() + + return ProcessAnalysis( + process_name=self.description.split("\n")[0] if self.description else "Unnamed Process", + description=self.description, + states=states, + transitions=transitions, + loops=loops, + decision_points=decision_points + ) + + def _detect_loops(self) -> List[Loop]: + """Simple loop detection algorithm.""" + loops = [] + # Map state names to IDs for easier lookup + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Find cycles in the graph using DFS + visited = set() + path = [] + + def dfs(node_id): + if node_id in path: + # Found a cycle + cycle_start = path.index(node_id) + cycle = path[cycle_start:] + # Only process if the cycle involves states (not just actions) + state_ids = [node_id for node_id in cycle if node_id in state_id_to_name] + if len(state_ids) > 1: + loops.append(Loop( + start_state=state_id_to_name[state_ids[0]], + end_state=state_id_to_name[state_ids[-1]], + exit_condition="Condition to exit loop", + description=f"Loop from {state_id_to_name[state_ids[0]]} to {state_id_to_name[state_ids[-1]]}" + )) + return + + if node_id in visited: + return + + visited.add(node_id) + path.append(node_id) + + # Find all outgoing edges + for from_state, _, to_state in self.edges: + if from_state.id == node_id: + dfs(to_state.id) + + path.pop() + + # Start DFS from each state + for state in self.get_abstract_states(): + dfs(state.id) + + return loops + + def _detect_decision_points(self) -> List[DecisionPoint]: + """Detect states with multiple outgoing transitions.""" + decision_points = [] + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Count outgoing edges for each state + outgoing_counts = {} + for from_state, _, to_state in self.edges: + if from_state.id not in outgoing_counts: + outgoing_counts[from_state.id] = [] + outgoing_counts[from_state.id].append(to_state.id) + + # States with multiple outgoing edges are decision points + for state_id, destinations in outgoing_counts.items(): + if state_id in state_id_to_name and len(destinations) > 1: + branches = [] + for dest_id in destinations: + if dest_id in state_id_to_name: + branches.append(Branch( + condition=f"Condition to go to {state_id_to_name[dest_id]}", + next_state=state_id_to_name[dest_id] + )) + + if branches: + decision_points.append(DecisionPoint( + state=state_id_to_name[state_id], + description=f"Decision point at {state_id_to_name[state_id]}", + branches=branches + )) + + return decision_points + + def to_json(self): + """Convert graph to JSON string.""" + return self.to_model().model_dump_json(indent=2) + + def update_with_observation(self, observed_state, previous_state, latest_action): + """Update graph with observed state during execution.""" + # Find abstract states that match the observed state + similar_state = None + highest_similarity = 0.0 + + for state in self.get_abstract_states(): + similarity = self._calculate_state_similarity(observed_state, state) + if similarity > highest_similarity: + highest_similarity = similarity + similar_state = state + + # Create a new state if no good match + if highest_similarity < 0.7: + similar_state = self._create_new_state_from_observation(observed_state) + + # If we have a previous state and action, create or update a transition + if previous_state and latest_action: + # Check if transition already exists + transition_exists = False + for from_state, action, to_state in self.edges: + if (from_state.id == previous_state.id and + action.name == latest_action.name and + to_state.id == similar_state.id): + transition_exists = True + break + + if not transition_exists: + # Create a new abstract action from the latest action + action = AbstractAction( + name=latest_action.name, + description=f"Action {latest_action.name}", + parameters=self._extract_action_parameters(latest_action) + ) + + # Add the edge + self.add_edge(previous_state, action, similar_state) + + return similar_state + + def _calculate_state_similarity(self, observed_state, abstract_state): + """Calculate similarity between observed state and abstract state.""" + # Use rule-based matching first + if abstract_state.match_rules(observed_state): + return 0.9 # High confidence if rules match + + # Fall back to visual similarity if we have example screenshots + if abstract_state.example_screenshots and observed_state.screenshot: + visual_similarities = [ + vision.get_image_similarity(observed_state.screenshot.image, example)[0] + for example in abstract_state.example_screenshots + ] + return max(visual_similarities) if visual_similarities else 0.0 + + return 0.0 + + def _create_new_state_from_observation(self, observed_state): + """Create a new abstract state from an observed state.""" + # Generate a name for the state based on window title + name = "State_" + str(len(self.get_abstract_states()) + 1) + if observed_state.window_event and observed_state.window_event.title: + name = f"State_{observed_state.window_event.title[:20]}" + + # Create recognition criteria + criteria = [] + if observed_state.window_event and observed_state.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": observed_state.window_event.title + }) + + if observed_state.visual_data: + # Add criteria based on visible UI elements + for element in observed_state.visual_data[:3]: # Limit to a few key elements + if element.get("description"): + criteria.append({ + "type": "ui_element_present", + "element_descriptor": element["description"] + }) + + state = AbstractState( + name=name, + description=f"State with window title: {observed_state.window_event.title if observed_state.window_event else 'Unknown'}", + recognition_criteria=criteria + ) + + # Add screenshot as example for visual matching + if observed_state.screenshot: + state.add_example(observed_state.screenshot.image) + + self.add_node(state) + return state + + def _extract_action_parameters(self, action_event): + """Extract parameters from an action event.""" + parameters = {} + + if action_event.name in common.MOUSE_EVENTS: + parameters["target_element"] = action_event.active_segment_description + if "click" in action_event.name: + if "double" in action_event.name: + parameters["click_type"] = "double" + elif "right" in action_event.name: + parameters["click_type"] = "right" + else: + parameters["click_type"] = "single" + + elif action_event.name in common.KEYBOARD_EVENTS: + if action_event.text: + parameters["text_input"] = action_event.text + + return parameters + + +class State: + """Represents a concrete state during execution.""" + + def __init__(self, screenshot, window_event, browser_event=None, visual_data=None): + self.id = str(uuid.uuid4()) + self.screenshot = screenshot + self.window_event = window_event + self.browser_event = browser_event + self.visual_data = visual_data or [] + + +class ProcessGraphStrategy(BaseReplayStrategy): + """Strategy using process graphs, OmniParser and Gemini 2.0 Flash.""" + + def __init__( + self, + task_description: str, + recording_id: int = None, + ) -> None: + """Initialize with task description rather than recording ID.""" + # Find best matching recording if not provided + if not recording_id: + recording_id = self._find_matching_recording(task_description) + + db_session = crud.get_new_session() + self.recording = crud.get_recording(db_session, recording_id) + super().__init__(self.recording) + + self.task_description = task_description + + # Initialize OmniParser service + self.omniparser_provider = OmniParserProvider() + self._ensure_omniparser_running() + + # Initialize tracking + self.state_action_history = [] # List of (state, action) pairs + self.action_history = [] + self.current_state = None + self.current_abstract_state = None + + # Build graph before replay + self.process_graph = self._build_generalizable_process_graph(task_description) + + def _ensure_omniparser_running(self): + """Ensure OmniParser is running, deploying if necessary.""" + status = self.omniparser_provider.status() + if not status['services']: + logger.info("Deploying OmniParser...") + self.omniparser_provider.deploy() + self.omniparser_provider.stack.create_service() + + def _find_matching_recording(self, task_description: str) -> int: + """Find recording with most similar task description using vector similarity.""" + db_session = crud.get_new_session() + recordings = crud.get_all_recordings(db_session) + best_match = None + highest_similarity = -1 + + for recording in recordings: + if not recording.task_description: + continue + + similarity = self._calculate_text_similarity(task_description, recording.task_description) + if similarity > highest_similarity: + highest_similarity = similarity + best_match = recording.id + + if best_match is None: + # If no good match, use the most recent recording + recordings_sorted = sorted(recordings, key=lambda r: r.timestamp, reverse=True) + if recordings_sorted: + best_match = recordings_sorted[0].id + else: + raise ValueError("No recordings found in the database.") + + return best_match + + def _calculate_text_similarity(self, text1, text2): + """Calculate similarity between two text strings.""" + # Simple word overlap similarity + if not text1 or not text2: + return 0.0 + + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1.intersection(words2) + union = words1.union(words2) + + return len(intersection) / len(union) + + def _build_generalizable_process_graph(self, task_description): + """Build a generalizable process graph using multi-phase approach with MMMs.""" + # Get coalesced actions + processed_actions = self.recording.processed_action_events + + # Phase 1: Process Understanding - Analyze the entire workflow + process_model = self._analyze_entire_process(processed_actions, task_description) + + # Phase 2: Graph Construction - Build abstract graph from understanding + initial_graph = self._construct_abstract_graph(process_model) + + # Phase 3: Graph Validation - Test and refine by walking through recording + refined_graph = self._validate_and_refine_graph(initial_graph, processed_actions) + + return refined_graph + + def _select_representative_screenshots(self, action_events, max_images=10): + """Select representative screenshots from the action events.""" + if not action_events: + return [] + + # If few actions, use all screenshots + if len(action_events) <= max_images: + return [action.screenshot.image for action in action_events if action.screenshot] + + # Otherwise, select evenly spaced screenshots + step = len(action_events) // max_images + selected_actions = action_events[::step] + + # Add the last action if not included + if action_events[-1] not in selected_actions: + selected_actions.append(action_events[-1]) + + return [action.screenshot.image for action in selected_actions if action.screenshot] + + def _analyze_entire_process(self, actions, task_description): + """Have Gemini analyze the entire recording to understand the process structure.""" + key_screenshots = self._select_representative_screenshots(actions) + + # Generate schema JSON for the prompt + schema_json = ProcessAnalysis.model_json_schema() + + system_prompt = "You are an expert in understanding user interface workflows." + prompt = f""" + Analyze this UI automation sequence and identify: + 1. The high-level steps in the process + 2. Any repetitive patterns or loops + 3. Decision points where the workflow might branch + 4. The semantic meaning of each major state + + Task description: {task_description} + + RESPOND USING THE FOLLOWING JSON SCHEMA: + ```json + {json.dumps(schema_json, indent=2)} + ``` + + Your response must strictly follow this schema and be valid JSON. + """ + + process_analysis_text = self.prompt_gemini(prompt, system_prompt, key_screenshots) + + # Use json_repair for robust parsing + try: + # Direct parsing if possible + process_data = repair_loads(process_analysis_text) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as e: + logger.warning(f"Initial JSON parsing failed: {e}") + + # Try to repair potentially broken JSON + try: + repaired_json = repair_json(process_analysis_text, ensure_ascii=False) + process_data = json.loads(repaired_json) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as repair_e: + logger.error(f"JSON repair also failed: {repair_e}") + + # Last resort: try direct object return + try: + process_data = repair_json(process_analysis_text, return_objects=True) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as final_e: + logger.error(f"All JSON parsing methods failed: {final_e}") + return self._fallback_process_analysis(actions, task_description) + + def _fallback_process_analysis(self, actions, task_description): + """Create a simple process model if all parsing fails.""" + logger.warning("Using fallback process analysis") + + # Create a simple linear process model + states = [] + transitions = [] + + # Create a state for each key action + key_actions = actions[::max(1, len(actions) // 5)] # At most 5 states + + for i, action in enumerate(key_actions): + state_name = f"State_{i+1}" + state_description = f"State after {action.name} action" + + # Create recognition criteria + criteria = [] + if action.window_event and action.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": action.window_event.title + }) + + states.append(StateModel( + name=state_name, + description=state_description, + recognition_criteria=criteria + )) + + # Create transition to next state + if i < len(key_actions) - 1: + next_action = key_actions[i+1] + transitions.append(Transition( + from_state=state_name, + to_state=f"State_{i+2}", + action=ActionModel( + name=next_action.name, + description=f"{next_action.name} action", + parameters=ActionParameter() + ) + )) + + return ProcessAnalysis( + process_name="Fallback Process", + description=f"Fallback process for task: {task_description}", + states=states, + transitions=transitions, + loops=[], + decision_points=[] + ) + + def _construct_abstract_graph(self, process_model): + """Construct an abstract process graph based on the process understanding.""" + graph = ProcessGraph() + graph.set_description(process_model.description) + + # Create abstract state definitions based on process model + for state_def in process_model.states: + state = AbstractState( + name=state_def.name, + description=state_def.description, + recognition_criteria=[criterion.model_dump() for criterion in state_def.recognition_criteria] + ) + graph.add_node(state) + + # Create transitions with abstract actions + for transition in process_model.transitions: + from_state = graph.get_state_by_name(transition.from_state) + to_state = graph.get_state_by_name(transition.to_state) + + if from_state and to_state: + action = AbstractAction( + name=transition.action.name, + description=transition.action.description, + parameters=transition.action.parameters.model_dump() + ) + graph.add_edge(from_state, action, to_state) + + # Add conditional branches if present + if transition.condition: + graph.add_condition(from_state, action, to_state, transition.condition.model_dump()) + + return graph + + def _validate_and_refine_graph(self, graph, actions): + """Test the graph against recorded actions and refine it with Gemini's help.""" + # Simulate walking through the recording using the graph + simulation_results = self._simulate_graph_execution(graph, actions) + + if simulation_results["success"]: + return graph + + # If simulation failed, ask Gemini to refine the graph + system_prompt = "You are an expert in refining process models." + prompt = f""" + The process graph failed to match the recording at these points: + {simulation_results["failures"]} + + Current graph: {graph.to_json()} + + Please refine the graph to better match the recorded process while + maintaining generalizability. Consider: + 1. Adding missing states or transitions + 2. Adjusting state recognition criteria + 3. Modifying action parameters + 4. Adding conditional logic + + RESPOND USING THE SAME JSON SCHEMA AS THE CURRENT GRAPH. + """ + + refinements_text = self.prompt_gemini(prompt, system_prompt, simulation_results["screenshots"]) + + try: + refinements_data = repair_loads(refinements_text) + refined_model = ProcessAnalysis(**refinements_data) + refined_graph = self._construct_abstract_graph(refined_model) + + # Check if refinement improved the simulation + new_failures = len(self._simulate_graph_execution(refined_graph, actions)["failures"]) + old_failures = len(simulation_results["failures"]) + + if new_failures < old_failures: + return refined_graph + return graph + + except Exception as e: + logger.error(f"Failed to parse graph refinements: {e}") + return graph + + def _simulate_graph_execution(self, graph, actions): + """Simulate executing the graph with the recorded actions.""" + failures = [] + screenshots = [] + current_state = None + + for i, action in enumerate(actions): + # If first action, find initial state + if i == 0: + state = State(action.screenshot, action.window_event, action.browser_event) + matched_state = None + highest_similarity = 0.0 + + for abstract_state in graph.get_abstract_states(): + similarity = graph._calculate_state_similarity(state, abstract_state) + if similarity > highest_similarity: + highest_similarity = similarity + matched_state = abstract_state + + if highest_similarity < 0.7: + failures.append(f"Failed to match initial state at action {i}") + screenshots.append(action.screenshot.image) + + current_state = matched_state + continue + + # For subsequent actions, check if the graph has a transition + if current_state: + possible_actions = graph.get_possible_actions(current_state) + + # Check if any action matches the recorded action + action_match = False + for graph_action, next_state in possible_actions: + if graph_action.name == action.name: + action_match = True + current_state = next_state + break + + if not action_match: + failures.append(f"No matching action '{action.name}' from state '{current_state.name}' at action {i}") + screenshots.append(action.screenshot.image) + + return { + "success": len(failures) == 0, + "failures": failures, + "screenshots": screenshots[:5] # Limit to 5 screenshots for prompt size + } + + def get_next_action_event( + self, + screenshot: models.Screenshot, + window_event: models.WindowEvent, + ) -> models.ActionEvent: + """Determine next action using the process graph and runtime adaptation.""" + # Create current state representation + current_state = State( + screenshot=screenshot, + window_event=window_event + ) + + # Parse visual state with OmniParser + visual_data = self._parse_state_with_omniparser(screenshot.image) + current_state.visual_data = visual_data + + # Update graph with actual observed state + previous_abstract_state = self.current_abstract_state + latest_action = self.action_history[-1] if self.action_history else None + + self.current_abstract_state = self.process_graph.update_with_observation( + current_state, + previous_abstract_state, + latest_action + ) + + self.current_state = current_state + + # Find possible next actions in graph + possible_actions = self.process_graph.get_possible_actions(self.current_abstract_state) + + if not possible_actions: + # No actions available - either reached end state or unexpected state + if len(self.action_history) > 0: + # We've taken at least one action, so this might be the end + raise StopIteration("No further actions available in the process graph") + else: + # No actions taken yet - generate one with Gemini + next_action = self._generate_action_with_gemini() + self.action_history.append(next_action) + return next_action + + if len(possible_actions) == 1: + # Single clear action to take + action, next_state = possible_actions[0] + next_action = self._instantiate_abstract_action(action, current_state) + else: + # Multiple possible actions - use Gemini to decide + next_action = self._decide_between_actions(possible_actions, current_state) + + self.state_action_history.append((self.current_abstract_state, next_action)) + self.action_history.append(next_action) + + return next_action + + def _parse_state_with_omniparser(self, screenshot_image): + """Use OmniParser to parse the visual state.""" + try: + # Convert PIL Image to bytes + import io + img_byte_arr = io.BytesIO() + screenshot_image.save(img_byte_arr, format='PNG') + img_bytes = img_byte_arr.getvalue() + + # Call OmniParser API + result = self.omniparser_provider.parse_screenshot(img_bytes) + + # Transform the result into our expected format + ui_elements = [] + for element in result.get("elements", []): + ui_elements.append({ + "type": element.get("type", "unknown"), + "text": element.get("text", ""), + "bounds": element.get("bounds", {"x": 0, "y": 0, "width": 0, "height": 0}), + "description": element.get("description", ""), + "is_interactive": element.get("is_interactive", False) + }) + + return ui_elements + + except Exception as e: + logger.error(f"Error parsing state with OmniParser: {e}") + return [] + + def _instantiate_abstract_action(self, abstract_action, current_state): + """Convert abstract action to concrete ActionEvent based on current state.""" + try: + # Use parameters from abstract action if possible + params = abstract_action.parameters + + if abstract_action.name in common.MOUSE_EVENTS: + # Create a mouse action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + # If we have a target element, find its coordinates + if params.get("target_element"): + target_element = None + for element in current_state.visual_data: + if params["target_element"] in element.get("description", ""): + target_element = element + break + + if target_element: + bounds = target_element.get("bounds", {}) + # Calculate center of element + center_x = bounds.get("x", 0) + bounds.get("width", 0) / 2 + center_y = bounds.get("y", 0) + bounds.get("height", 0) / 2 + + action_event.mouse_x = center_x + action_event.mouse_y = center_y + action_event.active_segment_description = params["target_element"] + else: + # If target not found, use Gemini to identify coordinates + action_event = self._locate_target_with_gemini( + params["target_element"], + abstract_action.name, + current_state + ) + else: + # Use Gemini to decide where to click + action_event = self._locate_target_with_gemini( + None, + abstract_action.name, + current_state + ) + + return action_event + + elif abstract_action.name in common.KEYBOARD_EVENTS: + # Create a keyboard action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + if params.get("text_input"): + # For "type" action, convert to actual keypresses + action_event = models.ActionEvent.from_dict({ + "name": "type", + "text": params["text_input"] + }) + action_event.screenshot = current_state.screenshot + action_event.window_event = current_state.window_event + action_event.recording = self.recording + + return action_event + + else: + # For other actions, use Gemini + return self._generate_action_with_gemini(abstract_action.name) + + except Exception as e: + logger.error(f"Error instantiating action: {e}") + return self._generate_action_with_gemini() + + def _locate_target_with_gemini(self, target_description, action_name, current_state): + """Use Gemini to locate a target on the screen.""" + system_prompt = "You are an expert in UI automation and element identification." + prompt = f""" + Identify the coordinates to perform a {action_name} action. + + {f'The target is described as: {target_description}' if target_description else 'Find the most appropriate element to interact with based on the current state.'} + + Analyze the screenshot and provide the x,y coordinates where the action should be performed. + Respond with a JSON object containing: + 1. x: the x-coordinate (number) + 2. y: the y-coordinate (number) + 3. description: brief description of what element is at these coordinates + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + # Parse the response + coord_data = repair_loads(result_text) + + # Create action event + action_event = models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=coord_data.get("x", 0), + mouse_y=coord_data.get("y", 0), + active_segment_description=coord_data.get("description", "") + ) + + return action_event + + except Exception as e: + logger.error(f"Error parsing coordinates: {e}") + + # Fallback: use center of screen + window_width = current_state.window_event.width if current_state.window_event else 800 + window_height = current_state.window_event.height if current_state.window_event else 600 + + return models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + active_segment_description="Center of screen (fallback)" + ) + + def _decide_between_actions(self, possible_actions, current_state): + """Use Gemini to decide between multiple possible actions.""" + system_prompt = "You are an expert in UI automation decision making." + + actions_list = [] + for i, (action, next_state) in enumerate(possible_actions): + actions_list.append({ + "id": i, + "name": action.name, + "description": action.description, + "parameters": action.parameters, + "next_state": next_state.name, + "next_state_description": next_state.description + }) + + prompt = f""" + Decide which action to take next based on the current state and task description. + + Task description: {self.task_description} + Current state: {self.current_abstract_state.description if self.current_abstract_state else "Initial state"} + + Possible actions: + {json.dumps(actions_list, indent=2)} + + Respond with a JSON object containing: + 1. chosen_action_id: the ID of the chosen action (number) + 2. reasoning: brief explanation for your choice + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + result = repair_loads(result_text) + chosen_id = result.get("chosen_action_id", 0) + chosen_id = min(chosen_id, len(possible_actions) - 1) # Ensure valid index + + action, next_state = possible_actions[chosen_id] + return self._instantiate_abstract_action(action, current_state) + + except Exception as e: + logger.error(f"Error deciding between actions: {e}") + # Default to first action + action, next_state = possible_actions[0] + return self._instantiate_abstract_action(action, current_state) + + def _generate_action_with_gemini(self, suggested_action_name=None): + """Generate action with Gemini if graph doesn't provide one.""" + system_prompt = "You are an expert in UI automation." + + trajectory = [] + for i, (state, action) in enumerate(self.state_action_history[-5:]): + trajectory.append({ + "step": i + 1, + "state": state.name if state else "Unknown", + "action": action.name if action else "None" + }) + + prompt = f""" + Generate the next action to perform based on: + + Task description: {self.task_description} + Recent trajectory: {json.dumps(trajectory, indent=2)} + {f'Suggested action type: {suggested_action_name}' if suggested_action_name else ''} + + Analyze the screenshot and respond with a JSON object for the next ActionEvent: + {{ + "name": "click|move|scroll|type", + "mouse_x": number, + "mouse_y": number, + "text": "text to type (for keyboard actions)", + "active_segment_description": "description of what's being clicked" + }} + + Only include relevant fields based on the action type. + """ + + result_text = self.prompt_gemini( + prompt, + system_prompt, + [self.current_state.screenshot.image] if self.current_state else [] + ) + + try: + action_dict = repair_loads(result_text) + action = models.ActionEvent.from_dict(action_dict) + + # Add missing context + action.screenshot = self.current_state.screenshot if self.current_state else None + action.window_event = self.current_state.window_event if self.current_state else None + action.recording = self.recording + + return action + + except Exception as e: + logger.error(f"Error generating action: {e}") + + # Create a fallback action - simple click in the center + window_width = self.current_state.window_event.width if self.current_state and self.current_state.window_event else 800 + window_height = self.current_state.window_event.height if self.current_state and self.current_state.window_event else 600 + + return models.ActionEvent( + name="click", + screenshot=self.current_state.screenshot if self.current_state else None, + window_event=self.current_state.window_event if self.current_state else None, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + mouse_button_name="left", + active_segment_description="Center of screen (fallback)" + ) + + def prompt_gemini(self, prompt, system_prompt, images): + """Helper method to prompt Gemini with images.""" + from openadapt.drivers import google + return google.prompt( + prompt, + system_prompt=system_prompt, + images=images, + model_name="models/gemini-1.5-pro-latest" + ) + + def __del__(self): + """Clean up OmniParser service when done.""" + try: + self.omniparser_provider.stack.stop_service() + except: + pass \ No newline at end of file