This commit is contained in:
twinkletoes
2023-12-25 18:33:16 +00:00
committed by khanon
parent 01e76cbb1c
commit 4a823b216f
27 changed files with 1070 additions and 12 deletions
+2 -2
View File
@@ -34,10 +34,10 @@
# Which model types users are allowed to access. # Which model types users are allowed to access.
# The following model families are recognized: # The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo # turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image # By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list. # generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo # ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked. # URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com # BLOCKED_ORIGINS=reddit.com,9gag.com
+9
View File
@@ -26,6 +26,10 @@ type Config = {
* same but the APIs are different. Vertex is the GCP product for enterprise. * same but the APIs are different. Vertex is the GCP product for enterprise.
**/ **/
googleAIKey?: string; googleAIKey?: string;
/**
* Comma-delimited list of Mistral AI API keys.
*/
mistralAIKey?: string;
/** /**
* Comma-delimited list of AWS credentials. Each credential item should be a * Comma-delimited list of AWS credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and AWS region. * colon-delimited list of access key, secret key, and AWS region.
@@ -203,6 +207,7 @@ export const config: Config = {
openaiKey: getEnvWithDefault("OPENAI_KEY", ""), openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""), googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
@@ -235,6 +240,9 @@ export const config: Config = {
"gpt4-turbo", "gpt4-turbo",
"claude", "claude",
"gemini-pro", "gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -372,6 +380,7 @@ export const OMITTED_KEYS = [
"openaiKey", "openaiKey",
"anthropicKey", "anthropicKey",
"googleAIKey", "googleAIKey",
"mistralAIKey",
"awsCredentials", "awsCredentials",
"azureCredentials", "azureCredentials",
"proxyKey", "proxyKey",
+3
View File
@@ -17,6 +17,9 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"dall-e": "DALL-E", "dall-e": "DALL-E",
"claude": "Claude", "claude": "Claude",
"gemini-pro": "Gemini Pro", "gemini-pro": "Gemini Pro",
"mistral-tiny": "Mistral 7B",
"mistral-small": "Mixtral 8x7B",
"mistral-medium": "Mistral prototype",
"aws-claude": "AWS Claude", "aws-claude": "AWS Claude",
"azure-turbo": "Azure GPT-3.5 Turbo", "azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4", "azure-gpt4": "Azure GPT-4",
+2
View File
@@ -193,6 +193,7 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
const format = req.outboundApi; const format = req.outboundApi;
switch (format) { switch (format) {
case "openai": case "openai":
case "mistral-ai":
return body.choices[0].message.content; return body.choices[0].message.content;
case "openai-text": case "openai-text":
return body.choices[0].text; return body.choices[0].text;
@@ -222,6 +223,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
switch (format) { switch (format) {
case "openai": case "openai":
case "openai-text": case "openai-text":
case "mistral-ai":
return body.model; return body.model;
case "openai-image": case "openai-image":
return req.body.model; return req.body.model;
@@ -40,6 +40,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
); );
case "google-ai": case "google-ai":
throw new Error("add-key should not be used for this model."); throw new Error("add-key should not be used for this model.");
case "mistral-ai":
throw new Error("Mistral AI should never be translated");
case "openai-image": case "openai-image":
assignedKey = keyPool.get("dall-e-3"); assignedKey = keyPool.get("dall-e-3");
break; break;
@@ -69,6 +71,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
if (key.organizationId) { if (key.organizationId) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId); proxyReq.setHeader("OpenAI-Organization", key.organizationId);
} }
case "mistral-ai":
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break; break;
case "azure": case "azure":
@@ -1,7 +1,11 @@
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization"; import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload"; import type {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "./transform-outbound-payload";
/** /**
* Given a request with an already-transformed body, counts the number of * Given a request with an already-transformed body, counts the number of
@@ -36,6 +40,12 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
} }
case "mistral-ai": {
req.outputTokens = req.body.max_tokens;
const prompt: MistralAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
}
case "openai-image": { case "openai-image": {
req.outputTokens = 1; req.outputTokens = 1;
result = await countTokens({ req, service }); result = await countTokens({ req, service });
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { UserInputError } from "../../../../shared/errors"; import { UserInputError } from "../../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload"; import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "./transform-outbound-payload";
const rejectedClients = new Map<string, number>(); const rejectedClients = new Map<string, number>();
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
case "anthropic": case "anthropic":
return body.prompt; return body.prompt;
case "openai": case "openai":
case "mistral-ai":
return body.messages return body.messages
.map((msg: OpenAIChatMessage) => { .map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
const text = Array.isArray(msg.content) const text = Array.isArray(msg.content)
? msg.content ? msg.content
.map((c) => { .map((c) => {
@@ -155,12 +155,38 @@ export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema typeof GoogleAIV1GenerateContentSchema
>["contents"][0]; >["contents"][0];
// https://docs.mistral.ai/api#operation/createChatCompletion
const MistralAIV1ChatCompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
.number()
.int()
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
safe_mode: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = { const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema, anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema, openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema, "openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema, "openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema, "google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
}; };
/** Transforms an incoming request body to one that matches the target API. */ /** Transforms an incoming request body to one that matches the target API. */
@@ -7,6 +7,7 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const GOOGLE_AI_MAX_CONTEXT = 32000; const GOOGLE_AI_MAX_CONTEXT = 32000;
const MISTRAL_AI_MAX_CONTENT = 32768;
/** /**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body * Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -34,6 +35,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "google-ai": case "google-ai":
proxyMax = GOOGLE_AI_MAX_CONTEXT; proxyMax = GOOGLE_AI_MAX_CONTEXT;
break; break;
case "mistral-ai":
proxyMax = MISTRAL_AI_MAX_CONTENT;
case "openai-image": case "openai-image":
return; return;
default: default:
@@ -64,6 +67,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 200000; modelMax = 200000;
} else if (model.match(/^gemini-\d{3}$/)) { } else if (model.match(/^gemini-\d{3}$/)) {
modelMax = GOOGLE_AI_MAX_CONTEXT; modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
modelMax = MISTRAL_AI_MAX_CONTENT;
} else if (model.match(/^anthropic\.claude/)) { } else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude. // Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000; modelMax = 100000;
+5
View File
@@ -292,6 +292,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) { switch (service) {
case "openai": case "openai":
case "google-ai": case "google-ai":
case "mistral-ai":
case "azure": case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"]; const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) { if (filteredCodes.includes(errorPayload.error?.code)) {
@@ -351,6 +352,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
handleAwsRateLimitError(req, errorPayload); handleAwsRateLimitError(req, errorPayload);
break; break;
case "azure": case "azure":
case "mistral-ai":
handleAzureRateLimitError(req, errorPayload); handleAzureRateLimitError(req, errorPayload);
break; break;
case "google-ai": case "google-ai":
@@ -379,6 +381,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "google-ai": case "google-ai":
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`; errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break; break;
case "mistral-ai":
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
break;
case "aws": case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break; break;
+7 -3
View File
@@ -9,7 +9,10 @@ import {
} from "../common"; } from "../common";
import { ProxyResHandlerWithBody } from "."; import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils"; import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload"; import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "../request/preprocessors/transform-outbound-payload";
/** If prompt logging is enabled, enqueues the prompt for logging. */ /** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async ( export const logPrompt: ProxyResHandlerWithBody = async (
@@ -54,12 +57,13 @@ type OaiImageResult = {
const getPromptForRequest = ( const getPromptForRequest = (
req: Request, req: Request,
responseBody: Record<string, any> responseBody: Record<string, any>
): string | OpenAIChatMessage[] | OaiImageResult => { ): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we // Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's // can assume the body has already been transformed to the target API's
// format. // format.
switch (req.outboundApi) { switch (req.outboundApi) {
case "openai": case "openai":
case "mistral-ai":
return req.body.messages; return req.body.messages;
case "openai-text": case "openai-text":
return req.body.prompt; return req.body.prompt;
@@ -81,7 +85,7 @@ const getPromptForRequest = (
}; };
const flattenMessages = ( const flattenMessages = (
val: string | OpenAIChatMessage[] | OaiImageResult val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
): string => { ): string => {
if (typeof val === "string") { if (typeof val === "string") {
return val.trim(); return val.trim();
@@ -4,7 +4,7 @@ import {
mergeEventsForAnthropic, mergeEventsForAnthropic,
mergeEventsForOpenAIChat, mergeEventsForOpenAIChat,
mergeEventsForOpenAIText, mergeEventsForOpenAIText,
OpenAIChatCompletionStreamEvent OpenAIChatCompletionStreamEvent,
} from "./index"; } from "./index";
/** /**
@@ -28,6 +28,7 @@ export class EventAggregator {
switch (this.format) { switch (this.format) {
case "openai": case "openai":
case "google-ai": case "google-ai":
case "mistral-ai":
return mergeEventsForOpenAIChat(this.events); return mergeEventsForOpenAIChat(this.events);
case "openai-text": case "openai-text":
return mergeEventsForOpenAIText(this.events); return mergeEventsForOpenAIText(this.events);
@@ -106,6 +106,7 @@ function getTransformer(
): StreamingCompletionTransformer { ): StreamingCompletionTransformer {
switch (responseApi) { switch (responseApi) {
case "openai": case "openai":
case "mistral-ai":
return passthroughToOpenAI; return passthroughToOpenAI;
case "openai-text": case "openai-text":
return openAITextToOpenAIChat; return openAITextToOpenAIChat;
+116
View File
@@ -0,0 +1,116 @@
import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
getMistralAIModelFamily,
MistralAIModelFamily,
ModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
// https://docs.mistral.ai/platform/endpoints
export const KNOWN_MISTRAL_AI_MODELS = [
"mistral-tiny",
"mistral-small",
"mistral-medium",
];
let modelsCache: any = null;
let modelsCacheTime = 0;
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
let available = new Set<MistralAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "mistral-ai") continue;
key.modelFamilies.forEach((family) =>
available.add(family as MistralAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
return models
.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "mistral-ai",
}))
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
const result = generateModelList();
modelsCache = { object: "list", data: result };
modelsCacheTime = new Date().getTime();
res.status(200).json(modelsCache);
};
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
const mistralAIProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.mistral.ai",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
error: handleProxyError,
},
}),
});
const mistralAIRouter = Router();
mistralAIRouter.get("/v1/models", handleModelRequest);
// General chat completion endpoint.
mistralAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
}),
mistralAIProxy
);
export const mistralAI = mistralAIRouter;
+2
View File
@@ -5,6 +5,7 @@ import { openai } from "./openai";
import { openaiImage } from "./openai-image"; import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic"; import { anthropic } from "./anthropic";
import { googleAI } from "./google-ai"; import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws"; import { aws } from "./aws";
import { azure } from "./azure"; import { azure } from "./azure";
@@ -32,6 +33,7 @@ proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-ai", addV1, googleAI); proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws); proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure); proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage. // Redirect browser requests to the homepage.
+23 -1
View File
@@ -16,6 +16,7 @@ import {
GoogleAIModelFamily, GoogleAIModelFamily,
LLM_SERVICES, LLM_SERVICES,
LLMService, LLMService,
MistralAIModelFamily,
MODEL_FAMILY_SERVICE, MODEL_FAMILY_SERVICE,
ModelFamily, ModelFamily,
OpenAIModelFamily, OpenAIModelFamily,
@@ -24,6 +25,7 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
import { getUniqueIps } from "./proxy/rate-limit"; import { getUniqueIps } from "./proxy/rate-limit";
import { assertNever } from "./shared/utils"; import { assertNever } from "./shared/utils";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue"; import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
const CACHE_TTL = 2000; const CACHE_TTL = 2000;
@@ -36,6 +38,8 @@ const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic"; k.service === "anthropic";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey => const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai"; k.service === "google-ai";
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
/** Stats aggregated across all keys for a given service. */ /** Stats aggregated across all keys for a given service. */
@@ -86,6 +90,7 @@ export type ServiceInfo = {
"openai-image"?: string; "openai-image"?: string;
anthropic?: string; anthropic?: string;
"google-ai"?: string; "google-ai"?: string;
"mistral-ai"?: string;
aws?: string; aws?: string;
azure?: string; azure?: string;
}; };
@@ -99,7 +104,8 @@ export type ServiceInfo = {
& { [f in AnthropicModelFamily]?: AnthropicInfo; } & { [f in AnthropicModelFamily]?: AnthropicInfo; }
& { [f in AwsBedrockModelFamily]?: AwsInfo } & { [f in AwsBedrockModelFamily]?: AwsInfo }
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; } & { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }; & { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
// https://stackoverflow.com/a/66661477 // https://stackoverflow.com/a/66661477
// type DeepKeyOf<T> = ( // type DeepKeyOf<T> = (
@@ -128,6 +134,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
"google-ai": { "google-ai": {
"google-ai": `%BASE%/google-ai`, "google-ai": `%BASE%/google-ai`,
}, },
"mistral-ai": {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: { aws: {
aws: `%BASE%/aws/claude`, aws: `%BASE%/aws/claude`,
}, },
@@ -268,6 +277,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0); increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0); increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0); increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0); increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0); increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
@@ -331,6 +341,18 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]); increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break; break;
} }
case "mistral-ai": {
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
break;
}
case "aws": { case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type"); if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude"; const family = "aws-claude";
+1
View File
@@ -11,6 +11,7 @@ export type APIFormat =
| "openai" | "openai"
| "anthropic" | "anthropic"
| "google-ai" | "google-ai"
| "mistral-ai"
| "openai-text" | "openai-text"
| "openai-image"; | "openai-image";
export type Model = export type Model =
+5
View File
@@ -11,6 +11,7 @@ import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider"; import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider"; import { AwsBedrockKeyProvider } from "./aws/provider";
import { AzureOpenAIKeyProvider } from "./azure/provider"; import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -24,6 +25,7 @@ export class KeyPool {
this.keyProviders.push(new OpenAIKeyProvider()); this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider()); this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GoogleAIKeyProvider()); this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new MistralAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider()); this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider()); this.keyProviders.push(new AzureOpenAIKeyProvider());
} }
@@ -121,6 +123,9 @@ export class KeyPool {
} else if (model.includes("gemini")) { } else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language // https://developers.generativeai.google.com/models/language
return "google-ai"; return "google-ai";
} else if (model.includes("mistral")) {
// https://docs.mistral.ai/platform/endpoints
return "mistral-ai";
} else if (model.startsWith("anthropic.claude")) { } else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers // AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
@@ -0,0 +1,112 @@
import axios, { AxiosError } from "axios";
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
type GetModelsResponse = {
data: [{ id: string }];
};
type MistralAIError = {
message: string;
request_id: string;
};
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "mistral-ai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: MistralAIKey) {
// We only need to check for provisioned models on the initial check.
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
const provisionedModels = await this.getProvisionedModels(key);
const updates = {
modelFamilies: provisionedModels,
};
this.updateKey(key.hash, updates);
}
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
}
private async getProvisionedModels(
key: MistralAIKey
): Promise<MistralAIModelFamily[]> {
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const families = new Set<MistralAIModelFamily>();
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
const familiesArray = [...families];
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
modelFamilies: familiesArray,
lastChecked: keyFromPool.lastChecked,
});
return familiesArray;
}
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
const { status, data } = error.response;
if (status === 401) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
modelFamilies: ["mistral-tiny"],
});
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
this.log.error(
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
static errorIsMistralAIError(
error: AxiosError
): error is AxiosError<MistralAIError> {
const data = error.response?.data as any;
return data?.message && data?.request_id;
}
static getHeaders(key: MistralAIKey) {
return {
Authorization: `Bearer ${key.key}`,
};
}
}
@@ -0,0 +1,210 @@
import crypto from "crypto";
import { Key, KeyProvider, Model } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { MistralAIKeyChecker } from "./checker";
export type MistralAIModel =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type MistralAIKeyUpdate = Omit<
Partial<MistralAIKey>,
| "key"
| "hash"
| "lastUsed"
| "promptCount"
| "rateLimitedAt"
| "rateLimitedUntil"
>;
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
};
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
readonly service = "mistral-ai";
private keys: MistralAIKey[] = [];
private checker?: MistralAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.mistralAIKey?.trim();
if (!keyConfig) {
this.log.warn(
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: MistralAIKey = {
key,
service: this.service,
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `mst-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"mistral-tinyTokens": 0,
"mistral-smallTokens": 0,
"mistral-mediumTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
}
public init() {
if (config.checkKeys) {
const updateFn = this.update.bind(this);
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: Model) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Mistral AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: MistralAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<MistralAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
const family = getMistralAIModelFamily(model);
key[`${family}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
+40 -2
View File
@@ -8,7 +8,13 @@ import type { Request } from "express";
* The service that a model is hosted on. Distinct from `APIFormat` because some * The service that a model is hosted on. Distinct from `APIFormat` because some
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure). * services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
*/ */
export type LLMService = "openai" | "anthropic" | "google-ai" | "aws" | "azure"; export type LLMService =
| "openai"
| "anthropic"
| "google-ai"
| "mistral-ai"
| "aws"
| "azure";
export type OpenAIModelFamily = export type OpenAIModelFamily =
| "turbo" | "turbo"
@@ -18,6 +24,10 @@ export type OpenAIModelFamily =
| "dall-e"; | "dall-e";
export type AnthropicModelFamily = "claude"; export type AnthropicModelFamily = "claude";
export type GoogleAIModelFamily = "gemini-pro"; export type GoogleAIModelFamily = "gemini-pro";
export type MistralAIModelFamily =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type AwsBedrockModelFamily = "aws-claude"; export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude< export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily, OpenAIModelFamily,
@@ -27,6 +37,7 @@ export type ModelFamily =
| OpenAIModelFamily | OpenAIModelFamily
| AnthropicModelFamily | AnthropicModelFamily
| GoogleAIModelFamily | GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily | AwsBedrockModelFamily
| AzureOpenAIModelFamily; | AzureOpenAIModelFamily;
@@ -40,6 +51,9 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"dall-e", "dall-e",
"claude", "claude",
"gemini-pro", "gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -49,7 +63,14 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
export const LLM_SERVICES = (<A extends readonly LLMService[]>( export const LLM_SERVICES = (<A extends readonly LLMService[]>(
arr: A & ([LLMService] extends [A[number]] ? unknown : never) arr: A & ([LLMService] extends [A[number]] ? unknown : never)
) => arr)(["openai", "anthropic", "google-ai", "aws", "azure"] as const); ) => arr)([
"openai",
"anthropic",
"google-ai",
"mistral-ai",
"aws",
"azure",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4-1106(-preview)?$": "gpt4-turbo", "^gpt-4-1106(-preview)?$": "gpt4-turbo",
@@ -78,6 +99,9 @@ export const MODEL_FAMILY_SERVICE: {
"azure-gpt4-32k": "azure", "azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure", "azure-gpt4-turbo": "azure",
"gemini-pro": "google-ai", "gemini-pro": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
}; };
pino({ level: "debug" }).child({ module: "startup" }); pino({ level: "debug" }).child({ module: "startup" });
@@ -101,6 +125,17 @@ export function getGoogleAIModelFamily(_model: string): ModelFamily {
return "gemini-pro"; return "gemini-pro";
} }
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
switch (model) {
case "mistral-tiny":
case "mistral-small":
case "mistral-medium":
return model;
default:
return "mistral-tiny";
}
}
export function getAwsBedrockModelFamily(_model: string): ModelFamily { export function getAwsBedrockModelFamily(_model: string): ModelFamily {
return "aws-claude"; return "aws-claude";
} }
@@ -158,6 +193,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "google-ai": case "google-ai":
modelFamily = getGoogleAIModelFamily(model); modelFamily = getGoogleAIModelFamily(model);
break; break;
case "mistral-ai":
modelFamily = getMistralAIModelFamily(model);
break;
default: default:
assertNever(req.outboundApi); assertNever(req.outboundApi);
} }
+9
View File
@@ -25,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "claude": case "claude":
cost = 0.00001102; cost = 0.00001102;
break; break;
case "mistral-tiny":
cost = 0.00000031;
break;
case "mistral-small":
cost = 0.00000132;
break;
case "mistral-medium":
cost = 0.0000055;
break;
} }
return cost * Math.max(0, tokens); return cost * Math.max(0, tokens);
} }
+1
View File
@@ -64,6 +64,7 @@ export function makeCompletionSSE({
switch (format) { switch (format) {
case "openai": case "openai":
case "mistral-ai":
event = { event = {
id: "chatcmpl-" + id, id: "chatcmpl-" + id,
object: "chat.completion.chunk", object: "chat.completion.chunk",
File diff suppressed because one or more lines are too long
+45
View File
@@ -0,0 +1,45 @@
import { MistralAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload.js";
import * as tokenizer from "./mistral-tokenizer-js";
export function init() {
tokenizer.initializemistralTokenizer();
return true;
}
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
let chunks = [];
for (const message of prompt) {
switch (message.role) {
case "system":
chunks.push(message.content);
break;
case "assistant":
chunks.push(message.content + "</s>");
break;
case "user":
chunks.push("[INST] " + message.content + " [/INST]");
break;
}
}
return getTextTokenCount(chunks.join(" "));
}
function getTextTokenCount(prompt: string) {
// Don't try tokenizing if the prompt is massive to prevent DoS.
// 500k characters should be sufficient for all supported models.
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "mistral-tokenizer-js",
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
};
}
+12
View File
@@ -1,6 +1,7 @@
import { Request } from "express"; import { Request } from "express";
import type { import type {
GoogleAIChatMessage, GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage, OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils"; import { assertNever } from "../utils";
@@ -14,11 +15,16 @@ import {
getOpenAIImageCost, getOpenAIImageCost,
estimateGoogleAITokenCount, estimateGoogleAITokenCount,
} from "./openai"; } from "./openai";
import {
init as initMistralAI,
getTokenCount as getMistralAITokenCount,
} from "./mistral";
import { APIFormat } from "../key-management"; import { APIFormat } from "../key-management";
export async function init() { export async function init() {
initClaude(); initClaude();
initOpenAi(); initOpenAi();
initMistralAI();
} }
/** Tagged union via `service` field of the different types of requests that can /** Tagged union via `service` field of the different types of requests that can
@@ -31,6 +37,7 @@ type TokenCountRequest = { req: Request } & (
service: "openai-text" | "anthropic" | "google-ai"; service: "openai-text" | "anthropic" | "google-ai";
} }
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" } | { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| { prompt: MistralAIChatMessage[]; completion?: never; service: "mistral-ai" }
| { prompt?: never; completion: string; service: APIFormat } | { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" } | { prompt?: never; completion?: never; service: "openai-image" }
); );
@@ -77,6 +84,11 @@ export async function countTokens({
...estimateGoogleAITokenCount(prompt ?? (completion || [])), ...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time), tokenization_duration_ms: getElapsedMs(time),
}; };
case "mistral-ai":
return {
...getMistralAITokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time),
};
default: default:
assertNever(service); assertNever(service);
} }
+6
View File
@@ -15,6 +15,7 @@ import {
getAzureOpenAIModelFamily, getAzureOpenAIModelFamily,
getClaudeModelFamily, getClaudeModelFamily,
getGoogleAIModelFamily, getGoogleAIModelFamily,
getMistralAIModelFamily,
getOpenAIModelFamily, getOpenAIModelFamily,
MODEL_FAMILIES, MODEL_FAMILIES,
ModelFamily, ModelFamily,
@@ -34,6 +35,9 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
"dall-e": 0, "dall-e": 0,
claude: 0, claude: 0,
"gemini-pro": 0, "gemini-pro": 0,
"mistral-tiny": 0,
"mistral-small": 0,
"mistral-medium": 0,
"aws-claude": 0, "aws-claude": 0,
"azure-turbo": 0, "azure-turbo": 0,
"azure-gpt4": 0, "azure-gpt4": 0,
@@ -399,6 +403,8 @@ function getModelFamilyForQuotaUsage(
return getClaudeModelFamily(model); return getClaudeModelFamily(model);
case "google-ai": case "google-ai":
return getGoogleAIModelFamily(model); return getGoogleAIModelFamily(model);
case "mistral-ai":
return getMistralAIModelFamily(model);
default: default:
assertNever(api); assertNever(api);
} }