Skip to content

Commit 89d74f4

Browse files
committed
feat: adapt proxy
1 parent a70324a commit 89d74f4

File tree

12 files changed

+2541
-1111
lines changed

12 files changed

+2541
-1111
lines changed

examples/01_intro_to_llmstudio copy.ipynb

Lines changed: 326 additions & 0 deletions
Large diffs are not rendered by default.

examples/01_intro_to_llmstudio_with_proxy.ipynb

Lines changed: 840 additions & 0 deletions
Large diffs are not rendered by default.

examples/03_langchain_integration.ipynb

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,32 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": 33,
19+
"execution_count": 1,
2020
"metadata": {},
21-
"outputs": [
22-
{
23-
"name": "stdout",
24-
"output_type": "stream",
25-
"text": [
26-
"Running LLMstudio Engine on http://localhost:55189 Running LLMstudio Tracking on http://localhost:55190 \n",
27-
"\n"
28-
]
29-
}
30-
],
21+
"outputs": [],
3122
"source": [
3223
"from llmstudio.llm.langchain import ChatLLMstudio\n",
3324
"from llmstudio import LLM"
3425
]
3526
},
3627
{
3728
"cell_type": "code",
38-
"execution_count": 38,
29+
"execution_count": 2,
3930
"metadata": {},
40-
"outputs": [],
31+
"outputs": [
32+
{
33+
"ename": "TypeError",
34+
"evalue": "Can't instantiate abstract class LLMProxyProvider without an implementation for abstract method '_provider_config_name'",
35+
"output_type": "error",
36+
"traceback": [
37+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
38+
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
39+
"Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# llm = ChatLLMstudio(model_id='openai/gpt-3.5-turbo', temperature=0)\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mChatLLMstudio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvertexai/gemini-1.5-flash\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n",
40+
"File \u001b[0;32m~/fun/LLMstudio/llmstudio/llm/langchain.py:33\u001b[0m, in \u001b[0;36mChatLLMstudio.__init__\u001b[0;34m(self, model_id, **kwargs)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, model_id: \u001b[38;5;28mstr\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(model_id\u001b[38;5;241m=\u001b[39mmodel_id, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 33\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mllm \u001b[38;5;241m=\u001b[39m \u001b[43mLLM\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_id\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
41+
"\u001b[0;31mTypeError\u001b[0m: Can't instantiate abstract class LLMProxyProvider without an implementation for abstract method '_provider_config_name'"
42+
]
43+
}
44+
],
4145
"source": [
4246
"# llm = ChatLLMstudio(model_id='openai/gpt-3.5-turbo', temperature=0)\n",
4347
"llm = ChatLLMstudio(model_id='vertexai/gemini-1.5-flash', temperature=0)"
@@ -52,7 +56,7 @@
5256
},
5357
{
5458
"cell_type": "code",
55-
"execution_count": 37,
59+
"execution_count": 3,
5660
"metadata": {},
5761
"outputs": [],
5862
"source": [
@@ -62,16 +66,16 @@
6266
},
6367
{
6468
"cell_type": "code",
65-
"execution_count": 35,
69+
"execution_count": 4,
6670
"metadata": {},
6771
"outputs": [
6872
{
6973
"data": {
7074
"text/plain": [
71-
"AIMessage(content='Hello! \\n\\nHow can I help you today? \\n', response_metadata={'token_usage': None, 'model_name': 'gemini-1.5-flash', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-5945f8bd-7151-4d85-bf90-3bcc1eaabc6c-0')"
75+
"AIMessage(content='Hello! 👋 How can I help you today? 😊 \\n', response_metadata={'token_usage': None, 'model_name': 'gemini-1.5-flash', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-2e60f8e1-bba0-4cab-9e21-8ecdbe8d49a2-0')"
7276
]
7377
},
74-
"execution_count": 35,
78+
"execution_count": 4,
7579
"metadata": {},
7680
"output_type": "execute_result"
7781
}
@@ -388,7 +392,7 @@
388392
"name": "python",
389393
"nbconvert_exporter": "python",
390394
"pygments_lexer": "ipython3",
391-
"version": "3.11.9"
395+
"version": "3.12.2"
392396
}
393397
},
394398
"nbformat": 4,

examples/llm_proxy.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from llmstudio.server import start_server
2+
start_server()
3+
4+
from llmstudio.engine.provider import LLMProxyProvider
5+
6+
7+
llm = LLMProxyProvider(provider="openai", host="0.0.0.0", port="8001")
8+
9+
result = llm.chat("What's your name", model="gpt-4o")
10+
print(result)
11+
12+
import asyncio
13+
14+
# stream
15+
print("\nasync stream")
16+
async def async_stream():
17+
18+
response_async = await llm.achat("What's your name", model="gpt-4o", is_stream=True)
19+
async for p in response_async:
20+
if "}" in p.chat_output:
21+
p.chat_output
22+
print("that: ",p.chat_output)
23+
# pprint(p.choices[0].delta.content==p.chat_output)
24+
# print("metrics: ", p.metrics)
25+
# print(p)
26+
if p.metrics:
27+
print(p)
28+
asyncio.run(async_stream())

libs/core/llmstudio_core/providers/provider.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ async def achat(
7373
model: str,
7474
is_stream: Optional[bool] = False,
7575
retries: Optional[int] = 0,
76-
parameters: Optional[dict] = {}
76+
parameters: Optional[dict] = {},
77+
**kwargs
7778
) -> Coroutine[Any, Any, Union[ChatCompletionChunk, ChatCompletion]]:
7879
raise NotImplementedError("Providers needs to have achat method implemented.")
7980

@@ -84,7 +85,8 @@ def chat(
8485
model: str,
8586
is_stream: Optional[bool] = False,
8687
retries: Optional[int] = 0,
87-
parameters: Optional[dict] = {}
88+
parameters: Optional[dict] = {},
89+
**kwargs
8890
) -> Union[ChatCompletionChunk, ChatCompletion]:
8991
raise NotImplementedError("Providers needs to have chat method implemented.")
9092

@@ -103,7 +105,8 @@ async def achat(
103105
model: str,
104106
is_stream: Optional[bool] = False,
105107
retries: Optional[int] = 0,
106-
parameters: Optional[dict] = {}
108+
parameters: Optional[dict] = {},
109+
**kwargs
107110
):
108111

109112
"""Makes a chat connection with the provider's API"""
@@ -145,7 +148,8 @@ def chat(
145148
model: str,
146149
is_stream: Optional[bool] = False,
147150
retries: Optional[int] = 0,
148-
parameters: Optional[dict] = {}
151+
parameters: Optional[dict] = {},
152+
**kwargs
149153
):
150154

151155
"""Makes a chat connection with the provider's API"""

llmstudio/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import signal
3+
import threading
34

45
import click
56

@@ -25,11 +26,14 @@ def server(ui):
2526

2627
print("Servers are running. Press CTRL+C to stop.")
2728

29+
stop_event = threading.Event()
2830
try:
29-
signal.pause()
31+
stop_event.wait() # Wait indefinitely until the event is set
3032
except KeyboardInterrupt:
3133
print("Shutting down servers...")
3234

3335

3436
if __name__ == "__main__":
35-
main()
37+
# main()
38+
server()
39+
print(4)

llmstudio/engine/__init__.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from fastapi import FastAPI, Request
1010
from fastapi.middleware.cors import CORSMiddleware
1111
from fastapi.responses import StreamingResponse
12-
from pydantic import BaseModel, ValidationError
12+
from pydantic import BaseModel
1313

1414
from llmstudio.config import ENGINE_HOST, ENGINE_PORT
15-
from llmstudio.engine.providers import *
15+
from llmstudio_core.providers import _load_engine_config
16+
from llmstudio_core.providers.provider import provider_registry
1617

1718
ENGINE_BASE_ENDPOINT = "/api/engine"
1819
ENGINE_HEALTH_ENDPOINT = "/health"
@@ -47,38 +48,6 @@ class EngineConfig(BaseModel):
4748
providers: Dict[str, ProviderConfig]
4849

4950

50-
def _load_engine_config() -> EngineConfig:
51-
default_config_path = Path(os.path.join(os.path.dirname(__file__), "config.yaml"))
52-
local_config_path = Path(os.getcwd(), "config.yaml")
53-
54-
def _merge_configs(config1, config2):
55-
for key in config2:
56-
if key in config1:
57-
if isinstance(config1[key], dict) and isinstance(config2[key], dict):
58-
_merge_configs(config1[key], config2[key])
59-
elif isinstance(config1[key], list) and isinstance(config2[key], list):
60-
config1[key].extend(config2[key])
61-
else:
62-
config1[key] = config2[key]
63-
else:
64-
config1[key] = config2[key]
65-
return config1
66-
67-
try:
68-
default_config_data = yaml.safe_load(default_config_path.read_text())
69-
local_config_data = (
70-
yaml.safe_load(local_config_path.read_text())
71-
if local_config_path.exists()
72-
else {}
73-
)
74-
config_data = _merge_configs(default_config_data, local_config_data)
75-
return EngineConfig(**config_data)
76-
except yaml.YAMLError as e:
77-
raise RuntimeError(f"Error parsing YAML configuration: {e}")
78-
except ValidationError as e:
79-
raise RuntimeError(f"Error in configuration data: {e}")
80-
81-
8251
def create_engine_app(
8352
started_event: Event, config: EngineConfig = _load_engine_config()
8453
) -> FastAPI:
@@ -124,9 +93,11 @@ def get_models(provider: Optional[str] = None):
12493
def create_chat_handler(provider_config):
12594
async def chat_handler(request: Request):
12695
"""Endpoint for chat functionality."""
127-
provider_class = provider_registry.get(f"{provider_config.name}Provider")
96+
provider_class = provider_registry.get(f"{provider_config.name}".lower())
12897
provider_instance = provider_class(provider_config)
129-
return await provider_instance.chat(await request.json())
98+
request_dict = await request.json()
99+
result = await provider_instance.achat(**request_dict)
100+
return result
130101

131102
return chat_handler
132103

0 commit comments

Comments
 (0)