This commit is contained in:
khanon
2023-12-13 21:56:07 +00:00
parent 0d3682197c
commit fad16cc268
40 changed files with 588 additions and 357 deletions
+29 -4
View File
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { z } from "zod";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
@@ -29,11 +33,11 @@ export async function getTokenCount(
return getTextTokenCount(prompt);
}
const gpt4 = model.startsWith("gpt-4");
const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
token_count: Math.ceil(tokens),
};
}
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
numTokens += encoder.encode(message.parts[0].text).length;
}
numTokens += 3;
return {
tokenizer: "tiktoken (google-ai estimate)",
token_count: numTokens,
};
}
+10 -5
View File
@@ -1,5 +1,8 @@
import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils";
import {
init as initClaude,
@@ -9,6 +12,7 @@ import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
getOpenAIImageCost,
estimateGoogleAITokenCount,
} from "./openai";
import { APIFormat } from "../key-management";
@@ -24,8 +28,9 @@ type TokenCountRequest = { req: Request } & (
| {
prompt: string;
completion?: never;
service: "openai-text" | "anthropic" | "google-palm";
service: "openai-text" | "anthropic" | "google-ai";
}
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" }
);
@@ -65,11 +70,11 @@ export async function countTokens({
}),
tokenization_duration_ms: getElapsedMs(time),
};
case "google-palm":
// TODO: Can't find a tokenization library for PaLM. There is an API
case "google-ai":
// TODO: Can't find a tokenization library for Gemini. There is an API
// endpoint for it but it adds significant latency to the request.
return {
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time),
};
default: