Mistral AI support (khanon/oai-reverse-proxy!58)
This commit is contained in:
+2
-2
@@ -34,10 +34,10 @@
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
|
||||
@@ -26,6 +26,10 @@ type Config = {
|
||||
* same but the APIs are different. Vertex is the GCP product for enterprise.
|
||||
**/
|
||||
googleAIKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of Mistral AI API keys.
|
||||
*/
|
||||
mistralAIKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of AWS credentials. Each credential item should be a
|
||||
* colon-delimited list of access key, secret key, and AWS region.
|
||||
@@ -203,6 +207,7 @@ export const config: Config = {
|
||||
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
|
||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
@@ -235,6 +240,9 @@ export const config: Config = {
|
||||
"gpt4-turbo",
|
||||
"claude",
|
||||
"gemini-pro",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -372,6 +380,7 @@ export const OMITTED_KEYS = [
|
||||
"openaiKey",
|
||||
"anthropicKey",
|
||||
"googleAIKey",
|
||||
"mistralAIKey",
|
||||
"awsCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
|
||||
@@ -17,6 +17,9 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"dall-e": "DALL-E",
|
||||
"claude": "Claude",
|
||||
"gemini-pro": "Gemini Pro",
|
||||
"mistral-tiny": "Mistral 7B",
|
||||
"mistral-small": "Mixtral 8x7B",
|
||||
"mistral-medium": "Mistral prototype",
|
||||
"aws-claude": "AWS Claude",
|
||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||
"azure-gpt4": "Azure GPT-4",
|
||||
|
||||
@@ -193,6 +193,7 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.choices[0].message.content;
|
||||
case "openai-text":
|
||||
return body.choices[0].text;
|
||||
@@ -222,6 +223,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "mistral-ai":
|
||||
return body.model;
|
||||
case "openai-image":
|
||||
return req.body.model;
|
||||
|
||||
@@ -40,6 +40,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
);
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this model.");
|
||||
case "mistral-ai":
|
||||
throw new Error("Mistral AI should never be translated");
|
||||
case "openai-image":
|
||||
assignedKey = keyPool.get("dall-e-3");
|
||||
break;
|
||||
@@ -69,6 +71,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (key.organizationId) {
|
||||
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
||||
}
|
||||
case "mistral-ai":
|
||||
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
||||
break;
|
||||
case "azure":
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "./transform-outbound-payload";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
@@ -36,6 +40,12 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: MistralAIChatMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "openai-image": {
|
||||
req.outputTokens = 1;
|
||||
result = await countTokens({ req, service });
|
||||
|
||||
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { UserInputError } from "../../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
|
||||
case "anthropic":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.messages
|
||||
.map((msg: OpenAIChatMessage) => {
|
||||
.map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
|
||||
const text = Array.isArray(msg.content)
|
||||
? msg.content
|
||||
.map((c) => {
|
||||
|
||||
@@ -155,12 +155,38 @@ export type GoogleAIChatMessage = z.infer<
|
||||
typeof GoogleAIV1GenerateContentSchema
|
||||
>["contents"][0];
|
||||
|
||||
// https://docs.mistral.ai/api#operation/createChatCompletion
|
||||
const MistralAIV1ChatCompletionsSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
})
|
||||
),
|
||||
temperature: z.number().optional().default(0.7),
|
||||
top_p: z.number().optional().default(1),
|
||||
max_tokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
stream: z.boolean().optional().default(false),
|
||||
safe_mode: z.boolean().optional().default(false),
|
||||
random_seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export type MistralAIChatMessage = z.infer<
|
||||
typeof MistralAIV1ChatCompletionsSchema
|
||||
>["messages"][0];
|
||||
|
||||
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
anthropic: AnthropicV1CompleteSchema,
|
||||
openai: OpenAIV1ChatCompletionSchema,
|
||||
"openai-text": OpenAIV1TextCompletionSchema,
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||
};
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
|
||||
@@ -7,6 +7,7 @@ import { RequestPreprocessor } from "../index";
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
||||
const MISTRAL_AI_MAX_CONTENT = 32768;
|
||||
|
||||
/**
|
||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||
@@ -34,6 +35,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
case "google-ai":
|
||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
proxyMax = MISTRAL_AI_MAX_CONTENT;
|
||||
case "openai-image":
|
||||
return;
|
||||
default:
|
||||
@@ -64,6 +67,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^gemini-\d{3}$/)) {
|
||||
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
|
||||
modelMax = MISTRAL_AI_MAX_CONTENT;
|
||||
} else if (model.match(/^anthropic\.claude/)) {
|
||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||
modelMax = 100000;
|
||||
|
||||
@@ -292,6 +292,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
@@ -351,6 +352,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "azure":
|
||||
case "mistral-ai":
|
||||
handleAzureRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-ai":
|
||||
@@ -379,6 +381,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "google-ai":
|
||||
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
|
||||
@@ -9,7 +9,10 @@ import {
|
||||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
|
||||
import {
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../request/preprocessors/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -54,12 +57,13 @@ type OaiImageResult = {
|
||||
const getPromptForRequest = (
|
||||
req: Request,
|
||||
responseBody: Record<string, any>
|
||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||
): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
|
||||
// Since the prompt logger only runs after the request has been proxied, we
|
||||
// can assume the body has already been transformed to the target API's
|
||||
// format.
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return req.body.messages;
|
||||
case "openai-text":
|
||||
return req.body.prompt;
|
||||
@@ -81,7 +85,7 @@ const getPromptForRequest = (
|
||||
};
|
||||
|
||||
const flattenMessages = (
|
||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||
val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
|
||||
): string => {
|
||||
if (typeof val === "string") {
|
||||
return val.trim();
|
||||
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
mergeEventsForAnthropic,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
OpenAIChatCompletionStreamEvent
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
@@ -28,6 +28,7 @@ export class EventAggregator {
|
||||
switch (this.format) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
|
||||
@@ -106,6 +106,7 @@ function getTransformer(
|
||||
): StreamingCompletionTransformer {
|
||||
switch (responseApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return passthroughToOpenAI;
|
||||
case "openai-text":
|
||||
return openAITextToOpenAIChat;
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
getMistralAIModelFamily,
|
||||
MistralAIModelFamily,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
export const KNOWN_MISTRAL_AI_MODELS = [
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
|
||||
let available = new Set<MistralAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "mistral-ai") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as MistralAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
return models
|
||||
.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "mistral-ai",
|
||||
}))
|
||||
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
|
||||
const result = generateModelList();
|
||||
modelsCache = { object: "list", data: result };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
res.status(200).json(modelsCache);
|
||||
};
|
||||
|
||||
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const mistralAIProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.mistral.ai",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const mistralAIRouter = Router();
|
||||
mistralAIRouter.get("/v1/models", handleModelRequest);
|
||||
// General chat completion endpoint.
|
||||
mistralAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "mistral-ai",
|
||||
outApi: "mistral-ai",
|
||||
service: "mistral-ai",
|
||||
}),
|
||||
mistralAIProxy
|
||||
);
|
||||
|
||||
export const mistralAI = mistralAIRouter;
|
||||
@@ -5,6 +5,7 @@ import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
@@ -32,6 +33,7 @@ proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
|
||||
+23
-1
@@ -16,6 +16,7 @@ import {
|
||||
GoogleAIModelFamily,
|
||||
LLM_SERVICES,
|
||||
LLMService,
|
||||
MistralAIModelFamily,
|
||||
MODEL_FAMILY_SERVICE,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
@@ -24,6 +25,7 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
|
||||
|
||||
const CACHE_TTL = 2000;
|
||||
|
||||
@@ -36,6 +38,8 @@ const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||
k.service === "google-ai";
|
||||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||
k.service === "mistral-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
|
||||
/** Stats aggregated across all keys for a given service. */
|
||||
@@ -86,6 +90,7 @@ export type ServiceInfo = {
|
||||
"openai-image"?: string;
|
||||
anthropic?: string;
|
||||
"google-ai"?: string;
|
||||
"mistral-ai"?: string;
|
||||
aws?: string;
|
||||
azure?: string;
|
||||
};
|
||||
@@ -99,7 +104,8 @@ export type ServiceInfo = {
|
||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo };
|
||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||
|
||||
// https://stackoverflow.com/a/66661477
|
||||
// type DeepKeyOf<T> = (
|
||||
@@ -128,6 +134,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
"google-ai": {
|
||||
"google-ai": `%BASE%/google-ai`,
|
||||
},
|
||||
"mistral-ai": {
|
||||
"mistral-ai": `%BASE%/mistral-ai`,
|
||||
},
|
||||
aws: {
|
||||
aws: `%BASE%/aws/claude`,
|
||||
},
|
||||
@@ -268,6 +277,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
|
||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
@@ -331,6 +341,18 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
const family = "aws-claude";
|
||||
|
||||
@@ -11,6 +11,7 @@ export type APIFormat =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "openai-text"
|
||||
| "openai-image";
|
||||
export type Model =
|
||||
|
||||
@@ -11,6 +11,7 @@ import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
|
||||
@@ -24,6 +25,7 @@ export class KeyPool {
|
||||
this.keyProviders.push(new OpenAIKeyProvider());
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||
this.keyProviders.push(new MistralAIKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
@@ -121,6 +123,9 @@ export class KeyPool {
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-ai";
|
||||
} else if (model.includes("mistral")) {
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
return "mistral-ai";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
|
||||
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
|
||||
|
||||
type GetModelsResponse = {
|
||||
data: [{ id: string }];
|
||||
};
|
||||
|
||||
type MistralAIError = {
|
||||
message: string;
|
||||
request_id: string;
|
||||
};
|
||||
|
||||
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
|
||||
|
||||
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
|
||||
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "mistral-ai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: MistralAIKey) {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
const provisionedModels = await this.getProvisionedModels(key);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
}
|
||||
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
key: MistralAIKey
|
||||
): Promise<MistralAIModelFamily[]> {
|
||||
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
|
||||
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||
const models = data.data;
|
||||
|
||||
const families = new Set<MistralAIModelFamily>();
|
||||
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
|
||||
|
||||
// We want to update the key's model families here, but we don't want to
|
||||
// update its `lastChecked` timestamp because we need to let the liveness
|
||||
// check run before we can consider the key checked.
|
||||
|
||||
const familiesArray = [...families];
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||
this.updateKey(key.hash, {
|
||||
modelFamilies: familiesArray,
|
||||
lastChecked: keyFromPool.lastChecked,
|
||||
});
|
||||
return familiesArray;
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
|
||||
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if (status === 401) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
modelFamilies: ["mistral-tiny"],
|
||||
});
|
||||
} else {
|
||||
this.log.error(
|
||||
{ key: key.hash, status, error: data },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
return;
|
||||
}
|
||||
this.log.error(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
static errorIsMistralAIError(
|
||||
error: AxiosError
|
||||
): error is AxiosError<MistralAIError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.message && data?.request_id;
|
||||
}
|
||||
|
||||
static getHeaders(key: MistralAIKey) {
|
||||
return {
|
||||
Authorization: `Bearer ${key.key}`,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider, Model } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||
import { MistralAIKeyChecker } from "./checker";
|
||||
|
||||
export type MistralAIModel =
|
||||
| "mistral-tiny"
|
||||
| "mistral-small"
|
||||
| "mistral-medium";
|
||||
|
||||
export type MistralAIKeyUpdate = Omit<
|
||||
Partial<MistralAIKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
| "promptCount"
|
||||
| "rateLimitedAt"
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
|
||||
type MistralAIKeyUsage = {
|
||||
[K in MistralAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface MistralAIKey extends Key, MistralAIKeyUsage {
|
||||
readonly service: "mistral-ai";
|
||||
readonly modelFamilies: MistralAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
readonly service = "mistral-ai";
|
||||
|
||||
private keys: MistralAIKey[] = [];
|
||||
private checker?: MistralAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.mistralAIKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: MistralAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `mst-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"mistral-tinyTokens": 0,
|
||||
"mistral-smallTokens": 0,
|
||||
"mistral-mediumTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
const updateFn = this.update.bind(this);
|
||||
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: Model) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error("No Mistral AI keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: MistralAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<MistralAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
const family = getMistralAIModelFamily(model);
|
||||
key[`${family}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
+40
-2
@@ -8,7 +8,13 @@ import type { Request } from "express";
|
||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
||||
*/
|
||||
export type LLMService = "openai" | "anthropic" | "google-ai" | "aws" | "azure";
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "aws"
|
||||
| "azure";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
@@ -18,6 +24,10 @@ export type OpenAIModelFamily =
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GoogleAIModelFamily = "gemini-pro";
|
||||
export type MistralAIModelFamily =
|
||||
| "mistral-tiny"
|
||||
| "mistral-small"
|
||||
| "mistral-medium";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
@@ -27,6 +37,7 @@ export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GoogleAIModelFamily
|
||||
| MistralAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
@@ -40,6 +51,9 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"dall-e",
|
||||
"claude",
|
||||
"gemini-pro",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -49,7 +63,14 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
|
||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
|
||||
) => arr)(["openai", "anthropic", "google-ai", "aws", "azure"] as const);
|
||||
) => arr)([
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google-ai",
|
||||
"mistral-ai",
|
||||
"aws",
|
||||
"azure",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||
@@ -78,6 +99,9 @@ export const MODEL_FAMILY_SERVICE: {
|
||||
"azure-gpt4-32k": "azure",
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"gemini-pro": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
};
|
||||
|
||||
pino({ level: "debug" }).child({ module: "startup" });
|
||||
@@ -101,6 +125,17 @@ export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
||||
return "gemini-pro";
|
||||
}
|
||||
|
||||
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
switch (model) {
|
||||
case "mistral-tiny":
|
||||
case "mistral-small":
|
||||
case "mistral-medium":
|
||||
return model;
|
||||
default:
|
||||
return "mistral-tiny";
|
||||
}
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
return "aws-claude";
|
||||
}
|
||||
@@ -158,6 +193,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
case "google-ai":
|
||||
modelFamily = getGoogleAIModelFamily(model);
|
||||
break;
|
||||
case "mistral-ai":
|
||||
modelFamily = getMistralAIModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
|
||||
@@ -25,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
case "claude":
|
||||
cost = 0.00001102;
|
||||
break;
|
||||
case "mistral-tiny":
|
||||
cost = 0.00000031;
|
||||
break;
|
||||
case "mistral-small":
|
||||
cost = 0.00000132;
|
||||
break;
|
||||
case "mistral-medium":
|
||||
cost = 0.0000055;
|
||||
break;
|
||||
}
|
||||
return cost * Math.max(0, tokens);
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ export function makeCompletionSSE({
|
||||
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
event = {
|
||||
id: "chatcmpl-" + id,
|
||||
object: "chat.completion.chunk",
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,45 @@
|
||||
import { MistralAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload.js";
|
||||
import * as tokenizer from "./mistral-tokenizer-js";
|
||||
|
||||
export function init() {
|
||||
tokenizer.initializemistralTokenizer();
|
||||
return true;
|
||||
}
|
||||
|
||||
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
|
||||
if (typeof prompt === "string") {
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
let chunks = [];
|
||||
for (const message of prompt) {
|
||||
switch (message.role) {
|
||||
case "system":
|
||||
chunks.push(message.content);
|
||||
break;
|
||||
case "assistant":
|
||||
chunks.push(message.content + "</s>");
|
||||
break;
|
||||
case "user":
|
||||
chunks.push("[INST] " + message.content + " [/INST]");
|
||||
break;
|
||||
}
|
||||
}
|
||||
return getTextTokenCount(chunks.join(" "));
|
||||
}
|
||||
|
||||
function getTextTokenCount(prompt: string) {
|
||||
// Don't try tokenizing if the prompt is massive to prevent DoS.
|
||||
// 500k characters should be sufficient for all supported models.
|
||||
if (prompt.length > 500000) {
|
||||
return {
|
||||
tokenizer: "length fallback",
|
||||
token_count: 100000,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
tokenizer: "mistral-tokenizer-js",
|
||||
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Request } from "express";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
@@ -14,11 +15,16 @@ import {
|
||||
getOpenAIImageCost,
|
||||
estimateGoogleAITokenCount,
|
||||
} from "./openai";
|
||||
import {
|
||||
init as initMistralAI,
|
||||
getTokenCount as getMistralAITokenCount,
|
||||
} from "./mistral";
|
||||
import { APIFormat } from "../key-management";
|
||||
|
||||
export async function init() {
|
||||
initClaude();
|
||||
initOpenAi();
|
||||
initMistralAI();
|
||||
}
|
||||
|
||||
/** Tagged union via `service` field of the different types of requests that can
|
||||
@@ -31,6 +37,7 @@ type TokenCountRequest = { req: Request } & (
|
||||
service: "openai-text" | "anthropic" | "google-ai";
|
||||
}
|
||||
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
|
||||
| { prompt: MistralAIChatMessage[]; completion?: never; service: "mistral-ai" }
|
||||
| { prompt?: never; completion: string; service: APIFormat }
|
||||
| { prompt?: never; completion?: never; service: "openai-image" }
|
||||
);
|
||||
@@ -77,6 +84,11 @@ export async function countTokens({
|
||||
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "mistral-ai":
|
||||
return {
|
||||
...getMistralAITokenCount(prompt ?? completion),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGoogleAIModelFamily,
|
||||
getMistralAIModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
@@ -34,6 +35,9 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
"dall-e": 0,
|
||||
claude: 0,
|
||||
"gemini-pro": 0,
|
||||
"mistral-tiny": 0,
|
||||
"mistral-small": 0,
|
||||
"mistral-medium": 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
@@ -399,6 +403,8 @@ function getModelFamilyForQuotaUsage(
|
||||
return getClaudeModelFamily(model);
|
||||
case "google-ai":
|
||||
return getGoogleAIModelFamily(model);
|
||||
case "mistral-ai":
|
||||
return getMistralAIModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user