improves OpenAI token counting accuracy
This commit is contained in:
@@ -59,7 +59,7 @@ export async function requestTokenCount({
|
||||
throw new Error("Claude tokenizer is not initialized");
|
||||
}
|
||||
|
||||
log.debug({ requestId, prompt: prompt.length }, "Requesting token count");
|
||||
log.debug({ requestId, chars: prompt.length }, "Requesting token count");
|
||||
await socket.send(["tokenize", requestId, prompt]);
|
||||
|
||||
log.debug({ requestId }, "Waiting for socket response");
|
||||
@@ -75,11 +75,11 @@ export async function requestTokenCount({
|
||||
setTimeout(() => {
|
||||
if (pendingRequests.has(requestId)) {
|
||||
pendingRequests.delete(requestId);
|
||||
const err = "Tokenizer took too long to respond";
|
||||
const err = "Tokenizer deadline exceeded";
|
||||
log.warn({ requestId }, err);
|
||||
reject(new Error(err));
|
||||
}
|
||||
}, 500);
|
||||
}, 250); // TODO: make this configurable, some really crappy VMs might need more time
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,46 @@ export function init() {
|
||||
return true;
|
||||
}
|
||||
|
||||
export function getTokenCount(text: string) {
|
||||
const tokens = encoder.encode(text);
|
||||
return tokens.length;
|
||||
// Implmentation based and tested against:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
export function getTokenCount(messages: any[], model: string) {
|
||||
const gpt4 = model.startsWith("gpt-4");
|
||||
|
||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||
|
||||
let numTokens = 0;
|
||||
|
||||
for (const message of messages) {
|
||||
numTokens += tokensPerMessage;
|
||||
for (const key of Object.keys(message)) {
|
||||
{
|
||||
const value = message[key];
|
||||
// Break if we get a huge message or exceed the token limit to prevent DoS
|
||||
// 100k tokens allows for future 100k GPT-4 models and 250k characters is
|
||||
// just a sanity check
|
||||
if (value.length > 250000 || numTokens > 100000) {
|
||||
numTokens = 100000;
|
||||
return {
|
||||
tokenizer: "tiktoken (prompt length limit exceeded)",
|
||||
token_count: numTokens,
|
||||
};
|
||||
}
|
||||
|
||||
numTokens += encoder.encode(message[key]).length;
|
||||
if (key === "name") {
|
||||
numTokens += tokensPerName;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
numTokens += 3; // every reply is primed with <|start|>assistant<|message|>
|
||||
return { tokenizer: "tiktoken", token_count: numTokens };
|
||||
}
|
||||
|
||||
export type OpenAIPromptMessage = {
|
||||
name?: string;
|
||||
content: string;
|
||||
role: string;
|
||||
};
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../config";
|
||||
import { AIService } from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import {
|
||||
init as initIpc,
|
||||
@@ -9,6 +8,7 @@ import {
|
||||
import {
|
||||
init as initEncoder,
|
||||
getTokenCount as getOpenAITokenCount,
|
||||
OpenAIPromptMessage,
|
||||
} from "./openai";
|
||||
|
||||
let canTokenizeClaude = false;
|
||||
@@ -28,38 +28,86 @@ export async function init() {
|
||||
}
|
||||
}
|
||||
|
||||
type TokenCountResult = {
|
||||
token_count: number;
|
||||
tokenizer: string;
|
||||
tokenization_duration_ms: number;
|
||||
};
|
||||
type TokenCountRequest = {
|
||||
req: Request;
|
||||
} & (
|
||||
| { prompt: string; service: "anthropic" }
|
||||
| { prompt: OpenAIPromptMessage[]; service: "openai" }
|
||||
);
|
||||
export async function countTokens({
|
||||
req,
|
||||
prompt,
|
||||
service,
|
||||
}: {
|
||||
req: Request;
|
||||
prompt: string;
|
||||
service: AIService;
|
||||
}) {
|
||||
if (service === "anthropic") {
|
||||
if (!canTokenizeClaude) return guesstimateClaudeTokenCount(prompt);
|
||||
try {
|
||||
return await requestClaudeTokenCount({
|
||||
requestId: String(req.id),
|
||||
prompt: prompt,
|
||||
});
|
||||
} catch (e) {
|
||||
req.log.error("Failed to tokenize with claude_tokenizer", e);
|
||||
return guesstimateClaudeTokenCount(prompt);
|
||||
}
|
||||
}
|
||||
if (service === "openai") {
|
||||
// All OpenAI models we support use the same tokenizer currently
|
||||
return getOpenAITokenCount(prompt);
|
||||
prompt,
|
||||
}: TokenCountRequest): Promise<TokenCountResult> {
|
||||
const time = process.hrtime();
|
||||
|
||||
switch (service) {
|
||||
case "anthropic":
|
||||
if (!canTokenizeClaude) {
|
||||
const result = guesstimateTokens(prompt);
|
||||
return {
|
||||
token_count: result,
|
||||
tokenizer: "guesstimate (claude-ipc disabled)",
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
}
|
||||
|
||||
// If the prompt is absolutely massive (possibly malicious) don't even try
|
||||
if (prompt.length > 500000) {
|
||||
return {
|
||||
token_count: guesstimateTokens(JSON.stringify(prompt)),
|
||||
tokenizer: "guesstimate (prompt too long)",
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await requestClaudeTokenCount({
|
||||
requestId: String(req.id),
|
||||
prompt,
|
||||
});
|
||||
return {
|
||||
token_count: result,
|
||||
tokenizer: "claude-ipc",
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
} catch (e: any) {
|
||||
req.log.error("Failed to tokenize with claude_tokenizer", e);
|
||||
const result = guesstimateTokens(prompt);
|
||||
return {
|
||||
token_count: result,
|
||||
tokenizer: `guesstimate (claude-ipc failed: ${e.message})`,
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
}
|
||||
|
||||
case "openai":
|
||||
const result = getOpenAITokenCount(prompt, req.body.model);
|
||||
return {
|
||||
...result,
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
throw new Error(`Unknown service: ${service}`);
|
||||
}
|
||||
}
|
||||
|
||||
function guesstimateClaudeTokenCount(prompt: string) {
|
||||
function getElapsedMs(time: [number, number]) {
|
||||
const diff = process.hrtime(time);
|
||||
return diff[0] * 1000 + diff[1] / 1e6;
|
||||
}
|
||||
|
||||
function guesstimateTokens(prompt: string) {
|
||||
// From Anthropic's docs:
|
||||
// The maximum length of prompt that Claude can see is its context window.
|
||||
// Claude's context window is currently ~6500 words / ~8000 tokens /
|
||||
// ~28000 Unicode characters.
|
||||
// We'll round up to ~0.3 tokens per character
|
||||
return Math.ceil(prompt.length * 0.3);
|
||||
// This suggests 0.28 tokens per character but in practice this seems to be
|
||||
// a substantial underestimate in some cases.
|
||||
return Math.ceil(prompt.length * 0.325);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user