CallChatWithFunctions - functions support, incl. OpenAI Implementation

May be rough on the edges, but should not create issues.
The implementation is defensive, excessively validates the
return types as the OpenAI API is brittle and can easily misbehave
This commit is contained in:
Enrico Ros
2023-06-28 03:00:25 -07:00
parent 87d9309a8e
commit 2d4c0e9c64
8 changed files with 191 additions and 72 deletions
+2 -2
View File
@@ -1,7 +1,7 @@
import { NextRequest, NextResponse } from 'next/server';
import { createParser } from 'eventsource-parser';
import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAICompletionRequest } from '~/modules/llms/openai/openai.router';
import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAIChatCompletionRequest } from '~/modules/llms/openai/openai.router';
import { OpenAI } from '~/modules/llms/openai/openai.types';
@@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C
// prepare request objects
const { headers, url } = openAIAccess(access, '/v1/chat/completions');
const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true);
const body: OpenAI.Wire.ChatCompletion.Request = openAIChatCompletionRequest(model, history, null, true);
// perform the request
upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal });
+41 -9
View File
@@ -1,17 +1,49 @@
import { DLLMId } from '~/modules/llms/llm.types';
import { findVendorById } from '~/modules/llms/vendor.registry';
import { useModelsStore } from '~/modules/llms/store-llms';
import { DLLM, DLLMId } from './llm.types';
import { OpenAI } from './openai/openai.types';
import { findVendorById } from './vendor.registry';
import { useModelsStore } from './store-llms';
export async function callChatGenerate(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
export type ModelVendorCallChatFn = (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => Promise<VChatMessageOut>;
export type ModelVendorCallChatWithFunctionsFn = (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => Promise<VChatMessageOrFunctionCallOut>;
// get the vendor
export interface VChatMessageIn {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}
export type VChatFunctionIn = OpenAI.Wire.ChatCompletion.RequestFunctionDef;
export interface VChatMessageOut {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}
export interface VChatFunctionCallOut {
function_name: string;
function_arguments: object | null;
}
export type VChatMessageOrFunctionCallOut = VChatMessageOut | VChatFunctionCallOut;
export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise<VChatMessageOut> {
const { llm, vendor } = getLLMAndVendorOrThrow(llmId);
return await vendor.callChat(llm, messages, maxTokens);
}
export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number): Promise<VChatMessageOrFunctionCallOut> {
const { llm, vendor } = getLLMAndVendorOrThrow(llmId);
return await vendor.callChatWithFunctions(llm, messages, functions, maxTokens);
}
function getLLMAndVendorOrThrow(llmId: string) {
const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId);
const vendor = findVendorById(llm?._source.vId);
if (!llm || !vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`);
// go for it
return await vendor.callChat(llm, messages, maxTokens);
return { llm, vendor };
}
+3 -5
View File
@@ -1,12 +1,11 @@
import type React from 'react';
import type { LLMOptionsOpenAI, SourceSetupOpenAI } from './openai/openai.vendor';
import type { OpenAI } from './openai/openai.types';
import type { ModelVendorCallChatFn, ModelVendorCallChatWithFunctionsFn } from './llm.client';
import type { SourceSetupLocalAI } from './localai/localai.vendor';
export type DLLMId = string;
// export type DLLMTags = 'stream' | 'chat';
export type DLLMOptions = LLMOptionsOpenAI; //DLLMValuesOpenAI | DLLMVaLocalAIDLLMValues;
export type DModelSourceId = string;
export type DModelSourceSetup = SourceSetupOpenAI | SourceSetupLocalAI;
@@ -60,6 +59,5 @@ export interface ModelVendor {
// functions
callChat: ModelVendorCallChatFn;
}
type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
callChatWithFunctions: ModelVendorCallChatWithFunctionsFn;
}
+2 -1
View File
@@ -17,7 +17,8 @@ export const ModelVendorLocalAI: ModelVendor = {
LLMOptionsComponent: () => <>No LocalAI Options</>,
// functions
callChat: () => Promise.reject(new Error('LocalAI is not implemented')),
callChat: () => Promise.reject(new Error('LocalAI chat is not implemented')),
callChatWithFunctions: () => Promise.reject(new Error('LocalAI chatWithFunctions is not implemented')),
};
+19 -8
View File
@@ -1,7 +1,7 @@
import { apiAsync } from '~/modules/trpc/trpc.client';
import { DLLM } from '../llm.types';
import { OpenAI } from './openai.types';
import type { DLLM } from '../llm.types';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../llm.client';
import { normalizeOAISetup, SourceSetupOpenAI } from './openai.vendor';
@@ -10,10 +10,17 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI;
export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40;
export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOut>(llm, messages, null, maxTokens);
export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) =>
callChatOverloaded<VChatMessageOrFunctionCallOut>(llm, messages, functions, maxTokens);
/**
* This function either returns the LLM response, or throws a descriptive error string
* This function either returns the LLM message, or function calls, or throws a descriptive error string
*/
export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
async function callChatOverloaded<TOut extends VChatMessageOrFunctionCallOut>(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, maxTokens?: number): Promise<TOut> {
// access params (source)
const partialSetup = llm._source.setup as Partial<SourceSetupOpenAI>;
const sourceSetupOpenAI = normalizeOAISetup(partialSetup);
@@ -21,14 +28,18 @@ export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.R
// model params (llm)
const openaiLlmRef = llm.options.llmRef!;
const modelTemperature = llm.options.llmTemperature || 0.5;
// const maxTokens = llm.options.llmResponseTokens || 1024; // <- note: this would be for chat answers, not programmatic chat calls
try {
return await apiAsync.openai.chatGenerate.mutate({
return await apiAsync.openai.chatGenerateWithFunctions.mutate({
access: sourceSetupOpenAI,
model: { id: openaiLlmRef, temperature: modelTemperature, ...(maxTokens && { maxTokens }) },
model: {
id: openaiLlmRef,
temperature: modelTemperature,
...(maxTokens && { maxTokens }),
},
functions: functions ?? undefined,
history: messages,
});
}) as TOut;
// errorMessage = `issue fetching: ${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`;
} catch (error: any) {
const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Fetch Error';
+104 -27
View File
@@ -10,6 +10,8 @@ import { OpenAI } from './openai.types';
// console.warn('OPENAI_API_KEY has not been provided in this deployment environment. Will need client-supplied keys, which is not recommended.');
// Input Schemas
const accessSchema = z.object({
oaiKey: z.string().trim(),
oaiOrg: z.string().trim(),
@@ -29,7 +31,7 @@ const historySchema = z.array(z.object({
content: z.string(),
}));
/*const functionsSchema = z.array(z.object({
const functionsSchema = z.array(z.object({
name: z.string(),
description: z.string().optional(),
parameters: z.object({
@@ -41,12 +43,29 @@ const historySchema = z.array(z.object({
})),
required: z.array(z.string()).optional(),
}).optional(),
}));*/
}));
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema });
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() });
export type ChatGenerateSchema = z.infer<typeof chatGenerateSchema>;
export const chatModerationSchema = z.object({ access: accessSchema, text: z.string() });
const chatModerationSchema = z.object({ access: accessSchema, text: z.string() });
// Output Schemas
const chatGenerateWithFunctionsOutputSchema = z.union([
z.object({
role: z.enum(['assistant', 'system', 'user']),
content: z.string(),
finish_reason: z.union([z.enum(['stop', 'length']), z.null()]),
}),
z.object({
function_name: z.string(),
function_arguments: z.record(z.any()),
}),
]);
export const openAIRouter = createTRPCRouter({
@@ -54,33 +73,29 @@ export const openAIRouter = createTRPCRouter({
/**
* Chat-based message generation
*/
chatGenerate: publicProcedure
chatGenerateWithFunctions: publicProcedure
.input(chatGenerateSchema)
.mutation(async ({ input }): Promise<OpenAI.API.Chat.Response> => {
.output(chatGenerateWithFunctionsOutputSchema)
.mutation(async ({ input }) => {
const { access, model, history } = input;
const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false);
let wireCompletions: OpenAI.Wire.ChatCompletion.Response;
const { access, model, history, functions } = input;
const isFunctionsCall = !!functions && functions.length > 0;
// try {
wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(access, requestBody, '/v1/chat/completions');
// } catch (error: any) {
// // NOTE: disabled on 2023-06-19: show all errors, 429 is not that common now, and could explain issues
// // don't log 429 errors on the server-side, they are expected
// if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests'))
// console.error('api/openai/chat error:', error);
// throw error;
// }
const wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(
access,
openAIChatCompletionRequest(model, history, isFunctionsCall ? functions : null, false),
'/v1/chat/completions',
);
// expect a single output
if (wireCompletions?.choices?.length !== 1)
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` });
const { message, finish_reason } = wireCompletions.choices[0];
const singleChoice = wireCompletions.choices[0];
return {
role: singleChoice.message.role,
content: singleChoice.message.content,
finish_reason: singleChoice.finish_reason,
};
// check for a function output
return finish_reason === 'function_call'
? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAI.Wire.ChatCompletion.ResponseFunctionCall)
: parseChatGenerateOutput(message as OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason);
}),
/**
@@ -147,6 +162,7 @@ export const openAIRouter = createTRPCRouter({
type AccessSchema = z.infer<typeof accessSchema>;
type ModelSchema = z.infer<typeof modelSchema>;
type HistorySchema = z.infer<typeof historySchema>;
type FunctionsSchema = z.infer<typeof functionsSchema>;
async function openaiGET<TOut>(access: AccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = openAIAccess(access, apiPath);
@@ -171,7 +187,11 @@ async function openaiPOST<TBody, TOut>(access: AccessSchema, body: TBody, apiPat
: `[Issue] ${response.statusText}`,
});
}
return await response.json() as TOut;
try {
return await response.json();
} catch (error: any) {
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] ${error?.message || error}` });
}
}
export function openAIAccess(access: AccessSchema, apiPath: string): { headers: HeadersInit, url: string } {
@@ -203,14 +223,71 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers:
};
}
export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
export function openAIChatCompletionRequest(model: ModelSchema, history: HistorySchema, functions: FunctionsSchema | null, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
return {
model: model.id,
messages: history,
// ...(functions && { functions: functions, function_call: 'auto', }),
...(functions && { functions: functions, function_call: 'auto' }),
...(model.temperature && { temperature: model.temperature }),
...(model.maxTokens && { max_tokens: model.maxTokens }),
stream,
n: 1,
};
}
function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAI.Wire.ChatCompletion.ResponseFunctionCall) {
// NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment
if (!isFunctionsCall)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Received a function call without a function call request`,
});
// parse the function call
const fcMessage = message as any as OpenAI.Wire.ChatCompletion.ResponseFunctionCall;
if (fcMessage.content !== null)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Expected a function call, got a message`,
});
// got a function call, so parse it
const fc = fcMessage.function_call;
if (!fc || !fc.name || !fc.arguments)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Issue with the function call, missing name or arguments`,
});
// decode the function call
const fcName = fc.name;
let fcArgs: object;
try {
fcArgs = JSON.parse(fc.arguments);
} catch (error: any) {
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Issue with the function call, arguments are not valid JSON`,
});
}
return {
function_name: fcName,
function_arguments: fcArgs,
};
}
function parseChatGenerateOutput(message: OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) {
// validate the message
if (message.content === null)
throw new TRPCError({
code: 'INTERNAL_SERVER_ERROR',
message: `[OpenAI Issue] Expected a message, got a null message`,
});
return {
role: message.role,
content: message.content,
finish_reason: finish_reason,
};
}
+18 -19
View File
@@ -5,12 +5,6 @@ export namespace OpenAI {
export namespace Chat {
export interface Response {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}
/**
* The client will be sent a stream of words. As an extra (an totally optional) 'data channel' we send a
* string JSON object with the few initial variables. We hope in the future to adopt a better
@@ -23,7 +17,12 @@ export namespace OpenAI {
}
/// OpenAI API types - https://platform.openai.com/docs/api-reference/
/**
* OpenAI API types - https://platform.openai.com/docs/api-reference/
*
* Notes:
* - [FN0613]: function calling capability - only 2023-06-13 and later Chat models
*/
export namespace Wire {
export namespace ChatCompletion {
@@ -37,11 +36,11 @@ export namespace OpenAI {
max_tokens?: number;
stream: boolean;
n: number;
// only 2023-06-13 and later Chat models
// functions?: RequestFunction[],
// function_call?: 'auto' | 'none' | {
// name: string;
// },
// [FN0613]
functions?: RequestFunctionDef[],
function_call?: 'auto' | 'none' | {
name: string;
},
}
export interface RequestMessage {
@@ -50,7 +49,7 @@ export namespace OpenAI {
//name?: string; // when role: 'function'
}
/*export interface RequestFunction {
export interface RequestFunctionDef { // [FN0613]
name: string;
description?: string;
parameters?: {
@@ -64,7 +63,7 @@ export namespace OpenAI {
}
required?: string[];
};
}*/
}
export interface Response {
@@ -74,8 +73,8 @@ export namespace OpenAI {
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
choices: {
index: number;
message: ResponseMessage; // | ResponseFunctionCall;
finish_reason: 'stop' | 'length' | null; // | 'function_call'
message: ResponseMessage | ResponseFunctionCall; // [FN0613]
finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613]
}[];
usage: {
prompt_tokens: number;
@@ -84,19 +83,19 @@ export namespace OpenAI {
};
}
interface ResponseMessage {
export interface ResponseMessage {
role: 'assistant';
content: string;
}
/*interface ResponseFunctionCall {
export interface ResponseFunctionCall { // [FN0613]
role: 'assistant';
content: null;
function_call: { // if content is null and finish_reason is 'function_call'
name: string;
arguments: string; // a JSON object, to deserialize
};
}*/
}
export interface ResponseStreamingChunk {
id: string;
+2 -1
View File
@@ -2,7 +2,7 @@ import { ModelVendor } from '../llm.types';
import { OpenAIIcon } from './OpenAIIcon';
import { OpenAILLMOptions } from './OpenAILLMOptions';
import { OpenAISourceSetup } from './OpenAISourceSetup';
import { callChat } from './openai.client';
import { callChat, callChatWithFunctions } from './openai.client';
export const ModelVendorOpenAI: ModelVendor = {
@@ -19,6 +19,7 @@ export const ModelVendorOpenAI: ModelVendor = {
// functions
callChat: callChat,
callChatWithFunctions: callChatWithFunctions,
};