diff --git a/src/tokenization/claude-ipc.ts b/src/tokenization/claude-ipc.ts index 4b8de23..0b21338 100644 --- a/src/tokenization/claude-ipc.ts +++ b/src/tokenization/claude-ipc.ts @@ -59,7 +59,7 @@ export async function requestTokenCount({ throw new Error("Claude tokenizer is not initialized"); } - log.debug({ requestId, prompt: prompt.length }, "Requesting token count"); + log.debug({ requestId, chars: prompt.length }, "Requesting token count"); await socket.send(["tokenize", requestId, prompt]); log.debug({ requestId }, "Waiting for socket response"); @@ -75,11 +75,11 @@ export async function requestTokenCount({ setTimeout(() => { if (pendingRequests.has(requestId)) { pendingRequests.delete(requestId); - const err = "Tokenizer took too long to respond"; + const err = "Tokenizer deadline exceeded"; log.warn({ requestId }, err); reject(new Error(err)); } - }, 500); + }, 250); // TODO: make this configurable, some really crappy VMs might need more time }); } diff --git a/src/tokenization/openai.ts b/src/tokenization/openai.ts index e226e82..6218a92 100644 --- a/src/tokenization/openai.ts +++ b/src/tokenization/openai.ts @@ -12,7 +12,46 @@ export function init() { return true; } -export function getTokenCount(text: string) { - const tokens = encoder.encode(text); - return tokens.length; +// Implmentation based and tested against: +// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + +export function getTokenCount(messages: any[], model: string) { + const gpt4 = model.startsWith("gpt-4"); + + const tokensPerMessage = gpt4 ? 3 : 4; + const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present + + let numTokens = 0; + + for (const message of messages) { + numTokens += tokensPerMessage; + for (const key of Object.keys(message)) { + { + const value = message[key]; + // Break if we get a huge message or exceed the token limit to prevent DoS + // 100k tokens allows for future 100k GPT-4 models and 250k characters is + // just a sanity check + if (value.length > 250000 || numTokens > 100000) { + numTokens = 100000; + return { + tokenizer: "tiktoken (prompt length limit exceeded)", + token_count: numTokens, + }; + } + + numTokens += encoder.encode(message[key]).length; + if (key === "name") { + numTokens += tokensPerName; + } + } + } + } + numTokens += 3; // every reply is primed with <|start|>assistant<|message|> + return { tokenizer: "tiktoken", token_count: numTokens }; } + +export type OpenAIPromptMessage = { + name?: string; + content: string; + role: string; +}; diff --git a/src/tokenization/tokenizer.ts b/src/tokenization/tokenizer.ts index cf55659..d3ece2b 100644 --- a/src/tokenization/tokenizer.ts +++ b/src/tokenization/tokenizer.ts @@ -1,6 +1,5 @@ import { Request } from "express"; import { config } from "../config"; -import { AIService } from "../key-management"; import { logger } from "../logger"; import { init as initIpc, @@ -9,6 +8,7 @@ import { import { init as initEncoder, getTokenCount as getOpenAITokenCount, + OpenAIPromptMessage, } from "./openai"; let canTokenizeClaude = false; @@ -28,38 +28,86 @@ export async function init() { } } +type TokenCountResult = { + token_count: number; + tokenizer: string; + tokenization_duration_ms: number; +}; +type TokenCountRequest = { + req: Request; +} & ( + | { prompt: string; service: "anthropic" } + | { prompt: OpenAIPromptMessage[]; service: "openai" } +); export async function countTokens({ req, - prompt, service, -}: { - req: Request; - prompt: string; - service: AIService; -}) { - if (service === "anthropic") { - if (!canTokenizeClaude) return guesstimateClaudeTokenCount(prompt); - try { - return await requestClaudeTokenCount({ - requestId: String(req.id), - prompt: prompt, - }); - } catch (e) { - req.log.error("Failed to tokenize with claude_tokenizer", e); - return guesstimateClaudeTokenCount(prompt); - } - } - if (service === "openai") { - // All OpenAI models we support use the same tokenizer currently - return getOpenAITokenCount(prompt); + prompt, +}: TokenCountRequest): Promise { + const time = process.hrtime(); + + switch (service) { + case "anthropic": + if (!canTokenizeClaude) { + const result = guesstimateTokens(prompt); + return { + token_count: result, + tokenizer: "guesstimate (claude-ipc disabled)", + tokenization_duration_ms: getElapsedMs(time), + }; + } + + // If the prompt is absolutely massive (possibly malicious) don't even try + if (prompt.length > 500000) { + return { + token_count: guesstimateTokens(JSON.stringify(prompt)), + tokenizer: "guesstimate (prompt too long)", + tokenization_duration_ms: getElapsedMs(time), + }; + } + + try { + const result = await requestClaudeTokenCount({ + requestId: String(req.id), + prompt, + }); + return { + token_count: result, + tokenizer: "claude-ipc", + tokenization_duration_ms: getElapsedMs(time), + }; + } catch (e: any) { + req.log.error("Failed to tokenize with claude_tokenizer", e); + const result = guesstimateTokens(prompt); + return { + token_count: result, + tokenizer: `guesstimate (claude-ipc failed: ${e.message})`, + tokenization_duration_ms: getElapsedMs(time), + }; + } + + case "openai": + const result = getOpenAITokenCount(prompt, req.body.model); + return { + ...result, + tokenization_duration_ms: getElapsedMs(time), + }; + default: + throw new Error(`Unknown service: ${service}`); } } -function guesstimateClaudeTokenCount(prompt: string) { +function getElapsedMs(time: [number, number]) { + const diff = process.hrtime(time); + return diff[0] * 1000 + diff[1] / 1e6; +} + +function guesstimateTokens(prompt: string) { // From Anthropic's docs: // The maximum length of prompt that Claude can see is its context window. // Claude's context window is currently ~6500 words / ~8000 tokens / // ~28000 Unicode characters. - // We'll round up to ~0.3 tokens per character - return Math.ceil(prompt.length * 0.3); + // This suggests 0.28 tokens per character but in practice this seems to be + // a substantial underestimate in some cases. + return Math.ceil(prompt.length * 0.325); }