diff --git a/src/config.ts b/src/config.ts index c6d3025..4ad9f0f 100644 --- a/src/config.ts +++ b/src/config.ts @@ -286,6 +286,14 @@ type Config = { googleSheetsSpreadsheetId?: string; /** Whether to periodically check keys for usage and validity. */ checkKeys: boolean; + /** + * Whether to use remote API token counting endpoints for Anthropic, AWS + * Bedrock, and GCP Vertex AI. When enabled, the proxy will use the provider's + * token counting API to get accurate token counts for prompts (including + * images). Output tokens are always counted from API responses when available. + * Falls back to local tokenization if remote counting fails. + */ + useRemoteTokenCounting: boolean; /** Whether to publicly show total token costs on the info page. */ showTokenCosts: boolean; /** @@ -567,6 +575,7 @@ export const config: Config = { ), logLevel: getEnvWithDefault("LOG_LEVEL", "info"), checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev), + useRemoteTokenCounting: getEnvWithDefault("USE_REMOTE_TOKEN_COUNTING", true), showTokenCosts: getEnvWithDefault("SHOW_TOKEN_COSTS", false), allowAwsLogging: getEnvWithDefault("ALLOW_AWS_LOGGING", false), promptLogging: getEnvWithDefault("PROMPT_LOGGING", false), diff --git a/src/proxy/anthropic.ts b/src/proxy/anthropic.ts index 6ac8b88..c1e037d 100644 --- a/src/proxy/anthropic.ts +++ b/src/proxy/anthropic.ts @@ -196,8 +196,11 @@ function setAnthropicBetaHeader(req: Request) { betaHeaders.push("extended-cache-ttl-2025-04-11"); } - // Add 1M context beta header for Claude Sonnet 4 if context > 200k tokens - if (model?.includes("claude-sonnet-4") && req.promptTokens && req.outputTokens) { + // Add 1M context beta header for Claude Sonnet 4/Opus 4 if context > 200k tokens + const supportsBigContext = + model?.includes("claude-sonnet-4") || + model?.includes("claude-opus-4"); + if (supportsBigContext && req.promptTokens && req.outputTokens) { const contextTokens = req.promptTokens + req.outputTokens; if (contextTokens > 200000) { betaHeaders.push("context-1m-2025-08-07"); @@ -211,8 +214,8 @@ function setAnthropicBetaHeader(req: Request) { } /** - * Adds web search tool for Claude-3.5 and Claude-3.7 models when enable_web_search is true - * + * Adds web search tool for Claude-3.5, Claude-3.7, Claude-4, and Claude-4.5 models when enable_web_search is true + * * Supports all optional parameters documented in the Claude API: * - max_uses: Limit the number of searches per request * - allowed_domains: Only include results from these domains @@ -323,7 +326,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware( */ const preprocessAnthropicTextRequest: RequestHandler = (req, res, next) => { const model = req.body.model; - const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4"); + const isClaude4Model = + model?.includes("claude-sonnet-4") || + model?.includes("claude-opus-4"); if (model?.startsWith("claude-3") || isClaude4Model) { textToChatPreprocessor(req, res, next); } else { @@ -356,7 +361,9 @@ const oaiToChatPreprocessor = createPreprocessorMiddleware( const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { maybeReassignModel(req); const model = req.body.model; - const isClaude4 = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4"); + const isClaude4 = + model?.includes("claude-sonnet-4") || + model?.includes("claude-opus-4"); if (model?.includes("claude-3") || isClaude4) { oaiToChatPreprocessor(req, res, next); } else { diff --git a/src/proxy/aws-claude.ts b/src/proxy/aws-claude.ts index f8caf7c..f2e5c12 100644 --- a/src/proxy/aws-claude.ts +++ b/src/proxy/aws-claude.ts @@ -102,7 +102,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware( */ const preprocessAwsTextRequest: RequestHandler = (req, res, next) => { const model = req.body.model; - const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4"); + const isClaude4Model = + model?.includes("claude-sonnet-4") || + model?.includes("claude-opus-4"); if (model?.includes("claude-3") || isClaude4Model) { textToChatPreprocessor(req, res, next); } else { @@ -126,7 +128,9 @@ const oaiToAwsChatPreprocessor = createPreprocessorMiddleware( */ const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => { const model = req.body.model; - const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4"); + const isClaude4Model = + model?.includes("claude-sonnet-4") || + model?.includes("claude-opus-4"); if (model?.includes("claude-3") || isClaude4Model) { oaiToAwsChatPreprocessor(req, res, next); } else { diff --git a/src/proxy/middleware/request/mutators/add-key.ts b/src/proxy/middleware/request/mutators/add-key.ts index 08c710f..8719d18 100644 --- a/src/proxy/middleware/request/mutators/add-key.ts +++ b/src/proxy/middleware/request/mutators/add-key.ts @@ -32,8 +32,9 @@ export const addKey: ProxyReqMutator = (manager) => { if (inboundApi === outboundApi) { // Pass streaming information for GPT-5 models that require verified keys for streaming + // Pass request body for cache-aware key selection (Anthropic, AWS, GCP) const isStreaming = body.stream === true; - assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming); + assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming, body); } else { switch (outboundApi) { // If we are translating between API formats we may need to select a model @@ -45,7 +46,7 @@ export const addKey: ProxyReqMutator = (manager) => { case "mistral-ai": case "mistral-text": case "google-ai": - assignedKey = keyPool.get(body.model, service); + assignedKey = keyPool.get(body.model, service, undefined, undefined, body); break; case "openai-text": assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service); diff --git a/src/proxy/middleware/request/mutators/sign-aws-request.ts b/src/proxy/middleware/request/mutators/sign-aws-request.ts index 86c3f2c..ae15541 100644 --- a/src/proxy/middleware/request/mutators/sign-aws-request.ts +++ b/src/proxy/middleware/request/mutators/sign-aws-request.ts @@ -24,7 +24,8 @@ const AMZ_HOST = export const signAwsRequest: ProxyReqMutator = async (manager) => { const req = manager.request; const { model, stream } = req.body; - const key = keyPool.get(model, "aws") as AwsBedrockKey; + // Pass request body to enable cache-aware key selection + const key = keyPool.get(model, "aws", undefined, stream, req.body) as AwsBedrockKey; manager.setKey(key); let system = req.body.system ?? ""; @@ -50,6 +51,8 @@ export const signAwsRequest: ProxyReqMutator = async (manager) => { // Uses the AWS SDK to sign a request, then modifies our HPM proxy request // with the headers generated by the SDK. + // Note: AWS Bedrock uses body parameters (anthropic_version) instead of headers + // for versioning, so we don't include anthropic-beta or anthropic-version headers. const newRequest = new HttpRequest({ method: "POST", protocol: "https:", @@ -130,6 +133,9 @@ function getStrictlyValidatedBodyForAws(req: Readonly): unknown { .parse(req.body); break; case "anthropic-chat": + // Preserve anthropic_version if user provided it (for beta features) + const userAnthropicVersion = (req.body as any).anthropic_version; + strippedParams = AnthropicV1MessagesSchema.pick({ messages: true, system: true, @@ -140,11 +146,14 @@ function getStrictlyValidatedBodyForAws(req: Readonly): unknown { top_p: true, tools: true, tool_choice: true, - thinking: true + thinking: true, + cache_control: true }) .strip() .parse(req.body); - strippedParams.anthropic_version = "bedrock-2023-05-31"; + + // Use user-provided version or default to bedrock-2023-05-31 + strippedParams.anthropic_version = userAnthropicVersion || "bedrock-2023-05-31"; break; case "mistral-ai": strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body); diff --git a/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts b/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts index 3820591..9f43fe0 100644 --- a/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts +++ b/src/proxy/middleware/request/mutators/sign-vertex-ai-request.ts @@ -20,7 +20,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { } const { model } = req.body; - const key: GcpKey = keyPool.get(model, "gcp") as GcpKey; + const key: GcpKey = keyPool.get(model, "gcp", undefined, undefined, req.body) as GcpKey; if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) { const [token, durationSec] = await refreshGcpAccessToken(key); @@ -38,6 +38,9 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { // TODO: This should happen in transform-outbound-payload.ts // TODO: Support tools + // Preserve anthropic_version if user provided it (for beta features) + const userAnthropicVersion = (req.body as any).anthropic_version; + let strippedParams: Record; strippedParams = AnthropicV1MessagesSchema.pick({ messages: true, @@ -50,11 +53,14 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { stream: true, tools: true, tool_choice: true, - thinking: true + thinking: true, + cache_control: true }) .strip() .parse(req.body); - strippedParams.anthropic_version = "vertex-2023-10-16"; + + // Use user-provided version or default to vertex-2023-10-16 + strippedParams.anthropic_version = userAnthropicVersion || "vertex-2023-10-16"; const credential = await getCredentialsFromGcpKey(key); @@ -63,6 +69,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => { // stream adapter selects the correct transformer. manager.setHeader("anthropic-version", "2023-06-01"); + // GCP Vertex AI uses body parameter anthropic_version (set above) for versioning, + // not anthropic-beta header. Beta features are enabled through body params. manager.setSignedRequest({ method: "POST", protocol: "https:", diff --git a/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts index c675f82..a784f3a 100644 --- a/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts +++ b/src/proxy/middleware/request/preprocessors/count-prompt-tokens.ts @@ -7,21 +7,50 @@ import { AnthropicChatMessage, flattenAnthropicMessages, } from "../../../../shared/api-schemas/anthropic"; -import { - MistralAIChatMessage, +import { + MistralAIChatMessage, ContentItem, - isMistralVisionModel + isMistralVisionModel } from "../../../../shared/api-schemas/mistral-ai"; import { isGrokVisionModel } from "../../../../shared/api-schemas/xai"; +import { keyPool } from "../../../../shared/key-management"; +import { config } from "../../../../config"; /** * Given a request with an already-transformed body, counts the number of * tokens and assigns the count to the request. + * + * If remote token counting is enabled, we temporarily get a key from the pool + * to use the remote API, then clear it. The actual key assignment happens + * later in the mutators after the request is dequeued. */ export const countPromptTokens: RequestPreprocessor = async (req) => { const service = req.outboundApi; let result; + // For remote token counting, temporarily get a key from the pool + // We don't permanently assign it - that happens in mutators after dequeue + // IMPORTANT: Don't pass req.body here to avoid premature cache fingerprinting + // Cache fingerprinting should only happen during actual key assignment in mutators, + // after all preprocessing (including model name transformations) is complete + let tempKey; + const hadKey = !!req.key; + if (config.useRemoteTokenCounting && !req.key) { + try { + tempKey = keyPool.get(req.body.model, req.service, undefined, undefined); + req.key = tempKey; // Temporarily assign for token counting + req.log.debug( + { keyHash: tempKey.hash, service: req.service }, + "Temporarily assigned key for remote token counting" + ); + } catch (error) { + req.log.debug( + { error: (error as Error).message }, + "Could not get key for remote token counting, will use local tokenizer" + ); + } + } + switch (service) { case "openai": { req.outputTokens = req.body.max_completion_tokens || req.body.max_tokens; @@ -129,4 +158,14 @@ export const countPromptTokens: RequestPreprocessor = async (req) => { req.log.debug({ result: result }, "Counted prompt tokens."); req.tokenizerInfo = req.tokenizerInfo ?? {}; req.tokenizerInfo = { ...req.tokenizerInfo, ...result }; + + // Clear the temporary key if we assigned one for token counting + // The real key will be assigned later in mutators after dequeue + if (tempKey && !hadKey) { + delete req.key; + req.log.debug( + { keyHash: tempKey.hash }, + "Cleared temporary key after token counting" + ); + } }; diff --git a/src/proxy/middleware/request/preprocessors/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts index 4c12938..f9ebb10 100644 --- a/src/proxy/middleware/request/preprocessors/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -121,13 +121,13 @@ export const validateContextSize: RequestPreprocessor = async (req) => { modelMax = 200000; } else if (model.match(/^claude-3/)) { modelMax = 200000; - } else if (model.match(/^claude-(?:sonnet|opus)-4/)) { + } else if (model.match(/^claude-(?:sonnet|opus)-4(?:-5)?/)) { modelMax = 1000000; } else if (model.match(/^gemini-/)) { modelMax = 1024000; } else if (model.match(/^anthropic\.claude-3/)) { modelMax = 200000; - } else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4/)) { + } else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4(?:-5)?/)) { modelMax = 1000000; } else if (model.match(/^anthropic\.claude-v2:\d/)) { modelMax = 200000; diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index aed2cf8..e07b154 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -1094,15 +1094,130 @@ const countResponseTokens: ProxyResHandlerWithBody = async ( try { assertJsonResponse(body); const service = req.outboundApi; - const completion = getCompletionFromBody(req, body); - const tokens = await countTokens({ req, completion, service }); - - if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") { - // O1 consumes (a significant amount of) invisible tokens for the chain- - // of-thought reasoning. We have no way to count these other than to check - // the response body. - tokens.reasoning_tokens = - body.usage?.completion_tokens_details?.reasoning_tokens; + + // Try to get token counts from the API response first + let tokens: { token_count: number; tokenizer: string; reasoning_tokens?: number } | null = null; + + // Anthropic API returns usage data in the response + if (service === "anthropic-chat" && body.usage) { + tokens = { + token_count: body.usage.output_tokens || 0, + tokenizer: "anthropic-api", + }; + req.log.debug( + { service, outputTokens: tokens.token_count, usage: body.usage }, + "Got output token count from Anthropic API response" + ); + + // Sanity check: if request had cache_control, expect cache metrics + if (req.body.system || req.body.tools || req.body.messages) { + const hasCacheControl = checkForCacheControl(req.body); + if (hasCacheControl) { + const cacheRead = body.usage.cache_read_input_tokens || 0; + const cacheCreation = body.usage.cache_creation_input_tokens || 0; + if (cacheRead === 0 && cacheCreation === 0) { + req.log.error( + { keyHash: req.key?.hash, usage: body.usage }, + "CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from Anthropic API" + ); + } + } + } + } + // AWS Bedrock returns usage data in the response (same format as Anthropic) + else if (req.service === "aws" && service === "anthropic-chat" && body.usage) { + tokens = { + token_count: body.usage.output_tokens || 0, + tokenizer: "aws-bedrock-api", + }; + req.log.debug( + { service, outputTokens: tokens.token_count, usage: body.usage }, + "Got output token count from AWS Bedrock API response" + ); + + // Sanity check: if request had cache_control, expect cache metrics + if (req.body.system || req.body.tools || req.body.messages) { + const hasCacheControl = checkForCacheControl(req.body); + if (hasCacheControl) { + const cacheRead = body.usage.cache_read_input_tokens || 0; + const cacheCreation = body.usage.cache_creation_input_tokens || 0; + if (cacheRead === 0 && cacheCreation === 0) { + req.log.error( + { keyHash: req.key?.hash, usage: body.usage }, + "CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from AWS Bedrock API" + ); + } + } + } + } + // GCP Vertex AI returns usage data in the response + // For Anthropic models, GCP returns Anthropic format (usage.output_tokens) + // For Gemini models, GCP returns GCP format (usageMetadata.candidatesTokenCount) + else if (req.service === "gcp") { + if (service === "anthropic-chat" && body.usage?.output_tokens) { + tokens = { + token_count: body.usage.output_tokens || 0, + tokenizer: "gcp-anthropic-api", + }; + req.log.debug( + { service, outputTokens: tokens.token_count, usage: body.usage }, + "Got output token count from GCP Vertex AI (Anthropic format)" + ); + + // Sanity check: if request had cache_control, expect cache metrics + if (req.body.system || req.body.tools || req.body.messages) { + const hasCacheControl = checkForCacheControl(req.body); + if (hasCacheControl) { + const cacheRead = body.usage.cache_read_input_tokens || 0; + const cacheCreation = body.usage.cache_creation_input_tokens || 0; + if (cacheRead === 0 && cacheCreation === 0) { + req.log.error( + { keyHash: req.key?.hash, usage: body.usage }, + "CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from GCP Vertex AI API" + ); + } + } + } + } else if (body.usageMetadata) { + tokens = { + token_count: body.usageMetadata.candidatesTokenCount || 0, + tokenizer: "gcp-vertex-api", + }; + req.log.debug( + { service, outputTokens: tokens.token_count, usageMetadata: body.usageMetadata }, + "Got output token count from GCP Vertex AI (Gemini format)" + ); + } + } + // OpenAI and similar services return usage data + else if (body.usage?.completion_tokens) { + tokens = { + token_count: body.usage.completion_tokens, + tokenizer: "api-usage-data", + }; + + if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") { + // O1 consumes (a significant amount of) invisible tokens for the chain- + // of-thought reasoning. We have no way to count these other than to check + // the response body. + tokens.reasoning_tokens = + body.usage?.completion_tokens_details?.reasoning_tokens; + } + + req.log.debug( + { service, outputTokens: tokens.token_count, usage: body.usage }, + "Got output token count from API usage data" + ); + } + + // Fall back to local tokenization if no usage data is available + if (!tokens) { + const completion = getCompletionFromBody(req, body); + tokens = await countTokens({ req, completion, service }); + req.log.debug( + { service, outputTokens: tokens.token_count }, + "Counted output tokens locally (no API usage data)" + ); } req.log.debug( @@ -1193,6 +1308,35 @@ function getAwsErrorType(header: string | string[] | undefined) { return val || String(header); } +function checkForCacheControl(body: any): boolean { + // Check tools + if (Array.isArray(body.tools)) { + for (const tool of body.tools) { + if (tool.cache_control) return true; + } + } + + // Check system blocks + if (Array.isArray(body.system)) { + for (const block of body.system) { + if (block.cache_control) return true; + } + } + + // Check message content blocks + if (Array.isArray(body.messages)) { + for (const message of body.messages) { + if (Array.isArray(message.content)) { + for (const block of message.content) { + if (block.cache_control) return true; + } + } + } + } + + return false; +} + function assertJsonResponse(body: any): asserts body is Record { if (typeof body !== "object") { throw new Error(`Expected response to be an object, got ${typeof body}`); diff --git a/src/proxy/middleware/response/streaming/aggregators/anthropic-chat.ts b/src/proxy/middleware/response/streaming/aggregators/anthropic-chat.ts index a0f67c0..07ca45c 100644 --- a/src/proxy/middleware/response/streaming/aggregators/anthropic-chat.ts +++ b/src/proxy/middleware/response/streaming/aggregators/anthropic-chat.ts @@ -8,7 +8,12 @@ export type AnthropicChatCompletionResponse = { model: string; stop_reason: string | null; stop_sequence: string | null; - usage: { input_tokens: number; output_tokens: number }; + usage: { + input_tokens: number; + output_tokens: number; + cache_creation_input_tokens?: number; + cache_read_input_tokens?: number; + }; }; /** @@ -43,6 +48,24 @@ export function mergeEventsForAnthropicChat( acc.content[0].text += event.choices[0].delta.content; } + // OpenAI events may include usage data (extended by our transformer) + // Usage tokens should be set to the latest value, not accumulated + // (they represent total counts, not deltas) + if ((event as any).usage) { + if ((event as any).usage.input_tokens !== undefined) { + acc.usage.input_tokens = (event as any).usage.input_tokens; + } + if ((event as any).usage.output_tokens !== undefined) { + acc.usage.output_tokens = (event as any).usage.output_tokens; + } + if ((event as any).usage.cache_creation_input_tokens !== undefined) { + acc.usage.cache_creation_input_tokens = (event as any).usage.cache_creation_input_tokens; + } + if ((event as any).usage.cache_read_input_tokens !== undefined) { + acc.usage.cache_read_input_tokens = (event as any).usage.cache_read_input_tokens; + } + } + return acc; }, merged); return merged; diff --git a/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts b/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts index f1a1bd4..cd2c6fa 100644 --- a/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts +++ b/src/proxy/middleware/response/streaming/aggregators/openai-chat.ts @@ -10,6 +10,11 @@ export type OpenAiChatCompletionResponse = { finish_reason: string | null; index: number; }[]; + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + }; }; /** @@ -52,6 +57,17 @@ export function mergeEventsForOpenAIChat( acc.choices[0].message.content += event.choices[0].delta.content; } + // Accumulate usage data from events (OpenAI may send this in the final event) + if ((event as any).usage) { + if (!acc.usage) { + acc.usage = {}; + } + const usage = (event as any).usage; + if (usage.prompt_tokens) acc.usage.prompt_tokens = usage.prompt_tokens; + if (usage.completion_tokens) acc.usage.completion_tokens = usage.completion_tokens; + if (usage.total_tokens) acc.usage.total_tokens = usage.total_tokens; + } + return acc; }, merged); return merged; diff --git a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts index 8c24458..2f79571 100644 --- a/src/proxy/middleware/response/streaming/sse-stream-adapter.ts +++ b/src/proxy/middleware/response/streaming/sse-stream-adapter.ts @@ -50,12 +50,32 @@ export class SSEStreamAdapter extends Transform { const event = Buffer.from(bytes, "base64").toString("utf8"); const eventObj = JSON.parse(event); + // AWS Bedrock includes usage metrics in the event stream headers + // Extract and attach them to the event object for downstream processing + const invocationMetrics = headers["amazon-bedrock-invocationMetrics"]; + if (invocationMetrics?.value) { + try { + const metricsStr = typeof invocationMetrics.value === 'string' + ? invocationMetrics.value + : JSON.stringify(invocationMetrics.value); + const metricsObj = JSON.parse(metricsStr); + eventObj["amazon-bedrock-invocationMetrics"] = metricsObj; + } catch (e) { + this.log.warn( + { invocationMetrics: invocationMetrics.value }, + "Failed to parse AWS invocationMetrics" + ); + } + } + + const eventWithMetrics = JSON.stringify(eventObj); + if ("completion" in eventObj) { - return ["event: completion", `data: ${event}`].join(`\n`); + return ["event: completion", `data: ${eventWithMetrics}`].join(`\n`); } else if (eventObj.type) { - return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`); + return [`event: ${eventObj.type}`, `data: ${eventWithMetrics}`].join(`\n`); } else { - return `data: ${event}`; + return `data: ${eventWithMetrics}`; } } // noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected diff --git a/src/proxy/middleware/response/streaming/transformers/anthropic-chat-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/anthropic-chat-to-openai.ts index 583477a..3343143 100644 --- a/src/proxy/middleware/response/streaming/transformers/anthropic-chat-to-openai.ts +++ b/src/proxy/middleware/response/streaming/transformers/anthropic-chat-to-openai.ts @@ -22,8 +22,70 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = ( return { position: -1 }; } + // Try to extract usage data from message_start and message_delta events + // Also check for AWS Bedrock invocationMetrics + let usageData: { + input_tokens?: number; + output_tokens?: number; + cache_creation_input_tokens?: number; + cache_read_input_tokens?: number; + } | undefined; + try { + const parsed = JSON.parse(rawEvent.data); + if (parsed.type === "message_start" && parsed.message?.usage) { + usageData = { + input_tokens: parsed.message.usage.input_tokens, + output_tokens: parsed.message.usage.output_tokens, + cache_creation_input_tokens: parsed.message.usage.cache_creation_input_tokens, + cache_read_input_tokens: parsed.message.usage.cache_read_input_tokens, + }; + } else if (parsed.type === "message_delta" && parsed.delta?.usage) { + usageData = { + output_tokens: parsed.delta.usage.output_tokens, + cache_creation_input_tokens: parsed.delta.usage.cache_creation_input_tokens, + cache_read_input_tokens: parsed.delta.usage.cache_read_input_tokens, + }; + } + // AWS Bedrock includes usage in amazon-bedrock-invocationMetrics + // AWS uses PascalCase field names (CacheReadInputTokens, CacheWriteInputTokens) + else if (parsed["amazon-bedrock-invocationMetrics"]) { + const metrics = parsed["amazon-bedrock-invocationMetrics"]; + usageData = { + input_tokens: metrics.inputTokenCount, + output_tokens: metrics.outputTokenCount, + // Map AWS PascalCase to Anthropic snake_case + cache_read_input_tokens: metrics.cacheReadInputTokenCount, + cache_creation_input_tokens: metrics.cacheWriteInputTokenCount, + }; + log.debug( + { metrics, usageData }, + "Extracted usage from AWS invocationMetrics" + ); + } + } catch (e) { + // Ignore parsing errors + } + const deltaEvent = asAnthropicChatDelta(rawEvent); if (!deltaEvent) { + // If we have usage data but no delta, still emit an event with usage + if (usageData) { + const usageEvent = { + id: params.fallbackId, + object: "chat.completion.chunk" as const, + created: Date.now(), + model: params.fallbackModel, + choices: [ + { + index: 0, + delta: {}, + finish_reason: null, + }, + ], + usage: usageData, + }; + return { position: -1, event: usageEvent }; + } return { position: -1 }; } @@ -39,6 +101,7 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = ( finish_reason: null, }, ], + ...(usageData && { usage: usageData }), }; return { position: -1, event: newEvent }; diff --git a/src/proxy/middleware/response/streaming/transformers/google-ai-to-openai.ts b/src/proxy/middleware/response/streaming/transformers/google-ai-to-openai.ts index b60151a..8b21aab 100644 --- a/src/proxy/middleware/response/streaming/transformers/google-ai-to-openai.ts +++ b/src/proxy/middleware/response/streaming/transformers/google-ai-to-openai.ts @@ -15,6 +15,11 @@ type GoogleAIStreamEvent = { tokenCount?: number; safetyRatings: { category: string; probability: string }[]; }[]; + usageMetadata?: { + promptTokenCount: number; + candidatesTokenCount: number; + totalTokenCount: number; + }; }; /** @@ -49,7 +54,7 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => { content = content.replace(/^(.*?): /, "").trim(); } - const newEvent = { + const newEvent: any = { id: "goo-" + params.fallbackId, object: "chat.completion.chunk" as const, created: Date.now(), @@ -63,6 +68,19 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => { ], }; + // Extract usage metadata from GCP/Vertex AI responses + if (completionEvent.usageMetadata) { + newEvent.usage = { + prompt_tokens: completionEvent.usageMetadata.promptTokenCount, + completion_tokens: completionEvent.usageMetadata.candidatesTokenCount, + total_tokens: completionEvent.usageMetadata.totalTokenCount, + }; + log.debug( + { usageMetadata: completionEvent.usageMetadata }, + "Extracted usage from GCP usageMetadata" + ); + } + return { position: -1, event: newEvent }; }; diff --git a/src/shared/key-management/anthropic/provider.ts b/src/shared/key-management/anthropic/provider.ts index 0d18f70..4fb82e3 100644 --- a/src/shared/key-management/anthropic/provider.ts +++ b/src/shared/key-management/anthropic/provider.ts @@ -5,6 +5,11 @@ import { logger } from "../../../logger"; import { AnthropicModelFamily, getClaudeModelFamily } from "../../models"; import { AnthropicKeyChecker } from "./checker"; import { PaymentRequiredError } from "../../errors"; +import { + generateCacheFingerprint, + recordCacheUsage, + getCachedKeyHash, +} from "../cache-tracker"; export type AnthropicKeyUpdate = Omit< Partial, @@ -136,7 +141,7 @@ export class AnthropicKeyProvider implements KeyProvider { return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); } - public get(rawModel: string) { + public get(rawModel: string, _streaming?: boolean, requestBody?: any) { this.log.debug({ model: rawModel }, "Selecting key"); const needsMultimodal = rawModel.endsWith("-multimodal"); @@ -152,7 +157,44 @@ export class AnthropicKeyProvider implements KeyProvider { ); } + // Generate cache fingerprint if request body contains cache_control + const cacheFingerprint = requestBody + ? generateCacheFingerprint(requestBody) + : null; + + // Try to get cached key if we have a fingerprint + let preferredKeyHash: string | null = null; + let matchedFingerprint: string | null = null; + if (cacheFingerprint) { + const cacheResult = getCachedKeyHash(cacheFingerprint); + if (cacheResult) { + preferredKeyHash = cacheResult.keyHash; + matchedFingerprint = cacheResult.matchedFingerprint; + // Check if the cached key is still available + const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash); + if (cachedKey) { + this.log.debug( + { + requestedModel: rawModel, + cacheFingerprint, + keyHash: preferredKeyHash, + }, + "Using cached key for prompt caching optimization" + ); + } else { + // Cached key no longer available + preferredKeyHash = null; + matchedFingerprint = null; + this.log.debug( + { cacheFingerprint, keyHash: preferredKeyHash }, + "Cached key not available, selecting new key" + ); + } + } + } + // Select a key, from highest priority to lowest priority: + // 0. Cache affinity (if we have a cached key preference) // 1. Keys which are not rate limit locked // 2. Keys with the highest tier // 3. Keys which are not pozzed @@ -161,6 +203,12 @@ export class AnthropicKeyProvider implements KeyProvider { const now = Date.now(); const keysByPriority = availableKeys.sort((a, b) => { + // Highest priority: cache affinity + if (preferredKeyHash) { + if (a.hash === preferredKeyHash) return -1; + if (b.hash === preferredKeyHash) return 1; + } + const aLockoutPeriod = getKeyLockout(a); const bLockoutPeriod = getKeyLockout(b); @@ -183,6 +231,13 @@ export class AnthropicKeyProvider implements KeyProvider { const selectedKey = keysByPriority[0]; selectedKey.lastUsed = now; this.throttle(selectedKey.hash); + + // Record cache usage for future requests + // Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint + if (cacheFingerprint) { + recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash); + } + return { ...selectedKey }; } diff --git a/src/shared/key-management/aws/provider.ts b/src/shared/key-management/aws/provider.ts index 80ba47e..6856059 100644 --- a/src/shared/key-management/aws/provider.ts +++ b/src/shared/key-management/aws/provider.ts @@ -7,6 +7,11 @@ import { findByAnthropicId } from "../../claude-models"; import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { prioritizeKeys } from "../prioritize-keys"; import { AwsKeyChecker } from "./checker"; +import { + generateCacheFingerprint, + recordCacheUsage, + getCachedKeyHash, +} from "../cache-tracker"; // AwsBedrockKeyUsage is removed, tokenUsage from base Key interface will be used. export interface AwsBedrockKey extends Key { @@ -90,14 +95,14 @@ export class AwsBedrockKeyProvider implements KeyProvider { return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); } - public get(model: string) { + public get(model: string, _streaming?: boolean, requestBody?: any) { let neededVariantId = model; // This function accepts both Anthropic/Mistral IDs and AWS IDs. // Generally all AWS model IDs are supersets of the original vendor IDs. // Claude 2 is the only model that breaks this convention; Anthropic calls // it claude-2 but AWS calls it claude-v2. if (model.includes("claude-2")) neededVariantId = "claude-v2"; - + // For Claude models, try to resolve aliases to AWS model IDs if (model.includes("claude") && !model.includes("anthropic.")) { const claudeMapping = findByAnthropicId(model); @@ -105,7 +110,7 @@ export class AwsBedrockKeyProvider implements KeyProvider { neededVariantId = claudeMapping.awsId; } } - + const neededFamily = getAwsBedrockModelFamily(model); const availableKeys = this.keys.filter((k) => { @@ -122,6 +127,42 @@ export class AwsBedrockKeyProvider implements KeyProvider { ); }); + // Generate cache fingerprint if request body contains cache_control + const cacheFingerprint = requestBody + ? generateCacheFingerprint(requestBody) + : null; + + // Try to get cached key if we have a fingerprint + let preferredKeyHash: string | null = null; + let matchedFingerprint: string | null = null; + if (cacheFingerprint) { + const cacheResult = getCachedKeyHash(cacheFingerprint); + if (cacheResult) { + preferredKeyHash = cacheResult.keyHash; + matchedFingerprint = cacheResult.matchedFingerprint; + // Check if the cached key is still available + const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash); + if (cachedKey) { + this.log.debug( + { + requestedModel: model, + cacheFingerprint, + keyHash: preferredKeyHash, + }, + "Using cached key for prompt caching optimization" + ); + } else { + // Cached key no longer available + preferredKeyHash = null; + matchedFingerprint = null; + this.log.debug( + { cacheFingerprint, keyHash: preferredKeyHash }, + "Cached key not available, selecting new key" + ); + } + } + } + this.log.debug( { requestedModel: model, @@ -129,6 +170,8 @@ export class AwsBedrockKeyProvider implements KeyProvider { selectedFamily: neededFamily, totalKeys: this.keys.length, availableKeys: availableKeys.length, + cacheFingerprint, + hasCachedKey: !!preferredKeyHash, }, "Selecting AWS key" ); @@ -140,22 +183,36 @@ export class AwsBedrockKeyProvider implements KeyProvider { } /** - * Comparator for prioritizing keys on inference profile compatibility. - * Requests made via inference profiles have higher rate limits so we want - * to use keys with compatible inference profiles first. + * Comparator for prioritizing keys based on: + * 1. Cache affinity (if we have a cached key preference) + * 2. Inference profile compatibility */ - const hasInferenceProfile = ( - a: AwsBedrockKey, - b: AwsBedrockKey - ) => { + const keyComparator = (a: AwsBedrockKey, b: AwsBedrockKey) => { + // Highest priority: cache affinity + if (preferredKeyHash) { + if (a.hash === preferredKeyHash) return -1; + if (b.hash === preferredKeyHash) return 1; + } + + // Second priority: inference profile compatibility const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model)); const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model)); - return aMatch - bMatch; + const profileDiff = bMatch - aMatch; + if (profileDiff !== 0) return profileDiff; + + return 0; }; - const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0]; + const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0]; selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); + + // Record cache usage for future requests + // Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint + if (cacheFingerprint) { + recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash); + } + return { ...selectedKey }; } diff --git a/src/shared/key-management/cache-tracker.ts b/src/shared/key-management/cache-tracker.ts new file mode 100644 index 0000000..2571d45 --- /dev/null +++ b/src/shared/key-management/cache-tracker.ts @@ -0,0 +1,457 @@ +import crypto from "crypto"; +import { logger } from "../../logger"; + +/** + * Deterministic JSON stringify that sorts object keys to ensure consistent hashing. + */ +function deterministicStringify(obj: any): string { + if (obj === null || obj === undefined) return String(obj); + if (typeof obj !== "object") return JSON.stringify(obj); + if (Array.isArray(obj)) return `[${obj.map(deterministicStringify).join(",")}]`; + + const keys = Object.keys(obj).sort(); + const pairs = keys.map(k => `"${k}":${deterministicStringify(obj[k])}`); + return `{${pairs.join(",")}}`; +} + +/** + * Universal cache tracker for all providers (Anthropic, AWS, GCP). + * + * Tracks which keys have cached which prompt prefixes to optimize prompt caching. + * Each API key has its own cache, so routing requests with identical cacheable + * content to the same key maximizes cache hits. + * + * Cache rules (per Anthropic/Bedrock/Vertex docs): + * - 100% identical prompt segments required for cache hit + * - Default 5-minute TTL, refreshed on each use + * - Optional 1-hour TTL + * - Cache becomes available after first response begins + * - Minimum 1024-2048 tokens (model-dependent) + */ + +interface CacheEntry { + keyHash: string; + lastUsed: number; + hitCount: number; + ttl: number; + // Store all prefix fingerprints for hierarchical matching + prefixFingerprints?: string[]; +} + +const log = logger.child({ module: "cache-tracker" }); + +// Maps cache fingerprints to the key that has cached that content +const cacheMap = new Map(); + +// Default TTLs in milliseconds +const TTL_5_MINUTES = 5 * 60 * 1000; +const TTL_1_HOUR = 60 * 60 * 1000; + +/** + * Generates a fingerprint for cacheable content in a request. + * The fingerprint includes content up to and including the LAST cache_control + * breakpoint, following Anthropic's hierarchy: tools → system → messages. + * + * Key insight for handling ever-growing prompts: + * By fingerprinting only up to the last cache_control marker, we ensure that + * new content added AFTER that marker (e.g., new user messages in a conversation) + * won't change the fingerprint. This enables cache hits as conversations grow. + * + * How Anthropic's cache matching works: + * - You place ONE cache_control at the end of your static/stable content + * - The API automatically checks ~20 blocks BEFORE that breakpoint for cache hits + * - It uses the longest matching prefix automatically + * - You don't need multiple breakpoints - one at the end is sufficient + * + * Example conversation flow: + * Request 1: [system prompt + cache_control] → Creates cache, fingerprint = "abc123" + * Request 2: [system prompt + cache_control] + [user msg] + [assistant msg] + [user msg] + * → Same fingerprint "abc123" → Cache HIT (new messages after breakpoint ignored) + * Request 3: [MODIFIED system prompt + cache_control] + messages + * → Different fingerprint "def456" → Cache MISS (prefix changed) + */ +export function generateCacheFingerprint(body: any): string | null { + if (!body) { + return null; + } + + const parts: any[] = []; + let hasCacheControl = false; + const cacheBreakpoints: number[] = []; // Indices of cache_control markers + const ttls: number[] = []; // TTL for each breakpoint + + // 1. Process tools if present (tools come first in hierarchy) + if (body.tools && Array.isArray(body.tools)) { + for (let i = 0; i < body.tools.length; i++) { + const tool = body.tools[i]; + // Only include the stable parts of the tool definition in the fingerprint + // Exclude cache_control as it's metadata, not part of the tool definition + const { cache_control, ...toolWithoutCache } = tool; + parts.push({ type: "tool", tool: toolWithoutCache }); + if (cache_control) { + hasCacheControl = true; + cacheBreakpoints.push(parts.length - 1); + ttls.push(parseTTL(cache_control.ttl)); + } + } + } + + // 2. Process system prompt if present + if (body.system) { + if (typeof body.system === "string") { + // Normalize string system to same structure as array blocks for consistent fingerprinting + parts.push({ type: "system_block", block_type: "text", text: body.system }); + } else if (Array.isArray(body.system)) { + for (const block of body.system) { + // System blocks can be: + // - {type: "text", text: "...", cache_control?: ...} + // - {type: "image", source: {...}, cache_control?: ...} + const { cache_control, type, ...rest } = block; + + // Create a consistent fingerprint structure + const contentPart: any = { type: "system_block", block_type: type }; + + switch (type) { + case "text": + contentPart.text = (rest as any).text; + break; + case "image": + // Hash image data for fingerprint + if ((rest as any).source?.data) { + contentPart.image_hash = crypto + .createHash("sha256") + .update((rest as any).source.data) + .digest("hex") + .slice(0, 16); + } + contentPart.media_type = (rest as any).source?.media_type; + break; + default: + // For unknown block types, include all remaining fields + Object.assign(contentPart, rest); + } + + parts.push(contentPart); + + if (cache_control) { + hasCacheControl = true; + cacheBreakpoints.push(parts.length - 1); + ttls.push(parseTTL(cache_control.ttl)); + } + } + } + } + + // 3. Process messages + if (body.messages && Array.isArray(body.messages)) { + for (const message of body.messages) { + if (typeof message.content === "string") { + parts.push({ + type: "message", + role: message.role, + content: message.content, + }); + } else if (Array.isArray(message.content)) { + // For multimodal content, include each block + for (const block of message.content) { + // CRITICAL: Exclude cache_control metadata from fingerprinting + // Only the actual content should affect the fingerprint + const { cache_control, ...blockWithoutCache } = block; + + const contentPart: any = { + type: "message_block", + role: message.role, + block_type: blockWithoutCache.type, + }; + + // Include essential identifying info without full data + // (full data would make fingerprints too large) + switch (blockWithoutCache.type) { + case "text": + contentPart.text = blockWithoutCache.text; + break; + case "image": + // Hash image data for fingerprint + if (blockWithoutCache.source?.data) { + contentPart.image_hash = crypto + .createHash("sha256") + .update(blockWithoutCache.source.data) + .digest("hex") + .slice(0, 16); + } + contentPart.media_type = blockWithoutCache.source?.media_type; + break; + case "tool_use": + // Don't include tool_id as it's often randomly generated + // Only include tool name and input which are stable + contentPart.tool_name = blockWithoutCache.name; + contentPart.tool_input = blockWithoutCache.input; + break; + case "tool_result": + // Don't include tool_use_id as it's randomly generated + // Only include the actual result content and error status + contentPart.is_error = blockWithoutCache.is_error; + if (typeof blockWithoutCache.content === "string") { + contentPart.content = blockWithoutCache.content; + } else if (Array.isArray(blockWithoutCache.content)) { + // tool_result content can be an array of content blocks + contentPart.content = blockWithoutCache.content; + } + break; + } + + parts.push(contentPart); + + if (cache_control) { + hasCacheControl = true; + cacheBreakpoints.push(parts.length - 1); + ttls.push(parseTTL(cache_control.ttl)); + } + } + } + } + } + + // No caching if no cache_control directives present + if (!hasCacheControl || cacheBreakpoints.length === 0) { + return null; + } + + const maxTTL = Math.max(...ttls); + + // Generate individual hashes for each part (tool, system block, message) + // This allows prefix matching even when conversations grow + const partHashes: string[] = []; + for (const part of parts) { + const partHash = crypto + .createHash("sha256") + .update(deterministicStringify(part)) + .digest("hex") + .slice(0, 8); // Shorter hash per part + partHashes.push(partHash); + } + + // Generate fingerprints for each cache_control position + // Fingerprint is the concatenation of individual part hashes up to the breakpoint + const prefixFingerprints: string[] = []; + + for (const breakpoint of cacheBreakpoints) { + // Concatenate part hashes up to and including the breakpoint + const fingerprint = partHashes.slice(0, breakpoint + 1).join(""); + + log.trace( + { + fingerprint: fingerprint.substring(0, 32) + "...", + breakpoint, + prefixPartsCount: breakpoint + 1, + }, + "Generated cache fingerprint" + ); + + prefixFingerprints.push(fingerprint); + } + + // Return the deepest (last) fingerprint as the primary one + const primaryFingerprint = prefixFingerprints[prefixFingerprints.length - 1]; + + // Store ALL prefix fingerprints in the cache map, not just the primary one + // This allows future requests to match against any prefix level + for (let i = 0; i < prefixFingerprints.length; i++) { + const fp = prefixFingerprints[i]; + const existing = cacheMap.get(fp); + + if (!existing) { + cacheMap.set(fp, { + keyHash: "", + lastUsed: 0, + hitCount: 0, + ttl: maxTTL, + prefixFingerprints: prefixFingerprints.slice(0, i + 1), + }); + } + } + + return primaryFingerprint; +} + +function parseTTL(ttl?: string): number { + if (ttl === "1h") return TTL_1_HOUR; + return TTL_5_MINUTES; // Default or "5m" +} + +/** + * Records that a key was used for a request with the given cache fingerprint. + */ +export function recordCacheUsage(fingerprint: string, keyHash: string): void { + if (!fingerprint) return; + + const entry = cacheMap.get(fingerprint); + if (!entry) { + // Shouldn't happen, but handle gracefully + cacheMap.set(fingerprint, { + keyHash, + lastUsed: Date.now(), + hitCount: 1, + ttl: TTL_5_MINUTES, + }); + log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "New cache entry recorded"); + return; + } + + const now = Date.now(); + + if (entry.keyHash === keyHash) { + // Same key - likely cache hit + entry.lastUsed = now; + entry.hitCount++; + log.trace( + { fingerprint: fingerprint.substring(0, 32) + "...", keyHash, hitCount: entry.hitCount }, + "Cache usage recorded (likely cache hit)" + ); + } else if (entry.keyHash === "") { + // First use of this fingerprint + entry.keyHash = keyHash; + entry.lastUsed = now; + entry.hitCount = 1; + log.debug({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "First cache usage for fingerprint"); + } else { + // Different key - cache miss, reset tracking + log.debug( + { fingerprint: fingerprint.substring(0, 32) + "...", oldKey: entry.keyHash, newKey: keyHash }, + "Cache key changed (will cause cache miss)" + ); + entry.keyHash = keyHash; + entry.lastUsed = now; + entry.hitCount = 1; + } +} + +/** + * Gets the key hash that has cached the longest matching prefix for the given + * fingerprint or any of its sub-prefixes. + * + * This is crucial for handling moving cache breakpoints: + * - If the current request has fingerprints ["fp1", "fp2", "fp3"] + * - We search backwards from "fp3" → "fp2" → "fp1" + * - Return the key that cached the longest available prefix + * + * Returns null if no cached key exists or all caches have expired. + */ +export function getCachedKeyHash(fingerprint: string): { keyHash: string; matchedFingerprint: string } | null { + if (!fingerprint) return null; + + const now = Date.now(); + + // First, try exact match on the primary (deepest) fingerprint + const primaryEntry = cacheMap.get(fingerprint); + if (primaryEntry && primaryEntry.keyHash) { + const age = now - primaryEntry.lastUsed; + if (age <= primaryEntry.ttl) { + log.trace( + { + fingerprint: fingerprint.substring(0, 32) + "...", + keyHash: primaryEntry.keyHash, + age, + hitCount: primaryEntry.hitCount, + matchType: "exact", + }, + "Cache entry found (exact match)" + ); + return { keyHash: primaryEntry.keyHash, matchedFingerprint: fingerprint }; + } else { + log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", age, ttl: primaryEntry.ttl }, "Cache entry expired"); + cacheMap.delete(fingerprint); + } + } + + // If no exact match, search all cache entries for prefix matches + // Since fingerprints are now concatenated part hashes, we can check if one is a prefix of another + let bestMatch: { keyHash: string; matchedFingerprint: string } | null = null; + let longestMatchLength = 0; + + for (const [cachedFp, entry] of cacheMap.entries()) { + if (!entry.keyHash) continue; + + const age = now - entry.lastUsed; + if (age > entry.ttl) { + cacheMap.delete(cachedFp); + continue; + } + + // Check if the cached fingerprint is a prefix of the current fingerprint + // (cached request had fewer parts, current request has grown) + if (fingerprint.startsWith(cachedFp) && cachedFp.length > longestMatchLength) { + bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp }; + longestMatchLength = cachedFp.length; + log.trace( + { + requestFingerprint: fingerprint.substring(0, 32) + "...", + matchedFingerprint: cachedFp.substring(0, 32) + "...", + keyHash: entry.keyHash, + matchType: "prefix", + }, + "Cache entry found (prefix match)" + ); + } + + // Also check if the current fingerprint is a prefix of the cached one + // (current request has fewer parts, cached had more) + if (cachedFp.startsWith(fingerprint) && fingerprint.length > longestMatchLength) { + bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp }; + longestMatchLength = fingerprint.length; + log.trace( + { + requestFingerprint: fingerprint.substring(0, 32) + "...", + matchedFingerprint: cachedFp.substring(0, 32) + "...", + keyHash: entry.keyHash, + matchType: "prefix_reverse", + }, + "Cache entry found (reverse prefix match)" + ); + } + } + + return bestMatch || null; +} + +/** + * Clears expired cache entries periodically to prevent memory leaks. + */ +export function cleanupExpiredCaches(): void { + const now = Date.now(); + let cleaned = 0; + + for (const [fingerprint, entry] of cacheMap.entries()) { + const age = now - entry.lastUsed; + if (age > entry.ttl) { + cacheMap.delete(fingerprint); + cleaned++; + } + } + + if (cleaned > 0) { + log.debug( + { cleaned, remaining: cacheMap.size }, + "Cleaned up expired cache entries" + ); + } +} + +/** + * Returns cache statistics for monitoring. + */ +export function getCacheStats() { + return { + totalEntries: cacheMap.size, + entries: Array.from(cacheMap.entries()).map(([fp, entry]) => ({ + fingerprint: fp, + keyHash: entry.keyHash, + age: Date.now() - entry.lastUsed, + hitCount: entry.hitCount, + ttl: entry.ttl, + })), + }; +} + +// Run cleanup every minute +setInterval(cleanupExpiredCaches, 60 * 1000); \ No newline at end of file diff --git a/src/shared/key-management/gcp/provider.ts b/src/shared/key-management/gcp/provider.ts index a5a1879..d13e72a 100644 --- a/src/shared/key-management/gcp/provider.ts +++ b/src/shared/key-management/gcp/provider.ts @@ -6,6 +6,11 @@ import { GcpModelFamily, getGcpModelFamily } from "../../models"; import { createGenericGetLockoutPeriod, Key, KeyProvider } from ".."; import { prioritizeKeys } from "../prioritize-keys"; import { GcpKeyChecker } from "./checker"; +import { + generateCacheFingerprint, + recordCacheUsage, + getCachedKeyHash, +} from "../cache-tracker"; // GcpKeyUsage is removed, tokenUsage from base Key interface will be used. export interface GcpKey extends Key { @@ -90,7 +95,7 @@ export class GcpKeyProvider implements KeyProvider { return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); } - public get(model: string) { + public get(model: string, _streaming?: boolean, requestBody?: any) { const neededFamily = getGcpModelFamily(model); // this is a horrible mess @@ -115,6 +120,42 @@ export class GcpKeyProvider implements KeyProvider { ); }); + // Generate cache fingerprint if request body contains cache_control + const cacheFingerprint = requestBody + ? generateCacheFingerprint(requestBody) + : null; + + // Try to get cached key if we have a fingerprint + let preferredKeyHash: string | null = null; + let matchedFingerprint: string | null = null; + if (cacheFingerprint) { + const cacheResult = getCachedKeyHash(cacheFingerprint); + if (cacheResult) { + preferredKeyHash = cacheResult.keyHash; + matchedFingerprint = cacheResult.matchedFingerprint; + // Check if the cached key is still available + const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash); + if (cachedKey) { + this.log.debug( + { + requestedModel: model, + cacheFingerprint, + keyHash: preferredKeyHash, + }, + "Using cached key for prompt caching optimization" + ); + } else { + // Cached key no longer available + preferredKeyHash = null; + matchedFingerprint = null; + this.log.debug( + { cacheFingerprint, keyHash: preferredKeyHash }, + "Cached key not available, selecting new key" + ); + } + } + } + this.log.debug( { model, @@ -124,6 +165,8 @@ export class GcpKeyProvider implements KeyProvider { needsSonnet35, availableKeys: availableKeys.length, totalKeys: this.keys.length, + cacheFingerprint, + hasCachedKey: !!preferredKeyHash, }, "Selecting GCP key" ); @@ -134,9 +177,28 @@ export class GcpKeyProvider implements KeyProvider { ); } - const selectedKey = prioritizeKeys(availableKeys)[0]; + /** + * Comparator for prioritizing keys based on cache affinity. + */ + const keyComparator = (a: GcpKey, b: GcpKey) => { + // Highest priority: cache affinity + if (preferredKeyHash) { + if (a.hash === preferredKeyHash) return -1; + if (b.hash === preferredKeyHash) return 1; + } + return 0; + }; + + const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0]; selectedKey.lastUsed = Date.now(); this.throttle(selectedKey.hash); + + // Record cache usage for future requests + // Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint + if (cacheFingerprint) { + recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash); + } + return { ...selectedKey }; } diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index e02957d..8be3f29 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -61,7 +61,7 @@ for service-agnostic functionality. export interface KeyProvider { readonly service: LLMService; init(): void; - get(model: string, streaming?: boolean): T; + get(model: string, streaming?: boolean, requestBody?: any): T; list(): Omit[]; disable(key: T): void; update(hash: string, update: Partial): void; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index 9ce2178..a226fcf 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -55,15 +55,15 @@ export class KeyPool { this.scheduleRecheck(); } - public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean): Key { + public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean, requestBody?: any): Key { // hack for some claude requests needing keys with particular permissions // even though they use the same models as the non-multimodal requests if (multimodal) { model += "-multimodal"; } - + const queryService = service || this.getServiceForModel(model); - return this.getKeyProvider(queryService).get(model, streaming); + return this.getKeyProvider(queryService).get(model, streaming, requestBody); } public list(): Omit[] { diff --git a/src/shared/tokenization/anthropic-remote.ts b/src/shared/tokenization/anthropic-remote.ts new file mode 100644 index 0000000..9cc736d --- /dev/null +++ b/src/shared/tokenization/anthropic-remote.ts @@ -0,0 +1,66 @@ +import { AnthropicChatMessage } from "../api-schemas"; +import { getAxiosInstance } from "../network"; +import { logger } from "../../logger"; +import { AnthropicKey } from "../key-management/anthropic/provider"; + +const log = logger.child({ module: "tokenizer", service: "anthropic-remote" }); + +export interface AnthropicTokenCountRequest { + model: string; + messages: AnthropicChatMessage[]; + system?: string | Array<{ type: string; text: string }>; + tools?: unknown[]; +} + +export interface AnthropicTokenCountResponse { + input_tokens: number; +} + +/** + * Counts tokens using Anthropic's remote token counting API endpoint. + * https://docs.claude.com/en/docs/build-with-claude/token-counting + */ +export async function countTokensRemote( + request: AnthropicTokenCountRequest, + key: AnthropicKey +): Promise<{ token_count: number; tokenizer: string }> { + const axios = getAxiosInstance(); + + try { + const response = await axios.post( + "https://api.anthropic.com/v1/messages/count_tokens", + request, + { + headers: { + "x-api-key": key.key, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + timeout: 5000, // 5 second timeout + } + ); + + log.debug( + { + model: request.model, + input_tokens: response.data.input_tokens, + }, + "Counted tokens via Anthropic API" + ); + + return { + token_count: response.data.input_tokens, + tokenizer: "anthropic-remote-api", + }; + } catch (error: any) { + log.warn( + { + error: error.message, + status: error.response?.status, + data: error.response?.data, + }, + "Failed to count tokens via Anthropic API, will fall back to local tokenizer" + ); + throw error; + } +} \ No newline at end of file diff --git a/src/shared/tokenization/aws-remote.ts b/src/shared/tokenization/aws-remote.ts new file mode 100644 index 0000000..55b2175 --- /dev/null +++ b/src/shared/tokenization/aws-remote.ts @@ -0,0 +1,117 @@ +import { Sha256 } from "@aws-crypto/sha256-js"; +import { SignatureV4 } from "@smithy/signature-v4"; +import { HttpRequest } from "@smithy/protocol-http"; +import { getAxiosInstance } from "../network"; +import { logger } from "../../logger"; +import { AwsBedrockKey } from "../key-management/aws/provider"; + +const log = logger.child({ module: "tokenizer", service: "aws-remote" }); + +const AMZ_HOST = + process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com"; + +export interface AwsTokenCountRequest { + input: { + invokeModel: { + body: string; // base64-encoded JSON string + }; + }; +} + +export interface AwsTokenCountResponse { + inputTokens: number; +} + +type Credential = { + accessKeyId: string; + secretAccessKey: string; + region: string; +}; + +function getCredentialParts(key: AwsBedrockKey): Credential { + const [accessKeyId, secretAccessKey, region] = key.key.split(":"); + + if (!accessKeyId || !secretAccessKey || !region) { + throw new Error("AWS_CREDENTIALS isn't correctly formatted"); + } + + return { accessKeyId, secretAccessKey, region }; +} + +async function sign(request: HttpRequest, credential: Credential) { + const { accessKeyId, secretAccessKey, region } = credential; + + const signer = new SignatureV4({ + sha256: Sha256, + credentials: { accessKeyId, secretAccessKey }, + region, + service: "bedrock", + }); + + return signer.sign(request); +} + +/** + * Counts tokens using AWS Bedrock's remote token counting API endpoint. + * https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html + */ +export async function countTokensRemote( + modelId: string, + request: AwsTokenCountRequest, + key: AwsBedrockKey +): Promise<{ token_count: number; tokenizer: string }> { + const axios = getAxiosInstance(); + const credential = getCredentialParts(key); + const host = AMZ_HOST.replace("%REGION%", credential.region); + + // Create the HTTP request to sign + const httpRequest = new HttpRequest({ + method: "POST", + protocol: "https:", + hostname: host, + path: `/model/${modelId}/count-tokens`, + headers: { + ["Host"]: host, + ["content-type"]: "application/json", + }, + body: JSON.stringify(request), + }); + + try { + // Sign the request using AWS Signature V4 + const signedRequest = await sign(httpRequest, credential); + + // Make the request + const response = await axios.post( + `https://${host}${signedRequest.path}`, + signedRequest.body, + { + headers: signedRequest.headers as Record, + timeout: 5000, // 5 second timeout + } + ); + + log.debug( + { + modelId, + input_tokens: response.data.inputTokens, + }, + "Counted tokens via AWS Bedrock API" + ); + + return { + token_count: response.data.inputTokens, + tokenizer: "aws-bedrock-remote-api", + }; + } catch (error: any) { + log.warn( + { + error: error.message, + status: error.response?.status, + data: error.response?.data, + }, + "Failed to count tokens via AWS Bedrock API, will fall back to local tokenizer" + ); + throw error; + } +} \ No newline at end of file diff --git a/src/shared/tokenization/claude.ts b/src/shared/tokenization/claude.ts index 6e1bf0a..cf08179 100644 --- a/src/shared/tokenization/claude.ts +++ b/src/shared/tokenization/claude.ts @@ -26,7 +26,7 @@ export async function getTokenCount( return getTokenCountForMessages(prompt); } - if (prompt.length > 800000) { + if (prompt.length > 5000000) { throw new Error("Content is too large to tokenize."); } @@ -59,7 +59,7 @@ async function getTokenCountForMessages({ switch (part.type) { case "text": const { text } = part; - if (text.length > 800000 || numTokens > 200000) { + if (text.length > 5000000 || numTokens > 1200000) { throw new Error("Text content is too large to tokenize."); } numTokens += encoder.encode(text.normalize("NFKC"), "all").length; diff --git a/src/shared/tokenization/gcp-remote.ts b/src/shared/tokenization/gcp-remote.ts new file mode 100644 index 0000000..087a7c2 --- /dev/null +++ b/src/shared/tokenization/gcp-remote.ts @@ -0,0 +1,104 @@ +import { getAxiosInstance } from "../network"; +import { logger } from "../../logger"; +import { GcpKey } from "../key-management/gcp/provider"; +import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../key-management/gcp/oauth"; + +const log = logger.child({ module: "tokenizer", service: "gcp-remote" }); + +export interface GcpTokenCountRequest { + contents: Array<{ + role: string; + parts: Array<{ + text?: string; + inline_data?: { + mime_type: string; + data: string; + }; + }>; + }>; + systemInstruction?: { + parts: Array<{ + text: string; + }>; + }; + tools?: unknown[]; +} + +export interface GcpTokenCountResponse { + totalTokens: number; + totalBillableCharacters?: number; + promptTokensDetails?: Array<{ + modality: string; + tokenCount: number; + }>; +} + +/** + * Counts tokens using GCP Vertex AI's remote token counting API endpoint. + * https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/count-tokens + */ +export async function countTokensRemote( + model: string, + request: GcpTokenCountRequest, + key: GcpKey +): Promise<{ token_count: number; tokenizer: string }> { + const axios = getAxiosInstance(); + + // Ensure we have a valid access token + const now = Date.now(); + if (!key.accessToken || now >= key.accessTokenExpiresAt) { + const [token, expiresIn] = await refreshGcpAccessToken(key); + key.accessToken = token; + key.accessTokenExpiresAt = now + expiresIn * 1000; + } + + const { projectId, region } = await getCredentialsFromGcpKey(key); + + // Extract just the model name (e.g., "gemini-1.5-pro" from full model ID) + // GCP model IDs are typically like "gemini-1.5-pro-001" or just "gemini-1.5-pro" + let modelName = model; + if (model.includes("/")) { + // Handle full resource names like "projects/.../models/gemini-1.5-pro" + modelName = model.split("/").pop()!; + } + + const url = `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/models/${modelName}:countTokens`; + + try { + const response = await axios.post( + url, + request, + { + headers: { + "Authorization": `Bearer ${key.accessToken}`, + "content-type": "application/json", + }, + timeout: 5000, // 5 second timeout + } + ); + + log.debug( + { + model: modelName, + total_tokens: response.data.totalTokens, + details: response.data.promptTokensDetails, + }, + "Counted tokens via GCP Vertex AI API" + ); + + return { + token_count: response.data.totalTokens, + tokenizer: "gcp-vertex-ai-remote-api", + }; + } catch (error: any) { + log.warn( + { + error: error.message, + status: error.response?.status, + data: error.response?.data, + }, + "Failed to count tokens via GCP Vertex AI API, will fall back to local tokenizer" + ); + throw error; + } +} \ No newline at end of file diff --git a/src/shared/tokenization/tokenizer.ts b/src/shared/tokenization/tokenizer.ts index b3b2457..b8d805b 100644 --- a/src/shared/tokenization/tokenizer.ts +++ b/src/shared/tokenization/tokenizer.ts @@ -21,6 +21,17 @@ import { MistralAIChatMessage, OpenAIChatMessage, } from "../api-schemas"; +import { countTokensRemote as countAnthropicTokensRemote } from "./anthropic-remote"; +import { countTokensRemote as countAwsTokensRemote } from "./aws-remote"; +import { countTokensRemote as countGcpTokensRemote } from "./gcp-remote"; +import { AnthropicKey } from "../key-management/anthropic/provider"; +import { AwsBedrockKey } from "../key-management/aws/provider"; +import { GcpKey } from "../key-management/gcp/provider"; +import { logger } from "../../logger"; +import { config } from "../../config"; +import { findByAnthropicId } from "../claude-models"; + +const log = logger.child({ module: "tokenizer" }); export async function init() { initClaude(); @@ -99,6 +110,87 @@ export async function countTokens({ completion, }: TokenCountRequest): Promise { const time = process.hrtime(); + + // For prompt counting, try remote APIs first (only when enabled, have a key, and counting prompts) + if (config.useRemoteTokenCounting && prompt && req.key) { + try { + switch (service) { + case "anthropic-chat": { + if (req.service === "anthropic" && req.key.service === "anthropic") { + const result = await countAnthropicTokensRemote( + { + model: req.body.model, + messages: prompt.messages, + system: prompt.system, + tools: req.body.tools, + }, + req.key as AnthropicKey + ); + return { ...result, tokenization_duration_ms: getElapsedMs(time) }; + } + break; + } + case "anthropic-text": { + if (req.service === "anthropic" && req.key.service === "anthropic") { + // Anthropic's API doesn't support text completion counting, fall through to local + break; + } + break; + } + } + + // For AWS Bedrock services, use AWS token counting + if (req.service === "aws" && req.key.service === "aws") { + if (service === "anthropic-chat") { + // Convert Anthropic model ID to AWS Bedrock model ID + const anthropicModelId = req.body.model; + const claudeMapping = findByAnthropicId(anthropicModelId); + const awsModelId = claudeMapping?.awsId || anthropicModelId; + + // Build the request body in Anthropic format - must include all required fields + const bodyObj: any = { + messages: req.body.messages, + max_tokens: req.body.max_tokens, + anthropic_version: req.body.anthropic_version || "bedrock-2023-05-31", + }; + if (req.body.system) bodyObj.system = req.body.system; + if (req.body.tools) bodyObj.tools = req.body.tools; + if (req.body.tool_choice) bodyObj.tool_choice = req.body.tool_choice; + + // AWS expects the body as a base64-encoded string + const bodyJson = JSON.stringify(bodyObj); + const bodyBase64 = Buffer.from(bodyJson, "utf-8").toString("base64"); + + const result = await countAwsTokensRemote( + awsModelId, + { + input: { + invokeModel: { + body: bodyBase64, + }, + }, + }, + req.key as AwsBedrockKey + ); + return { ...result, tokenization_duration_ms: getElapsedMs(time) }; + } + } + + // For GCP Vertex AI services, use GCP token counting + if (req.service === "gcp" && req.key.service === "gcp") { + // GCP uses a different format, would need transformation + // For now, fall through to local counting + } + } catch (error) { + // Fall through to local tokenization + log.debug( + { error: (error as Error).message, service }, + "Remote token counting failed, using local tokenizer" + ); + } + } + + // Fall back to local tokenization switch (service) { case "anthropic-chat": case "anthropic-text":