From 415c4e2ec32d9eb63efe6b3aa2d88b85b179dba4 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 9 Jul 2024 17:45:20 -0700 Subject: [PATCH] OpenAI Wire: full migration --- .../server/dispatch/openai/oai.wiretypes.ts | 4 +- .../llms/server/openai/openai.router.ts | 57 ++++++++++------- .../llms/server/openai/openai.wiretypes.ts | 62 ------------------- 3 files changed, 38 insertions(+), 85 deletions(-) delete mode 100644 src/modules/llms/server/openai/openai.wiretypes.ts diff --git a/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts b/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts index e58091a50..3805610c3 100644 --- a/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts +++ b/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts @@ -163,7 +163,8 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({ max_tokens: z.number().optional(), temperature: z.number().min(0).max(2).optional(), - // other model configuration + // API configuration + n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it stream: z.boolean().optional(), // If set, partial message deltas will be sent, with the stream terminated by a `data: [DONE]` message. stream_options: z.object({ include_usage: z.boolean().optional(), // If set, an additional chunk will be streamed with a 'usage' field on the entire request. @@ -199,7 +200,6 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({ // top_p: z.number().min(0).max(1).optional(), // (disabled) advanced API configuration - // n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it // service_tier: z.unknown().optional(), }); diff --git a/src/modules/llms/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts index c6ce64fed..83163619f 100644 --- a/src/modules/llms/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -10,9 +10,7 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i import { Brand } from '~/common/app.config'; import { fixupHost } from '~/common/util/urlUtils'; -import { OpenaiWire_ChatCompletionRequest, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes'; - -import type { OpenAIWire } from './openai.wiretypes'; +import { OpenaiWire_ChatCompletionRequest, OpenaiWire_ChatCompletionResponse, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes'; import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data'; import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes'; @@ -266,7 +264,7 @@ export const llmOpenAIRouter = createTRPCRouter({ const isFunctionsCall = !!functions && functions.length > 0; const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false); - const wireCompletions = await openaiPOSTOrThrow( + const wireCompletions = await openaiPOSTOrThrow( access, model.id, completionsBody, '/v1/chat/completions', ); @@ -286,9 +284,9 @@ export const llmOpenAIRouter = createTRPCRouter({ // check for a function output // NOTE: this includes a workaround for when we requested a function but the model could not deliver - return (finish_reason === 'function_call' || 'function_call' in message) - ? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAIWire.ChatCompletion.ResponseFunctionCall) - : parseChatGenerateOutput(message as OpenAIWire.ChatCompletion.ResponseMessage, finish_reason); + return (finish_reason === 'tool_calls' || 'tool_calls' in message) + ? parseChatGenerateSingleToolFunctionOutput(isFunctionsCall, message) + : parseChatGenerateOutput(message, finish_reason); }), /* [OpenAI/LocalAI] images/generations */ @@ -644,15 +642,33 @@ export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: Open }, [] as OpenAIHistorySchema); } - return { + const chatCompletionRequest: OpenaiWire_ChatCompletionRequest = { model: model.id, messages: history, - ...(functions && { functions: functions, function_call: forceFunctionName ? { name: forceFunctionName } : 'auto' }), - ...(model.temperature !== undefined && { temperature: model.temperature }), - ...(model.maxTokens && { max_tokens: model.maxTokens }), - ...(n > 1 && { n }), - stream, + stream: stream, + stream_options: { + include_usage: true, + }, }; + if (model.temperature !== undefined) + chatCompletionRequest.temperature = model.temperature; + if (model.maxTokens) + chatCompletionRequest.max_tokens = model.maxTokens; + if (functions?.length) + chatCompletionRequest.tools = functions.map(fun => ({ + type: 'function', + function: fun, + })); + if (forceFunctionName) + chatCompletionRequest.tool_choice = { + type: 'function', + function: { name: forceFunctionName }, + }; + if (n > 1) { + throw new Error('OpenAI-derived API do not support n > 1 for chat completions, so we will not do it either'); + // chatCompletionRequest.n = n; + } + return chatCompletionRequest; } async function openaiGETOrThrow(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { @@ -665,7 +681,7 @@ async function openaiPOSTOrThrow( return await fetchJsonOrTRPCThrow({ url, method: 'POST', headers, body, name: `OpenAI/${access.dialect}` }); } -function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire.ChatCompletion.ResponseFunctionCall) { +function parseChatGenerateSingleToolFunctionOutput(isFunctionsCall: boolean, message: OpenaiWire_ChatCompletionResponse['choices'][number]['message']) { // NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment if (!isFunctionsCall) throw new TRPCError({ @@ -673,17 +689,16 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire message: `[OpenAI Issue] Received a function call without a function call request`, }); - // parse the function call - const fcMessage = message as any as OpenAIWire.ChatCompletion.ResponseFunctionCall; - if (fcMessage.content !== null) + // validate a single function call + if (!message.tool_calls || message.tool_calls.length !== 1 || message.tool_calls[0].type !== 'function' || message.content) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected a function call, got a message`, }); // got a function call, so parse it - const fc = fcMessage.function_call; - if (!fc || !fc.name || !fc.arguments) + const fc = message.tool_calls[0].function; + if (!fc.name || !fc.arguments) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Issue with the function call, missing name or arguments`, @@ -707,7 +722,7 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire }; } -function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) { +function parseChatGenerateOutput(message: OpenaiWire_ChatCompletionResponse['choices'][number]['message'], finish_reason: OpenaiWire_ChatCompletionResponse['choices'][number]['finish_reason']) { // validate the message if (message.content === null) throw new TRPCError({ @@ -718,6 +733,6 @@ function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMess return { role: message.role, content: message.content, - finish_reason: finish_reason, + finish_reason: (finish_reason === 'stop' || finish_reason === 'length') ? finish_reason : null, }; } \ No newline at end of file diff --git a/src/modules/llms/server/openai/openai.wiretypes.ts b/src/modules/llms/server/openai/openai.wiretypes.ts deleted file mode 100644 index 7c3254e60..000000000 --- a/src/modules/llms/server/openai/openai.wiretypes.ts +++ /dev/null @@ -1,62 +0,0 @@ -/** - * OpenAI API types - https://platform.openai.com/docs/api-reference/ - * - * Notes: - * - 2023-12-22: - * Below we have the manually typed types for the OpenAI API. Everywhere else we are switching - * to Zod inferred types, and we shall do it here sooner (so we can validate upon parsing too). - */ -export namespace OpenAIWire { - - export namespace ChatCompletion { - - export interface RequestFunctionDef { // [FN0613] - name: string; - description?: string; - parameters?: { - type: 'object'; - properties: { - [key: string]: { - type: 'string' | 'number' | 'integer' | 'boolean'; - description?: string; - enum?: string[]; - } - } - required?: string[]; - }; - } - - - export interface Response { - id: string; - object: 'chat.completion'; - created: number; // unix timestamp in seconds - model: string; // can differ from the ask, e.g. 'gpt-4-0314' - choices: { - index: number; - message: ResponseMessage | ResponseFunctionCall; // [FN0613] - finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613] - }[]; - usage: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; - } - - export interface ResponseMessage { - role: 'assistant'; - content: string; - } - - export interface ResponseFunctionCall { // [FN0613] - role: 'assistant'; - content: null; - function_call: { // if content is null and finish_reason is 'function_call' - name: string; - arguments: string; // a JSON object, to deserialize - }; - } - - } -}