Skip to content

add token estimation to token counting for truncation #3286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 31 additions & 73 deletions server/utils/helpers/chat/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,13 @@ const { safeJsonParse } = require("../../http");
const { TokenManager } = require("../tiktoken");
const { convertToPromptHistory } = require("./responses");

/*
What is the message Array compressor?
TLDR: So anyway, i started blasting (your prompts & stuff)

messageArrayCompressor arose out of a need for users to be able to insert unlimited token prompts
and also maintain coherent history, system instructions and context, if applicable.

We took an opinionated approach that after much back-testing we have found retained a highly coherent answer
under most user conditions that a user would take while using this specific system. While other systems may
use a more advanced model for compressing message history or simplify text through a recursive approach - our is much more simple.

We "cannonball" the input.
Cannonball (verb): To ensure a prompt fits through a model window we blast a hole in the center of any inputs blocking our path to doing so.
This starts by dissecting the input as tokens and delete from the middle-out bi-directionally until the prompt window is satisfied.
You may think: "Doesn't this result in massive data loss?" - yes & no.
Under the use cases we expect the tool to be used, which is mostly chatting with documents, we are able to use this approach with minimal blowback
on the quality of responses.

We accomplish this by taking a rate-limit approach that is proportional to the model capacity. Since we support more than openAI models, this needs to
be generic and reliance on a "better summary" model just is not a luxury we can afford. The added latency overhead during prompting is also unacceptable.
In general:
system: at best 15% of token capacity
history: at best 15% of token capacity
prompt: at best 70% of token capacity.

we handle overflows by taking an aggressive path for two main cases.

1. Very large user prompt
- Likely uninterested in context, history, or even system prompt. This is a "standalone" prompt that highjacks the whole thread.
- We run this prompt on its own since a prompt that is over 70% of context window certainly is standalone.

2. Context window is exceeded in regular use.
- We do not touch prompt since it is very likely to be <70% of window.
- We check system prompt is not outrageous - if it is we cannonball it and keep context if present.
- We check a sliding window of history, only allowing up to 15% of the history to pass through if it fits, with a
preference for recent history if we can cannonball to fit it, otherwise it is omitted.

We end up with a rather large prompt that fits through a given window with a lot of room for response in most use-cases.
We also take the approach that history is the least important and most flexible of the items in this array of responses.

There is a supplemental version of this function that also returns a formatted string for models like Claude-2
*/

/**
* Compresses the message array to fit within the prompt window limit via end-truncation.
* @param {Object} llm - The LLM object.
* @param {Object[]} messages - The messages to compress.
* @param {Object[]} rawHistory - The raw history of messages.
* @returns {Promise<Object[]>} The compressed messages.
*/
async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
// assume the response will be at least 600 tokens. If the total prompt + reply is over we need to proactively
// run the compressor to ensure the prompt has enough space to reply.
Expand All @@ -68,7 +32,7 @@ async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
return [
{
role: "user",
content: cannonball({
content: truncateContent({
input: user.content,
targetTokenSize: llm.promptWindowLimit() * 0.8,
tiktokenInstance: tokenManager,
Expand All @@ -94,7 +58,7 @@ async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
// 25% of the system limit, we will cannonball it - this favors the context
// over the instruction from the user.
if (tokenManager.countFromString(prompt) >= llm.limits.system * 0.25) {
compressedPrompt = cannonball({
compressedPrompt = truncateContent({
input: prompt,
targetTokenSize: llm.limits.system * 0.25,
tiktokenInstance: tokenManager,
Expand All @@ -104,7 +68,7 @@ async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
}

if (tokenManager.countFromString(context) >= llm.limits.system * 0.75) {
compressedContext = cannonball({
compressedContext = truncateContent({
input: context,
targetTokenSize: llm.limits.system * 0.75,
tiktokenInstance: tokenManager,
Expand All @@ -113,9 +77,8 @@ async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
compressedContext = context;
}

system.content = `${compressedPrompt}${
compressedContext ? `\nContext: ${compressedContext}` : ""
}`;
system.content = `${compressedPrompt}${compressedContext ? `\nContext: ${compressedContext}` : ""
}`;
resolve(system);
});

Expand Down Expand Up @@ -158,15 +121,15 @@ async function messageArrayCompressor(llm, messages = [], rawHistory = []) {
// The math isnt perfect for tokens, so we have to add a fudge factor for safety.
const maxTargetSize = Math.floor(llm.limits.history / 2.2);
if (userTokens > maxTargetSize) {
user.content = cannonball({
user.content = truncateContent({
input: user.content,
targetTokenSize: maxTargetSize,
tiktokenInstance: tokenManager,
});
}

if (assistantTokens > maxTargetSize) {
assistant.content = cannonball({
assistant.content = truncateContent({
input: assistant.content,
targetTokenSize: maxTargetSize,
tiktokenInstance: tokenManager,
Expand Down Expand Up @@ -210,7 +173,7 @@ async function messageStringCompressor(llm, promptArgs = {}, rawHistory = []) {
// the token supply to reply with.
if (userPromptSize > llm.limits.user) {
return llm.constructPrompt({
userPrompt: cannonball({
userPrompt: truncateContent({
input: user,
targetTokenSize: llm.promptWindowLimit() * 0.8,
tiktokenInstance: tokenManager,
Expand All @@ -225,7 +188,7 @@ async function messageStringCompressor(llm, promptArgs = {}, rawHistory = []) {
return;
}
resolve(
cannonball({
truncateContent({
input: system,
targetTokenSize: llm.limits.system,
tiktokenInstance: tokenManager,
Expand Down Expand Up @@ -272,15 +235,15 @@ async function messageStringCompressor(llm, promptArgs = {}, rawHistory = []) {
// The math isnt perfect for tokens, so we have to add a fudge factor for safety.
const maxTargetSize = Math.floor(llm.limits.history / 2.2);
if (userTokens > maxTargetSize) {
user.content = cannonball({
user.content = truncateContent({
input: user.content,
targetTokenSize: maxTargetSize,
tiktokenInstance: tokenManager,
});
}

if (assistantTokens > maxTargetSize) {
assistant.content = cannonball({
assistant.content = truncateContent({
input: assistant.content,
targetTokenSize: maxTargetSize,
tiktokenInstance: tokenManager,
Expand Down Expand Up @@ -309,9 +272,7 @@ async function messageStringCompressor(llm, promptArgs = {}, rawHistory = []) {
});
}

// Cannonball prompting: aka where we shoot a proportionally big cannonball through a proportional large prompt
// Nobody should be sending prompts this big, but there is no reason we shouldn't allow it if results are good even by doing it.
function cannonball({
function truncateContent({
input = "",
targetTokenSize = 0,
tiktokenInstance = null,
Expand All @@ -323,25 +284,22 @@ function cannonball({
const initialInputSize = tokenManager.countFromString(input);
if (initialInputSize < targetTokenSize) return input;

console.log("input", input.length);
if (input.length > TokenManager.MAX_STRING_LENGTH) {
const charsToTruncate = input.length - (targetTokenSize * TokenManager.TOKEN_CHAR_ESTIMATE); // approx number of chars to truncate
const truncatedInput = input.slice(0, (charsToTruncate * -1)) + truncText;
console.log(`[Content Truncated (estimated)] ${initialInputSize} input tokens, target: ${targetTokenSize} => ${tokenManager.countFromString(truncatedInput)} tokens.`);
return truncatedInput;
}

// if the delta is the token difference between where our prompt is in size
// and where we ideally need to land.
console.log("Truncating input via encoder method");
const delta = initialInputSize - targetTokenSize;
const tokenChunks = tokenManager.tokensFromString(input);
const middleIdx = Math.floor(tokenChunks.length / 2);

// middle truncate the text going left and right of midpoint
const leftChunks = tokenChunks.slice(0, middleIdx - Math.round(delta / 2));
const rightChunks = tokenChunks.slice(middleIdx + Math.round(delta / 2));
const truncatedText =
tokenManager.bytesFromTokens(leftChunks) +
truncText +
tokenManager.bytesFromTokens(rightChunks);

console.log(
`Cannonball results ${initialInputSize} -> ${tokenManager.countFromString(
truncatedText
)} tokens.`
);
const allowedTokens = tokenChunks.slice(0, delta * -1);
const truncatedText = tokenManager.bytesFromTokens(allowedTokens) + truncText;
console.log(`[Content Truncated (encoder)] ${initialInputSize} tokens, target: ${targetTokenSize} => ${allowedTokens.length} tokens.`);
return truncatedText;
}

Expand Down
9 changes: 8 additions & 1 deletion server/utils/helpers/tiktoken.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const { getEncodingNameForModel, getEncoding } = require("js-tiktoken");
class TokenManager {
static instance = null;
static currentModel = null;
static MAX_STRING_LENGTH = 400_000; // 1M chars as a sanity limit for low-end devices
static TOKEN_CHAR_ESTIMATE = 3;

constructor(model = "gpt-3.5-turbo") {
if (TokenManager.instance && TokenManager.currentModel === model) {
Expand Down Expand Up @@ -57,7 +59,7 @@ class TokenManager {
}

/**
* Converts an array of tokens back to a string.
* Conversion function for tokenIds to a string via encoder module.
* @param {number[]} tokens
* @returns {string}
*/
Expand All @@ -72,6 +74,11 @@ class TokenManager {
* @returns {number}
*/
countFromString(input = "") {
if (input.length > TokenManager.MAX_STRING_LENGTH) {
this.log("estimating token count for performance...");
return Math.ceil(input.length / TokenManager.TOKEN_CHAR_ESTIMATE);
}

const tokens = this.tokensFromString(input);
return tokens.length;
}
Expand Down