mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-11 14:10:15 -07:00
OpenAI Wire: full migration
This commit is contained in:
@@ -163,7 +163,8 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({
|
||||
max_tokens: z.number().optional(),
|
||||
temperature: z.number().min(0).max(2).optional(),
|
||||
|
||||
// other model configuration
|
||||
// API configuration
|
||||
n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it
|
||||
stream: z.boolean().optional(), // If set, partial message deltas will be sent, with the stream terminated by a `data: [DONE]` message.
|
||||
stream_options: z.object({
|
||||
include_usage: z.boolean().optional(), // If set, an additional chunk will be streamed with a 'usage' field on the entire request.
|
||||
@@ -199,7 +200,6 @@ export const openaiWire_chatCompletionRequest_Schema = z.object({
|
||||
// top_p: z.number().min(0).max(1).optional(),
|
||||
|
||||
// (disabled) advanced API configuration
|
||||
// n: z.number().int().positive().optional(), // defaulting 'n' to 1, as the derived-ecosystem does not support it
|
||||
// service_tier: z.unknown().optional(),
|
||||
|
||||
});
|
||||
|
||||
@@ -10,9 +10,7 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i
|
||||
import { Brand } from '~/common/app.config';
|
||||
import { fixupHost } from '~/common/util/urlUtils';
|
||||
|
||||
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
|
||||
import type { OpenAIWire } from './openai.wiretypes';
|
||||
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_ChatCompletionResponse, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_FunctionDefinition, openaiWire_FunctionDefinition_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data';
|
||||
import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
|
||||
import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes';
|
||||
@@ -266,7 +264,7 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
const isFunctionsCall = !!functions && functions.length > 0;
|
||||
|
||||
const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false);
|
||||
const wireCompletions = await openaiPOSTOrThrow<OpenAIWire.ChatCompletion.Response, OpenaiWire_ChatCompletionRequest>(
|
||||
const wireCompletions = await openaiPOSTOrThrow<OpenaiWire_ChatCompletionResponse, OpenaiWire_ChatCompletionRequest>(
|
||||
access, model.id, completionsBody, '/v1/chat/completions',
|
||||
);
|
||||
|
||||
@@ -286,9 +284,9 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
|
||||
// check for a function output
|
||||
// NOTE: this includes a workaround for when we requested a function but the model could not deliver
|
||||
return (finish_reason === 'function_call' || 'function_call' in message)
|
||||
? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAIWire.ChatCompletion.ResponseFunctionCall)
|
||||
: parseChatGenerateOutput(message as OpenAIWire.ChatCompletion.ResponseMessage, finish_reason);
|
||||
return (finish_reason === 'tool_calls' || 'tool_calls' in message)
|
||||
? parseChatGenerateSingleToolFunctionOutput(isFunctionsCall, message)
|
||||
: parseChatGenerateOutput(message, finish_reason);
|
||||
}),
|
||||
|
||||
/* [OpenAI/LocalAI] images/generations */
|
||||
@@ -644,15 +642,33 @@ export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: Open
|
||||
}, [] as OpenAIHistorySchema);
|
||||
}
|
||||
|
||||
return {
|
||||
const chatCompletionRequest: OpenaiWire_ChatCompletionRequest = {
|
||||
model: model.id,
|
||||
messages: history,
|
||||
...(functions && { functions: functions, function_call: forceFunctionName ? { name: forceFunctionName } : 'auto' }),
|
||||
...(model.temperature !== undefined && { temperature: model.temperature }),
|
||||
...(model.maxTokens && { max_tokens: model.maxTokens }),
|
||||
...(n > 1 && { n }),
|
||||
stream,
|
||||
stream: stream,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
};
|
||||
if (model.temperature !== undefined)
|
||||
chatCompletionRequest.temperature = model.temperature;
|
||||
if (model.maxTokens)
|
||||
chatCompletionRequest.max_tokens = model.maxTokens;
|
||||
if (functions?.length)
|
||||
chatCompletionRequest.tools = functions.map(fun => ({
|
||||
type: 'function',
|
||||
function: fun,
|
||||
}));
|
||||
if (forceFunctionName)
|
||||
chatCompletionRequest.tool_choice = {
|
||||
type: 'function',
|
||||
function: { name: forceFunctionName },
|
||||
};
|
||||
if (n > 1) {
|
||||
throw new Error('OpenAI-derived API do not support n > 1 for chat completions, so we will not do it either');
|
||||
// chatCompletionRequest.n = n;
|
||||
}
|
||||
return chatCompletionRequest;
|
||||
}
|
||||
|
||||
async function openaiGETOrThrow<TOut extends object>(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
|
||||
@@ -665,7 +681,7 @@ async function openaiPOSTOrThrow<TOut extends object, TPostBody extends object>(
|
||||
return await fetchJsonOrTRPCThrow<TOut, TPostBody>({ url, method: 'POST', headers, body, name: `OpenAI/${access.dialect}` });
|
||||
}
|
||||
|
||||
function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire.ChatCompletion.ResponseFunctionCall) {
|
||||
function parseChatGenerateSingleToolFunctionOutput(isFunctionsCall: boolean, message: OpenaiWire_ChatCompletionResponse['choices'][number]['message']) {
|
||||
// NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment
|
||||
if (!isFunctionsCall)
|
||||
throw new TRPCError({
|
||||
@@ -673,17 +689,16 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire
|
||||
message: `[OpenAI Issue] Received a function call without a function call request`,
|
||||
});
|
||||
|
||||
// parse the function call
|
||||
const fcMessage = message as any as OpenAIWire.ChatCompletion.ResponseFunctionCall;
|
||||
if (fcMessage.content !== null)
|
||||
// validate a single function call
|
||||
if (!message.tool_calls || message.tool_calls.length !== 1 || message.tool_calls[0].type !== 'function' || message.content)
|
||||
throw new TRPCError({
|
||||
code: 'INTERNAL_SERVER_ERROR',
|
||||
message: `[OpenAI Issue] Expected a function call, got a message`,
|
||||
});
|
||||
|
||||
// got a function call, so parse it
|
||||
const fc = fcMessage.function_call;
|
||||
if (!fc || !fc.name || !fc.arguments)
|
||||
const fc = message.tool_calls[0].function;
|
||||
if (!fc.name || !fc.arguments)
|
||||
throw new TRPCError({
|
||||
code: 'INTERNAL_SERVER_ERROR',
|
||||
message: `[OpenAI Issue] Issue with the function call, missing name or arguments`,
|
||||
@@ -707,7 +722,7 @@ function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAIWire
|
||||
};
|
||||
}
|
||||
|
||||
function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) {
|
||||
function parseChatGenerateOutput(message: OpenaiWire_ChatCompletionResponse['choices'][number]['message'], finish_reason: OpenaiWire_ChatCompletionResponse['choices'][number]['finish_reason']) {
|
||||
// validate the message
|
||||
if (message.content === null)
|
||||
throw new TRPCError({
|
||||
@@ -718,6 +733,6 @@ function parseChatGenerateOutput(message: OpenAIWire.ChatCompletion.ResponseMess
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
finish_reason: finish_reason,
|
||||
finish_reason: (finish_reason === 'stop' || finish_reason === 'length') ? finish_reason : null,
|
||||
};
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
/**
|
||||
* OpenAI API types - https://platform.openai.com/docs/api-reference/
|
||||
*
|
||||
* Notes:
|
||||
* - 2023-12-22:
|
||||
* Below we have the manually typed types for the OpenAI API. Everywhere else we are switching
|
||||
* to Zod inferred types, and we shall do it here sooner (so we can validate upon parsing too).
|
||||
*/
|
||||
export namespace OpenAIWire {
|
||||
|
||||
export namespace ChatCompletion {
|
||||
|
||||
export interface RequestFunctionDef { // [FN0613]
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: {
|
||||
type: 'object';
|
||||
properties: {
|
||||
[key: string]: {
|
||||
type: 'string' | 'number' | 'integer' | 'boolean';
|
||||
description?: string;
|
||||
enum?: string[];
|
||||
}
|
||||
}
|
||||
required?: string[];
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
export interface Response {
|
||||
id: string;
|
||||
object: 'chat.completion';
|
||||
created: number; // unix timestamp in seconds
|
||||
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
|
||||
choices: {
|
||||
index: number;
|
||||
message: ResponseMessage | ResponseFunctionCall; // [FN0613]
|
||||
finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613]
|
||||
}[];
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
total_tokens: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface ResponseMessage {
|
||||
role: 'assistant';
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface ResponseFunctionCall { // [FN0613]
|
||||
role: 'assistant';
|
||||
content: null;
|
||||
function_call: { // if content is null and finish_reason is 'function_call'
|
||||
name: string;
|
||||
arguments: string; // a JSON object, to deserialize
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user