diff --git a/package-lock.json b/package-lock.json index 93101b1..6103dd2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -19,6 +19,7 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "showdown": "^2.1.0", + "tiktoken": "^1.0.10", "uuid": "^9.0.0", "zlib": "^1.0.5", "zod": "^3.21.4" @@ -3854,6 +3855,11 @@ "real-require": "^0.2.0" } }, + "node_modules/tiktoken": { + "version": "1.0.10", + "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.10.tgz", + "integrity": "sha512-gF8ndTCNu7WcRFbl1UUWaFIB4CTXmHzS3tRYdyUYF7x3C6YR6Evoao4zhKDmWIwv2PzNbzoQMV8Pxt+17lEDbA==" + }, "node_modules/tmp": { "version": "0.2.1", "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.1.tgz", diff --git a/package.json b/package.json index c866f8a..e5f3549 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "pino": "^8.11.0", "pino-http": "^8.3.3", "showdown": "^2.1.0", + "tiktoken": "^1.0.10", "uuid": "^9.0.0", "zlib": "^1.0.5", "zod": "^3.21.4" diff --git a/src/admin/routes.ts b/src/admin/routes.ts index d718faf..82fb95e 100644 --- a/src/admin/routes.ts +++ b/src/admin/routes.ts @@ -1,4 +1,4 @@ -import { RequestHandler, Router } from "express"; +import express, { RequestHandler, Router } from "express"; import { config } from "../config"; import { usersRouter } from "./users"; @@ -32,5 +32,9 @@ const auth: RequestHandler = (req, res, next) => { }; adminRouter.use(auth); +adminRouter.use( + express.json({ limit: "20mb" }), + express.urlencoded({ extended: true, limit: "20mb" }) +); adminRouter.use("/users", usersRouter); export { adminRouter }; diff --git a/src/config.ts b/src/config.ts index 8d0ab06..929f8a9 100644 --- a/src/config.ts +++ b/src/config.ts @@ -63,6 +63,20 @@ type Config = { maxIpsPerUser: number; /** Per-IP limit for requests per minute to OpenAI's completions endpoint. */ modelRateLimit: number; + /** + * For OpenAI, the maximum number of context tokens (prompt + max output) a + * user can request before their request is rejected. + * Context limits can help prevent excessive spend. + * Defaults to 0, which means no limit beyond OpenAI's stated maximums. + */ + maxContextTokensOpenAI: number; + /** + * For Anthropic, the maximum number of context tokens a user can request. + * Claude context limits can prevent requests from tying up concurrency slots + * for too long, which can lengthen queue times for other users. + * Defaults to 0, which means no limit beyond Anthropic's stated maximums. + */ + maxContextTokensAnthropic: number; /** For OpenAI, the maximum number of sampled tokens a user can request. */ maxOutputTokensOpenAI: number; /** For Anthropic, the maximum number of sampled tokens a user can request. */ @@ -140,6 +154,11 @@ export const config: Config = { firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined), firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined), modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4), + maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0), + maxContextTokensAnthropic: getEnvWithDefault( + "MAX_CONTEXT_TOKENS_ANTHROPIC", + 0 + ), maxOutputTokensOpenAI: getEnvWithDefault("MAX_OUTPUT_TOKENS_OPENAI", 300), maxOutputTokensAnthropic: getEnvWithDefault( "MAX_OUTPUT_TOKENS_ANTHROPIC", diff --git a/src/info-page.ts b/src/info-page.ts index 51cdaa0..c1ac69a 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -199,7 +199,7 @@ Logs are anonymous and do not contain IP addresses or timestamps. [You can see t } if (config.queueMode !== "none") { - const waits = []; + const waits: string[] = []; infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`; if (config.openaiKey) { diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 334a41b..d8b2233 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -13,7 +13,6 @@ import { createPreprocessorMiddleware, finalizeBody, languageFilter, - limitOutputTokens, removeOriginHeaders, } from "./middleware/request"; import { @@ -76,7 +75,6 @@ const rewriteAnthropicRequest = ( addKey, addAnthropicPreamble, languageFilter, - limitOutputTokens, blockZoomerOrigins, removeOriginHeaders, finalizeBody, @@ -108,10 +106,16 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async ( body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; } - if (!req.originalUrl.includes("/v1/complete") && !req.originalUrl.includes("/complete")) { + if (req.inboundApi === "openai") { req.log.info("Transforming Anthropic response to OpenAI format"); body = transformAnthropicResponse(body); } + + // TODO: Remove once tokenization is stable + if (req.debug) { + body.proxy_tokenizer_debug_info = req.debug; + } + res.status(200).json(body); }; diff --git a/src/proxy/kobold.ts b/src/proxy/kobold.ts index efaed7a..3fcdf46 100644 --- a/src/proxy/kobold.ts +++ b/src/proxy/kobold.ts @@ -13,7 +13,6 @@ import { createPreprocessorMiddleware, finalizeBody, languageFilter, - limitOutputTokens, transformKoboldPayload, } from "./middleware/request"; import { @@ -45,7 +44,6 @@ const rewriteRequest = ( addKey, transformKoboldPayload, languageFilter, - limitOutputTokens, finalizeBody, ]; diff --git a/src/proxy/middleware/common.ts b/src/proxy/middleware/common.ts index 7c70655..d4266ef 100644 --- a/src/proxy/middleware/common.ts +++ b/src/proxy/middleware/common.ts @@ -45,6 +45,9 @@ export function writeErrorResponse( res.write(`data: [DONE]\n\n`); res.end(); } else { + if (req.debug) { + errorPayload.error.proxy_tokenizer_debug_info = req.debug; + } res.status(statusCode).json(errorPayload); } } @@ -86,7 +89,7 @@ export const handleInternalError = ( } else { writeErrorResponse(req, res, 500, { error: { - type: "proxy_rewriter_error", + type: "proxy_internal_error", proxy_note: `Reverse proxy encountered an error before it could reach the upstream API.`, message: err.message, stack: err.stack, diff --git a/src/proxy/middleware/request/add-key.ts b/src/proxy/middleware/request/add-key.ts index 08dbab3..df06858 100644 --- a/src/proxy/middleware/request/add-key.ts +++ b/src/proxy/middleware/request/add-key.ts @@ -41,8 +41,6 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => { // For such cases, ignore the requested model entirely. if (req.inboundApi === "openai" && req.outboundApi === "anthropic") { req.log.debug("Using an Anthropic key for an OpenAI-compatible request"); - // We don't assign the model here, that will happen when transforming the - // request body. assignedKey = keyPool.get("claude-v1"); } else { assignedKey = keyPool.get(req.body.model); diff --git a/src/proxy/middleware/request/check-context-size.ts b/src/proxy/middleware/request/check-context-size.ts new file mode 100644 index 0000000..20ad478 --- /dev/null +++ b/src/proxy/middleware/request/check-context-size.ts @@ -0,0 +1,135 @@ +import { Request } from "express"; +import { z } from "zod"; +import { config } from "../../../config"; +import { countTokens } from "../../../tokenization"; +import { RequestPreprocessor } from "."; + +const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; +const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; + +/** + * Claude models don't throw an error if you exceed the token limit and + * instead just become extremely slow and provide schizo output. To be safe, + * we will only allow 95% of the stated limit, which also accounts for our + * tokenization being slightly different than Anthropic's. + */ +const CLAUDE_TOKEN_LIMIT_ADJUSTMENT = 0.95; + +/** + * Assigns `req.promptTokens` and `req.outputTokens` based on the request body + * and outbound API format, which combined determine the size of the context. + * If the context is too large, an error is thrown. + * This preprocessor should run after any preprocessor that transforms the + * request body. + */ +export const checkContextSize: RequestPreprocessor = async (req) => { + let prompt; + + switch (req.outboundApi) { + case "openai": + req.outputTokens = req.body.max_tokens; + prompt = req.body.messages; + break; + case "anthropic": + req.outputTokens = req.body.max_tokens_to_sample; + prompt = req.body.prompt; + break; + default: + throw new Error(`Unknown outbound API: ${req.outboundApi}`); + } + + const result = await countTokens({ req, prompt, service: req.outboundApi }); + req.promptTokens = result.token_count; + + // TODO: Remove once token counting is stable + req.log.debug({ result: result }, "Counted prompt tokens."); + req.debug = req.debug ?? {}; + req.debug = { ...req.debug, ...result }; + + maybeReassignModel(req); + validateContextSize(req); +}; + +function validateContextSize(req: Request) { + assertRequestHasTokenCounts(req); + const promptTokens = req.promptTokens; + const outputTokens = req.outputTokens; + const contextTokens = promptTokens + outputTokens; + const model = req.body.model; + + const proxyMax = + (req.outboundApi === "openai" ? OPENAI_MAX_CONTEXT : CLAUDE_MAX_CONTEXT) || + Number.MAX_SAFE_INTEGER; + let modelMax = 0; + + if (model.match(/gpt-3.5/)) { + modelMax = 4096; + } else if (model.match(/gpt-4/)) { + modelMax = 8192; + } else if (model.match(/gpt-4-32k/)) { + modelMax = 32768; + } else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?(?:-100k)/)) { + modelMax = 100000 * CLAUDE_TOKEN_LIMIT_ADJUSTMENT; + } else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?$/)) { + modelMax = 9000 * CLAUDE_TOKEN_LIMIT_ADJUSTMENT; + } else if (model.match(/claude-2/)) { + modelMax = 100000 * CLAUDE_TOKEN_LIMIT_ADJUSTMENT; + } else { + // Don't really want to throw here because I don't want to have to update + // this ASAP every time a new model is released. + req.log.warn({ model }, "Unknown model, using 100k token limit."); + modelMax = 100000; + } + + const finalMax = Math.min(proxyMax, modelMax); + z.number() + .int() + .max(finalMax, { + message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`, + }) + .parse(contextTokens); + + req.log.debug( + { promptTokens, outputTokens, contextTokens, modelMax, proxyMax }, + "Prompt size validated" + ); + + req.debug.prompt_tokens = promptTokens; + req.debug.max_model_tokens = modelMax; + req.debug.max_proxy_tokens = proxyMax; +} + +function assertRequestHasTokenCounts( + req: Request +): asserts req is Request & { promptTokens: number; outputTokens: number } { + z.object({ + promptTokens: z.number().int().min(1), + outputTokens: z.number().int().min(1), + }) + .nonstrict() + .parse(req); +} + +/** + * For OpenAI-to-Anthropic requests, users can't specify the model, so we need + * to pick one based on the final context size. Ideally this would happen in + * the `transformOutboundPayload` preprocessor, but we don't have the context + * size at that point (and need a transformed body to calculate it). + */ +function maybeReassignModel(req: Request) { + if (req.inboundApi !== "openai" || req.outboundApi !== "anthropic") { + return; + } + + const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k"; + const contextSize = req.promptTokens! + req.outputTokens!; + + if (contextSize > 8500) { + req.log.debug( + { model: bigModel, contextSize }, + "Using Claude 100k model for OpenAI-to-Anthropic request" + ); + req.body.model = bigModel; + } + // Small model is the default already set in `transformOutboundPayload` +} diff --git a/src/proxy/middleware/request/index.ts b/src/proxy/middleware/request/index.ts index eee478f..4f61405 100644 --- a/src/proxy/middleware/request/index.ts +++ b/src/proxy/middleware/request/index.ts @@ -4,6 +4,7 @@ import type { ProxyReqCallback } from "http-proxy"; // Express middleware (runs before http-proxy-middleware, can be async) export { createPreprocessorMiddleware } from "./preprocess"; +export { checkContextSize } from "./check-context-size"; export { setApiFormat } from "./set-api-format"; export { transformOutboundPayload } from "./transform-outbound-payload"; @@ -14,7 +15,6 @@ export { blockZoomerOrigins } from "./block-zoomer-origins"; export { finalizeBody } from "./finalize-body"; export { languageFilter } from "./language-filter"; export { limitCompletions } from "./limit-completions"; -export { limitOutputTokens } from "./limit-output-tokens"; export { removeOriginHeaders } from "./remove-origin-headers"; export { transformKoboldPayload } from "./transform-kobold-payload"; diff --git a/src/proxy/middleware/request/limit-output-tokens.ts b/src/proxy/middleware/request/limit-output-tokens.ts deleted file mode 100644 index 09e9475..0000000 --- a/src/proxy/middleware/request/limit-output-tokens.ts +++ /dev/null @@ -1,46 +0,0 @@ -import { Request } from "express"; -import { config } from "../../../config"; -import { isCompletionRequest } from "../common"; -import { ProxyRequestMiddleware } from "."; - -/** Enforce a maximum number of tokens requested from the model. */ -export const limitOutputTokens: ProxyRequestMiddleware = (_proxyReq, req) => { - // TODO: do all of this shit in the zod validator - if (isCompletionRequest(req)) { - const requestedMax = Number.parseInt(getMaxTokensFromRequest(req)); - const apiMax = - req.outboundApi === "openai" - ? config.maxOutputTokensOpenAI - : config.maxOutputTokensAnthropic; - let maxTokens = requestedMax; - - if (typeof requestedMax !== "number") { - maxTokens = apiMax; - } - - maxTokens = Math.min(maxTokens, apiMax); - if (req.outboundApi === "openai") { - req.body.max_tokens = maxTokens; - } else if (req.outboundApi === "anthropic") { - req.body.max_tokens_to_sample = maxTokens; - } - - if (requestedMax !== maxTokens) { - req.log.info( - { requestedMax, configMax: apiMax, final: maxTokens }, - "Limiting user's requested max output tokens" - ); - } - } -}; - -function getMaxTokensFromRequest(req: Request) { - switch (req.outboundApi) { - case "anthropic": - return req.body?.max_tokens_to_sample; - case "openai": - return req.body?.max_tokens; - default: - throw new Error(`Unknown service: ${req.outboundApi}`); - } -} diff --git a/src/proxy/middleware/request/preprocess.ts b/src/proxy/middleware/request/preprocess.ts index 2915e7f..ed2db95 100644 --- a/src/proxy/middleware/request/preprocess.ts +++ b/src/proxy/middleware/request/preprocess.ts @@ -1,6 +1,11 @@ import { RequestHandler } from "express"; import { handleInternalError } from "../common"; -import { RequestPreprocessor, setApiFormat, transformOutboundPayload } from "."; +import { + RequestPreprocessor, + checkContextSize, + setApiFormat, + transformOutboundPayload, +} from "."; /** * Returns a middleware function that processes the request body into the given @@ -13,6 +18,7 @@ export const createPreprocessorMiddleware = ( const preprocessors: RequestPreprocessor[] = [ setApiFormat(apiFormat), transformOutboundPayload, + checkContextSize, ...(additionalPreprocessors ?? []), ]; diff --git a/src/proxy/middleware/request/transform-outbound-payload.ts b/src/proxy/middleware/request/transform-outbound-payload.ts index c96a745..7585702 100644 --- a/src/proxy/middleware/request/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/transform-outbound-payload.ts @@ -1,8 +1,12 @@ import { Request } from "express"; import { z } from "zod"; +import { config } from "../../../config"; +import { OpenAIPromptMessage } from "../../../tokenization"; import { isCompletionRequest } from "../common"; import { RequestPreprocessor } from "."; -// import { countTokens } from "../../../tokenization"; + +const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic; +const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI; // https://console.anthropic.com/docs/api/reference#-v1-complete const AnthropicV1CompleteSchema = z.object({ @@ -11,7 +15,10 @@ const AnthropicV1CompleteSchema = z.object({ required_error: "No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?", }), - max_tokens_to_sample: z.coerce.number(), + max_tokens_to_sample: z.coerce + .number() + .int() + .transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)), stop_sequences: z.array(z.string()).optional(), stream: z.boolean().optional().default(false), temperature: z.coerce.number().optional().default(1), @@ -32,6 +39,8 @@ const OpenAIV1ChatCompletionSchema = z.object({ { required_error: "No prompt found. Are you sending an Anthropic-formatted request to the OpenAI endpoint?", + invalid_type_error: + "Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.", } ), temperature: z.number().optional().default(1), @@ -45,7 +54,12 @@ const OpenAIV1ChatCompletionSchema = z.object({ .optional(), stream: z.boolean().optional().default(false), stop: z.union([z.string(), z.array(z.string())]).optional(), - max_tokens: z.coerce.number().optional(), + max_tokens: z.coerce + .number() + .int() + .optional() + .default(16) + .transform((v) => Math.min(v, OPENAI_OUTPUT_MAX)), frequency_penalty: z.number().optional().default(0), presence_penalty: z.number().optional().default(0), logit_bias: z.any().optional(), @@ -63,7 +77,6 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { } if (sameService) { - // Just validate, don't transform. const validator = req.outboundApi === "openai" ? OpenAIV1ChatCompletionSchema @@ -76,11 +89,12 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { ); throw result.error; } + req.body = result.data; return; } if (req.inboundApi === "openai" && req.outboundApi === "anthropic") { - req.body = openaiToAnthropic(req.body, req); + req.body = await openaiToAnthropic(req.body, req); return; } @@ -89,7 +103,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { ); }; -function openaiToAnthropic(body: any, req: Request) { +async function openaiToAnthropic(body: any, req: Request) { const result = OpenAIV1ChatCompletionSchema.safeParse(body); if (!result.success) { req.log.error( @@ -107,37 +121,7 @@ function openaiToAnthropic(body: any, req: Request) { req.headers["anthropic-version"] = "2023-01-01"; const { messages, ...rest } = result.data; - const prompt = - result.data.messages - .map((m) => { - let role: string = m.role; - if (role === "assistant") { - role = "Assistant"; - } else if (role === "system") { - role = "System"; - } else if (role === "user") { - role = "Human"; - } - // https://console.anthropic.com/docs/prompt-design - // `name` isn't supported by Anthropic but we can still try to use it. - return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${ - m.content - }`; - }) - .join("") + "\n\nAssistant: "; - - // No longer defaulting to `claude-v1.2` because it seems to be in the process - // of being deprecated. `claude-v1` is the new default. - // If you have keys that can still use `claude-v1.2`, you can set the - // CLAUDE_BIG_MODEL and CLAUDE_SMALL_MODEL environment variables in your .env - // file. - - const CLAUDE_BIG = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k"; - const CLAUDE_SMALL = process.env.CLAUDE_SMALL_MODEL || "claude-v1"; - - // TODO: Finish implementing tokenizer for more accurate model selection. - // This currently uses _character count_, not token count. - const model = prompt.length > 25000 ? CLAUDE_BIG : CLAUDE_SMALL; + const prompt = openAIMessagesToClaudePrompt(messages); let stops = rest.stop ? Array.isArray(rest.stop) @@ -154,9 +138,35 @@ function openaiToAnthropic(body: any, req: Request) { return { ...rest, - model, + // Model may be overridden in `calculate-context-size.ts` to avoid having + // a circular dependency (`calculate-context-size.ts` needs an already- + // transformed request body to count tokens, but this function would like + // to know the count to select a model). + model: process.env.CLAUDE_SMALL_MODEL || "claude-v1", prompt: prompt, max_tokens_to_sample: rest.max_tokens, stop_sequences: stops, }; } + +export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) { + return ( + messages + .map((m) => { + let role: string = m.role; + if (role === "assistant") { + role = "Assistant"; + } else if (role === "system") { + role = "System"; + } else if (role === "user") { + role = "Human"; + } + // https://console.anthropic.com/docs/prompt-design + // `name` isn't supported by Anthropic but we can still try to use it. + return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${ + m.content + }`; + }) + .join("") + "\n\nAssistant:" + ); +} diff --git a/src/proxy/openai.ts b/src/proxy/openai.ts index 1c8832e..0f33122 100644 --- a/src/proxy/openai.ts +++ b/src/proxy/openai.ts @@ -14,7 +14,6 @@ import { finalizeBody, languageFilter, limitCompletions, - limitOutputTokens, removeOriginHeaders, } from "./middleware/request"; import { @@ -93,7 +92,6 @@ const rewriteRequest = ( const rewriterPipeline = [ addKey, languageFilter, - limitOutputTokens, limitCompletions, blockZoomerOrigins, removeOriginHeaders, @@ -125,6 +123,11 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async ( body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`; } + // TODO: Remove once tokenization is stable + if (req.debug) { + body.proxy_tokenizer_debug_info = req.debug; + } + res.status(200).json(body); }; diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index ed4f982..bf1f867 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -10,10 +10,18 @@ import { kobold } from "./kobold"; import { openai } from "./openai"; import { anthropic } from "./anthropic"; -const router = express.Router(); - -router.use(gatekeeper); -router.use("/kobold", kobold); -router.use("/openai", openai); -router.use("/anthropic", anthropic); -export { router as proxyRouter }; +const proxyRouter = express.Router(); +proxyRouter.use( + express.json({ limit: "1536kb" }), + express.urlencoded({ extended: true, limit: "1536kb" }) +); +proxyRouter.use(gatekeeper); +proxyRouter.use((req, _res, next) => { + req.startTime = Date.now(); + req.retryCount = 0; + next(); +}); +proxyRouter.use("/kobold", kobold); +proxyRouter.use("/openai", openai); +proxyRouter.use("/anthropic", anthropic); +export { proxyRouter as proxyRouter }; diff --git a/src/server.ts b/src/server.ts index be0aa13..79d9cde 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,6 +12,7 @@ import { handleInfoPage } from "./info-page"; import { logQueue } from "./prompt-logging"; import { start as startRequestQueue } from "./proxy/queue"; import { init as initUserStore } from "./proxy/auth/user-store"; +import { init as initTokenizers } from "./tokenization"; import { checkOrigin } from "./proxy/check-origin"; const PORT = config.port; @@ -47,25 +48,16 @@ app.use( }) ); -app.get("/health", (_req, res) => res.sendStatus(200)); -app.use((req, _res, next) => { - req.startTime = Date.now(); - req.retryCount = 0; - next(); -}); -app.use(cors()); -app.use( - express.json({ limit: "10mb" }), - express.urlencoded({ extended: true, limit: "10mb" }) -); - // TODO: Detect (or support manual configuration of) whether the app is behind // a load balancer/reverse proxy, which is necessary to determine request IP // addresses correctly. app.set("trust proxy", true); -// routes +app.get("/health", (_req, res) => res.sendStatus(200)); +app.use(cors()); app.use(checkOrigin); + +// routes app.get("/", handleInfoPage); app.use("/admin", adminRouter); app.use("/proxy", proxyRouter); @@ -99,6 +91,8 @@ async function start() { keyPool.init(); + await initTokenizers(); + if (config.gatekeeper === "user_token") { await initUserStore(); } diff --git a/src/tokenization/claude.ts b/src/tokenization/claude.ts new file mode 100644 index 0000000..309ceee --- /dev/null +++ b/src/tokenization/claude.ts @@ -0,0 +1,34 @@ +// For now this is just using the GPT vocabulary, even though Claude has a +// different one. Token counts won't be perfect so this just provides +// a rough estimate. +// +// TODO: use huggingface tokenizers instead of openai's tiktoken library since +// that should support the vocabulary file Anthropic provides. + +import { Tiktoken } from "tiktoken/lite"; +import cl100k_base from "tiktoken/encoders/cl100k_base.json"; + +let encoder: Tiktoken; + +export function init() { + encoder = new Tiktoken( + cl100k_base.bpe_ranks, + cl100k_base.special_tokens, + cl100k_base.pat_str + ); + return true; +} + +export function getTokenCount(prompt: string, _model: string) { + if (prompt.length > 250000) { + return { + tokenizer: "tiktoken (prompt length limit exceeded)", + token_count: 100000, + }; + } + + return { + tokenizer: "tiktoken (cl100k_base)", + token_count: encoder.encode(prompt).length, + }; +} diff --git a/src/tokenization/index.ts b/src/tokenization/index.ts new file mode 100644 index 0000000..df7dc7d --- /dev/null +++ b/src/tokenization/index.ts @@ -0,0 +1,2 @@ +export { OpenAIPromptMessage } from "./openai"; +export { init, countTokens } from "./tokenizer"; diff --git a/src/tokenization/openai.ts b/src/tokenization/openai.ts new file mode 100644 index 0000000..fe83b35 --- /dev/null +++ b/src/tokenization/openai.ts @@ -0,0 +1,57 @@ +import { Tiktoken } from "tiktoken/lite"; +import cl100k_base from "tiktoken/encoders/cl100k_base.json"; + +let encoder: Tiktoken; + +export function init() { + encoder = new Tiktoken( + cl100k_base.bpe_ranks, + cl100k_base.special_tokens, + cl100k_base.pat_str + ); + return true; +} + +// Tested against: +// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + +export function getTokenCount(messages: any[], model: string) { + const gpt4 = model.startsWith("gpt-4"); + + const tokensPerMessage = gpt4 ? 3 : 4; + const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present + + let numTokens = 0; + + for (const message of messages) { + numTokens += tokensPerMessage; + for (const key of Object.keys(message)) { + { + const value = message[key]; + // Break if we get a huge message or exceed the token limit to prevent DoS + // 100k tokens allows for future 100k GPT-4 models and 250k characters is + // just a sanity check + if (value.length > 250000 || numTokens > 100000) { + numTokens = 100000; + return { + tokenizer: "tiktoken (prompt length limit exceeded)", + token_count: numTokens, + }; + } + + numTokens += encoder.encode(message[key]).length; + if (key === "name") { + numTokens += tokensPerName; + } + } + } + } + numTokens += 3; // every reply is primed with <|start|>assistant<|message|> + return { tokenizer: "tiktoken", token_count: numTokens }; +} + +export type OpenAIPromptMessage = { + name?: string; + content: string; + role: string; +}; diff --git a/src/tokenization/tokenizer.ts b/src/tokenization/tokenizer.ts new file mode 100644 index 0000000..3f04dc9 --- /dev/null +++ b/src/tokenization/tokenizer.ts @@ -0,0 +1,58 @@ +import { Request } from "express"; +import { config } from "../config"; +import { + init as initClaude, + getTokenCount as getClaudeTokenCount, +} from "./claude"; +import { + init as initOpenAi, + getTokenCount as getOpenAITokenCount, + OpenAIPromptMessage, +} from "./openai"; + +export async function init() { + if (config.anthropicKey) { + initClaude(); + } + if (config.openaiKey) { + initOpenAi(); + } +} + +type TokenCountResult = { + token_count: number; + tokenizer: string; + tokenization_duration_ms: number; +}; +type TokenCountRequest = { + req: Request; +} & ( + | { prompt: string; service: "anthropic" } + | { prompt: OpenAIPromptMessage[]; service: "openai" } +); +export async function countTokens({ + req, + service, + prompt, +}: TokenCountRequest): Promise { + const time = process.hrtime(); + switch (service) { + case "anthropic": + return { + ...getClaudeTokenCount(prompt, req.body.model), + tokenization_duration_ms: getElapsedMs(time), + }; + case "openai": + return { + ...getOpenAITokenCount(prompt, req.body.model), + tokenization_duration_ms: getElapsedMs(time), + }; + default: + throw new Error(`Unknown service: ${service}`); + } +} + +function getElapsedMs(time: [number, number]) { + const diff = process.hrtime(time); + return diff[0] * 1000 + diff[1] / 1e6; +} diff --git a/src/types/custom.d.ts b/src/types/custom.d.ts index e81bd1b..6708482 100644 --- a/src/types/custom.d.ts +++ b/src/types/custom.d.ts @@ -18,6 +18,10 @@ declare global { onAborted?: () => void; proceed: () => void; heartbeatInterval?: NodeJS.Timeout; + promptTokens?: number; + outputTokens?: number; + // TODO: remove later + debug: Record; } } } diff --git a/tsconfig.json b/tsconfig.json index 6789a3a..13a4926 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -9,7 +9,9 @@ "skipLibCheck": true, "skipDefaultLibCheck": true, "outDir": "build", - "sourceMap": true + "sourceMap": true, + "resolveJsonModule": true, + "useUnknownInCatchVariables": false }, "include": ["src"], "exclude": ["node_modules"],