Compare commits
8 Commits
sqlite-users
...
cohere
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d5059534a | |||
| 0ea43f61c2 | |||
| ca4321b4cb | |||
| 7660ed8b94 | |||
| 55f1bbed3b | |||
| 57fd17ede0 | |||
| 9d00b8a9de | |||
| 155e185c6e |
+15
-6
@@ -46,6 +46,14 @@ NODE_ENV=production
|
||||
# 'azure-dall-e' to the list of allowed model families.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
|
||||
# Which services can be used to process prompts containing images via multimodal
|
||||
# models. The following services are recognized:
|
||||
# openai | anthropic | aws | azure | google-ai | mistral-ai
|
||||
# Do not enable this feature unless all users are trusted, as you will be liable
|
||||
# for any user-submitted images containing illegal content.
|
||||
# By default, no image services are allowed and image prompts are rejected.
|
||||
# ALLOWED_VISION_SERVICES=
|
||||
|
||||
# IP addresses or CIDR blocks from which requests will be blocked.
|
||||
# IP_BLACKLIST=10.0.0.1/24
|
||||
# URLs from which requests will be blocked.
|
||||
@@ -60,7 +68,7 @@ NODE_ENV=production
|
||||
# Avoid short or common phrases as this tests the entire prompt.
|
||||
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
|
||||
# Message to show when requests are rejected.
|
||||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
||||
# REJECT_MESSAGE="You can't say that here."
|
||||
|
||||
# Whether prompts should be logged to Google Sheets.
|
||||
# Requires additional setup. See `docs/google-sheets.md` for more information.
|
||||
@@ -102,18 +110,19 @@ NODE_ENV=production
|
||||
# ALLOW_NICKNAME_CHANGES=true
|
||||
|
||||
# Default token quotas for each model family. (0 for unlimited)
|
||||
# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
|
||||
# which is similar to the cost of GPT-4 Turbo.
|
||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||
# See `docs/dall-e-configuration.md` for more information.
|
||||
# Specify as TOKEN_QUOTA_MODEL_FAMILY=value, replacing dashes with underscores.
|
||||
# TOKEN_QUOTA_TURBO=0
|
||||
# TOKEN_QUOTA_GPT4=0
|
||||
# TOKEN_QUOTA_GPT4_32K=0
|
||||
# TOKEN_QUOTA_GPT4_TURBO=0
|
||||
# TOKEN_QUOTA_DALL_E=0
|
||||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
||||
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||
# See `docs/dall-e-configuration.md` for more information.
|
||||
# TOKEN_QUOTA_DALL_E=0
|
||||
|
||||
# How often to refresh token quotas. (hourly | daily)
|
||||
# Leave unset to never automatically refresh quotas.
|
||||
|
||||
@@ -70,4 +70,4 @@ You can provide a comma-separated list containing individual IPv4 or IPv6 addres
|
||||
|
||||
To whitelist an entire IP range, use CIDR notation. For example, `192.168.0.1/24` would whitelist all addresses from `192.168.0.0` to `192.168.0.255`.
|
||||
|
||||
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0`, which will allow access from any IP address. This is the default behavior.
|
||||
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0,::0`, which will allow access from any IPv4 or IPv6 address. This is the default behavior.
|
||||
|
||||
+28
-8
@@ -3,7 +3,7 @@ import dotenv from "dotenv";
|
||||
import type firebase from "firebase-admin";
|
||||
import path from "path";
|
||||
import pino from "pino";
|
||||
import type { ModelFamily } from "./shared/models";
|
||||
import type { LLMService, ModelFamily } from "./shared/models";
|
||||
import { MODEL_FAMILIES } from "./shared/models";
|
||||
|
||||
dotenv.config();
|
||||
@@ -340,13 +340,18 @@ type Config = {
|
||||
*/
|
||||
allowOpenAIToolUsage?: boolean;
|
||||
/**
|
||||
* Whether to allow prompts containing images, for use with multimodal models.
|
||||
* Avoid giving this to untrusted users, as they can submit illegal content.
|
||||
* Which services will accept prompts containing images, for use with
|
||||
* multimodal models. Users with `special` role are exempt from this
|
||||
* restriction.
|
||||
*
|
||||
* Applies to GPT-4 Vision and Claude Vision. Users with `special` role are
|
||||
* exempt from this restriction.
|
||||
* Do not enable this feature for untrusted users, as malicious users could
|
||||
* send images which violate your provider's terms of service or local laws.
|
||||
*
|
||||
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
||||
* separated list. Available services are:
|
||||
* openai,anthropic,google-ai,mistral-ai,aws,azure
|
||||
*/
|
||||
allowImagePrompts?: boolean;
|
||||
allowedVisionServices: LLMService[];
|
||||
/**
|
||||
* Allows overriding the default proxy endpoint route. Defaults to /proxy.
|
||||
* A leading slash is required.
|
||||
@@ -479,9 +484,13 @@ export const config: Config = {
|
||||
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
|
||||
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
|
||||
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
|
||||
allowImagePrompts: getEnvWithDefault("ALLOW_IMAGE_PROMPTS", false),
|
||||
allowedVisionServices: parseCsv(
|
||||
getEnvWithDefault("ALLOWED_VISION_SERVICES", "")
|
||||
) as LLMService[],
|
||||
proxyEndpointRoute: getEnvWithDefault("PROXY_ENDPOINT_ROUTE", "/proxy"),
|
||||
adminWhitelist: parseCsv(getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0")),
|
||||
adminWhitelist: parseCsv(
|
||||
getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0,::/0")
|
||||
),
|
||||
ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")),
|
||||
} as const;
|
||||
|
||||
@@ -534,6 +543,17 @@ export async function assertConfigIsValid() {
|
||||
);
|
||||
}
|
||||
|
||||
if (process.env.ALLOW_IMAGE_PROMPTS === "true") {
|
||||
const hasAllowedServices = config.allowedVisionServices.length > 0;
|
||||
if (!hasAllowedServices) {
|
||||
config.allowedVisionServices = ["openai", "anthropic"];
|
||||
startupLogger.warn(
|
||||
{ allowedVisionServices: config.allowedVisionServices },
|
||||
"ALLOW_IMAGE_PROMPTS is deprecated. Use ALLOWED_VISION_SERVICES instead."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.promptLogging && !config.promptLoggingBackend) {
|
||||
throw new Error(
|
||||
"Prompt logging is enabled but no backend is configured. Set PROMPT_LOGGING_BACKEND to 'google_sheets' or 'file'."
|
||||
|
||||
@@ -66,7 +66,8 @@ export const gatekeeper: RequestHandler = (req, res, next) => {
|
||||
req,
|
||||
res,
|
||||
403,
|
||||
"Forbidden: no more IPs can authenticate with this user token"
|
||||
`Forbidden: no more IP addresses allowed for this user token`,
|
||||
{ currentIp: ip, maxIps: user?.maxIps }
|
||||
);
|
||||
case "disabled":
|
||||
const bannedUser = getUser(token);
|
||||
@@ -84,7 +85,8 @@ function sendError(
|
||||
req: Request,
|
||||
res: Response,
|
||||
status: number,
|
||||
message: string
|
||||
message: string,
|
||||
data: any = {}
|
||||
) {
|
||||
const isPost = req.method === "POST";
|
||||
const hasBody = isPost && req.body;
|
||||
@@ -103,6 +105,7 @@ function sendError(
|
||||
format: "unknown",
|
||||
statusCode: status,
|
||||
reqId: req.id,
|
||||
obj: data,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -9,9 +9,14 @@ import { ForbiddenError } from "../../../../shared/errors";
|
||||
* Rejects prompts containing images if multimodal prompts are disabled.
|
||||
*/
|
||||
export const validateVision: RequestPreprocessor = async (req) => {
|
||||
if (config.allowImagePrompts) return;
|
||||
if (req.user?.type === "special") return;
|
||||
if (req.service === undefined) {
|
||||
throw new Error("Request service must be set before validateVision");
|
||||
}
|
||||
|
||||
if (req.user?.type === "special") return;
|
||||
if (config.allowedVisionServices.includes(req.service)) return;
|
||||
|
||||
// vision not allowed for req's service, block prompts with images
|
||||
let hasImage = false;
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
|
||||
@@ -52,7 +52,13 @@ function getMessageContent({
|
||||
delete obj.stack;
|
||||
}
|
||||
|
||||
return [header, friendlyMessage, serializedObj, prettyTrace].join("\n\n");
|
||||
return [
|
||||
header,
|
||||
friendlyMessage,
|
||||
serializedObj,
|
||||
prettyTrace,
|
||||
"<!-- oai-proxy-error -->",
|
||||
].join("\n\n");
|
||||
}
|
||||
|
||||
type ErrorGeneratorOptions = {
|
||||
@@ -116,6 +122,11 @@ export function sendErrorToClient({
|
||||
const isStreaming =
|
||||
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.setHeader("x-oai-proxy-error", options.title);
|
||||
res.setHeader("x-oai-proxy-error-status", options.statusCode || 500);
|
||||
}
|
||||
|
||||
if (isStreaming) {
|
||||
if (!res.headersSent) {
|
||||
initializeSseStream(res);
|
||||
|
||||
+1
-4
@@ -179,10 +179,7 @@ function cleanup() {
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
process.on("exit", () => cleanup());
|
||||
process.on("SIGHUP", () => process.exit(128 + 1));
|
||||
process.on("SIGINT", () => process.exit(128 + 2));
|
||||
process.on("SIGTERM", () => process.exit(128 + 15));
|
||||
process.on("SIGINT", cleanup);
|
||||
|
||||
function registerUncaughtExceptionHandler() {
|
||||
process.on("uncaughtException", (err: any) => {
|
||||
|
||||
@@ -119,7 +119,8 @@ export const transformOpenAIToAnthropicChat: APIFormatTransformer<
|
||||
stream: rest.stream,
|
||||
temperature: rest.temperature,
|
||||
top_p: rest.top_p,
|
||||
stop_sequences: typeof rest.stop === "string" ? [rest.stop] : rest.stop,
|
||||
stop_sequences:
|
||||
typeof rest.stop === "string" ? [rest.stop] : rest.stop || undefined,
|
||||
...(rest.user ? { metadata: { user_id: rest.user } } : {}),
|
||||
// Anthropic supports top_k, but OpenAI does not
|
||||
// OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed,
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
import { z } from "zod";
|
||||
import {
|
||||
OPENAI_OUTPUT_MAX,
|
||||
OpenAIV1ChatCompletionSchema,
|
||||
flattenOpenAIMessageContent,
|
||||
} from "./openai";
|
||||
import { APIFormatTransformer } from ".";
|
||||
|
||||
// https://docs.cohere.com/reference/chat
|
||||
export const CohereV1ChatSchema = z
|
||||
.object({
|
||||
message: z.string(),
|
||||
model: z.string().default("command-r-plus"),
|
||||
stream: z.boolean().default(false).optional(),
|
||||
preamble: z.string().optional(),
|
||||
chat_history: z
|
||||
.array(
|
||||
// Either a message from a chat participant, or a past tool call
|
||||
z.union([
|
||||
z.object({
|
||||
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
|
||||
message: z.string(),
|
||||
tool_calls: z
|
||||
.array(z.object({ name: z.string(), parameters: z.any() }))
|
||||
.optional(),
|
||||
}),
|
||||
z.object({
|
||||
role: z.enum(["TOOL"]),
|
||||
tool_results: z.array(
|
||||
z.object({
|
||||
call: z.object({ name: z.string(), parameters: z.any() }),
|
||||
outputs: z.array(z.any()),
|
||||
})
|
||||
),
|
||||
}),
|
||||
])
|
||||
)
|
||||
.optional(),
|
||||
// Don't allow conversation_id as it causes calls to be stateful and we don't
|
||||
// offer guarantees about which key a user's request will be routed to.
|
||||
conversation_id: z.literal(undefined).optional(),
|
||||
prompt_truncation: z
|
||||
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
|
||||
.optional(),
|
||||
/*
|
||||
Supporting RAG is complex because documents can be arbitrary size and have
|
||||
to have embeddings generated, which incurs a cost that is not trivial to
|
||||
estimate. We don't support it for now.
|
||||
connectors: z
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
user_access_token: z.string().optional(),
|
||||
continue_on_failure: z.boolean().default(false).optional(),
|
||||
options: z.any().optional(),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
search_queries_only: z.boolean().default(false).optional(),
|
||||
documents: z
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string().optional(),
|
||||
title: z.string().optional(),
|
||||
text: z.string(),
|
||||
_excludes: z.array(z.string()).optional(),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
citation_quality: z.enum(["accurate", "fast"]).optional(),
|
||||
*/
|
||||
temperature: z.number().default(0.3).optional(),
|
||||
max_tokens: z
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
max_input_tokens: z.number().int().optional(),
|
||||
k: z.number().int().min(0).max(500).default(0).optional(),
|
||||
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
|
||||
seed: z.number().int().optional(),
|
||||
stop_sequences: z.array(z.string()).max(5).optional(),
|
||||
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||
presence_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||
tools: z
|
||||
.array(
|
||||
z.object({
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
parameter_definitions: z.record(
|
||||
z.object({
|
||||
description: z.string().optional(),
|
||||
type: z.string(),
|
||||
required: z.boolean().optional().default(false),
|
||||
})
|
||||
),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
tool_results: z
|
||||
.array(
|
||||
z.object({
|
||||
call: z.object({
|
||||
name: z.string(),
|
||||
parameters: z.record(z.any()),
|
||||
}),
|
||||
outputs: z.array(z.record(z.any())),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
// We always force single step to avoid stateful calls or expensive multi-step
|
||||
// generations when tools are involved.
|
||||
force_single_step: z.literal(true).default(true).optional(),
|
||||
})
|
||||
.strip();
|
||||
export type CohereChatMessage = NonNullable<
|
||||
z.infer<typeof CohereV1ChatSchema>["chat_history"]
|
||||
>[number];
|
||||
|
||||
export function flattenCohereMessageContent(
|
||||
message: CohereChatMessage
|
||||
): string {
|
||||
return message.role === "TOOL"
|
||||
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
|
||||
: message.message;
|
||||
}
|
||||
|
||||
export const transformOpenAIToCohere: APIFormatTransformer<
|
||||
typeof CohereV1ChatSchema
|
||||
> = async (req) => {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
...body,
|
||||
model: "gpt-3.5-turbo",
|
||||
});
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Cohere request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
// Final OAI message becomes the `message` field in Cohere
|
||||
const message = messages[messages.length - 1];
|
||||
// If the first message has system role, use it as preamble.
|
||||
const hasSystemPreamble = messages[0]?.role === "system";
|
||||
const preamble = hasSystemPreamble
|
||||
? flattenOpenAIMessageContent(messages[0].content)
|
||||
: undefined;
|
||||
|
||||
const chatHistory = messages.slice(0, -1).map((m) => {
|
||||
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
|
||||
m.role === "assistant"
|
||||
? "CHATBOT"
|
||||
: m.role === "system"
|
||||
? "SYSTEM"
|
||||
: "USER";
|
||||
const content = flattenOpenAIMessageContent(m.content);
|
||||
const message = m.name ? `${m.name}: ${content}` : content;
|
||||
return { role, message };
|
||||
});
|
||||
|
||||
return {
|
||||
model: rest.model,
|
||||
preamble,
|
||||
chat_history: chatHistory,
|
||||
message: flattenOpenAIMessageContent(message.content),
|
||||
stop_sequences:
|
||||
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
|
||||
max_tokens: rest.max_tokens,
|
||||
temperature: rest.temperature,
|
||||
p: rest.top_p,
|
||||
frequency_penalty: rest.frequency_penalty,
|
||||
presence_penalty: rest.presence_penalty,
|
||||
seed: rest.seed,
|
||||
stream: rest.stream,
|
||||
};
|
||||
};
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
transformOpenAIToGoogleAI,
|
||||
} from "./google-ai";
|
||||
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
|
||||
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
|
||||
|
||||
export { OpenAIChatMessage } from "./openai";
|
||||
export {
|
||||
@@ -33,15 +34,29 @@ export {
|
||||
export { GoogleAIChatMessage } from "./google-ai";
|
||||
export { MistralAIChatMessage } from "./mistral-ai";
|
||||
|
||||
/** Represents a pair of API formats that can be transformed between. */
|
||||
type APIPair = `${APIFormat}->${APIFormat}`;
|
||||
/** Represents a map of API format pairs to transformer functions. */
|
||||
type TransformerMap = {
|
||||
[key in APIPair]?: APIFormatTransformer<any>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents a transformer function that takes a Request and returns a Promise
|
||||
* resolving to a value of the specified Zod schema type.
|
||||
*
|
||||
* @template Z The Zod schema type to transform the request into (from api-schemas).
|
||||
* @param req The incoming Request to transform.
|
||||
* @returns A Promise resolving to the transformed request body.
|
||||
*/
|
||||
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
|
||||
req: Request
|
||||
) => Promise<z.infer<Z>>;
|
||||
|
||||
/**
|
||||
* Specifies possible translations between API formats and the corresponding
|
||||
* transformer functions to apply them.
|
||||
*/
|
||||
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
|
||||
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
|
||||
@@ -49,8 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||
"openai->openai-text": transformOpenAIToOpenAIText,
|
||||
"openai->openai-image": transformOpenAIToOpenAIImage,
|
||||
"openai->google-ai": transformOpenAIToGoogleAI,
|
||||
"openai->cohere-chat": transformOpenAIToCohere,
|
||||
};
|
||||
|
||||
/**
|
||||
* Specifies the schema for each API format to validate incoming requests.
|
||||
*/
|
||||
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
"anthropic-chat": AnthropicV1MessagesSchema,
|
||||
"anthropic-text": AnthropicV1TextSchema,
|
||||
@@ -59,4 +78,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||
"cohere-chat": CohereV1ChatSchema,
|
||||
};
|
||||
|
||||
@@ -47,7 +47,7 @@ export const OpenAIV1ChatCompletionSchema = z
|
||||
stream: z.boolean().optional().default(false),
|
||||
stop: z
|
||||
.union([z.string().max(500), z.array(z.string().max(500))])
|
||||
.optional(),
|
||||
.nullish(),
|
||||
max_tokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
|
||||
@@ -9,7 +9,8 @@ export type APIFormat =
|
||||
| "anthropic-chat" // Anthropic's newer messages array format
|
||||
| "anthropic-text" // Legacy flat string prompt format
|
||||
| "google-ai"
|
||||
| "mistral-ai";
|
||||
| "mistral-ai"
|
||||
| "cohere-chat";
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
|
||||
+18
-2
@@ -14,7 +14,8 @@ export type LLMService =
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "aws"
|
||||
| "azure";
|
||||
| "azure"
|
||||
| "cohere";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
@@ -32,13 +33,15 @@ export type MistralAIModelFamily =
|
||||
| "mistral-large";
|
||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||
export type CohereModelFamily = "command-r" | "command-r-plus";
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GoogleAIModelFamily
|
||||
| MistralAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
| AzureOpenAIModelFamily
|
||||
| CohereModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
@@ -64,6 +67,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4o",
|
||||
"azure-dall-e",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
] as const);
|
||||
|
||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
@@ -75,6 +80,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
"mistral-ai",
|
||||
"aws",
|
||||
"azure",
|
||||
"cohere",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
@@ -116,6 +122,8 @@ export const MODEL_FAMILY_SERVICE: {
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
"mistral-large": "mistral-ai",
|
||||
"command-r": "cohere",
|
||||
"command-r-plus": "cohere",
|
||||
};
|
||||
|
||||
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
||||
@@ -181,6 +189,11 @@ export function getAzureOpenAIModelFamily(
|
||||
return defaultFamily;
|
||||
}
|
||||
|
||||
export function getCohereModelFamily(model: string): CohereModelFamily {
|
||||
if (model.includes("plus")) return "command-r-plus";
|
||||
return "command-r";
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
@@ -220,6 +233,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
case "mistral-ai":
|
||||
modelFamily = getMistralAIModelFamily(model);
|
||||
break;
|
||||
case "cohere-chat":
|
||||
modelFamily = getCohereModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user