From 22ebc39e9f9560f85e11fc638d3faefa86d1a560 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 18 Feb 2025 15:43:29 -0500 Subject: [PATCH 01/13] add working omniparser deploy.py, Dockerfile, pyproject.toml, README.md, .env.example, .dockerignore --- deploy/.env.example | 3 + deploy/README.md | 10 + deploy/deploy/models/omniparser/.dockerignore | 20 + deploy/deploy/models/omniparser/Dockerfile | 53 ++ deploy/deploy/models/omniparser/deploy.py | 547 ++++++++++++++++++ deploy/pyproject.toml | 20 + 6 files changed, 653 insertions(+) create mode 100644 deploy/.env.example create mode 100644 deploy/README.md create mode 100644 deploy/deploy/models/omniparser/.dockerignore create mode 100644 deploy/deploy/models/omniparser/Dockerfile create mode 100644 deploy/deploy/models/omniparser/deploy.py create mode 100644 deploy/pyproject.toml diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 000000000..6b0f4de4f --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,3 @@ +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 000000000..2e53469ce --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,10 @@ +``` +# First time setup +cd deploy +uv venv +source .venv/bin/activate +uv pip install -e . + +# Subsequent usage +python deploy/models/omniparser/deploy.py start +``` diff --git a/deploy/deploy/models/omniparser/.dockerignore b/deploy/deploy/models/omniparser/.dockerignore new file mode 100644 index 000000000..213bee701 --- /dev/null +++ b/deploy/deploy/models/omniparser/.dockerignore @@ -0,0 +1,20 @@ +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +env +pip-log.txt +pip-delete-this-directory.txt +.tox +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.log +.pytest_cache +.env +.venv +.DS_Store diff --git a/deploy/deploy/models/omniparser/Dockerfile b/deploy/deploy/models/omniparser/Dockerfile new file mode 100644 index 000000000..d6b63e398 --- /dev/null +++ b/deploy/deploy/models/omniparser/Dockerfile @@ -0,0 +1,53 @@ +FROM nvidia/cuda:12.3.1-devel-ubuntu22.04 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git-lfs \ + wget \ + libgl1 \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:$PATH" + +RUN conda create -n omni python=3.12 && \ + echo "source activate omni" > ~/.bashrc +ENV CONDA_DEFAULT_ENV=omni +ENV PATH="/opt/conda/envs/omni/bin:$PATH" + +WORKDIR /app + +RUN git clone https://github.com/microsoft/OmniParser.git && \ + cd OmniParser && \ + git lfs install && \ + git lfs pull + +WORKDIR /app/OmniParser + +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + pip uninstall -y opencv-python opencv-python-headless && \ + pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \ + pip install -r requirements.txt && \ + pip install huggingface_hub fastapi uvicorn + +# Download V2 weights +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + mkdir -p /app/OmniParser/weights && \ + cd /app/OmniParser && \ + rm -rf weights/icon_detect weights/icon_caption weights/icon_caption_florence && \ + for folder in icon_caption icon_detect; do \ + huggingface-cli download microsoft/OmniParser-v2.0 --local-dir weights --repo-type model --include "$folder/*"; \ + done && \ + mv weights/icon_caption weights/icon_caption_florence + +CMD ["python3", "/app/OmniParser/omnitool/omniparserserver/omniparserserver.py", \ + "--som_model_path", "/app/OmniParser/weights/icon_detect/model.pt", \ + "--caption_model_path", "/app/OmniParser/weights/icon_caption_florence", \ + "--device", "cuda", \ + "--BOX_TRESHOLD", "0.05", \ + "--host", "0.0.0.0", \ + "--port", "8000"] diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py new file mode 100644 index 000000000..774f6e840 --- /dev/null +++ b/deploy/deploy/models/omniparser/deploy.py @@ -0,0 +1,547 @@ +import os +import subprocess +import time + +from botocore.exceptions import ClientError +from loguru import logger +from pydantic_settings import BaseSettings +import boto3 +import fire +import paramiko + +class Config(BaseSettings): + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION: str + + PROJECT_NAME: str = "omniparser" + REPO_URL: str = "https://github.com/microsoft/OmniParser.git" + AWS_EC2_AMI: str = "ami-06835d15c4de57810" + AWS_EC2_DISK_SIZE: int = 128 # GB + AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64) + AWS_EC2_USER: str = "ubuntu" + PORT: int = 8000 # FastAPI port + COMMAND_TIMEOUT: int = 600 # 10 minutes + + class Config: + env_file = ".env" + env_file_encoding = 'utf-8' + + @property + def AWS_EC2_KEY_NAME(self) -> str: + return f"{self.PROJECT_NAME}-key" + + @property + def AWS_EC2_KEY_PATH(self) -> str: + return f"./{self.AWS_EC2_KEY_NAME}.pem" + + @property + def AWS_EC2_SECURITY_GROUP(self) -> str: + return f"{self.PROJECT_NAME}-SecurityGroup" + +config = Config() + +def create_key_pair(key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH) -> str | None: + ec2_client = boto3.client('ec2', region_name=config.AWS_REGION) + try: + key_pair = ec2_client.create_key_pair(KeyName=key_name) + private_key = key_pair['KeyMaterial'] + + with open(key_path, "w") as key_file: + key_file.write(private_key) + os.chmod(key_path, 0o400) # Set read-only permissions + + logger.info(f"Key pair {key_name} created and saved to {key_path}") + return key_name + except ClientError as e: + logger.error(f"Error creating key pair: {e}") + return None + +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + ec2 = boto3.client('ec2', region_name=config.AWS_REGION) + + ip_permissions = [{ + 'IpProtocol': 'tcp', + 'FromPort': port, + 'ToPort': port, + 'IpRanges': [{'CidrIp': '0.0.0.0/0'}] + } for port in ports] + + try: + response = ec2.describe_security_groups(GroupNames=[config.AWS_EC2_SECURITY_GROUP]) + security_group_id = response['SecurityGroups'][0]['GroupId'] + logger.info(f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: {security_group_id}") + + for ip_permission in ip_permissions: + try: + ec2.authorize_security_group_ingress( + GroupId=security_group_id, + IpPermissions=[ip_permission] + ) + logger.info(f"Added inbound rule for port {ip_permission['FromPort']}") + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidPermission.Duplicate': + logger.info(f"Rule for port {ip_permission['FromPort']} already exists") + else: + logger.error(f"Error adding rule for port {ip_permission['FromPort']}: {e}") + + return security_group_id + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidGroup.NotFound': + try: + response = ec2.create_security_group( + GroupName=config.AWS_EC2_SECURITY_GROUP, + Description='Security group for OmniParser deployment', + TagSpecifications=[ + { + 'ResourceType': 'security-group', + 'Tags': [{'Key': 'Name', 'Value': config.PROJECT_NAME}] + } + ] + ) + security_group_id = response['GroupId'] + logger.info(f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' with ID: {security_group_id}") + + ec2.authorize_security_group_ingress(GroupId=security_group_id, IpPermissions=ip_permissions) + logger.info(f"Added inbound rules for ports {ports}") + + return security_group_id + except ClientError as e: + logger.error(f"Error creating security group: {e}") + return None + else: + logger.error(f"Error describing security groups: {e}") + return None + +def deploy_ec2_instance( + ami: str = config.AWS_EC2_AMI, + instance_type: str = config.AWS_EC2_INSTANCE_TYPE, + project_name: str = config.PROJECT_NAME, + key_name: str = config.AWS_EC2_KEY_NAME, + disk_size: int = config.AWS_EC2_DISK_SIZE, +) -> tuple[str | None, str | None]: + ec2 = boto3.resource('ec2') + ec2_client = boto3.client('ec2') + + # Check for existing instances first + instances = ec2.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}, + {'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopped']} + ] + ) + + existing_instance = None + for instance in instances: + existing_instance = instance + if instance.state['Name'] == 'running': + logger.info(f"Instance already running: ID - {instance.id}, IP - {instance.public_ip_address}") + break + elif instance.state['Name'] == 'stopped': + logger.info(f"Starting existing stopped instance: ID - {instance.id}") + ec2_client.start_instances(InstanceIds=[instance.id]) + instance.wait_until_running() + instance.reload() + logger.info(f"Instance started: ID - {instance.id}, IP - {instance.public_ip_address}") + break + + # If we found an existing instance, ensure we have its key + if existing_instance: + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.warning(f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance.") + logger.warning("You'll need to use the original key file to connect to this instance.") + logger.warning("Consider terminating the instance with 'deploy.py stop' and starting fresh.") + return None, None + return existing_instance.id, existing_instance.public_ip_address + + # No existing instance found, create new one with new key pair + security_group_id = get_or_create_security_group_id() + if not security_group_id: + logger.error("Unable to retrieve security group ID. Instance deployment aborted.") + return None, None + + # Create new key pair + try: + if os.path.exists(config.AWS_EC2_KEY_PATH): + logger.info(f"Removing existing key file {config.AWS_EC2_KEY_PATH}") + os.remove(config.AWS_EC2_KEY_PATH) + + try: + ec2_client.delete_key_pair(KeyName=key_name) + logger.info(f"Deleted existing key pair {key_name}") + except ClientError: + pass # Key pair doesn't exist, which is fine + + if not create_key_pair(key_name): + logger.error("Failed to create key pair") + return None, None + except Exception as e: + logger.error(f"Error managing key pair: {e}") + return None, None + + # Create new instance + ebs_config = { + 'DeviceName': '/dev/sda1', + 'Ebs': { + 'VolumeSize': disk_size, + 'VolumeType': 'gp3', + 'DeleteOnTermination': True + }, + } + + new_instance = ec2.create_instances( + ImageId=ami, + MinCount=1, + MaxCount=1, + InstanceType=instance_type, + KeyName=key_name, + SecurityGroupIds=[security_group_id], + BlockDeviceMappings=[ebs_config], + TagSpecifications=[ + { + 'ResourceType': 'instance', + 'Tags': [{'Key': 'Name', 'Value': project_name}] + }, + ] + )[0] + + new_instance.wait_until_running() + new_instance.reload() + logger.info(f"New instance created: ID - {new_instance.id}, IP - {new_instance.public_ip_address}") + return new_instance.id, new_instance.public_ip_address + +def configure_ec2_instance( + instance_id: str | None = None, + instance_ip: str | None = None, + max_ssh_retries: int = 20, + ssh_retry_delay: int = 20, + max_cmd_retries: int = 20, + cmd_retry_delay: int = 30, +) -> tuple[str | None, str | None]: + if not instance_id: + ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() + else: + ec2_instance_id = instance_id + ec2_instance_ip = instance_ip + + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh_retries = 0 + while ssh_retries < max_ssh_retries: + try: + ssh_client.connect(hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key) + break + except Exception as e: + ssh_retries += 1 + logger.error(f"SSH connection attempt {ssh_retries} failed: {e}") + if ssh_retries < max_ssh_retries: + logger.info(f"Retrying SSH connection in {ssh_retry_delay} seconds...") + time.sleep(ssh_retry_delay) + else: + logger.error("Maximum SSH connection attempts reached. Aborting.") + return None, None + + commands = [ + "sudo apt-get update", + "sudo apt-get install -y ca-certificates curl gnupg", + "sudo install -m 0755 -d /etc/apt/keyrings", + "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo dd of=/etc/apt/keyrings/docker.gpg", + "sudo chmod a+r /etc/apt/keyrings/docker.gpg", + 'echo "deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu "$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null', + "sudo apt-get update", + "sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin", + "sudo systemctl start docker", + "sudo systemctl enable docker", + "sudo usermod -a -G docker ${USER}", + "sudo docker system prune -af --volumes", + f"sudo docker rm -f {config.PROJECT_NAME}-container || true", + ] + + for command in commands: + logger.info(f"Executing command: {command}") + cmd_retries = 0 + while cmd_retries < max_cmd_retries: + stdin, stdout, stderr = ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() + + if exit_status == 0: + logger.info(f"Command executed successfully") + break + else: + error_message = stderr.read() + if "Could not get lock" in str(error_message): + cmd_retries += 1 + logger.warning(f"dpkg is locked, retrying in {cmd_retry_delay} seconds... Attempt {cmd_retries}/{max_cmd_retries}") + time.sleep(cmd_retry_delay) + else: + logger.error(f"Error in command: {command}, Exit Status: {exit_status}, Error: {error_message}") + break + + ssh_client.close() + return ec2_instance_id, ec2_instance_ip + +class Deploy: + @staticmethod + def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode('utf-8', errors='replace') + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode('utf-8', errors='replace') + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode('utf-8', errors='replace') + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode('utf-8', errors='replace') + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") + + @staticmethod + def start() -> None: + try: + instance_id, instance_ip = configure_ec2_instance() + assert instance_ip, f"invalid {instance_ip=}" + + # Trigger driver installation via login shell + Deploy.ssh(non_interactive=True) + + # Get the directory containing deploy.py + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Define files to copy + files_to_copy = { + "Dockerfile": os.path.join(current_dir, "Dockerfile"), + ".dockerignore": os.path.join(current_dir, ".dockerignore"), + } + + # Copy files to instance + for filename, filepath in files_to_copy.items(): + if os.path.exists(filepath): + logger.info(f"Copying {filename} to instance...") + subprocess.run([ + "scp", + "-i", config.AWS_EC2_KEY_PATH, + "-o", "StrictHostKeyChecking=no", + filepath, + f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}" + ], check=True) + else: + logger.warning(f"File not found: {filepath}") + + # Connect to instance and execute commands + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + logger.info(f"Connecting to {instance_ip}...") + ssh_client.connect( + hostname=instance_ip, + username=config.AWS_EC2_USER, + pkey=key, + timeout=30 + ) + + setup_commands = [ + "rm -rf OmniParser", # Clean up any existing repo + f"git clone {config.REPO_URL}", + "cp Dockerfile .dockerignore OmniParser/", + ] + + # Execute setup commands + for command in setup_commands: + logger.info(f"Executing setup command: {command}") + Deploy.execute_command(ssh_client, command) + + # Build and run Docker container + docker_commands = [ + # Remove any existing container + "sudo docker rm -f omniparser-container || true", + # Remove any existing image + "sudo docker rmi omniparser || true", + # Build new image + "cd OmniParser && sudo docker build --progress=plain -t omniparser .", + # Run new container + "sudo docker run -d -p 8000:8000 --gpus all --name omniparser-container omniparser" + ] + + # Execute Docker commands + for command in docker_commands: + logger.info(f"Executing Docker command: {command}") + Deploy.execute_command(ssh_client, command) + + # Wait for container to start and check its logs + logger.info("Waiting for container to start...") + time.sleep(10) # Give container time to start + Deploy.execute_command(ssh_client, "docker logs omniparser-container") + + # Wait for server to become responsive + logger.info("Waiting for server to become responsive...") + max_retries = 30 + retry_delay = 10 + server_ready = False + + for attempt in range(max_retries): + try: + # Check if server is responding + check_command = f"curl -s http://localhost:{config.PORT}/probe/" + Deploy.execute_command(ssh_client, check_command) + server_ready = True + break + except Exception as e: + logger.warning(f"Server not ready (attempt {attempt + 1}/{max_retries}): {e}") + if attempt < max_retries - 1: + logger.info(f"Waiting {retry_delay} seconds before next attempt...") + time.sleep(retry_delay) + + if not server_ready: + raise RuntimeError("Server failed to start properly") + + # Final status check + Deploy.execute_command(ssh_client, "docker ps | grep omniparser-container") + + server_url = f"http://{instance_ip}:{config.PORT}" + logger.info(f"Deployment complete. Server running at: {server_url}") + + # Verify server is accessible from outside + try: + import requests + response = requests.get(f"{server_url}/probe/", timeout=10) + if response.status_code == 200: + logger.info("Server is accessible from outside!") + else: + logger.warning(f"Server responded with status code: {response.status_code}") + except Exception as e: + logger.warning(f"Could not verify external access: {e}") + + except Exception as e: + logger.error(f"Error during deployment: {e}") + # Get container logs for debugging + try: + Deploy.execute_command(ssh_client, "docker logs omniparser-container") + except: + pass + raise + + finally: + ssh_client.close() + + except Exception as e: + logger.error(f"Deployment failed: {e}") + # Attempt cleanup on failure + try: + Deploy.stop() + except Exception as cleanup_error: + logger.error(f"Cleanup after failure also failed: {cleanup_error}") + raise + + logger.info("Deployment completed successfully!") + + @staticmethod + def status() -> None: + ec2 = boto3.resource('ec2') + instances = ec2.instances.filter( + Filters=[{'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}] + ) + + for instance in instances: + public_ip = instance.public_ip_address + if public_ip: + server_url = f"http://{public_ip}:{config.PORT}" + logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, URL: {server_url}") + else: + logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, URL: Not available (no public IP)") + + @staticmethod + def ssh(non_interactive: bool = False) -> None: + """SSH into the running instance.""" + # Get instance IP + ec2 = boto3.resource('ec2') + instances = ec2.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}, + {'Name': 'instance-state-name', 'Values': ['running']} + ] + ) + + instance = next(iter(instances), None) + if not instance: + logger.error("No running instance found") + return + + ip = instance.public_ip_address + if not ip: + logger.error("Instance has no public IP") + return + + # Check if key file exists + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.error(f"Key file not found: {config.AWS_EC2_KEY_PATH}") + return + + if non_interactive: + # Simulate full login by forcing all initialization scripts + ssh_command = [ + "ssh", + "-o", "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{ip}", + "-t", # Allocate a pseudo-terminal + "-tt", # Force pseudo-terminal allocation + "bash --login -c 'exit'" # Force a full login shell and exit immediately + ] + else: + # Build and execute SSH command + ssh_command = f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no {config.AWS_EC2_USER}@{ip}" + logger.info(f"Connecting with: {ssh_command}") + os.system(ssh_command) + return + + # Execute the SSH command for non-interactive mode + try: + subprocess.run(ssh_command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"SSH connection failed: {e}") + +if __name__ == "__main__": + fire.Fire(Deploy) diff --git a/deploy/pyproject.toml b/deploy/pyproject.toml new file mode 100644 index 000000000..224d4b956 --- /dev/null +++ b/deploy/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "deploy" +version = "0.1.0" +authors = [ + { name="Richard Abrich", email="richard@openadapt.ai" }, +] +description = "Deployment tools for OpenAdapt models" +requires-python = ">=3.10" +dependencies = [ + "boto3>=1.36.22", + "fire>=0.7.0", + "loguru>=0.7.0", + "paramiko>=3.5.1", + "pydantic>=2.10.6", + "pydantic-settings>=2.7.1", +] From 3e5251a80638c33e89df0e53dce7fe1c88170eff Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 18 Feb 2025 15:47:51 -0500 Subject: [PATCH 02/13] add Deploy.stop --- deploy/deploy/models/omniparser/deploy.py | 40 +++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index 774f6e840..9b4045d4b 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -543,5 +543,45 @@ def ssh(non_interactive: bool = False) -> None: except subprocess.CalledProcessError as e: logger.error(f"SSH connection failed: {e}") + @staticmethod + def stop( + project_name: str = config.PROJECT_NAME, + security_group_name: str = config.AWS_EC2_SECURITY_GROUP, + ) -> None: + """ + Terminates the EC2 instance and deletes the associated security group. + + Args: + project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. Defaults to config.AWS_EC2_SECURITY_GROUP. + """ + ec2_resource = boto3.resource('ec2') + ec2_client = boto3.client('ec2') + + # Terminate EC2 instances + instances = ec2_resource.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [project_name]}, + {'Name': 'instance-state-name', 'Values': ['pending', 'running', 'shutting-down', 'stopped', 'stopping']} + ] + ) + + for instance in instances: + logger.info(f"Terminating instance: ID - {instance.id}") + instance.terminate() + instance.wait_until_terminated() + logger.info(f"Instance {instance.id} terminated successfully.") + + # Delete security group + try: + ec2_client.delete_security_group(GroupName=security_group_name) + logger.info(f"Deleted security group: {security_group_name}") + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidGroup.NotFound': + logger.info(f"Security group {security_group_name} does not exist or already deleted.") + else: + logger.error(f"Error deleting security group: {e}") + + if __name__ == "__main__": fire.Fire(Deploy) From 96b0f5ae7565cd3585f35f7110d6411dfa65c571 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 18 Feb 2025 16:55:56 -0500 Subject: [PATCH 03/13] add client.py --- deploy/deploy/models/omniparser/client.py | 103 ++++++++++++++++++++++ deploy/pyproject.toml | 2 + 2 files changed, 105 insertions(+) create mode 100644 deploy/deploy/models/omniparser/client.py diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py new file mode 100644 index 000000000..852983a57 --- /dev/null +++ b/deploy/deploy/models/omniparser/client.py @@ -0,0 +1,103 @@ +import argparse +import base64 +import io +import requests + +from loguru import logger +from PIL import Image, ImageDraw, ImageFont + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Omniparser Client') + parser.add_argument('--image_path', type=str, required=True, help='Path to the image') + parser.add_argument('--server_url', type=str, required=True, + help='URL of the Omniparser server (e.g., http://localhost:8000)') + args = parser.parse_args() + return args + +def image_to_base64(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + +def base64_to_image(base64_string): + img_data = base64.b64decode(base64_string) + return Image.open(io.BytesIO(img_data)) + +def plot_results(original_image_path, som_image_base64, parsed_content_list): + # Open original image + image = Image.open(original_image_path) + width, height = image.size + + # Create drawable image + draw = ImageDraw.Draw(image) + + # Draw bounding boxes and labels + for item in parsed_content_list: + # Get normalized coordinates and convert to pixel coordinates + x1, y1, x2, y2 = item['bbox'] + x1 = int(x1 * width) + y1 = int(y1 * height) + x2 = int(x2 * width) + y2 = int(y2 * height) + + label = item['content'] + + # Draw rectangle + draw.rectangle([(x1, y1), (x2, y2)], outline='red', width=2) + + # Draw label background + text_bbox = draw.textbbox((x1, y1), label) + draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], + fill='white') + + # Draw label text + draw.text((x1, y1), label, fill='red') + + # Show image + image.show() + +def main(): + args = parse_arguments() + + # Remove trailing slash from server_url if present + server_url = args.server_url.rstrip('/') + + # Convert image to base64 + base64_image = image_to_base64(args.image_path) + + # Prepare request + url = f"{server_url}/parse/" + payload = {"base64_image": base64_image} + + try: + # First, check if the server is available + probe_url = f"{server_url}/probe/" + probe_response = requests.get(probe_url) + probe_response.raise_for_status() + logger.info("Server is available") + + # Make request to API + response = requests.post(url, json=payload) + response.raise_for_status() + + # Parse response + result = response.json() + som_image_base64 = result['som_image_base64'] + parsed_content_list = result['parsed_content_list'] + + # Plot results + plot_results(args.image_path, som_image_base64, parsed_content_list) + + # Print latency + logger.info(f"API Latency: {result['latency']:.2f} seconds") + + except requests.exceptions.ConnectionError: + logger.error(f"Error: Could not connect to server at {server_url}") + logger.error("Please check if the server is running and the URL is correct") + except requests.exceptions.RequestException as e: + logger.error(f"Error making request to API: {e}") + except Exception as e: + logger.error(f"Error: {e}") + +if __name__ == "__main__": + main() + diff --git a/deploy/pyproject.toml b/deploy/pyproject.toml index 224d4b956..835b62424 100644 --- a/deploy/pyproject.toml +++ b/deploy/pyproject.toml @@ -15,6 +15,8 @@ dependencies = [ "fire>=0.7.0", "loguru>=0.7.0", "paramiko>=3.5.1", + "pillow>=11.1.0", "pydantic>=2.10.6", "pydantic-settings>=2.7.1", + "requests>=2.32.3", ] From 959c78b3cfab4b51a3d3090ec7badb443e80f719 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 18 Feb 2025 17:07:10 -0500 Subject: [PATCH 04/13] ruff --- deploy/deploy/models/omniparser/client.py | 77 +++-- deploy/deploy/models/omniparser/deploy.py | 344 +++++++++++++++------- 2 files changed, 276 insertions(+), 145 deletions(-) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index 852983a57..c19c04e56 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -4,92 +4,105 @@ import requests from loguru import logger -from PIL import Image, ImageDraw, ImageFont +from PIL import Image, ImageDraw + def parse_arguments(): - parser = argparse.ArgumentParser(description='Omniparser Client') - parser.add_argument('--image_path', type=str, required=True, help='Path to the image') - parser.add_argument('--server_url', type=str, required=True, - help='URL of the Omniparser server (e.g., http://localhost:8000)') + parser = argparse.ArgumentParser(description="Omniparser Client") + parser.add_argument( + "--image_path", type=str, required=True, help="Path to the image" + ) + parser.add_argument( + "--server_url", + type=str, + required=True, + help="URL of the Omniparser server (e.g., http://localhost:8000)", + ) args = parser.parse_args() return args + def image_to_base64(image_path): with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') + return base64.b64encode(image_file.read()).decode("utf-8") + def base64_to_image(base64_string): img_data = base64.b64decode(base64_string) return Image.open(io.BytesIO(img_data)) + def plot_results(original_image_path, som_image_base64, parsed_content_list): # Open original image image = Image.open(original_image_path) width, height = image.size - + # Create drawable image draw = ImageDraw.Draw(image) - + # Draw bounding boxes and labels for item in parsed_content_list: # Get normalized coordinates and convert to pixel coordinates - x1, y1, x2, y2 = item['bbox'] + x1, y1, x2, y2 = item["bbox"] x1 = int(x1 * width) y1 = int(y1 * height) x2 = int(x2 * width) y2 = int(y2 * height) - - label = item['content'] - + + label = item["content"] + # Draw rectangle - draw.rectangle([(x1, y1), (x2, y2)], outline='red', width=2) - + draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2) + # Draw label background text_bbox = draw.textbbox((x1, y1), label) - draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], - fill='white') - + draw.rectangle( + [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], + fill="white", + ) + # Draw label text - draw.text((x1, y1), label, fill='red') - + draw.text((x1, y1), label, fill="red") + # Show image image.show() + def main(): args = parse_arguments() - + # Remove trailing slash from server_url if present - server_url = args.server_url.rstrip('/') - + server_url = args.server_url.rstrip("/") + # Convert image to base64 base64_image = image_to_base64(args.image_path) - + # Prepare request url = f"{server_url}/parse/" payload = {"base64_image": base64_image} - + try: # First, check if the server is available probe_url = f"{server_url}/probe/" probe_response = requests.get(probe_url) probe_response.raise_for_status() logger.info("Server is available") - + # Make request to API response = requests.post(url, json=payload) response.raise_for_status() - + # Parse response result = response.json() - som_image_base64 = result['som_image_base64'] - parsed_content_list = result['parsed_content_list'] - + som_image_base64 = result["som_image_base64"] + parsed_content_list = result["parsed_content_list"] + # Plot results plot_results(args.image_path, som_image_base64, parsed_content_list) - + # Print latency logger.info(f"API Latency: {result['latency']:.2f} seconds") - + except requests.exceptions.ConnectionError: logger.error(f"Error: Could not connect to server at {server_url}") logger.error("Please check if the server is running and the URL is correct") @@ -98,6 +111,6 @@ def main(): except Exception as e: logger.error(f"Error: {e}") + if __name__ == "__main__": main() - diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index 9b4045d4b..f9c52cd7e 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -9,6 +9,7 @@ import fire import paramiko + class Config(BaseSettings): AWS_ACCESS_KEY_ID: str AWS_SECRET_ACCESS_KEY: str @@ -25,7 +26,7 @@ class Config(BaseSettings): class Config: env_file = ".env" - env_file_encoding = 'utf-8' + env_file_encoding = "utf-8" @property def AWS_EC2_KEY_NAME(self) -> str: @@ -39,13 +40,17 @@ def AWS_EC2_KEY_PATH(self) -> str: def AWS_EC2_SECURITY_GROUP(self) -> str: return f"{self.PROJECT_NAME}-SecurityGroup" + config = Config() -def create_key_pair(key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH) -> str | None: - ec2_client = boto3.client('ec2', region_name=config.AWS_REGION) + +def create_key_pair( + key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH +) -> str | None: + ec2_client = boto3.client("ec2", region_name=config.AWS_REGION) try: key_pair = ec2_client.create_key_pair(KeyName=key_name) - private_key = key_pair['KeyMaterial'] + private_key = key_pair["KeyMaterial"] with open(key_path, "w") as key_file: key_file.write(private_key) @@ -57,52 +62,69 @@ def create_key_pair(key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = con logger.error(f"Error creating key pair: {e}") return None -def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: - ec2 = boto3.client('ec2', region_name=config.AWS_REGION) - ip_permissions = [{ - 'IpProtocol': 'tcp', - 'FromPort': port, - 'ToPort': port, - 'IpRanges': [{'CidrIp': '0.0.0.0/0'}] - } for port in ports] +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + ec2 = boto3.client("ec2", region_name=config.AWS_REGION) + + ip_permissions = [ + { + "IpProtocol": "tcp", + "FromPort": port, + "ToPort": port, + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + for port in ports + ] try: - response = ec2.describe_security_groups(GroupNames=[config.AWS_EC2_SECURITY_GROUP]) - security_group_id = response['SecurityGroups'][0]['GroupId'] - logger.info(f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: {security_group_id}") + response = ec2.describe_security_groups( + GroupNames=[config.AWS_EC2_SECURITY_GROUP] + ) + security_group_id = response["SecurityGroups"][0]["GroupId"] + logger.info( + f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: " + f"{security_group_id}" + ) for ip_permission in ip_permissions: try: ec2.authorize_security_group_ingress( - GroupId=security_group_id, - IpPermissions=[ip_permission] + GroupId=security_group_id, IpPermissions=[ip_permission] ) logger.info(f"Added inbound rule for port {ip_permission['FromPort']}") except ClientError as e: - if e.response['Error']['Code'] == 'InvalidPermission.Duplicate': - logger.info(f"Rule for port {ip_permission['FromPort']} already exists") + if e.response["Error"]["Code"] == "InvalidPermission.Duplicate": + logger.info( + f"Rule for port {ip_permission['FromPort']} already exists" + ) else: - logger.error(f"Error adding rule for port {ip_permission['FromPort']}: {e}") + logger.error( + f"Error adding rule for port {ip_permission['FromPort']}: {e}" + ) return security_group_id except ClientError as e: - if e.response['Error']['Code'] == 'InvalidGroup.NotFound': + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": try: response = ec2.create_security_group( GroupName=config.AWS_EC2_SECURITY_GROUP, - Description='Security group for OmniParser deployment', + Description="Security group for OmniParser deployment", TagSpecifications=[ { - 'ResourceType': 'security-group', - 'Tags': [{'Key': 'Name', 'Value': config.PROJECT_NAME}] + "ResourceType": "security-group", + "Tags": [{"Key": "Name", "Value": config.PROJECT_NAME}], } - ] + ], + ) + security_group_id = response["GroupId"] + logger.info( + f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' " + f"with ID: {security_group_id}" ) - security_group_id = response['GroupId'] - logger.info(f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' with ID: {security_group_id}") - ec2.authorize_security_group_ingress(GroupId=security_group_id, IpPermissions=ip_permissions) + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) logger.info(f"Added inbound rules for ports {ports}") return security_group_id @@ -113,6 +135,7 @@ def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str logger.error(f"Error describing security groups: {e}") return None + def deploy_ec2_instance( ami: str = config.AWS_EC2_AMI, instance_type: str = config.AWS_EC2_INSTANCE_TYPE, @@ -120,44 +143,62 @@ def deploy_ec2_instance( key_name: str = config.AWS_EC2_KEY_NAME, disk_size: int = config.AWS_EC2_DISK_SIZE, ) -> tuple[str | None, str | None]: - ec2 = boto3.resource('ec2') - ec2_client = boto3.client('ec2') + ec2 = boto3.resource("ec2") + ec2_client = boto3.client("ec2") # Check for existing instances first instances = ec2.instances.filter( Filters=[ - {'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}, - {'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopped']} + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + { + "Name": "instance-state-name", + "Values": ["running", "pending", "stopped"], + }, ] ) existing_instance = None for instance in instances: existing_instance = instance - if instance.state['Name'] == 'running': - logger.info(f"Instance already running: ID - {instance.id}, IP - {instance.public_ip_address}") + if instance.state["Name"] == "running": + logger.info( + f"Instance already running: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) break - elif instance.state['Name'] == 'stopped': + elif instance.state["Name"] == "stopped": logger.info(f"Starting existing stopped instance: ID - {instance.id}") ec2_client.start_instances(InstanceIds=[instance.id]) instance.wait_until_running() instance.reload() - logger.info(f"Instance started: ID - {instance.id}, IP - {instance.public_ip_address}") + logger.info( + f"Instance started: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) break # If we found an existing instance, ensure we have its key if existing_instance: if not os.path.exists(config.AWS_EC2_KEY_PATH): - logger.warning(f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance.") - logger.warning("You'll need to use the original key file to connect to this instance.") - logger.warning("Consider terminating the instance with 'deploy.py stop' and starting fresh.") + logger.warning( + f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance." + ) + logger.warning( + "You'll need to use the original key file to connect to this instance." + ) + logger.warning( + "Consider terminating the instance with 'deploy.py stop' and starting " + "fresh." + ) return None, None return existing_instance.id, existing_instance.public_ip_address # No existing instance found, create new one with new key pair security_group_id = get_or_create_security_group_id() if not security_group_id: - logger.error("Unable to retrieve security group ID. Instance deployment aborted.") + logger.error( + "Unable to retrieve security group ID. Instance deployment aborted." + ) return None, None # Create new key pair @@ -165,7 +206,7 @@ def deploy_ec2_instance( if os.path.exists(config.AWS_EC2_KEY_PATH): logger.info(f"Removing existing key file {config.AWS_EC2_KEY_PATH}") os.remove(config.AWS_EC2_KEY_PATH) - + try: ec2_client.delete_key_pair(KeyName=key_name) logger.info(f"Deleted existing key pair {key_name}") @@ -181,11 +222,11 @@ def deploy_ec2_instance( # Create new instance ebs_config = { - 'DeviceName': '/dev/sda1', - 'Ebs': { - 'VolumeSize': disk_size, - 'VolumeType': 'gp3', - 'DeleteOnTermination': True + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": disk_size, + "VolumeType": "gp3", + "DeleteOnTermination": True, }, } @@ -199,17 +240,21 @@ def deploy_ec2_instance( BlockDeviceMappings=[ebs_config], TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [{'Key': 'Name', 'Value': project_name}] + "ResourceType": "instance", + "Tags": [{"Key": "Name", "Value": project_name}], }, - ] + ], )[0] new_instance.wait_until_running() new_instance.reload() - logger.info(f"New instance created: ID - {new_instance.id}, IP - {new_instance.public_ip_address}") + logger.info( + f"New instance created: ID - {new_instance.id}, " + f"IP - {new_instance.public_ip_address}" + ) return new_instance.id, new_instance.public_ip_address + def configure_ec2_instance( instance_id: str | None = None, instance_ip: str | None = None, @@ -231,7 +276,9 @@ def configure_ec2_instance( ssh_retries = 0 while ssh_retries < max_ssh_retries: try: - ssh_client.connect(hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key) + ssh_client.connect( + hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key + ) break except Exception as e: ssh_retries += 1 @@ -247,11 +294,23 @@ def configure_ec2_instance( "sudo apt-get update", "sudo apt-get install -y ca-certificates curl gnupg", "sudo install -m 0755 -d /etc/apt/keyrings", - "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo dd of=/etc/apt/keyrings/docker.gpg", + ( + "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | " + "sudo dd of=/etc/apt/keyrings/docker.gpg" + ), "sudo chmod a+r /etc/apt/keyrings/docker.gpg", - 'echo "deb [arch="$(dpkg --print-architecture)" signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu "$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null', + ( + 'echo "deb [arch="$(dpkg --print-architecture)" ' + "signed-by=/etc/apt/keyrings/docker.gpg] " + "https://download.docker.com/linux/ubuntu " + '"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | ' + "sudo tee /etc/apt/sources.list.d/docker.list > /dev/null" + ), "sudo apt-get update", - "sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin", + ( + "sudo apt-get install -y docker-ce docker-ce-cli containerd.io " + "docker-buildx-plugin docker-compose-plugin" + ), "sudo systemctl start docker", "sudo systemctl enable docker", "sudo usermod -a -G docker ${USER}", @@ -267,72 +326,81 @@ def configure_ec2_instance( exit_status = stdout.channel.recv_exit_status() if exit_status == 0: - logger.info(f"Command executed successfully") + logger.info("Command executed successfully") break else: error_message = stderr.read() if "Could not get lock" in str(error_message): cmd_retries += 1 - logger.warning(f"dpkg is locked, retrying in {cmd_retry_delay} seconds... Attempt {cmd_retries}/{max_cmd_retries}") + logger.warning( + f"dpkg is locked, retrying in {cmd_retry_delay} seconds... " + f"Attempt {cmd_retries}/{max_cmd_retries}" + ) time.sleep(cmd_retry_delay) else: - logger.error(f"Error in command: {command}, Exit Status: {exit_status}, Error: {error_message}") + logger.error( + f"Error in command: {command}, Exit Status: {exit_status}, " + f"Error: {error_message}" + ) break ssh_client.close() return ec2_instance_id, ec2_instance_ip + class Deploy: @staticmethod def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: """Execute a command and handle its output safely.""" logger.info(f"Executing: {command}") stdin, stdout, stderr = ssh_client.exec_command( - command, + command, timeout=config.COMMAND_TIMEOUT, # get_pty=True ) - + # Stream output in real-time while not stdout.channel.exit_status_ready(): if stdout.channel.recv_ready(): try: - line = stdout.channel.recv(1024).decode('utf-8', errors='replace') + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") if line.strip(): # Only log non-empty lines logger.info(line.strip()) except Exception as e: logger.warning(f"Error decoding stdout: {e}") - + if stdout.channel.recv_stderr_ready(): try: - line = stdout.channel.recv_stderr(1024).decode('utf-8', errors='replace') + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) if line.strip(): # Only log non-empty lines logger.error(line.strip()) except Exception as e: logger.warning(f"Error decoding stderr: {e}") - + exit_status = stdout.channel.recv_exit_status() - + # Capture any remaining output try: - remaining_stdout = stdout.read().decode('utf-8', errors='replace') + remaining_stdout = stdout.read().decode("utf-8", errors="replace") if remaining_stdout.strip(): logger.info(remaining_stdout.strip()) except Exception as e: logger.warning(f"Error decoding remaining stdout: {e}") - + try: - remaining_stderr = stderr.read().decode('utf-8', errors='replace') + remaining_stderr = stderr.read().decode("utf-8", errors="replace") if remaining_stderr.strip(): logger.error(remaining_stderr.strip()) except Exception as e: logger.warning(f"Error decoding remaining stderr: {e}") - + if exit_status != 0: error_msg = f"Command failed with exit status {exit_status}: {command}" logger.error(error_msg) raise RuntimeError(error_msg) - + logger.info(f"Successfully executed: {command}") @staticmethod @@ -346,7 +414,7 @@ def start() -> None: # Get the directory containing deploy.py current_dir = os.path.dirname(os.path.abspath(__file__)) - + # Define files to copy files_to_copy = { "Dockerfile": os.path.join(current_dir, "Dockerfile"), @@ -357,13 +425,18 @@ def start() -> None: for filename, filepath in files_to_copy.items(): if os.path.exists(filepath): logger.info(f"Copying {filename} to instance...") - subprocess.run([ - "scp", - "-i", config.AWS_EC2_KEY_PATH, - "-o", "StrictHostKeyChecking=no", - filepath, - f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}" - ], check=True) + subprocess.run( + [ + "scp", + "-i", + config.AWS_EC2_KEY_PATH, + "-o", + "StrictHostKeyChecking=no", + filepath, + f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}", + ], + check=True, + ) else: logger.warning(f"File not found: {filepath}") @@ -371,14 +444,14 @@ def start() -> None: key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) ssh_client = paramiko.SSHClient() ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - + try: logger.info(f"Connecting to {instance_ip}...") ssh_client.connect( hostname=instance_ip, username=config.AWS_EC2_USER, pkey=key, - timeout=30 + timeout=30, ) setup_commands = [ @@ -399,9 +472,15 @@ def start() -> None: # Remove any existing image "sudo docker rmi omniparser || true", # Build new image - "cd OmniParser && sudo docker build --progress=plain -t omniparser .", + ( + "cd OmniParser && sudo docker build --progress=plain " + "-t omniparser ." + ), # Run new container - "sudo docker run -d -p 8000:8000 --gpus all --name omniparser-container omniparser" + ( + "sudo docker run -d -p 8000:8000 --gpus all --name " + "omniparser-container omniparser" + ), ] # Execute Docker commands @@ -428,28 +507,38 @@ def start() -> None: server_ready = True break except Exception as e: - logger.warning(f"Server not ready (attempt {attempt + 1}/{max_retries}): {e}") + logger.warning( + f"Server not ready (attempt {attempt + 1}/{max_retries}): " + f"{e}" + ) if attempt < max_retries - 1: - logger.info(f"Waiting {retry_delay} seconds before next attempt...") + logger.info( + f"Waiting {retry_delay} seconds before next attempt..." + ) time.sleep(retry_delay) if not server_ready: raise RuntimeError("Server failed to start properly") # Final status check - Deploy.execute_command(ssh_client, "docker ps | grep omniparser-container") - + Deploy.execute_command( + ssh_client, "docker ps | grep omniparser-container" + ) + server_url = f"http://{instance_ip}:{config.PORT}" logger.info(f"Deployment complete. Server running at: {server_url}") - + # Verify server is accessible from outside try: import requests + response = requests.get(f"{server_url}/probe/", timeout=10) if response.status_code == 200: logger.info("Server is accessible from outside!") else: - logger.warning(f"Server responded with status code: {response.status_code}") + logger.warning( + f"Server responded with status code: {response.status_code}" + ) except Exception as e: logger.warning(f"Could not verify external access: {e}") @@ -457,8 +546,11 @@ def start() -> None: logger.error(f"Error during deployment: {e}") # Get container logs for debugging try: - Deploy.execute_command(ssh_client, "docker logs omniparser-container") - except: + Deploy.execute_command( + ssh_client, "docker logs omniparser-container" + ) + except Exception as exc: + logger.warning(f"{exc=}") pass raise @@ -478,61 +570,73 @@ def start() -> None: @staticmethod def status() -> None: - ec2 = boto3.resource('ec2') + ec2 = boto3.resource("ec2") instances = ec2.instances.filter( - Filters=[{'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}] + Filters=[{"Name": "tag:Name", "Values": [config.PROJECT_NAME]}] ) for instance in instances: public_ip = instance.public_ip_address if public_ip: server_url = f"http://{public_ip}:{config.PORT}" - logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, URL: {server_url}") + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: {server_url}" + ) else: - logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, URL: Not available (no public IP)") + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: Not available (no public IP)" + ) @staticmethod def ssh(non_interactive: bool = False) -> None: """SSH into the running instance.""" # Get instance IP - ec2 = boto3.resource('ec2') + ec2 = boto3.resource("ec2") instances = ec2.instances.filter( Filters=[ - {'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}, - {'Name': 'instance-state-name', 'Values': ['running']} + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + {"Name": "instance-state-name", "Values": ["running"]}, ] ) - + instance = next(iter(instances), None) if not instance: logger.error("No running instance found") return - + ip = instance.public_ip_address if not ip: logger.error("Instance has no public IP") return - + # Check if key file exists if not os.path.exists(config.AWS_EC2_KEY_PATH): logger.error(f"Key file not found: {config.AWS_EC2_KEY_PATH}") return - + if non_interactive: # Simulate full login by forcing all initialization scripts ssh_command = [ "ssh", - "-o", "StrictHostKeyChecking=no", # Automatically accept new host keys - "-o", "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts - "-i", config.AWS_EC2_KEY_PATH, + "-o", + "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", + "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", + config.AWS_EC2_KEY_PATH, f"{config.AWS_EC2_USER}@{ip}", "-t", # Allocate a pseudo-terminal "-tt", # Force pseudo-terminal allocation - "bash --login -c 'exit'" # Force a full login shell and exit immediately + "bash --login -c 'exit'", # Force full login shell and exit immediately ] else: # Build and execute SSH command - ssh_command = f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no {config.AWS_EC2_USER}@{ip}" + ssh_command = ( + f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no " + f"{config.AWS_EC2_USER}@{ip}" + ) logger.info(f"Connecting with: {ssh_command}") os.system(ssh_command) return @@ -552,17 +656,28 @@ def stop( Terminates the EC2 instance and deletes the associated security group. Args: - project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME. - security_group_name (str): The name of the security group to delete. Defaults to config.AWS_EC2_SECURITY_GROUP. + project_name (str): The project name used to tag the instance. + Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. + Defaults to config.AWS_EC2_SECURITY_GROUP. """ - ec2_resource = boto3.resource('ec2') - ec2_client = boto3.client('ec2') + ec2_resource = boto3.resource("ec2") + ec2_client = boto3.client("ec2") # Terminate EC2 instances instances = ec2_resource.instances.filter( Filters=[ - {'Name': 'tag:Name', 'Values': [project_name]}, - {'Name': 'instance-state-name', 'Values': ['pending', 'running', 'shutting-down', 'stopped', 'stopping']} + {"Name": "tag:Name", "Values": [project_name]}, + { + "Name": "instance-state-name", + "Values": [ + "pending", + "running", + "shutting-down", + "stopped", + "stopping", + ], + }, ] ) @@ -577,8 +692,11 @@ def stop( ec2_client.delete_security_group(GroupName=security_group_name) logger.info(f"Deleted security group: {security_group_name}") except ClientError as e: - if e.response['Error']['Code'] == 'InvalidGroup.NotFound': - logger.info(f"Security group {security_group_name} does not exist or already deleted.") + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + logger.info( + f"Security group {security_group_name} does not exist or already " + "deleted." + ) else: logger.error(f"Error deleting security group: {e}") From 6290d17799e0d195d6faf31f87907ae364992164 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 18 Feb 2025 18:04:11 -0500 Subject: [PATCH 05/13] download OCR files during build; CLEANUP_ON_FAILURE = False --- deploy/deploy/models/omniparser/Dockerfile | 6 ++++++ deploy/deploy/models/omniparser/deploy.py | 14 +++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/deploy/deploy/models/omniparser/Dockerfile b/deploy/deploy/models/omniparser/Dockerfile index d6b63e398..f14ea7ac8 100644 --- a/deploy/deploy/models/omniparser/Dockerfile +++ b/deploy/deploy/models/omniparser/Dockerfile @@ -44,6 +44,12 @@ RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ done && \ mv weights/icon_caption weights/icon_caption_florence +# Pre-download OCR models during build +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + cd /app/OmniParser && \ + python3 -c "import easyocr; reader = easyocr.Reader(['en']); print('Downloaded EasyOCR model')" && \ + python3 -c "from paddleocr import PaddleOCR; ocr = PaddleOCR(lang='en', use_angle_cls=False, use_gpu=False, show_log=False); print('Downloaded PaddleOCR model')" + CMD ["python3", "/app/OmniParser/omnitool/omniparserserver/omniparserserver.py", \ "--som_model_path", "/app/OmniParser/weights/icon_detect/model.pt", \ "--caption_model_path", "/app/OmniParser/weights/icon_caption_florence", \ diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index f9c52cd7e..32ddb0e1e 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -10,6 +10,9 @@ import paramiko +CLEANUP_ON_FAILURE = False + + class Config(BaseSettings): AWS_ACCESS_KEY_ID: str AWS_SECRET_ACCESS_KEY: str @@ -559,11 +562,12 @@ def start() -> None: except Exception as e: logger.error(f"Deployment failed: {e}") - # Attempt cleanup on failure - try: - Deploy.stop() - except Exception as cleanup_error: - logger.error(f"Cleanup after failure also failed: {cleanup_error}") + if CLEANUP_ON_FAILURE: + # Attempt cleanup on failure + try: + Deploy.stop() + except Exception as cleanup_error: + logger.error(f"Cleanup after failure also failed: {cleanup_error}") raise logger.info("Deployment completed successfully!") From 944549f4227a2138a7273bf22625d90b4d1f986e Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 15:45:42 -0500 Subject: [PATCH 06/13] ruff --- deploy/deploy/models/omniparser/client.py | 31 +++++++++++++++++++++++ deploy/deploy/models/omniparser/deploy.py | 21 +++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index c19c04e56..139855579 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -1,3 +1,5 @@ +"""Client module for interacting with the OmniParser server.""" + import argparse import base64 import io @@ -8,6 +10,11 @@ def parse_arguments(): + """Parse command line arguments. + + Returns: + argparse.Namespace: Parsed command line arguments + """ parser = argparse.ArgumentParser(description="Omniparser Client") parser.add_argument( "--image_path", type=str, required=True, help="Path to the image" @@ -23,16 +30,39 @@ def parse_arguments(): def image_to_base64(image_path): + """Convert an image file to base64 string. + + Args: + image_path: Path to the image file + + Returns: + str: Base64 encoded string of the image + """ with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def base64_to_image(base64_string): + """Convert a base64 string to PIL Image. + + Args: + base64_string: Base64 encoded string of an image + + Returns: + Image.Image: PIL Image object + """ img_data = base64.b64decode(base64_string) return Image.open(io.BytesIO(img_data)) def plot_results(original_image_path, som_image_base64, parsed_content_list): + """Plot parsing results on the original image. + + Args: + original_image_path: Path to the original image + som_image_base64: Base64 encoded SOM image + parsed_content_list: List of parsed content with bounding boxes + """ # Open original image image = Image.open(original_image_path) width, height = image.size @@ -69,6 +99,7 @@ def plot_results(original_image_path, som_image_base64, parsed_content_list): def main(): + """Main entry point for the client application.""" args = parse_arguments() # Remove trailing slash from server_url if present diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index 32ddb0e1e..fe13d4268 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -1,3 +1,5 @@ +"""Deployment module for OmniParser on AWS EC2.""" + import os import subprocess import time @@ -14,6 +16,8 @@ class Config(BaseSettings): + """Configuration settings for deployment.""" + AWS_ACCESS_KEY_ID: str AWS_SECRET_ACCESS_KEY: str AWS_REGION: str @@ -28,19 +32,27 @@ class Config(BaseSettings): COMMAND_TIMEOUT: int = 600 # 10 minutes class Config: + """Pydantic configuration class.""" + env_file = ".env" env_file_encoding = "utf-8" @property def AWS_EC2_KEY_NAME(self) -> str: + """Get the EC2 key pair name.""" + return f"{self.PROJECT_NAME}-key" @property def AWS_EC2_KEY_PATH(self) -> str: + """Get the path to the EC2 key file.""" + return f"./{self.AWS_EC2_KEY_NAME}.pem" @property def AWS_EC2_SECURITY_GROUP(self) -> str: + """Get the EC2 security group name.""" + return f"{self.PROJECT_NAME}-SecurityGroup" @@ -50,6 +62,15 @@ def AWS_EC2_SECURITY_GROUP(self) -> str: def create_key_pair( key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH ) -> str | None: + """Create an EC2 key pair. + + Args: + key_name: Name of the key pair + key_path: Path where to save the key file + + Returns: + str | None: Key name if successful, None otherwise + """ ec2_client = boto3.client("ec2", region_name=config.AWS_REGION) try: key_pair = ec2_client.create_key_pair(KeyName=key_name) From b6ec99dabb966dc62e2f68a13398a46d6a25fdab Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:11:48 -0500 Subject: [PATCH 07/13] lint --- deploy/deploy/models/omniparser/client.py | 14 +- deploy/deploy/models/omniparser/deploy.py | 150 +++++++++++++--------- 2 files changed, 95 insertions(+), 69 deletions(-) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index 139855579..cddc17966 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -9,7 +9,7 @@ from PIL import Image, ImageDraw -def parse_arguments(): +def parse_arguments() -> argparse.Namespace: """Parse command line arguments. Returns: @@ -29,7 +29,7 @@ def parse_arguments(): return args -def image_to_base64(image_path): +def image_to_base64(image_path: str) -> str: """Convert an image file to base64 string. Args: @@ -42,7 +42,7 @@ def image_to_base64(image_path): return base64.b64encode(image_file.read()).decode("utf-8") -def base64_to_image(base64_string): +def base64_to_image(base64_string: str) -> Image.Image: """Convert a base64 string to PIL Image. Args: @@ -55,7 +55,11 @@ def base64_to_image(base64_string): return Image.open(io.BytesIO(img_data)) -def plot_results(original_image_path, som_image_base64, parsed_content_list): +def plot_results( + original_image_path: str, + som_image_base64: str, + parsed_content_list: list[dict[str, list[float]]], +) -> None: """Plot parsing results on the original image. Args: @@ -98,7 +102,7 @@ def plot_results(original_image_path, som_image_base64, parsed_content_list): image.show() -def main(): +def main() -> None: """Main entry point for the client application.""" args = parse_arguments() diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index fe13d4268..7ad09c113 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -88,6 +88,14 @@ def create_key_pair( def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + """Get existing security group or create a new one. + + Args: + ports: List of ports to open in the security group + + Returns: + str | None: Security group ID if successful, None otherwise + """ ec2 = boto3.client("ec2", region_name=config.AWS_REGION) ip_permissions = [ @@ -167,6 +175,18 @@ def deploy_ec2_instance( key_name: str = config.AWS_EC2_KEY_NAME, disk_size: int = config.AWS_EC2_DISK_SIZE, ) -> tuple[str | None, str | None]: + """Deploy a new EC2 instance or return existing one. + + Args: + ami: AMI ID to use for the instance + instance_type: EC2 instance type + project_name: Name tag for the instance + key_name: Name of the key pair to use + disk_size: Size of the root volume in GB + + Returns: + tuple[str | None, str | None]: Instance ID and public IP if successful + """ ec2 = boto3.resource("ec2") ec2_client = boto3.client("ec2") @@ -371,64 +391,65 @@ def configure_ec2_instance( ssh_client.close() return ec2_instance_id, ec2_instance_ip +def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode("utf-8", errors="replace") + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode("utf-8", errors="replace") + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") class Deploy: - @staticmethod - def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: - """Execute a command and handle its output safely.""" - logger.info(f"Executing: {command}") - stdin, stdout, stderr = ssh_client.exec_command( - command, - timeout=config.COMMAND_TIMEOUT, - # get_pty=True - ) - - # Stream output in real-time - while not stdout.channel.exit_status_ready(): - if stdout.channel.recv_ready(): - try: - line = stdout.channel.recv(1024).decode("utf-8", errors="replace") - if line.strip(): # Only log non-empty lines - logger.info(line.strip()) - except Exception as e: - logger.warning(f"Error decoding stdout: {e}") - - if stdout.channel.recv_stderr_ready(): - try: - line = stdout.channel.recv_stderr(1024).decode( - "utf-8", errors="replace" - ) - if line.strip(): # Only log non-empty lines - logger.error(line.strip()) - except Exception as e: - logger.warning(f"Error decoding stderr: {e}") - - exit_status = stdout.channel.recv_exit_status() - - # Capture any remaining output - try: - remaining_stdout = stdout.read().decode("utf-8", errors="replace") - if remaining_stdout.strip(): - logger.info(remaining_stdout.strip()) - except Exception as e: - logger.warning(f"Error decoding remaining stdout: {e}") - - try: - remaining_stderr = stderr.read().decode("utf-8", errors="replace") - if remaining_stderr.strip(): - logger.error(remaining_stderr.strip()) - except Exception as e: - logger.warning(f"Error decoding remaining stderr: {e}") - - if exit_status != 0: - error_msg = f"Command failed with exit status {exit_status}: {command}" - logger.error(error_msg) - raise RuntimeError(error_msg) - - logger.info(f"Successfully executed: {command}") + """Class handling deployment operations for OmniParser.""" @staticmethod def start() -> None: + """Start a new deployment of OmniParser on EC2.""" try: instance_id, instance_ip = configure_ec2_instance() assert instance_ip, f"invalid {instance_ip=}" @@ -487,7 +508,7 @@ def start() -> None: # Execute setup commands for command in setup_commands: logger.info(f"Executing setup command: {command}") - Deploy.execute_command(ssh_client, command) + execute_command(ssh_client, command) # Build and run Docker container docker_commands = [ @@ -510,12 +531,12 @@ def start() -> None: # Execute Docker commands for command in docker_commands: logger.info(f"Executing Docker command: {command}") - Deploy.execute_command(ssh_client, command) + execute_command(ssh_client, command) # Wait for container to start and check its logs logger.info("Waiting for container to start...") time.sleep(10) # Give container time to start - Deploy.execute_command(ssh_client, "docker logs omniparser-container") + execute_command(ssh_client, "docker logs omniparser-container") # Wait for server to become responsive logger.info("Waiting for server to become responsive...") @@ -527,7 +548,7 @@ def start() -> None: try: # Check if server is responding check_command = f"curl -s http://localhost:{config.PORT}/probe/" - Deploy.execute_command(ssh_client, check_command) + execute_command(ssh_client, check_command) server_ready = True break except Exception as e: @@ -545,9 +566,7 @@ def start() -> None: raise RuntimeError("Server failed to start properly") # Final status check - Deploy.execute_command( - ssh_client, "docker ps | grep omniparser-container" - ) + execute_command(ssh_client, "docker ps | grep omniparser-container") server_url = f"http://{instance_ip}:{config.PORT}" logger.info(f"Deployment complete. Server running at: {server_url}") @@ -570,9 +589,7 @@ def start() -> None: logger.error(f"Error during deployment: {e}") # Get container logs for debugging try: - Deploy.execute_command( - ssh_client, "docker logs omniparser-container" - ) + execute_command(ssh_client, "docker logs omniparser-container") except Exception as exc: logger.warning(f"{exc=}") pass @@ -595,6 +612,7 @@ def start() -> None: @staticmethod def status() -> None: + """Check the status of deployed instances.""" ec2 = boto3.resource("ec2") instances = ec2.instances.filter( Filters=[{"Name": "tag:Name", "Values": [config.PROJECT_NAME]}] @@ -616,7 +634,11 @@ def status() -> None: @staticmethod def ssh(non_interactive: bool = False) -> None: - """SSH into the running instance.""" + """SSH into the running instance. + + Args: + non_interactive: If True, run in non-interactive mode + """ # Get instance IP ec2 = boto3.resource("ec2") instances = ec2.instances.filter( From bfc49d21d7f7749794ee21566313614f5da3f405 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:17:42 -0500 Subject: [PATCH 08/13] ruff --- deploy/deploy/models/omniparser/deploy.py | 116 +++++++++++----------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index 7ad09c113..b4be073cb 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -89,10 +89,10 @@ def create_key_pair( def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: """Get existing security group or create a new one. - + Args: ports: List of ports to open in the security group - + Returns: str | None: Security group ID if successful, None otherwise """ @@ -176,14 +176,14 @@ def deploy_ec2_instance( disk_size: int = config.AWS_EC2_DISK_SIZE, ) -> tuple[str | None, str | None]: """Deploy a new EC2 instance or return existing one. - + Args: ami: AMI ID to use for the instance instance_type: EC2 instance type project_name: Name tag for the instance key_name: Name of the key pair to use disk_size: Size of the root volume in GB - + Returns: tuple[str | None, str | None]: Instance ID and public IP if successful """ @@ -391,61 +391,63 @@ def configure_ec2_instance( ssh_client.close() return ec2_instance_id, ec2_instance_ip + def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: - """Execute a command and handle its output safely.""" - logger.info(f"Executing: {command}") - stdin, stdout, stderr = ssh_client.exec_command( - command, - timeout=config.COMMAND_TIMEOUT, - # get_pty=True - ) - - # Stream output in real-time - while not stdout.channel.exit_status_ready(): - if stdout.channel.recv_ready(): - try: - line = stdout.channel.recv(1024).decode("utf-8", errors="replace") - if line.strip(): # Only log non-empty lines - logger.info(line.strip()) - except Exception as e: - logger.warning(f"Error decoding stdout: {e}") - - if stdout.channel.recv_stderr_ready(): - try: - line = stdout.channel.recv_stderr(1024).decode( - "utf-8", errors="replace" - ) - if line.strip(): # Only log non-empty lines - logger.error(line.strip()) - except Exception as e: - logger.warning(f"Error decoding stderr: {e}") - - exit_status = stdout.channel.recv_exit_status() - - # Capture any remaining output - try: - remaining_stdout = stdout.read().decode("utf-8", errors="replace") - if remaining_stdout.strip(): - logger.info(remaining_stdout.strip()) - except Exception as e: - logger.warning(f"Error decoding remaining stdout: {e}") - - try: - remaining_stderr = stderr.read().decode("utf-8", errors="replace") - if remaining_stderr.strip(): - logger.error(remaining_stderr.strip()) - except Exception as e: - logger.warning(f"Error decoding remaining stderr: {e}") - - if exit_status != 0: - error_msg = f"Command failed with exit status {exit_status}: {command}" - logger.error(error_msg) - raise RuntimeError(error_msg) - - logger.info(f"Successfully executed: {command}") + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode("utf-8", errors="replace") + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode("utf-8", errors="replace") + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") + class Deploy: - """Class handling deployment operations for OmniParser.""" + """Class handling deployment operations for OmniParser.""" @staticmethod def start() -> None: @@ -635,7 +637,7 @@ def status() -> None: @staticmethod def ssh(non_interactive: bool = False) -> None: """SSH into the running instance. - + Args: non_interactive: If True, run in non-interactive mode """ From bf9d548151cf5f115babea722bbc599e3865af7d Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:23:46 -0500 Subject: [PATCH 09/13] remove unused function --- deploy/deploy/models/omniparser/client.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index cddc17966..ef639d168 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -2,7 +2,6 @@ import argparse import base64 -import io import requests from loguru import logger @@ -42,19 +41,6 @@ def image_to_base64(image_path: str) -> str: return base64.b64encode(image_file.read()).decode("utf-8") -def base64_to_image(base64_string: str) -> Image.Image: - """Convert a base64 string to PIL Image. - - Args: - base64_string: Base64 encoded string of an image - - Returns: - Image.Image: PIL Image object - """ - img_data = base64.b64decode(base64_string) - return Image.open(io.BytesIO(img_data)) - - def plot_results( original_image_path: str, som_image_base64: str, From f52a33e73e125c723f7d31ea730b8424233d6337 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:25:50 -0500 Subject: [PATCH 10/13] replace argparse with fire --- deploy/deploy/models/omniparser/client.py | 45 +++++++++-------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index ef639d168..b765874da 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -1,33 +1,13 @@ """Client module for interacting with the OmniParser server.""" -import argparse import base64 +import fire import requests from loguru import logger from PIL import Image, ImageDraw -def parse_arguments() -> argparse.Namespace: - """Parse command line arguments. - - Returns: - argparse.Namespace: Parsed command line arguments - """ - parser = argparse.ArgumentParser(description="Omniparser Client") - parser.add_argument( - "--image_path", type=str, required=True, help="Path to the image" - ) - parser.add_argument( - "--server_url", - type=str, - required=True, - help="URL of the Omniparser server (e.g., http://localhost:8000)", - ) - args = parser.parse_args() - return args - - def image_to_base64(image_path: str) -> str: """Convert an image file to base64 string. @@ -88,15 +68,21 @@ def plot_results( image.show() -def main() -> None: - """Main entry point for the client application.""" - args = parse_arguments() +def parse_image( + image_path: str, + server_url: str, +) -> None: + """Parse an image using the OmniParser server. + Args: + image_path: Path to the image file + server_url: URL of the OmniParser server + """ # Remove trailing slash from server_url if present - server_url = args.server_url.rstrip("/") + server_url = server_url.rstrip("/") # Convert image to base64 - base64_image = image_to_base64(args.image_path) + base64_image = image_to_base64(image_path) # Prepare request url = f"{server_url}/parse/" @@ -119,7 +105,7 @@ def main() -> None: parsed_content_list = result["parsed_content_list"] # Plot results - plot_results(args.image_path, som_image_base64, parsed_content_list) + plot_results(image_path, som_image_base64, parsed_content_list) # Print latency logger.info(f"API Latency: {result['latency']:.2f} seconds") @@ -133,5 +119,10 @@ def main() -> None: logger.error(f"Error: {e}") +def main(): + """Main entry point for the client application.""" + fire.Fire(parse_image) + + if __name__ == "__main__": main() From 9a49bc2b3d0a769c4f6bdd59abbb959f999d37e9 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:34:11 -0500 Subject: [PATCH 11/13] config.CONTAINER_NAME --- deploy/deploy/models/omniparser/deploy.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index b4be073cb..ee6ea7a50 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -37,6 +37,11 @@ class Config: env_file = ".env" env_file_encoding = "utf-8" + @property + def CONTAINER_NAME(self) -> str: + """Get the container name.""" + return f"{self.PROJECT_NAME}-container" + @property def AWS_EC2_KEY_NAME(self) -> str: """Get the EC2 key pair name.""" @@ -515,18 +520,18 @@ def start() -> None: # Build and run Docker container docker_commands = [ # Remove any existing container - "sudo docker rm -f omniparser-container || true", + "sudo docker rm -f {config.CONTAINER_NAME} || true", # Remove any existing image - "sudo docker rmi omniparser || true", + "sudo docker rmi {config.PROJECT_NAME} || true", # Build new image ( "cd OmniParser && sudo docker build --progress=plain " - "-t omniparser ." + "-t {config.PROJECT_NAME} ." ), # Run new container ( "sudo docker run -d -p 8000:8000 --gpus all --name " - "omniparser-container omniparser" + "{config.CONTAINER_NAME} {config.PROJECT_NAME}" ), ] @@ -538,7 +543,7 @@ def start() -> None: # Wait for container to start and check its logs logger.info("Waiting for container to start...") time.sleep(10) # Give container time to start - execute_command(ssh_client, "docker logs omniparser-container") + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") # Wait for server to become responsive logger.info("Waiting for server to become responsive...") @@ -568,7 +573,7 @@ def start() -> None: raise RuntimeError("Server failed to start properly") # Final status check - execute_command(ssh_client, "docker ps | grep omniparser-container") + execute_command(ssh_client, "docker ps | grep {config.CONTAINER_NAME}") server_url = f"http://{instance_ip}:{config.PORT}" logger.info(f"Deployment complete. Server running at: {server_url}") @@ -591,7 +596,7 @@ def start() -> None: logger.error(f"Error during deployment: {e}") # Get container logs for debugging try: - execute_command(ssh_client, "docker logs omniparser-container") + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") except Exception as exc: logger.warning(f"{exc=}") pass From fac368bb5f4cfa9d669ba44cf94aee2aeb5dcab8 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Wed, 19 Feb 2025 16:37:05 -0500 Subject: [PATCH 12/13] lint --- deploy/deploy/models/omniparser/client.py | 2 +- deploy/deploy/models/omniparser/deploy.py | 36 +++++++++++++++++++---- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py index b765874da..c0cac4f49 100644 --- a/deploy/deploy/models/omniparser/client.py +++ b/deploy/deploy/models/omniparser/client.py @@ -119,7 +119,7 @@ def parse_image( logger.error(f"Error: {e}") -def main(): +def main() -> None: """Main entry point for the client application.""" fire.Fire(parse_image) diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py index ee6ea7a50..b951378bb 100644 --- a/deploy/deploy/models/omniparser/deploy.py +++ b/deploy/deploy/models/omniparser/deploy.py @@ -45,19 +45,16 @@ def CONTAINER_NAME(self) -> str: @property def AWS_EC2_KEY_NAME(self) -> str: """Get the EC2 key pair name.""" - return f"{self.PROJECT_NAME}-key" @property def AWS_EC2_KEY_PATH(self) -> str: """Get the path to the EC2 key file.""" - return f"./{self.AWS_EC2_KEY_NAME}.pem" @property def AWS_EC2_SECURITY_GROUP(self) -> str: """Get the EC2 security group name.""" - return f"{self.PROJECT_NAME}-SecurityGroup" @@ -312,6 +309,36 @@ def configure_ec2_instance( max_cmd_retries: int = 20, cmd_retry_delay: int = 30, ) -> tuple[str | None, str | None]: + """Configure an EC2 instance with necessary dependencies and Docker setup. + + This function either configures an existing EC2 instance specified by instance_id + and instance_ip, or deploys and configures a new instance. It installs Docker and + other required dependencies, and sets up the environment for running containers. + + Args: + instance_id: Optional ID of an existing EC2 instance to configure. + If None, a new instance will be deployed. + instance_ip: Optional IP address of an existing EC2 instance. + Required if instance_id is provided. + max_ssh_retries: Maximum number of SSH connection attempts. + Defaults to 20 attempts. + ssh_retry_delay: Delay in seconds between SSH connection attempts. + Defaults to 20 seconds. + max_cmd_retries: Maximum number of command execution retries. + Defaults to 20 attempts. + cmd_retry_delay: Delay in seconds between command execution retries. + Defaults to 30 seconds. + + Returns: + tuple[str | None, str | None]: A tuple containing: + - The instance ID (str) or None if configuration failed + - The instance's public IP address (str) or None if configuration failed + + Raises: + RuntimeError: If command execution fails + paramiko.SSHException: If SSH connection fails + Exception: For other unexpected errors during configuration + """ if not instance_id: ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() else: @@ -706,8 +733,7 @@ def stop( project_name: str = config.PROJECT_NAME, security_group_name: str = config.AWS_EC2_SECURITY_GROUP, ) -> None: - """ - Terminates the EC2 instance and deletes the associated security group. + """Terminates the EC2 instance and deletes the associated security group. Args: project_name (str): The project name used to tag the instance. From 27d4d1dc50f40e339eaf770ce02acbe3e5ae723f Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Sat, 15 Mar 2025 18:26:45 -0400 Subject: [PATCH 13/13] Add ProcessGraphStrategy for task-based automation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This new strategy: 1. Uses OmniParser for parsing visual state and Gemini 2.0 for evaluation 2. Takes natural language task descriptions instead of recording IDs 3. Processes coalesced actions from events.py 4. Builds and maintains a process graph where states are nodes and actions are edges The graph is constructed before replay based on recording + task description, and is updated during replay based on observed states. Uses Pydantic for structured data handling and robust JSON parsing. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- openadapt/strategies/process_graph.py | 1214 +++++++++++++++++++++++++ 1 file changed, 1214 insertions(+) create mode 100644 openadapt/strategies/process_graph.py diff --git a/openadapt/strategies/process_graph.py b/openadapt/strategies/process_graph.py new file mode 100644 index 000000000..9af045b30 --- /dev/null +++ b/openadapt/strategies/process_graph.py @@ -0,0 +1,1214 @@ +"""Process graph-based replay strategy using OmniParser and Gemini 2.0. + +This strategy: +1. Uses OmniParser for parsing visual state and Gemini 2.0 for state evaluation +2. Takes natural language task descriptions instead of recording IDs +3. Processes coalesced actions from events.py +4. Builds and maintains a process graph G=(V,E) where: + - V represents States + - E represents Actions + - Graph is constructed before replay based on recording + task description + - Graph is updated during replay based on observed states +""" + +import json +import math +import time +import uuid +from typing import List, Optional, Dict, Union, Literal, Any +import numpy as np + +from pydantic import BaseModel, Field +from PIL import Image +from json_repair import repair_json, loads as repair_loads + +from openadapt import adapters, common, models, utils, vision +from openadapt.custom_logger import logger +from openadapt.db import crud +from openadapt.strategies.base import BaseReplayStrategy +from openadapter.providers.omniparser import OmniParserProvider + + +# Pydantic models for structured data +class RecognitionCriterion(BaseModel): + """Criteria for recognizing a state""" + type: Literal["window_title", "ui_element_present", "visual_template"] + pattern: Optional[str] = None + threshold: Optional[float] = None + element_descriptor: Optional[str] = None + + +class ActionParameter(BaseModel): + """Parameters for an action""" + target_element: Optional[str] = None + text_input: Optional[str] = None + click_type: Optional[Literal["single", "double", "right"]] = None + coordinate_type: Optional[Literal["absolute", "relative"]] = None + + +class ActionModel(BaseModel): + """Model for an action in the process""" + name: str + description: str + parameters: ActionParameter + + +class Condition(BaseModel): + """Condition for a transition""" + type: Literal["element_state", "data_value", "previous_action"] + description: str + + +class Transition(BaseModel): + """Transition between states""" + from_state: str = Field(..., alias="from") + to_state: str = Field(..., alias="to") + action: ActionModel + condition: Optional[Condition] = None + + +class Branch(BaseModel): + """Branch in a decision point""" + condition: str + next_state: str + + +class DecisionPoint(BaseModel): + """Decision point in the process""" + state: str + description: str + branches: List[Branch] + + +class Loop(BaseModel): + """Loop in the process""" + start_state: str + end_state: str + exit_condition: str + description: str + + +class StateModel(BaseModel): + """Model for a state in the process""" + name: str + description: str + recognition_criteria: List[RecognitionCriterion] + + +class ProcessAnalysis(BaseModel): + """Complete model of a process""" + process_name: str + description: str + states: List[StateModel] + transitions: List[Transition] + loops: List[Loop] + decision_points: List[DecisionPoint] + + +class StateTrajectoryEntry(BaseModel): + """Entry in the state trajectory""" + state_name: Optional[str] = None + action_name: Optional[str] = None + timestamp: float + + +class CurrentStateMatch(BaseModel): + """Result of matching current state to graph""" + matched_state_name: str + confidence: float + reasoning: str + + +class UIElement(BaseModel): + """UI element in the visual state""" + type: str + text: Optional[str] = None + bounds: Dict[str, int] + description: str + is_interactive: bool + + +class VisualState(BaseModel): + """Visual state representation""" + window_title: str + ui_elements: List[UIElement] + screenshot_timestamp: float + + +class AbstractState: + """Represents an abstract state in the process graph with recognition logic.""" + + def __init__(self, name, description, recognition_criteria): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.recognition_criteria = recognition_criteria + self.example_screenshots = [] + + def match_rules(self, current_state, trajectory=None): + """Apply rule-based matching using recognition criteria.""" + for criterion in self.recognition_criteria: + if not self._evaluate_criterion(criterion, current_state): + return False + return True + + def _evaluate_criterion(self, criterion, state): + """Evaluate a single recognition criterion against current state.""" + criterion_type = criterion["type"] + + if criterion_type == "window_title": + if not state.window_event or not state.window_event.title: + return False + return criterion["pattern"] in state.window_event.title + + elif criterion_type == "ui_element_present": + if not state.visual_data: + return False + return any( + criterion["element_descriptor"] in element["description"] + for element in state.visual_data + ) + + elif criterion_type == "visual_template": + # Match against example screenshots + if not self.example_screenshots: + return False + return any( + vision.get_image_similarity(state.screenshot.image, example)[0] > criterion.get("threshold", 0.8) + for example in self.example_screenshots + ) + + return False + + def add_example(self, screenshot): + """Add example screenshot for visual matching.""" + self.example_screenshots.append(screenshot) + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "recognition_criteria": self.recognition_criteria + } + + +class AbstractAction: + """Represents an abstract action with parameters to be instantiated.""" + + def __init__(self, name, description, parameters): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.parameters = parameters + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "parameters": self.parameters + } + + +class ProcessGraph: + """Enhanced process graph with abstract states and conditional transitions.""" + + def __init__(self): + self.nodes = set() + self.edges = [] + self.conditions = {} # Maps (from_state, action, to_state) to condition logic + self.description = "" + + def add_node(self, node): + """Add a node to the graph.""" + self.nodes.add(node) + + def add_edge(self, from_state, action, to_state): + """Add an edge to the graph.""" + self.add_node(from_state) + self.add_node(action) + self.add_node(to_state) + self.edges.append((from_state, action, to_state)) + + def add_condition(self, from_state, action, to_state, condition): + """Add a condition to an edge.""" + key = (from_state.id, action.id, to_state.id) + self.conditions[key] = condition + + def get_abstract_states(self): + """Get all abstract states in the graph.""" + return [node for node in self.nodes if isinstance(node, AbstractState)] + + def get_state_by_name(self, name): + """Find a state by name.""" + for node in self.nodes: + if isinstance(node, AbstractState) and node.name == name: + return node + return None + + def get_possible_actions(self, state): + """Get possible actions from a state, considering conditions.""" + possible_actions = [] + + for from_state, action, to_state in self.edges: + if from_state.id == state.id: + key = (from_state.id, action.id, to_state.id) + if key in self.conditions: + # For now, we include conditional actions + # In a full implementation, would need to evaluate conditions + possible_actions.append((action, to_state)) + else: + possible_actions.append((action, to_state)) + + return possible_actions + + def set_description(self, description): + """Set the overall description of the process.""" + self.description = description + + def get_description(self): + """Get the overall description of the process.""" + return self.description + + def to_model(self) -> ProcessAnalysis: + """Convert graph to Pydantic model for serialization.""" + states = [ + StateModel( + name=state.name, + description=state.description, + recognition_criteria=[ + RecognitionCriterion(**criterion) + for criterion in state.recognition_criteria + ] + ) + for state in self.get_abstract_states() + ] + + transitions = [] + for from_state, action, to_state in self.edges: + if isinstance(from_state, AbstractState) and isinstance(to_state, AbstractState): + key = (from_state.id, action.id, to_state.id) + transition = Transition( + from_state=from_state.name, + to_state=to_state.name, + action=ActionModel( + name=action.name, + description=action.description, + parameters=ActionParameter(**action.parameters) + ) + ) + if key in self.conditions: + transition.condition = Condition(**self.conditions[key]) + transitions.append(transition) + + # Build loops and decision points using a simple algorithm + loops = self._detect_loops() + decision_points = self._detect_decision_points() + + return ProcessAnalysis( + process_name=self.description.split("\n")[0] if self.description else "Unnamed Process", + description=self.description, + states=states, + transitions=transitions, + loops=loops, + decision_points=decision_points + ) + + def _detect_loops(self) -> List[Loop]: + """Simple loop detection algorithm.""" + loops = [] + # Map state names to IDs for easier lookup + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Find cycles in the graph using DFS + visited = set() + path = [] + + def dfs(node_id): + if node_id in path: + # Found a cycle + cycle_start = path.index(node_id) + cycle = path[cycle_start:] + # Only process if the cycle involves states (not just actions) + state_ids = [node_id for node_id in cycle if node_id in state_id_to_name] + if len(state_ids) > 1: + loops.append(Loop( + start_state=state_id_to_name[state_ids[0]], + end_state=state_id_to_name[state_ids[-1]], + exit_condition="Condition to exit loop", + description=f"Loop from {state_id_to_name[state_ids[0]]} to {state_id_to_name[state_ids[-1]]}" + )) + return + + if node_id in visited: + return + + visited.add(node_id) + path.append(node_id) + + # Find all outgoing edges + for from_state, _, to_state in self.edges: + if from_state.id == node_id: + dfs(to_state.id) + + path.pop() + + # Start DFS from each state + for state in self.get_abstract_states(): + dfs(state.id) + + return loops + + def _detect_decision_points(self) -> List[DecisionPoint]: + """Detect states with multiple outgoing transitions.""" + decision_points = [] + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Count outgoing edges for each state + outgoing_counts = {} + for from_state, _, to_state in self.edges: + if from_state.id not in outgoing_counts: + outgoing_counts[from_state.id] = [] + outgoing_counts[from_state.id].append(to_state.id) + + # States with multiple outgoing edges are decision points + for state_id, destinations in outgoing_counts.items(): + if state_id in state_id_to_name and len(destinations) > 1: + branches = [] + for dest_id in destinations: + if dest_id in state_id_to_name: + branches.append(Branch( + condition=f"Condition to go to {state_id_to_name[dest_id]}", + next_state=state_id_to_name[dest_id] + )) + + if branches: + decision_points.append(DecisionPoint( + state=state_id_to_name[state_id], + description=f"Decision point at {state_id_to_name[state_id]}", + branches=branches + )) + + return decision_points + + def to_json(self): + """Convert graph to JSON string.""" + return self.to_model().model_dump_json(indent=2) + + def update_with_observation(self, observed_state, previous_state, latest_action): + """Update graph with observed state during execution.""" + # Find abstract states that match the observed state + similar_state = None + highest_similarity = 0.0 + + for state in self.get_abstract_states(): + similarity = self._calculate_state_similarity(observed_state, state) + if similarity > highest_similarity: + highest_similarity = similarity + similar_state = state + + # Create a new state if no good match + if highest_similarity < 0.7: + similar_state = self._create_new_state_from_observation(observed_state) + + # If we have a previous state and action, create or update a transition + if previous_state and latest_action: + # Check if transition already exists + transition_exists = False + for from_state, action, to_state in self.edges: + if (from_state.id == previous_state.id and + action.name == latest_action.name and + to_state.id == similar_state.id): + transition_exists = True + break + + if not transition_exists: + # Create a new abstract action from the latest action + action = AbstractAction( + name=latest_action.name, + description=f"Action {latest_action.name}", + parameters=self._extract_action_parameters(latest_action) + ) + + # Add the edge + self.add_edge(previous_state, action, similar_state) + + return similar_state + + def _calculate_state_similarity(self, observed_state, abstract_state): + """Calculate similarity between observed state and abstract state.""" + # Use rule-based matching first + if abstract_state.match_rules(observed_state): + return 0.9 # High confidence if rules match + + # Fall back to visual similarity if we have example screenshots + if abstract_state.example_screenshots and observed_state.screenshot: + visual_similarities = [ + vision.get_image_similarity(observed_state.screenshot.image, example)[0] + for example in abstract_state.example_screenshots + ] + return max(visual_similarities) if visual_similarities else 0.0 + + return 0.0 + + def _create_new_state_from_observation(self, observed_state): + """Create a new abstract state from an observed state.""" + # Generate a name for the state based on window title + name = "State_" + str(len(self.get_abstract_states()) + 1) + if observed_state.window_event and observed_state.window_event.title: + name = f"State_{observed_state.window_event.title[:20]}" + + # Create recognition criteria + criteria = [] + if observed_state.window_event and observed_state.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": observed_state.window_event.title + }) + + if observed_state.visual_data: + # Add criteria based on visible UI elements + for element in observed_state.visual_data[:3]: # Limit to a few key elements + if element.get("description"): + criteria.append({ + "type": "ui_element_present", + "element_descriptor": element["description"] + }) + + state = AbstractState( + name=name, + description=f"State with window title: {observed_state.window_event.title if observed_state.window_event else 'Unknown'}", + recognition_criteria=criteria + ) + + # Add screenshot as example for visual matching + if observed_state.screenshot: + state.add_example(observed_state.screenshot.image) + + self.add_node(state) + return state + + def _extract_action_parameters(self, action_event): + """Extract parameters from an action event.""" + parameters = {} + + if action_event.name in common.MOUSE_EVENTS: + parameters["target_element"] = action_event.active_segment_description + if "click" in action_event.name: + if "double" in action_event.name: + parameters["click_type"] = "double" + elif "right" in action_event.name: + parameters["click_type"] = "right" + else: + parameters["click_type"] = "single" + + elif action_event.name in common.KEYBOARD_EVENTS: + if action_event.text: + parameters["text_input"] = action_event.text + + return parameters + + +class State: + """Represents a concrete state during execution.""" + + def __init__(self, screenshot, window_event, browser_event=None, visual_data=None): + self.id = str(uuid.uuid4()) + self.screenshot = screenshot + self.window_event = window_event + self.browser_event = browser_event + self.visual_data = visual_data or [] + + +class ProcessGraphStrategy(BaseReplayStrategy): + """Strategy using process graphs, OmniParser and Gemini 2.0 Flash.""" + + def __init__( + self, + task_description: str, + recording_id: int = None, + ) -> None: + """Initialize with task description rather than recording ID.""" + # Find best matching recording if not provided + if not recording_id: + recording_id = self._find_matching_recording(task_description) + + db_session = crud.get_new_session() + self.recording = crud.get_recording(db_session, recording_id) + super().__init__(self.recording) + + self.task_description = task_description + + # Initialize OmniParser service + self.omniparser_provider = OmniParserProvider() + self._ensure_omniparser_running() + + # Initialize tracking + self.state_action_history = [] # List of (state, action) pairs + self.action_history = [] + self.current_state = None + self.current_abstract_state = None + + # Build graph before replay + self.process_graph = self._build_generalizable_process_graph(task_description) + + def _ensure_omniparser_running(self): + """Ensure OmniParser is running, deploying if necessary.""" + status = self.omniparser_provider.status() + if not status['services']: + logger.info("Deploying OmniParser...") + self.omniparser_provider.deploy() + self.omniparser_provider.stack.create_service() + + def _find_matching_recording(self, task_description: str) -> int: + """Find recording with most similar task description using vector similarity.""" + db_session = crud.get_new_session() + recordings = crud.get_all_recordings(db_session) + best_match = None + highest_similarity = -1 + + for recording in recordings: + if not recording.task_description: + continue + + similarity = self._calculate_text_similarity(task_description, recording.task_description) + if similarity > highest_similarity: + highest_similarity = similarity + best_match = recording.id + + if best_match is None: + # If no good match, use the most recent recording + recordings_sorted = sorted(recordings, key=lambda r: r.timestamp, reverse=True) + if recordings_sorted: + best_match = recordings_sorted[0].id + else: + raise ValueError("No recordings found in the database.") + + return best_match + + def _calculate_text_similarity(self, text1, text2): + """Calculate similarity between two text strings.""" + # Simple word overlap similarity + if not text1 or not text2: + return 0.0 + + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1.intersection(words2) + union = words1.union(words2) + + return len(intersection) / len(union) + + def _build_generalizable_process_graph(self, task_description): + """Build a generalizable process graph using multi-phase approach with MMMs.""" + # Get coalesced actions + processed_actions = self.recording.processed_action_events + + # Phase 1: Process Understanding - Analyze the entire workflow + process_model = self._analyze_entire_process(processed_actions, task_description) + + # Phase 2: Graph Construction - Build abstract graph from understanding + initial_graph = self._construct_abstract_graph(process_model) + + # Phase 3: Graph Validation - Test and refine by walking through recording + refined_graph = self._validate_and_refine_graph(initial_graph, processed_actions) + + return refined_graph + + def _select_representative_screenshots(self, action_events, max_images=10): + """Select representative screenshots from the action events.""" + if not action_events: + return [] + + # If few actions, use all screenshots + if len(action_events) <= max_images: + return [action.screenshot.image for action in action_events if action.screenshot] + + # Otherwise, select evenly spaced screenshots + step = len(action_events) // max_images + selected_actions = action_events[::step] + + # Add the last action if not included + if action_events[-1] not in selected_actions: + selected_actions.append(action_events[-1]) + + return [action.screenshot.image for action in selected_actions if action.screenshot] + + def _analyze_entire_process(self, actions, task_description): + """Have Gemini analyze the entire recording to understand the process structure.""" + key_screenshots = self._select_representative_screenshots(actions) + + # Generate schema JSON for the prompt + schema_json = ProcessAnalysis.model_json_schema() + + system_prompt = "You are an expert in understanding user interface workflows." + prompt = f""" + Analyze this UI automation sequence and identify: + 1. The high-level steps in the process + 2. Any repetitive patterns or loops + 3. Decision points where the workflow might branch + 4. The semantic meaning of each major state + + Task description: {task_description} + + RESPOND USING THE FOLLOWING JSON SCHEMA: + ```json + {json.dumps(schema_json, indent=2)} + ``` + + Your response must strictly follow this schema and be valid JSON. + """ + + process_analysis_text = self.prompt_gemini(prompt, system_prompt, key_screenshots) + + # Use json_repair for robust parsing + try: + # Direct parsing if possible + process_data = repair_loads(process_analysis_text) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as e: + logger.warning(f"Initial JSON parsing failed: {e}") + + # Try to repair potentially broken JSON + try: + repaired_json = repair_json(process_analysis_text, ensure_ascii=False) + process_data = json.loads(repaired_json) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as repair_e: + logger.error(f"JSON repair also failed: {repair_e}") + + # Last resort: try direct object return + try: + process_data = repair_json(process_analysis_text, return_objects=True) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as final_e: + logger.error(f"All JSON parsing methods failed: {final_e}") + return self._fallback_process_analysis(actions, task_description) + + def _fallback_process_analysis(self, actions, task_description): + """Create a simple process model if all parsing fails.""" + logger.warning("Using fallback process analysis") + + # Create a simple linear process model + states = [] + transitions = [] + + # Create a state for each key action + key_actions = actions[::max(1, len(actions) // 5)] # At most 5 states + + for i, action in enumerate(key_actions): + state_name = f"State_{i+1}" + state_description = f"State after {action.name} action" + + # Create recognition criteria + criteria = [] + if action.window_event and action.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": action.window_event.title + }) + + states.append(StateModel( + name=state_name, + description=state_description, + recognition_criteria=criteria + )) + + # Create transition to next state + if i < len(key_actions) - 1: + next_action = key_actions[i+1] + transitions.append(Transition( + from_state=state_name, + to_state=f"State_{i+2}", + action=ActionModel( + name=next_action.name, + description=f"{next_action.name} action", + parameters=ActionParameter() + ) + )) + + return ProcessAnalysis( + process_name="Fallback Process", + description=f"Fallback process for task: {task_description}", + states=states, + transitions=transitions, + loops=[], + decision_points=[] + ) + + def _construct_abstract_graph(self, process_model): + """Construct an abstract process graph based on the process understanding.""" + graph = ProcessGraph() + graph.set_description(process_model.description) + + # Create abstract state definitions based on process model + for state_def in process_model.states: + state = AbstractState( + name=state_def.name, + description=state_def.description, + recognition_criteria=[criterion.model_dump() for criterion in state_def.recognition_criteria] + ) + graph.add_node(state) + + # Create transitions with abstract actions + for transition in process_model.transitions: + from_state = graph.get_state_by_name(transition.from_state) + to_state = graph.get_state_by_name(transition.to_state) + + if from_state and to_state: + action = AbstractAction( + name=transition.action.name, + description=transition.action.description, + parameters=transition.action.parameters.model_dump() + ) + graph.add_edge(from_state, action, to_state) + + # Add conditional branches if present + if transition.condition: + graph.add_condition(from_state, action, to_state, transition.condition.model_dump()) + + return graph + + def _validate_and_refine_graph(self, graph, actions): + """Test the graph against recorded actions and refine it with Gemini's help.""" + # Simulate walking through the recording using the graph + simulation_results = self._simulate_graph_execution(graph, actions) + + if simulation_results["success"]: + return graph + + # If simulation failed, ask Gemini to refine the graph + system_prompt = "You are an expert in refining process models." + prompt = f""" + The process graph failed to match the recording at these points: + {simulation_results["failures"]} + + Current graph: {graph.to_json()} + + Please refine the graph to better match the recorded process while + maintaining generalizability. Consider: + 1. Adding missing states or transitions + 2. Adjusting state recognition criteria + 3. Modifying action parameters + 4. Adding conditional logic + + RESPOND USING THE SAME JSON SCHEMA AS THE CURRENT GRAPH. + """ + + refinements_text = self.prompt_gemini(prompt, system_prompt, simulation_results["screenshots"]) + + try: + refinements_data = repair_loads(refinements_text) + refined_model = ProcessAnalysis(**refinements_data) + refined_graph = self._construct_abstract_graph(refined_model) + + # Check if refinement improved the simulation + new_failures = len(self._simulate_graph_execution(refined_graph, actions)["failures"]) + old_failures = len(simulation_results["failures"]) + + if new_failures < old_failures: + return refined_graph + return graph + + except Exception as e: + logger.error(f"Failed to parse graph refinements: {e}") + return graph + + def _simulate_graph_execution(self, graph, actions): + """Simulate executing the graph with the recorded actions.""" + failures = [] + screenshots = [] + current_state = None + + for i, action in enumerate(actions): + # If first action, find initial state + if i == 0: + state = State(action.screenshot, action.window_event, action.browser_event) + matched_state = None + highest_similarity = 0.0 + + for abstract_state in graph.get_abstract_states(): + similarity = graph._calculate_state_similarity(state, abstract_state) + if similarity > highest_similarity: + highest_similarity = similarity + matched_state = abstract_state + + if highest_similarity < 0.7: + failures.append(f"Failed to match initial state at action {i}") + screenshots.append(action.screenshot.image) + + current_state = matched_state + continue + + # For subsequent actions, check if the graph has a transition + if current_state: + possible_actions = graph.get_possible_actions(current_state) + + # Check if any action matches the recorded action + action_match = False + for graph_action, next_state in possible_actions: + if graph_action.name == action.name: + action_match = True + current_state = next_state + break + + if not action_match: + failures.append(f"No matching action '{action.name}' from state '{current_state.name}' at action {i}") + screenshots.append(action.screenshot.image) + + return { + "success": len(failures) == 0, + "failures": failures, + "screenshots": screenshots[:5] # Limit to 5 screenshots for prompt size + } + + def get_next_action_event( + self, + screenshot: models.Screenshot, + window_event: models.WindowEvent, + ) -> models.ActionEvent: + """Determine next action using the process graph and runtime adaptation.""" + # Create current state representation + current_state = State( + screenshot=screenshot, + window_event=window_event + ) + + # Parse visual state with OmniParser + visual_data = self._parse_state_with_omniparser(screenshot.image) + current_state.visual_data = visual_data + + # Update graph with actual observed state + previous_abstract_state = self.current_abstract_state + latest_action = self.action_history[-1] if self.action_history else None + + self.current_abstract_state = self.process_graph.update_with_observation( + current_state, + previous_abstract_state, + latest_action + ) + + self.current_state = current_state + + # Find possible next actions in graph + possible_actions = self.process_graph.get_possible_actions(self.current_abstract_state) + + if not possible_actions: + # No actions available - either reached end state or unexpected state + if len(self.action_history) > 0: + # We've taken at least one action, so this might be the end + raise StopIteration("No further actions available in the process graph") + else: + # No actions taken yet - generate one with Gemini + next_action = self._generate_action_with_gemini() + self.action_history.append(next_action) + return next_action + + if len(possible_actions) == 1: + # Single clear action to take + action, next_state = possible_actions[0] + next_action = self._instantiate_abstract_action(action, current_state) + else: + # Multiple possible actions - use Gemini to decide + next_action = self._decide_between_actions(possible_actions, current_state) + + self.state_action_history.append((self.current_abstract_state, next_action)) + self.action_history.append(next_action) + + return next_action + + def _parse_state_with_omniparser(self, screenshot_image): + """Use OmniParser to parse the visual state.""" + try: + # Convert PIL Image to bytes + import io + img_byte_arr = io.BytesIO() + screenshot_image.save(img_byte_arr, format='PNG') + img_bytes = img_byte_arr.getvalue() + + # Call OmniParser API + result = self.omniparser_provider.parse_screenshot(img_bytes) + + # Transform the result into our expected format + ui_elements = [] + for element in result.get("elements", []): + ui_elements.append({ + "type": element.get("type", "unknown"), + "text": element.get("text", ""), + "bounds": element.get("bounds", {"x": 0, "y": 0, "width": 0, "height": 0}), + "description": element.get("description", ""), + "is_interactive": element.get("is_interactive", False) + }) + + return ui_elements + + except Exception as e: + logger.error(f"Error parsing state with OmniParser: {e}") + return [] + + def _instantiate_abstract_action(self, abstract_action, current_state): + """Convert abstract action to concrete ActionEvent based on current state.""" + try: + # Use parameters from abstract action if possible + params = abstract_action.parameters + + if abstract_action.name in common.MOUSE_EVENTS: + # Create a mouse action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + # If we have a target element, find its coordinates + if params.get("target_element"): + target_element = None + for element in current_state.visual_data: + if params["target_element"] in element.get("description", ""): + target_element = element + break + + if target_element: + bounds = target_element.get("bounds", {}) + # Calculate center of element + center_x = bounds.get("x", 0) + bounds.get("width", 0) / 2 + center_y = bounds.get("y", 0) + bounds.get("height", 0) / 2 + + action_event.mouse_x = center_x + action_event.mouse_y = center_y + action_event.active_segment_description = params["target_element"] + else: + # If target not found, use Gemini to identify coordinates + action_event = self._locate_target_with_gemini( + params["target_element"], + abstract_action.name, + current_state + ) + else: + # Use Gemini to decide where to click + action_event = self._locate_target_with_gemini( + None, + abstract_action.name, + current_state + ) + + return action_event + + elif abstract_action.name in common.KEYBOARD_EVENTS: + # Create a keyboard action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + if params.get("text_input"): + # For "type" action, convert to actual keypresses + action_event = models.ActionEvent.from_dict({ + "name": "type", + "text": params["text_input"] + }) + action_event.screenshot = current_state.screenshot + action_event.window_event = current_state.window_event + action_event.recording = self.recording + + return action_event + + else: + # For other actions, use Gemini + return self._generate_action_with_gemini(abstract_action.name) + + except Exception as e: + logger.error(f"Error instantiating action: {e}") + return self._generate_action_with_gemini() + + def _locate_target_with_gemini(self, target_description, action_name, current_state): + """Use Gemini to locate a target on the screen.""" + system_prompt = "You are an expert in UI automation and element identification." + prompt = f""" + Identify the coordinates to perform a {action_name} action. + + {f'The target is described as: {target_description}' if target_description else 'Find the most appropriate element to interact with based on the current state.'} + + Analyze the screenshot and provide the x,y coordinates where the action should be performed. + Respond with a JSON object containing: + 1. x: the x-coordinate (number) + 2. y: the y-coordinate (number) + 3. description: brief description of what element is at these coordinates + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + # Parse the response + coord_data = repair_loads(result_text) + + # Create action event + action_event = models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=coord_data.get("x", 0), + mouse_y=coord_data.get("y", 0), + active_segment_description=coord_data.get("description", "") + ) + + return action_event + + except Exception as e: + logger.error(f"Error parsing coordinates: {e}") + + # Fallback: use center of screen + window_width = current_state.window_event.width if current_state.window_event else 800 + window_height = current_state.window_event.height if current_state.window_event else 600 + + return models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + active_segment_description="Center of screen (fallback)" + ) + + def _decide_between_actions(self, possible_actions, current_state): + """Use Gemini to decide between multiple possible actions.""" + system_prompt = "You are an expert in UI automation decision making." + + actions_list = [] + for i, (action, next_state) in enumerate(possible_actions): + actions_list.append({ + "id": i, + "name": action.name, + "description": action.description, + "parameters": action.parameters, + "next_state": next_state.name, + "next_state_description": next_state.description + }) + + prompt = f""" + Decide which action to take next based on the current state and task description. + + Task description: {self.task_description} + Current state: {self.current_abstract_state.description if self.current_abstract_state else "Initial state"} + + Possible actions: + {json.dumps(actions_list, indent=2)} + + Respond with a JSON object containing: + 1. chosen_action_id: the ID of the chosen action (number) + 2. reasoning: brief explanation for your choice + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + result = repair_loads(result_text) + chosen_id = result.get("chosen_action_id", 0) + chosen_id = min(chosen_id, len(possible_actions) - 1) # Ensure valid index + + action, next_state = possible_actions[chosen_id] + return self._instantiate_abstract_action(action, current_state) + + except Exception as e: + logger.error(f"Error deciding between actions: {e}") + # Default to first action + action, next_state = possible_actions[0] + return self._instantiate_abstract_action(action, current_state) + + def _generate_action_with_gemini(self, suggested_action_name=None): + """Generate action with Gemini if graph doesn't provide one.""" + system_prompt = "You are an expert in UI automation." + + trajectory = [] + for i, (state, action) in enumerate(self.state_action_history[-5:]): + trajectory.append({ + "step": i + 1, + "state": state.name if state else "Unknown", + "action": action.name if action else "None" + }) + + prompt = f""" + Generate the next action to perform based on: + + Task description: {self.task_description} + Recent trajectory: {json.dumps(trajectory, indent=2)} + {f'Suggested action type: {suggested_action_name}' if suggested_action_name else ''} + + Analyze the screenshot and respond with a JSON object for the next ActionEvent: + {{ + "name": "click|move|scroll|type", + "mouse_x": number, + "mouse_y": number, + "text": "text to type (for keyboard actions)", + "active_segment_description": "description of what's being clicked" + }} + + Only include relevant fields based on the action type. + """ + + result_text = self.prompt_gemini( + prompt, + system_prompt, + [self.current_state.screenshot.image] if self.current_state else [] + ) + + try: + action_dict = repair_loads(result_text) + action = models.ActionEvent.from_dict(action_dict) + + # Add missing context + action.screenshot = self.current_state.screenshot if self.current_state else None + action.window_event = self.current_state.window_event if self.current_state else None + action.recording = self.recording + + return action + + except Exception as e: + logger.error(f"Error generating action: {e}") + + # Create a fallback action - simple click in the center + window_width = self.current_state.window_event.width if self.current_state and self.current_state.window_event else 800 + window_height = self.current_state.window_event.height if self.current_state and self.current_state.window_event else 600 + + return models.ActionEvent( + name="click", + screenshot=self.current_state.screenshot if self.current_state else None, + window_event=self.current_state.window_event if self.current_state else None, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + mouse_button_name="left", + active_segment_description="Center of screen (fallback)" + ) + + def prompt_gemini(self, prompt, system_prompt, images): + """Helper method to prompt Gemini with images.""" + from openadapt.drivers import google + return google.prompt( + prompt, + system_prompt=system_prompt, + images=images, + model_name="models/gemini-1.5-pro-latest" + ) + + def __del__(self): + """Clean up OmniParser service when done.""" + try: + self.omniparser_provider.stack.stop_service() + except: + pass \ No newline at end of file