Skip to content

Commit b77aff2

Browse files
committed
feat: adapt proxy
1 parent 89d74f4 commit b77aff2

File tree

8 files changed

+270
-696
lines changed

8 files changed

+270
-696
lines changed

examples/01_intro_to_llmstudio.ipynb

Lines changed: 28 additions & 87 deletions
Large diffs are not rendered by default.

examples/01_intro_to_llmstudio_with_proxy.ipynb

Lines changed: 162 additions & 565 deletions
Large diffs are not rendered by default.

examples/llm_proxy.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,38 @@
66

77
llm = LLMProxyProvider(provider="openai", host="0.0.0.0", port="8001")
88

9-
result = llm.chat("What's your name", model="gpt-4o")
9+
result = llm.chat("Write a paragfraph about space", model="gpt-4o")
1010
print(result)
1111

12+
13+
response = llm.chat("Write a paragfraph about space", model="gpt-4o", is_stream=True)
14+
for i, chunk in enumerate(response):
15+
if i%20==0:
16+
print("\n")
17+
if not chunk.metrics:
18+
print(chunk.chat_output, end="", flush=True)
19+
else:
20+
print("\n\n## Metrics:")
21+
print(chunk.metrics)
22+
23+
1224
import asyncio
1325

1426
# stream
1527
print("\nasync stream")
1628
async def async_stream():
1729

18-
response_async = await llm.achat("What's your name", model="gpt-4o", is_stream=True)
19-
async for p in response_async:
20-
if "}" in p.chat_output:
21-
p.chat_output
22-
print("that: ",p.chat_output)
30+
response_async = await llm.achat("Write a paragfraph about space", model="gpt-4o", is_stream=False)
31+
print(response_async)
32+
33+
response_async_stream = await llm.achat("Write a paragfraph about space", model="gpt-4o", is_stream=True)
34+
async for p in response_async_stream:
35+
2336
# pprint(p.choices[0].delta.content==p.chat_output)
2437
# print("metrics: ", p.metrics)
2538
# print(p)
26-
if p.metrics:
27-
print(p)
39+
if not p.metrics:
40+
print(p.chat_output, end="", flush=True)
41+
else:
42+
print(p.metrics)
2843
asyncio.run(async_stream())

libs/core/llmstudio_core/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
_engine_config = _load_engine_config()
88

99

10-
def LLM(provider: str, api_key: Optional[str] = None, **kwargs) -> BaseProvider:
10+
def LLMCore(provider: str, api_key: Optional[str] = None, **kwargs) -> BaseProvider:
1111
"""
1212
Factory method to create an instance of a provider.
1313
@@ -35,7 +35,7 @@ def LLM(provider: str, api_key: Optional[str] = None, **kwargs) -> BaseProvider:
3535
load_dotenv()
3636

3737
def test_stuff(provider, model, api_key, **kwargs):
38-
llm = LLM(provider=provider, api_key=api_key, **kwargs)
38+
llm = LLMCore(provider=provider, api_key=api_key, **kwargs)
3939

4040
latencies = {}
4141
chat_request = {
@@ -76,9 +76,8 @@ async def async_stream():
7676

7777
response_async = await llm.achat(**chat_request)
7878
async for p in response_async:
79-
if "}" in p.chat_output:
80-
p.chat_output
81-
print("that: ",p.chat_output)
79+
if not p.metrics:
80+
print("that: ",p.chat_output_stream)
8281
# pprint(p.choices[0].delta.content==p.chat_output)
8382
# print("metrics: ", p.metrics)
8483
# print(p)

libs/core/llmstudio_core/providers/provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,8 @@ async def ahandle_response(
229229
if isinstance(request.chat_input, str)
230230
else request.chat_input[-1]["content"]
231231
),
232-
"chat_output": chat_output if chat_output else "",
232+
"chat_output": None,
233+
"chat_output_stream": chat_output if chat_output else "",
233234
"context": (
234235
[{"role": "user", "content": request.chat_input}]
235236
if isinstance(request.chat_input, str)
@@ -276,7 +277,8 @@ async def ahandle_response(
276277
if isinstance(request.chat_input, str)
277278
else request.chat_input[-1]["content"]
278279
),
279-
"chat_output": "" if request.is_stream else output_string,
280+
"chat_output": output_string,
281+
"chat_output_stream": "",
280282
"context": (
281283
[{"role": "user", "content": request.chat_input}]
282284
if isinstance(request.chat_input, str)

llmstudio/engine/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,13 @@ async def chat_handler(request: Request):
9696
provider_class = provider_registry.get(f"{provider_config.name}".lower())
9797
provider_instance = provider_class(provider_config)
9898
request_dict = await request.json()
99+
99100
result = await provider_instance.achat(**request_dict)
101+
if request_dict.get("is_stream", False):
102+
async def result_generator():
103+
async for chunk in result:
104+
yield json.dumps(chunk.dict())
105+
return StreamingResponse(result_generator(), media_type="application/json")
100106
return result
101107

102108
return chat_handler

llmstudio/engine/provider.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,40 @@
11
import asyncio
2-
from typing import Any, Coroutine, Dict, List, Union
2+
import json
3+
from typing import Any, Coroutine, Dict, List, Optional, Union
34

5+
from pydantic import BaseModel
46
import requests
5-
from libs.core.llmstudio_core.providers.provider import BaseProvider, ProviderABC
7+
from llmstudio_core.providers.provider import ProviderABC
68
from llmstudio.server import is_server_running
79
from openai.types.chat import ChatCompletion, ChatCompletionChunk
810
from tqdm.asyncio import tqdm_asyncio
911

10-
from llmstudio.config import ENGINE_HOST, ENGINE_PORT
1112
from llmstudio.llm.semaphore import DynamicSemaphore
1213

14+
15+
class ProxyConfig(BaseModel):
16+
host: Optional[str] = None
17+
port: Optional[str] = None
18+
url: Optional[str] = None
19+
username: Optional[str] = None
20+
password: Optional[str] = None
21+
def __init__(self, **data):
22+
super().__init__(**data)
23+
if (self.host is None and self.port is None) and self.url is None:
24+
raise ValueError("Either both 'host' and 'port' must be provided, or 'url' must be specified.")
25+
26+
1327
class LLMProxyProvider(ProviderABC):
1428
def __init__(self, provider: str,
15-
host: str,
16-
port: str,
17-
**kwargs):
29+
proxy_config: ProxyConfig):
1830
self.provider = provider
19-
self.engine_host = host
20-
self.engine_port = port
21-
if is_server_running(host=host, port=port):
22-
print(f"Connected to LLMStudio Proxy @ {host}:{port}")
31+
32+
self.engine_host = proxy_config.host
33+
self.engine_port = proxy_config.port
34+
if is_server_running(host=self.engine_host, port=self.engine_port):
35+
print(f"Connected to LLMStudio Proxy @ {self.engine_host}:{self.engine_port}")
2336
else:
24-
raise Exception(f"LLMStudio Proxy is not running @ {host}:{port}")
37+
raise Exception(f"LLMStudio Proxy is not running @ {self.engine_host}:{self.engine_port}")
2538

2639
@staticmethod
2740
def _provider_config_name():
@@ -59,7 +72,7 @@ def chat(self, chat_input: str,
5972
def generate_chat(self, response):
6073
for chunk in response.iter_content(chunk_size=None):
6174
if chunk:
62-
yield chunk.decode("utf-8")
75+
yield ChatCompletionChunk(**json.loads(chunk.decode("utf-8")))
6376

6477
async def achat(self, chat_input: Any,
6578
model: str,
@@ -233,4 +246,4 @@ async def async_stream(self, model:str, chat_input: str, retries: int, parameter
233246

234247
for chunk in response.iter_content(chunk_size=None):
235248
if chunk:
236-
yield chunk.decode("utf-8")
249+
yield ChatCompletionChunk(**json.loads(chunk.decode("utf-8")))

llmstudio/llm/__init__.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,44 @@
11
from typing import Any, Coroutine, Optional
2-
from llmstudio_core import LLM as LLM_factory
2+
from llmstudio_core import LLMCore
33
from llmstudio_core.providers.provider import ProviderABC
44
from openai.types.chat import ChatCompletion, ChatCompletionChunk
55
from pydantic import BaseModel
66

7-
from llmstudio.engine.provider import LLMProxyProvider
7+
from llmstudio.engine.provider import LLMProxyProvider, ProxyConfig
88
from llmstudio.tracking.database import create_tracking_engine
99
from llmstudio.tracking.logs import crud, schemas
1010

11-
from sqlalchemy.orm import declarative_base, sessionmaker
12-
13-
14-
class ProxyConfig(BaseModel):
15-
host: str
16-
port: int
17-
username: Optional[str] = None
18-
password: Optional[str] = None
11+
from sqlalchemy.orm import sessionmaker
1912

2013
class TrackingConfig(BaseModel):
21-
database_uri: str
14+
database_uri: Optional[str] = None
15+
host: Optional[str] = None
16+
port: Optional[int] = None
17+
url: Optional[str] = None
18+
19+
def __init__(self, **data):
20+
super().__init__(**data)
21+
if (self.host and self.port) or self.url or self.database_uri:
22+
raise ValueError("You must provide either both 'host' and 'port', or 'url', or 'database_uri'.")
2223

2324

2425
class LLM(ProviderABC):
2526

2627

2728
def __init__(self,
2829
provider: str,
30+
api_key: Optional[str] = None,
2931
proxy_config: Optional[ProxyConfig] = None,
3032
tracking_config: Optional[TrackingConfig] = None,
3133
**kwargs):
3234

3335
if proxy_config is not None:
3436
self._provider = LLMProxyProvider(provider=provider,
35-
host=proxy_config.host,
36-
port=proxy_config.port,
37-
**kwargs
38-
)
37+
proxy_config=proxy_config)
3938
else:
40-
self._provider = LLM_factory(provider, **kwargs)
39+
self._provider = LLMCore(provider=provider,
40+
api_key=api_key,
41+
**kwargs)
4142

4243
self._session_local = None
4344
if tracking_config is not None:

0 commit comments

Comments
 (0)