Mistral AI support (khanon/oai-reverse-proxy!58)
This commit is contained in:
+2
-2
@@ -34,10 +34,10 @@
|
|||||||
|
|
||||||
# Which model types users are allowed to access.
|
# Which model types users are allowed to access.
|
||||||
# The following model families are recognized:
|
# The following model families are recognized:
|
||||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||||
|
|
||||||
# URLs from which requests will be blocked.
|
# URLs from which requests will be blocked.
|
||||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||||
|
|||||||
@@ -26,6 +26,10 @@ type Config = {
|
|||||||
* same but the APIs are different. Vertex is the GCP product for enterprise.
|
* same but the APIs are different. Vertex is the GCP product for enterprise.
|
||||||
**/
|
**/
|
||||||
googleAIKey?: string;
|
googleAIKey?: string;
|
||||||
|
/**
|
||||||
|
* Comma-delimited list of Mistral AI API keys.
|
||||||
|
*/
|
||||||
|
mistralAIKey?: string;
|
||||||
/**
|
/**
|
||||||
* Comma-delimited list of AWS credentials. Each credential item should be a
|
* Comma-delimited list of AWS credentials. Each credential item should be a
|
||||||
* colon-delimited list of access key, secret key, and AWS region.
|
* colon-delimited list of access key, secret key, and AWS region.
|
||||||
@@ -203,6 +207,7 @@ export const config: Config = {
|
|||||||
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
|
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
|
||||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||||
|
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||||
@@ -235,6 +240,9 @@ export const config: Config = {
|
|||||||
"gpt4-turbo",
|
"gpt4-turbo",
|
||||||
"claude",
|
"claude",
|
||||||
"gemini-pro",
|
"gemini-pro",
|
||||||
|
"mistral-tiny",
|
||||||
|
"mistral-small",
|
||||||
|
"mistral-medium",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
"azure-turbo",
|
"azure-turbo",
|
||||||
"azure-gpt4",
|
"azure-gpt4",
|
||||||
@@ -372,6 +380,7 @@ export const OMITTED_KEYS = [
|
|||||||
"openaiKey",
|
"openaiKey",
|
||||||
"anthropicKey",
|
"anthropicKey",
|
||||||
"googleAIKey",
|
"googleAIKey",
|
||||||
|
"mistralAIKey",
|
||||||
"awsCredentials",
|
"awsCredentials",
|
||||||
"azureCredentials",
|
"azureCredentials",
|
||||||
"proxyKey",
|
"proxyKey",
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
|||||||
"dall-e": "DALL-E",
|
"dall-e": "DALL-E",
|
||||||
"claude": "Claude",
|
"claude": "Claude",
|
||||||
"gemini-pro": "Gemini Pro",
|
"gemini-pro": "Gemini Pro",
|
||||||
|
"mistral-tiny": "Mistral 7B",
|
||||||
|
"mistral-small": "Mixtral 8x7B",
|
||||||
|
"mistral-medium": "Mistral prototype",
|
||||||
"aws-claude": "AWS Claude",
|
"aws-claude": "AWS Claude",
|
||||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||||
"azure-gpt4": "Azure GPT-4",
|
"azure-gpt4": "Azure GPT-4",
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
|||||||
const format = req.outboundApi;
|
const format = req.outboundApi;
|
||||||
switch (format) {
|
switch (format) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
case "mistral-ai":
|
||||||
return body.choices[0].message.content;
|
return body.choices[0].message.content;
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
return body.choices[0].text;
|
return body.choices[0].text;
|
||||||
@@ -222,6 +223,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
|
|||||||
switch (format) {
|
switch (format) {
|
||||||
case "openai":
|
case "openai":
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
|
case "mistral-ai":
|
||||||
return body.model;
|
return body.model;
|
||||||
case "openai-image":
|
case "openai-image":
|
||||||
return req.body.model;
|
return req.body.model;
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
|||||||
);
|
);
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
throw new Error("add-key should not be used for this model.");
|
throw new Error("add-key should not be used for this model.");
|
||||||
|
case "mistral-ai":
|
||||||
|
throw new Error("Mistral AI should never be translated");
|
||||||
case "openai-image":
|
case "openai-image":
|
||||||
assignedKey = keyPool.get("dall-e-3");
|
assignedKey = keyPool.get("dall-e-3");
|
||||||
break;
|
break;
|
||||||
@@ -69,6 +71,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
|||||||
if (key.organizationId) {
|
if (key.organizationId) {
|
||||||
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
||||||
}
|
}
|
||||||
|
case "mistral-ai":
|
||||||
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
||||||
break;
|
break;
|
||||||
case "azure":
|
case "azure":
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import { RequestPreprocessor } from "../index";
|
import { RequestPreprocessor } from "../index";
|
||||||
import { countTokens } from "../../../../shared/tokenization";
|
import { countTokens } from "../../../../shared/tokenization";
|
||||||
import { assertNever } from "../../../../shared/utils";
|
import { assertNever } from "../../../../shared/utils";
|
||||||
import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload";
|
import type {
|
||||||
|
GoogleAIChatMessage,
|
||||||
|
MistralAIChatMessage,
|
||||||
|
OpenAIChatMessage,
|
||||||
|
} from "./transform-outbound-payload";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given a request with an already-transformed body, counts the number of
|
* Given a request with an already-transformed body, counts the number of
|
||||||
@@ -36,6 +40,12 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
|||||||
result = await countTokens({ req, prompt, service });
|
result = await countTokens({ req, prompt, service });
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "mistral-ai": {
|
||||||
|
req.outputTokens = req.body.max_tokens;
|
||||||
|
const prompt: MistralAIChatMessage[] = req.body.messages;
|
||||||
|
result = await countTokens({ req, prompt, service });
|
||||||
|
break;
|
||||||
|
}
|
||||||
case "openai-image": {
|
case "openai-image": {
|
||||||
req.outputTokens = 1;
|
req.outputTokens = 1;
|
||||||
result = await countTokens({ req, service });
|
result = await countTokens({ req, service });
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
|
|||||||
import { assertNever } from "../../../../shared/utils";
|
import { assertNever } from "../../../../shared/utils";
|
||||||
import { RequestPreprocessor } from "../index";
|
import { RequestPreprocessor } from "../index";
|
||||||
import { UserInputError } from "../../../../shared/errors";
|
import { UserInputError } from "../../../../shared/errors";
|
||||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
import {
|
||||||
|
MistralAIChatMessage,
|
||||||
|
OpenAIChatMessage,
|
||||||
|
} from "./transform-outbound-payload";
|
||||||
|
|
||||||
const rejectedClients = new Map<string, number>();
|
const rejectedClients = new Map<string, number>();
|
||||||
|
|
||||||
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
|
|||||||
case "anthropic":
|
case "anthropic":
|
||||||
return body.prompt;
|
return body.prompt;
|
||||||
case "openai":
|
case "openai":
|
||||||
|
case "mistral-ai":
|
||||||
return body.messages
|
return body.messages
|
||||||
.map((msg: OpenAIChatMessage) => {
|
.map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
|
||||||
const text = Array.isArray(msg.content)
|
const text = Array.isArray(msg.content)
|
||||||
? msg.content
|
? msg.content
|
||||||
.map((c) => {
|
.map((c) => {
|
||||||
|
|||||||
@@ -155,12 +155,38 @@ export type GoogleAIChatMessage = z.infer<
|
|||||||
typeof GoogleAIV1GenerateContentSchema
|
typeof GoogleAIV1GenerateContentSchema
|
||||||
>["contents"][0];
|
>["contents"][0];
|
||||||
|
|
||||||
|
// https://docs.mistral.ai/api#operation/createChatCompletion
|
||||||
|
const MistralAIV1ChatCompletionsSchema = z.object({
|
||||||
|
model: z.string(),
|
||||||
|
messages: z.array(
|
||||||
|
z.object({
|
||||||
|
role: z.enum(["system", "user", "assistant"]),
|
||||||
|
content: z.string(),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
temperature: z.number().optional().default(0.7),
|
||||||
|
top_p: z.number().optional().default(1),
|
||||||
|
max_tokens: z.coerce
|
||||||
|
.number()
|
||||||
|
.int()
|
||||||
|
.nullish()
|
||||||
|
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||||
|
stream: z.boolean().optional().default(false),
|
||||||
|
safe_mode: z.boolean().optional().default(false),
|
||||||
|
random_seed: z.number().int().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type MistralAIChatMessage = z.infer<
|
||||||
|
typeof MistralAIV1ChatCompletionsSchema
|
||||||
|
>["messages"][0];
|
||||||
|
|
||||||
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||||
anthropic: AnthropicV1CompleteSchema,
|
anthropic: AnthropicV1CompleteSchema,
|
||||||
openai: OpenAIV1ChatCompletionSchema,
|
openai: OpenAIV1ChatCompletionSchema,
|
||||||
"openai-text": OpenAIV1TextCompletionSchema,
|
"openai-text": OpenAIV1TextCompletionSchema,
|
||||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||||
|
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||||
};
|
};
|
||||||
|
|
||||||
/** Transforms an incoming request body to one that matches the target API. */
|
/** Transforms an incoming request body to one that matches the target API. */
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { RequestPreprocessor } from "../index";
|
|||||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||||
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
||||||
|
const MISTRAL_AI_MAX_CONTENT = 32768;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||||
@@ -34,6 +35,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
|||||||
case "google-ai":
|
case "google-ai":
|
||||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||||
break;
|
break;
|
||||||
|
case "mistral-ai":
|
||||||
|
proxyMax = MISTRAL_AI_MAX_CONTENT;
|
||||||
case "openai-image":
|
case "openai-image":
|
||||||
return;
|
return;
|
||||||
default:
|
default:
|
||||||
@@ -64,6 +67,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
|||||||
modelMax = 200000;
|
modelMax = 200000;
|
||||||
} else if (model.match(/^gemini-\d{3}$/)) {
|
} else if (model.match(/^gemini-\d{3}$/)) {
|
||||||
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
||||||
|
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
|
||||||
|
modelMax = MISTRAL_AI_MAX_CONTENT;
|
||||||
} else if (model.match(/^anthropic\.claude/)) {
|
} else if (model.match(/^anthropic\.claude/)) {
|
||||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||||
modelMax = 100000;
|
modelMax = 100000;
|
||||||
|
|||||||
@@ -292,6 +292,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
switch (service) {
|
switch (service) {
|
||||||
case "openai":
|
case "openai":
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
|
case "mistral-ai":
|
||||||
case "azure":
|
case "azure":
|
||||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||||
@@ -351,6 +352,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
handleAwsRateLimitError(req, errorPayload);
|
handleAwsRateLimitError(req, errorPayload);
|
||||||
break;
|
break;
|
||||||
case "azure":
|
case "azure":
|
||||||
|
case "mistral-ai":
|
||||||
handleAzureRateLimitError(req, errorPayload);
|
handleAzureRateLimitError(req, errorPayload);
|
||||||
break;
|
break;
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
@@ -379,6 +381,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
case "google-ai":
|
case "google-ai":
|
||||||
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
|
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
|
||||||
break;
|
break;
|
||||||
|
case "mistral-ai":
|
||||||
|
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
|
||||||
|
break;
|
||||||
case "aws":
|
case "aws":
|
||||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -9,7 +9,10 @@ import {
|
|||||||
} from "../common";
|
} from "../common";
|
||||||
import { ProxyResHandlerWithBody } from ".";
|
import { ProxyResHandlerWithBody } from ".";
|
||||||
import { assertNever } from "../../../shared/utils";
|
import { assertNever } from "../../../shared/utils";
|
||||||
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
|
import {
|
||||||
|
MistralAIChatMessage,
|
||||||
|
OpenAIChatMessage,
|
||||||
|
} from "../request/preprocessors/transform-outbound-payload";
|
||||||
|
|
||||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||||
@@ -54,12 +57,13 @@ type OaiImageResult = {
|
|||||||
const getPromptForRequest = (
|
const getPromptForRequest = (
|
||||||
req: Request,
|
req: Request,
|
||||||
responseBody: Record<string, any>
|
responseBody: Record<string, any>
|
||||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
|
||||||
// Since the prompt logger only runs after the request has been proxied, we
|
// Since the prompt logger only runs after the request has been proxied, we
|
||||||
// can assume the body has already been transformed to the target API's
|
// can assume the body has already been transformed to the target API's
|
||||||
// format.
|
// format.
|
||||||
switch (req.outboundApi) {
|
switch (req.outboundApi) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
case "mistral-ai":
|
||||||
return req.body.messages;
|
return req.body.messages;
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
return req.body.prompt;
|
return req.body.prompt;
|
||||||
@@ -81,7 +85,7 @@ const getPromptForRequest = (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const flattenMessages = (
|
const flattenMessages = (
|
||||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
|
||||||
): string => {
|
): string => {
|
||||||
if (typeof val === "string") {
|
if (typeof val === "string") {
|
||||||
return val.trim();
|
return val.trim();
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import {
|
|||||||
mergeEventsForAnthropic,
|
mergeEventsForAnthropic,
|
||||||
mergeEventsForOpenAIChat,
|
mergeEventsForOpenAIChat,
|
||||||
mergeEventsForOpenAIText,
|
mergeEventsForOpenAIText,
|
||||||
OpenAIChatCompletionStreamEvent
|
OpenAIChatCompletionStreamEvent,
|
||||||
} from "./index";
|
} from "./index";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -28,6 +28,7 @@ export class EventAggregator {
|
|||||||
switch (this.format) {
|
switch (this.format) {
|
||||||
case "openai":
|
case "openai":
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
|
case "mistral-ai":
|
||||||
return mergeEventsForOpenAIChat(this.events);
|
return mergeEventsForOpenAIChat(this.events);
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
return mergeEventsForOpenAIText(this.events);
|
return mergeEventsForOpenAIText(this.events);
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ function getTransformer(
|
|||||||
): StreamingCompletionTransformer {
|
): StreamingCompletionTransformer {
|
||||||
switch (responseApi) {
|
switch (responseApi) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
case "mistral-ai":
|
||||||
return passthroughToOpenAI;
|
return passthroughToOpenAI;
|
||||||
case "openai-text":
|
case "openai-text":
|
||||||
return openAITextToOpenAIChat;
|
return openAITextToOpenAIChat;
|
||||||
|
|||||||
@@ -0,0 +1,116 @@
|
|||||||
|
import { RequestHandler, Router } from "express";
|
||||||
|
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||||
|
import { config } from "../config";
|
||||||
|
import { keyPool } from "../shared/key-management";
|
||||||
|
import {
|
||||||
|
getMistralAIModelFamily,
|
||||||
|
MistralAIModelFamily,
|
||||||
|
ModelFamily,
|
||||||
|
} from "../shared/models";
|
||||||
|
import { logger } from "../logger";
|
||||||
|
import { createQueueMiddleware } from "./queue";
|
||||||
|
import { ipLimiter } from "./rate-limit";
|
||||||
|
import { handleProxyError } from "./middleware/common";
|
||||||
|
import {
|
||||||
|
addKey,
|
||||||
|
createOnProxyReqHandler,
|
||||||
|
createPreprocessorMiddleware,
|
||||||
|
finalizeBody,
|
||||||
|
} from "./middleware/request";
|
||||||
|
import {
|
||||||
|
createOnProxyResHandler,
|
||||||
|
ProxyResHandlerWithBody,
|
||||||
|
} from "./middleware/response";
|
||||||
|
|
||||||
|
// https://docs.mistral.ai/platform/endpoints
|
||||||
|
export const KNOWN_MISTRAL_AI_MODELS = [
|
||||||
|
"mistral-tiny",
|
||||||
|
"mistral-small",
|
||||||
|
"mistral-medium",
|
||||||
|
];
|
||||||
|
|
||||||
|
let modelsCache: any = null;
|
||||||
|
let modelsCacheTime = 0;
|
||||||
|
|
||||||
|
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
|
||||||
|
let available = new Set<MistralAIModelFamily>();
|
||||||
|
for (const key of keyPool.list()) {
|
||||||
|
if (key.isDisabled || key.service !== "mistral-ai") continue;
|
||||||
|
key.modelFamilies.forEach((family) =>
|
||||||
|
available.add(family as MistralAIModelFamily)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||||
|
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||||
|
|
||||||
|
return models
|
||||||
|
.map((id) => ({
|
||||||
|
id,
|
||||||
|
object: "model",
|
||||||
|
created: new Date().getTime(),
|
||||||
|
owned_by: "mistral-ai",
|
||||||
|
}))
|
||||||
|
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||||
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
|
||||||
|
const result = generateModelList();
|
||||||
|
modelsCache = { object: "list", data: result };
|
||||||
|
modelsCacheTime = new Date().getTime();
|
||||||
|
res.status(200).json(modelsCache);
|
||||||
|
};
|
||||||
|
|
||||||
|
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||||
|
_proxyRes,
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
body
|
||||||
|
) => {
|
||||||
|
if (typeof body !== "object") {
|
||||||
|
throw new Error("Expected body to be an object");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.promptLogging) {
|
||||||
|
const host = req.get("host");
|
||||||
|
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.tokenizerInfo) {
|
||||||
|
body.proxy_tokenizer = req.tokenizerInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.status(200).json(body);
|
||||||
|
};
|
||||||
|
|
||||||
|
const mistralAIProxy = createQueueMiddleware({
|
||||||
|
proxyMiddleware: createProxyMiddleware({
|
||||||
|
target: "https://api.mistral.ai",
|
||||||
|
changeOrigin: true,
|
||||||
|
selfHandleResponse: true,
|
||||||
|
logger,
|
||||||
|
on: {
|
||||||
|
proxyReq: createOnProxyReqHandler({
|
||||||
|
pipeline: [addKey, finalizeBody],
|
||||||
|
}),
|
||||||
|
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
|
||||||
|
error: handleProxyError,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const mistralAIRouter = Router();
|
||||||
|
mistralAIRouter.get("/v1/models", handleModelRequest);
|
||||||
|
// General chat completion endpoint.
|
||||||
|
mistralAIRouter.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
ipLimiter,
|
||||||
|
createPreprocessorMiddleware({
|
||||||
|
inApi: "mistral-ai",
|
||||||
|
outApi: "mistral-ai",
|
||||||
|
service: "mistral-ai",
|
||||||
|
}),
|
||||||
|
mistralAIProxy
|
||||||
|
);
|
||||||
|
|
||||||
|
export const mistralAI = mistralAIRouter;
|
||||||
@@ -5,6 +5,7 @@ import { openai } from "./openai";
|
|||||||
import { openaiImage } from "./openai-image";
|
import { openaiImage } from "./openai-image";
|
||||||
import { anthropic } from "./anthropic";
|
import { anthropic } from "./anthropic";
|
||||||
import { googleAI } from "./google-ai";
|
import { googleAI } from "./google-ai";
|
||||||
|
import { mistralAI } from "./mistral-ai";
|
||||||
import { aws } from "./aws";
|
import { aws } from "./aws";
|
||||||
import { azure } from "./azure";
|
import { azure } from "./azure";
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ proxyRouter.use("/openai", addV1, openai);
|
|||||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||||
|
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||||
proxyRouter.use("/aws/claude", addV1, aws);
|
proxyRouter.use("/aws/claude", addV1, aws);
|
||||||
proxyRouter.use("/azure/openai", addV1, azure);
|
proxyRouter.use("/azure/openai", addV1, azure);
|
||||||
// Redirect browser requests to the homepage.
|
// Redirect browser requests to the homepage.
|
||||||
|
|||||||
+23
-1
@@ -16,6 +16,7 @@ import {
|
|||||||
GoogleAIModelFamily,
|
GoogleAIModelFamily,
|
||||||
LLM_SERVICES,
|
LLM_SERVICES,
|
||||||
LLMService,
|
LLMService,
|
||||||
|
MistralAIModelFamily,
|
||||||
MODEL_FAMILY_SERVICE,
|
MODEL_FAMILY_SERVICE,
|
||||||
ModelFamily,
|
ModelFamily,
|
||||||
OpenAIModelFamily,
|
OpenAIModelFamily,
|
||||||
@@ -24,6 +25,7 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
|
|||||||
import { getUniqueIps } from "./proxy/rate-limit";
|
import { getUniqueIps } from "./proxy/rate-limit";
|
||||||
import { assertNever } from "./shared/utils";
|
import { assertNever } from "./shared/utils";
|
||||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||||
|
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
|
||||||
|
|
||||||
const CACHE_TTL = 2000;
|
const CACHE_TTL = 2000;
|
||||||
|
|
||||||
@@ -36,6 +38,8 @@ const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
|||||||
k.service === "anthropic";
|
k.service === "anthropic";
|
||||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||||
k.service === "google-ai";
|
k.service === "google-ai";
|
||||||
|
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||||
|
k.service === "mistral-ai";
|
||||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||||
|
|
||||||
/** Stats aggregated across all keys for a given service. */
|
/** Stats aggregated across all keys for a given service. */
|
||||||
@@ -86,6 +90,7 @@ export type ServiceInfo = {
|
|||||||
"openai-image"?: string;
|
"openai-image"?: string;
|
||||||
anthropic?: string;
|
anthropic?: string;
|
||||||
"google-ai"?: string;
|
"google-ai"?: string;
|
||||||
|
"mistral-ai"?: string;
|
||||||
aws?: string;
|
aws?: string;
|
||||||
azure?: string;
|
azure?: string;
|
||||||
};
|
};
|
||||||
@@ -99,7 +104,8 @@ export type ServiceInfo = {
|
|||||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo };
|
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||||
|
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||||
|
|
||||||
// https://stackoverflow.com/a/66661477
|
// https://stackoverflow.com/a/66661477
|
||||||
// type DeepKeyOf<T> = (
|
// type DeepKeyOf<T> = (
|
||||||
@@ -128,6 +134,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
|||||||
"google-ai": {
|
"google-ai": {
|
||||||
"google-ai": `%BASE%/google-ai`,
|
"google-ai": `%BASE%/google-ai`,
|
||||||
},
|
},
|
||||||
|
"mistral-ai": {
|
||||||
|
"mistral-ai": `%BASE%/mistral-ai`,
|
||||||
|
},
|
||||||
aws: {
|
aws: {
|
||||||
aws: `%BASE%/aws/claude`,
|
aws: `%BASE%/aws/claude`,
|
||||||
},
|
},
|
||||||
@@ -268,6 +277,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
||||||
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||||
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||||
|
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
|
||||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||||
|
|
||||||
@@ -331,6 +341,18 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "mistral-ai": {
|
||||||
|
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
|
||||||
|
k.modelFamilies.forEach((f) => {
|
||||||
|
const tokens = k[`${f}Tokens`];
|
||||||
|
sumTokens += tokens;
|
||||||
|
sumCost += getTokenCostUsd(f, tokens);
|
||||||
|
increment(modelStats, `${f}__tokens`, tokens);
|
||||||
|
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||||
|
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
}
|
||||||
case "aws": {
|
case "aws": {
|
||||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||||
const family = "aws-claude";
|
const family = "aws-claude";
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ export type APIFormat =
|
|||||||
| "openai"
|
| "openai"
|
||||||
| "anthropic"
|
| "anthropic"
|
||||||
| "google-ai"
|
| "google-ai"
|
||||||
|
| "mistral-ai"
|
||||||
| "openai-text"
|
| "openai-text"
|
||||||
| "openai-image";
|
| "openai-image";
|
||||||
export type Model =
|
export type Model =
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
|||||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||||
|
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||||
|
|
||||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||||
|
|
||||||
@@ -24,6 +25,7 @@ export class KeyPool {
|
|||||||
this.keyProviders.push(new OpenAIKeyProvider());
|
this.keyProviders.push(new OpenAIKeyProvider());
|
||||||
this.keyProviders.push(new AnthropicKeyProvider());
|
this.keyProviders.push(new AnthropicKeyProvider());
|
||||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||||
|
this.keyProviders.push(new MistralAIKeyProvider());
|
||||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||||
}
|
}
|
||||||
@@ -121,6 +123,9 @@ export class KeyPool {
|
|||||||
} else if (model.includes("gemini")) {
|
} else if (model.includes("gemini")) {
|
||||||
// https://developers.generativeai.google.com/models/language
|
// https://developers.generativeai.google.com/models/language
|
||||||
return "google-ai";
|
return "google-ai";
|
||||||
|
} else if (model.includes("mistral")) {
|
||||||
|
// https://docs.mistral.ai/platform/endpoints
|
||||||
|
return "mistral-ai";
|
||||||
} else if (model.startsWith("anthropic.claude")) {
|
} else if (model.startsWith("anthropic.claude")) {
|
||||||
// AWS offers models from a few providers
|
// AWS offers models from a few providers
|
||||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||||
|
|||||||
@@ -0,0 +1,112 @@
|
|||||||
|
import axios, { AxiosError } from "axios";
|
||||||
|
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||||
|
import { KeyCheckerBase } from "../key-checker-base";
|
||||||
|
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
|
||||||
|
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
|
||||||
|
|
||||||
|
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||||
|
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||||
|
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
|
||||||
|
|
||||||
|
type GetModelsResponse = {
|
||||||
|
data: [{ id: string }];
|
||||||
|
};
|
||||||
|
|
||||||
|
type MistralAIError = {
|
||||||
|
message: string;
|
||||||
|
request_id: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
|
||||||
|
|
||||||
|
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
|
||||||
|
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
|
||||||
|
super(keys, {
|
||||||
|
service: "mistral-ai",
|
||||||
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
recurringChecksEnabled: false,
|
||||||
|
updateKey,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected async testKeyOrFail(key: MistralAIKey) {
|
||||||
|
// We only need to check for provisioned models on the initial check.
|
||||||
|
const isInitialCheck = !key.lastChecked;
|
||||||
|
if (isInitialCheck) {
|
||||||
|
const provisionedModels = await this.getProvisionedModels(key);
|
||||||
|
const updates = {
|
||||||
|
modelFamilies: provisionedModels,
|
||||||
|
};
|
||||||
|
this.updateKey(key.hash, updates);
|
||||||
|
}
|
||||||
|
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getProvisionedModels(
|
||||||
|
key: MistralAIKey
|
||||||
|
): Promise<MistralAIModelFamily[]> {
|
||||||
|
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
|
||||||
|
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||||
|
const models = data.data;
|
||||||
|
|
||||||
|
const families = new Set<MistralAIModelFamily>();
|
||||||
|
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
|
||||||
|
|
||||||
|
// We want to update the key's model families here, but we don't want to
|
||||||
|
// update its `lastChecked` timestamp because we need to let the liveness
|
||||||
|
// check run before we can consider the key checked.
|
||||||
|
|
||||||
|
const familiesArray = [...families];
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||||
|
this.updateKey(key.hash, {
|
||||||
|
modelFamilies: familiesArray,
|
||||||
|
lastChecked: keyFromPool.lastChecked,
|
||||||
|
});
|
||||||
|
return familiesArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
|
||||||
|
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
|
||||||
|
const { status, data } = error.response;
|
||||||
|
if (status === 401) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, error: data },
|
||||||
|
"Key is invalid or revoked. Disabling key."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, {
|
||||||
|
isDisabled: true,
|
||||||
|
isRevoked: true,
|
||||||
|
modelFamilies: ["mistral-tiny"],
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, status, error: data },
|
||||||
|
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, error: error.message },
|
||||||
|
"Network error while checking key; trying this key again in a minute."
|
||||||
|
);
|
||||||
|
const oneMinute = 60 * 1000;
|
||||||
|
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||||
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
|
}
|
||||||
|
|
||||||
|
static errorIsMistralAIError(
|
||||||
|
error: AxiosError
|
||||||
|
): error is AxiosError<MistralAIError> {
|
||||||
|
const data = error.response?.data as any;
|
||||||
|
return data?.message && data?.request_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
static getHeaders(key: MistralAIKey) {
|
||||||
|
return {
|
||||||
|
Authorization: `Bearer ${key.key}`,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { Key, KeyProvider, Model } from "..";
|
||||||
|
import { config } from "../../../config";
|
||||||
|
import { logger } from "../../../logger";
|
||||||
|
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||||
|
import { MistralAIKeyChecker } from "./checker";
|
||||||
|
|
||||||
|
export type MistralAIModel =
|
||||||
|
| "mistral-tiny"
|
||||||
|
| "mistral-small"
|
||||||
|
| "mistral-medium";
|
||||||
|
|
||||||
|
export type MistralAIKeyUpdate = Omit<
|
||||||
|
Partial<MistralAIKey>,
|
||||||
|
| "key"
|
||||||
|
| "hash"
|
||||||
|
| "lastUsed"
|
||||||
|
| "promptCount"
|
||||||
|
| "rateLimitedAt"
|
||||||
|
| "rateLimitedUntil"
|
||||||
|
>;
|
||||||
|
|
||||||
|
type MistralAIKeyUsage = {
|
||||||
|
[K in MistralAIModelFamily as `${K}Tokens`]: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface MistralAIKey extends Key, MistralAIKeyUsage {
|
||||||
|
readonly service: "mistral-ai";
|
||||||
|
readonly modelFamilies: MistralAIModelFamily[];
|
||||||
|
/** The time at which this key was last rate limited. */
|
||||||
|
rateLimitedAt: number;
|
||||||
|
/** The time until which this key is rate limited. */
|
||||||
|
rateLimitedUntil: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||||
|
* while we wait for other concurrent requests to finish.
|
||||||
|
*/
|
||||||
|
const RATE_LIMIT_LOCKOUT = 2000;
|
||||||
|
/**
|
||||||
|
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||||
|
* to be used again. This is to prevent the queue from flooding a key with too
|
||||||
|
* many requests while we wait to learn whether previous ones succeeded.
|
||||||
|
*/
|
||||||
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
|
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||||
|
readonly service = "mistral-ai";
|
||||||
|
|
||||||
|
private keys: MistralAIKey[] = [];
|
||||||
|
private checker?: MistralAIKeyChecker;
|
||||||
|
private log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
const keyConfig = config.mistralAIKey?.trim();
|
||||||
|
if (!keyConfig) {
|
||||||
|
this.log.warn(
|
||||||
|
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let bareKeys: string[];
|
||||||
|
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||||
|
for (const key of bareKeys) {
|
||||||
|
const newKey: MistralAIKey = {
|
||||||
|
key,
|
||||||
|
service: this.service,
|
||||||
|
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
hash: `mst-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
"mistral-tinyTokens": 0,
|
||||||
|
"mistral-smallTokens": 0,
|
||||||
|
"mistral-mediumTokens": 0,
|
||||||
|
};
|
||||||
|
this.keys.push(newKey);
|
||||||
|
}
|
||||||
|
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
|
||||||
|
}
|
||||||
|
|
||||||
|
public init() {
|
||||||
|
if (config.checkKeys) {
|
||||||
|
const updateFn = this.update.bind(this);
|
||||||
|
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
|
||||||
|
this.checker.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public list() {
|
||||||
|
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public get(_model: Model) {
|
||||||
|
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||||
|
if (availableKeys.length === 0) {
|
||||||
|
throw new Error("No Mistral AI keys available");
|
||||||
|
}
|
||||||
|
|
||||||
|
// (largely copied from the OpenAI provider, without trial key support)
|
||||||
|
// Select a key, from highest priority to lowest priority:
|
||||||
|
// 1. Keys which are not rate limited
|
||||||
|
// a. If all keys were rate limited recently, select the least-recently
|
||||||
|
// rate limited key.
|
||||||
|
// 3. Keys which have not been used in the longest time
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
const keysByPriority = availableKeys.sort((a, b) => {
|
||||||
|
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
|
||||||
|
if (aRateLimited && !bRateLimited) return 1;
|
||||||
|
if (!aRateLimited && bRateLimited) return -1;
|
||||||
|
if (aRateLimited && bRateLimited) {
|
||||||
|
return a.rateLimitedAt - b.rateLimitedAt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.lastUsed - b.lastUsed;
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedKey = keysByPriority[0];
|
||||||
|
selectedKey.lastUsed = now;
|
||||||
|
this.throttle(selectedKey.hash);
|
||||||
|
return { ...selectedKey };
|
||||||
|
}
|
||||||
|
|
||||||
|
public disable(key: MistralAIKey) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||||
|
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||||
|
keyFromPool.isDisabled = true;
|
||||||
|
this.log.warn({ key: key.hash }, "Key disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
public update(hash: string, update: Partial<MistralAIKey>) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||||
|
}
|
||||||
|
|
||||||
|
public available() {
|
||||||
|
return this.keys.filter((k) => !k.isDisabled).length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||||
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
|
if (!key) return;
|
||||||
|
key.promptCount++;
|
||||||
|
const family = getMistralAIModelFamily(model);
|
||||||
|
key[`${family}Tokens`] += tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
public getLockoutPeriod() {
|
||||||
|
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||||
|
// Don't lock out if there are no keys available or the queue will stall.
|
||||||
|
// Just let it through so the add-key middleware can throw an error.
|
||||||
|
if (activeKeys.length === 0) return 0;
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||||
|
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||||
|
|
||||||
|
if (anyNotRateLimited) return 0;
|
||||||
|
|
||||||
|
// If all keys are rate-limited, return the time until the first key is
|
||||||
|
// ready.
|
||||||
|
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is called when we receive a 429, which means there are already five
|
||||||
|
* concurrent requests running on this key. We don't have any information on
|
||||||
|
* when these requests will resolve, so all we can do is wait a bit and try
|
||||||
|
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||||
|
* retrying in order to give the other requests a chance to finish.
|
||||||
|
*/
|
||||||
|
public markRateLimited(keyHash: string) {
|
||||||
|
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||||
|
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||||
|
const now = Date.now();
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
public recheck() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||||
|
* prevent it from being immediately assigned to another request before the
|
||||||
|
* current one can be dispatched.
|
||||||
|
**/
|
||||||
|
private throttle(hash: string) {
|
||||||
|
const now = Date.now();
|
||||||
|
const key = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
|
||||||
|
const currentRateLimit = key.rateLimitedUntil;
|
||||||
|
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||||
|
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||||
|
}
|
||||||
|
}
|
||||||
+40
-2
@@ -8,7 +8,13 @@ import type { Request } from "express";
|
|||||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
||||||
*/
|
*/
|
||||||
export type LLMService = "openai" | "anthropic" | "google-ai" | "aws" | "azure";
|
export type LLMService =
|
||||||
|
| "openai"
|
||||||
|
| "anthropic"
|
||||||
|
| "google-ai"
|
||||||
|
| "mistral-ai"
|
||||||
|
| "aws"
|
||||||
|
| "azure";
|
||||||
|
|
||||||
export type OpenAIModelFamily =
|
export type OpenAIModelFamily =
|
||||||
| "turbo"
|
| "turbo"
|
||||||
@@ -18,6 +24,10 @@ export type OpenAIModelFamily =
|
|||||||
| "dall-e";
|
| "dall-e";
|
||||||
export type AnthropicModelFamily = "claude";
|
export type AnthropicModelFamily = "claude";
|
||||||
export type GoogleAIModelFamily = "gemini-pro";
|
export type GoogleAIModelFamily = "gemini-pro";
|
||||||
|
export type MistralAIModelFamily =
|
||||||
|
| "mistral-tiny"
|
||||||
|
| "mistral-small"
|
||||||
|
| "mistral-medium";
|
||||||
export type AwsBedrockModelFamily = "aws-claude";
|
export type AwsBedrockModelFamily = "aws-claude";
|
||||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||||
OpenAIModelFamily,
|
OpenAIModelFamily,
|
||||||
@@ -27,6 +37,7 @@ export type ModelFamily =
|
|||||||
| OpenAIModelFamily
|
| OpenAIModelFamily
|
||||||
| AnthropicModelFamily
|
| AnthropicModelFamily
|
||||||
| GoogleAIModelFamily
|
| GoogleAIModelFamily
|
||||||
|
| MistralAIModelFamily
|
||||||
| AwsBedrockModelFamily
|
| AwsBedrockModelFamily
|
||||||
| AzureOpenAIModelFamily;
|
| AzureOpenAIModelFamily;
|
||||||
|
|
||||||
@@ -40,6 +51,9 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||||||
"dall-e",
|
"dall-e",
|
||||||
"claude",
|
"claude",
|
||||||
"gemini-pro",
|
"gemini-pro",
|
||||||
|
"mistral-tiny",
|
||||||
|
"mistral-small",
|
||||||
|
"mistral-medium",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
"azure-turbo",
|
"azure-turbo",
|
||||||
"azure-gpt4",
|
"azure-gpt4",
|
||||||
@@ -49,7 +63,14 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||||||
|
|
||||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||||
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
|
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
|
||||||
) => arr)(["openai", "anthropic", "google-ai", "aws", "azure"] as const);
|
) => arr)([
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"google-ai",
|
||||||
|
"mistral-ai",
|
||||||
|
"aws",
|
||||||
|
"azure",
|
||||||
|
] as const);
|
||||||
|
|
||||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||||
@@ -78,6 +99,9 @@ export const MODEL_FAMILY_SERVICE: {
|
|||||||
"azure-gpt4-32k": "azure",
|
"azure-gpt4-32k": "azure",
|
||||||
"azure-gpt4-turbo": "azure",
|
"azure-gpt4-turbo": "azure",
|
||||||
"gemini-pro": "google-ai",
|
"gemini-pro": "google-ai",
|
||||||
|
"mistral-tiny": "mistral-ai",
|
||||||
|
"mistral-small": "mistral-ai",
|
||||||
|
"mistral-medium": "mistral-ai",
|
||||||
};
|
};
|
||||||
|
|
||||||
pino({ level: "debug" }).child({ module: "startup" });
|
pino({ level: "debug" }).child({ module: "startup" });
|
||||||
@@ -101,6 +125,17 @@ export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
|||||||
return "gemini-pro";
|
return "gemini-pro";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||||
|
switch (model) {
|
||||||
|
case "mistral-tiny":
|
||||||
|
case "mistral-small":
|
||||||
|
case "mistral-medium":
|
||||||
|
return model;
|
||||||
|
default:
|
||||||
|
return "mistral-tiny";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||||
return "aws-claude";
|
return "aws-claude";
|
||||||
}
|
}
|
||||||
@@ -158,6 +193,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||||||
case "google-ai":
|
case "google-ai":
|
||||||
modelFamily = getGoogleAIModelFamily(model);
|
modelFamily = getGoogleAIModelFamily(model);
|
||||||
break;
|
break;
|
||||||
|
case "mistral-ai":
|
||||||
|
modelFamily = getMistralAIModelFamily(model);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
assertNever(req.outboundApi);
|
assertNever(req.outboundApi);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
|||||||
case "claude":
|
case "claude":
|
||||||
cost = 0.00001102;
|
cost = 0.00001102;
|
||||||
break;
|
break;
|
||||||
|
case "mistral-tiny":
|
||||||
|
cost = 0.00000031;
|
||||||
|
break;
|
||||||
|
case "mistral-small":
|
||||||
|
cost = 0.00000132;
|
||||||
|
break;
|
||||||
|
case "mistral-medium":
|
||||||
|
cost = 0.0000055;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
return cost * Math.max(0, tokens);
|
return cost * Math.max(0, tokens);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ export function makeCompletionSSE({
|
|||||||
|
|
||||||
switch (format) {
|
switch (format) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
case "mistral-ai":
|
||||||
event = {
|
event = {
|
||||||
id: "chatcmpl-" + id,
|
id: "chatcmpl-" + id,
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,45 @@
|
|||||||
|
import { MistralAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload.js";
|
||||||
|
import * as tokenizer from "./mistral-tokenizer-js";
|
||||||
|
|
||||||
|
export function init() {
|
||||||
|
tokenizer.initializemistralTokenizer();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
|
||||||
|
if (typeof prompt === "string") {
|
||||||
|
return getTextTokenCount(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunks = [];
|
||||||
|
for (const message of prompt) {
|
||||||
|
switch (message.role) {
|
||||||
|
case "system":
|
||||||
|
chunks.push(message.content);
|
||||||
|
break;
|
||||||
|
case "assistant":
|
||||||
|
chunks.push(message.content + "</s>");
|
||||||
|
break;
|
||||||
|
case "user":
|
||||||
|
chunks.push("[INST] " + message.content + " [/INST]");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return getTextTokenCount(chunks.join(" "));
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTextTokenCount(prompt: string) {
|
||||||
|
// Don't try tokenizing if the prompt is massive to prevent DoS.
|
||||||
|
// 500k characters should be sufficient for all supported models.
|
||||||
|
if (prompt.length > 500000) {
|
||||||
|
return {
|
||||||
|
tokenizer: "length fallback",
|
||||||
|
token_count: 100000,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
tokenizer: "mistral-tokenizer-js",
|
||||||
|
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import { Request } from "express";
|
import { Request } from "express";
|
||||||
import type {
|
import type {
|
||||||
GoogleAIChatMessage,
|
GoogleAIChatMessage,
|
||||||
|
MistralAIChatMessage,
|
||||||
OpenAIChatMessage,
|
OpenAIChatMessage,
|
||||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||||
import { assertNever } from "../utils";
|
import { assertNever } from "../utils";
|
||||||
@@ -14,11 +15,16 @@ import {
|
|||||||
getOpenAIImageCost,
|
getOpenAIImageCost,
|
||||||
estimateGoogleAITokenCount,
|
estimateGoogleAITokenCount,
|
||||||
} from "./openai";
|
} from "./openai";
|
||||||
|
import {
|
||||||
|
init as initMistralAI,
|
||||||
|
getTokenCount as getMistralAITokenCount,
|
||||||
|
} from "./mistral";
|
||||||
import { APIFormat } from "../key-management";
|
import { APIFormat } from "../key-management";
|
||||||
|
|
||||||
export async function init() {
|
export async function init() {
|
||||||
initClaude();
|
initClaude();
|
||||||
initOpenAi();
|
initOpenAi();
|
||||||
|
initMistralAI();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Tagged union via `service` field of the different types of requests that can
|
/** Tagged union via `service` field of the different types of requests that can
|
||||||
@@ -31,6 +37,7 @@ type TokenCountRequest = { req: Request } & (
|
|||||||
service: "openai-text" | "anthropic" | "google-ai";
|
service: "openai-text" | "anthropic" | "google-ai";
|
||||||
}
|
}
|
||||||
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
|
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
|
||||||
|
| { prompt: MistralAIChatMessage[]; completion?: never; service: "mistral-ai" }
|
||||||
| { prompt?: never; completion: string; service: APIFormat }
|
| { prompt?: never; completion: string; service: APIFormat }
|
||||||
| { prompt?: never; completion?: never; service: "openai-image" }
|
| { prompt?: never; completion?: never; service: "openai-image" }
|
||||||
);
|
);
|
||||||
@@ -77,6 +84,11 @@ export async function countTokens({
|
|||||||
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
|
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
|
||||||
tokenization_duration_ms: getElapsedMs(time),
|
tokenization_duration_ms: getElapsedMs(time),
|
||||||
};
|
};
|
||||||
|
case "mistral-ai":
|
||||||
|
return {
|
||||||
|
...getMistralAITokenCount(prompt ?? completion),
|
||||||
|
tokenization_duration_ms: getElapsedMs(time),
|
||||||
|
};
|
||||||
default:
|
default:
|
||||||
assertNever(service);
|
assertNever(service);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
getAzureOpenAIModelFamily,
|
getAzureOpenAIModelFamily,
|
||||||
getClaudeModelFamily,
|
getClaudeModelFamily,
|
||||||
getGoogleAIModelFamily,
|
getGoogleAIModelFamily,
|
||||||
|
getMistralAIModelFamily,
|
||||||
getOpenAIModelFamily,
|
getOpenAIModelFamily,
|
||||||
MODEL_FAMILIES,
|
MODEL_FAMILIES,
|
||||||
ModelFamily,
|
ModelFamily,
|
||||||
@@ -34,6 +35,9 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
|||||||
"dall-e": 0,
|
"dall-e": 0,
|
||||||
claude: 0,
|
claude: 0,
|
||||||
"gemini-pro": 0,
|
"gemini-pro": 0,
|
||||||
|
"mistral-tiny": 0,
|
||||||
|
"mistral-small": 0,
|
||||||
|
"mistral-medium": 0,
|
||||||
"aws-claude": 0,
|
"aws-claude": 0,
|
||||||
"azure-turbo": 0,
|
"azure-turbo": 0,
|
||||||
"azure-gpt4": 0,
|
"azure-gpt4": 0,
|
||||||
@@ -399,6 +403,8 @@ function getModelFamilyForQuotaUsage(
|
|||||||
return getClaudeModelFamily(model);
|
return getClaudeModelFamily(model);
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
return getGoogleAIModelFamily(model);
|
return getGoogleAIModelFamily(model);
|
||||||
|
case "mistral-ai":
|
||||||
|
return getMistralAIModelFamily(model);
|
||||||
default:
|
default:
|
||||||
assertNever(api);
|
assertNever(api);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user