From 8d5059534a65b4611ed8a062a9df2b3faa370911 Mon Sep 17 00:00:00 2001 From: nai-degen Date: Sun, 9 Jun 2024 12:46:02 -0500 Subject: [PATCH] starts adding cohere api format and schemas --- src/shared/api-schemas/cohere.ts | 181 +++++++++++++++++++++++++++++ src/shared/api-schemas/index.ts | 20 ++++ src/shared/key-management/index.ts | 3 +- src/shared/models.ts | 20 +++- 4 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 src/shared/api-schemas/cohere.ts diff --git a/src/shared/api-schemas/cohere.ts b/src/shared/api-schemas/cohere.ts new file mode 100644 index 0000000..4e62f32 --- /dev/null +++ b/src/shared/api-schemas/cohere.ts @@ -0,0 +1,181 @@ +import { z } from "zod"; +import { + OPENAI_OUTPUT_MAX, + OpenAIV1ChatCompletionSchema, + flattenOpenAIMessageContent, +} from "./openai"; +import { APIFormatTransformer } from "."; + +// https://docs.cohere.com/reference/chat +export const CohereV1ChatSchema = z + .object({ + message: z.string(), + model: z.string().default("command-r-plus"), + stream: z.boolean().default(false).optional(), + preamble: z.string().optional(), + chat_history: z + .array( + // Either a message from a chat participant, or a past tool call + z.union([ + z.object({ + role: z.enum(["CHATBOT", "SYSTEM", "USER"]), + message: z.string(), + tool_calls: z + .array(z.object({ name: z.string(), parameters: z.any() })) + .optional(), + }), + z.object({ + role: z.enum(["TOOL"]), + tool_results: z.array( + z.object({ + call: z.object({ name: z.string(), parameters: z.any() }), + outputs: z.array(z.any()), + }) + ), + }), + ]) + ) + .optional(), + // Don't allow conversation_id as it causes calls to be stateful and we don't + // offer guarantees about which key a user's request will be routed to. + conversation_id: z.literal(undefined).optional(), + prompt_truncation: z + .enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"]) + .optional(), + /* + Supporting RAG is complex because documents can be arbitrary size and have + to have embeddings generated, which incurs a cost that is not trivial to + estimate. We don't support it for now. + connectors: z + .array( + z.object({ + id: z.string(), + user_access_token: z.string().optional(), + continue_on_failure: z.boolean().default(false).optional(), + options: z.any().optional(), + }) + ) + .optional(), + search_queries_only: z.boolean().default(false).optional(), + documents: z + .array( + z.object({ + id: z.string().optional(), + title: z.string().optional(), + text: z.string(), + _excludes: z.array(z.string()).optional(), + }) + ) + .optional(), + citation_quality: z.enum(["accurate", "fast"]).optional(), + */ + temperature: z.number().default(0.3).optional(), + max_tokens: z + .number() + .int() + .nullish() + .default(Math.min(OPENAI_OUTPUT_MAX, 4096)) + .transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)), + max_input_tokens: z.number().int().optional(), + k: z.number().int().min(0).max(500).default(0).optional(), + p: z.number().min(0.01).max(0.99).default(0.75).optional(), + seed: z.number().int().optional(), + stop_sequences: z.array(z.string()).max(5).optional(), + frequency_penalty: z.number().min(0).max(1).default(0).optional(), + presence_penalty: z.number().min(0).max(1).default(0).optional(), + tools: z + .array( + z.object({ + name: z.string(), + description: z.string(), + parameter_definitions: z.record( + z.object({ + description: z.string().optional(), + type: z.string(), + required: z.boolean().optional().default(false), + }) + ), + }) + ) + .optional(), + tool_results: z + .array( + z.object({ + call: z.object({ + name: z.string(), + parameters: z.record(z.any()), + }), + outputs: z.array(z.record(z.any())), + }) + ) + .optional(), + // We always force single step to avoid stateful calls or expensive multi-step + // generations when tools are involved. + force_single_step: z.literal(true).default(true).optional(), + }) + .strip(); +export type CohereChatMessage = NonNullable< + z.infer["chat_history"] +>[number]; + +export function flattenCohereMessageContent( + message: CohereChatMessage +): string { + return message.role === "TOOL" + ? message.tool_results.map((r) => r.outputs[0].text).join("\n") + : message.message; +} + +export const transformOpenAIToCohere: APIFormatTransformer< + typeof CohereV1ChatSchema +> = async (req) => { + const { body } = req; + const result = OpenAIV1ChatCompletionSchema.safeParse({ + ...body, + model: "gpt-3.5-turbo", + }); + if (!result.success) { + req.log.warn( + { issues: result.error.issues, body }, + "Invalid OpenAI-to-Cohere request" + ); + throw result.error; + } + + const { messages, ...rest } = result.data; + // Final OAI message becomes the `message` field in Cohere + const message = messages[messages.length - 1]; + // If the first message has system role, use it as preamble. + const hasSystemPreamble = messages[0]?.role === "system"; + const preamble = hasSystemPreamble + ? flattenOpenAIMessageContent(messages[0].content) + : undefined; + + const chatHistory = messages.slice(0, -1).map((m) => { + const role: Exclude = + m.role === "assistant" + ? "CHATBOT" + : m.role === "system" + ? "SYSTEM" + : "USER"; + const content = flattenOpenAIMessageContent(m.content); + const message = m.name ? `${m.name}: ${content}` : content; + return { role, message }; + }); + + return { + model: rest.model, + preamble, + chat_history: chatHistory, + message: flattenOpenAIMessageContent(message.content), + stop_sequences: + typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined, + max_tokens: rest.max_tokens, + temperature: rest.temperature, + p: rest.top_p, + frequency_penalty: rest.frequency_penalty, + presence_penalty: rest.presence_penalty, + seed: rest.seed, + stream: rest.stream, + }; +}; diff --git a/src/shared/api-schemas/index.ts b/src/shared/api-schemas/index.ts index 598bf23..f333415 100644 --- a/src/shared/api-schemas/index.ts +++ b/src/shared/api-schemas/index.ts @@ -22,6 +22,7 @@ import { transformOpenAIToGoogleAI, } from "./google-ai"; import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai"; +import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere"; export { OpenAIChatMessage } from "./openai"; export { @@ -33,15 +34,29 @@ export { export { GoogleAIChatMessage } from "./google-ai"; export { MistralAIChatMessage } from "./mistral-ai"; +/** Represents a pair of API formats that can be transformed between. */ type APIPair = `${APIFormat}->${APIFormat}`; +/** Represents a map of API format pairs to transformer functions. */ type TransformerMap = { [key in APIPair]?: APIFormatTransformer; }; +/** + * Represents a transformer function that takes a Request and returns a Promise + * resolving to a value of the specified Zod schema type. + * + * @template Z The Zod schema type to transform the request into (from api-schemas). + * @param req The incoming Request to transform. + * @returns A Promise resolving to the transformed request body. + */ export type APIFormatTransformer> = ( req: Request ) => Promise>; +/** + * Specifies possible translations between API formats and the corresponding + * transformer functions to apply them. + */ export const API_REQUEST_TRANSFORMERS: TransformerMap = { "anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat, "openai->anthropic-chat": transformOpenAIToAnthropicChat, @@ -49,8 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = { "openai->openai-text": transformOpenAIToOpenAIText, "openai->openai-image": transformOpenAIToOpenAIImage, "openai->google-ai": transformOpenAIToGoogleAI, + "openai->cohere-chat": transformOpenAIToCohere, }; +/** + * Specifies the schema for each API format to validate incoming requests. + */ export const API_REQUEST_VALIDATORS: Record> = { "anthropic-chat": AnthropicV1MessagesSchema, "anthropic-text": AnthropicV1TextSchema, @@ -59,4 +78,5 @@ export const API_REQUEST_VALIDATORS: Record> = { "openai-image": OpenAIV1ImagesGenerationSchema, "google-ai": GoogleAIV1GenerateContentSchema, "mistral-ai": MistralAIV1ChatCompletionsSchema, + "cohere-chat": CohereV1ChatSchema, }; diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 5e43e57..2bf3c12 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -9,7 +9,8 @@ export type APIFormat = | "anthropic-chat" // Anthropic's newer messages array format | "anthropic-text" // Legacy flat string prompt format | "google-ai" - | "mistral-ai"; + | "mistral-ai" + | "cohere-chat"; export interface Key { /** The API key itself. Never log this, use `hash` instead. */ diff --git a/src/shared/models.ts b/src/shared/models.ts index 3721f0d..2a1a581 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -14,7 +14,8 @@ export type LLMService = | "google-ai" | "mistral-ai" | "aws" - | "azure"; + | "azure" + | "cohere"; export type OpenAIModelFamily = | "turbo" @@ -32,13 +33,15 @@ export type MistralAIModelFamily = | "mistral-large"; export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus"; export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`; +export type CohereModelFamily = "command-r" | "command-r-plus"; export type ModelFamily = | OpenAIModelFamily | AnthropicModelFamily | GoogleAIModelFamily | MistralAIModelFamily | AwsBedrockModelFamily - | AzureOpenAIModelFamily; + | AzureOpenAIModelFamily + | CohereModelFamily; export const MODEL_FAMILIES = (( arr: A & ([ModelFamily] extends [A[number]] ? unknown : never) @@ -64,6 +67,8 @@ export const MODEL_FAMILIES = (( "azure-gpt4-turbo", "azure-gpt4o", "azure-dall-e", + "command-r", + "command-r-plus", ] as const); export const LLM_SERVICES = (( @@ -75,6 +80,7 @@ export const LLM_SERVICES = (( "mistral-ai", "aws", "azure", + "cohere", ] as const); export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = { @@ -116,6 +122,8 @@ export const MODEL_FAMILY_SERVICE: { "mistral-small": "mistral-ai", "mistral-medium": "mistral-ai", "mistral-large": "mistral-ai", + "command-r": "cohere", + "command-r-plus": "cohere", }; export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"]; @@ -181,6 +189,11 @@ export function getAzureOpenAIModelFamily( return defaultFamily; } +export function getCohereModelFamily(model: string): CohereModelFamily { + if (model.includes("plus")) return "command-r-plus"; + return "command-r"; +} + export function assertIsKnownModelFamily( modelFamily: string ): asserts modelFamily is ModelFamily { @@ -220,6 +233,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { case "mistral-ai": modelFamily = getMistralAIModelFamily(model); break; + case "cohere-chat": + modelFamily = getCohereModelFamily(model); + break; default: assertNever(req.outboundApi); }