Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion spacy_llm/models/rest/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import time
from enum import Enum
from typing import Any, Callable, Dict, Iterable, Optional
from typing import Any, Callable, Dict, List, Iterable, Optional

import requests # type: ignore
from requests import ConnectTimeout, ReadTimeout
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(
interval: float,
max_request_time: float,
context_length: Optional[int],
conversational_history: Optional[List[Dict[str,str]]] = None
):
"""Initializes new instance of REST-based model.
name (str): Model name.
Expand All @@ -60,6 +61,7 @@ def __init__(
self._max_request_time = max_request_time
self._credentials = self.credentials
self._context_length = context_length
self._conversational_history = conversational_history

assert self._max_tries >= 1
assert self._interval > 0
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/rest/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:

else:
for prompt in prompts_for_doc:
user_message = [{"role": "user", "content": prompt}]
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
{"messages": self._conversational_history + user_message if self._conversational_history else user_message}
)
if "error" in responses:
return responses["error"]
Expand Down
37 changes: 36 additions & 1 deletion spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from confection import SimpleFrozenDict

Expand All @@ -21,6 +21,41 @@
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
"""

@registry.llm_models("spacy.GPT-4.v4")
def openai_gpt_4_v3(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
name: str = "gpt-4",
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
conversational_history: Optional[List[Dict[str,str]]] = None
) -> OpenAI:
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4',
"gpt-4-1106-preview", ....
conversational_history ( Optional[List[Dict[str,str]]]): Optional conversational history to be provided in the ChatML approach,
with the User/Assistant streams to be appended before main prompt request
RETURNS (OpenAI): OpenAI instance for 'gpt-4' model.

DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
conversational_history=conversational_history
)


@registry.llm_models("spacy.GPT-4.v3")
def openai_gpt_4_v3(
Expand Down