| 
 | 1 | +from dataclasses import dataclass  | 
 | 2 | +from enum import Enum  | 
 | 3 | +from typing import List, Optional  | 
 | 4 | + | 
 | 5 | +from pydantic import BaseModel  | 
 | 6 | +from typing_extensions import TypedDict  | 
 | 7 | + | 
 | 8 | + | 
 | 9 | +class Document(TypedDict):  | 
 | 10 | +    title: str  | 
 | 11 | +    text: str  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +class Role(Enum):  | 
 | 15 | +    system = "system"  | 
 | 16 | +    user = "user"  | 
 | 17 | +    assistant = "assistant"  | 
 | 18 | + | 
 | 19 | + | 
 | 20 | +@dataclass  | 
 | 21 | +class Message:  | 
 | 22 | +    role: Role  | 
 | 23 | +    content: str  | 
 | 24 | + | 
 | 25 | + | 
 | 26 | +class Chat:  | 
 | 27 | +    def __init__(  | 
 | 28 | +        self,  | 
 | 29 | +        system_msg: Optional[str] = None,  | 
 | 30 | +        tools: Optional[List[BaseModel]] = None,  | 
 | 31 | +        documents: Optional[List[Document]] = None,  | 
 | 32 | +        history: List[Message] = [],  | 
 | 33 | +    ):  | 
 | 34 | +        self.history = history  | 
 | 35 | +        self.system = system_msg  | 
 | 36 | +        self.tools = tools  | 
 | 37 | +        self.documents = documents  | 
 | 38 | + | 
 | 39 | +    @property  | 
 | 40 | +    def trimmed_history(self):  | 
 | 41 | +        return self.history  | 
 | 42 | + | 
 | 43 | +    def __add__(self, other: Message):  | 
 | 44 | +        history = self.history  | 
 | 45 | +        history.append(other)  | 
 | 46 | +        return Chat(self.system, self.tools, self.documents, history=history)  | 
 | 47 | + | 
 | 48 | +    def __iadd__(self, other: Message):  | 
 | 49 | +        self.history.append(other)  | 
 | 50 | + | 
 | 51 | +    def __getitem__(self, key):  | 
 | 52 | +        if isinstance(key, int):  | 
 | 53 | +            return self.history[key]  | 
 | 54 | +        else:  | 
 | 55 | +            raise KeyError()  | 
 | 56 | + | 
 | 57 | +    def render(self, model_name: str):  | 
 | 58 | +        """Render the conversation using the model's chat template.  | 
 | 59 | +
  | 
 | 60 | +        TODO: Do this ourselves.  | 
 | 61 | +
  | 
 | 62 | +        Parameters  | 
 | 63 | +        ----------  | 
 | 64 | +        model_name  | 
 | 65 | +            The name of the model whose chat template we need to use.  | 
 | 66 | +
  | 
 | 67 | +        """  | 
 | 68 | +        from transformers import AutoTokenizer  | 
 | 69 | + | 
 | 70 | +        conversation = []  | 
 | 71 | +        if self.system is not None:  | 
 | 72 | +            conversation.append({"role": "system", "content": self.system})  | 
 | 73 | +        for message in self.trimmed_history:  | 
 | 74 | +            conversation.append({"role": message.role, "content": message.content})  | 
 | 75 | + | 
 | 76 | +        self.tokenizer = AutoTokenizer.from_pretrained(model_name)  | 
 | 77 | + | 
 | 78 | +        return self.tokenizer.apply_chat_template(  | 
 | 79 | +            conversation, self.tools, self.documents  | 
 | 80 | +        )  | 
0 commit comments