diff --git a/src/server.ts b/src/server.ts index be0aa13..0fa653f 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,6 +12,7 @@ import { handleInfoPage } from "./info-page"; import { logQueue } from "./prompt-logging"; import { start as startRequestQueue } from "./proxy/queue"; import { init as initUserStore } from "./proxy/auth/user-store"; +import { init as initTokenizers } from "./tokenization"; import { checkOrigin } from "./proxy/check-origin"; const PORT = config.port; @@ -99,6 +100,8 @@ async function start() { keyPool.init(); + await initTokenizers(); + if (config.gatekeeper === "user_token") { await initUserStore(); } diff --git a/src/tokenization/index.ts b/src/tokenization/index.ts new file mode 100644 index 0000000..2d07f5e --- /dev/null +++ b/src/tokenization/index.ts @@ -0,0 +1 @@ +export { init, countTokens } from "./tokenizer"; diff --git a/src/tokenization/openai.ts b/src/tokenization/openai.ts new file mode 100644 index 0000000..e226e82 --- /dev/null +++ b/src/tokenization/openai.ts @@ -0,0 +1,18 @@ +import { Tiktoken } from "tiktoken/lite"; +import cl100k_base from "tiktoken/encoders/cl100k_base.json"; + +let encoder: Tiktoken; + +export function init() { + encoder = new Tiktoken( + cl100k_base.bpe_ranks, + cl100k_base.special_tokens, + cl100k_base.pat_str + ); + return true; +} + +export function getTokenCount(text: string) { + const tokens = encoder.encode(text); + return tokens.length; +} diff --git a/src/tokenization/tokenizer.ts b/src/tokenization/tokenizer.ts new file mode 100644 index 0000000..cf55659 --- /dev/null +++ b/src/tokenization/tokenizer.ts @@ -0,0 +1,65 @@ +import { Request } from "express"; +import { config } from "../config"; +import { AIService } from "../key-management"; +import { logger } from "../logger"; +import { + init as initIpc, + requestTokenCount as requestClaudeTokenCount, +} from "./claude-ipc"; +import { + init as initEncoder, + getTokenCount as getOpenAITokenCount, +} from "./openai"; + +let canTokenizeClaude = false; +let canTokenizeOpenAI = false; + +export async function init() { + if (config.anthropicKey) { + canTokenizeClaude = await initIpc(); + if (!canTokenizeClaude) { + logger.warn( + "Anthropic key is set, but tokenizer is not available. Claude prompts will use a naive estimate for token count." + ); + } + } + if (config.openaiKey) { + canTokenizeOpenAI = initEncoder(); + } +} + +export async function countTokens({ + req, + prompt, + service, +}: { + req: Request; + prompt: string; + service: AIService; +}) { + if (service === "anthropic") { + if (!canTokenizeClaude) return guesstimateClaudeTokenCount(prompt); + try { + return await requestClaudeTokenCount({ + requestId: String(req.id), + prompt: prompt, + }); + } catch (e) { + req.log.error("Failed to tokenize with claude_tokenizer", e); + return guesstimateClaudeTokenCount(prompt); + } + } + if (service === "openai") { + // All OpenAI models we support use the same tokenizer currently + return getOpenAITokenCount(prompt); + } +} + +function guesstimateClaudeTokenCount(prompt: string) { + // From Anthropic's docs: + // The maximum length of prompt that Claude can see is its context window. + // Claude's context window is currently ~6500 words / ~8000 tokens / + // ~28000 Unicode characters. + // We'll round up to ~0.3 tokens per character + return Math.ceil(prompt.length * 0.3); +}