diff --git a/.gitignore b/.gitignore index 8e3515c..4b6991f 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ run_experiments/ generated/ runs/ testst/ +test/ # System files .DS_Store @@ -175,7 +176,6 @@ flask_test.log # Temporary files *.tmp *.temp - # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore diff --git a/config.template.toml b/config.template.toml index 8af5589..4c035cb 100644 --- a/config.template.toml +++ b/config.template.toml @@ -29,3 +29,21 @@ llm_api_key = "" # Budget allocation preference for pipeline modules: balanced, write-heavy, # think-heavy, or review-heavy #budget_preference = "balanced" + +#################################### MCP #################################### +[mcp.servers] + +[mcp.servers.code_search] +command = "python" +args = ["-m", "tiny_scientist.mcp.code_search_server"] +cwd = "." + +[mcp.servers.paper_search] +command = "python" +args = ["-m", "tiny_scientist.mcp.paper_search_server"] +cwd = "." + +[mcp.servers.drawer] +command = "python" +args = ["-m", "tiny_scientist.mcp.drawer_server"] +cwd = "." diff --git a/poetry.lock b/poetry.lock index efb661c..5288e0b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,8 @@ +<<<<<<< HEAD # This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +======= +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +>>>>>>> origin/mcp-2.0 [[package]] name = "aider-chat" @@ -363,7 +367,7 @@ description = "Timeout context manager for asyncio programs" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, @@ -1065,7 +1069,7 @@ description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" groups = ["main", "dev", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -1104,6 +1108,30 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastmcp" +version = "1.0" +description = "A more ergonomic interface for MCP servers" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "fastmcp-1.0-py3-none-any.whl", hash = "sha256:88f0c5acc2af06f22cf46dd26c1a1c4c54f1479ef09e5f871fdfbade6defe3a6"}, + {file = "fastmcp-1.0.tar.gz", hash = "sha256:202f454e82cb68460a2b7372f975901e78e03b27734ce3a16c4d1d3e3cdbc519"}, +] + +[package.dependencies] +httpx = ">=0.26.0" +mcp = ">=1.0.0,<2.0.0" +pydantic = ">=2.5.3,<3.0.0" +pydantic-settings = ">=2.6.1" +python-dotenv = ">=1.0.1" +typer = ">=0.9.0" + +[package.extras] +dev = ["copychat (>=0.5.2)", "ipython (>=8.12.3)", "pdbpp (>=0.10.3)", "pre-commit", "pyright (>=1.1.389)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.23.5)", "pytest-flakefinder", "pytest-xdist (>=3.6.1)", "ruff"] +tests = ["pre-commit", "pyright (>=1.1.389)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.23.5)", "pytest-flakefinder", "pytest-xdist (>=3.6.1)", "ruff"] + [[package]] name = "filelock" version = "3.18.0" @@ -1712,6 +1740,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.1" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37"}, + {file = "httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e"}, +] + [[package]] name = "huggingface-hub" version = "0.30.2" @@ -1906,6 +1946,22 @@ qtconsole = ["qtconsole"] test = ["pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.22)", "pandas", "pickleshare", "pytest (<7.1)", "pytest-asyncio (<0.22)", "testpath", "trio"] +[[package]] +name = "isort" +version = "6.0.1" +description = "A Python utility / library to sort Python imports." +optional = false +python-versions = ">=3.9.0" +groups = ["dev"] +files = [ + {file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"}, + {file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"}, +] + +[package.extras] +colors = ["colorama"] +plugins = ["setuptools"] + [[package]] name = "itsdangerous" version = "2.2.0" @@ -2421,6 +2477,35 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.10.1" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.10.1-py3-none-any.whl", hash = "sha256:4d08301aefe906dce0fa482289db55ce1db831e3e67212e65b5e23ad8454b3c5"}, + {file = "mcp-1.10.1.tar.gz", hash = "sha256:aaa0957d8307feeff180da2d9d359f2b801f35c0c67f1882136239055ef034c2"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.7.2,<3.0.0" +pydantic-settings = ">=2.5.2" +python-multipart = ">=0.0.9" +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +uvicorn = {version = ">=0.23.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -3572,6 +3657,30 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pydantic-settings" +version = "2.10.1" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pydantic_settings-2.10.1-py3-none-any.whl", hash = "sha256:a60952460b99cf661dc25c29c0ef171721f98bfcb52ef8d9ea4c943d7c8cc796"}, + {file = "pydantic_settings-2.10.1.tar.gz", hash = "sha256:06f0062169818d0f5524420a360d632d5857b83cffd4d42fe29597807a1614ee"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" +typing-inspection = ">=0.4.0" + +[package.extras] +aws-secrets-manager = ["boto3 (>=1.35.0)", "boto3-stubs[secretsmanager]"] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +gcp-secret-manager = ["google-cloud-secret-manager (>=2.23.1)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pydub" version = "0.25.1" @@ -3780,6 +3889,7 @@ files = [ cli = ["click (>=5.0)"] [[package]] +<<<<<<< HEAD name = "python-engineio" version = "4.12.2" description = "Engine.IO server and client for Python" @@ -3803,10 +3913,16 @@ docs = ["sphinx"] name = "python-socketio" version = "5.13.0" description = "Socket.IO server and client for Python" +======= +name = "python-multipart" +version = "0.0.20" +description = "A streaming multipart parser for Python" +>>>>>>> origin/mcp-2.0 optional = false python-versions = ">=3.8" groups = ["main"] files = [ +<<<<<<< HEAD {file = "python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf"}, {file = "python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029"}, ] @@ -3820,6 +3936,12 @@ asyncio-client = ["aiohttp (>=3.4)"] client = ["requests (>=2.21.0)", "websocket-client (>=0.54.0)"] docs = ["sphinx"] +======= + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, +] + +>>>>>>> origin/mcp-2.0 [[package]] name = "pywin32" version = "310" @@ -4341,6 +4463,34 @@ files = [ [package.dependencies] pyasn1 = ">=0.1.3" +[[package]] +name = "ruff" +version = "0.12.2" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be"}, + {file = "ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e"}, + {file = "ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9"}, + {file = "ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04"}, + {file = "ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342"}, + {file = "ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a"}, + {file = "ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639"}, + {file = "ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12"}, + {file = "ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e"}, +] + [[package]] name = "scipy" version = "1.13.1" @@ -4733,6 +4883,7 @@ files = [ catalogue = ">=2.0.3,<2.1.0" [[package]] +<<<<<<< HEAD name = "sseclient-py" version = "1.7.2" description = "SSE client for Python" @@ -4744,6 +4895,28 @@ files = [ {file = "sseclient_py-1.7.2-py2.py3-none-any.whl", hash = "sha256:a758653b13b78df42cdb696740635a26cb72ad433b75efb68dbbb163d099b6a9"}, ] +======= +name = "sse-starlette" +version = "2.3.6" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-2.3.6-py3-none-any.whl", hash = "sha256:d49a8285b182f6e2228e2609c350398b2ca2c36216c2675d875f81e93548f760"}, + {file = "sse_starlette-2.3.6.tar.gz", hash = "sha256:0382336f7d4ec30160cf9ca0518962905e1b69b72d6c1c995131e0a703b436e3"}, +] + +[package.dependencies] +anyio = ">=4.7.0" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio,examples] (>=2.0.41)", "starlette (>=0.41.3)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + +>>>>>>> origin/mcp-2.0 [[package]] name = "stack-data" version = "0.6.3" @@ -4765,6 +4938,43 @@ pure-eval = "*" tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] +<<<<<<< HEAD +======= +name = "starlette" +version = "0.47.1" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527"}, + {file = "starlette-0.47.1.tar.gz", hash = "sha256:aef012dd2b6be325ffa16698f9dc533614fb1cebd593a906b90dc1025529a79b"}, +] + +[package.dependencies] +anyio = ">=3.6.2,<5" +typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} + +[package.extras] +full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] + +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + +[[package]] +>>>>>>> origin/mcp-2.0 name = "thinc" version = "8.3.4" description = "A refreshing functional take on deep learning, compatible with your favorite libraries" @@ -4975,7 +5185,7 @@ description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["main", "dev", "test"] -markers = "python_version < \"3.11\"" +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -5381,6 +5591,27 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvicorn" +version = "0.35.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "sys_platform != \"emscripten\"" +files = [ + {file = "uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a"}, + {file = "uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4) ; sys_platform == \"win32\"", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.15.1) ; sys_platform != \"win32\" and sys_platform != \"cygwin\" and platform_python_implementation != \"PyPy\"", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "virtualenv" version = "20.30.0" @@ -5812,5 +6043,10 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" +<<<<<<< HEAD python-versions = ">=3.9, <3.12" content-hash = "a6fae8faef3c2efeeb68d656c734a25b6a0355dc777ff1be6cf944c19a793536" +======= +python-versions = ">=3.10, <3.12" +content-hash = "4011c12d0bf99a57e589406755e1821ce2e6efeff776a355390434da1b61d804" +>>>>>>> origin/mcp-2.0 diff --git a/pyproject.toml b/pyproject.toml index 33fb162..97e3fe8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tiny-scientist" -version = "0.1.0" +version = "0.1.1" description = "A lightweight framework for building research agents" authors = ["Haofei Yu "] license = "Apache 2.0 License" @@ -33,6 +33,10 @@ together = "*" flask = "^3.0.0" flask-cors = ">=4,<7" flask-socketio = "^5.5.1" +fastmcp = "*" +mcp = "*" +httpx = "*" +docker = "^7.1.0" [tool.poetry.group.dev.dependencies] pre-commit = "*" @@ -41,6 +45,8 @@ types-setuptools = "*" types-pyyaml = "^6.0.12.20250402" types-requests = "^2.31" types-toml = "^0.10" +isort = "^6.0.1" +ruff = "^0.12.2" [tool.poetry.group.test.dependencies] pytest = "*" diff --git a/tiny_scientist/coder.py b/tiny_scientist/coder.py index 4199473..e08f195 100644 --- a/tiny_scientist/coder.py +++ b/tiny_scientist/coder.py @@ -30,6 +30,7 @@ def __init__( chat_history: Optional[str] = None, auto_install: bool = True, cost_tracker: Optional[BudgetChecker] = None, + mcp_client: Any = None, ): """Initialize the ExperimentCoder with configuration and Aider setup.""" self.client, self.model = create_client(model) @@ -40,6 +41,7 @@ def __init__( self.auto_install = auto_install self.config = Config() self.cost_tracker = cost_tracker or BudgetChecker() + self.mcp_client = mcp_client # Load prompts self.prompts = self.config.prompt_template.coder_prompt @@ -77,9 +79,28 @@ def run( ) -> Tuple[bool, str, Optional[str]]: # Ensure a clean slate for every run print(f"[System] Cleaning experiment directory: {self.output_dir}") - if osp.exists(self.output_dir): - shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir) + + # Save current working directory and switch to parent directory to avoid deletion issues + original_cwd = os.getcwd() + safe_cwd = osp.dirname(osp.abspath(self.output_dir)) + + try: + # Switch to safe directory before cleaning + os.chdir(safe_cwd) + + if osp.exists(self.output_dir): + shutil.rmtree(self.output_dir) + os.makedirs(self.output_dir) + + finally: + # Restore original working directory if it still exists, otherwise use safe directory + try: + os.chdir(original_cwd) + except (FileNotFoundError, OSError): + print( + f"[System] Original working directory no longer exists, staying in {safe_cwd}" + ) + os.chdir(safe_cwd) fnames = [ osp.join(self.output_dir, "experiment.py"), osp.join(self.output_dir, "notes.txt"), diff --git a/tiny_scientist/mcp/code_search_server.py b/tiny_scientist/mcp/code_search_server.py new file mode 100644 index 0000000..3917a34 --- /dev/null +++ b/tiny_scientist/mcp/code_search_server.py @@ -0,0 +1,212 @@ +import json +import os +import re +from typing import Any, Dict, List, Optional + +import httpx +import spacy +import toml +from mcp.server.fastmcp import FastMCP + +# Initialize FastMCP server +mcp = FastMCP("code_search") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# GitHub API configuration +GITHUB_API_BASE = "https://api.github.com" +GITHUB_TOKEN = config["core"].get("github_token", None) + + +async def make_github_request( + url: str, params: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """Make a request to the GitHub API with proper error handling.""" + headers = {"Accept": "application/vnd.github.v3+json"} + if GITHUB_TOKEN: + headers["Authorization"] = f"token {GITHUB_TOKEN}" + + async with httpx.AsyncClient() as client: + try: + response = await client.get( + url, headers=headers, params=params, timeout=30.0 + ) + response.raise_for_status() + result: Dict[str, Any] = response.json() + return result + except Exception as e: + print(f"GitHub API request failed: {e}") + return None + + +def format_github_repo_query( + idea: Dict[str, Any], max_terms: int = 6, max_query_length: int = 250 +) -> str: + """Format a research idea into a GitHub search query.""" + title = idea.get("Title", "") + experiment = idea.get("Experiment", "") + combined_text = f"{title}. {experiment}" + + try: + nlp = spacy.load("en_core_web_sm") + doc = nlp(combined_text) + candidates = set() + + # Extract short noun phrases + for chunk in doc.noun_chunks: + phrase = chunk.text.strip().lower() + if 1 <= len(phrase.split()) <= 4: + candidates.add(phrase) + + # Add important standalone nouns and proper nouns + for token in doc: + if token.pos_ in {"NOUN", "PROPN"} and len(token.text) > 2: + candidates.add(token.text.lower()) + + # Clean and deduplicate + seen = set() + keywords = [] + for kw in candidates: + cleaned = re.sub(r"[^\w\s]", "", kw) + if cleaned not in seen: + seen.add(cleaned) + keywords.append(cleaned) + if len(keywords) >= max_terms: + break + + # Build query string + quoted_keywords = [f'"{kw}"' if " " in kw else kw for kw in keywords] + base_query = " ".join(quoted_keywords) + suffix = " in:file language:python" + full_query = f"{base_query} {suffix}" + + # Truncate if needed + if len(full_query) > max_query_length: + full_query = f"{' '.join(quoted_keywords[:max_terms//2])} {suffix}" + + return full_query + except Exception: + # Fallback to simple keyword extraction + return f"{title} {experiment} language:python" + + +def extract_github_repo_info(repos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Extract relevant information from GitHub repository search results.""" + return [ + { + "name": repo["name"], + "owner": repo["owner"]["login"], + "stars": repo["stargazers_count"], + "forks": repo["forks_count"], + "url": repo["html_url"], + "description": repo["description"] or "No description provided.", + } + for repo in repos + ] + + +def extract_github_code_info( + code_results: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """Extract relevant information from GitHub code search results.""" + return [ + { + "file_name": item["name"], + "repository": item["repository"]["full_name"], + "url": item["html_url"], + } + for item in code_results + ] + + +@mcp.tool() +async def search_github_repositories(query: str, result_limit: int = 10) -> str: + """Search GitHub repositories. + + Args: + query: Search query string or JSON string containing research idea + result_limit: Maximum number of results to return (default: 10) + """ + print(f"[GitHub API] Searching repositories with query: {query}") + + # Try to parse as JSON (research idea format) + try: + idea = json.loads(query) + if isinstance(idea, dict) and any(k in idea for k in ["Title", "Experiment"]): + formatted_query = format_github_repo_query(idea) + print(f"[GitHub API] Formatted query from idea: {formatted_query}") + else: + formatted_query = query + except (json.JSONDecodeError, TypeError): + formatted_query = query + + url = f"{GITHUB_API_BASE}/search/repositories" + params = { + "q": formatted_query, + "sort": "stars", + "order": "desc", + "per_page": min(result_limit, 100), + } + + data = await make_github_request(url, params) + if not data or "items" not in data: + return json.dumps( + {"error": "Unable to fetch repositories or no repositories found."} + ) + + repos = extract_github_repo_info(data["items"]) + + # Format results for return + results = {} + for i, repo in enumerate(repos): + results[str(i)] = { + "title": repo["name"], + "source": repo["url"], + "info": f"Stars: {repo['stars']}, Owner: {repo['owner']}", + "description": repo["description"], + } + + return json.dumps(results, indent=2) + + +@mcp.tool() +async def search_github_code(query: str, result_limit: int = 10) -> str: + """Search GitHub code files. + + Args: + query: Search query string + result_limit: Maximum number of results to return (default: 10) + """ + print(f"[GitHub API] Searching code with query: {query}") + + url = f"{GITHUB_API_BASE}/search/code" + params = { + "q": query, + "sort": "indexed", + "order": "desc", + "per_page": min(result_limit, 100), + } + + data = await make_github_request(url, params) + if not data or "items" not in data: + return json.dumps({"error": "Unable to fetch code results or no code found."}) + + code_results = extract_github_code_info(data["items"]) + + # Format results for return + results = {} + for i, code in enumerate(code_results): + results[str(i)] = { + "title": code["file_name"], + "source": code["url"], + "info": f"Repository: {code['repository']}", + } + + return json.dumps(results, indent=2) + + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport="stdio") diff --git a/tiny_scientist/mcp/drawer_server.py b/tiny_scientist/mcp/drawer_server.py new file mode 100644 index 0000000..ac3e972 --- /dev/null +++ b/tiny_scientist/mcp/drawer_server.py @@ -0,0 +1,246 @@ +import json +import os +import re +from importlib import resources +from typing import Any, Dict, Optional + +import fitz +import httpx +import toml +from mcp.server.fastmcp import FastMCP + +from tiny_scientist.configs import Config + +# Initialize FastMCP server +mcp = FastMCP("drawer") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# LLM configuration +LLM_MODEL = config["core"].get("model", "gpt-4o-mini") +LLM_API_KEY = config["core"].get("llm_api_key", "") +LLM_TEMPERATURE = config["core"].get("temperature", 0.75) + +# Load prompt templates from the configs module +prompt_config = Config() +prompts = prompt_config.prompt_template.drawer_prompt + + +def escape_curly_braces(text: str) -> str: + """Escape curly braces in text to prevent format string issues.""" + return re.sub(r"({|})", r"{{\1}}", text) + + +def extract_pdf_text_from_resource(package: str, filename: str) -> str: + """Extract text from a PDF resource file.""" + with resources.files(package).joinpath(filename).open("rb") as f: + doc = fitz.open(stream=f.read(), filetype="pdf") + extracted = [page.get_text().strip() for page in doc] + return "\n\n".join(extracted) + + +def get_section_prompts(section_name: str, section_text: str) -> str: + """Get section-specific prompts.""" + section_prompt = prompts.section_prompt[section_name].format( + section_text=section_text + ) + return section_prompt + + +async def make_llm_request(prompt: str, system_message: str) -> Optional[str]: + """Make a request to the LLM API.""" + headers = { + "Authorization": f"Bearer {LLM_API_KEY}", + "Content-Type": "application/json", + } + + data = { + "model": LLM_MODEL, + "messages": [ + {"role": "system", "content": system_message}, + {"role": "user", "content": prompt}, + ], + "temperature": LLM_TEMPERATURE, + } + + async with httpx.AsyncClient() as client: + try: + response = await client.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=data, + timeout=60.0, + ) + response.raise_for_status() + result = response.json() + content = result["choices"][0]["message"]["content"] + return content if isinstance(content, str) else None + except Exception as e: + print(f"LLM API request failed: {e}") + return None + + +def extract_diagram_data(response: str) -> Dict[str, Any]: + """Extract diagram data from LLM response.""" + result = {"summary": "", "svg": "", "full_response": response} + + try: + parsed = json.loads(response) + summary = parsed["summary"] + svg = parsed["svg"] + except json.JSONDecodeError: + svg_match = re.search(r"", response, re.DOTALL) + svg = svg_match.group(0) if svg_match else "" + summary = ( + re.sub(r"", "", response, flags=re.DOTALL) + .strip() + .split("\n")[0] + ) + + if "" in svg: + result["summary"] = summary + result["svg"] = clean_svg(svg) + else: + print("[ERROR] SVG missing or too short.") + return result + + +def clean_svg(svg: str) -> str: + """Clean and format SVG content.""" + # Strip any outer code block delimiters + svg = svg.strip() + svg = re.sub(r"^```(?:svg)?", "", svg) + svg = re.sub(r"```$", "", svg) + + # Replace problematic ampersands + svg = svg.replace("&", "&") + + # Ensure no double XML declarations + svg = re.sub(r"<\?xml.*?\?>", "", svg, count=1) + + # Remove extra whitespace lines + svg = "\n".join([line for line in svg.splitlines() if line.strip()]) + + return svg.strip() + + +# Initialize system prompt with sample data +def initialize_system_prompt() -> str: + """Initialize the system prompt with sample data.""" + try: + method_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "framework.pdf" + ) + result_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "result.pdf" + ) + + method_sample = escape_curly_braces(method_sample_raw) + result_sample = escape_curly_braces(result_sample_raw) + + return prompts.diagram_system_prompt.format( + method_sample=method_sample, + result_sample=result_sample, + ) + except Exception as e: + print(f"[WARNING] Failed to load sample data: {e}") + return "You are a diagram generation assistant. Generate SVG diagrams based on research paper sections." + + +SYSTEM_PROMPT = initialize_system_prompt() + + +@mcp.tool() +async def generate_diagram(section_name: str, section_content: str) -> str: + """Generate an SVG diagram for a research paper section. + + Args: + section_name: Name of the paper section (e.g., "Method", "Results") + section_content: Content of the section to visualize + """ + print(f"[Drawer] Generating diagram for section: {section_name}") + + if not section_content.strip(): + return json.dumps({"error": "Section content cannot be empty"}) + + # Get section-specific prompts + section_prompt = get_section_prompts(section_name, section_content) + + # Generate diagram using LLM + llm_response = await make_llm_request(section_prompt, SYSTEM_PROMPT) + + if not llm_response: + return json.dumps({"error": "Failed to generate diagram from LLM"}) + + # Extract diagram data + diagram = extract_diagram_data(llm_response) + + # Format response + result = { + "diagram": { + "summary": diagram.get("summary", ""), + "svg": diagram.get("svg", ""), + } + } + + return json.dumps(result, indent=2) + + +@mcp.tool() +async def validate_svg(svg_content: str) -> str: + """Validate and clean SVG content. + + Args: + svg_content: SVG content to validate and clean + """ + print("[Drawer] Validating and cleaning SVG content") + + if not svg_content.strip(): + return json.dumps({"error": "SVG content cannot be empty"}) + + try: + cleaned_svg = clean_svg(svg_content) + + # Basic validation - check if it looks like valid SVG + if "" in cleaned_svg: + result = { + "valid": True, + "cleaned_svg": cleaned_svg, + "message": "SVG is valid and has been cleaned", + } + else: + result = { + "valid": False, + "cleaned_svg": "", + "message": "SVG appears to be invalid or incomplete", + } + + return json.dumps(result, indent=2) + except Exception as e: + return json.dumps( + { + "valid": False, + "cleaned_svg": "", + "message": f"Error validating SVG: {str(e)}", + } + ) + + +@mcp.tool() +async def get_supported_sections() -> str: + """Get list of supported section types for diagram generation.""" + supported_sections = list(prompts.section_prompt.keys()) + + result = { + "supported_sections": supported_sections, + "description": "These are the section types that have specialized prompts for diagram generation", + } + + return json.dumps(result, indent=2) + + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport="stdio") diff --git a/tiny_scientist/mcp/paper_search_server.py b/tiny_scientist/mcp/paper_search_server.py new file mode 100644 index 0000000..8522ef1 --- /dev/null +++ b/tiny_scientist/mcp/paper_search_server.py @@ -0,0 +1,239 @@ +import asyncio +import json +import os +from typing import Any, Dict, List, Optional + +import httpx +import toml +from mcp.server.fastmcp import FastMCP + +# Initialize FastMCP server +mcp = FastMCP("paper_search") + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "../..", "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + +# Semantic Scholar API configuration +S2_API_BASE = "https://api.semanticscholar.org/graph/v1" +S2_API_KEY = config["core"].get("s2_api_key", None) +SEARCH_ENGINE = config["core"].get("engine", "semanticscholar") + +# Debug: Print configuration status +print(f"[Paper Search] Config path: {config_path}") +print(f"[Paper Search] Config exists: {os.path.exists(config_path)}") +print(f"[Paper Search] API Key configured: {'Yes' if S2_API_KEY else 'No'}") +print(f"[Paper Search] Search engine: {SEARCH_ENGINE}") + + +async def make_s2_request( + url: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, +) -> Optional[Dict[str, Any]]: + """Make a request to the Semantic Scholar API with proper error handling.""" + default_headers = {} + + # Temporarily disable API key due to invalid key issue + # TODO: Update with a valid API key when available + use_api_key = False # Set to True when you have a valid API key + + if S2_API_KEY and use_api_key: + default_headers["X-API-KEY"] = S2_API_KEY + print(f"[Paper Search] Using API key: {S2_API_KEY[:10]}...") + else: + print("[Paper Search] Using unauthenticated access (rate limited)") + + if headers: + default_headers.update(headers) + + async with httpx.AsyncClient() as client: + try: + response = await client.get( + url, headers=default_headers, params=params, timeout=30.0 + ) + print(f"[Paper Search] Response status: {response.status_code}") + response.raise_for_status() + result: Dict[str, Any] = response.json() + if result.get("data"): + print(f"[Paper Search] Found {len(result['data'])} papers") + return result + except Exception as e: + print(f"[Paper Search] Semantic Scholar API request failed: {e}") + if hasattr(e, "response"): + print( + f"[Paper Search] Response text: {e.response.text if e.response else 'No response'}" + ) + return None + + +async def make_openalex_request( + query: str, result_limit: int +) -> Optional[List[Dict[str, Any]]]: + """Make a request to OpenAlex API.""" + try: + import pyalex + from pyalex import Works + + mail = os.environ.get("OPENALEX_MAIL_ADDRESS") + if mail: + pyalex.config.email = mail + else: + print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better API access") + + works = Works().search(query).get(per_page=result_limit) + if not works: + return None + + return [extract_openalex_work_info(work) for work in works] + except ImportError: + print("[ERROR] pyalex not installed, falling back to Semantic Scholar") + return None + except Exception as e: + print(f"OpenAlex API request failed: {e}") + return None + + +def extract_openalex_work_info( + work: Dict[str, Any], max_abstract_length: int = 1000 +) -> Dict[str, str]: + """Extract relevant information from OpenAlex work data.""" + venue = next( + (loc["source"]["display_name"] for loc in work["locations"] if loc["source"]), + "Unknown", + ) + + authors_list = [author["author"]["display_name"] for author in work["authorships"]] + authors = ( + " and ".join(authors_list) + if len(authors_list) < 20 + else f"{authors_list[0]} et al." + ) + + abstract = work.get("abstract", "") + if len(abstract) > max_abstract_length: + print(f"[WARNING] {work['title']}: Abstract is too long, truncating.") + abstract = abstract[:max_abstract_length] + + return { + "title": work["title"], + "authors": authors, + "venue": venue, + "year": str(work.get("publication_year", "Unknown")), + "abstract": abstract, + "citationCount": str(work.get("cited_by_count", 0)), + } + + +@mcp.tool() +async def search_papers(query: str, result_limit: int = 3) -> str: + """Search for academic papers using Semantic Scholar or OpenAlex. + + Args: + query: Search query string for papers + result_limit: Maximum number of papers to return (default: 3) + """ + print(f"[Paper Search] Searching for papers with query: {query}") + + if not query: + return json.dumps({"error": "No query provided"}) + + papers = None + + if SEARCH_ENGINE == "semanticscholar": + print(f"(Semantic Scholar API) Searching for papers with query: {query}") + papers = await search_semanticscholar(query, result_limit) + elif SEARCH_ENGINE == "openalex": + print(f"(OpenAlex API) Searching for papers with query: {query}") + papers = await make_openalex_request(query, result_limit) + else: + return json.dumps({"error": f"Unsupported search engine: {SEARCH_ENGINE}"}) + + if not papers: + return json.dumps({"error": "No papers found or API error"}) + + # Format papers and fetch bibtex for Semantic Scholar results + results = {} + for paper in papers: + paper_id = paper.get("paperId", None) + bibtex = "N/A" + + if SEARCH_ENGINE == "semanticscholar" and paper_id: + bibtex = await fetch_bibtex(paper_id) + + if bibtex and bibtex != "N/A": + title = paper.get("title", "Unknown Title") + results[title] = {"title": title, "bibtex": bibtex} + + return json.dumps(results, indent=2) + + +async def search_semanticscholar( + query: str, result_limit: int +) -> Optional[List[Dict[str, Any]]]: + """Search Semantic Scholar for papers.""" + params = { + "query": query, + "limit": result_limit, + "fields": "title,authors,venue,year,abstract,citationStyles,citationCount,paperId", + } + + url = f"{S2_API_BASE}/paper/search" + data = await make_s2_request(url, params) + + if not data or not data.get("total"): + return None + + # Add a small delay to be respectful to the API + await asyncio.sleep(8.0) + result = data.get("data") + return result if isinstance(result, list) else None + + +@mcp.tool() +async def fetch_bibtex(paper_id: str) -> str: + """Fetch BibTeX citation for a paper by its Semantic Scholar ID. + + Args: + paper_id: Semantic Scholar paper ID + """ + print(f"[Paper Search] Fetching BibTeX for paper ID: {paper_id}") + + url = f"{S2_API_BASE}/paper/{paper_id}" + params = {"fields": "citationStyles"} + + data = await make_s2_request(url, params) + if not data: + return "N/A" + + citation_styles = data.get("citationStyles", {}) + bibtex = citation_styles.get("bibtex", "N/A") + return bibtex if isinstance(bibtex, str) else "N/A" + + +@mcp.tool() +async def get_paper_details(paper_id: str) -> str: + """Get detailed information about a paper by its Semantic Scholar ID. + + Args: + paper_id: Semantic Scholar paper ID + """ + print(f"[Paper Search] Getting details for paper ID: {paper_id}") + + url = f"{S2_API_BASE}/paper/{paper_id}" + params = { + "fields": "title,authors,venue,year,abstract,citationCount,citationStyles" + } + + data = await make_s2_request(url, params) + if not data: + return json.dumps({"error": "Paper not found or API error"}) + + return json.dumps(data, indent=2) + + +# Import asyncio at the end to avoid issues + +if __name__ == "__main__": + # Initialize and run the server + mcp.run(transport="stdio") diff --git a/tiny_scientist/mcp/tool.py b/tiny_scientist/mcp/tool.py new file mode 100644 index 0000000..0dc9592 --- /dev/null +++ b/tiny_scientist/mcp/tool.py @@ -0,0 +1,466 @@ +import abc +import json +import os +import re +import time +from importlib import resources +from typing import Any, Dict, List, Optional, cast + +import fitz +import requests +import toml +from rich import print + +from ..budget_checker import BudgetChecker +from ..configs import Config +from ..utils.error_handler import api_calling_error_exponential_backoff +from ..utils.llm import create_client, get_response_from_llm + +# Load config +config_path = os.path.join(os.path.dirname(__file__), "config.toml") +config = toml.load(config_path) if os.path.exists(config_path) else {"core": {}} + + +class BaseTool(abc.ABC): + def __init__(self, cost_tracker: Optional[BudgetChecker] = None) -> None: + self.cost_tracker = cost_tracker or BudgetChecker() + self.github_token = config["core"].get("github_token", None) + + @abc.abstractmethod + def run(self, query: str) -> Dict[str, Dict[str, str]]: + pass + + +class CodeSearchTool(BaseTool): + def __init__(self) -> None: + super().__init__() + + def run( + self, query: str, search_type: str = "repositories" + ) -> Dict[str, Dict[str, str]]: + print(f"[github API calling] Searching for code with query: {query}") + results = {} + + try: + idea = json.loads(query) + if isinstance(idea, dict) and any( + k in idea for k in ["Title", "Experiment"] + ): + query = self.format_github_repo_query(idea) + print(f"[github API calling] Formatted query from idea: {query}") + except (json.JSONDecodeError, TypeError): + pass + + repos = self._search_github(query=query, search_type=search_type) + + if repos: + for i, repo in enumerate(repos): + results[str(i)] = { + "title": repo["name"], + "source": repo["url"], + "info": f"Stars: {repo['stars']}", + } + + self.cost_tracker.report() + return results + + def format_github_repo_query( + self, idea: Dict[str, Any], max_terms: int = 6, max_query_length: int = 250 + ) -> str: + import re + + import spacy + + title = idea.get("Title", "") + experiment = idea.get("Experiment", "") + combined_text = f"{title}. {experiment}" + + nlp = spacy.load("en_core_web_sm") + doc = nlp(combined_text) + candidates = set() + + # Extract short noun phrases + for chunk in doc.noun_chunks: + phrase = chunk.text.strip().lower() + if 1 <= len(phrase.split()) <= 4: + candidates.add(phrase) + + # Add important standalone nouns and proper nouns + for token in doc: + if token.pos_ in {"NOUN", "PROPN"} and len(token.text) > 2: + candidates.add(token.text.lower()) + + # Clean and deduplicate + seen = set() + keywords = [] + for kw in candidates: + cleaned = re.sub(r"[^\w\s]", "", kw) + if cleaned not in seen: + seen.add(cleaned) + keywords.append(cleaned) + if len(keywords) >= max_terms: + break + + # Build query string + quoted_keywords = [f'"{kw}"' if " " in kw else kw for kw in keywords] + base_query = " ".join(quoted_keywords) + suffix = " in:file language:python" + full_query = f"{base_query} {suffix}" + + # Truncate if needed + if len(full_query) > max_query_length: + full_query = f"{' '.join(quoted_keywords[:max_terms//2])} {suffix}" + + return full_query + + def _search_github( + self, query: str, search_type: str, result_limit: int = 10 + ) -> Optional[List[Dict[str, Any]]]: + if search_type not in ["repositories", "code"]: + raise ValueError("search_type must be either 'repositories' or 'code'.") + + url = f"https://api.github.com/search/{search_type}" + headers = ( + {"Authorization": f"token {self.github_token}"} if self.github_token else {} + ) + + params = { + "q": query, + "sort": "stars" if search_type == "repositories" else "indexed", + "order": "desc", + "per_page": result_limit, + } + + response = requests.get(url, headers=headers, params=params) + print( + f"GitHub {search_type.capitalize()} Response Status Code: {response.status_code}" + ) + response.raise_for_status() + + results = response.json() + if "items" not in results: + return None + + return ( + self._extract_github_repo_info(results["items"]) + if search_type == "repositories" + else self._extract_github_code_info(results["items"]) + ) + + @staticmethod + def _extract_github_repo_info(repos: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return [ + { + "name": repo["name"], + "owner": repo["owner"]["login"], + "stars": repo["stargazers_count"], + "forks": repo["forks_count"], + "url": repo["html_url"], + "description": repo["description"] or "No description provided.", + } + for repo in repos + ] + + @staticmethod + def _extract_github_code_info( + code_results: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + return [ + { + "file_name": item["name"], + "repository": item["repository"]["full_name"], + "url": item["html_url"], + } + for item in code_results + ] + + +class PaperSearchTool(BaseTool): + def __init__(self, s2_api_key: Optional[str] = None) -> None: + super().__init__() + self.s2_api_key = ( + s2_api_key + or os.environ.get("S2_API_KEY") + or config["core"].get("s2_api_key") + ) + + # Set default engine if not configured + self.engine = config["core"].get("engine", "semanticscholar") + + def run(self, query: str) -> Dict[str, Dict[str, str]]: + results = {} + papers = self.search_for_papers(query) + + if papers: + for i, paper in enumerate(papers): + paper_id = paper.get("paperId", None) + bibtex = self.fetch_bibtex(paper_id) if paper_id else "N/A" + + if not bibtex or bibtex == "N/A": + continue + + results[paper["title"]] = {"title": paper["title"], "bibtex": bibtex} + + self.cost_tracker.report() + return results + + def search_for_papers( + self, query: str, result_limit: int = 3 + ) -> Optional[List[Dict[str, Any]]]: + if not query: + return None + + if self.engine == "semanticscholar": + print( + f"(semantic scholar API calling) Searching for papers with query: {query}" + ) + return self._search_semanticscholar(query, result_limit) + elif self.engine == "openalex": + print(f"(openalex API calling) Searching for papers with query: {query}") + return self._search_openalex(query, result_limit) + else: + raise NotImplementedError(f"{self.engine=} not supported!") + + @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) + def _search_semanticscholar( + self, query: str, result_limit: int + ) -> Optional[List[Dict[str, Any]]]: + params: Dict[str, str | int] = { + "query": query, + "limit": result_limit, + "fields": "title,authors,venue,year,abstract,citationStyles,citationCount,paperId", + } + + headers = {"X-API-KEY": self.s2_api_key} if self.s2_api_key else {} + rsp = requests.get( + "https://api.semanticscholar.org/graph/v1/paper/search", + headers=headers, + params=params, + ) + rsp.raise_for_status() + + results = rsp.json() + if not results.get("total"): + return None + + time.sleep(1.0) + return cast(Optional[List[Dict[str, Any]]], results.get("data")) + + def _search_openalex( + self, query: str, result_limit: int + ) -> Optional[List[Dict[str, Any]]]: + import pyalex + from pyalex import Works + + mail = os.environ.get("OPENALEX_MAIL_ADDRESS") + if mail: + pyalex.config.email = mail + else: + print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better API access") + + works = Works().search(query).get(per_page=result_limit) + if not works: + return None + + return [self._extract_work_info(work) for work in works] + + @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) + def fetch_bibtex(self, paper_id: str) -> Any: + headers = {"X-API-KEY": self.s2_api_key} if self.s2_api_key else {} + rsp = requests.get( + f"https://api.semanticscholar.org/graph/v1/paper/{paper_id}", + headers=headers, + params={"fields": "citationStyles"}, + ) + rsp.raise_for_status() + citation_styles = rsp.json().get("citationStyles", {}) + return citation_styles.get("bibtex", "N/A") + + @staticmethod + def _extract_work_info( + work: Dict[str, Any], max_abstract_length: int = 1000 + ) -> Dict[str, str]: + venue = next( + ( + loc["source"]["display_name"] + for loc in work["locations"] + if loc["source"] + ), + "Unknown", + ) + + authors_list = [ + author["author"]["display_name"] for author in work["authorships"] + ] + authors = ( + " and ".join(authors_list) + if len(authors_list) < 20 + else f"{authors_list[0]} et al." + ) + + abstract = work.get("abstract", "") + if len(abstract) > max_abstract_length: + print(f"[WARNING] {work['title']}: Abstract is too long, truncating.") + abstract = abstract[:max_abstract_length] + + return { + "title": work["title"], + "authors": authors, + "venue": venue, + "year": work.get("publication_year", "Unknown"), + "abstract": abstract, + "citationCount": work.get("cited_by_count", 0), + } + + +class DrawerTool(BaseTool): + def __init__( + self, + model: Any, + prompt_template_dir: Optional[str] = None, + temperature: float = 0.75, + ): + super().__init__() + self.client, self.model = create_client(model) + self.temperature = temperature + + # Load prompt templates using Config + self.config = Config(prompt_template_dir) + self.prompts = self.config.prompt_template.drawer_prompt + + def escape_curly_braces(text: str) -> str: + return re.sub(r"({|})", r"{{\1}}", text) + + def extract_pdf_text_from_resource(package: str, filename: str) -> str: + with resources.files(package).joinpath(filename).open("rb") as f: + doc = fitz.open(stream=f.read(), filetype="pdf") + extracted = [page.get_text().strip() for page in doc] + return "\n\n".join(extracted) + + method_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "framework.pdf" + ) + result_sample_raw = extract_pdf_text_from_resource( + "tiny_scientist.fewshot_sample", "result.pdf" + ) + + method_sample = escape_curly_braces(method_sample_raw) + result_sample = escape_curly_braces(result_sample_raw) + + self.system_prompts = self.prompts.diagram_system_prompt.format( + method_sample=method_sample, + result_sample=result_sample, + ) + + self.dir_path = os.path.dirname(os.path.realpath(__file__)) + + def run(self, query: str) -> Dict[str, Dict[str, str]]: + try: + query_dict = json.loads(query) + section_name = query_dict.get("section_name") + section_content = query_dict.get("section_content") + except (json.JSONDecodeError, TypeError, AttributeError): + raise ValueError( + "Expected query to be a JSON string with 'section_name' and 'section_content'." + ) + + diagram = self.draw_diagram( + section_name=section_name, section_content=section_content + ) + + results = {} + if diagram: + results["diagram"] = { + "summary": diagram.get("summary", ""), + "svg": diagram.get("svg", ""), + } + self.cost_tracker.report() + return results + + def draw_diagram( + self, + section_name: str, + section_content: str, + msg_history: Optional[List[Dict[str, Any]]] = None, + return_msg_history: bool = False, + ) -> Any: + # Use default system prompt if none provided + section_prompt = self._get_section_prompts(section_name, section_content) + + diagram, updated_msg_history = self._generate_diagram( + section_prompt, self.system_prompts, msg_history + ) + + return (diagram, updated_msg_history) if return_msg_history else diagram + + def _get_section_prompts(self, section_name: str, section_text: str) -> str: + section_prompt = self.prompts.section_prompt[section_name].format( + section_text=section_text + ) + + return section_prompt + + @api_calling_error_exponential_backoff(retries=5, base_wait_time=2) + def _generate_diagram( + self, + section_prompt: str, + drawer_system_prompt: str, + msg_history: Optional[List[Dict[str, Any]]], + ) -> tuple[Dict[str, Any], List[Dict[str, Any]]]: + # Ensure msg_history is a list + msg_history = msg_history or [] + + # Generate diagram + llm_response, msg_history = get_response_from_llm( + section_prompt, + model=self.model, + client=self.client, + system_message=drawer_system_prompt, + msg_history=msg_history, + temperature=self.temperature, + cost_tracker=self.cost_tracker, + task_name="generate_diagram", + ) + + diagram = self._extract_diagram(llm_response) + return diagram, msg_history + + def _extract_diagram(self, response: str) -> Dict[str, Any]: + result = {"summary": "", "svg": "", "full_response": response} + + try: + parsed = json.loads(response) + summary = parsed["summary"] + svg = parsed["svg"] + except json.JSONDecodeError: + svg_match = re.search(r"", response, re.DOTALL) + svg = svg_match.group(0) if svg_match else "" + summary = ( + re.sub(r"", "", response, flags=re.DOTALL) + .strip() + .split("\n")[0] + ) + + if "" in svg: + result["summary"] = summary + result["svg"] = self._clean_svg(svg) + else: + print("[ERROR] SVG missing or too short.") + return result + + def _clean_svg(self, svg: str) -> str: + # Strip any outer code block delimiters + svg = svg.strip() + svg = re.sub(r"^```(?:svg)?", "", svg) + svg = re.sub(r"```$", "", svg) + + # Replace problematic ampersands + svg = svg.replace("&", "&") + + # Ensure no double XML declarations + svg = re.sub(r"<\?xml.*?\?>", "", svg, count=1) + + # Remove extra whitespace lines + svg = "\n".join([line for line in svg.splitlines() if line.strip()]) + + return svg.strip() diff --git a/tiny_scientist/reviewer.py b/tiny_scientist/reviewer.py index 59aec92..fe3cefb 100644 --- a/tiny_scientist/reviewer.py +++ b/tiny_scientist/reviewer.py @@ -28,6 +28,7 @@ def __init__( pre_reflection_threshold: float = 0.5, post_reflection_threshold: float = 0.8, s2_api_key: Optional[str] = None, + mcp_client: Any = None, ): self.tools = tools self.num_reviews = num_reviews @@ -35,7 +36,11 @@ def __init__( self.client, self.model = create_client(model) self.temperature = temperature self.config = Config(prompt_template_dir) - self.searcher: BaseTool = PaperSearchTool(s2_api_key=s2_api_key) + self.mcp_client = mcp_client + # Use MCP searcher if available, otherwise fallback to traditional searcher + self.searcher: Optional[BaseTool] = ( + PaperSearchTool(s2_api_key=s2_api_key) if not mcp_client else None + ) self._query_cache: Dict[str, List[Dict[str, Any]]] = {} self.last_related_works_string = "" self.cost_tracker = cost_tracker or BudgetChecker() @@ -139,8 +144,89 @@ def _get_related_works(self, query: str) -> str: if query in self._query_cache: related_papers = self._query_cache[query] else: - results_dict = self.searcher.run(query) - related_papers = list(results_dict.values()) + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run(search_papers(query, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + if results_json: + import json + + results_dict = json.loads(results_json) + if results_dict: + # Convert MCP format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # MCP format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # MCP doesn't return author info + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}", + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue", + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + except Exception as e: + print( + f"[WARNING] MCP search failed, falling back to traditional search: {e}" + ) + if self.searcher: + results_dict = self.searcher.run(query) + related_papers = list(results_dict.values()) + else: + related_papers = [] + else: + # Use traditional searcher + if self.searcher: + results_dict = self.searcher.run(query) + if results_dict: + # Convert traditional format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # Traditional format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # Traditional tool doesn't return author info either + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}", + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue", + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + self._query_cache[query] = related_papers if related_papers else [] if related_papers: diff --git a/tiny_scientist/scientist.py b/tiny_scientist/scientist.py index 9b74352..80c1408 100644 --- a/tiny_scientist/scientist.py +++ b/tiny_scientist/scientist.py @@ -1,3 +1,4 @@ +import datetime import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -10,6 +11,7 @@ from .safety_checker import SafetyChecker from .thinker import Thinker from .utils.input_formatter import InputFormatter +from .utils.mcp_client import MCPClient from .writer import Writer @@ -23,16 +25,30 @@ def __init__( budget: Optional[float] = None, enable_safety_check: bool = True, budget_preference: Optional[str] = None, + use_mcp: bool = True, ): self.model = model - self.output_dir = output_dir + self.base_output_dir = output_dir # Store user's base directory + + # Create a unique experiment directory with timestamp + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + self.experiment_dir = os.path.join(output_dir, f"experiment_{timestamp}") + + # Ensure the experiment directory exists + os.makedirs(self.experiment_dir, exist_ok=True) + print(f"๐Ÿ”ฌ Created experiment directory: {self.experiment_dir}") + self.template = template self.prompt_template_dir = prompt_template_dir self.input_formatter = InputFormatter() self.enable_safety_check = enable_safety_check + self.use_mcp = use_mcp self.cost = 0.0 + # Initialize MCP client if enabled + self.mcp_client = MCPClient() if use_mcp else None + # Naive budget split modules = ["safety_checker", "thinker", "coder", "writer", "reviewer"] per_module_budget = budget / len(modules) if budget else None @@ -85,9 +101,10 @@ def __init__( else None ) + # Use the unique experiment directory for all modules self.thinker = Thinker( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, tools=[], iter_num=3, @@ -96,23 +113,26 @@ def __init__( enable_ethical_defense=False, enable_safety_check=enable_safety_check, cost_tracker=BudgetChecker(budget=allocation.get("thinker")), + mcp_client=self.mcp_client, ) self.coder = Coder( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, max_iters=4, max_runs=3, cost_tracker=BudgetChecker(budget=allocation.get("coder")), + mcp_client=self.mcp_client, ) self.writer = Writer( model=model, - output_dir=output_dir, + output_dir=self.experiment_dir, prompt_template_dir=prompt_template_dir, template=template, cost_tracker=BudgetChecker(budget=allocation.get("writer")), + mcp_client=self.mcp_client, ) self.reviewer = Reviewer( @@ -120,8 +140,35 @@ def __init__( prompt_template_dir=prompt_template_dir, tools=[], cost_tracker=BudgetChecker(budget=allocation.get("reviewer")), + mcp_client=self.mcp_client, ) + async def initialize_mcp(self) -> None: + """Initialize MCP servers.""" + if self.mcp_client: + print("๐Ÿ”ง Initializing MCP servers...") + results = await self.mcp_client.start_all_servers() + for server_name, success in results.items(): + if success: + print(f"โœ… MCP server '{server_name}' started successfully") + else: + print(f"โŒ Failed to start MCP server '{server_name}'") + + async def cleanup_mcp(self) -> None: + """Clean up MCP servers.""" + if self.mcp_client: + print("๐Ÿงน Shutting down MCP servers...") + await self.mcp_client.stop_all_servers() + + async def __aenter__(self) -> "TinyScientist": + """Async context manager entry.""" + await self.initialize_mcp() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.cleanup_mcp() + def think( self, intent: str, num_ideas: int = 1, pdf_content: Optional[str] = None ) -> Union[List[Dict[str, Any]], Dict[str, Any]]: @@ -158,11 +205,13 @@ def code( print(f"โŒ Experiment failed. Please check {exp_path} for details.") if error_details: print(f"Error details: {error_details}") - return status, exp_path + return status, self.experiment_dir - def write(self, idea: Dict[str, Any], experiment_dir: str) -> str: + def write(self, idea: Dict[str, Any], experiment_dir: Optional[str] = None) -> str: print("๐Ÿ“ Writing paper...") - pdf_path, paper_name = self.writer.run(idea=idea, experiment_dir=experiment_dir) + # Use the internal experiment directory if no specific directory is provided + exp_dir = experiment_dir if experiment_dir is not None else self.experiment_dir + pdf_path, paper_name = self.writer.run(idea=idea, experiment_dir=exp_dir) print( f"Check the generated paper named as {paper_name} and saved at {pdf_path}" ) diff --git a/tiny_scientist/thinker.py b/tiny_scientist/thinker.py index 4dfad65..33cc86c 100644 --- a/tiny_scientist/thinker.py +++ b/tiny_scientist/thinker.py @@ -30,8 +30,10 @@ def __init__( prompt_template_dir: Optional[str] = None, cost_tracker: Optional[BudgetChecker] = None, enable_safety_check: bool = False, + enable_ethical_defense: bool = False, pre_reflection_threshold: float = 0.5, post_reflection_threshold: float = 0.8, + mcp_client: Any = None, ): self.tools = tools self.iter_num = iter_num @@ -39,12 +41,21 @@ def __init__( self.output_dir = output_dir self.temperature = temperature self.config = Config(prompt_template_dir) - self.searcher = PaperSearchTool() + self.mcp_client = mcp_client + # Use MCP searcher if available, otherwise fallback to traditional searcher + self.searcher: Optional[PaperSearchTool] = ( + PaperSearchTool() if not mcp_client else None + ) self.search_papers = search_papers self.generate_exp_plan = generate_exp_plan self.prompts = self.config.prompt_template.thinker_prompt self.intent = "" self._query_cache: Dict[str, List[Dict[str, Any]]] = {} + self.cost_tracker = cost_tracker or BudgetChecker() + self.enable_safety_check = enable_safety_check + self.enable_ethical_defense = enable_ethical_defense + self.pre_reflection_threshold = pre_reflection_threshold + self.post_reflection_threshold = post_reflection_threshold # Enhanced criteria system from TinyScientistUI self.default_system_prompt = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field. @@ -551,8 +562,92 @@ def _get_related_works(self, query: str) -> str: print("โœ… Using cached query results") else: print(f"Searching for papers with query: {query}") - results_dict = self.searcher.run(query) - related_papers = list(results_dict.values()) if results_dict else [] + + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run(search_papers(query, self.mcp_client)) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + if results_json: + import json + + results_dict = json.loads(results_json) + if results_dict: + # Convert MCP format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # MCP format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # MCP doesn't return author info + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}", + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue", + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + except Exception as e: + print( + f"[WARNING] MCP search failed, falling back to traditional search: {e}" + ) + if self.searcher: + results_dict = self.searcher.run(query) + related_papers = ( + list(results_dict.values()) if results_dict else [] + ) + else: + related_papers = [] + else: + # Use traditional searcher + if self.searcher: + results_dict = self.searcher.run(query) + if results_dict: + # Convert traditional format to expected format + related_papers = [] + for title, paper_data in results_dict.items(): + if isinstance(paper_data, dict): + # Traditional format: {"title": ..., "bibtex": ...} + paper = { + "title": paper_data.get("title", title), + "source": "Unknown authors", # Traditional tool doesn't return author info either + "info": f"BibTeX available: {paper_data.get('bibtex', 'N/A') != 'N/A'}", + } + else: + # Fallback if unexpected format + paper = { + "title": title, + "source": "Unknown authors", + "info": "Unknown venue", + } + related_papers.append(paper) + else: + related_papers = [] + else: + related_papers = [] + self._query_cache[query] = related_papers if related_papers: diff --git a/tiny_scientist/utils/mcp_client.py b/tiny_scientist/utils/mcp_client.py new file mode 100644 index 0000000..073f20e --- /dev/null +++ b/tiny_scientist/utils/mcp_client.py @@ -0,0 +1,476 @@ +import json +import os +import subprocess +from typing import Any, Dict, List, Optional + +import toml +from rich import print + + +class MCPClient: + """Client for managing and communicating with MCP servers.""" + + def __init__(self, config_path: Optional[str] = None): + """Initialize MCP client with configuration. + + Args: + config_path: Path to configuration file containing MCP server settings + """ + self.config_path = config_path or self._get_default_config_path() + self.config = self._load_config() + self.servers: Dict[str, subprocess.Popen[str]] = {} + self.server_configs = self.config.get("mcp", {}).get("servers", {}) + + def _get_default_config_path(self) -> str: + """Get default config path.""" + this_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + return os.path.join(this_dir, "config.toml") + + def _load_config(self) -> Dict[str, Any]: + """Load configuration from TOML file.""" + try: + with open(self.config_path, "r") as f: + return toml.load(f) + except FileNotFoundError: + print(f"[WARNING] Config file not found: {self.config_path}") + return {} + except Exception as e: + print(f"[ERROR] Failed to load config: {e}") + return {} + + async def start_server(self, server_name: str) -> bool: + """Start a specific MCP server. + + Args: + server_name: Name of the server to start + + Returns: + bool: True if server started successfully + """ + if server_name in self.servers: + print(f"[MCP] Server {server_name} is already running") + return True + + server_config = self.server_configs.get(server_name) + if not server_config: + print(f"[ERROR] No configuration found for server: {server_name}") + return False + + try: + command = server_config.get("command", "") + args = server_config.get("args", []) + working_dir = server_config.get("cwd") + + # Build full command + full_command = [command] + args + + # Start the server process + process = subprocess.Popen( + full_command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=working_dir, + text=True, + ) + + self.servers[server_name] = process + + # Perform MCP initialization handshake + init_success = await self._initialize_server(server_name) + if not init_success: + await self.stop_server(server_name) + return False + + print(f"[MCP] Started server: {server_name}") + return True + + except Exception as e: + print(f"[ERROR] Failed to start server {server_name}: {e}") + return False + + async def _initialize_server(self, server_name: str) -> bool: + """Initialize MCP server with proper handshake. + + Args: + server_name: Name of the server to initialize + + Returns: + bool: True if initialization successful + """ + if server_name not in self.servers: + return False + + try: + process = self.servers[server_name] + + # Send initialize request + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "tiny-scientist-mcp-client", + "version": "1.0.0", + }, + }, + } + + request_json = json.dumps(init_request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return False + process.stdin.write(request_json) + process.stdin.flush() + + # Read initialization response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return False + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No initialization response from {server_name}") + return False + + response = json.loads(response_line.strip()) + + # Check for initialization success + if "error" in response: + print(f"[ERROR] Server initialization failed: {response['error']}") + return False + + # Send initialized notification + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + } + + notification_json = json.dumps(initialized_notification) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return False + process.stdin.write(notification_json) + process.stdin.flush() + + return True + + except Exception as e: + print(f"[ERROR] Failed to initialize server {server_name}: {e}") + return False + + async def stop_server(self, server_name: str) -> bool: + """Stop a specific MCP server. + + Args: + server_name: Name of the server to stop + + Returns: + bool: True if server stopped successfully + """ + if server_name not in self.servers: + print(f"[WARNING] Server {server_name} is not running") + return True + + try: + process = self.servers[server_name] + process.terminate() + + # Wait for process to terminate + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + del self.servers[server_name] + print(f"[MCP] Stopped server: {server_name}") + return True + + except Exception as e: + print(f"[ERROR] Failed to stop server {server_name}: {e}") + return False + + async def start_all_servers(self) -> Dict[str, bool]: + """Start all configured MCP servers. + + Returns: + Dict mapping server names to success status + """ + results = {} + for server_name in self.server_configs.keys(): + results[server_name] = await self.start_server(server_name) + return results + + async def stop_all_servers(self) -> Dict[str, bool]: + """Stop all running MCP servers. + + Returns: + Dict mapping server names to success status + """ + results = {} + for server_name in list(self.servers.keys()): + results[server_name] = await self.stop_server(server_name) + return results + + async def call_tool( + self, server_name: str, tool_name: str, **kwargs: Any + ) -> Optional[str]: + """Call a tool on a specific MCP server. + + Args: + server_name: Name of the server to call + tool_name: Name of the tool to call + **kwargs: Tool parameters + + Returns: + Tool response as string, or None if error + """ + if server_name not in self.servers: + print(f"[ERROR] Server {server_name} is not running") + return None + + try: + process = self.servers[server_name] + + # Create tool call request + request = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": tool_name, "arguments": kwargs}, + } + + # Send request to server + request_json = json.dumps(request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return None + process.stdin.write(request_json) + process.stdin.flush() + + # Read response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return None + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No response from server {server_name}") + return None + + response = json.loads(response_line.strip()) + + # Check for errors + if "error" in response: + print(f"[ERROR] Tool call failed: {response['error']}") + return None + + # Extract result + result = response.get("result", {}) + if isinstance(result, dict) and "content" in result: + content = result["content"][0].get("text", "") + return content if isinstance(content, str) else str(content) + elif isinstance(result, str): + return result + else: + return json.dumps(result) + + except Exception as e: + print(f"[ERROR] Failed to call tool {tool_name} on {server_name}: {e}") + return None + + async def get_available_tools( + self, server_name: str + ) -> Optional[List[Dict[str, Any]]]: + """Get list of available tools from a server. + + Args: + server_name: Name of the server to query + + Returns: + List of tool definitions, or None if error + """ + if server_name not in self.servers: + print(f"[ERROR] Server {server_name} is not running") + return None + + try: + process = self.servers[server_name] + + # Create list tools request + request = {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}} + + # Send request to server + request_json = json.dumps(request) + "\n" + if process.stdin is None: + print(f"[ERROR] No stdin available for server {server_name}") + return None + process.stdin.write(request_json) + process.stdin.flush() + + # Read response + if process.stdout is None: + print(f"[ERROR] No stdout available for server {server_name}") + return None + response_line = process.stdout.readline() + if not response_line: + print(f"[ERROR] No response from server {server_name}") + return None + + response = json.loads(response_line.strip()) + + # Check for errors + if "error" in response: + print(f"[ERROR] Failed to list tools: {response['error']}") + return None + + # Extract tools + result = response.get("result", {}) + tools = result.get("tools", []) + return tools if isinstance(tools, list) else [] + + except Exception as e: + print(f"[ERROR] Failed to get tools from {server_name}: {e}") + return None + + def is_server_running(self, server_name: str) -> bool: + """Check if a server is currently running. + + Args: + server_name: Name of the server to check + + Returns: + True if server is running + """ + if server_name not in self.servers: + return False + + process = self.servers[server_name] + return process.poll() is None + + def get_running_servers(self) -> List[str]: + """Get list of currently running servers. + + Returns: + List of server names + """ + return [name for name in self.servers.keys() if self.is_server_running(name)] + + async def health_check(self) -> Dict[str, bool]: + """Perform health check on all configured servers. + + Returns: + Dict mapping server names to health status + """ + results = {} + for server_name in self.server_configs.keys(): + if self.is_server_running(server_name): + # Try to get tools as a health check + tools = await self.get_available_tools(server_name) + results[server_name] = tools is not None + else: + results[server_name] = False + return results + + async def __aenter__(self) -> "MCPClient": + """Async context manager entry.""" + await self.start_all_servers() + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + """Async context manager exit.""" + await self.stop_all_servers() + + +# Convenience functions for common operations +async def search_github_code( + query: str, client: MCPClient, result_limit: int = 10 +) -> Optional[str]: + """Search GitHub code using MCP client. + + Args: + query: Search query + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("code_search"): + await client.start_server("code_search") + + return await client.call_tool( + "code_search", "search_github_code", query=query, result_limit=result_limit + ) + + +async def search_github_repositories( + query: str, client: MCPClient, result_limit: int = 10 +) -> Optional[str]: + """Search GitHub repositories using MCP client. + + Args: + query: Search query or JSON research idea + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("code_search"): + await client.start_server("code_search") + + return await client.call_tool( + "code_search", + "search_github_repositories", + query=query, + result_limit=result_limit, + ) + + +async def search_papers( + query: str, client: MCPClient, result_limit: int = 3 +) -> Optional[str]: + """Search papers using MCP client. + + Args: + query: Search query + client: MCP client instance + result_limit: Maximum results to return + + Returns: + Search results as JSON string + """ + if not client.is_server_running("paper_search"): + await client.start_server("paper_search") + + return await client.call_tool( + "paper_search", "search_papers", query=query, result_limit=result_limit + ) + + +async def generate_diagram( + section_name: str, section_content: str, client: MCPClient +) -> Optional[str]: + """Generate diagram using MCP client. + + Args: + section_name: Name of the paper section + section_content: Content of the section + client: MCP client instance + + Returns: + Diagram data as JSON string + """ + if not client.is_server_running("drawer"): + await client.start_server("drawer") + + return await client.call_tool( + "drawer", + "generate_diagram", + section_name=section_name, + section_content=section_content, + ) diff --git a/tiny_scientist/writer.py b/tiny_scientist/writer.py index 4312cc8..fb3fd4e 100644 --- a/tiny_scientist/writer.py +++ b/tiny_scientist/writer.py @@ -35,13 +35,22 @@ def __init__( prompt_template_dir: Optional[str] = None, cost_tracker: Optional[BudgetChecker] = None, s2_api_key: Optional[str] = None, + mcp_client: Any = None, ) -> None: self.client, self.model = create_client(model) self.output_dir = output_dir self.template = template self.temperature = temperature - self.searcher: BaseTool = PaperSearchTool(s2_api_key=s2_api_key) - self.drawer: BaseTool = DrawerTool(model, prompt_template_dir, temperature) + self.mcp_client = mcp_client + # Fallback to traditional tools if MCP is not available + self.searcher: Optional[BaseTool] = ( + PaperSearchTool(s2_api_key=s2_api_key) if not mcp_client else None + ) + self.drawer: Optional[BaseTool] = ( + DrawerTool(model, prompt_template_dir, temperature) + if not mcp_client + else None + ) self.formatter: BaseOutputFormatter self.config = Config(prompt_template_dir) if self.template == "acl": @@ -162,10 +171,55 @@ def _generate_diagram_for_section(self) -> None: for section in ["Method", "Experimental_Setup", "Results"]: content = self.generated_sections[section] try: - query = json.dumps( - {"section_name": section, "section_content": content} - ) - diagram_result = self.drawer.run(query) + if self.mcp_client: + # Use MCP client for diagram generation + import asyncio + + from .utils.mcp_client import generate_diagram + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_diagram() -> Optional[str]: + """Run the async diagram function in a new event loop.""" + return asyncio.run( + generate_diagram(section, content, self.mcp_client) + ) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_diagram) + results_json = future.result( + timeout=60.0 + ) # Longer timeout for diagram generation + + if results_json: + import json + + diagram_result = json.loads(results_json) + else: + diagram_result = {} + except Exception as e: + print( + f"[WARNING] MCP diagram generation failed, falling back to traditional drawer: {e}" + ) + if self.drawer: + query = json.dumps( + {"section_name": section, "section_content": content} + ) + diagram_result = self.drawer.run(query) + else: + diagram_result = {} + else: + # Use traditional drawer + if self.drawer: + query = json.dumps( + {"section_name": section, "section_content": content} + ) + diagram_result = self.drawer.run(query) + else: + diagram_result = {} if diagram_result and "diagram" in diagram_result: diagram = diagram_result["diagram"] @@ -266,12 +320,16 @@ def _write_section( elif section == "Analysis": # For non-experimental papers, use the research plan content research_plan = idea.get("ResearchPlan", experiment) + approach = idea.get("Approach", "No approach specified") section_prompt = self.prompts.section_prompt.get( section, self.prompts.section_prompt.get("Results", "") ).format( section_tips=self.prompts.section_tips.get( section, self.prompts.section_tips.get("Results", "") ), + problem=idea["Problem"], # Add the required problem field + approach=approach, # Add the required approach field + research_plan=research_plan, # Add the required research_plan field experiment=research_plan, baseline_results=baseline_result, experiment_results=experiment_result, @@ -339,20 +397,113 @@ def _search_reference(self, paper_list: List[str]) -> Dict[str, Any]: for paper_name in paper_list: try: - result = self.searcher.run(paper_name) + print(f"[Writer] Searching for paper: {paper_name}") + + if self.mcp_client: + # Use MCP client for paper search + import asyncio + + from .utils.mcp_client import search_papers + + try: + # Handle async function call properly to avoid event loop conflicts + import concurrent.futures + + def run_async_search() -> Optional[str]: + """Run the async search function in a new event loop.""" + return asyncio.run( + search_papers(paper_name, self.mcp_client) + ) + + # Always use ThreadPoolExecutor to avoid event loop conflicts + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_async_search) + results_json = future.result(timeout=30.0) # Add timeout + + print( + f"[Writer] MCP raw result: {results_json[:200] if results_json else 'None'}..." + ) + + if results_json: + import json + + result = json.loads(results_json) + print( + f"[Writer] MCP JSON parsing successful: {type(result)}" + ) + + # Check MCP return format and handle errors + if isinstance(result, dict): + if "error" in result: + print( + f"[Writer] MCP returned error: {result['error']}" + ) + result = {} # Convert error to empty result + else: + # Validate and convert data format + formatted_result = {} + for title, meta in result.items(): + if isinstance(meta, dict) and "bibtex" in meta: + # MCP format correct: {title: {title: "...", bibtex: "..."}} + formatted_result[title] = meta + print( + f"[Writer] Valid format paper: {title}" + ) + elif isinstance(meta, str): + print( + f"[Writer] Invalid format, skipping: {title} -> {meta}" + ) + else: + print( + f"[Writer] Unknown format, skipping: {title} -> {type(meta)}" + ) + result = formatted_result + else: + print("[Writer] MCP returned empty result") + result = {} + + except Exception as e: + print( + f"[Writer] MCP search failed, falling back to traditional search: {e}" + ) + if self.searcher: + result = self.searcher.run(paper_name) + print( + f"[Writer] Traditional search result: {type(result)}, length: {len(result) if result else 0}" + ) + else: + result = {} + else: + # Use traditional searcher + print("[Writer] Using traditional searcher") + if self.searcher: + result = self.searcher.run(paper_name) + print( + f"[Writer] Traditional search result: {type(result)}, length: {len(result) if result else 0}" + ) + else: + result = {} + # Process search results if result: + print(f"[Writer] Found search results, count: {len(result)}") if paper_name in result: results_dict[paper_name] = result[paper_name] + print(f"[Writer] Exact match: {paper_name}") else: + # Use first result first_key = next(iter(result)) results_dict[first_key] = result[first_key] + print(f"[Writer] Using first result: {first_key}") + else: + print(f"[Writer] No papers found for: {paper_name}") time.sleep(1.0) except Exception as e: - print(f"[ERROR] While processing '{paper_name}': {e}") + print(f"[Writer] Error while processing '{paper_name}': {e}") traceback.print_exc() + print(f"[Writer] Search completed, found {len(results_dict)} papers total") return results_dict def _write_related_work(self, idea: Dict[str, Any]) -> None: @@ -384,7 +535,21 @@ def _write_related_work(self, idea: Dict[str, Any]) -> None: ) for title, meta in paper_source.items(): - match = re.search(r"@\w+\{(.+?),", meta.get("bibtex", "")) + # Ensure meta is a dictionary before accessing 'bibtex' + if isinstance(meta, dict): + bibtex = meta.get("bibtex", "") + elif isinstance(meta, str): + print( + f"[Writer] Warning: meta is string for {title}, skipping citation replacement" + ) + continue + else: + print( + f"[Writer] Warning: unexpected meta type {type(meta)} for {title}, skipping citation replacement" + ) + continue + + match = re.search(r"@\w+\{(.+?),", bibtex) if match: try: bibtex_key = match.group(1) @@ -396,7 +561,7 @@ def _write_related_work(self, idea: Dict[str, Any]) -> None: relatedwork_content, ) except Exception: - print(f"[ERROR] Failed to replace citation for title: {title}") + print(f"[Writer] Failed to replace citation for title: {title}") traceback.print_exc() self.generated_sections["Related_Work"] = relatedwork_content @@ -541,7 +706,21 @@ def _add_citations(self, idea: Dict[str, Any]) -> None: ) for title, meta in paper_source.items(): - match = re.search(r"@\w+\{(.+?),", meta.get("bibtex", "")) + # Ensure meta is a dictionary before accessing 'bibtex' + if isinstance(meta, dict): + bibtex = meta.get("bibtex", "") + elif isinstance(meta, str): + print( + f"[Writer] Warning: meta is string for {title}, skipping citation replacement" + ) + continue + else: + print( + f"[Writer] Warning: unexpected meta type {type(meta)} for {title}, skipping citation replacement" + ) + continue + + match = re.search(r"@\w+\{(.+?),", bibtex) if match: bibtex_key = match.group(1) escaped_title = re.escape(title) @@ -554,5 +733,5 @@ def _add_citations(self, idea: Dict[str, Any]) -> None: self.generated_sections[section] = refined_section except Exception: - print(f"[ERROR] Failed to add citations to section: {section}") + print(f"[Writer] Failed to add citations to section: {section}") traceback.print_exc()