mirror of
https://gitgud.io/reanon/nonono.git
synced 2026-05-11 11:40:12 -07:00
claude tokenizer+cache working
This commit is contained in:
@@ -286,6 +286,14 @@ type Config = {
|
||||
googleSheetsSpreadsheetId?: string;
|
||||
/** Whether to periodically check keys for usage and validity. */
|
||||
checkKeys: boolean;
|
||||
/**
|
||||
* Whether to use remote API token counting endpoints for Anthropic, AWS
|
||||
* Bedrock, and GCP Vertex AI. When enabled, the proxy will use the provider's
|
||||
* token counting API to get accurate token counts for prompts (including
|
||||
* images). Output tokens are always counted from API responses when available.
|
||||
* Falls back to local tokenization if remote counting fails.
|
||||
*/
|
||||
useRemoteTokenCounting: boolean;
|
||||
/** Whether to publicly show total token costs on the info page. */
|
||||
showTokenCosts: boolean;
|
||||
/**
|
||||
@@ -567,6 +575,7 @@ export const config: Config = {
|
||||
),
|
||||
logLevel: getEnvWithDefault("LOG_LEVEL", "info"),
|
||||
checkKeys: getEnvWithDefault("CHECK_KEYS", !isDev),
|
||||
useRemoteTokenCounting: getEnvWithDefault("USE_REMOTE_TOKEN_COUNTING", true),
|
||||
showTokenCosts: getEnvWithDefault("SHOW_TOKEN_COSTS", false),
|
||||
allowAwsLogging: getEnvWithDefault("ALLOW_AWS_LOGGING", false),
|
||||
promptLogging: getEnvWithDefault("PROMPT_LOGGING", false),
|
||||
|
||||
+13
-6
@@ -196,8 +196,11 @@ function setAnthropicBetaHeader(req: Request) {
|
||||
betaHeaders.push("extended-cache-ttl-2025-04-11");
|
||||
}
|
||||
|
||||
// Add 1M context beta header for Claude Sonnet 4 if context > 200k tokens
|
||||
if (model?.includes("claude-sonnet-4") && req.promptTokens && req.outputTokens) {
|
||||
// Add 1M context beta header for Claude Sonnet 4/Opus 4 if context > 200k tokens
|
||||
const supportsBigContext =
|
||||
model?.includes("claude-sonnet-4") ||
|
||||
model?.includes("claude-opus-4");
|
||||
if (supportsBigContext && req.promptTokens && req.outputTokens) {
|
||||
const contextTokens = req.promptTokens + req.outputTokens;
|
||||
if (contextTokens > 200000) {
|
||||
betaHeaders.push("context-1m-2025-08-07");
|
||||
@@ -211,8 +214,8 @@ function setAnthropicBetaHeader(req: Request) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds web search tool for Claude-3.5 and Claude-3.7 models when enable_web_search is true
|
||||
*
|
||||
* Adds web search tool for Claude-3.5, Claude-3.7, Claude-4, and Claude-4.5 models when enable_web_search is true
|
||||
*
|
||||
* Supports all optional parameters documented in the Claude API:
|
||||
* - max_uses: Limit the number of searches per request
|
||||
* - allowed_domains: Only include results from these domains
|
||||
@@ -323,7 +326,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
*/
|
||||
const preprocessAnthropicTextRequest: RequestHandler = (req, res, next) => {
|
||||
const model = req.body.model;
|
||||
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
|
||||
const isClaude4Model =
|
||||
model?.includes("claude-sonnet-4") ||
|
||||
model?.includes("claude-opus-4");
|
||||
if (model?.startsWith("claude-3") || isClaude4Model) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
@@ -356,7 +361,9 @@ const oaiToChatPreprocessor = createPreprocessorMiddleware(
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
maybeReassignModel(req);
|
||||
const model = req.body.model;
|
||||
const isClaude4 = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
|
||||
const isClaude4 =
|
||||
model?.includes("claude-sonnet-4") ||
|
||||
model?.includes("claude-opus-4");
|
||||
if (model?.includes("claude-3") || isClaude4) {
|
||||
oaiToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
|
||||
@@ -102,7 +102,9 @@ const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
*/
|
||||
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
|
||||
const model = req.body.model;
|
||||
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
|
||||
const isClaude4Model =
|
||||
model?.includes("claude-sonnet-4") ||
|
||||
model?.includes("claude-opus-4");
|
||||
if (model?.includes("claude-3") || isClaude4Model) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
@@ -126,7 +128,9 @@ const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
const model = req.body.model;
|
||||
const isClaude4Model = model?.includes("claude-sonnet-4") || model?.includes("claude-opus-4");
|
||||
const isClaude4Model =
|
||||
model?.includes("claude-sonnet-4") ||
|
||||
model?.includes("claude-opus-4");
|
||||
if (model?.includes("claude-3") || isClaude4Model) {
|
||||
oaiToAwsChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
|
||||
@@ -32,8 +32,9 @@ export const addKey: ProxyReqMutator = (manager) => {
|
||||
|
||||
if (inboundApi === outboundApi) {
|
||||
// Pass streaming information for GPT-5 models that require verified keys for streaming
|
||||
// Pass request body for cache-aware key selection (Anthropic, AWS, GCP)
|
||||
const isStreaming = body.stream === true;
|
||||
assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming);
|
||||
assignedKey = keyPool.get(body.model, service, needsMultimodal, isStreaming, body);
|
||||
} else {
|
||||
switch (outboundApi) {
|
||||
// If we are translating between API formats we may need to select a model
|
||||
@@ -45,7 +46,7 @@ export const addKey: ProxyReqMutator = (manager) => {
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
case "google-ai":
|
||||
assignedKey = keyPool.get(body.model, service);
|
||||
assignedKey = keyPool.get(body.model, service, undefined, undefined, body);
|
||||
break;
|
||||
case "openai-text":
|
||||
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
|
||||
|
||||
@@ -24,7 +24,8 @@ const AMZ_HOST =
|
||||
export const signAwsRequest: ProxyReqMutator = async (manager) => {
|
||||
const req = manager.request;
|
||||
const { model, stream } = req.body;
|
||||
const key = keyPool.get(model, "aws") as AwsBedrockKey;
|
||||
// Pass request body to enable cache-aware key selection
|
||||
const key = keyPool.get(model, "aws", undefined, stream, req.body) as AwsBedrockKey;
|
||||
manager.setKey(key);
|
||||
|
||||
let system = req.body.system ?? "";
|
||||
@@ -50,6 +51,8 @@ export const signAwsRequest: ProxyReqMutator = async (manager) => {
|
||||
|
||||
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
|
||||
// with the headers generated by the SDK.
|
||||
// Note: AWS Bedrock uses body parameters (anthropic_version) instead of headers
|
||||
// for versioning, so we don't include anthropic-beta or anthropic-version headers.
|
||||
const newRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
@@ -130,6 +133,9 @@ function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown {
|
||||
.parse(req.body);
|
||||
break;
|
||||
case "anthropic-chat":
|
||||
// Preserve anthropic_version if user provided it (for beta features)
|
||||
const userAnthropicVersion = (req.body as any).anthropic_version;
|
||||
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
@@ -140,11 +146,14 @@ function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown {
|
||||
top_p: true,
|
||||
tools: true,
|
||||
tool_choice: true,
|
||||
thinking: true
|
||||
thinking: true,
|
||||
cache_control: true
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "bedrock-2023-05-31";
|
||||
|
||||
// Use user-provided version or default to bedrock-2023-05-31
|
||||
strippedParams.anthropic_version = userAnthropicVersion || "bedrock-2023-05-31";
|
||||
break;
|
||||
case "mistral-ai":
|
||||
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
|
||||
|
||||
@@ -20,7 +20,7 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
||||
}
|
||||
|
||||
const { model } = req.body;
|
||||
const key: GcpKey = keyPool.get(model, "gcp") as GcpKey;
|
||||
const key: GcpKey = keyPool.get(model, "gcp", undefined, undefined, req.body) as GcpKey;
|
||||
|
||||
if (!key.accessToken || Date.now() > key.accessTokenExpiresAt) {
|
||||
const [token, durationSec] = await refreshGcpAccessToken(key);
|
||||
@@ -38,6 +38,9 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
||||
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
// TODO: Support tools
|
||||
// Preserve anthropic_version if user provided it (for beta features)
|
||||
const userAnthropicVersion = (req.body as any).anthropic_version;
|
||||
|
||||
let strippedParams: Record<string, unknown>;
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
@@ -50,11 +53,14 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
||||
stream: true,
|
||||
tools: true,
|
||||
tool_choice: true,
|
||||
thinking: true
|
||||
thinking: true,
|
||||
cache_control: true
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "vertex-2023-10-16";
|
||||
|
||||
// Use user-provided version or default to vertex-2023-10-16
|
||||
strippedParams.anthropic_version = userAnthropicVersion || "vertex-2023-10-16";
|
||||
|
||||
const credential = await getCredentialsFromGcpKey(key);
|
||||
|
||||
@@ -63,6 +69,8 @@ export const signGcpRequest: ProxyReqMutator = async (manager) => {
|
||||
// stream adapter selects the correct transformer.
|
||||
manager.setHeader("anthropic-version", "2023-06-01");
|
||||
|
||||
// GCP Vertex AI uses body parameter anthropic_version (set above) for versioning,
|
||||
// not anthropic-beta header. Beta features are enabled through body params.
|
||||
manager.setSignedRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
|
||||
@@ -7,21 +7,50 @@ import {
|
||||
AnthropicChatMessage,
|
||||
flattenAnthropicMessages,
|
||||
} from "../../../../shared/api-schemas/anthropic";
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
ContentItem,
|
||||
isMistralVisionModel
|
||||
isMistralVisionModel
|
||||
} from "../../../../shared/api-schemas/mistral-ai";
|
||||
import { isGrokVisionModel } from "../../../../shared/api-schemas/xai";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { config } from "../../../../config";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
* tokens and assigns the count to the request.
|
||||
*
|
||||
* If remote token counting is enabled, we temporarily get a key from the pool
|
||||
* to use the remote API, then clear it. The actual key assignment happens
|
||||
* later in the mutators after the request is dequeued.
|
||||
*/
|
||||
export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
const service = req.outboundApi;
|
||||
let result;
|
||||
|
||||
// For remote token counting, temporarily get a key from the pool
|
||||
// We don't permanently assign it - that happens in mutators after dequeue
|
||||
// IMPORTANT: Don't pass req.body here to avoid premature cache fingerprinting
|
||||
// Cache fingerprinting should only happen during actual key assignment in mutators,
|
||||
// after all preprocessing (including model name transformations) is complete
|
||||
let tempKey;
|
||||
const hadKey = !!req.key;
|
||||
if (config.useRemoteTokenCounting && !req.key) {
|
||||
try {
|
||||
tempKey = keyPool.get(req.body.model, req.service, undefined, undefined);
|
||||
req.key = tempKey; // Temporarily assign for token counting
|
||||
req.log.debug(
|
||||
{ keyHash: tempKey.hash, service: req.service },
|
||||
"Temporarily assigned key for remote token counting"
|
||||
);
|
||||
} catch (error) {
|
||||
req.log.debug(
|
||||
{ error: (error as Error).message },
|
||||
"Could not get key for remote token counting, will use local tokenizer"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_completion_tokens || req.body.max_tokens;
|
||||
@@ -129,4 +158,14 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
req.log.debug({ result: result }, "Counted prompt tokens.");
|
||||
req.tokenizerInfo = req.tokenizerInfo ?? {};
|
||||
req.tokenizerInfo = { ...req.tokenizerInfo, ...result };
|
||||
|
||||
// Clear the temporary key if we assigned one for token counting
|
||||
// The real key will be assigned later in mutators after dequeue
|
||||
if (tempKey && !hadKey) {
|
||||
delete req.key;
|
||||
req.log.debug(
|
||||
{ keyHash: tempKey.hash },
|
||||
"Cleared temporary key after token counting"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -121,13 +121,13 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^claude-3/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^claude-(?:sonnet|opus)-4/)) {
|
||||
} else if (model.match(/^claude-(?:sonnet|opus)-4(?:-5)?/)) {
|
||||
modelMax = 1000000;
|
||||
} else if (model.match(/^gemini-/)) {
|
||||
modelMax = 1024000;
|
||||
} else if (model.match(/^anthropic\.claude-3/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4/)) {
|
||||
} else if (model.match(/^anthropic\.claude-(?:sonnet|opus)-4(?:-5)?/)) {
|
||||
modelMax = 1000000;
|
||||
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
|
||||
modelMax = 200000;
|
||||
|
||||
@@ -1094,15 +1094,130 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
try {
|
||||
assertJsonResponse(body);
|
||||
const service = req.outboundApi;
|
||||
const completion = getCompletionFromBody(req, body);
|
||||
const tokens = await countTokens({ req, completion, service });
|
||||
|
||||
if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") {
|
||||
// O1 consumes (a significant amount of) invisible tokens for the chain-
|
||||
// of-thought reasoning. We have no way to count these other than to check
|
||||
// the response body.
|
||||
tokens.reasoning_tokens =
|
||||
body.usage?.completion_tokens_details?.reasoning_tokens;
|
||||
|
||||
// Try to get token counts from the API response first
|
||||
let tokens: { token_count: number; tokenizer: string; reasoning_tokens?: number } | null = null;
|
||||
|
||||
// Anthropic API returns usage data in the response
|
||||
if (service === "anthropic-chat" && body.usage) {
|
||||
tokens = {
|
||||
token_count: body.usage.output_tokens || 0,
|
||||
tokenizer: "anthropic-api",
|
||||
};
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count, usage: body.usage },
|
||||
"Got output token count from Anthropic API response"
|
||||
);
|
||||
|
||||
// Sanity check: if request had cache_control, expect cache metrics
|
||||
if (req.body.system || req.body.tools || req.body.messages) {
|
||||
const hasCacheControl = checkForCacheControl(req.body);
|
||||
if (hasCacheControl) {
|
||||
const cacheRead = body.usage.cache_read_input_tokens || 0;
|
||||
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
|
||||
if (cacheRead === 0 && cacheCreation === 0) {
|
||||
req.log.error(
|
||||
{ keyHash: req.key?.hash, usage: body.usage },
|
||||
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from Anthropic API"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// AWS Bedrock returns usage data in the response (same format as Anthropic)
|
||||
else if (req.service === "aws" && service === "anthropic-chat" && body.usage) {
|
||||
tokens = {
|
||||
token_count: body.usage.output_tokens || 0,
|
||||
tokenizer: "aws-bedrock-api",
|
||||
};
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count, usage: body.usage },
|
||||
"Got output token count from AWS Bedrock API response"
|
||||
);
|
||||
|
||||
// Sanity check: if request had cache_control, expect cache metrics
|
||||
if (req.body.system || req.body.tools || req.body.messages) {
|
||||
const hasCacheControl = checkForCacheControl(req.body);
|
||||
if (hasCacheControl) {
|
||||
const cacheRead = body.usage.cache_read_input_tokens || 0;
|
||||
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
|
||||
if (cacheRead === 0 && cacheCreation === 0) {
|
||||
req.log.error(
|
||||
{ keyHash: req.key?.hash, usage: body.usage },
|
||||
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from AWS Bedrock API"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// GCP Vertex AI returns usage data in the response
|
||||
// For Anthropic models, GCP returns Anthropic format (usage.output_tokens)
|
||||
// For Gemini models, GCP returns GCP format (usageMetadata.candidatesTokenCount)
|
||||
else if (req.service === "gcp") {
|
||||
if (service === "anthropic-chat" && body.usage?.output_tokens) {
|
||||
tokens = {
|
||||
token_count: body.usage.output_tokens || 0,
|
||||
tokenizer: "gcp-anthropic-api",
|
||||
};
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count, usage: body.usage },
|
||||
"Got output token count from GCP Vertex AI (Anthropic format)"
|
||||
);
|
||||
|
||||
// Sanity check: if request had cache_control, expect cache metrics
|
||||
if (req.body.system || req.body.tools || req.body.messages) {
|
||||
const hasCacheControl = checkForCacheControl(req.body);
|
||||
if (hasCacheControl) {
|
||||
const cacheRead = body.usage.cache_read_input_tokens || 0;
|
||||
const cacheCreation = body.usage.cache_creation_input_tokens || 0;
|
||||
if (cacheRead === 0 && cacheCreation === 0) {
|
||||
req.log.error(
|
||||
{ keyHash: req.key?.hash, usage: body.usage },
|
||||
"CACHE SANITY CHECK FAILED: Request had cache_control but received NO cache metrics from GCP Vertex AI API"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (body.usageMetadata) {
|
||||
tokens = {
|
||||
token_count: body.usageMetadata.candidatesTokenCount || 0,
|
||||
tokenizer: "gcp-vertex-api",
|
||||
};
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count, usageMetadata: body.usageMetadata },
|
||||
"Got output token count from GCP Vertex AI (Gemini format)"
|
||||
);
|
||||
}
|
||||
}
|
||||
// OpenAI and similar services return usage data
|
||||
else if (body.usage?.completion_tokens) {
|
||||
tokens = {
|
||||
token_count: body.usage.completion_tokens,
|
||||
tokenizer: "api-usage-data",
|
||||
};
|
||||
|
||||
if (req.service === "openai" || req.service === "azure" || req.service === "deepseek" || req.service === "glm" || req.service === "cohere" || req.service === "qwen") {
|
||||
// O1 consumes (a significant amount of) invisible tokens for the chain-
|
||||
// of-thought reasoning. We have no way to count these other than to check
|
||||
// the response body.
|
||||
tokens.reasoning_tokens =
|
||||
body.usage?.completion_tokens_details?.reasoning_tokens;
|
||||
}
|
||||
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count, usage: body.usage },
|
||||
"Got output token count from API usage data"
|
||||
);
|
||||
}
|
||||
|
||||
// Fall back to local tokenization if no usage data is available
|
||||
if (!tokens) {
|
||||
const completion = getCompletionFromBody(req, body);
|
||||
tokens = await countTokens({ req, completion, service });
|
||||
req.log.debug(
|
||||
{ service, outputTokens: tokens.token_count },
|
||||
"Counted output tokens locally (no API usage data)"
|
||||
);
|
||||
}
|
||||
|
||||
req.log.debug(
|
||||
@@ -1193,6 +1308,35 @@ function getAwsErrorType(header: string | string[] | undefined) {
|
||||
return val || String(header);
|
||||
}
|
||||
|
||||
function checkForCacheControl(body: any): boolean {
|
||||
// Check tools
|
||||
if (Array.isArray(body.tools)) {
|
||||
for (const tool of body.tools) {
|
||||
if (tool.cache_control) return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check system blocks
|
||||
if (Array.isArray(body.system)) {
|
||||
for (const block of body.system) {
|
||||
if (block.cache_control) return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check message content blocks
|
||||
if (Array.isArray(body.messages)) {
|
||||
for (const message of body.messages) {
|
||||
if (Array.isArray(message.content)) {
|
||||
for (const block of message.content) {
|
||||
if (block.cache_control) return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
function assertJsonResponse(body: any): asserts body is Record<string, any> {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error(`Expected response to be an object, got ${typeof body}`);
|
||||
|
||||
@@ -8,7 +8,12 @@ export type AnthropicChatCompletionResponse = {
|
||||
model: string;
|
||||
stop_reason: string | null;
|
||||
stop_sequence: string | null;
|
||||
usage: { input_tokens: number; output_tokens: number };
|
||||
usage: {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
cache_creation_input_tokens?: number;
|
||||
cache_read_input_tokens?: number;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -43,6 +48,24 @@ export function mergeEventsForAnthropicChat(
|
||||
acc.content[0].text += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
// OpenAI events may include usage data (extended by our transformer)
|
||||
// Usage tokens should be set to the latest value, not accumulated
|
||||
// (they represent total counts, not deltas)
|
||||
if ((event as any).usage) {
|
||||
if ((event as any).usage.input_tokens !== undefined) {
|
||||
acc.usage.input_tokens = (event as any).usage.input_tokens;
|
||||
}
|
||||
if ((event as any).usage.output_tokens !== undefined) {
|
||||
acc.usage.output_tokens = (event as any).usage.output_tokens;
|
||||
}
|
||||
if ((event as any).usage.cache_creation_input_tokens !== undefined) {
|
||||
acc.usage.cache_creation_input_tokens = (event as any).usage.cache_creation_input_tokens;
|
||||
}
|
||||
if ((event as any).usage.cache_read_input_tokens !== undefined) {
|
||||
acc.usage.cache_read_input_tokens = (event as any).usage.cache_read_input_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
|
||||
@@ -10,6 +10,11 @@ export type OpenAiChatCompletionResponse = {
|
||||
finish_reason: string | null;
|
||||
index: number;
|
||||
}[];
|
||||
usage?: {
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -52,6 +57,17 @@ export function mergeEventsForOpenAIChat(
|
||||
acc.choices[0].message.content += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
// Accumulate usage data from events (OpenAI may send this in the final event)
|
||||
if ((event as any).usage) {
|
||||
if (!acc.usage) {
|
||||
acc.usage = {};
|
||||
}
|
||||
const usage = (event as any).usage;
|
||||
if (usage.prompt_tokens) acc.usage.prompt_tokens = usage.prompt_tokens;
|
||||
if (usage.completion_tokens) acc.usage.completion_tokens = usage.completion_tokens;
|
||||
if (usage.total_tokens) acc.usage.total_tokens = usage.total_tokens;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
|
||||
@@ -50,12 +50,32 @@ export class SSEStreamAdapter extends Transform {
|
||||
const event = Buffer.from(bytes, "base64").toString("utf8");
|
||||
const eventObj = JSON.parse(event);
|
||||
|
||||
// AWS Bedrock includes usage metrics in the event stream headers
|
||||
// Extract and attach them to the event object for downstream processing
|
||||
const invocationMetrics = headers["amazon-bedrock-invocationMetrics"];
|
||||
if (invocationMetrics?.value) {
|
||||
try {
|
||||
const metricsStr = typeof invocationMetrics.value === 'string'
|
||||
? invocationMetrics.value
|
||||
: JSON.stringify(invocationMetrics.value);
|
||||
const metricsObj = JSON.parse(metricsStr);
|
||||
eventObj["amazon-bedrock-invocationMetrics"] = metricsObj;
|
||||
} catch (e) {
|
||||
this.log.warn(
|
||||
{ invocationMetrics: invocationMetrics.value },
|
||||
"Failed to parse AWS invocationMetrics"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const eventWithMetrics = JSON.stringify(eventObj);
|
||||
|
||||
if ("completion" in eventObj) {
|
||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||
return ["event: completion", `data: ${eventWithMetrics}`].join(`\n`);
|
||||
} else if (eventObj.type) {
|
||||
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
|
||||
return [`event: ${eventObj.type}`, `data: ${eventWithMetrics}`].join(`\n`);
|
||||
} else {
|
||||
return `data: ${event}`;
|
||||
return `data: ${eventWithMetrics}`;
|
||||
}
|
||||
}
|
||||
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
|
||||
|
||||
@@ -22,8 +22,70 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
// Try to extract usage data from message_start and message_delta events
|
||||
// Also check for AWS Bedrock invocationMetrics
|
||||
let usageData: {
|
||||
input_tokens?: number;
|
||||
output_tokens?: number;
|
||||
cache_creation_input_tokens?: number;
|
||||
cache_read_input_tokens?: number;
|
||||
} | undefined;
|
||||
try {
|
||||
const parsed = JSON.parse(rawEvent.data);
|
||||
if (parsed.type === "message_start" && parsed.message?.usage) {
|
||||
usageData = {
|
||||
input_tokens: parsed.message.usage.input_tokens,
|
||||
output_tokens: parsed.message.usage.output_tokens,
|
||||
cache_creation_input_tokens: parsed.message.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: parsed.message.usage.cache_read_input_tokens,
|
||||
};
|
||||
} else if (parsed.type === "message_delta" && parsed.delta?.usage) {
|
||||
usageData = {
|
||||
output_tokens: parsed.delta.usage.output_tokens,
|
||||
cache_creation_input_tokens: parsed.delta.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: parsed.delta.usage.cache_read_input_tokens,
|
||||
};
|
||||
}
|
||||
// AWS Bedrock includes usage in amazon-bedrock-invocationMetrics
|
||||
// AWS uses PascalCase field names (CacheReadInputTokens, CacheWriteInputTokens)
|
||||
else if (parsed["amazon-bedrock-invocationMetrics"]) {
|
||||
const metrics = parsed["amazon-bedrock-invocationMetrics"];
|
||||
usageData = {
|
||||
input_tokens: metrics.inputTokenCount,
|
||||
output_tokens: metrics.outputTokenCount,
|
||||
// Map AWS PascalCase to Anthropic snake_case
|
||||
cache_read_input_tokens: metrics.cacheReadInputTokenCount,
|
||||
cache_creation_input_tokens: metrics.cacheWriteInputTokenCount,
|
||||
};
|
||||
log.debug(
|
||||
{ metrics, usageData },
|
||||
"Extracted usage from AWS invocationMetrics"
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore parsing errors
|
||||
}
|
||||
|
||||
const deltaEvent = asAnthropicChatDelta(rawEvent);
|
||||
if (!deltaEvent) {
|
||||
// If we have usage data but no delta, still emit an event with usage
|
||||
if (usageData) {
|
||||
const usageEvent = {
|
||||
id: params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: {},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
usage: usageData,
|
||||
};
|
||||
return { position: -1, event: usageEvent };
|
||||
}
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
@@ -39,6 +101,7 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
...(usageData && { usage: usageData }),
|
||||
};
|
||||
|
||||
return { position: -1, event: newEvent };
|
||||
|
||||
@@ -15,6 +15,11 @@ type GoogleAIStreamEvent = {
|
||||
tokenCount?: number;
|
||||
safetyRatings: { category: string; probability: string }[];
|
||||
}[];
|
||||
usageMetadata?: {
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -49,7 +54,7 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
content = content.replace(/^(.*?): /, "").trim();
|
||||
}
|
||||
|
||||
const newEvent = {
|
||||
const newEvent: any = {
|
||||
id: "goo-" + params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
@@ -63,6 +68,19 @@ export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
],
|
||||
};
|
||||
|
||||
// Extract usage metadata from GCP/Vertex AI responses
|
||||
if (completionEvent.usageMetadata) {
|
||||
newEvent.usage = {
|
||||
prompt_tokens: completionEvent.usageMetadata.promptTokenCount,
|
||||
completion_tokens: completionEvent.usageMetadata.candidatesTokenCount,
|
||||
total_tokens: completionEvent.usageMetadata.totalTokenCount,
|
||||
};
|
||||
log.debug(
|
||||
{ usageMetadata: completionEvent.usageMetadata },
|
||||
"Extracted usage from GCP usageMetadata"
|
||||
);
|
||||
}
|
||||
|
||||
return { position: -1, event: newEvent };
|
||||
};
|
||||
|
||||
|
||||
@@ -5,6 +5,11 @@ import { logger } from "../../../logger";
|
||||
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
|
||||
import { AnthropicKeyChecker } from "./checker";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import {
|
||||
generateCacheFingerprint,
|
||||
recordCacheUsage,
|
||||
getCachedKeyHash,
|
||||
} from "../cache-tracker";
|
||||
|
||||
export type AnthropicKeyUpdate = Omit<
|
||||
Partial<AnthropicKey>,
|
||||
@@ -136,7 +141,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(rawModel: string) {
|
||||
public get(rawModel: string, _streaming?: boolean, requestBody?: any) {
|
||||
this.log.debug({ model: rawModel }, "Selecting key");
|
||||
const needsMultimodal = rawModel.endsWith("-multimodal");
|
||||
|
||||
@@ -152,7 +157,44 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
);
|
||||
}
|
||||
|
||||
// Generate cache fingerprint if request body contains cache_control
|
||||
const cacheFingerprint = requestBody
|
||||
? generateCacheFingerprint(requestBody)
|
||||
: null;
|
||||
|
||||
// Try to get cached key if we have a fingerprint
|
||||
let preferredKeyHash: string | null = null;
|
||||
let matchedFingerprint: string | null = null;
|
||||
if (cacheFingerprint) {
|
||||
const cacheResult = getCachedKeyHash(cacheFingerprint);
|
||||
if (cacheResult) {
|
||||
preferredKeyHash = cacheResult.keyHash;
|
||||
matchedFingerprint = cacheResult.matchedFingerprint;
|
||||
// Check if the cached key is still available
|
||||
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
|
||||
if (cachedKey) {
|
||||
this.log.debug(
|
||||
{
|
||||
requestedModel: rawModel,
|
||||
cacheFingerprint,
|
||||
keyHash: preferredKeyHash,
|
||||
},
|
||||
"Using cached key for prompt caching optimization"
|
||||
);
|
||||
} else {
|
||||
// Cached key no longer available
|
||||
preferredKeyHash = null;
|
||||
matchedFingerprint = null;
|
||||
this.log.debug(
|
||||
{ cacheFingerprint, keyHash: preferredKeyHash },
|
||||
"Cached key not available, selecting new key"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 0. Cache affinity (if we have a cached key preference)
|
||||
// 1. Keys which are not rate limit locked
|
||||
// 2. Keys with the highest tier
|
||||
// 3. Keys which are not pozzed
|
||||
@@ -161,6 +203,12 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
// Highest priority: cache affinity
|
||||
if (preferredKeyHash) {
|
||||
if (a.hash === preferredKeyHash) return -1;
|
||||
if (b.hash === preferredKeyHash) return 1;
|
||||
}
|
||||
|
||||
const aLockoutPeriod = getKeyLockout(a);
|
||||
const bLockoutPeriod = getKeyLockout(b);
|
||||
|
||||
@@ -183,6 +231,13 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
|
||||
// Record cache usage for future requests
|
||||
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
|
||||
if (cacheFingerprint) {
|
||||
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
|
||||
}
|
||||
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,11 @@ import { findByAnthropicId } from "../../claude-models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { AwsKeyChecker } from "./checker";
|
||||
import {
|
||||
generateCacheFingerprint,
|
||||
recordCacheUsage,
|
||||
getCachedKeyHash,
|
||||
} from "../cache-tracker";
|
||||
|
||||
// AwsBedrockKeyUsage is removed, tokenUsage from base Key interface will be used.
|
||||
export interface AwsBedrockKey extends Key {
|
||||
@@ -90,14 +95,14 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
public get(model: string, _streaming?: boolean, requestBody?: any) {
|
||||
let neededVariantId = model;
|
||||
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
|
||||
// Generally all AWS model IDs are supersets of the original vendor IDs.
|
||||
// Claude 2 is the only model that breaks this convention; Anthropic calls
|
||||
// it claude-2 but AWS calls it claude-v2.
|
||||
if (model.includes("claude-2")) neededVariantId = "claude-v2";
|
||||
|
||||
|
||||
// For Claude models, try to resolve aliases to AWS model IDs
|
||||
if (model.includes("claude") && !model.includes("anthropic.")) {
|
||||
const claudeMapping = findByAnthropicId(model);
|
||||
@@ -105,7 +110,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
neededVariantId = claudeMapping.awsId;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
const neededFamily = getAwsBedrockModelFamily(model);
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
@@ -122,6 +127,42 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
);
|
||||
});
|
||||
|
||||
// Generate cache fingerprint if request body contains cache_control
|
||||
const cacheFingerprint = requestBody
|
||||
? generateCacheFingerprint(requestBody)
|
||||
: null;
|
||||
|
||||
// Try to get cached key if we have a fingerprint
|
||||
let preferredKeyHash: string | null = null;
|
||||
let matchedFingerprint: string | null = null;
|
||||
if (cacheFingerprint) {
|
||||
const cacheResult = getCachedKeyHash(cacheFingerprint);
|
||||
if (cacheResult) {
|
||||
preferredKeyHash = cacheResult.keyHash;
|
||||
matchedFingerprint = cacheResult.matchedFingerprint;
|
||||
// Check if the cached key is still available
|
||||
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
|
||||
if (cachedKey) {
|
||||
this.log.debug(
|
||||
{
|
||||
requestedModel: model,
|
||||
cacheFingerprint,
|
||||
keyHash: preferredKeyHash,
|
||||
},
|
||||
"Using cached key for prompt caching optimization"
|
||||
);
|
||||
} else {
|
||||
// Cached key no longer available
|
||||
preferredKeyHash = null;
|
||||
matchedFingerprint = null;
|
||||
this.log.debug(
|
||||
{ cacheFingerprint, keyHash: preferredKeyHash },
|
||||
"Cached key not available, selecting new key"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
requestedModel: model,
|
||||
@@ -129,6 +170,8 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
selectedFamily: neededFamily,
|
||||
totalKeys: this.keys.length,
|
||||
availableKeys: availableKeys.length,
|
||||
cacheFingerprint,
|
||||
hasCachedKey: !!preferredKeyHash,
|
||||
},
|
||||
"Selecting AWS key"
|
||||
);
|
||||
@@ -140,22 +183,36 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
/**
|
||||
* Comparator for prioritizing keys on inference profile compatibility.
|
||||
* Requests made via inference profiles have higher rate limits so we want
|
||||
* to use keys with compatible inference profiles first.
|
||||
* Comparator for prioritizing keys based on:
|
||||
* 1. Cache affinity (if we have a cached key preference)
|
||||
* 2. Inference profile compatibility
|
||||
*/
|
||||
const hasInferenceProfile = (
|
||||
a: AwsBedrockKey,
|
||||
b: AwsBedrockKey
|
||||
) => {
|
||||
const keyComparator = (a: AwsBedrockKey, b: AwsBedrockKey) => {
|
||||
// Highest priority: cache affinity
|
||||
if (preferredKeyHash) {
|
||||
if (a.hash === preferredKeyHash) return -1;
|
||||
if (b.hash === preferredKeyHash) return 1;
|
||||
}
|
||||
|
||||
// Second priority: inference profile compatibility
|
||||
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
|
||||
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
|
||||
return aMatch - bMatch;
|
||||
const profileDiff = bMatch - aMatch;
|
||||
if (profileDiff !== 0) return profileDiff;
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
|
||||
const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
|
||||
// Record cache usage for future requests
|
||||
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
|
||||
if (cacheFingerprint) {
|
||||
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
|
||||
}
|
||||
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,457 @@
|
||||
import crypto from "crypto";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
/**
|
||||
* Deterministic JSON stringify that sorts object keys to ensure consistent hashing.
|
||||
*/
|
||||
function deterministicStringify(obj: any): string {
|
||||
if (obj === null || obj === undefined) return String(obj);
|
||||
if (typeof obj !== "object") return JSON.stringify(obj);
|
||||
if (Array.isArray(obj)) return `[${obj.map(deterministicStringify).join(",")}]`;
|
||||
|
||||
const keys = Object.keys(obj).sort();
|
||||
const pairs = keys.map(k => `"${k}":${deterministicStringify(obj[k])}`);
|
||||
return `{${pairs.join(",")}}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Universal cache tracker for all providers (Anthropic, AWS, GCP).
|
||||
*
|
||||
* Tracks which keys have cached which prompt prefixes to optimize prompt caching.
|
||||
* Each API key has its own cache, so routing requests with identical cacheable
|
||||
* content to the same key maximizes cache hits.
|
||||
*
|
||||
* Cache rules (per Anthropic/Bedrock/Vertex docs):
|
||||
* - 100% identical prompt segments required for cache hit
|
||||
* - Default 5-minute TTL, refreshed on each use
|
||||
* - Optional 1-hour TTL
|
||||
* - Cache becomes available after first response begins
|
||||
* - Minimum 1024-2048 tokens (model-dependent)
|
||||
*/
|
||||
|
||||
interface CacheEntry {
|
||||
keyHash: string;
|
||||
lastUsed: number;
|
||||
hitCount: number;
|
||||
ttl: number;
|
||||
// Store all prefix fingerprints for hierarchical matching
|
||||
prefixFingerprints?: string[];
|
||||
}
|
||||
|
||||
const log = logger.child({ module: "cache-tracker" });
|
||||
|
||||
// Maps cache fingerprints to the key that has cached that content
|
||||
const cacheMap = new Map<string, CacheEntry>();
|
||||
|
||||
// Default TTLs in milliseconds
|
||||
const TTL_5_MINUTES = 5 * 60 * 1000;
|
||||
const TTL_1_HOUR = 60 * 60 * 1000;
|
||||
|
||||
/**
|
||||
* Generates a fingerprint for cacheable content in a request.
|
||||
* The fingerprint includes content up to and including the LAST cache_control
|
||||
* breakpoint, following Anthropic's hierarchy: tools → system → messages.
|
||||
*
|
||||
* Key insight for handling ever-growing prompts:
|
||||
* By fingerprinting only up to the last cache_control marker, we ensure that
|
||||
* new content added AFTER that marker (e.g., new user messages in a conversation)
|
||||
* won't change the fingerprint. This enables cache hits as conversations grow.
|
||||
*
|
||||
* How Anthropic's cache matching works:
|
||||
* - You place ONE cache_control at the end of your static/stable content
|
||||
* - The API automatically checks ~20 blocks BEFORE that breakpoint for cache hits
|
||||
* - It uses the longest matching prefix automatically
|
||||
* - You don't need multiple breakpoints - one at the end is sufficient
|
||||
*
|
||||
* Example conversation flow:
|
||||
* Request 1: [system prompt + cache_control] → Creates cache, fingerprint = "abc123"
|
||||
* Request 2: [system prompt + cache_control] + [user msg] + [assistant msg] + [user msg]
|
||||
* → Same fingerprint "abc123" → Cache HIT (new messages after breakpoint ignored)
|
||||
* Request 3: [MODIFIED system prompt + cache_control] + messages
|
||||
* → Different fingerprint "def456" → Cache MISS (prefix changed)
|
||||
*/
|
||||
export function generateCacheFingerprint(body: any): string | null {
|
||||
if (!body) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const parts: any[] = [];
|
||||
let hasCacheControl = false;
|
||||
const cacheBreakpoints: number[] = []; // Indices of cache_control markers
|
||||
const ttls: number[] = []; // TTL for each breakpoint
|
||||
|
||||
// 1. Process tools if present (tools come first in hierarchy)
|
||||
if (body.tools && Array.isArray(body.tools)) {
|
||||
for (let i = 0; i < body.tools.length; i++) {
|
||||
const tool = body.tools[i];
|
||||
// Only include the stable parts of the tool definition in the fingerprint
|
||||
// Exclude cache_control as it's metadata, not part of the tool definition
|
||||
const { cache_control, ...toolWithoutCache } = tool;
|
||||
parts.push({ type: "tool", tool: toolWithoutCache });
|
||||
if (cache_control) {
|
||||
hasCacheControl = true;
|
||||
cacheBreakpoints.push(parts.length - 1);
|
||||
ttls.push(parseTTL(cache_control.ttl));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Process system prompt if present
|
||||
if (body.system) {
|
||||
if (typeof body.system === "string") {
|
||||
// Normalize string system to same structure as array blocks for consistent fingerprinting
|
||||
parts.push({ type: "system_block", block_type: "text", text: body.system });
|
||||
} else if (Array.isArray(body.system)) {
|
||||
for (const block of body.system) {
|
||||
// System blocks can be:
|
||||
// - {type: "text", text: "...", cache_control?: ...}
|
||||
// - {type: "image", source: {...}, cache_control?: ...}
|
||||
const { cache_control, type, ...rest } = block;
|
||||
|
||||
// Create a consistent fingerprint structure
|
||||
const contentPart: any = { type: "system_block", block_type: type };
|
||||
|
||||
switch (type) {
|
||||
case "text":
|
||||
contentPart.text = (rest as any).text;
|
||||
break;
|
||||
case "image":
|
||||
// Hash image data for fingerprint
|
||||
if ((rest as any).source?.data) {
|
||||
contentPart.image_hash = crypto
|
||||
.createHash("sha256")
|
||||
.update((rest as any).source.data)
|
||||
.digest("hex")
|
||||
.slice(0, 16);
|
||||
}
|
||||
contentPart.media_type = (rest as any).source?.media_type;
|
||||
break;
|
||||
default:
|
||||
// For unknown block types, include all remaining fields
|
||||
Object.assign(contentPart, rest);
|
||||
}
|
||||
|
||||
parts.push(contentPart);
|
||||
|
||||
if (cache_control) {
|
||||
hasCacheControl = true;
|
||||
cacheBreakpoints.push(parts.length - 1);
|
||||
ttls.push(parseTTL(cache_control.ttl));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Process messages
|
||||
if (body.messages && Array.isArray(body.messages)) {
|
||||
for (const message of body.messages) {
|
||||
if (typeof message.content === "string") {
|
||||
parts.push({
|
||||
type: "message",
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
});
|
||||
} else if (Array.isArray(message.content)) {
|
||||
// For multimodal content, include each block
|
||||
for (const block of message.content) {
|
||||
// CRITICAL: Exclude cache_control metadata from fingerprinting
|
||||
// Only the actual content should affect the fingerprint
|
||||
const { cache_control, ...blockWithoutCache } = block;
|
||||
|
||||
const contentPart: any = {
|
||||
type: "message_block",
|
||||
role: message.role,
|
||||
block_type: blockWithoutCache.type,
|
||||
};
|
||||
|
||||
// Include essential identifying info without full data
|
||||
// (full data would make fingerprints too large)
|
||||
switch (blockWithoutCache.type) {
|
||||
case "text":
|
||||
contentPart.text = blockWithoutCache.text;
|
||||
break;
|
||||
case "image":
|
||||
// Hash image data for fingerprint
|
||||
if (blockWithoutCache.source?.data) {
|
||||
contentPart.image_hash = crypto
|
||||
.createHash("sha256")
|
||||
.update(blockWithoutCache.source.data)
|
||||
.digest("hex")
|
||||
.slice(0, 16);
|
||||
}
|
||||
contentPart.media_type = blockWithoutCache.source?.media_type;
|
||||
break;
|
||||
case "tool_use":
|
||||
// Don't include tool_id as it's often randomly generated
|
||||
// Only include tool name and input which are stable
|
||||
contentPart.tool_name = blockWithoutCache.name;
|
||||
contentPart.tool_input = blockWithoutCache.input;
|
||||
break;
|
||||
case "tool_result":
|
||||
// Don't include tool_use_id as it's randomly generated
|
||||
// Only include the actual result content and error status
|
||||
contentPart.is_error = blockWithoutCache.is_error;
|
||||
if (typeof blockWithoutCache.content === "string") {
|
||||
contentPart.content = blockWithoutCache.content;
|
||||
} else if (Array.isArray(blockWithoutCache.content)) {
|
||||
// tool_result content can be an array of content blocks
|
||||
contentPart.content = blockWithoutCache.content;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
parts.push(contentPart);
|
||||
|
||||
if (cache_control) {
|
||||
hasCacheControl = true;
|
||||
cacheBreakpoints.push(parts.length - 1);
|
||||
ttls.push(parseTTL(cache_control.ttl));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No caching if no cache_control directives present
|
||||
if (!hasCacheControl || cacheBreakpoints.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxTTL = Math.max(...ttls);
|
||||
|
||||
// Generate individual hashes for each part (tool, system block, message)
|
||||
// This allows prefix matching even when conversations grow
|
||||
const partHashes: string[] = [];
|
||||
for (const part of parts) {
|
||||
const partHash = crypto
|
||||
.createHash("sha256")
|
||||
.update(deterministicStringify(part))
|
||||
.digest("hex")
|
||||
.slice(0, 8); // Shorter hash per part
|
||||
partHashes.push(partHash);
|
||||
}
|
||||
|
||||
// Generate fingerprints for each cache_control position
|
||||
// Fingerprint is the concatenation of individual part hashes up to the breakpoint
|
||||
const prefixFingerprints: string[] = [];
|
||||
|
||||
for (const breakpoint of cacheBreakpoints) {
|
||||
// Concatenate part hashes up to and including the breakpoint
|
||||
const fingerprint = partHashes.slice(0, breakpoint + 1).join("");
|
||||
|
||||
log.trace(
|
||||
{
|
||||
fingerprint: fingerprint.substring(0, 32) + "...",
|
||||
breakpoint,
|
||||
prefixPartsCount: breakpoint + 1,
|
||||
},
|
||||
"Generated cache fingerprint"
|
||||
);
|
||||
|
||||
prefixFingerprints.push(fingerprint);
|
||||
}
|
||||
|
||||
// Return the deepest (last) fingerprint as the primary one
|
||||
const primaryFingerprint = prefixFingerprints[prefixFingerprints.length - 1];
|
||||
|
||||
// Store ALL prefix fingerprints in the cache map, not just the primary one
|
||||
// This allows future requests to match against any prefix level
|
||||
for (let i = 0; i < prefixFingerprints.length; i++) {
|
||||
const fp = prefixFingerprints[i];
|
||||
const existing = cacheMap.get(fp);
|
||||
|
||||
if (!existing) {
|
||||
cacheMap.set(fp, {
|
||||
keyHash: "",
|
||||
lastUsed: 0,
|
||||
hitCount: 0,
|
||||
ttl: maxTTL,
|
||||
prefixFingerprints: prefixFingerprints.slice(0, i + 1),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return primaryFingerprint;
|
||||
}
|
||||
|
||||
function parseTTL(ttl?: string): number {
|
||||
if (ttl === "1h") return TTL_1_HOUR;
|
||||
return TTL_5_MINUTES; // Default or "5m"
|
||||
}
|
||||
|
||||
/**
|
||||
* Records that a key was used for a request with the given cache fingerprint.
|
||||
*/
|
||||
export function recordCacheUsage(fingerprint: string, keyHash: string): void {
|
||||
if (!fingerprint) return;
|
||||
|
||||
const entry = cacheMap.get(fingerprint);
|
||||
if (!entry) {
|
||||
// Shouldn't happen, but handle gracefully
|
||||
cacheMap.set(fingerprint, {
|
||||
keyHash,
|
||||
lastUsed: Date.now(),
|
||||
hitCount: 1,
|
||||
ttl: TTL_5_MINUTES,
|
||||
});
|
||||
log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "New cache entry recorded");
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
if (entry.keyHash === keyHash) {
|
||||
// Same key - likely cache hit
|
||||
entry.lastUsed = now;
|
||||
entry.hitCount++;
|
||||
log.trace(
|
||||
{ fingerprint: fingerprint.substring(0, 32) + "...", keyHash, hitCount: entry.hitCount },
|
||||
"Cache usage recorded (likely cache hit)"
|
||||
);
|
||||
} else if (entry.keyHash === "") {
|
||||
// First use of this fingerprint
|
||||
entry.keyHash = keyHash;
|
||||
entry.lastUsed = now;
|
||||
entry.hitCount = 1;
|
||||
log.debug({ fingerprint: fingerprint.substring(0, 32) + "...", keyHash }, "First cache usage for fingerprint");
|
||||
} else {
|
||||
// Different key - cache miss, reset tracking
|
||||
log.debug(
|
||||
{ fingerprint: fingerprint.substring(0, 32) + "...", oldKey: entry.keyHash, newKey: keyHash },
|
||||
"Cache key changed (will cause cache miss)"
|
||||
);
|
||||
entry.keyHash = keyHash;
|
||||
entry.lastUsed = now;
|
||||
entry.hitCount = 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the key hash that has cached the longest matching prefix for the given
|
||||
* fingerprint or any of its sub-prefixes.
|
||||
*
|
||||
* This is crucial for handling moving cache breakpoints:
|
||||
* - If the current request has fingerprints ["fp1", "fp2", "fp3"]
|
||||
* - We search backwards from "fp3" → "fp2" → "fp1"
|
||||
* - Return the key that cached the longest available prefix
|
||||
*
|
||||
* Returns null if no cached key exists or all caches have expired.
|
||||
*/
|
||||
export function getCachedKeyHash(fingerprint: string): { keyHash: string; matchedFingerprint: string } | null {
|
||||
if (!fingerprint) return null;
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
// First, try exact match on the primary (deepest) fingerprint
|
||||
const primaryEntry = cacheMap.get(fingerprint);
|
||||
if (primaryEntry && primaryEntry.keyHash) {
|
||||
const age = now - primaryEntry.lastUsed;
|
||||
if (age <= primaryEntry.ttl) {
|
||||
log.trace(
|
||||
{
|
||||
fingerprint: fingerprint.substring(0, 32) + "...",
|
||||
keyHash: primaryEntry.keyHash,
|
||||
age,
|
||||
hitCount: primaryEntry.hitCount,
|
||||
matchType: "exact",
|
||||
},
|
||||
"Cache entry found (exact match)"
|
||||
);
|
||||
return { keyHash: primaryEntry.keyHash, matchedFingerprint: fingerprint };
|
||||
} else {
|
||||
log.trace({ fingerprint: fingerprint.substring(0, 32) + "...", age, ttl: primaryEntry.ttl }, "Cache entry expired");
|
||||
cacheMap.delete(fingerprint);
|
||||
}
|
||||
}
|
||||
|
||||
// If no exact match, search all cache entries for prefix matches
|
||||
// Since fingerprints are now concatenated part hashes, we can check if one is a prefix of another
|
||||
let bestMatch: { keyHash: string; matchedFingerprint: string } | null = null;
|
||||
let longestMatchLength = 0;
|
||||
|
||||
for (const [cachedFp, entry] of cacheMap.entries()) {
|
||||
if (!entry.keyHash) continue;
|
||||
|
||||
const age = now - entry.lastUsed;
|
||||
if (age > entry.ttl) {
|
||||
cacheMap.delete(cachedFp);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the cached fingerprint is a prefix of the current fingerprint
|
||||
// (cached request had fewer parts, current request has grown)
|
||||
if (fingerprint.startsWith(cachedFp) && cachedFp.length > longestMatchLength) {
|
||||
bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp };
|
||||
longestMatchLength = cachedFp.length;
|
||||
log.trace(
|
||||
{
|
||||
requestFingerprint: fingerprint.substring(0, 32) + "...",
|
||||
matchedFingerprint: cachedFp.substring(0, 32) + "...",
|
||||
keyHash: entry.keyHash,
|
||||
matchType: "prefix",
|
||||
},
|
||||
"Cache entry found (prefix match)"
|
||||
);
|
||||
}
|
||||
|
||||
// Also check if the current fingerprint is a prefix of the cached one
|
||||
// (current request has fewer parts, cached had more)
|
||||
if (cachedFp.startsWith(fingerprint) && fingerprint.length > longestMatchLength) {
|
||||
bestMatch = { keyHash: entry.keyHash, matchedFingerprint: cachedFp };
|
||||
longestMatchLength = fingerprint.length;
|
||||
log.trace(
|
||||
{
|
||||
requestFingerprint: fingerprint.substring(0, 32) + "...",
|
||||
matchedFingerprint: cachedFp.substring(0, 32) + "...",
|
||||
keyHash: entry.keyHash,
|
||||
matchType: "prefix_reverse",
|
||||
},
|
||||
"Cache entry found (reverse prefix match)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return bestMatch || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears expired cache entries periodically to prevent memory leaks.
|
||||
*/
|
||||
export function cleanupExpiredCaches(): void {
|
||||
const now = Date.now();
|
||||
let cleaned = 0;
|
||||
|
||||
for (const [fingerprint, entry] of cacheMap.entries()) {
|
||||
const age = now - entry.lastUsed;
|
||||
if (age > entry.ttl) {
|
||||
cacheMap.delete(fingerprint);
|
||||
cleaned++;
|
||||
}
|
||||
}
|
||||
|
||||
if (cleaned > 0) {
|
||||
log.debug(
|
||||
{ cleaned, remaining: cacheMap.size },
|
||||
"Cleaned up expired cache entries"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns cache statistics for monitoring.
|
||||
*/
|
||||
export function getCacheStats() {
|
||||
return {
|
||||
totalEntries: cacheMap.size,
|
||||
entries: Array.from(cacheMap.entries()).map(([fp, entry]) => ({
|
||||
fingerprint: fp,
|
||||
keyHash: entry.keyHash,
|
||||
age: Date.now() - entry.lastUsed,
|
||||
hitCount: entry.hitCount,
|
||||
ttl: entry.ttl,
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
// Run cleanup every minute
|
||||
setInterval(cleanupExpiredCaches, 60 * 1000);
|
||||
@@ -6,6 +6,11 @@ import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { GcpKeyChecker } from "./checker";
|
||||
import {
|
||||
generateCacheFingerprint,
|
||||
recordCacheUsage,
|
||||
getCachedKeyHash,
|
||||
} from "../cache-tracker";
|
||||
|
||||
// GcpKeyUsage is removed, tokenUsage from base Key interface will be used.
|
||||
export interface GcpKey extends Key {
|
||||
@@ -90,7 +95,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
public get(model: string, _streaming?: boolean, requestBody?: any) {
|
||||
const neededFamily = getGcpModelFamily(model);
|
||||
|
||||
// this is a horrible mess
|
||||
@@ -115,6 +120,42 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
);
|
||||
});
|
||||
|
||||
// Generate cache fingerprint if request body contains cache_control
|
||||
const cacheFingerprint = requestBody
|
||||
? generateCacheFingerprint(requestBody)
|
||||
: null;
|
||||
|
||||
// Try to get cached key if we have a fingerprint
|
||||
let preferredKeyHash: string | null = null;
|
||||
let matchedFingerprint: string | null = null;
|
||||
if (cacheFingerprint) {
|
||||
const cacheResult = getCachedKeyHash(cacheFingerprint);
|
||||
if (cacheResult) {
|
||||
preferredKeyHash = cacheResult.keyHash;
|
||||
matchedFingerprint = cacheResult.matchedFingerprint;
|
||||
// Check if the cached key is still available
|
||||
const cachedKey = availableKeys.find((k) => k.hash === preferredKeyHash);
|
||||
if (cachedKey) {
|
||||
this.log.debug(
|
||||
{
|
||||
requestedModel: model,
|
||||
cacheFingerprint,
|
||||
keyHash: preferredKeyHash,
|
||||
},
|
||||
"Using cached key for prompt caching optimization"
|
||||
);
|
||||
} else {
|
||||
// Cached key no longer available
|
||||
preferredKeyHash = null;
|
||||
matchedFingerprint = null;
|
||||
this.log.debug(
|
||||
{ cacheFingerprint, keyHash: preferredKeyHash },
|
||||
"Cached key not available, selecting new key"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
model,
|
||||
@@ -124,6 +165,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
needsSonnet35,
|
||||
availableKeys: availableKeys.length,
|
||||
totalKeys: this.keys.length,
|
||||
cacheFingerprint,
|
||||
hasCachedKey: !!preferredKeyHash,
|
||||
},
|
||||
"Selecting GCP key"
|
||||
);
|
||||
@@ -134,9 +177,28 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
);
|
||||
}
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
/**
|
||||
* Comparator for prioritizing keys based on cache affinity.
|
||||
*/
|
||||
const keyComparator = (a: GcpKey, b: GcpKey) => {
|
||||
// Highest priority: cache affinity
|
||||
if (preferredKeyHash) {
|
||||
if (a.hash === preferredKeyHash) return -1;
|
||||
if (b.hash === preferredKeyHash) return 1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys, keyComparator)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
|
||||
// Record cache usage for future requests
|
||||
// Use matchedFingerprint if we had a cache hit, otherwise use the current fingerprint
|
||||
if (cacheFingerprint) {
|
||||
recordCacheUsage(matchedFingerprint || cacheFingerprint, selectedKey.hash);
|
||||
}
|
||||
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ for service-agnostic functionality.
|
||||
export interface KeyProvider<T extends Key = Key> {
|
||||
readonly service: LLMService;
|
||||
init(): void;
|
||||
get(model: string, streaming?: boolean): T;
|
||||
get(model: string, streaming?: boolean, requestBody?: any): T;
|
||||
list(): Omit<T, "key">[];
|
||||
disable(key: T): void;
|
||||
update(hash: string, update: Partial<T>): void;
|
||||
|
||||
@@ -55,15 +55,15 @@ export class KeyPool {
|
||||
this.scheduleRecheck();
|
||||
}
|
||||
|
||||
public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean): Key {
|
||||
public get(model: string, service?: LLMService, multimodal?: boolean, streaming?: boolean, requestBody?: any): Key {
|
||||
// hack for some claude requests needing keys with particular permissions
|
||||
// even though they use the same models as the non-multimodal requests
|
||||
if (multimodal) {
|
||||
model += "-multimodal";
|
||||
}
|
||||
|
||||
|
||||
const queryService = service || this.getServiceForModel(model);
|
||||
return this.getKeyProvider(queryService).get(model, streaming);
|
||||
return this.getKeyProvider(queryService).get(model, streaming, requestBody);
|
||||
}
|
||||
|
||||
public list(): Omit<Key, "key">[] {
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
import { AnthropicChatMessage } from "../api-schemas";
|
||||
import { getAxiosInstance } from "../network";
|
||||
import { logger } from "../../logger";
|
||||
import { AnthropicKey } from "../key-management/anthropic/provider";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "anthropic-remote" });
|
||||
|
||||
export interface AnthropicTokenCountRequest {
|
||||
model: string;
|
||||
messages: AnthropicChatMessage[];
|
||||
system?: string | Array<{ type: string; text: string }>;
|
||||
tools?: unknown[];
|
||||
}
|
||||
|
||||
export interface AnthropicTokenCountResponse {
|
||||
input_tokens: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts tokens using Anthropic's remote token counting API endpoint.
|
||||
* https://docs.claude.com/en/docs/build-with-claude/token-counting
|
||||
*/
|
||||
export async function countTokensRemote(
|
||||
request: AnthropicTokenCountRequest,
|
||||
key: AnthropicKey
|
||||
): Promise<{ token_count: number; tokenizer: string }> {
|
||||
const axios = getAxiosInstance();
|
||||
|
||||
try {
|
||||
const response = await axios.post<AnthropicTokenCountResponse>(
|
||||
"https://api.anthropic.com/v1/messages/count_tokens",
|
||||
request,
|
||||
{
|
||||
headers: {
|
||||
"x-api-key": key.key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
},
|
||||
timeout: 5000, // 5 second timeout
|
||||
}
|
||||
);
|
||||
|
||||
log.debug(
|
||||
{
|
||||
model: request.model,
|
||||
input_tokens: response.data.input_tokens,
|
||||
},
|
||||
"Counted tokens via Anthropic API"
|
||||
);
|
||||
|
||||
return {
|
||||
token_count: response.data.input_tokens,
|
||||
tokenizer: "anthropic-remote-api",
|
||||
};
|
||||
} catch (error: any) {
|
||||
log.warn(
|
||||
{
|
||||
error: error.message,
|
||||
status: error.response?.status,
|
||||
data: error.response?.data,
|
||||
},
|
||||
"Failed to count tokens via Anthropic API, will fall back to local tokenizer"
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
import { getAxiosInstance } from "../network";
|
||||
import { logger } from "../../logger";
|
||||
import { AwsBedrockKey } from "../key-management/aws/provider";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "aws-remote" });
|
||||
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
|
||||
|
||||
export interface AwsTokenCountRequest {
|
||||
input: {
|
||||
invokeModel: {
|
||||
body: string; // base64-encoded JSON string
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export interface AwsTokenCountResponse {
|
||||
inputTokens: number;
|
||||
}
|
||||
|
||||
type Credential = {
|
||||
accessKeyId: string;
|
||||
secretAccessKey: string;
|
||||
region: string;
|
||||
};
|
||||
|
||||
function getCredentialParts(key: AwsBedrockKey): Credential {
|
||||
const [accessKeyId, secretAccessKey, region] = key.key.split(":");
|
||||
|
||||
if (!accessKeyId || !secretAccessKey || !region) {
|
||||
throw new Error("AWS_CREDENTIALS isn't correctly formatted");
|
||||
}
|
||||
|
||||
return { accessKeyId, secretAccessKey, region };
|
||||
}
|
||||
|
||||
async function sign(request: HttpRequest, credential: Credential) {
|
||||
const { accessKeyId, secretAccessKey, region } = credential;
|
||||
|
||||
const signer = new SignatureV4({
|
||||
sha256: Sha256,
|
||||
credentials: { accessKeyId, secretAccessKey },
|
||||
region,
|
||||
service: "bedrock",
|
||||
});
|
||||
|
||||
return signer.sign(request);
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts tokens using AWS Bedrock's remote token counting API endpoint.
|
||||
* https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_CountTokens.html
|
||||
*/
|
||||
export async function countTokensRemote(
|
||||
modelId: string,
|
||||
request: AwsTokenCountRequest,
|
||||
key: AwsBedrockKey
|
||||
): Promise<{ token_count: number; tokenizer: string }> {
|
||||
const axios = getAxiosInstance();
|
||||
const credential = getCredentialParts(key);
|
||||
const host = AMZ_HOST.replace("%REGION%", credential.region);
|
||||
|
||||
// Create the HTTP request to sign
|
||||
const httpRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: host,
|
||||
path: `/model/${modelId}/count-tokens`,
|
||||
headers: {
|
||||
["Host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
try {
|
||||
// Sign the request using AWS Signature V4
|
||||
const signedRequest = await sign(httpRequest, credential);
|
||||
|
||||
// Make the request
|
||||
const response = await axios.post<AwsTokenCountResponse>(
|
||||
`https://${host}${signedRequest.path}`,
|
||||
signedRequest.body,
|
||||
{
|
||||
headers: signedRequest.headers as Record<string, string>,
|
||||
timeout: 5000, // 5 second timeout
|
||||
}
|
||||
);
|
||||
|
||||
log.debug(
|
||||
{
|
||||
modelId,
|
||||
input_tokens: response.data.inputTokens,
|
||||
},
|
||||
"Counted tokens via AWS Bedrock API"
|
||||
);
|
||||
|
||||
return {
|
||||
token_count: response.data.inputTokens,
|
||||
tokenizer: "aws-bedrock-remote-api",
|
||||
};
|
||||
} catch (error: any) {
|
||||
log.warn(
|
||||
{
|
||||
error: error.message,
|
||||
status: error.response?.status,
|
||||
data: error.response?.data,
|
||||
},
|
||||
"Failed to count tokens via AWS Bedrock API, will fall back to local tokenizer"
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@@ -26,7 +26,7 @@ export async function getTokenCount(
|
||||
return getTokenCountForMessages(prompt);
|
||||
}
|
||||
|
||||
if (prompt.length > 800000) {
|
||||
if (prompt.length > 5000000) {
|
||||
throw new Error("Content is too large to tokenize.");
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ async function getTokenCountForMessages({
|
||||
switch (part.type) {
|
||||
case "text":
|
||||
const { text } = part;
|
||||
if (text.length > 800000 || numTokens > 200000) {
|
||||
if (text.length > 5000000 || numTokens > 1200000) {
|
||||
throw new Error("Text content is too large to tokenize.");
|
||||
}
|
||||
numTokens += encoder.encode(text.normalize("NFKC"), "all").length;
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
import { getAxiosInstance } from "../network";
|
||||
import { logger } from "../../logger";
|
||||
import { GcpKey } from "../key-management/gcp/provider";
|
||||
import { getCredentialsFromGcpKey, refreshGcpAccessToken } from "../key-management/gcp/oauth";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "gcp-remote" });
|
||||
|
||||
export interface GcpTokenCountRequest {
|
||||
contents: Array<{
|
||||
role: string;
|
||||
parts: Array<{
|
||||
text?: string;
|
||||
inline_data?: {
|
||||
mime_type: string;
|
||||
data: string;
|
||||
};
|
||||
}>;
|
||||
}>;
|
||||
systemInstruction?: {
|
||||
parts: Array<{
|
||||
text: string;
|
||||
}>;
|
||||
};
|
||||
tools?: unknown[];
|
||||
}
|
||||
|
||||
export interface GcpTokenCountResponse {
|
||||
totalTokens: number;
|
||||
totalBillableCharacters?: number;
|
||||
promptTokensDetails?: Array<{
|
||||
modality: string;
|
||||
tokenCount: number;
|
||||
}>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Counts tokens using GCP Vertex AI's remote token counting API endpoint.
|
||||
* https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/count-tokens
|
||||
*/
|
||||
export async function countTokensRemote(
|
||||
model: string,
|
||||
request: GcpTokenCountRequest,
|
||||
key: GcpKey
|
||||
): Promise<{ token_count: number; tokenizer: string }> {
|
||||
const axios = getAxiosInstance();
|
||||
|
||||
// Ensure we have a valid access token
|
||||
const now = Date.now();
|
||||
if (!key.accessToken || now >= key.accessTokenExpiresAt) {
|
||||
const [token, expiresIn] = await refreshGcpAccessToken(key);
|
||||
key.accessToken = token;
|
||||
key.accessTokenExpiresAt = now + expiresIn * 1000;
|
||||
}
|
||||
|
||||
const { projectId, region } = await getCredentialsFromGcpKey(key);
|
||||
|
||||
// Extract just the model name (e.g., "gemini-1.5-pro" from full model ID)
|
||||
// GCP model IDs are typically like "gemini-1.5-pro-001" or just "gemini-1.5-pro"
|
||||
let modelName = model;
|
||||
if (model.includes("/")) {
|
||||
// Handle full resource names like "projects/.../models/gemini-1.5-pro"
|
||||
modelName = model.split("/").pop()!;
|
||||
}
|
||||
|
||||
const url = `https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/models/${modelName}:countTokens`;
|
||||
|
||||
try {
|
||||
const response = await axios.post<GcpTokenCountResponse>(
|
||||
url,
|
||||
request,
|
||||
{
|
||||
headers: {
|
||||
"Authorization": `Bearer ${key.accessToken}`,
|
||||
"content-type": "application/json",
|
||||
},
|
||||
timeout: 5000, // 5 second timeout
|
||||
}
|
||||
);
|
||||
|
||||
log.debug(
|
||||
{
|
||||
model: modelName,
|
||||
total_tokens: response.data.totalTokens,
|
||||
details: response.data.promptTokensDetails,
|
||||
},
|
||||
"Counted tokens via GCP Vertex AI API"
|
||||
);
|
||||
|
||||
return {
|
||||
token_count: response.data.totalTokens,
|
||||
tokenizer: "gcp-vertex-ai-remote-api",
|
||||
};
|
||||
} catch (error: any) {
|
||||
log.warn(
|
||||
{
|
||||
error: error.message,
|
||||
status: error.response?.status,
|
||||
data: error.response?.data,
|
||||
},
|
||||
"Failed to count tokens via GCP Vertex AI API, will fall back to local tokenizer"
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,17 @@ import {
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../api-schemas";
|
||||
import { countTokensRemote as countAnthropicTokensRemote } from "./anthropic-remote";
|
||||
import { countTokensRemote as countAwsTokensRemote } from "./aws-remote";
|
||||
import { countTokensRemote as countGcpTokensRemote } from "./gcp-remote";
|
||||
import { AnthropicKey } from "../key-management/anthropic/provider";
|
||||
import { AwsBedrockKey } from "../key-management/aws/provider";
|
||||
import { GcpKey } from "../key-management/gcp/provider";
|
||||
import { logger } from "../../logger";
|
||||
import { config } from "../../config";
|
||||
import { findByAnthropicId } from "../claude-models";
|
||||
|
||||
const log = logger.child({ module: "tokenizer" });
|
||||
|
||||
export async function init() {
|
||||
initClaude();
|
||||
@@ -99,6 +110,87 @@ export async function countTokens({
|
||||
completion,
|
||||
}: TokenCountRequest): Promise<TokenCountResult> {
|
||||
const time = process.hrtime();
|
||||
|
||||
// For prompt counting, try remote APIs first (only when enabled, have a key, and counting prompts)
|
||||
if (config.useRemoteTokenCounting && prompt && req.key) {
|
||||
try {
|
||||
switch (service) {
|
||||
case "anthropic-chat": {
|
||||
if (req.service === "anthropic" && req.key.service === "anthropic") {
|
||||
const result = await countAnthropicTokensRemote(
|
||||
{
|
||||
model: req.body.model,
|
||||
messages: prompt.messages,
|
||||
system: prompt.system,
|
||||
tools: req.body.tools,
|
||||
},
|
||||
req.key as AnthropicKey
|
||||
);
|
||||
return { ...result, tokenization_duration_ms: getElapsedMs(time) };
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "anthropic-text": {
|
||||
if (req.service === "anthropic" && req.key.service === "anthropic") {
|
||||
// Anthropic's API doesn't support text completion counting, fall through to local
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// For AWS Bedrock services, use AWS token counting
|
||||
if (req.service === "aws" && req.key.service === "aws") {
|
||||
if (service === "anthropic-chat") {
|
||||
// Convert Anthropic model ID to AWS Bedrock model ID
|
||||
const anthropicModelId = req.body.model;
|
||||
const claudeMapping = findByAnthropicId(anthropicModelId);
|
||||
const awsModelId = claudeMapping?.awsId || anthropicModelId;
|
||||
|
||||
// Build the request body in Anthropic format - must include all required fields
|
||||
const bodyObj: any = {
|
||||
messages: req.body.messages,
|
||||
max_tokens: req.body.max_tokens,
|
||||
anthropic_version: req.body.anthropic_version || "bedrock-2023-05-31",
|
||||
};
|
||||
if (req.body.system) bodyObj.system = req.body.system;
|
||||
if (req.body.tools) bodyObj.tools = req.body.tools;
|
||||
if (req.body.tool_choice) bodyObj.tool_choice = req.body.tool_choice;
|
||||
|
||||
// AWS expects the body as a base64-encoded string
|
||||
const bodyJson = JSON.stringify(bodyObj);
|
||||
const bodyBase64 = Buffer.from(bodyJson, "utf-8").toString("base64");
|
||||
|
||||
const result = await countAwsTokensRemote(
|
||||
awsModelId,
|
||||
{
|
||||
input: {
|
||||
invokeModel: {
|
||||
body: bodyBase64,
|
||||
},
|
||||
},
|
||||
},
|
||||
req.key as AwsBedrockKey
|
||||
);
|
||||
return { ...result, tokenization_duration_ms: getElapsedMs(time) };
|
||||
}
|
||||
}
|
||||
|
||||
// For GCP Vertex AI services, use GCP token counting
|
||||
if (req.service === "gcp" && req.key.service === "gcp") {
|
||||
// GCP uses a different format, would need transformation
|
||||
// For now, fall through to local counting
|
||||
}
|
||||
} catch (error) {
|
||||
// Fall through to local tokenization
|
||||
log.debug(
|
||||
{ error: (error as Error).message, service },
|
||||
"Remote token counting failed, using local tokenizer"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local tokenization
|
||||
switch (service) {
|
||||
case "anthropic-chat":
|
||||
case "anthropic-text":
|
||||
|
||||
Reference in New Issue
Block a user