improves OpenAI token counting accuracy

This commit is contained in:
nai-degen
2023-06-02 01:57:55 -05:00
parent 0064fd4f3a
commit 4341dc5961
3 changed files with 118 additions and 31 deletions
+3 -3
View File
@@ -59,7 +59,7 @@ export async function requestTokenCount({
throw new Error("Claude tokenizer is not initialized");
}
log.debug({ requestId, prompt: prompt.length }, "Requesting token count");
log.debug({ requestId, chars: prompt.length }, "Requesting token count");
await socket.send(["tokenize", requestId, prompt]);
log.debug({ requestId }, "Waiting for socket response");
@@ -75,11 +75,11 @@ export async function requestTokenCount({
setTimeout(() => {
if (pendingRequests.has(requestId)) {
pendingRequests.delete(requestId);
const err = "Tokenizer took too long to respond";
const err = "Tokenizer deadline exceeded";
log.warn({ requestId }, err);
reject(new Error(err));
}
}, 500);
}, 250); // TODO: make this configurable, some really crappy VMs might need more time
});
}
+42 -3
View File
@@ -12,7 +12,46 @@ export function init() {
return true;
}
export function getTokenCount(text: string) {
const tokens = encoder.encode(text);
return tokens.length;
// Implmentation based and tested against:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
export function getTokenCount(messages: any[], model: string) {
const gpt4 = model.startsWith("gpt-4");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
let numTokens = 0;
for (const message of messages) {
numTokens += tokensPerMessage;
for (const key of Object.keys(message)) {
{
const value = message[key];
// Break if we get a huge message or exceed the token limit to prevent DoS
// 100k tokens allows for future 100k GPT-4 models and 250k characters is
// just a sanity check
if (value.length > 250000 || numTokens > 100000) {
numTokens = 100000;
return {
tokenizer: "tiktoken (prompt length limit exceeded)",
token_count: numTokens,
};
}
numTokens += encoder.encode(message[key]).length;
if (key === "name") {
numTokens += tokensPerName;
}
}
}
}
numTokens += 3; // every reply is primed with <|start|>assistant<|message|>
return { tokenizer: "tiktoken", token_count: numTokens };
}
export type OpenAIPromptMessage = {
name?: string;
content: string;
role: string;
};
+73 -25
View File
@@ -1,6 +1,5 @@
import { Request } from "express";
import { config } from "../config";
import { AIService } from "../key-management";
import { logger } from "../logger";
import {
init as initIpc,
@@ -9,6 +8,7 @@ import {
import {
init as initEncoder,
getTokenCount as getOpenAITokenCount,
OpenAIPromptMessage,
} from "./openai";
let canTokenizeClaude = false;
@@ -28,38 +28,86 @@ export async function init() {
}
}
type TokenCountResult = {
token_count: number;
tokenizer: string;
tokenization_duration_ms: number;
};
type TokenCountRequest = {
req: Request;
} & (
| { prompt: string; service: "anthropic" }
| { prompt: OpenAIPromptMessage[]; service: "openai" }
);
export async function countTokens({
req,
prompt,
service,
}: {
req: Request;
prompt: string;
service: AIService;
}) {
if (service === "anthropic") {
if (!canTokenizeClaude) return guesstimateClaudeTokenCount(prompt);
try {
return await requestClaudeTokenCount({
requestId: String(req.id),
prompt: prompt,
});
} catch (e) {
req.log.error("Failed to tokenize with claude_tokenizer", e);
return guesstimateClaudeTokenCount(prompt);
}
}
if (service === "openai") {
// All OpenAI models we support use the same tokenizer currently
return getOpenAITokenCount(prompt);
prompt,
}: TokenCountRequest): Promise<TokenCountResult> {
const time = process.hrtime();
switch (service) {
case "anthropic":
if (!canTokenizeClaude) {
const result = guesstimateTokens(prompt);
return {
token_count: result,
tokenizer: "guesstimate (claude-ipc disabled)",
tokenization_duration_ms: getElapsedMs(time),
};
}
// If the prompt is absolutely massive (possibly malicious) don't even try
if (prompt.length > 500000) {
return {
token_count: guesstimateTokens(JSON.stringify(prompt)),
tokenizer: "guesstimate (prompt too long)",
tokenization_duration_ms: getElapsedMs(time),
};
}
try {
const result = await requestClaudeTokenCount({
requestId: String(req.id),
prompt,
});
return {
token_count: result,
tokenizer: "claude-ipc",
tokenization_duration_ms: getElapsedMs(time),
};
} catch (e: any) {
req.log.error("Failed to tokenize with claude_tokenizer", e);
const result = guesstimateTokens(prompt);
return {
token_count: result,
tokenizer: `guesstimate (claude-ipc failed: ${e.message})`,
tokenization_duration_ms: getElapsedMs(time),
};
}
case "openai":
const result = getOpenAITokenCount(prompt, req.body.model);
return {
...result,
tokenization_duration_ms: getElapsedMs(time),
};
default:
throw new Error(`Unknown service: ${service}`);
}
}
function guesstimateClaudeTokenCount(prompt: string) {
function getElapsedMs(time: [number, number]) {
const diff = process.hrtime(time);
return diff[0] * 1000 + diff[1] / 1e6;
}
function guesstimateTokens(prompt: string) {
// From Anthropic's docs:
// The maximum length of prompt that Claude can see is its context window.
// Claude's context window is currently ~6500 words / ~8000 tokens /
// ~28000 Unicode characters.
// We'll round up to ~0.3 tokens per character
return Math.ceil(prompt.length * 0.3);
// This suggests 0.28 tokens per character but in practice this seems to be
// a substantial underestimate in some cases.
return Math.ceil(prompt.length * 0.325);
}