claude tokenizer+cache working

This commit is contained in:
reanon
2025-10-02 12:57:06 +02:00
parent 3247698173
commit b16bb6d17d
25 changed files with 1427 additions and 56 deletions
+9
View File
@@ -286,6 +286,14 @@ type Config = {
googleSheetsSpreadsheetId?: string;
/** Whether to periodically check keys for usage and validity. */
checkKeys: boolean;
/**
* Whether to use remote API token counting endpoints for Anthropic, AWS
* Bedrock, and GCP Vertex AI. When enabled, the proxy will use the provider's
* token counting API to get accurate token counts for prompts (including
* images). Output tokens are always counted from API responses when available.
* Falls back to local tokenization if remote counting fails.
*/
useRemoteTokenCounting: boolean;
/** Whether to publicly show total token costs on the info page. */
showTokenCosts: boolean;
/**
@@ -567,6 +575,7 @@ export const config: Config = {
),
logLevel: getEnvWithDefault("LOG_LEVEL", "info"),
checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev),
useRemoteTokenCounting: getEnvWithDefault("USE_REMOTE_TOKEN_COUNTING", true),
showTokenCosts: getEnvWithDefault("SHOW_TOKEN_COSTS", false),
allowAwsLogging: getEnvWithDefault("ALLOW_AWS_LOGGING", false),
promptLogging: getEnvWithDefault("PROMPT_LOGGING", false),
+13 -6
View File
@@ -196,8 +196,11 @@ function setAnthropicBetaHeader(req: Request) {
betaHeaders.push("extended-cache-ttl-2025-04-11");
}
// Add 1M context beta header for Claude Sonnet 4 if context > 200k tokens
if (model?.includes("claude-sonnet-4") && req.promptTokens && req.outputTokens) {
// Add 1M context beta header for Claude Sonnet 4/Opus 4 if context > 200k tokens
const supportsBigContext =
model?.includes("claude-sonnet-4") ||
model?.includes("claude-opus-4");
if (supportsBigContext && req.promptTokens && req.outputTokens) {
const contextTokens = req.promptTokens + req.outputTokens;
if (contextTokens > 200000) {
betaHeaders.push("context-1m-2025-08-07");
@@ -211,8 +214,8 @@ function setAnthropicBetaHeader(req: Request) {
}
/**
* Adds web search tool for Claude-3.5 and Claude-3.7 models when enable_web_search is true
*
* Adds web search tool for Claude-3.5, Claude-3.7, Claude-4, and Claude-4.5 models when enable_web_search is true
*
* Supports all optional parameters documented in the Claude API:
* - max_uses: Limit the number of searches per request
* - allowed_domains: Only include results from these domains
@@ -323,7 +326,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware(
*/
const preprocessAnthropicTextRequest: RequestHandler = (req, res, next) => {
const model = req.body.model;
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
const isClaude4Model =
model?.includes("claude-sonnet-4") ||
model?.includes("claude-opus-4");
if (model?.startsWith("claude-3") || isClaude4Model) {
textToChatPreprocessor(req, res, next);
} else {
@@ -356,7 +361,9 @@ const oaiToChatPreprocessor = createPreprocessorMiddleware(
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
maybeReassignModel(req);
const model = req.body.model;
const isClaude4 = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
const isClaude4 =
model?.includes("claude-sonnet-4") ||
model?.includes("claude-opus-4");
if (model?.includes("claude-3") || isClaude4) {
oaiToChatPreprocessor(req, res, next);
} else {
+6 -2
View File
@@ -102,7 +102,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware(
*/
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
const model = req.body.model;
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
const isClaude4Model =
model?.includes("claude-sonnet-4") ||
model?.includes("claude-opus-4");
if (model?.includes("claude-3") || isClaude4Model) {
textToChatPreprocessor(req, res, next);
} else {
@@ -126,7 +128,9 @@ const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
const model = req.body.model;
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
const isClaude4Model =
model?.includes("claude-sonnet-4") ||
model?.includes("claude-opus-4");
if (model?.includes("claude-3") || isClaude4Model) {
oaiToAwsChatPreprocessor(req, res, next);
} else {
@@ -32,8 +32,9 @@ export const addKey: ProxyReqMutator = (manager) => {
if (inboundApi === outboundApi) {
// Pass streaming information for GPT-5 models that require verified keys for streaming
// Pass request body for cache-aware key selection (Anthropic, AWS, GCP)
const isStreaming = body.stream === true;
assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming);
assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming, body);
} else {
switch (outboundApi) {
// If we are translating between API formats we may need to select a model
@@ -45,7 +46,7 @@ export const addKey: ProxyReqMutator = (manager) => {
case "mistral-ai":
case "mistral-text":
case "google-ai":
assignedKey = keyPool.get(body.model, service);
assignedKey = keyPool.get(body.model, service, undefined, undefined, body);
break;
case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
@@ -24,7 +24,8 @@ const AMZ_HOST =
export const signAwsRequest: ProxyReqMutator = async (manager) => {
const req = manager.request;
const { model, stream } = req.body;
const key = keyPool.get(model, "aws") as AwsBedrockKey;
// Pass request body to enable cache-aware key selection
const key = keyPool.get(model, "aws", undefined, stream, req.body) as AwsBedrockKey;
manager.setKey(key);
let system = req.body.system ?? "";
@@ -50,6 +51,8 @@ export const signAwsRequest: ProxyReqMutator = async (manager) => {
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
// with the headers generated by the SDK.
// Note: AWS Bedrock uses body parameters (anthropic_version) instead of headers
// for versioning, so we don't include anthropic-beta or anthropic-version headers.
const newRequest = new HttpRequest({
method: "POST",
protocol: "https:",
@@ -130,6 +133,9 @@ function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown {
.parse(req.body);
break;
case "anthropic-chat":
// Preserve anthropic_version if user provided it (for beta features)
const userAnthropicVersion = (req.body as any).anthropic_version;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
@@ -140,11 +146,14 @@ function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown {
top_p: true,
tools: true,
tool_choice: true,
thinking: true
thinking: true,
cache_control: true
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
// Use user-provided version or default to bedrock-2023-05-31
strippedParams.anthropic_version = userAnthropicVersion || "bedrock-2023-05-31";
break;
case "mistral-ai":
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
@@ -20,7 +20,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
}
const { model } = req.body;
const key: GcpKey = keyPool.get(model, "gcp") as GcpKey;
const key: GcpKey = keyPool.get(model, "gcp", undefined, undefined, req.body) as GcpKey;
if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) {
const [token, durationSec] = await refreshGcpAccessToken(key);
@@ -38,6 +38,9 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
// TODO: This should happen in transform-outbound-payload.ts
// TODO: Support tools
// Preserve anthropic_version if user provided it (for beta features)
const userAnthropicVersion = (req.body as any).anthropic_version;
let strippedParams: Record<string, unknown>;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
@@ -50,11 +53,14 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
stream: true,
tools: true,
tool_choice: true,
thinking: true
thinking: true,
cache_control: true
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "vertex-2023-10-16";
// Use user-provided version or default to vertex-2023-10-16
strippedParams.anthropic_version = userAnthropicVersion || "vertex-2023-10-16";
const credential = await getCredentialsFromGcpKey(key);
@@ -63,6 +69,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
// stream adapter selects the correct transformer.
manager.setHeader("anthropic-version", "2023-06-01");
// GCP Vertex AI uses body parameter anthropic_version (set above) for versioning,
// not anthropic-beta header. Beta features are enabled through body params.
manager.setSignedRequest({
method: "POST",
protocol: "https:",
@@ -7,21 +7,50 @@ import {
AnthropicChatMessage,
flattenAnthropicMessages,
} from "../../../../shared/api-schemas/anthropic";
import {
MistralAIChatMessage,
import {
MistralAIChatMessage,
ContentItem,
isMistralVisionModel
isMistralVisionModel
} from "../../../../shared/api-schemas/mistral-ai";
import { isGrokVisionModel } from "../../../../shared/api-schemas/xai";
import { keyPool } from "../../../../shared/key-management";
import { config } from "../../../../config";
/**
* Given a request with an already-transformed body, counts the number of
* tokens and assigns the count to the request.
*
* If remote token counting is enabled, we temporarily get a key from the pool
* to use the remote API, then clear it. The actual key assignment happens
* later in the mutators after the request is dequeued.
*/
export const countPromptTokens: RequestPreprocessor = async (req) => {
const service = req.outboundApi;
let result;
// For remote token counting, temporarily get a key from the pool
// We don't permanently assign it - that happens in mutators after dequeue
// IMPORTANT: Don't pass req.body here to avoid premature cache fingerprinting
// Cache fingerprinting should only happen during actual key assignment in mutators,
// after all preprocessing (including model name transformations) is complete
let tempKey;
const hadKey = !!req.key;
if (config.useRemoteTokenCounting && !req.key) {
try {
tempKey = keyPool.get(req.body.model, req.service, undefined, undefined);
req.key = tempKey; // Temporarily assign for token counting
req.log.debug(
{ keyHash: tempKey.hash, service: req.service },
"Temporarily assigned key for remote token counting"
);
} catch (error) {
req.log.debug(
{ error: (error as Error).message },
"Could not get key for remote token counting, will use local tokenizer"
);
}
}
switch (service) {
case "openai": {
req.outputTokens = req.body.max_completion_tokens || req.body.max_tokens;
@@ -129,4 +158,14 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
req.log.debug({ result: result }, "Counted prompt tokens.");
req.tokenizerInfo = req.tokenizerInfo ?? {};
req.tokenizerInfo = { ...req.tokenizerInfo, ...result };
// Clear the temporary key if we assigned one for token counting
// The real key will be assigned later in mutators after dequeue
if (tempKey && !hadKey) {
delete req.key;
req.log.debug(
{ keyHash: tempKey.hash },
"Cleared temporary key after token counting"
);
}
};
@@ -121,13 +121,13 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 200000;
} else if (model.match(/^claude-3/)) {
modelMax = 200000;
} else if (model.match(/^claude-(?:sonnet|opus)-4/)) {
} else if (model.match(/^claude-(?:sonnet|opus)-4(?:-5)?/)) {
modelMax = 1000000;
} else if (model.match(/^gemini-/)) {
modelMax = 1024000;
} else if (model.match(/^anthropic\.claude-3/)) {
modelMax = 200000;
} else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4/)) {
} else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4(?:-5)?/)) {
modelMax = 1000000;
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
modelMax = 200000;
+153 -9
View File
@@ -1094,15 +1094,130 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
try {
assertJsonResponse(body);
const service = req.outboundApi;
const completion = getCompletionFromBody(req, body);
const tokens = await countTokens({ req, completion, service });
if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") {
// O1 consumes (a significant amount of) invisible tokens for the chain-
// of-thought reasoning. We have no way to count these other than to check
// the response body.
tokens.reasoning_tokens =
body.usage?.completion_tokens_details?.reasoning_tokens;
// Try to get token counts from the API response first
let tokens: { token_count: number; tokenizer: string; reasoning_tokens?: number } | null = null;
// Anthropic API returns usage data in the response
if (service === "anthropic-chat" && body.usage) {
tokens = {
token_count: body.usage.output_tokens || 0,
tokenizer: "anthropic-api",
};
req.log.debug(
{ service, outputTokens: tokens.token_count, usage: body.usage },
"Got output token count from Anthropic API response"
);
// Sanity check: if request had cache_control, expect cache metrics
if (req.body.system || req.body.tools || req.body.messages) {
const hasCacheControl = checkForCacheControl(req.body);
if (hasCacheControl) {
const cacheRead = body.usage.cache_read_input_tokens || 0;
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
if (cacheRead === 0 && cacheCreation === 0) {
req.log.error(
{ keyHash: req.key?.hash, usage: body.usage },
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from Anthropic API"
);
}
}
}
}
// AWS Bedrock returns usage data in the response (same format as Anthropic)
else if (req.service === "aws" && service === "anthropic-chat" && body.usage) {
tokens = {
token_count: body.usage.output_tokens || 0,
tokenizer: "aws-bedrock-api",
};
req.log.debug(
{ service, outputTokens: tokens.token_count, usage: body.usage },
"Got output token count from AWS Bedrock API response"
);
// Sanity check: if request had cache_control, expect cache metrics
if (req.body.system || req.body.tools || req.body.messages) {
const hasCacheControl = checkForCacheControl(req.body);
if (hasCacheControl) {
const cacheRead = body.usage.cache_read_input_tokens || 0;
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
if (cacheRead === 0 && cacheCreation === 0) {
req.log.error(
{ keyHash: req.key?.hash, usage: body.usage },
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from AWS Bedrock API"
);
}
}
}
}
// GCP Vertex AI returns usage data in the response
// For Anthropic models, GCP returns Anthropic format (usage.output_tokens)
// For Gemini models, GCP returns GCP format (usageMetadata.candidatesTokenCount)
else if (req.service === "gcp") {
if (service === "anthropic-chat" && body.usage?.output_tokens) {
tokens = {
token_count: body.usage.output_tokens || 0,
tokenizer: "gcp-anthropic-api",
};
req.log.debug(
{ service, outputTokens: tokens.token_count, usage: body.usage },
"Got output token count from GCP Vertex AI (Anthropic format)"
);
// Sanity check: if request had cache_control, expect cache metrics
if (req.body.system || req.body.tools || req.body.messages) {
const hasCacheControl = checkForCacheControl(req.body);
if (hasCacheControl) {
const cacheRead = body.usage.cache_read_input_tokens || 0;
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
if (cacheRead === 0 && cacheCreation === 0) {
req.log.error(
{ keyHash: req.key?.hash, usage: body.usage },
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from GCP Vertex AI API"
);
}
}
}
} else if (body.usageMetadata) {
tokens = {
token_count: body.usageMetadata.candidatesTokenCount || 0,
tokenizer: "gcp-vertex-api",
};
req.log.debug(
{ service, outputTokens: tokens.token_count, usageMetadata: body.usageMetadata },
"Got output token count from GCP Vertex AI (Gemini format)"
);
}
}
// OpenAI and similar services return usage data
else if (body.usage?.completion_tokens) {
tokens = {
token_count: body.usage.completion_tokens,
tokenizer: "api-usage-data",
};
if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") {
// O1 consumes (a significant amount of) invisible tokens for the chain-
// of-thought reasoning. We have no way to count these other than to check
// the response body.
tokens.reasoning_tokens =
body.usage?.completion_tokens_details?.reasoning_tokens;
}
req.log.debug(
{ service, outputTokens: tokens.token_count, usage: body.usage },
"Got output token count from API usage data"
);
}
// Fall back to local tokenization if no usage data is available
if (!tokens) {
const completion = getCompletionFromBody(req, body);
tokens = await countTokens({ req, completion, service });
req.log.debug(
{ service, outputTokens: tokens.token_count },
"Counted output tokens locally (no API usage data)"
);
}
req.log.debug(
@@ -1193,6 +1308,35 @@ function getAwsErrorType(header: string | string[] | undefined) {
return val || String(header);
}
function checkForCacheControl(body: any): boolean {
// Check tools
if (Array.isArray(body.tools)) {
for (const tool of body.tools) {
if (tool.cache_control) return true;
}
}
// Check system blocks
if (Array.isArray(body.system)) {
for (const block of body.system) {
if (block.cache_control) return true;
}
}
// Check message content blocks
if (Array.isArray(body.messages)) {
for (const message of body.messages) {
if (Array.isArray(message.content)) {
for (const block of message.content) {
if (block.cache_control) return true;
}
}
}
}
return false;
}
function assertJsonResponse(body: any): asserts body is Record<string, any> {
if (typeof body !== "object") {
throw new Error(`Expected response to be an object, got ${typeof body}`);
@@ -8,7 +8,12 @@ export type AnthropicChatCompletionResponse = {
model: string;
stop_reason: string | null;
stop_sequence: string | null;
usage: { input_tokens: number; output_tokens: number };
usage: {
input_tokens: number;
output_tokens: number;
cache_creation_input_tokens?: number;
cache_read_input_tokens?: number;
};
};
/**
@@ -43,6 +48,24 @@ export function mergeEventsForAnthropicChat(
acc.content[0].text += event.choices[0].delta.content;
}
// OpenAI events may include usage data (extended by our transformer)
// Usage tokens should be set to the latest value, not accumulated
// (they represent total counts, not deltas)
if ((event as any).usage) {
if ((event as any).usage.input_tokens !== undefined) {
acc.usage.input_tokens = (event as any).usage.input_tokens;
}
if ((event as any).usage.output_tokens !== undefined) {
acc.usage.output_tokens = (event as any).usage.output_tokens;
}
if ((event as any).usage.cache_creation_input_tokens !== undefined) {
acc.usage.cache_creation_input_tokens = (event as any).usage.cache_creation_input_tokens;
}
if ((event as any).usage.cache_read_input_tokens !== undefined) {
acc.usage.cache_read_input_tokens = (event as any).usage.cache_read_input_tokens;
}
}
return acc;
}, merged);
return merged;
@@ -10,6 +10,11 @@ export type OpenAiChatCompletionResponse = {
finish_reason: string | null;
index: number;
}[];
usage?: {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
};
/**
@@ -52,6 +57,17 @@ export function mergeEventsForOpenAIChat(
acc.choices[0].message.content += event.choices[0].delta.content;
}
// Accumulate usage data from events (OpenAI may send this in the final event)
if ((event as any).usage) {
if (!acc.usage) {
acc.usage = {};
}
const usage = (event as any).usage;
if (usage.prompt_tokens) acc.usage.prompt_tokens = usage.prompt_tokens;
if (usage.completion_tokens) acc.usage.completion_tokens = usage.completion_tokens;
if (usage.total_tokens) acc.usage.total_tokens = usage.total_tokens;
}
return acc;
}, merged);
return merged;
@@ -50,12 +50,32 @@ export class SSEStreamAdapter extends Transform {
const event = Buffer.from(bytes, "base64").toString("utf8");
const eventObj = JSON.parse(event);
// AWS Bedrock includes usage metrics in the event stream headers
// Extract and attach them to the event object for downstream processing
const invocationMetrics = headers["amazon-bedrock-invocationMetrics"];
if (invocationMetrics?.value) {
try {
const metricsStr = typeof invocationMetrics.value === 'string'
? invocationMetrics.value
: JSON.stringify(invocationMetrics.value);
const metricsObj = JSON.parse(metricsStr);
eventObj["amazon-bedrock-invocationMetrics"] = metricsObj;
} catch (e) {
this.log.warn(
{ invocationMetrics: invocationMetrics.value },
"Failed to parse AWS invocationMetrics"
);
}
}
const eventWithMetrics = JSON.stringify(eventObj);
if ("completion" in eventObj) {
return ["event: completion", `data: ${event}`].join(`\n`);
return ["event: completion", `data: ${eventWithMetrics}`].join(`\n`);
} else if (eventObj.type) {
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
return [`event: ${eventObj.type}`, `data: ${eventWithMetrics}`].join(`\n`);
} else {
return `data: ${event}`;
return `data: ${eventWithMetrics}`;
}
}
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
@@ -22,8 +22,70 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
return { position: -1 };
}
// Try to extract usage data from message_start and message_delta events
// Also check for AWS Bedrock invocationMetrics
let usageData: {
input_tokens?: number;
output_tokens?: number;
cache_creation_input_tokens?: number;
cache_read_input_tokens?: number;
} | undefined;
try {
const parsed = JSON.parse(rawEvent.data);
if (parsed.type === "message_start" && parsed.message?.usage) {
usageData = {
input_tokens: parsed.message.usage.input_tokens,
output_tokens: parsed.message.usage.output_tokens,
cache_creation_input_tokens: parsed.message.usage.cache_creation_input_tokens,
cache_read_input_tokens: parsed.message.usage.cache_read_input_tokens,
};
} else if (parsed.type === "message_delta" && parsed.delta?.usage) {
usageData = {
output_tokens: parsed.delta.usage.output_tokens,
cache_creation_input_tokens: parsed.delta.usage.cache_creation_input_tokens,
cache_read_input_tokens: parsed.delta.usage.cache_read_input_tokens,
};
}
// AWS Bedrock includes usage in amazon-bedrock-invocationMetrics
// AWS uses PascalCase field names (CacheReadInputTokens, CacheWriteInputTokens)
else if (parsed["amazon-bedrock-invocationMetrics"]) {
const metrics = parsed["amazon-bedrock-invocationMetrics"];
usageData = {
input_tokens: metrics.inputTokenCount,
output_tokens: metrics.outputTokenCount,
// Map AWS PascalCase to Anthropic snake_case
cache_read_input_tokens: metrics.cacheReadInputTokenCount,
cache_creation_input_tokens: metrics.cacheWriteInputTokenCount,
};
log.debug(
{ metrics, usageData },
"Extracted usage from AWS invocationMetrics"
);
}
} catch (e) {
// Ignore parsing errors
}
const deltaEvent = asAnthropicChatDelta(rawEvent);
if (!deltaEvent) {
// If we have usage data but no delta, still emit an event with usage
if (usageData) {
const usageEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: {},
finish_reason: null,
},
],
usage: usageData,
};
return { position: -1, event: usageEvent };
}
return { position: -1 };
}
@@ -39,6 +101,7 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
finish_reason: null,
},
],
...(usageData && { usage: usageData }),
};
return { position: -1, event: newEvent };
@@ -15,6 +15,11 @@ type GoogleAIStreamEvent = {
tokenCount?: number;
safetyRatings: { category: string; probability: string }[];
}[];
usageMetadata?: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
};
/**
@@ -49,7 +54,7 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
content = content.replace(/^(.*?): /, "").trim();
}
const newEvent = {
const newEvent: any = {
id: "goo-" + params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
@@ -63,6 +68,19 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
],
};
// Extract usage metadata from GCP/Vertex AI responses
if (completionEvent.usageMetadata) {
newEvent.usage = {
prompt_tokens: completionEvent.usageMetadata.promptTokenCount,
completion_tokens: completionEvent.usageMetadata.candidatesTokenCount,
total_tokens: completionEvent.usageMetadata.totalTokenCount,
};
log.debug(
{ usageMetadata: completionEvent.usageMetadata },
"Extracted usage from GCP usageMetadata"
);
}
return { position: -1, event: newEvent };
};
@@ -5,6 +5,11 @@ import { logger } from "../../../logger";
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
import { AnthropicKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
import {
generateCacheFingerprint,
recordCacheUsage,
getCachedKeyHash,
} from "../cache-tracker";
export type AnthropicKeyUpdate = Omit<
Partial<AnthropicKey>,
@@ -136,7 +141,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(rawModel: string) {
public get(rawModel: string, _streaming?: boolean, requestBody?: any) {
this.log.debug({ model: rawModel }, "Selecting key");
const needsMultimodal = rawModel.endsWith("-multimodal");
@@ -152,7 +157,44 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
);
}
// Generate cache fingerprint if request body contains cache_control
const cacheFingerprint = requestBody
? generateCacheFingerprint(requestBody)
: null;
// Try to get cached key if we have a fingerprint
let preferredKeyHash: string | null = null;
let matchedFingerprint: string | null = null;
if (cacheFingerprint) {
const cacheResult = getCachedKeyHash(cacheFingerprint);
if (cacheResult) {
preferredKeyHash = cacheResult.keyHash;
matchedFingerprint = cacheResult.matchedFingerprint;
// Check if the cached key is still available
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
if (cachedKey) {
this.log.debug(
{
requestedModel: rawModel,
cacheFingerprint,
keyHash: preferredKeyHash,
},
"Using cached key for prompt caching optimization"
);
} else {
// Cached key no longer available
preferredKeyHash = null;
matchedFingerprint = null;
this.log.debug(
{ cacheFingerprint, keyHash: preferredKeyHash },
"Cached key not available, selecting new key"
);
}
}
}
// Select a key, from highest priority to lowest priority:
// 0. Cache affinity (if we have a cached key preference)
// 1. Keys which are not rate limit locked
// 2. Keys with the highest tier
// 3. Keys which are not pozzed
@@ -161,6 +203,12 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
// Highest priority: cache affinity
if (preferredKeyHash) {
if (a.hash === preferredKeyHash) return -1;
if (b.hash === preferredKeyHash) return 1;
}
const aLockoutPeriod = getKeyLockout(a);
const bLockoutPeriod = getKeyLockout(b);
@@ -183,6 +231,13 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
// Record cache usage for future requests
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
if (cacheFingerprint) {
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
}
return { ...selectedKey };
}
+69 -12
View File
@@ -7,6 +7,11 @@ import { findByAnthropicId } from "../../claude-models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AwsKeyChecker } from "./checker";
import {
generateCacheFingerprint,
recordCacheUsage,
getCachedKeyHash,
} from "../cache-tracker";
// AwsBedrockKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface AwsBedrockKey extends Key {
@@ -90,14 +95,14 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: string) {
public get(model: string, _streaming?: boolean, requestBody?: any) {
let neededVariantId = model;
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
// Generally all AWS model IDs are supersets of the original vendor IDs.
// Claude 2 is the only model that breaks this convention; Anthropic calls
// it claude-2 but AWS calls it claude-v2.
if (model.includes("claude-2")) neededVariantId = "claude-v2";
// For Claude models, try to resolve aliases to AWS model IDs
if (model.includes("claude") && !model.includes("anthropic.")) {
const claudeMapping = findByAnthropicId(model);
@@ -105,7 +110,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
neededVariantId = claudeMapping.awsId;
}
}
const neededFamily = getAwsBedrockModelFamily(model);
const availableKeys = this.keys.filter((k) => {
@@ -122,6 +127,42 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
);
});
// Generate cache fingerprint if request body contains cache_control
const cacheFingerprint = requestBody
? generateCacheFingerprint(requestBody)
: null;
// Try to get cached key if we have a fingerprint
let preferredKeyHash: string | null = null;
let matchedFingerprint: string | null = null;
if (cacheFingerprint) {
const cacheResult = getCachedKeyHash(cacheFingerprint);
if (cacheResult) {
preferredKeyHash = cacheResult.keyHash;
matchedFingerprint = cacheResult.matchedFingerprint;
// Check if the cached key is still available
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
if (cachedKey) {
this.log.debug(
{
requestedModel: model,
cacheFingerprint,
keyHash: preferredKeyHash,
},
"Using cached key for prompt caching optimization"
);
} else {
// Cached key no longer available
preferredKeyHash = null;
matchedFingerprint = null;
this.log.debug(
{ cacheFingerprint, keyHash: preferredKeyHash },
"Cached key not available, selecting new key"
);
}
}
}
this.log.debug(
{
requestedModel: model,
@@ -129,6 +170,8 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
selectedFamily: neededFamily,
totalKeys: this.keys.length,
availableKeys: availableKeys.length,
cacheFingerprint,
hasCachedKey: !!preferredKeyHash,
},
"Selecting AWS key"
);
@@ -140,22 +183,36 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
}
/**
* Comparator for prioritizing keys on inference profile compatibility.
* Requests made via inference profiles have higher rate limits so we want
* to use keys with compatible inference profiles first.
* Comparator for prioritizing keys based on:
* 1. Cache affinity (if we have a cached key preference)
* 2. Inference profile compatibility
*/
const hasInferenceProfile = (
a: AwsBedrockKey,
b: AwsBedrockKey
) => {
const keyComparator = (a: AwsBedrockKey, b: AwsBedrockKey) => {
// Highest priority: cache affinity
if (preferredKeyHash) {
if (a.hash === preferredKeyHash) return -1;
if (b.hash === preferredKeyHash) return 1;
}
// Second priority: inference profile compatibility
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
return aMatch - bMatch;
const profileDiff = bMatch - aMatch;
if (profileDiff !== 0) return profileDiff;
return 0;
};
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
// Record cache usage for future requests
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
if (cacheFingerprint) {
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
}
return { ...selectedKey };
}
+457
View File
@@ -0,0 +1,457 @@
import crypto from "crypto";
import { logger } from "../../logger";
/**
* Deterministic JSON stringify that sorts object keys to ensure consistent hashing.
*/
function deterministicStringify(obj: any): string {
if (obj === null || obj === undefined) return String(obj);
if (typeof obj !== "object") return JSON.stringify(obj);
if (Array.isArray(obj)) return `[${obj.map(deterministicStringify).join(",")}]`;
const keys = Object.keys(obj).sort();
const pairs = keys.map(k => `"${k}":${deterministicStringify(obj[k])}`);
return `{${pairs.join(",")}}`;
}
/**
* Universal cache tracker for all providers (Anthropic, AWS, GCP).
*
* Tracks which keys have cached which prompt prefixes to optimize prompt caching.
* Each API key has its own cache, so routing requests with identical cacheable
* content to the same key maximizes cache hits.
*
* Cache rules (per Anthropic/Bedrock/Vertex docs):
* - 100% identical prompt segments required for cache hit
* - Default 5-minute TTL, refreshed on each use
* - Optional 1-hour TTL
* - Cache becomes available after first response begins
* - Minimum 1024-2048 tokens (model-dependent)
*/
interface CacheEntry {
keyHash: string;
lastUsed: number;
hitCount: number;
ttl: number;
// Store all prefix fingerprints for hierarchical matching
prefixFingerprints?: string[];
}
const log = logger.child({ module: "cache-tracker" });
// Maps cache fingerprints to the key that has cached that content
const cacheMap = new Map<string, CacheEntry>();
// Default TTLs in milliseconds
const TTL_5_MINUTES = 5 * 60 * 1000;
const TTL_1_HOUR = 60 * 60 * 1000;
/**
* Generates a fingerprint for cacheable content in a request.
* The fingerprint includes content up to and including the LAST cache_control
* breakpoint, following Anthropic's hierarchy: tools → system → messages.
*
* Key insight for handling ever-growing prompts:
* By fingerprinting only up to the last cache_control marker, we ensure that
* new content added AFTER that marker (e.g., new user messages in a conversation)
* won't change the fingerprint. This enables cache hits as conversations grow.
*
* How Anthropic's cache matching works:
* - You place ONE cache_control at the end of your static/stable content
* - The API automatically checks ~20 blocks BEFORE that breakpoint for cache hits
* - It uses the longest matching prefix automatically
* - You don't need multiple breakpoints - one at the end is sufficient
*
* Example conversation flow:
* Request 1: [system prompt + cache_control] → Creates cache, fingerprint = "abc123"
* Request 2: [system prompt + cache_control] + [user msg] + [assistant msg] + [user msg]
* → Same fingerprint "abc123" → Cache HIT (new messages after breakpoint ignored)
* Request 3: [MODIFIED system prompt + cache_control] + messages
* → Different fingerprint "def456" → Cache MISS (prefix changed)
*/
export function generateCacheFingerprint(body: any): string | null {
if (!body) {
return null;
}
const parts: any[] = [];
let hasCacheControl = false;
const cacheBreakpoints: number[] = []; // Indices of cache_control markers
const ttls: number[] = []; // TTL for each breakpoint
// 1. Process tools if present (tools come first in hierarchy)
if (body.tools && Array.isArray(body.tools)) {
for (let i = 0; i < body.tools.length; i++) {
const tool = body.tools[i];
// Only include the stable parts of the tool definition in the fingerprint
// Exclude cache_control as it's metadata, not part of the tool definition
const { cache_control, ...toolWithoutCache } = tool;
parts.push({ type: "tool", tool: toolWithoutCache });
if (cache_control) {
hasCacheControl = true;
cacheBreakpoints.push(parts.length - 1);
ttls.push(parseTTL(cache_control.ttl));
}
}
}
// 2. Process system prompt if present
if (body.system) {
if (typeof body.system === "string") {
// Normalize string system to same structure as array blocks for consistent fingerprinting
parts.push({ type: "system_block", block_type: "text", text: body.system });
} else if (Array.isArray(body.system)) {
for (const block of body.system) {
// System blocks can be:
// - {type: "text", text: "...", cache_control?: ...}
// - {type: "image", source: {...}, cache_control?: ...}
const { cache_control, type, ...rest } = block;
// Create a consistent fingerprint structure
const contentPart: any = { type: "system_block", block_type: type };
switch (type) {
case "text":
contentPart.text = (rest as any).text;
break;
case "image":
// Hash image data for fingerprint
if ((rest as any).source?.data) {
contentPart.image_hash = crypto
.createHash("sha256")
.update((rest as any).source.data)
.digest("hex")
.slice(0, 16);
}
contentPart.media_type = (rest as any).source?.media_type;
break;
default:
// For unknown block types, include all remaining fields
Object.assign(contentPart, rest);
}
parts.push(contentPart);
if (cache_control) {
hasCacheControl = true;
cacheBreakpoints.push(parts.length - 1);
ttls.push(parseTTL(cache_control.ttl));
}
}
}
}
// 3. Process messages
if (body.messages && Array.isArray(body.messages)) {
for (const message of body.messages) {
if (typeof message.content === "string") {
parts.push({
type: "message",
role: message.role,
content: message.content,
});
} else if (Array.isArray(message.content)) {
// For multimodal content, include each block
for (const block of message.content) {
// CRITICAL: Exclude cache_control metadata from fingerprinting
// Only the actual content should affect the fingerprint
const { cache_control, ...blockWithoutCache } = block;
const contentPart: any = {
type: "message_block",
role: message.role,
block_type: blockWithoutCache.type,
};
// Include essential identifying info without full data
// (full data would make fingerprints too large)
switch (blockWithoutCache.type) {
case "text":
contentPart.text = blockWithoutCache.text;
break;
case "image":
// Hash image data for fingerprint
if (blockWithoutCache.source?.data) {
contentPart.image_hash = crypto
.createHash("sha256")
.update(blockWithoutCache.source.data)
.digest("hex")
.slice(0, 16);
}
contentPart.media_type = blockWithoutCache.source?.media_type;
break;
case "tool_use":
// Don't include tool_id as it's often randomly generated
// Only include tool name and input which are stable
contentPart.tool_name = blockWithoutCache.name;
contentPart.tool_input = blockWithoutCache.input;
break;
case "tool_result":
// Don't include tool_use_id as it's randomly generated
// Only include the actual result content and error status
contentPart.is_error = blockWithoutCache.is_error;
if (typeof blockWithoutCache.content === "string") {
contentPart.content = blockWithoutCache.content;
} else if (Array.isArray(blockWithoutCache.content)) {
// tool_result content can be an array of content blocks
contentPart.content = blockWithoutCache.content;
}
break;
}
parts.push(contentPart);
if (cache_control) {
hasCacheControl = true;
cacheBreakpoints.push(parts.length - 1);
ttls.push(parseTTL(cache_control.ttl));
}
}
}
}
}
// No caching if no cache_control directives present
if (!hasCacheControl || cacheBreakpoints.length === 0) {
return null;
}
const maxTTL = Math.max(...ttls);
// Generate individual hashes for each part (tool, system block, message)
// This allows prefix matching even when conversations grow
const partHashes: string[] = [];
for (const part of parts) {
const partHash = crypto
.createHash("sha256")
.update(deterministicStringify(part))
.digest("hex")
.slice(0, 8); // Shorter hash per part
partHashes.push(partHash);
}
// Generate fingerprints for each cache_control position
// Fingerprint is the concatenation of individual part hashes up to the breakpoint
const prefixFingerprints: string[] = [];
for (const breakpoint of cacheBreakpoints) {
// Concatenate part hashes up to and including the breakpoint
const fingerprint = partHashes.slice(0, breakpoint + 1).join("");
log.trace(
{
fingerprint: fingerprint.substring(0, 32) + "...",
breakpoint,
prefixPartsCount: breakpoint + 1,
},
"Generated cache fingerprint"
);
prefixFingerprints.push(fingerprint);
}
// Return the deepest (last) fingerprint as the primary one
const primaryFingerprint = prefixFingerprints[prefixFingerprints.length - 1];
// Store ALL prefix fingerprints in the cache map, not just the primary one
// This allows future requests to match against any prefix level
for (let i = 0; i < prefixFingerprints.length; i++) {
const fp = prefixFingerprints[i];
const existing = cacheMap.get(fp);
if (!existing) {
cacheMap.set(fp, {
keyHash: "",
lastUsed: 0,
hitCount: 0,
ttl: maxTTL,
prefixFingerprints: prefixFingerprints.slice(0, i + 1),
});
}
}
return primaryFingerprint;
}
function parseTTL(ttl?: string): number {
if (ttl === "1h") return TTL_1_HOUR;
return TTL_5_MINUTES; // Default or "5m"
}
/**
* Records that a key was used for a request with the given cache fingerprint.
*/
export function recordCacheUsage(fingerprint: string, keyHash: string): void {
if (!fingerprint) return;
const entry = cacheMap.get(fingerprint);
if (!entry) {
// Shouldn't happen, but handle gracefully
cacheMap.set(fingerprint, {
keyHash,
lastUsed: Date.now(),
hitCount: 1,
ttl: TTL_5_MINUTES,
});
log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "New cache entry recorded");
return;
}
const now = Date.now();
if (entry.keyHash === keyHash) {
// Same key - likely cache hit
entry.lastUsed = now;
entry.hitCount++;
log.trace(
{ fingerprint: fingerprint.substring(0, 32) + "...", keyHash, hitCount: entry.hitCount },
"Cache usage recorded (likely cache hit)"
);
} else if (entry.keyHash === "") {
// First use of this fingerprint
entry.keyHash = keyHash;
entry.lastUsed = now;
entry.hitCount = 1;
log.debug({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "First cache usage for fingerprint");
} else {
// Different key - cache miss, reset tracking
log.debug(
{ fingerprint: fingerprint.substring(0, 32) + "...", oldKey: entry.keyHash, newKey: keyHash },
"Cache key changed (will cause cache miss)"
);
entry.keyHash = keyHash;
entry.lastUsed = now;
entry.hitCount = 1;
}
}
/**
* Gets the key hash that has cached the longest matching prefix for the given
* fingerprint or any of its sub-prefixes.
*
* This is crucial for handling moving cache breakpoints:
* - If the current request has fingerprints ["fp1", "fp2", "fp3"]
* - We search backwards from "fp3" → "fp2" → "fp1"
* - Return the key that cached the longest available prefix
*
* Returns null if no cached key exists or all caches have expired.
*/
export function getCachedKeyHash(fingerprint: string): { keyHash: string; matchedFingerprint: string } | null {
if (!fingerprint) return null;
const now = Date.now();
// First, try exact match on the primary (deepest) fingerprint
const primaryEntry = cacheMap.get(fingerprint);
if (primaryEntry && primaryEntry.keyHash) {
const age = now - primaryEntry.lastUsed;
if (age <= primaryEntry.ttl) {
log.trace(
{
fingerprint: fingerprint.substring(0, 32) + "...",
keyHash: primaryEntry.keyHash,
age,
hitCount: primaryEntry.hitCount,
matchType: "exact",
},
"Cache entry found (exact match)"
);
return { keyHash: primaryEntry.keyHash, matchedFingerprint: fingerprint };
} else {
log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", age, ttl: primaryEntry.ttl }, "Cache entry expired");
cacheMap.delete(fingerprint);
}
}
// If no exact match, search all cache entries for prefix matches
// Since fingerprints are now concatenated part hashes, we can check if one is a prefix of another
let bestMatch: { keyHash: string; matchedFingerprint: string } | null = null;
let longestMatchLength = 0;
for (const [cachedFp, entry] of cacheMap.entries()) {
if (!entry.keyHash) continue;
const age = now - entry.lastUsed;
if (age > entry.ttl) {
cacheMap.delete(cachedFp);
continue;
}
// Check if the cached fingerprint is a prefix of the current fingerprint
// (cached request had fewer parts, current request has grown)
if (fingerprint.startsWith(cachedFp) && cachedFp.length > longestMatchLength) {
bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp };
longestMatchLength = cachedFp.length;
log.trace(
{
requestFingerprint: fingerprint.substring(0, 32) + "...",
matchedFingerprint: cachedFp.substring(0, 32) + "...",
keyHash: entry.keyHash,
matchType: "prefix",
},
"Cache entry found (prefix match)"
);
}
// Also check if the current fingerprint is a prefix of the cached one
// (current request has fewer parts, cached had more)
if (cachedFp.startsWith(fingerprint) && fingerprint.length > longestMatchLength) {
bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp };
longestMatchLength = fingerprint.length;
log.trace(
{
requestFingerprint: fingerprint.substring(0, 32) + "...",
matchedFingerprint: cachedFp.substring(0, 32) + "...",
keyHash: entry.keyHash,
matchType: "prefix_reverse",
},
"Cache entry found (reverse prefix match)"
);
}
}
return bestMatch || null;
}
/**
* Clears expired cache entries periodically to prevent memory leaks.
*/
export function cleanupExpiredCaches(): void {
const now = Date.now();
let cleaned = 0;
for (const [fingerprint, entry] of cacheMap.entries()) {
const age = now - entry.lastUsed;
if (age > entry.ttl) {
cacheMap.delete(fingerprint);
cleaned++;
}
}
if (cleaned > 0) {
log.debug(
{ cleaned, remaining: cacheMap.size },
"Cleaned up expired cache entries"
);
}
}
/**
* Returns cache statistics for monitoring.
*/
export function getCacheStats() {
return {
totalEntries: cacheMap.size,
entries: Array.from(cacheMap.entries()).map(([fp, entry]) => ({
fingerprint: fp,
keyHash: entry.keyHash,
age: Date.now() - entry.lastUsed,
hitCount: entry.hitCount,
ttl: entry.ttl,
})),
};
}
// Run cleanup every minute
setInterval(cleanupExpiredCaches, 60 * 1000);
+64 -2
View File
@@ -6,6 +6,11 @@ import { GcpModelFamily, getGcpModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GcpKeyChecker } from "./checker";
import {
generateCacheFingerprint,
recordCacheUsage,
getCachedKeyHash,
} from "../cache-tracker";
// GcpKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface GcpKey extends Key {
@@ -90,7 +95,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: string) {
public get(model: string, _streaming?: boolean, requestBody?: any) {
const neededFamily = getGcpModelFamily(model);
// this is a horrible mess
@@ -115,6 +120,42 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
);
});
// Generate cache fingerprint if request body contains cache_control
const cacheFingerprint = requestBody
? generateCacheFingerprint(requestBody)
: null;
// Try to get cached key if we have a fingerprint
let preferredKeyHash: string | null = null;
let matchedFingerprint: string | null = null;
if (cacheFingerprint) {
const cacheResult = getCachedKeyHash(cacheFingerprint);
if (cacheResult) {
preferredKeyHash = cacheResult.keyHash;
matchedFingerprint = cacheResult.matchedFingerprint;
// Check if the cached key is still available
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
if (cachedKey) {
this.log.debug(
{
requestedModel: model,
cacheFingerprint,
keyHash: preferredKeyHash,
},
"Using cached key for prompt caching optimization"
);
} else {
// Cached key no longer available
preferredKeyHash = null;
matchedFingerprint = null;
this.log.debug(
{ cacheFingerprint, keyHash: preferredKeyHash },
"Cached key not available, selecting new key"
);
}
}
}
this.log.debug(
{
model,
@@ -124,6 +165,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
needsSonnet35,
availableKeys: availableKeys.length,
totalKeys: this.keys.length,
cacheFingerprint,
hasCachedKey: !!preferredKeyHash,
},
"Selecting GCP key"
);
@@ -134,9 +177,28 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
);
}
const selectedKey = prioritizeKeys(availableKeys)[0];
/**
* Comparator for prioritizing keys based on cache affinity.
*/
const keyComparator = (a: GcpKey, b: GcpKey) => {
// Highest priority: cache affinity
if (preferredKeyHash) {
if (a.hash === preferredKeyHash) return -1;
if (b.hash === preferredKeyHash) return 1;
}
return 0;
};
const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
// Record cache usage for future requests
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
if (cacheFingerprint) {
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
}
return { ...selectedKey };
}
+1 -1
View File
@@ -61,7 +61,7 @@ for service-agnostic functionality.
export interface KeyProvider<T extends Key = Key> {
readonly service: LLMService;
init(): void;
get(model: string, streaming?: boolean): T;
get(model: string, streaming?: boolean, requestBody?: any): T;
list(): Omit<T, "key">[];
disable(key: T): void;
update(hash: string, update: Partial<T>): void;
+3 -3
View File
@@ -55,15 +55,15 @@ export class KeyPool {
this.scheduleRecheck();
}
public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean): Key {
public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean, requestBody?: any): Key {
// hack for some claude requests needing keys with particular permissions
// even though they use the same models as the non-multimodal requests
if (multimodal) {
model += "-multimodal";
}
const queryService = service || this.getServiceForModel(model);
return this.getKeyProvider(queryService).get(model, streaming);
return this.getKeyProvider(queryService).get(model, streaming, requestBody);
}
public list(): Omit<Key, "key">[] {
@@ -0,0 +1,66 @@
import { AnthropicChatMessage } from "../api-schemas";
import { getAxiosInstance } from "../network";
import { logger } from "../../logger";
import { AnthropicKey } from "../key-management/anthropic/provider";
const log = logger.child({ module: "tokenizer", service: "anthropic-remote" });
export interface AnthropicTokenCountRequest {
model: string;
messages: AnthropicChatMessage[];
system?: string | Array<{ type: string; text: string }>;
tools?: unknown[];
}
export interface AnthropicTokenCountResponse {
input_tokens: number;
}
/**
* Counts tokens using Anthropic's remote token counting API endpoint.
* https://docs.claude.com/en/docs/build-with-claude/token-counting
*/
export async function countTokensRemote(
request: AnthropicTokenCountRequest,
key: AnthropicKey
): Promise<{ token_count: number; tokenizer: string }> {
const axios = getAxiosInstance();
try {
const response = await axios.post<AnthropicTokenCountResponse>(
"https://api.anthropic.com/v1/messages/count_tokens",
request,
{
headers: {
"x-api-key": key.key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
timeout: 5000, // 5 second timeout
}
);
log.debug(
{
model: request.model,
input_tokens: response.data.input_tokens,
},
"Counted tokens via Anthropic API"
);
return {
token_count: response.data.input_tokens,
tokenizer: "anthropic-remote-api",
};
} catch (error: any) {
log.warn(
{
error: error.message,
status: error.response?.status,
data: error.response?.data,
},
"Failed to count tokens via Anthropic API, will fall back to local tokenizer"
);
throw error;
}
}
+117
View File
@@ -0,0 +1,117 @@
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import { getAxiosInstance } from "../network";
import { logger } from "../../logger";
import { AwsBedrockKey } from "../key-management/aws/provider";
const log = logger.child({ module: "tokenizer", service: "aws-remote" });
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
export interface AwsTokenCountRequest {
input: {
invokeModel: {
body: string; // base64-encoded JSON string
};
};
}
export interface AwsTokenCountResponse {
inputTokens: number;
}
type Credential = {
accessKeyId: string;
secretAccessKey: string;
region: string;
};
function getCredentialParts(key: AwsBedrockKey): Credential {
const [accessKeyId, secretAccessKey, region] = key.key.split(":");
if (!accessKeyId || !secretAccessKey || !region) {
throw new Error("AWS_CREDENTIALS isn't correctly formatted");
}
return { accessKeyId, secretAccessKey, region };
}
async function sign(request: HttpRequest, credential: Credential) {
const { accessKeyId, secretAccessKey, region } = credential;
const signer = new SignatureV4({
sha256: Sha256,
credentials: { accessKeyId, secretAccessKey },
region,
service: "bedrock",
});
return signer.sign(request);
}
/**
* Counts tokens using AWS Bedrock's remote token counting API endpoint.
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html
*/
export async function countTokensRemote(
modelId: string,
request: AwsTokenCountRequest,
key: AwsBedrockKey
): Promise<{ token_count: number; tokenizer: string }> {
const axios = getAxiosInstance();
const credential = getCredentialParts(key);
const host = AMZ_HOST.replace("%REGION%", credential.region);
// Create the HTTP request to sign
const httpRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: host,
path: `/model/${modelId}/count-tokens`,
headers: {
["Host"]: host,
["content-type"]: "application/json",
},
body: JSON.stringify(request),
});
try {
// Sign the request using AWS Signature V4
const signedRequest = await sign(httpRequest, credential);
// Make the request
const response = await axios.post<AwsTokenCountResponse>(
`https://${host}${signedRequest.path}`,
signedRequest.body,
{
headers: signedRequest.headers as Record<string, string>,
timeout: 5000, // 5 second timeout
}
);
log.debug(
{
modelId,
input_tokens: response.data.inputTokens,
},
"Counted tokens via AWS Bedrock API"
);
return {
token_count: response.data.inputTokens,
tokenizer: "aws-bedrock-remote-api",
};
} catch (error: any) {
log.warn(
{
error: error.message,
status: error.response?.status,
data: error.response?.data,
},
"Failed to count tokens via AWS Bedrock API, will fall back to local tokenizer"
);
throw error;
}
}
+2 -2
View File
@@ -26,7 +26,7 @@ export async function getTokenCount(
return getTokenCountForMessages(prompt);
}
if (prompt.length > 800000) {
if (prompt.length > 5000000) {
throw new Error("Content is too large to tokenize.");
}
@@ -59,7 +59,7 @@ async function getTokenCountForMessages({
switch (part.type) {
case "text":
const { text } = part;
if (text.length > 800000 || numTokens > 200000) {
if (text.length > 5000000 || numTokens > 1200000) {
throw new Error("Text content is too large to tokenize.");
}
numTokens += encoder.encode(text.normalize("NFKC"), "all").length;
+104
View File
@@ -0,0 +1,104 @@
import { getAxiosInstance } from "../network";
import { logger } from "../../logger";
import { GcpKey } from "../key-management/gcp/provider";
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../key-management/gcp/oauth";
const log = logger.child({ module: "tokenizer", service: "gcp-remote" });
export interface GcpTokenCountRequest {
contents: Array<{
role: string;
parts: Array<{
text?: string;
inline_data?: {
mime_type: string;
data: string;
};
}>;
}>;
systemInstruction?: {
parts: Array<{
text: string;
}>;
};
tools?: unknown[];
}
export interface GcpTokenCountResponse {
totalTokens: number;
totalBillableCharacters?: number;
promptTokensDetails?: Array<{
modality: string;
tokenCount: number;
}>;
}
/**
* Counts tokens using GCP Vertex AI's remote token counting API endpoint.
* https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/count-tokens
*/
export async function countTokensRemote(
model: string,
request: GcpTokenCountRequest,
key: GcpKey
): Promise<{ token_count: number; tokenizer: string }> {
const axios = getAxiosInstance();
// Ensure we have a valid access token
const now = Date.now();
if (!key.accessToken || now >= key.accessTokenExpiresAt) {
const [token, expiresIn] = await refreshGcpAccessToken(key);
key.accessToken = token;
key.accessTokenExpiresAt = now + expiresIn * 1000;
}
const { projectId, region } = await getCredentialsFromGcpKey(key);
// Extract just the model name (e.g., "gemini-1.5-pro" from full model ID)
// GCP model IDs are typically like "gemini-1.5-pro-001" or just "gemini-1.5-pro"
let modelName = model;
if (model.includes("/")) {
// Handle full resource names like "projects/.../models/gemini-1.5-pro"
modelName = model.split("/").pop()!;
}
const url = `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/models/${modelName}:countTokens`;
try {
const response = await axios.post<GcpTokenCountResponse>(
url,
request,
{
headers: {
"Authorization": `Bearer ${key.accessToken}`,
"content-type": "application/json",
},
timeout: 5000, // 5 second timeout
}
);
log.debug(
{
model: modelName,
total_tokens: response.data.totalTokens,
details: response.data.promptTokensDetails,
},
"Counted tokens via GCP Vertex AI API"
);
return {
token_count: response.data.totalTokens,
tokenizer: "gcp-vertex-ai-remote-api",
};
} catch (error: any) {
log.warn(
{
error: error.message,
status: error.response?.status,
data: error.response?.data,
},
"Failed to count tokens via GCP Vertex AI API, will fall back to local tokenizer"
);
throw error;
}
}
+92
View File
@@ -21,6 +21,17 @@ import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "../api-schemas";
import { countTokensRemote as countAnthropicTokensRemote } from "./anthropic-remote";
import { countTokensRemote as countAwsTokensRemote } from "./aws-remote";
import { countTokensRemote as countGcpTokensRemote } from "./gcp-remote";
import { AnthropicKey } from "../key-management/anthropic/provider";
import { AwsBedrockKey } from "../key-management/aws/provider";
import { GcpKey } from "../key-management/gcp/provider";
import { logger } from "../../logger";
import { config } from "../../config";
import { findByAnthropicId } from "../claude-models";
const log = logger.child({ module: "tokenizer" });
export async function init() {
initClaude();
@@ -99,6 +110,87 @@ export async function countTokens({
completion,
}: TokenCountRequest): Promise<TokenCountResult> {
const time = process.hrtime();
// For prompt counting, try remote APIs first (only when enabled, have a key, and counting prompts)
if (config.useRemoteTokenCounting && prompt && req.key) {
try {
switch (service) {
case "anthropic-chat": {
if (req.service === "anthropic" && req.key.service === "anthropic") {
const result = await countAnthropicTokensRemote(
{
model: req.body.model,
messages: prompt.messages,
system: prompt.system,
tools: req.body.tools,
},
req.key as AnthropicKey
);
return { ...result, tokenization_duration_ms: getElapsedMs(time) };
}
break;
}
case "anthropic-text": {
if (req.service === "anthropic" && req.key.service === "anthropic") {
// Anthropic's API doesn't support text completion counting, fall through to local
break;
}
break;
}
}
// For AWS Bedrock services, use AWS token counting
if (req.service === "aws" && req.key.service === "aws") {
if (service === "anthropic-chat") {
// Convert Anthropic model ID to AWS Bedrock model ID
const anthropicModelId = req.body.model;
const claudeMapping = findByAnthropicId(anthropicModelId);
const awsModelId = claudeMapping?.awsId || anthropicModelId;
// Build the request body in Anthropic format - must include all required fields
const bodyObj: any = {
messages: req.body.messages,
max_tokens: req.body.max_tokens,
anthropic_version: req.body.anthropic_version || "bedrock-2023-05-31",
};
if (req.body.system) bodyObj.system = req.body.system;
if (req.body.tools) bodyObj.tools = req.body.tools;
if (req.body.tool_choice) bodyObj.tool_choice = req.body.tool_choice;
// AWS expects the body as a base64-encoded string
const bodyJson = JSON.stringify(bodyObj);
const bodyBase64 = Buffer.from(bodyJson, "utf-8").toString("base64");
const result = await countAwsTokensRemote(
awsModelId,
{
input: {
invokeModel: {
body: bodyBase64,
},
},
},
req.key as AwsBedrockKey
);
return { ...result, tokenization_duration_ms: getElapsedMs(time) };
}
}
// For GCP Vertex AI services, use GCP token counting
if (req.service === "gcp" && req.key.service === "gcp") {
// GCP uses a different format, would need transformation
// For now, fall through to local counting
}
} catch (error) {
// Fall through to local tokenization
log.debug(
{ error: (error as Error).message, service },
"Remote token counting failed, using local tokenizer"
);
}
}
// Fall back to local tokenization
switch (service) {
case "anthropic-chat":
case "anthropic-text":