1 Commits

Author SHA1 Message Date
nai-degen 8d5059534a starts adding cohere api format and schemas 2024-06-09 12:46:02 -05:00
4 changed files with 221 additions and 3 deletions
+181
View File
@@ -0,0 +1,181 @@
import { z } from "zod";
import {
OPENAI_OUTPUT_MAX,
OpenAIV1ChatCompletionSchema,
flattenOpenAIMessageContent,
} from "./openai";
import { APIFormatTransformer } from ".";
// https://docs.cohere.com/reference/chat
export const CohereV1ChatSchema = z
.object({
message: z.string(),
model: z.string().default("command-r-plus"),
stream: z.boolean().default(false).optional(),
preamble: z.string().optional(),
chat_history: z
.array(
// Either a message from a chat participant, or a past tool call
z.union([
z.object({
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
message: z.string(),
tool_calls: z
.array(z.object({ name: z.string(), parameters: z.any() }))
.optional(),
}),
z.object({
role: z.enum(["TOOL"]),
tool_results: z.array(
z.object({
call: z.object({ name: z.string(), parameters: z.any() }),
outputs: z.array(z.any()),
})
),
}),
])
)
.optional(),
// Don't allow conversation_id as it causes calls to be stateful and we don't
// offer guarantees about which key a user's request will be routed to.
conversation_id: z.literal(undefined).optional(),
prompt_truncation: z
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
.optional(),
/*
Supporting RAG is complex because documents can be arbitrary size and have
to have embeddings generated, which incurs a cost that is not trivial to
estimate. We don't support it for now.
connectors: z
.array(
z.object({
id: z.string(),
user_access_token: z.string().optional(),
continue_on_failure: z.boolean().default(false).optional(),
options: z.any().optional(),
})
)
.optional(),
search_queries_only: z.boolean().default(false).optional(),
documents: z
.array(
z.object({
id: z.string().optional(),
title: z.string().optional(),
text: z.string(),
_excludes: z.array(z.string()).optional(),
})
)
.optional(),
citation_quality: z.enum(["accurate", "fast"]).optional(),
*/
temperature: z.number().default(0.3).optional(),
max_tokens: z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
max_input_tokens: z.number().int().optional(),
k: z.number().int().min(0).max(500).default(0).optional(),
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
seed: z.number().int().optional(),
stop_sequences: z.array(z.string()).max(5).optional(),
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
presence_penalty: z.number().min(0).max(1).default(0).optional(),
tools: z
.array(
z.object({
name: z.string(),
description: z.string(),
parameter_definitions: z.record(
z.object({
description: z.string().optional(),
type: z.string(),
required: z.boolean().optional().default(false),
})
),
})
)
.optional(),
tool_results: z
.array(
z.object({
call: z.object({
name: z.string(),
parameters: z.record(z.any()),
}),
outputs: z.array(z.record(z.any())),
})
)
.optional(),
// We always force single step to avoid stateful calls or expensive multi-step
// generations when tools are involved.
force_single_step: z.literal(true).default(true).optional(),
})
.strip();
export type CohereChatMessage = NonNullable<
z.infer<typeof CohereV1ChatSchema>["chat_history"]
>[number];
export function flattenCohereMessageContent(
message: CohereChatMessage
): string {
return message.role === "TOOL"
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
: message.message;
}
export const transformOpenAIToCohere: APIFormatTransformer<
typeof CohereV1ChatSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Cohere request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
// Final OAI message becomes the `message` field in Cohere
const message = messages[messages.length - 1];
// If the first message has system role, use it as preamble.
const hasSystemPreamble = messages[0]?.role === "system";
const preamble = hasSystemPreamble
? flattenOpenAIMessageContent(messages[0].content)
: undefined;
const chatHistory = messages.slice(0, -1).map((m) => {
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
m.role === "assistant"
? "CHATBOT"
: m.role === "system"
? "SYSTEM"
: "USER";
const content = flattenOpenAIMessageContent(m.content);
const message = m.name ? `${m.name}: ${content}` : content;
return { role, message };
});
return {
model: rest.model,
preamble,
chat_history: chatHistory,
message: flattenOpenAIMessageContent(message.content),
stop_sequences:
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
max_tokens: rest.max_tokens,
temperature: rest.temperature,
p: rest.top_p,
frequency_penalty: rest.frequency_penalty,
presence_penalty: rest.presence_penalty,
seed: rest.seed,
stream: rest.stream,
};
};
+20
View File
@@ -22,6 +22,7 @@ import {
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
export { OpenAIChatMessage } from "./openai";
export {
@@ -33,15 +34,29 @@ export {
export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai";
/** Represents a pair of API formats that can be transformed between. */
type APIPair = `${APIFormat}->${APIFormat}`;
/** Represents a map of API format pairs to transformer functions. */
type TransformerMap = {
[key in APIPair]?: APIFormatTransformer<any>;
};
/**
* Represents a transformer function that takes a Request and returns a Promise
* resolving to a value of the specified Zod schema type.
*
* @template Z The Zod schema type to transform the request into (from api-schemas).
* @param req The incoming Request to transform.
* @returns A Promise resolving to the transformed request body.
*/
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
/**
* Specifies possible translations between API formats and the corresponding
* transformer functions to apply them.
*/
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
@@ -49,8 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
"openai->cohere-chat": transformOpenAIToCohere,
};
/**
* Specifies the schema for each API format to validate incoming requests.
*/
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema,
@@ -59,4 +78,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
"cohere-chat": CohereV1ChatSchema,
};
+2 -1
View File
@@ -9,7 +9,8 @@ export type APIFormat =
| "anthropic-chat" // Anthropic's newer messages array format
| "anthropic-text" // Legacy flat string prompt format
| "google-ai"
| "mistral-ai";
| "mistral-ai"
| "cohere-chat";
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
+18 -2
View File
@@ -14,7 +14,8 @@ export type LLMService =
| "google-ai"
| "mistral-ai"
| "aws"
| "azure";
| "azure"
| "cohere";
export type OpenAIModelFamily =
| "turbo"
@@ -32,13 +33,15 @@ export type MistralAIModelFamily =
| "mistral-large";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type CohereModelFamily = "command-r" | "command-r-plus";
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
| AzureOpenAIModelFamily
| CohereModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
@@ -64,6 +67,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"azure-gpt4-turbo",
"azure-gpt4o",
"azure-dall-e",
"command-r",
"command-r-plus",
] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
@@ -75,6 +80,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
"mistral-ai",
"aws",
"azure",
"cohere",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
@@ -116,6 +122,8 @@ export const MODEL_FAMILY_SERVICE: {
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
"mistral-large": "mistral-ai",
"command-r": "cohere",
"command-r-plus": "cohere",
};
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
@@ -181,6 +189,11 @@ export function getAzureOpenAIModelFamily(
return defaultFamily;
}
export function getCohereModelFamily(model: string): CohereModelFamily {
if (model.includes("plus")) return "command-r-plus";
return "command-r";
}
export function assertIsKnownModelFamily(
modelFamily: string
): asserts modelFamily is ModelFamily {
@@ -220,6 +233,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "mistral-ai":
modelFamily = getMistralAIModelFamily(model);
break;
case "cohere-chat":
modelFamily = getCohereModelFamily(model);
break;
default:
assertNever(req.outboundApi);
}