OpenAI Wire: full migration

This commit is contained in:
Enrico Ros
2024-07-09 17:45:20 -07:00
parent 7afe4ab477
commit 415c4e2ec3
3 changed files with 38 additions and 85 deletions
@@ -163,7 +163,8 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({
max_tokens: z.number().optional(),
temperature: z.number().min(0).max(2).optional(),
// other model configuration
// API configuration
n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it
stream: z.boolean().optional(), // If set, partial message deltas will be sent, with the stream terminated by a `data: [DONE]` message.
stream_options: z.object({
include_usage: z.boolean().optional(), // If set, an additional chunk will be streamed with a 'usage' field on the entire request.
@@ -199,7 +200,6 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({
// top_p: z.number().min(0).max(1).optional(),
// (disabled) advanced API configuration
// n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it
// service_tier: z.unknown().optional(),
});
+36 -21
View File
@@ -10,9 +10,7 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i
import { Brand } from '~/common/app.config';
import { fixupHost } from '~/common/util/urlUtils';
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import type { OpenAIWire } from './openai.wiretypes';
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_ChatCompletionResponse, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data';
import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes';
@@ -266,7 +264,7 @@ export const llmOpenAIRouter = createTRPCRouter({
const isFunctionsCall = !!functions && functions.length > 0;
const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false);
const wireCompletions = await openaiPOSTOrThrow<OpenAIWire.ChatCompletion.Response, OpenaiWire_ChatCompletionRequest>(
const wireCompletions = await openaiPOSTOrThrow<OpenaiWire_ChatCompletionResponse, OpenaiWire_ChatCompletionRequest>(
access, model.id, completionsBody, '/v1/chat/completions',
);
@@ -286,9 +284,9 @@ export const llmOpenAIRouter = createTRPCRouter({
// check for a function output
// NOTE: this includes a workaround for when we requested a function but the model could not deliver
return (finish_reason === 'function_call' || 'function_call' in message)
? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAIWire.ChatCompletion.ResponseFunctionCall)
: parseChatGenerateOutput(message as OpenAIWire.ChatCompletion.ResponseMessage, finish_reason);
return (finish_reason === 'tool_calls' || 'tool_calls' in message)
? parseChatGenerateSingleToolFunctionOutput(isFunctionsCall, message)
: parseChatGenerateOutput(message, finish_reason);
}),
/* [OpenAI/LocalAI] images/generations */
@@ -644,15 +642,33 @@ export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: Open
}, [] as OpenAIHistorySchema);
}
return {
const chatCompletionRequest: OpenaiWire_ChatCompletionRequest = {
model: model.id,
messages: history,
...(functions && { functions: functions, function_call: forceFunctionName ? { name: forceFunctionName } : 'auto' }),
...(model.temperature !== undefined && { temperature: model.temperature }),
...(model.maxTokens && { max_tokens: model.maxTokens }),
...(n > 1 && { n }),
stream,
stream: stream,
stream_options: {
include_usage: true,
},
};
if (model.temperature !== undefined)
chatCompletionRequest.temperature = model.temperature;
if (model.maxTokens)
chatCompletionRequest.max_tokens = model.maxTokens;
if (functions?.length)
chatCompletionRequest.tools = functions.map(fun => ({
type: 'function',
function: fun,
}));
if (forceFunctionName)
chatCompletionRequest.tool_choice = {
type: 'function',
function: { name: forceFunctionName },
};
if (n > 1) {
throw new Error('OpenAI-derived API do not support n > 1 for chat completions, so we will not do it either');
// chatCompletionRequest.n = n;
}
return chatCompletionRequest;
}
async function openaiGETOrThrow<TOut extends object>(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
@@ -665,7 +681,7 @@ async function openaiPOSTOrThrow<TOut extends object, TPostBody extends object>(
return await fetchJsonOrTRPCThrow<TOut, TPostBody>({ url, method: 'POST', headers, body, name: `OpenAI/${access.dialect}` });
}
function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire.ChatCompletion.ResponseFunctionCall) {
function parseChatGenerateSingleToolFunctionOutput(isFunctionsCall: boolean, message: OpenaiWire_ChatCompletionResponse['choices'][number]['message']) {
// NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment
if (!isFunctionsCall)
throw new TRPCError({
@@ -673,17 +689,16 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire
message: `[OpenAI Issue] Received a function call without a function call request`,
});
// parse the function call
const fcMessage = message as any as OpenAIWire.ChatCompletion.ResponseFunctionCall;
if (fcMessage.content !== null)
// validate a single function call
if (!message.tool_calls || message.tool_calls.length !== 1 || message.tool_calls[0].type !== 'function' || message.content)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Expected a function call, got a message`,
});
// got a function call, so parse it
const fc = fcMessage.function_call;
if (!fc || !fc.name || !fc.arguments)
const fc = message.tool_calls[0].function;
if (!fc.name || !fc.arguments)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Issue with the function call, missing name or arguments`,
@@ -707,7 +722,7 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire
};
}
function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) {
function parseChatGenerateOutput(message: OpenaiWire_ChatCompletionResponse['choices'][number]['message'], finish_reason: OpenaiWire_ChatCompletionResponse['choices'][number]['finish_reason']) {
// validate the message
if (message.content === null)
throw new TRPCError({
@@ -718,6 +733,6 @@ function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMess
return {
role: message.role,
content: message.content,
finish_reason: finish_reason,
finish_reason: (finish_reason === 'stop' || finish_reason === 'length') ? finish_reason : null,
};
}
@@ -1,62 +0,0 @@
/**
* OpenAI API types - https://platform.openai.com/docs/api-reference/
*
* Notes:
* - 2023-12-22:
* Below we have the manually typed types for the OpenAI API. Everywhere else we are switching
* to Zod inferred types, and we shall do it here sooner (so we can validate upon parsing too).
*/
export namespace OpenAIWire {
export namespace ChatCompletion {
export interface RequestFunctionDef { // [FN0613]
name: string;
description?: string;
parameters?: {
type: 'object';
properties: {
[key: string]: {
type: 'string' | 'number' | 'integer' | 'boolean';
description?: string;
enum?: string[];
}
}
required?: string[];
};
}
export interface Response {
id: string;
object: 'chat.completion';
created: number; // unix timestamp in seconds
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
choices: {
index: number;
message: ResponseMessage | ResponseFunctionCall; // [FN0613]
finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613]
}[];
usage: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
};
}
export interface ResponseMessage {
role: 'assistant';
content: string;
}
export interface ResponseFunctionCall { // [FN0613]
role: 'assistant';
content: null;
function_call: { // if content is null and finish_reason is 'function_call'
name: string;
arguments: string; // a JSON object, to deserialize
};
}
}
}