Compare commits
24 Commits
main
...
aws-mistral
| Author | SHA1 | Date | |
|---|---|---|---|
| cfc1290f83 | |||
| 14f228f666 | |||
| d264fdd573 | |||
| 9c3e345720 | |||
| 37c421bb45 | |||
| 6c5fed90e2 | |||
| 9479fa4ab0 | |||
| e145f5757e | |||
| 2fe6e07cf5 | |||
| bc340c1be6 | |||
| 45c5d3d338 | |||
| 3032ae3198 | |||
| 49a89122f5 | |||
| 2d8e1dac13 | |||
| 9e5a660ef5 | |||
| 6cf8c09fad | |||
| dc1b573020 | |||
| 3ff771d945 | |||
| 985035fe80 | |||
| 442f9529de | |||
| 598ac8e4e1 | |||
| 750dbee483 | |||
| a2d64e281e | |||
| c6467b02f3 |
Generated
+14
-5
@@ -11,6 +11,7 @@
|
||||
"dependencies": {
|
||||
"@anthropic-ai/tokenizer": "^0.0.4",
|
||||
"@aws-crypto/sha256-js": "^5.2.0",
|
||||
"@huggingface/jinja": "^0.3.0",
|
||||
"@node-rs/argon2": "^1.8.3",
|
||||
"@smithy/eventstream-codec": "^2.1.3",
|
||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||
@@ -18,7 +19,7 @@
|
||||
"@smithy/signature-v4": "^2.1.3",
|
||||
"@smithy/types": "^2.10.1",
|
||||
"@smithy/util-utf8": "^2.1.1",
|
||||
"axios": "^1.3.5",
|
||||
"axios": "^1.7.4",
|
||||
"better-sqlite3": "^10.0.0",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
@@ -866,6 +867,14 @@
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/@huggingface/jinja": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
|
||||
"integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/cliui": {
|
||||
"version": "8.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
|
||||
@@ -1887,11 +1896,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.6.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
|
||||
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
|
||||
"version": "1.7.4",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz",
|
||||
"integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.0",
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.0",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
|
||||
+2
-1
@@ -20,6 +20,7 @@
|
||||
"dependencies": {
|
||||
"@anthropic-ai/tokenizer": "^0.0.4",
|
||||
"@aws-crypto/sha256-js": "^5.2.0",
|
||||
"@huggingface/jinja": "^0.3.0",
|
||||
"@node-rs/argon2": "^1.8.3",
|
||||
"@smithy/eventstream-codec": "^2.1.3",
|
||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||
@@ -27,7 +28,7 @@
|
||||
"@smithy/signature-v4": "^2.1.3",
|
||||
"@smithy/types": "^2.10.1",
|
||||
"@smithy/util-utf8": "^2.1.1",
|
||||
"axios": "^1.3.5",
|
||||
"axios": "^1.7.4",
|
||||
"better-sqlite3": "^10.0.0",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
// uses the aws sdk to sign a request, then uses axios to send it to the bedrock REST API manually
|
||||
import axios from "axios";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
|
||||
const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID!;
|
||||
const AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY!;
|
||||
|
||||
// Copied from amazon bedrock docs
|
||||
|
||||
// List models
|
||||
// ListFoundationModels
|
||||
// Service: Amazon Bedrock
|
||||
// List of Bedrock foundation models that you can use. For more information, see Foundation models in the
|
||||
// Bedrock User Guide.
|
||||
// Request Syntax
|
||||
// GET /foundation-models?
|
||||
// byCustomizationType=byCustomizationType&byInferenceType=byInferenceType&byOutputModality=byOutputModality&byProvider=byProvider
|
||||
// HTTP/1.1
|
||||
// URI Request Parameters
|
||||
// The request uses the following URI parameters.
|
||||
// byCustomizationType (p. 38)
|
||||
// List by customization type.
|
||||
// Valid Values: FINE_TUNING
|
||||
// byInferenceType (p. 38)
|
||||
// List by inference type.
|
||||
// Valid Values: ON_DEMAND | PROVISIONED
|
||||
// byOutputModality (p. 38)
|
||||
// List by output modality type.
|
||||
// Valid Values: TEXT | IMAGE | EMBEDDING
|
||||
// byProvider (p. 38)
|
||||
// A Bedrock model provider.
|
||||
// Pattern: ^[a-z0-9-]{1,63}$
|
||||
// Request Body
|
||||
// The request does not have a request body
|
||||
|
||||
// Run inference on a text model
|
||||
// Send an invoke request to run inference on a Titan Text G1 - Express model. We set the accept
|
||||
// parameter to accept any content type in the response.
|
||||
// POST https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke
|
||||
// -H accept: */*
|
||||
// -H content-type: application/json
|
||||
// Payload
|
||||
// {"inputText": "Hello world"}
|
||||
// Example response
|
||||
// Response for the above request.
|
||||
// -H content-type: application/json
|
||||
// Payload
|
||||
// <the model response>
|
||||
|
||||
const AMZ_REGION = "us-east-1";
|
||||
const AMZ_HOST = "invoke-bedrock.us-east-1.amazonaws.com";
|
||||
|
||||
async function listModels() {
|
||||
const httpRequest = new HttpRequest({
|
||||
method: "GET",
|
||||
protocol: "https:",
|
||||
hostname: AMZ_HOST,
|
||||
path: "/foundation-models",
|
||||
headers: { ["Host"]: AMZ_HOST },
|
||||
});
|
||||
|
||||
const signedRequest = await signRequest(httpRequest);
|
||||
const response = await axios.get(
|
||||
`https://${signedRequest.hostname}${signedRequest.path}`,
|
||||
{ headers: signedRequest.headers }
|
||||
);
|
||||
console.log(response.data);
|
||||
}
|
||||
|
||||
async function invokeModel() {
|
||||
const model = "anthropic.claude-v1";
|
||||
const httpRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: AMZ_HOST,
|
||||
path: `/model/${model}/invoke`,
|
||||
headers: {
|
||||
["Host"]: AMZ_HOST,
|
||||
["accept"]: "*/*",
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
temperature: 0.5,
|
||||
prompt: "\n\nHuman:Hello world\n\nAssistant:",
|
||||
max_tokens_to_sample: 10,
|
||||
}),
|
||||
});
|
||||
console.log("httpRequest", httpRequest);
|
||||
|
||||
const signedRequest = await signRequest(httpRequest);
|
||||
const response = await axios.post(
|
||||
`https://${signedRequest.hostname}${signedRequest.path}`,
|
||||
signedRequest.body,
|
||||
{ headers: signedRequest.headers }
|
||||
);
|
||||
console.log(response.status);
|
||||
console.log(response.headers);
|
||||
console.log(response.data);
|
||||
console.log("full url", response.request.res.responseUrl);
|
||||
}
|
||||
|
||||
async function signRequest(request: HttpRequest) {
|
||||
const signer = new SignatureV4({
|
||||
sha256: Sha256,
|
||||
credentials: {
|
||||
accessKeyId: AWS_ACCESS_KEY_ID,
|
||||
secretAccessKey: AWS_SECRET_ACCESS_KEY,
|
||||
},
|
||||
region: AMZ_REGION,
|
||||
service: "bedrock",
|
||||
});
|
||||
return await signer.sign(request, { signingDate: new Date() });
|
||||
}
|
||||
|
||||
// listModels();
|
||||
// invokeModel();
|
||||
+8
-25
@@ -428,31 +428,10 @@ export const config: Config = {
|
||||
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
|
||||
400
|
||||
),
|
||||
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"gpt4o",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-flash",
|
||||
"gemini-pro",
|
||||
"gemini-ultra",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4o",
|
||||
]),
|
||||
allowedModelFamilies: getEnvWithDefault(
|
||||
"ALLOWED_MODEL_FAMILIES",
|
||||
getDefaultModelFamilies()
|
||||
),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
@@ -801,3 +780,7 @@ function parseCsv(val: string): string[] {
|
||||
const matches = val.match(regex) || [];
|
||||
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
|
||||
}
|
||||
|
||||
function getDefaultModelFamilies(): ModelFamily[] {
|
||||
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
|
||||
}
|
||||
|
||||
+5
-1
@@ -29,6 +29,10 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"mistral-large": "Mistral Large",
|
||||
"aws-claude": "AWS Claude (Sonnet)",
|
||||
"aws-claude-opus": "AWS Claude (Opus)",
|
||||
"aws-mistral-tiny": "AWS Mistral 7B",
|
||||
"aws-mistral-small": "AWS Mistral Nemo",
|
||||
"aws-mistral-medium": "AWS Mistral Medium",
|
||||
"aws-mistral-large": "AWS Mistral Large",
|
||||
"gcp-claude": "GCP Claude (Sonnet)",
|
||||
"gcp-claude-opus": "GCP Claude (Opus)",
|
||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||
@@ -41,7 +45,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
|
||||
const converter = new showdown.Converter();
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
|
||||
: "";
|
||||
let infoPageHtml: string | undefined;
|
||||
let infoPageLastUpdated = 0;
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
|
||||
export function addV1(req: Request, res: Response, next: NextFunction) {
|
||||
// Clients don't consistently use the /v1 prefix so we'll add it for them.
|
||||
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import {
|
||||
transformAnthropicChatResponseToAnthropicText,
|
||||
transformAnthropicChatResponseToOpenAI,
|
||||
} from "./anthropic";
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const awsResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-text":
|
||||
req.log.info("Transforming Anthropic Text back to OpenAI format");
|
||||
newBody = transformAwsTextResponseToOpenAI(body, req);
|
||||
break;
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
case "anthropic-text<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to Text format");
|
||||
newBody = transformAnthropicChatResponseToAnthropicText(body);
|
||||
break;
|
||||
}
|
||||
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAwsTextResponseToOpenAI(
|
||||
awsBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "aws-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: awsBody.completion?.trim(),
|
||||
},
|
||||
finish_reason: awsBody.stop_reason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const awsClaudeProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const nativeTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes text completion prompts to aws anthropic-chat if they need translation
|
||||
* (claude-3 based models do not support the old text completion endpoint).
|
||||
*/
|
||||
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
nativeTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
oaiToAwsChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
oaiToAwsTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const awsClaudeRouter = Router();
|
||||
// Native(ish) Anthropic text completion endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/complete",
|
||||
ipLimiter,
|
||||
preprocessAwsTextRequest,
|
||||
awsClaudeProxy
|
||||
);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsClaudeProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-AWS Anthropic compatibility endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
awsClaudeProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending AWS model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that AWS doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*
|
||||
* If client sends AWS model ID it will be used verbatim. Otherwise, various
|
||||
* strategies are used to try to map a non-AWS model name to AWS model ID.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an AWS model, use it as-is
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Anthropic model names can look like:
|
||||
// - claude-v1
|
||||
// - claude-2.1
|
||||
// - claude-3-5-sonnet-20240620-v1:0
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, fallback to Claude v2 as it is most likely to be
|
||||
// available on AWS.
|
||||
if (!match) {
|
||||
req.body.model = `anthropic.claude-v2:1`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||
|
||||
if (instant) {
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
const ver = minor ? `${major}.${minor}` : major;
|
||||
switch (ver) {
|
||||
case "1":
|
||||
case "1.0":
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
case "2":
|
||||
case "2.0":
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
case "3":
|
||||
case "3.0":
|
||||
if (name.includes("opus")) {
|
||||
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
} else if (name.includes("haiku")) {
|
||||
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
} else {
|
||||
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
}
|
||||
return;
|
||||
case "3.5":
|
||||
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to Claude 2.1
|
||||
req.body.model = `anthropic.claude-v2:1`;
|
||||
return;
|
||||
}
|
||||
|
||||
export const awsClaude = awsClaudeRouter;
|
||||
@@ -0,0 +1,110 @@
|
||||
import { Request } from "express";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import {
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
signAwsRequest,
|
||||
} from "./middleware/request";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { logger } from "../logger";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import { Router } from "express";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
|
||||
|
||||
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") {
|
||||
newBody = transformMistralTextToMistralChat(body);
|
||||
}
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
const awsMistralProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an AWS model, use it as-is
|
||||
if (model.startsWith("mistral.")) {
|
||||
return;
|
||||
}
|
||||
// Mistral 7B Instruct
|
||||
else if (model.includes("7b")) {
|
||||
req.body.model = "mistral.mistral-7b-instruct-v0:2";
|
||||
}
|
||||
// Mistral 8x7B Instruct
|
||||
else if (model.includes("8x7b")) {
|
||||
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
|
||||
}
|
||||
// Mistral Large (Feb 2024)
|
||||
else if (model.includes("large-2402")) {
|
||||
req.body.model = "mistral.mistral-large-2402-v1:0";
|
||||
}
|
||||
// Mistral Large 2 (July 2024)
|
||||
else if (model.includes("large")) {
|
||||
req.body.model = "mistral.mistral-large-2407-v1:0";
|
||||
}
|
||||
// Mistral Small (Feb 2024)
|
||||
else if (model.includes("small")) {
|
||||
req.body.model = "mistral.mistral-small-2402-v1:0";
|
||||
} else {
|
||||
throw new Error(
|
||||
`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
|
||||
{
|
||||
beforeTransform: [detectMistralInputApi],
|
||||
afterTransform: [maybeReassignModel],
|
||||
}
|
||||
);
|
||||
|
||||
const awsMistralRouter = Router();
|
||||
awsMistralRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
nativeMistralChatPreprocessor,
|
||||
awsMistralProxy
|
||||
);
|
||||
|
||||
export const awsMistral = awsMistralRouter;
|
||||
+52
-314
@@ -1,337 +1,75 @@
|
||||
import { Request, RequestHandler, Response, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
/* Shared code between AWS Claude and AWS Mistral endpoints. */
|
||||
|
||||
import { Request, Response, Router } from "express";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
import { addV1 } from "./add-v1";
|
||||
import { awsClaude } from "./aws-claude";
|
||||
import { awsMistral } from "./aws-mistral";
|
||||
import { AwsBedrockKey, keyPool } from "../shared/key-management";
|
||||
|
||||
const LATEST_AWS_V2_MINOR_VERSION = "1";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
const awsRouter = Router();
|
||||
awsRouter.get(["/:vendor?/v1/models", "/:vendor?/models"], handleModelsRequest);
|
||||
awsRouter.use("/claude", addV1, awsClaude);
|
||||
awsRouter.use("/mistral", addV1, awsMistral);
|
||||
|
||||
const MODELS_CACHE_TTL = 10000;
|
||||
let modelsCache: Record<string, any> = {};
|
||||
let modelsCacheTime: Record<string, number> = {};
|
||||
function handleModelsRequest(req: Request, res: Response) {
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
|
||||
const vendor = req.params.vendor?.length
|
||||
? req.params.vendor === "claude"
|
||||
? "anthropic"
|
||||
: req.params.vendor
|
||||
: "all";
|
||||
|
||||
const cacheTime = modelsCacheTime[vendor] || 0;
|
||||
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
|
||||
return res.json(modelsCache[vendor]);
|
||||
}
|
||||
|
||||
const availableModelIds = new Set<string>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "aws") continue;
|
||||
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
const variants = [
|
||||
const models = [
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
"mistral.mistral-small-2402-v1:0",
|
||||
]
|
||||
.filter((id) => availableModelIds.has(id))
|
||||
.map((id) => {
|
||||
const vendor = id.match(/^(.*)\./)?.[1];
|
||||
return {
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
owned_by: vendor,
|
||||
permission: [],
|
||||
root: "claude",
|
||||
root: vendor,
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const awsResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-text":
|
||||
req.log.info("Transforming Anthropic Text back to OpenAI format");
|
||||
newBody = transformAwsTextResponseToOpenAI(body, req);
|
||||
break;
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
case "anthropic-text<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to Text format");
|
||||
newBody = transformAnthropicChatResponseToAnthropicText(body);
|
||||
break;
|
||||
}
|
||||
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAwsTextResponseToOpenAI(
|
||||
awsBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "aws-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: awsBody.completion?.trim(),
|
||||
},
|
||||
finish_reason: awsBody.stop_reason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const awsProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const nativeTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes text completion prompts to aws anthropic-chat if they need translation
|
||||
* (claude-3 based models do not support the old text completion endpoint).
|
||||
*/
|
||||
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
nativeTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
oaiToAwsChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
oaiToAwsTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get("/v1/models", handleModelRequest);
|
||||
// Native(ish) Anthropic text completion endpoint.
|
||||
awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
awsRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
// Temporary force-Claude3 endpoint
|
||||
awsRouter.post(
|
||||
"/v1/sonnet/:action(complete|messages)",
|
||||
ipLimiter,
|
||||
handleCompatibilityRequest,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "anthropic-text",
|
||||
outApi: "anthropic-chat",
|
||||
service: "aws",
|
||||
}),
|
||||
awsProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-AWS Anthropic compatibility endpoint.
|
||||
awsRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
awsProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending AWS model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that AWS doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*
|
||||
* If client sends AWS model ID it will be used verbatim. Otherwise, various
|
||||
* strategies are used to try to map a non-AWS model name to AWS model ID.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an AWS model, use it as-is
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Anthropic model names can look like:
|
||||
// - claude-v1
|
||||
// - claude-2.1
|
||||
// - claude-3-5-sonnet-20240620-v1:0
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, fallback to Claude v2 as it is most likely to be
|
||||
// available on AWS.
|
||||
if (!match) {
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||
|
||||
if (instant) {
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
const ver = minor ? `${major}.${minor}` : major;
|
||||
switch (ver) {
|
||||
case "1":
|
||||
case "1.0":
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
case "2":
|
||||
case "2.0":
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
case "3":
|
||||
case "3.0":
|
||||
if (name.includes("opus")) {
|
||||
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
} else if (name.includes("haiku")) {
|
||||
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
} else {
|
||||
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
}
|
||||
return;
|
||||
case "3.5":
|
||||
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to Claude 2.1
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
export function handleCompatibilityRequest(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: any
|
||||
) {
|
||||
const action = req.params.action;
|
||||
const alreadyInChatFormat = Boolean(req.body.messages);
|
||||
const compatModel = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
req.log.info(
|
||||
{ inputModel: req.body.model, compatModel, alreadyInChatFormat },
|
||||
"Handling AWS compatibility request"
|
||||
);
|
||||
|
||||
if (action === "messages" || alreadyInChatFormat) {
|
||||
return sendErrorToClient({
|
||||
req,
|
||||
res,
|
||||
options: {
|
||||
title: "Unnecessary usage of compatibility endpoint",
|
||||
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`,
|
||||
format: "unknown",
|
||||
statusCode: 400,
|
||||
reqId: req.id,
|
||||
obj: {
|
||||
requested_endpoint: "/aws/claude/sonnet",
|
||||
correct_endpoint: "/aws/claude",
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
req.body.model = compatModel;
|
||||
next();
|
||||
modelsCache[vendor] = {
|
||||
object: "list",
|
||||
data: models.filter((m) => vendor === "all" || m.root === vendor),
|
||||
};
|
||||
modelsCacheTime[vendor] = new Date().getTime();
|
||||
|
||||
return res.json(modelsCache[vendor]);
|
||||
}
|
||||
|
||||
export const aws = awsRouter;
|
||||
|
||||
@@ -221,9 +221,12 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
// Can be null if the model wants to invoke tools rather than return a
|
||||
// completion.
|
||||
return body.choices[0].message.content || "";
|
||||
// Few possible values:
|
||||
// - choices[0].message.content
|
||||
// - choices[0].message with no content if model is invoking a tool
|
||||
return body.choices?.[0]?.message?.content || "";
|
||||
case "mistral-text":
|
||||
return body.outputs?.[0]?.text || "";
|
||||
case "openai-text":
|
||||
return body.choices[0].text;
|
||||
case "anthropic-chat":
|
||||
@@ -260,22 +263,22 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
}
|
||||
}
|
||||
|
||||
export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
export function getModelFromBody(req: Request, resBody: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return resBody.model;
|
||||
case "mistral-ai":
|
||||
return body.model;
|
||||
case "mistral-text":
|
||||
case "openai-image":
|
||||
case "google-ai":
|
||||
// These formats don't have a model in the response body.
|
||||
return req.body.model;
|
||||
case "anthropic-chat":
|
||||
case "anthropic-text":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return body.model || req.body.model;
|
||||
case "google-ai":
|
||||
// Google doesn't confirm the model in the response.
|
||||
return req.body.model;
|
||||
return resBody.model || req.body.model;
|
||||
default:
|
||||
assertNever(format);
|
||||
}
|
||||
|
||||
@@ -38,7 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
// translation now reassigns the model earlier in the request pipeline.
|
||||
case "anthropic-text":
|
||||
case "anthropic-chat":
|
||||
assignedKey = keyPool.get("claude-v1", service, needsMultimodal);
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
case "google-ai":
|
||||
assignedKey = keyPool.get(body.model, service);
|
||||
break;
|
||||
case "openai-text":
|
||||
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
|
||||
@@ -47,10 +50,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
assignedKey = keyPool.get("dall-e-3", service);
|
||||
break;
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
throw new Error(
|
||||
`add-key should not be called for outbound API ${outboundApi}`
|
||||
`Outbound API ${outboundApi} is not supported for ${inboundApi}`
|
||||
);
|
||||
default:
|
||||
assertNever(outboundApi);
|
||||
|
||||
@@ -86,7 +86,7 @@ async function executePreprocessors(
|
||||
const msg = error?.issues
|
||||
?.map((issue: ZodIssue) => issue.message)
|
||||
.join("; ");
|
||||
req.log.info(msg, "Prompt validation failed.");
|
||||
req.log.warn({ issues: msg }, "Prompt validation failed.");
|
||||
} else {
|
||||
req.log.error(error, "Error while executing request preprocessor");
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
AnthropicChatMessage,
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
@@ -50,9 +49,11 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
case "mistral-ai":
|
||||
case "mistral-text": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: MistralAIChatMessage[] = req.body.messages;
|
||||
const prompt: string | MistralAIChatMessage[] =
|
||||
req.body.messages ?? req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -56,8 +56,6 @@ function getPromptFromRequest(req: Request) {
|
||||
switch (service) {
|
||||
case "anthropic-chat":
|
||||
return flattenAnthropicMessages(body.messages);
|
||||
case "anthropic-text":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.messages
|
||||
@@ -72,8 +70,10 @@ function getPromptFromRequest(req: Request) {
|
||||
return `${msg.role}: ${text}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
case "anthropic-text":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
case "mistral-text":
|
||||
return body.prompt;
|
||||
case "google-ai":
|
||||
return body.prompt.text;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import express from "express";
|
||||
import express, { Request } from "express";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
@@ -8,6 +8,10 @@ import {
|
||||
} from "../../../../shared/api-schemas";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import {
|
||||
AWSMistralV1ChatCompletionsSchema,
|
||||
AWSMistralV1TextCompletionsSchema,
|
||||
} from "../../../../shared/api-schemas/mistral-ai";
|
||||
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
|
||||
@@ -29,38 +33,6 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
req.body.prompt = preamble + req.body.prompt;
|
||||
}
|
||||
|
||||
// AWS uses mostly the same parameters as Anthropic, with a few removed params
|
||||
// and much stricter validation on unused parameters. Rather than treating it
|
||||
// as a separate schema we will use the anthropic ones and strip the unused
|
||||
// parameters.
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
let strippedParams: Record<string, unknown>;
|
||||
if (req.outboundApi === "anthropic-chat") {
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "bedrock-2023-05-31";
|
||||
} else {
|
||||
strippedParams = AnthropicV1TextSchema.pick({
|
||||
prompt: true,
|
||||
max_tokens_to_sample: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
}
|
||||
|
||||
const credential = getCredentialParts(req);
|
||||
const host = AMZ_HOST.replace("%REGION%", credential.region);
|
||||
// AWS only uses 2023-06-01 and does not actually check this header, but we
|
||||
@@ -78,7 +50,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
["Host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(strippedParams),
|
||||
body: JSON.stringify(applyAwsStrictValidation(req)),
|
||||
});
|
||||
|
||||
if (stream) {
|
||||
@@ -128,3 +100,48 @@ async function sign(request: HttpRequest, credential: Credential) {
|
||||
|
||||
return signer.sign(request);
|
||||
}
|
||||
|
||||
function applyAwsStrictValidation(req: Request): unknown {
|
||||
// AWS uses vendor API formats but imposes additional (more strict) validation
|
||||
// rules, namely that extraneous parameters are not allowed. We will validate
|
||||
// using the vendor's zod schema but apply `.strip` to ensure that any
|
||||
// extraneous parameters are removed.
|
||||
let strippedParams: Record<string, unknown> = {};
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic-text":
|
||||
strippedParams = AnthropicV1TextSchema.pick({
|
||||
prompt: true,
|
||||
max_tokens_to_sample: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
break;
|
||||
case "anthropic-chat":
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "bedrock-2023-05-31";
|
||||
break;
|
||||
case "mistral-ai":
|
||||
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
|
||||
break;
|
||||
case "mistral-text":
|
||||
strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body);
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unexpected outbound API for AWS.");
|
||||
}
|
||||
return strippedParams;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Request } from "express";
|
||||
import {
|
||||
API_REQUEST_VALIDATORS,
|
||||
API_REQUEST_TRANSFORMERS,
|
||||
@@ -12,29 +13,23 @@ import { RequestPreprocessor } from "../index";
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
const sameService = req.inboundApi === req.outboundApi;
|
||||
const alreadyTransformed = req.retryCount > 0;
|
||||
const notTransformable =
|
||||
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
|
||||
|
||||
if (alreadyTransformed || notTransformable) return;
|
||||
|
||||
// TODO: this should be an APIFormatTransformer
|
||||
if (req.inboundApi === "mistral-ai") {
|
||||
const messages = req.body.messages;
|
||||
req.body.messages = fixMistralPrompt(messages);
|
||||
req.log.info(
|
||||
{ old: messages.length, new: req.body.messages.length },
|
||||
"Fixed Mistral prompt"
|
||||
);
|
||||
}
|
||||
applyMistralPromptFixes(req);
|
||||
|
||||
if (sameService) {
|
||||
// Native prompts are those which were already provided by the client in the
|
||||
// target API format. We don't need to transform them.
|
||||
const isNativePrompt = req.inboundApi === req.outboundApi;
|
||||
if (isNativePrompt) {
|
||||
const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body);
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body: req.body },
|
||||
"Request validation failed"
|
||||
"Native prompt request validation failed."
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
@@ -42,11 +37,12 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prompt requires translation from one API format to another.
|
||||
const transformation = `${req.inboundApi}->${req.outboundApi}` as const;
|
||||
const transFn = API_REQUEST_TRANSFORMERS[transformation];
|
||||
|
||||
if (transFn) {
|
||||
req.log.info({ transformation }, "Transforming request");
|
||||
req.log.info({ transformation }, "Transforming request...");
|
||||
req.body = await transFn(req);
|
||||
return;
|
||||
}
|
||||
@@ -55,3 +51,34 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
`${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.`
|
||||
);
|
||||
};
|
||||
|
||||
// handles weird cases that don't fit into our abstractions
|
||||
function applyMistralPromptFixes(req: Request): void {
|
||||
if (req.inboundApi === "mistral-ai") {
|
||||
// Mistral Chat is very similar to OpenAI but not identical and many clients
|
||||
// don't properly handle the differences. We will try to validate the
|
||||
// mistral prompt and try to fix it if it fails. It will be re-validated
|
||||
// after this function returns.
|
||||
const result = API_REQUEST_VALIDATORS["mistral-ai"].parse(req.body);
|
||||
req.body.messages = fixMistralPrompt(result.messages);
|
||||
req.log.info(
|
||||
{ n: req.body.messages.length, prev: result.messages.length },
|
||||
"Applied Mistral chat prompt fixes."
|
||||
);
|
||||
|
||||
// If the prompt relies on `prefix: true` for the last message, we need to
|
||||
// convert it to a text completions request because Mistral support for
|
||||
// this feature is limited (and completely broken on AWS Mistral).
|
||||
const { messages } = req.body;
|
||||
const lastMessage = messages && messages[messages.length - 1];
|
||||
if (lastMessage && lastMessage.role === "assistant") {
|
||||
// enable prefix if client forgot, otherwise the template will insert an
|
||||
// eos token which is very unlikely to be what the client wants.
|
||||
lastMessage.prefix = true;
|
||||
req.outboundApi = "mistral-text";
|
||||
req.log.info(
|
||||
"Native Mistral chat prompt relies on assistant message prefix. Converting to text completions request."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
proxyMax = MISTRAL_AI_MAX_CONTENT;
|
||||
break;
|
||||
case "openai-image":
|
||||
|
||||
@@ -28,6 +28,7 @@ export const validateVision: RequestPreprocessor = async (req) => {
|
||||
case "anthropic-text":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
case "openai-image":
|
||||
case "openai-text":
|
||||
return;
|
||||
|
||||
@@ -189,6 +189,11 @@ export function buildSpoofedCompletion({
|
||||
},
|
||||
],
|
||||
};
|
||||
case "mistral-text":
|
||||
return {
|
||||
outputs: [{ text: content, stop_reason: title }],
|
||||
model,
|
||||
}
|
||||
case "openai-text":
|
||||
return {
|
||||
id: "error-" + id,
|
||||
@@ -267,6 +272,11 @@ export function buildSpoofedSSE({
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: title }],
|
||||
};
|
||||
break;
|
||||
case "mistral-text":
|
||||
event = {
|
||||
outputs: [{ text: content, stop_reason: title }],
|
||||
};
|
||||
break;
|
||||
case "openai-text":
|
||||
event = {
|
||||
id: "cmpl-" + id,
|
||||
|
||||
@@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
/**
|
||||
* `handleStreamedResponse` consumes and transforms a streamed response from the
|
||||
* upstream service, forwarding events to the client in their requested format.
|
||||
* `handleStreamedResponse` consumes a streamed response from the upstream API,
|
||||
* decodes chunk-by-chunk into a stream of events, transforms those events into
|
||||
* the client's requested format, and forwards the result to the client.
|
||||
*
|
||||
* After the entire stream has been consumed, it resolves with the full response
|
||||
* body so that subsequent middleware in the chain can process it as if it were
|
||||
* a non-streaming response.
|
||||
* a non-streaming response (to count output tokens, track usage, etc).
|
||||
*
|
||||
* In the event of an error, the request's streaming flag is unset and the non-
|
||||
* streaming response handler is called instead.
|
||||
*
|
||||
* If the error is retryable, that handler will re-enqueue the request and also
|
||||
* reset the streaming flag. Unfortunately the streaming flag is set and unset
|
||||
* in multiple places, so it's hard to keep track of.
|
||||
* In the event of an error, the request's streaming flag is unset and the
|
||||
* request is bounced back to the non-streaming response handler. If the error
|
||||
* is retryable, that handler will re-enqueue the request and also reset the
|
||||
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
|
||||
* places, so it's hard to keep track of.
|
||||
*/
|
||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
@@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
logger: req.log,
|
||||
};
|
||||
|
||||
// Decoder turns the raw response stream into a stream of events in some
|
||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||
// While the request is streaming, aggregator collects all events so that we
|
||||
// can compile them into a single response object and publish that to the
|
||||
// remaining middleware. Because we have an OpenAI transformer for every
|
||||
// supported format, EventAggregator always consumes OpenAI events so that we
|
||||
// only have to write one aggregator (OpenAI input) for each output format.
|
||||
const aggregator = new EventAggregator(req);
|
||||
|
||||
// Decoder reads from the raw response buffer and produces a stream of
|
||||
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
|
||||
// streaming JSON, etc).
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter transforms the decoded events into server-sent events.
|
||||
// Adapter consumes the decoded events and produces server-sent events so we
|
||||
// have a standard event format for the client and to translate between API
|
||||
// message formats.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
// Aggregator compiles all events into a single response object.
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
// Transformer converts server-sent events from one vendor's API message
|
||||
// format to another.
|
||||
const transformer = new SSEMessageTransformer({
|
||||
|
||||
@@ -11,7 +11,8 @@ import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import {
|
||||
AnthropicChatMessage,
|
||||
flattenAnthropicMessages, GoogleAIChatMessage,
|
||||
flattenAnthropicMessages,
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../../shared/api-schemas";
|
||||
@@ -76,6 +77,8 @@ const getPromptForRequest = (
|
||||
case "anthropic-chat":
|
||||
return { system: req.body.system, messages: req.body.messages };
|
||||
case "openai-text":
|
||||
case "anthropic-text":
|
||||
case "mistral-text":
|
||||
return req.body.prompt;
|
||||
case "openai-image":
|
||||
return {
|
||||
@@ -85,8 +88,6 @@ const getPromptForRequest = (
|
||||
quality: req.body.quality,
|
||||
revisedPrompt: responseBody.data[0].revised_prompt,
|
||||
};
|
||||
case "anthropic-text":
|
||||
return req.body.prompt;
|
||||
case "google-ai":
|
||||
return { contents: req.body.contents };
|
||||
default:
|
||||
@@ -113,9 +114,7 @@ const flattenMessages = (
|
||||
if (isGoogleAIChatPrompt(val)) {
|
||||
return val.contents
|
||||
.map(({ parts, role }) => {
|
||||
const text = parts
|
||||
.map((p) => p.text)
|
||||
.join("\n");
|
||||
const text = parts.map((p) => p.text).join("\n");
|
||||
return `${role}: ${text}`;
|
||||
})
|
||||
.join("\n");
|
||||
@@ -143,11 +142,7 @@ const flattenMessages = (
|
||||
function isGoogleAIChatPrompt(
|
||||
val: unknown
|
||||
): val is { contents: GoogleAIChatMessage[] } {
|
||||
return (
|
||||
typeof val === "object" &&
|
||||
val !== null &&
|
||||
"contents" in val
|
||||
);
|
||||
return typeof val === "object" && val !== null && "contents" in val;
|
||||
}
|
||||
|
||||
function isAnthropicChatPrompt(
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type MistralChatCompletionResponse = {
|
||||
choices: {
|
||||
index: number;
|
||||
message: { role: string; content: string };
|
||||
finish_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized Mistral chat completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForMistralChat(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): MistralChatCompletionResponse {
|
||||
let merged: MistralChatCompletionResponse = {
|
||||
choices: [
|
||||
{ index: 0, message: { role: "", content: "" }, finish_reason: "" },
|
||||
],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
acc.choices[0].message.role = event.choices[0].delta.role ?? "assistant";
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.choices[0].finish_reason = event.choices[0].finish_reason ?? "";
|
||||
if (event.choices[0].delta.content) {
|
||||
acc.choices[0].message.content += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type MistralTextCompletionResponse = {
|
||||
outputs: {
|
||||
text: string;
|
||||
stop_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized Mistral text completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForMistralText(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): MistralTextCompletionResponse {
|
||||
let merged: MistralTextCompletionResponse = {
|
||||
outputs: [{ text: "", stop_reason: "" }],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.outputs[0].text += event.choices[0].delta.content ?? "";
|
||||
acc.outputs[0].stop_reason = event.choices[0].finish_reason ?? "";
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
|
||||
if (eventType === "chunk") {
|
||||
result = input[eventType];
|
||||
} else {
|
||||
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
|
||||
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
|
||||
result = { [eventType]: input[eventType] } as any;
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import express from "express";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
@@ -6,8 +7,13 @@ import {
|
||||
mergeEventsForAnthropicText,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
mergeEventsForMistralChat,
|
||||
mergeEventsForMistralText,
|
||||
AnthropicV2StreamEvent,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
mistralAIToOpenAI,
|
||||
MistralAIStreamEvent,
|
||||
MistralChatCompletionEvent,
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
@@ -15,45 +21,70 @@ import {
|
||||
* compiles them into a single finalized response for downstream middleware.
|
||||
*/
|
||||
export class EventAggregator {
|
||||
private readonly format: APIFormat;
|
||||
private readonly model: string;
|
||||
private readonly requestFormat: APIFormat;
|
||||
private readonly responseFormat: APIFormat;
|
||||
private readonly events: OpenAIChatCompletionStreamEvent[];
|
||||
|
||||
constructor({ format }: { format: APIFormat }) {
|
||||
constructor({ body, inboundApi, outboundApi }: express.Request) {
|
||||
this.events = [];
|
||||
this.format = format;
|
||||
this.requestFormat = inboundApi;
|
||||
this.responseFormat = outboundApi;
|
||||
this.model = body.model;
|
||||
}
|
||||
|
||||
addEvent(event: OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent) {
|
||||
addEvent(
|
||||
event:
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralAIStreamEvent
|
||||
) {
|
||||
if (eventIsOpenAIEvent(event)) {
|
||||
this.events.push(event);
|
||||
} else {
|
||||
// horrible special case. previously all transformers' target format was
|
||||
// openai, so the event aggregator could conveniently assume all incoming
|
||||
// events were in openai format.
|
||||
// now we have added anthropic-chat-to-text, so aggregator needs to know
|
||||
// how to collapse events from two formats.
|
||||
// because that is annoying, we will simply transform anthropic events to
|
||||
// openai (even if the client didn't ask for openai) so we don't have to
|
||||
// write aggregation logic for anthropic chat (which is also a troublesome
|
||||
// stateful format).
|
||||
const openAIEvent = anthropicV2ToOpenAI({
|
||||
// now we have added some transformers that convert between non-openai
|
||||
// formats, so aggregator needs to know how to collapse for more than
|
||||
// just openai.
|
||||
// because writing aggregation logic for every possible output format is
|
||||
// annoying, we will just transform any non-openai output events to openai
|
||||
// format (even if the client did not request openai at all) so that we
|
||||
// still only need to write aggregators for openai SSEs.
|
||||
let openAIEvent: OpenAIChatCompletionStreamEvent | undefined;
|
||||
switch (this.requestFormat) {
|
||||
case "anthropic-text":
|
||||
assertIsAnthropicV2Event(event);
|
||||
openAIEvent = anthropicV2ToOpenAI({
|
||||
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
|
||||
lastPosition: -1,
|
||||
index: 0,
|
||||
fallbackId: event.log_id || "event-aggregator-fallback",
|
||||
fallbackModel: event.model || "claude-3-fallback",
|
||||
});
|
||||
if (openAIEvent.event) {
|
||||
this.events.push(openAIEvent.event);
|
||||
fallbackId: event.log_id || "fallback-" + Date.now(),
|
||||
fallbackModel: event.model || this.model || "fallback-claude-3",
|
||||
})?.event;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
assertIsMistralChatEvent(event);
|
||||
openAIEvent = mistralAIToOpenAI({
|
||||
data: `data: ${JSON.stringify(event)}\n\n`,
|
||||
lastPosition: -1,
|
||||
index: 0,
|
||||
fallbackId: "fallback-" + Date.now(),
|
||||
fallbackModel: this.model || "fallback-mistral",
|
||||
})?.event;
|
||||
break;
|
||||
}
|
||||
if (openAIEvent) {
|
||||
this.events.push(openAIEvent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getFinalResponse() {
|
||||
switch (this.format) {
|
||||
switch (this.responseFormat) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "google-ai": // TODO: this is probably wrong now that we support native Google Makersuite prompts
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
@@ -61,10 +92,16 @@ export class EventAggregator {
|
||||
return mergeEventsForAnthropicText(this.events);
|
||||
case "anthropic-chat":
|
||||
return mergeEventsForAnthropicChat(this.events);
|
||||
case "mistral-ai":
|
||||
return mergeEventsForMistralChat(this.events);
|
||||
case "mistral-text":
|
||||
return mergeEventsForMistralText(this.events);
|
||||
case "openai-image":
|
||||
throw new Error(`SSE aggregation not supported for ${this.format}`);
|
||||
throw new Error(
|
||||
`SSE aggregation not supported for ${this.responseFormat}`
|
||||
);
|
||||
default:
|
||||
assertNever(this.format);
|
||||
assertNever(this.responseFormat);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,3 +115,17 @@ function eventIsOpenAIEvent(
|
||||
): event is OpenAIChatCompletionStreamEvent {
|
||||
return event?.object === "chat.completion.chunk";
|
||||
}
|
||||
|
||||
function assertIsAnthropicV2Event(event: any): asserts event is AnthropicV2StreamEvent {
|
||||
if (!event?.completion) {
|
||||
throw new Error(`Bad event for Anthropic V2 SSE aggregation`);
|
||||
}
|
||||
}
|
||||
|
||||
function assertIsMistralChatEvent(
|
||||
event: any
|
||||
): asserts event is MistralChatCompletionEvent {
|
||||
if (!event?.choices) {
|
||||
throw new Error(`Bad event for Mistral SSE aggregation`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,25 @@ export type SSEResponseTransformArgs<S = Record<string, any>> = {
|
||||
state?: S;
|
||||
};
|
||||
|
||||
export type MistralChatCompletionEvent = {
|
||||
choices: {
|
||||
index: number;
|
||||
message: { role: string; content: string };
|
||||
stop_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
export type MistralTextCompletionEvent = {
|
||||
outputs: { text: string; stop_reason: string | null }[];
|
||||
};
|
||||
export type MistralAIStreamEvent = {
|
||||
"amazon-bedrock-invocationMetrics"?: {
|
||||
inputTokenCount: number;
|
||||
outputTokenCount: number;
|
||||
invocationLatency: number;
|
||||
firstByteLatency: number;
|
||||
};
|
||||
} & (MistralChatCompletionEvent | MistralTextCompletionEvent);
|
||||
|
||||
export type AnthropicV2StreamEvent = {
|
||||
log_id?: string;
|
||||
model?: string;
|
||||
@@ -41,8 +60,12 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||
export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2";
|
||||
export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai";
|
||||
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
|
||||
export { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai";
|
||||
export { mistralTextToMistralChat } from "./transformers/mistral-text-to-mistral-chat";
|
||||
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||
export { mergeEventsForAnthropicText } from "./aggregators/anthropic-text";
|
||||
export { mergeEventsForAnthropicChat } from "./aggregators/anthropic-chat";
|
||||
export { mergeEventsForMistralChat } from "./aggregators/mistral-chat";
|
||||
export { mergeEventsForMistralText } from "./aggregators/mistral-text";
|
||||
|
||||
@@ -11,8 +11,11 @@ import {
|
||||
googleAIToOpenAI,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
openAITextToOpenAIChat,
|
||||
mistralAIToOpenAI,
|
||||
mistralTextToMistralChat,
|
||||
passthroughToOpenAI,
|
||||
StreamingCompletionTransformer,
|
||||
MistralChatCompletionEvent,
|
||||
} from "./index";
|
||||
|
||||
type SSEMessageTransformerOptions = TransformOptions & {
|
||||
@@ -35,7 +38,9 @@ export class SSEMessageTransformer extends Transform {
|
||||
private readonly inputFormat: APIFormat;
|
||||
private readonly transformFn: StreamingCompletionTransformer<
|
||||
// TODO: Refactor transformers to not assume only OpenAI events as output
|
||||
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralChatCompletionEvent
|
||||
>;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
@@ -121,16 +126,17 @@ function eventIsOpenAIEvent(
|
||||
function getTransformer(
|
||||
responseApi: APIFormat,
|
||||
version?: string,
|
||||
// There's only one case where we're not transforming back to OpenAI, which is
|
||||
// Anthropic Chat response -> Anthropic Text request. This parameter is only
|
||||
// used for that case.
|
||||
// In most cases, we are transforming back to OpenAI. Some responses can be
|
||||
// translated between two non-OpenAI formats, eg Anthropic Chat -> Anthropic
|
||||
// Text, or Mistral Text -> Mistral Chat.
|
||||
requestApi: APIFormat = "openai"
|
||||
): StreamingCompletionTransformer<
|
||||
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralChatCompletionEvent
|
||||
> {
|
||||
switch (responseApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return passthroughToOpenAI;
|
||||
case "openai-text":
|
||||
return openAITextToOpenAIChat;
|
||||
@@ -140,10 +146,16 @@ function getTransformer(
|
||||
: anthropicV2ToOpenAI;
|
||||
case "anthropic-chat":
|
||||
return requestApi === "anthropic-text"
|
||||
? anthropicChatToAnthropicV2
|
||||
? anthropicChatToAnthropicV2 // User's legacy text prompt was converted to chat, and response must be converted back to text
|
||||
: anthropicChatToOpenAI;
|
||||
case "google-ai":
|
||||
return googleAIToOpenAI;
|
||||
case "mistral-ai":
|
||||
return mistralAIToOpenAI;
|
||||
case "mistral-text":
|
||||
return requestApi === "mistral-ai"
|
||||
? mistralTextToMistralChat // User's chat request was converted to text, and response must be converted back to chat
|
||||
: mistralAIToOpenAI;
|
||||
case "openai-image":
|
||||
throw new Error(`SSE transformation not supported for ${responseApi}`);
|
||||
default:
|
||||
|
||||
@@ -55,8 +55,10 @@ export class SSEStreamAdapter extends Transform {
|
||||
|
||||
if ("completion" in eventObj) {
|
||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||
} else {
|
||||
} else if (eventObj.type) {
|
||||
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
|
||||
} else {
|
||||
return `data: ${event}`;
|
||||
}
|
||||
}
|
||||
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import { logger } from "../../../../../logger";
|
||||
import { MistralAIStreamEvent, SSEResponseTransformArgs } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "mistral-ai-to-openai",
|
||||
});
|
||||
|
||||
export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
if ("choices" in completionEvent) {
|
||||
const newChatEvent = {
|
||||
id: params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: completionEvent.choices[0].index,
|
||||
delta: { content: completionEvent.choices[0].message.content },
|
||||
finish_reason: completionEvent.choices[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: newChatEvent };
|
||||
} else if ("outputs" in completionEvent) {
|
||||
const newTextEvent = {
|
||||
id: params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: completionEvent.outputs[0].text },
|
||||
finish_reason: completionEvent.outputs[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: newTextEvent };
|
||||
}
|
||||
|
||||
// should never happen
|
||||
return { position: -1 };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (
|
||||
(Array.isArray(parsed.choices) &&
|
||||
parsed.choices[0].message !== undefined) ||
|
||||
(Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined)
|
||||
) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
import {
|
||||
MistralChatCompletionEvent,
|
||||
MistralTextCompletionEvent,
|
||||
StreamingCompletionTransformer,
|
||||
} from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "mistral-text-to-mistral-chat",
|
||||
});
|
||||
|
||||
/**
|
||||
* Transforms an incoming Mistral Text SSE to an equivalent Mistral Chat SSE.
|
||||
* This is generally used when a client sends a Mistral Chat prompt, but we
|
||||
* convert it to Mistral Text before sending it to the API to work around
|
||||
* some bugs in Mistral/AWS prompt templating. In these cases we need to convert
|
||||
* the response back to Mistral Chat.
|
||||
*/
|
||||
export const mistralTextToMistralChat: StreamingCompletionTransformer<
|
||||
MistralChatCompletionEvent
|
||||
> = (params) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const textCompletion = asTextCompletion(rawEvent);
|
||||
if (!textCompletion) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const chatEvent: MistralChatCompletionEvent = {
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: { role: "assistant", content: textCompletion.outputs[0].text },
|
||||
stop_reason: textCompletion.outputs[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: chatEvent };
|
||||
};
|
||||
|
||||
function asTextCompletion(
|
||||
event: ServerSentEvent
|
||||
): MistralTextCompletionEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
+40
-5
@@ -1,4 +1,4 @@
|
||||
import { RequestHandler, Router } from "express";
|
||||
import express, { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
@@ -61,7 +61,7 @@ export const KNOWN_MISTRAL_AI_MODELS = [
|
||||
"mistral-medium-latest",
|
||||
"mistral-medium-2312",
|
||||
"mistral-tiny",
|
||||
"mistral-tiny-2312"
|
||||
"mistral-tiny-2312",
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
@@ -108,9 +108,24 @@ const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
res.status(200).json({ ...body, proxy: body.proxy });
|
||||
let newBody = body;
|
||||
if (req.inboundApi === "mistral-text" && req.outboundApi === "mistral-ai") {
|
||||
newBody = transformMistralTextToMistralChat(body);
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
export function transformMistralTextToMistralChat(textBody: any) {
|
||||
return {
|
||||
...textBody,
|
||||
choices: [
|
||||
{ message: { content: textBody.outputs[0].text, role: "assistant" } },
|
||||
],
|
||||
outputs: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const mistralAIProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.mistral.ai",
|
||||
@@ -133,12 +148,32 @@ mistralAIRouter.get("/v1/models", handleModelRequest);
|
||||
mistralAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
createPreprocessorMiddleware(
|
||||
{
|
||||
inApi: "mistral-ai",
|
||||
outApi: "mistral-ai",
|
||||
service: "mistral-ai",
|
||||
}),
|
||||
},
|
||||
{ beforeTransform: [detectMistralInputApi] }
|
||||
),
|
||||
mistralAIProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* We can't determine if a request is Mistral text or chat just from the path
|
||||
* because they both use the same endpoint. We need to check the request body
|
||||
* for either `messages` or `prompt`.
|
||||
* @param req
|
||||
*/
|
||||
export function detectMistralInputApi(req: Request) {
|
||||
const { messages, prompt } = req.body;
|
||||
if (messages) {
|
||||
req.inboundApi = "mistral-ai";
|
||||
req.outboundApi = "mistral-ai";
|
||||
} else if (prompt) {
|
||||
req.inboundApi = "mistral-text";
|
||||
req.outboundApi = "mistral-text";
|
||||
}
|
||||
}
|
||||
|
||||
export const mistralAI = mistralAIRouter;
|
||||
|
||||
+24
-20
@@ -1,44 +1,55 @@
|
||||
import express, { Request, Response, NextFunction } from "express";
|
||||
import { gatekeeper } from "./gatekeeper";
|
||||
import { checkRisuToken } from "./check-risu-token";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import express from "express";
|
||||
import { addV1 } from "./add-v1";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
import { checkRisuToken } from "./check-risu-token";
|
||||
import { gatekeeper } from "./gatekeeper";
|
||||
import { gcp } from "./gcp";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { aws } from "./aws";
|
||||
import { gcp } from "./gcp";
|
||||
import { azure } from "./azure";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
|
||||
// Remove `expect: 100-continue` header from requests due to incompatibility
|
||||
// with node-http-proxy.
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
if (req.headers.expect) {
|
||||
// node-http-proxy does not like it when clients send `expect: 100-continue`
|
||||
// and will stall. none of the upstream APIs use this header anyway.
|
||||
delete req.headers.expect;
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
// Apply body parsers.
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "100mb" }),
|
||||
express.urlencoded({ extended: true, limit: "100mb" })
|
||||
);
|
||||
|
||||
// Apply auth/rate limits.
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
|
||||
// Initialize request queue metadata.
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
req.startTime = Date.now();
|
||||
req.retryCount = 0;
|
||||
next();
|
||||
});
|
||||
|
||||
// Proxy endpoints.
|
||||
proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/aws", aws);
|
||||
proxyRouter.use("/gcp/claude", addV1, gcp);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
@@ -48,7 +59,8 @@ proxyRouter.get("*", (req, res, next) => {
|
||||
next();
|
||||
}
|
||||
});
|
||||
// Handle 404s.
|
||||
|
||||
// Send a fake client error if user specifies an invalid proxy endpoint.
|
||||
proxyRouter.use((req, res) => {
|
||||
sendErrorToClient({
|
||||
req,
|
||||
@@ -69,11 +81,3 @@ proxyRouter.use((req, res) => {
|
||||
});
|
||||
|
||||
export { proxyRouter as proxyRouter };
|
||||
|
||||
function addV1(req: Request, res: Response, next: NextFunction) {
|
||||
// Clients don't consistently use the /v1 prefix so we'll add it for them.
|
||||
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
||||
+91
-145
@@ -3,8 +3,6 @@ import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
GcpKey,
|
||||
AzureOpenAIKey,
|
||||
GoogleAIKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
@@ -26,21 +24,14 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
|
||||
|
||||
const CACHE_TTL = 2000;
|
||||
|
||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||
k.service === "google-ai";
|
||||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||
k.service === "mistral-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
|
||||
|
||||
@@ -54,14 +45,15 @@ type ModelAggregates = {
|
||||
overQuota?: number;
|
||||
pozzed?: number;
|
||||
awsLogged?: number;
|
||||
awsSonnet?: number;
|
||||
awsSonnet35?: number;
|
||||
awsHaiku?: number;
|
||||
// needed to disambugiate aws-claude family's variants
|
||||
awsClaude2?: number;
|
||||
awsSonnet3?: number;
|
||||
awsSonnet3_5?: number;
|
||||
awsHaiku: number;
|
||||
gcpSonnet?: number;
|
||||
gcpSonnet35?: number;
|
||||
gcpHaiku?: number;
|
||||
queued: number;
|
||||
queueTime: string;
|
||||
tokens: number;
|
||||
};
|
||||
/** All possible combinations of model family and aggregate type. */
|
||||
@@ -93,14 +85,10 @@ type AnthropicInfo = BaseFamilyInfo & {
|
||||
};
|
||||
type AwsInfo = BaseFamilyInfo & {
|
||||
privacy?: string;
|
||||
sonnetKeys?: number;
|
||||
sonnet35Keys?: number;
|
||||
haikuKeys?: number;
|
||||
enabledVariants?: string;
|
||||
};
|
||||
type GcpInfo = BaseFamilyInfo & {
|
||||
sonnetKeys?: number;
|
||||
sonnet35Keys?: number;
|
||||
haikuKeys?: number;
|
||||
enabledVariants?: string;
|
||||
};
|
||||
|
||||
// prettier-ignore
|
||||
@@ -108,12 +96,10 @@ export type ServiceInfo = {
|
||||
uptime: number;
|
||||
endpoints: {
|
||||
openai?: string;
|
||||
openai2?: string;
|
||||
anthropic?: string;
|
||||
"anthropic-claude-3"?: string;
|
||||
"google-ai"?: string;
|
||||
"mistral-ai"?: string;
|
||||
aws?: string;
|
||||
"aws"?: string;
|
||||
gcp?: string;
|
||||
azure?: string;
|
||||
"openai-image"?: string;
|
||||
@@ -151,7 +137,6 @@ export type ServiceInfo = {
|
||||
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
openai: {
|
||||
openai: `%BASE%/openai`,
|
||||
openai2: `%BASE%/openai/turbo-instruct`,
|
||||
"openai-image": `%BASE%/openai-image`,
|
||||
},
|
||||
anthropic: {
|
||||
@@ -164,7 +149,8 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
"mistral-ai": `%BASE%/mistral-ai`,
|
||||
},
|
||||
aws: {
|
||||
aws: `%BASE%/aws/claude`,
|
||||
"aws-claude": `%BASE%/aws/claude`,
|
||||
"aws-mistral": `%BASE%/aws/mistral`,
|
||||
},
|
||||
gcp: {
|
||||
gcp: `%BASE%/gcp/claude`,
|
||||
@@ -175,7 +161,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
},
|
||||
};
|
||||
|
||||
const modelStats = new Map<ModelAggregateKey, number>();
|
||||
const familyStats = new Map<ModelAggregateKey, number>();
|
||||
const serviceStats = new Map<keyof AllStats, number>();
|
||||
|
||||
let cachedInfo: ServiceInfo | undefined;
|
||||
@@ -192,7 +178,7 @@ export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
|
||||
.concat("turbo")
|
||||
);
|
||||
|
||||
modelStats.clear();
|
||||
familyStats.clear();
|
||||
serviceStats.clear();
|
||||
keys.forEach(addKeyToAggregates);
|
||||
|
||||
@@ -311,150 +297,102 @@ function increment<T extends keyof AllStats | ModelAggregateKey>(
|
||||
) {
|
||||
map.set(key, (map.get(key) || 0) + delta);
|
||||
}
|
||||
const addToService = increment.bind(null, serviceStats);
|
||||
const addToFamily = increment.bind(null, familyStats);
|
||||
|
||||
function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "proompts", k.promptCount);
|
||||
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
increment(
|
||||
serviceStats,
|
||||
"mistral-ai__keys",
|
||||
k.service === "mistral-ai" ? 1 : 0
|
||||
);
|
||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0);
|
||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||
addToService("proompts", k.promptCount);
|
||||
addToService("openai__keys", k.service === "openai" ? 1 : 0);
|
||||
addToService("anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
addToService("google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
addToService("mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
|
||||
addToService("aws__keys", k.service === "aws" ? 1 : 0);
|
||||
addToService("gcp__keys", k.service === "gcp" ? 1 : 0);
|
||||
addToService("azure__keys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
|
||||
const incrementGenericFamilyStats = (f: ModelFamily) => {
|
||||
const tokens = (k as any)[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
addToFamily(`${f}__tokens`, tokens);
|
||||
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
|
||||
};
|
||||
|
||||
switch (k.service) {
|
||||
case "openai":
|
||||
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
|
||||
increment(
|
||||
serviceStats,
|
||||
"openai__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
addToService("openai__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
incrementGenericFamilyStats(f);
|
||||
addToFamily(`${f}__trial`, k.isTrial ? 1 : 0);
|
||||
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
case "anthropic":
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
addToService("anthropic__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
increment(modelStats, `${f}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
});
|
||||
increment(
|
||||
serviceStats,
|
||||
"anthropic__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-ai": {
|
||||
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((family) => {
|
||||
const tokens = k[`${family}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(family, tokens);
|
||||
increment(modelStats, `${family}__tokens`, tokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
incrementGenericFamilyStats(f);
|
||||
addToFamily(`${f}__trial`, k.tier === "free" ? 1 : 0);
|
||||
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
addToFamily(`${f}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case "aws": {
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
if (!k.isDisabled) {
|
||||
// Don't add revoked keys to available AWS variants
|
||||
k.modelIds.forEach((id) => {
|
||||
if (id.includes("claude-3-sonnet")) {
|
||||
addToFamily(`aws-claude__awsSonnet3`, 1);
|
||||
} else if (id.includes("claude-3-5-sonnet")) {
|
||||
addToFamily(`aws-claude__awsSonnet3_5`, 1);
|
||||
} else if (id.includes("claude-3-haiku")) {
|
||||
addToFamily(`aws-claude__awsHaiku`, 1);
|
||||
} else if (id.includes("claude-v2")) {
|
||||
addToFamily(`aws-claude__awsClaude2`, 1);
|
||||
}
|
||||
});
|
||||
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
|
||||
|
||||
}
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
// logging status is unknown.
|
||||
const countAsLogged =
|
||||
k.lastChecked && !k.isDisabled && k.awsLoggingStatus === "enabled";
|
||||
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||
addToFamily(`aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||
break;
|
||||
}
|
||||
case "gcp": {
|
||||
case "gcp":
|
||||
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||
increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||
increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0);
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
// TODO: add modelIds to GcpKey
|
||||
break;
|
||||
// These services don't have any additional stats to track.
|
||||
case "azure":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
||||
increment(serviceStats, "tokens", sumTokens);
|
||||
increment(serviceStats, "tokenCost", sumCost);
|
||||
addToService("tokens", sumTokens);
|
||||
addToService("tokenCost", sumCost);
|
||||
}
|
||||
|
||||
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const tokens = familyStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
||||
activeKeys: modelStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
activeKeys: familyStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: familyStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
// Add service-specific stats to the info object.
|
||||
@@ -462,8 +400,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const service = MODEL_FAMILY_SERVICE[family];
|
||||
switch (service) {
|
||||
case "openai":
|
||||
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
|
||||
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
|
||||
|
||||
// Delete trial/revoked keys for non-turbo families.
|
||||
// Trials are turbo 99% of the time, and if a key is invalid we don't
|
||||
@@ -474,16 +412,25 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
|
||||
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
|
||||
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
|
||||
info.prefilledKeys = familyStats.get(`${family}__pozzed`) || 0;
|
||||
break;
|
||||
case "aws":
|
||||
if (family === "aws-claude") {
|
||||
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
|
||||
info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0;
|
||||
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
|
||||
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
||||
const logged = familyStats.get(`${family}__awsLogged`) || 0;
|
||||
const variants = new Set<string>();
|
||||
if (familyStats.get(`${family}__awsClaude2`) || 0)
|
||||
variants.add("claude2");
|
||||
if (familyStats.get(`${family}__awsSonnet3`) || 0)
|
||||
variants.add("sonnet3");
|
||||
if (familyStats.get(`${family}__awsSonnet3_5`) || 0)
|
||||
variants.add("sonnet3.5");
|
||||
if (familyStats.get(`${family}__awsHaiku`) || 0)
|
||||
variants.add("haiku");
|
||||
info.enabledVariants = variants.size
|
||||
? `${Array.from(variants).join(",")}`
|
||||
: undefined;
|
||||
if (logged > 0) {
|
||||
info.privacy = config.allowAwsLogging
|
||||
? `AWS logging verification inactive. Prompts could be logged.`
|
||||
@@ -493,9 +440,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
break;
|
||||
case "gcp":
|
||||
if (family === "gcp-claude") {
|
||||
info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0;
|
||||
info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0;
|
||||
info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0;
|
||||
// TODO: implement
|
||||
info.enabledVariants = "not implemented";
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,11 @@ import {
|
||||
GoogleAIV1GenerateContentSchema,
|
||||
transformOpenAIToGoogleAI,
|
||||
} from "./google-ai";
|
||||
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
|
||||
import {
|
||||
MistralAIV1ChatCompletionsSchema,
|
||||
MistralAIV1TextCompletionsSchema,
|
||||
transformMistralChatToText,
|
||||
} from "./mistral-ai";
|
||||
|
||||
export { OpenAIChatMessage } from "./openai";
|
||||
export {
|
||||
@@ -49,6 +53,7 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||
"openai->openai-text": transformOpenAIToOpenAIText,
|
||||
"openai->openai-image": transformOpenAIToOpenAIImage,
|
||||
"openai->google-ai": transformOpenAIToGoogleAI,
|
||||
"mistral-ai->mistral-text": transformMistralChatToText,
|
||||
};
|
||||
|
||||
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
@@ -59,4 +64,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||
"mistral-text": MistralAIV1TextCompletionsSchema,
|
||||
};
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
import { z } from "zod";
|
||||
import { OPENAI_OUTPUT_MAX } from "./openai";
|
||||
import { Template } from "@huggingface/jinja";
|
||||
import { APIFormatTransformer } from "./index";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
const MistralChatMessageSchema = z.object({
|
||||
role: z.enum(["system", "user", "assistant", "tool"]), // TODO: implement tools
|
||||
content: z.string(),
|
||||
prefix: z.boolean().optional(),
|
||||
});
|
||||
|
||||
const MistralMessagesSchema = z.array(MistralChatMessageSchema).refine(
|
||||
(input) => {
|
||||
const prefixIdx = input.findIndex((msg) => Boolean(msg.prefix));
|
||||
if (prefixIdx === -1) return true; // no prefix messages
|
||||
const lastIdx = input.length - 1;
|
||||
const lastMsg = input[lastIdx];
|
||||
return prefixIdx === lastIdx && lastMsg.role === "assistant";
|
||||
},
|
||||
{
|
||||
message:
|
||||
"`prefix` can only be set to `true` on the last message, and only for an assistant message.",
|
||||
}
|
||||
);
|
||||
|
||||
// https://docs.mistral.ai/api#operation/createChatCompletion
|
||||
export const MistralAIV1ChatCompletionsSchema = z.object({
|
||||
const BaseMistralAIV1CompletionsSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
})
|
||||
),
|
||||
messages: MistralMessagesSchema.optional(),
|
||||
prompt: z.string().optional(),
|
||||
temperature: z.number().optional().default(0.7),
|
||||
top_p: z.number().optional().default(1),
|
||||
max_tokens: z.coerce
|
||||
@@ -18,12 +37,48 @@ export const MistralAIV1ChatCompletionsSchema = z.object({
|
||||
.nullish()
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
stream: z.boolean().optional().default(false),
|
||||
// Mistral docs say that `stop` can be a string or array but AWS Mistral
|
||||
// blows up if a string is passed. We must convert it to an array.
|
||||
stop: z
|
||||
.union([z.string(), z.array(z.string())])
|
||||
.optional()
|
||||
.default([])
|
||||
.transform((v) => (Array.isArray(v) ? v : [v])),
|
||||
random_seed: z.number().int().min(0).optional(),
|
||||
response_format: z.enum(["text", "json_object"]).optional().default("text"),
|
||||
safe_prompt: z.boolean().optional().default(false),
|
||||
random_seed: z.number().int().optional(),
|
||||
});
|
||||
export type MistralAIChatMessage = z.infer<
|
||||
typeof MistralAIV1ChatCompletionsSchema
|
||||
>["messages"][0];
|
||||
|
||||
export const MistralAIV1ChatCompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.and(
|
||||
z.object({ messages: MistralMessagesSchema })
|
||||
);
|
||||
export const MistralAIV1TextCompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
|
||||
|
||||
/*
|
||||
Slightly more strict version that only allows a subset of the parameters. AWS
|
||||
Mistral helpfully returns no details if unsupported parameters are passed so
|
||||
this list comes from trial and error as of 2024-08-12.
|
||||
*/
|
||||
const BaseAWSMistralAIV1CompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.pick({
|
||||
temperature: true,
|
||||
top_p: true,
|
||||
max_tokens: true,
|
||||
stop: true,
|
||||
random_seed: true,
|
||||
// response_format: true,
|
||||
// safe_prompt: true,
|
||||
}).strip();
|
||||
export const AWSMistralV1ChatCompletionsSchema =
|
||||
BaseAWSMistralAIV1CompletionsSchema.and(
|
||||
z.object({ messages: MistralMessagesSchema })
|
||||
);
|
||||
export const AWSMistralV1TextCompletionsSchema =
|
||||
BaseAWSMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
|
||||
|
||||
export type MistralAIChatMessage = z.infer<typeof MistralChatMessageSchema>;
|
||||
|
||||
export function fixMistralPrompt(
|
||||
messages: MistralAIChatMessage[]
|
||||
@@ -31,12 +86,11 @@ export function fixMistralPrompt(
|
||||
// Mistral uses OpenAI format but has some additional requirements:
|
||||
// - Only one system message per request, and it must be the first message if
|
||||
// present.
|
||||
// - Final message must be a user message.
|
||||
// - Final message must be a user message, unless it has `prefix: true`.
|
||||
// - Cannot have multiple messages from the same role in a row.
|
||||
// While frontends should be able to handle this, we can fix it here in the
|
||||
// meantime.
|
||||
|
||||
return messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
|
||||
const fixed = messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
|
||||
if (acc.length === 0) {
|
||||
acc.push(msg);
|
||||
return acc;
|
||||
@@ -57,4 +111,54 @@ export function fixMistralPrompt(
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
// If the last message is an assistant message, mark it as a prefix. An
|
||||
// assistant message at the end of the conversation without `prefix: true`
|
||||
// results in an error.
|
||||
if (fixed[fixed.length - 1].role === "assistant") {
|
||||
fixed[fixed.length - 1].prefix = true;
|
||||
}
|
||||
return fixed;
|
||||
}
|
||||
|
||||
let jinjaTemplate: Template;
|
||||
let renderTemplate: (messages: MistralAIChatMessage[]) => string;
|
||||
function renderMistralPrompt(messages: MistralAIChatMessage[]) {
|
||||
if (!jinjaTemplate) {
|
||||
logger.warn("Lazy loading mistral chat template...");
|
||||
const { chatTemplate, bosToken, eosToken } =
|
||||
require("./templates/mistral-template").MISTRAL_TEMPLATE;
|
||||
jinjaTemplate = new Template(chatTemplate);
|
||||
renderTemplate = (messages) =>
|
||||
jinjaTemplate.render({
|
||||
messages,
|
||||
bos_token: bosToken,
|
||||
eos_token: eosToken,
|
||||
});
|
||||
}
|
||||
|
||||
return renderTemplate(messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to convert a Mistral chat completions request to a text completions,
|
||||
* using the official prompt template published by Mistral.
|
||||
*/
|
||||
export const transformMistralChatToText: APIFormatTransformer<
|
||||
typeof MistralAIV1TextCompletionsSchema
|
||||
> = async (req) => {
|
||||
const { body } = req;
|
||||
const result = MistralAIV1ChatCompletionsSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid Mistral chat completions request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = renderMistralPrompt(messages);
|
||||
|
||||
return { ...rest, prompt, messages: undefined };
|
||||
};
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
export const MISTRAL_TEMPLATE = {
|
||||
bosToken: "<s>",
|
||||
eosToken: "</s>",
|
||||
chatTemplate: `"{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{%- if loop.last and message.prefix is defined and message.prefix %}
|
||||
{{- " " + message["content"] }}
|
||||
{%- else %}
|
||||
{{- " " + message["content"] + eos_token}}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}`,
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
|
||||
@@ -23,10 +23,6 @@ type AnthropicKeyUsage = {
|
||||
export interface AnthropicKey extends Key, AnthropicKeyUsage {
|
||||
readonly service: "anthropic";
|
||||
readonly modelFamilies: AnthropicModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/**
|
||||
* Whether this key requires a special preamble. For unclear reasons, some
|
||||
* Anthropic keys will throw an error if the prompt does not begin with a
|
||||
@@ -217,22 +213,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -5,9 +5,25 @@ import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
|
||||
import { URL } from "url";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
||||
import { AwsBedrockModelFamily } from "../../models";
|
||||
import { getAwsBedrockModelFamily } from "../../models";
|
||||
import { config } from "../../../config";
|
||||
|
||||
type ParentModelId = string;
|
||||
type AliasModelId = string;
|
||||
type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]];
|
||||
|
||||
const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
|
||||
["anthropic.claude-v2", "anthropic.claude-v2:1"],
|
||||
["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
["anthropic.claude-3-haiku-20240307-v1:0"],
|
||||
["anthropic.claude-3-opus-20240229-v1:0"],
|
||||
["anthropic.claude-3-5-sonnet-20240620-v1:0"],
|
||||
["mistral.mistral-7b-instruct-v0:2"],
|
||||
["mistral.mixtral-8x7b-instruct-v0:1"],
|
||||
["mistral.mistral-large-2402-v1:0"],
|
||||
["mistral.mistral-large-2407-v1:0"],
|
||||
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
|
||||
];
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||
const AMZ_HOST =
|
||||
@@ -47,36 +63,26 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
|
||||
if (isInitialCheck) {
|
||||
checks = [
|
||||
this.invokeModel("anthropic.claude-v2", key),
|
||||
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key),
|
||||
];
|
||||
}
|
||||
|
||||
checks.unshift(this.checkLoggingConfiguration(key));
|
||||
|
||||
const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] =
|
||||
await Promise.all(checks);
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 },
|
||||
"AWS model tests complete."
|
||||
// Perform checks for all parent model IDs
|
||||
const results = await Promise.all(
|
||||
KNOWN_MODEL_IDS.filter(([model]) =>
|
||||
// Skip checks for models that are disabled anyway
|
||||
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
|
||||
).map(async ([model, ...aliases]) => ({
|
||||
models: [model, ...aliases],
|
||||
success: await this.invokeModel(model, key),
|
||||
}))
|
||||
);
|
||||
|
||||
if (isInitialCheck) {
|
||||
const families: AwsBedrockModelFamily[] = [];
|
||||
if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude");
|
||||
if (opus) families.push("aws-claude-opus");
|
||||
// Filter out models that are disabled
|
||||
const modelIds = results
|
||||
.filter(({ success }) => success)
|
||||
.flatMap(({ models }) => models);
|
||||
|
||||
if (families.length === 0) {
|
||||
if (modelIds.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
@@ -85,20 +91,19 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
sonnetEnabled: sonnet,
|
||||
haikuEnabled: haiku,
|
||||
sonnet35Enabled: sonnet35,
|
||||
modelFamilies: families,
|
||||
modelIds,
|
||||
modelFamilies: Array.from(
|
||||
new Set(modelIds.map(getAwsBedrockModelFamily))
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
sonnet,
|
||||
haiku,
|
||||
families: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
families: key.modelFamilies,
|
||||
models: key.modelIds,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
@@ -169,7 +174,22 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(model: string, key: AwsBedrockKey) {
|
||||
private async invokeModel(
|
||||
model: string,
|
||||
key: AwsBedrockKey
|
||||
): Promise<boolean> {
|
||||
if (model.includes("claude")) {
|
||||
return this.testClaudeModel(key, model);
|
||||
} else if (model.includes("mistral")) {
|
||||
return this.testMistralModel(key, model);
|
||||
}
|
||||
throw new Error("AwsKeyChecker#invokeModel: no implementation for model");
|
||||
}
|
||||
|
||||
private async testClaudeModel(
|
||||
key: AwsBedrockKey,
|
||||
model: string
|
||||
): Promise<boolean> {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
// This is not a valid invocation payload, but a 400 response indicates that
|
||||
// the principal at least has permission to invoke the model.
|
||||
@@ -196,7 +216,8 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
|
||||
const errorMessage = data?.message;
|
||||
|
||||
// We only allow one type of 403 error, and we only allow it for one model.
|
||||
// This message indicates the key is valid but this particular model is not
|
||||
// accessible. Other 403s may indicate the key is not usable.
|
||||
if (
|
||||
status === 403 &&
|
||||
errorMessage?.match(/access to the model with the specified model ID/)
|
||||
@@ -219,6 +240,61 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
const correctErrorType = errorType === "ValidationException";
|
||||
const correctErrorMessage = errorMessage?.match(/max_tokens/);
|
||||
if (!correctErrorType || !correctErrorMessage) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"AWS InvokeModel test unsuccessful."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"AWS InvokeModel test successful."
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
private async testMistralModel(
|
||||
key: AwsBedrockKey,
|
||||
model: string
|
||||
): Promise<boolean> {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
|
||||
const payload = {
|
||||
max_tokens: -1,
|
||||
prompt: "<s>[INST] What is your favourite condiment? [/INST]</s>",
|
||||
};
|
||||
const config: AxiosRequestConfig = {
|
||||
method: "POST",
|
||||
url: POST_INVOKE_MODEL_URL(creds.region, model),
|
||||
data: payload,
|
||||
validateStatus: (status) => [400, 403, 404].includes(status),
|
||||
headers: {
|
||||
"content-type": "application/json",
|
||||
accept: "*/*",
|
||||
},
|
||||
};
|
||||
await AwsKeyChecker.signRequestForAws(config, key);
|
||||
const response = await axios.request(config);
|
||||
const { data, status, headers } = response;
|
||||
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
|
||||
const errorMessage = data?.message;
|
||||
|
||||
if (status === 403 || status === 404) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"AWS InvokeModel test returned 403 or 404."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
const isBadRequest = status === 400;
|
||||
const isValidationError = errorMessage?.match(/validation error/i);
|
||||
if (isBadRequest && !isValidationError) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status, headers },
|
||||
"AWS InvokeModel test returned 400 but not a validation error."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
|
||||
import { AwsKeyChecker } from "./checker";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { AwsKeyChecker } from "./checker";
|
||||
|
||||
type AwsBedrockKeyUsage = {
|
||||
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
|
||||
@@ -13,10 +14,6 @@ type AwsBedrockKeyUsage = {
|
||||
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
readonly service: "aws";
|
||||
readonly modelFamilies: AwsBedrockModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/**
|
||||
* The confirmed logging status of this key. This is "unknown" until we
|
||||
* receive a response from the AWS API. Keys which are logged, or not
|
||||
@@ -24,9 +21,7 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
* set.
|
||||
*/
|
||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
modelIds: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,11 +71,13 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
sonnet35Enabled: false,
|
||||
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
["aws-claudeTokens"]: 0,
|
||||
["aws-claude-opusTokens"]: 0,
|
||||
["aws-mistral-tinyTokens"]: 0,
|
||||
["aws-mistral-smallTokens"]: 0,
|
||||
["aws-mistral-mediumTokens"]: 0,
|
||||
["aws-mistral-largeTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
@@ -99,41 +96,35 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
let neededVariantId = model;
|
||||
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
|
||||
// Generally all AWS model IDs are supersets of the original vendor IDs.
|
||||
// Claude 2 is the only model that breaks this convention; Anthropic calls
|
||||
// it claude-2 but AWS calls it claude-v2.
|
||||
if (model.includes("claude-2")) neededVariantId = "claude-v2";
|
||||
const neededFamily = getAwsBedrockModelFamily(model);
|
||||
|
||||
// this is a horrible mess
|
||||
// each of these should be separate model families, but adding model
|
||||
// families is not low enough friction for the rate at which aws claude
|
||||
// model variants are added.
|
||||
const needsSonnet35 =
|
||||
model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude";
|
||||
const needsSonnet =
|
||||
!needsSonnet35 &&
|
||||
model.includes("sonnet") &&
|
||||
neededFamily === "aws-claude";
|
||||
const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude";
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
const isNotLogged = k.awsLoggingStatus !== "enabled";
|
||||
// Select keys which
|
||||
return (
|
||||
// are enabled
|
||||
!k.isDisabled &&
|
||||
(isNotLogged || config.allowAwsLogging) &&
|
||||
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
|
||||
(k.haikuEnabled || !needsHaiku) &&
|
||||
(k.sonnet35Enabled || !needsSonnet35) &&
|
||||
k.modelFamilies.includes(neededFamily)
|
||||
// are not logged, unless policy allows it
|
||||
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
|
||||
// have access to the model family we need
|
||||
k.modelFamilies.includes(neededFamily) &&
|
||||
// have access to the specific variant we need
|
||||
k.modelIds.some((m) => m.includes(neededVariantId))
|
||||
);
|
||||
});
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
model,
|
||||
neededFamily,
|
||||
needsSonnet,
|
||||
needsHaiku,
|
||||
needsSonnet35,
|
||||
availableKeys: availableKeys.length,
|
||||
requestedModel: model,
|
||||
selectedVariant: neededVariantId,
|
||||
selectedFamily: neededFamily,
|
||||
totalKeys: this.keys.length,
|
||||
availableKeys: availableKeys.length,
|
||||
},
|
||||
"Selecting AWS key"
|
||||
);
|
||||
@@ -144,30 +135,8 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -195,22 +164,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import {
|
||||
AzureOpenAIModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
} from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
@@ -14,10 +17,6 @@ type AzureOpenAIKeyUsage = {
|
||||
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||
readonly service: "azure";
|
||||
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
contentFiltering: boolean;
|
||||
}
|
||||
|
||||
@@ -105,30 +104,8 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -156,26 +133,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
// TODO: all of this shit is duplicate code
|
||||
|
||||
public getLockoutPeriod(family: AzureOpenAIModelFamily) {
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(family)
|
||||
);
|
||||
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||
import { GcpKeyChecker } from "./checker";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { GcpKeyChecker } from "./checker";
|
||||
|
||||
type GcpKeyUsage = {
|
||||
[K in GcpModelFamily as `${K}Tokens`]: number;
|
||||
@@ -13,10 +14,6 @@ type GcpKeyUsage = {
|
||||
export interface GcpKey extends Key, GcpKeyUsage {
|
||||
readonly service: "gcp";
|
||||
readonly modelFamilies: GcpModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
@@ -134,30 +131,8 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -185,22 +160,7 @@ export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { GoogleAIKeyChecker } from "./checker";
|
||||
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by
|
||||
@@ -28,10 +29,6 @@ type GoogleAIKeyUsage = {
|
||||
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
|
||||
readonly service: "google-ai";
|
||||
readonly modelFamilies: GoogleAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/** All detected model IDs on this key. */
|
||||
modelIds: string[];
|
||||
}
|
||||
@@ -112,29 +109,10 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
throw new PaymentRequiredError("No Google AI keys available");
|
||||
}
|
||||
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
const keysByPriority = prioritizeKeys(availableKeys);
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -162,22 +140,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -9,7 +9,8 @@ export type APIFormat =
|
||||
| "anthropic-chat" // Anthropic's newer messages array format
|
||||
| "anthropic-text" // Legacy flat string prompt format
|
||||
| "google-ai"
|
||||
| "mistral-ai";
|
||||
| "mistral-ai"
|
||||
| "mistral-text"
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
@@ -30,6 +31,10 @@ export interface Key {
|
||||
lastChecked: number;
|
||||
/** Hash of the key, for logging and to find the key in the pool. */
|
||||
hash: string;
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -58,10 +63,32 @@ export interface KeyProvider<T extends Key = Key> {
|
||||
recheck(): void;
|
||||
}
|
||||
|
||||
export function createGenericGetLockoutPeriod<T extends Key>(
|
||||
getKeys: () => T[]
|
||||
) {
|
||||
return function (this: unknown, family?: ModelFamily): number {
|
||||
const keys = getKeys();
|
||||
const activeKeys = keys.filter(
|
||||
(k) => !k.isDisabled && (!family || k.modelFamilies.includes(family))
|
||||
);
|
||||
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
};
|
||||
}
|
||||
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { GcpKey } from "./gcp/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { MistralAIKey } from "./mistral-ai/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||
import { MistralAIKeyChecker } from "./checker";
|
||||
import { HttpError } from "../../errors";
|
||||
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { MistralAIKeyChecker } from "./checker";
|
||||
|
||||
type MistralAIKeyUsage = {
|
||||
[K in MistralAIModelFamily as `${K}Tokens`]: number;
|
||||
@@ -13,10 +14,6 @@ type MistralAIKeyUsage = {
|
||||
export interface MistralAIKey extends Key, MistralAIKeyUsage {
|
||||
readonly service: "mistral-ai";
|
||||
readonly modelFamilies: MistralAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -98,30 +95,8 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
throw new HttpError(402, "No Mistral AI keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -150,22 +125,7 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
key[`${family}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -26,8 +26,6 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
|
||||
isTrial: boolean;
|
||||
/** Set when key check returns a non-transient 429. */
|
||||
isOverQuota: boolean;
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/**
|
||||
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
|
||||
* number.
|
||||
@@ -111,6 +109,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
rateLimitRequestsReset: 0,
|
||||
rateLimitTokensReset: 0,
|
||||
turboTokens: 0,
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import { Key } from "./index";
|
||||
|
||||
export function prioritizeKeys<T extends Key>(keys: T[]) {
|
||||
// Sorts keys from highest priority to lowest priority, where priority is:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 2. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
return keys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil;
|
||||
const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
}
|
||||
+25
-5
@@ -32,7 +32,9 @@ export type MistralAIModelFamily =
|
||||
// mistral changes their model classes frequently so these no longer
|
||||
// correspond to specific models. consider them rough pricing tiers.
|
||||
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
|
||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||
export type AwsBedrockModelFamily = `aws-${
|
||||
| AnthropicModelFamily
|
||||
| MistralAIModelFamily}`;
|
||||
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
|
||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||
export type ModelFamily =
|
||||
@@ -64,6 +66,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"aws-mistral-tiny",
|
||||
"aws-mistral-small",
|
||||
"aws-mistral-medium",
|
||||
"aws-mistral-large",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
@@ -99,6 +105,10 @@ export const MODEL_FAMILY_SERVICE: {
|
||||
"claude-opus": "anthropic",
|
||||
"aws-claude": "aws",
|
||||
"aws-claude-opus": "aws",
|
||||
"aws-mistral-tiny": "aws",
|
||||
"aws-mistral-small": "aws",
|
||||
"aws-mistral-medium": "aws",
|
||||
"aws-mistral-large": "aws",
|
||||
"gcp-claude": "gcp",
|
||||
"gcp-claude-opus": "gcp",
|
||||
"azure-turbo": "azure",
|
||||
@@ -180,8 +190,16 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
|
||||
if (model.includes("opus")) return "aws-claude-opus";
|
||||
return "aws-claude";
|
||||
// remove vendor and version from AWS model ids
|
||||
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
|
||||
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
|
||||
|
||||
if (["claude", "anthropic"].some((x) => model.includes(x))) {
|
||||
return `aws-${getClaudeModelFamily(deAwsified)}`;
|
||||
} else if (model.includes("tral")) {
|
||||
return `aws-${getMistralAIModelFamily(deAwsified)}`;
|
||||
}
|
||||
return `aws-claude`;
|
||||
}
|
||||
|
||||
export function getGcpModelFamily(model: string): GcpModelFamily {
|
||||
@@ -223,8 +241,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
let modelFamily: ModelFamily;
|
||||
|
||||
// Weird special case for AWS/GCP/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
// Weird special case for AWS/GCP/Azure because they serve models with
|
||||
// different API formats, so the outbound API alone is not sufficient to
|
||||
// determine the partition.
|
||||
if (req.service === "aws") {
|
||||
modelFamily = getAwsBedrockModelFamily(model);
|
||||
} else if (req.service === "gcp") {
|
||||
@@ -246,6 +265,7 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
modelFamily = getGoogleAIModelFamily(model);
|
||||
break;
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
modelFamily = getMistralAIModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
|
||||
@@ -47,9 +47,9 @@ type GoogleAIChatTokenCountRequest = {
|
||||
};
|
||||
|
||||
type MistralAIChatTokenCountRequest = {
|
||||
prompt: MistralAIChatMessage[];
|
||||
prompt: string | MistralAIChatMessage[];
|
||||
completion?: never;
|
||||
service: "mistral-ai";
|
||||
service: "mistral-ai" | "mistral-text";
|
||||
};
|
||||
|
||||
type FlatPromptTokenCountRequest = {
|
||||
@@ -128,6 +128,7 @@ export async function countTokens({
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
return {
|
||||
...getMistralAITokenCount(prompt ?? completion),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
|
||||
@@ -431,6 +431,7 @@ function getModelFamilyForQuotaUsage(
|
||||
case "google-ai":
|
||||
return getGoogleAIModelFamily(model);
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
return getMistralAIModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
|
||||
Reference in New Issue
Block a user