From 2d4c0e9c64b314ce40cbf86c7c0c8eb079e2013a Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Wed, 28 Jun 2023 03:00:25 -0700 Subject: [PATCH] CallChatWithFunctions - functions support, incl. OpenAI Implementation May be rough on the edges, but should not create issues. The implementation is defensive, excessively validates the return types as the OpenAI API is brittle and can easily misbehave --- pages/api/openai/stream-chat.ts | 4 +- src/modules/llms/llm.client.ts | 50 ++++++-- src/modules/llms/llm.types.ts | 8 +- src/modules/llms/localai/localai.vendor.tsx | 3 +- src/modules/llms/openai/openai.client.ts | 27 ++-- src/modules/llms/openai/openai.router.ts | 131 ++++++++++++++++---- src/modules/llms/openai/openai.types.ts | 37 +++--- src/modules/llms/openai/openai.vendor.ts | 3 +- 8 files changed, 191 insertions(+), 72 deletions(-) diff --git a/pages/api/openai/stream-chat.ts b/pages/api/openai/stream-chat.ts index a8392e57d..99263172f 100644 --- a/pages/api/openai/stream-chat.ts +++ b/pages/api/openai/stream-chat.ts @@ -1,7 +1,7 @@ import { NextRequest, NextResponse } from 'next/server'; import { createParser } from 'eventsource-parser'; -import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAICompletionRequest } from '~/modules/llms/openai/openai.router'; +import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAIChatCompletionRequest } from '~/modules/llms/openai/openai.router'; import { OpenAI } from '~/modules/llms/openai/openai.types'; @@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C // prepare request objects const { headers, url } = openAIAccess(access, '/v1/chat/completions'); - const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true); + const body: OpenAI.Wire.ChatCompletion.Request = openAIChatCompletionRequest(model, history, null, true); // perform the request upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal }); diff --git a/src/modules/llms/llm.client.ts b/src/modules/llms/llm.client.ts index a9496f0f9..14ee086a6 100644 --- a/src/modules/llms/llm.client.ts +++ b/src/modules/llms/llm.client.ts @@ -1,17 +1,49 @@ -import { DLLMId } from '~/modules/llms/llm.types'; -import { findVendorById } from '~/modules/llms/vendor.registry'; -import { useModelsStore } from '~/modules/llms/store-llms'; - +import { DLLM, DLLMId } from './llm.types'; import { OpenAI } from './openai/openai.types'; +import { findVendorById } from './vendor.registry'; +import { useModelsStore } from './store-llms'; -export async function callChatGenerate(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { +export type ModelVendorCallChatFn = (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => Promise; +export type ModelVendorCallChatWithFunctionsFn = (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => Promise; - // get the vendor +export interface VChatMessageIn { + role: 'assistant' | 'system' | 'user'; // | 'function'; + content: string; + //name?: string; // when role: 'function' +} + +export type VChatFunctionIn = OpenAI.Wire.ChatCompletion.RequestFunctionDef; + +export interface VChatMessageOut { + role: 'assistant' | 'system' | 'user'; + content: string; + finish_reason: 'stop' | 'length' | null; +} + +export interface VChatFunctionCallOut { + function_name: string; + function_arguments: object | null; +} + +export type VChatMessageOrFunctionCallOut = VChatMessageOut | VChatFunctionCallOut; + + + +export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise { + const { llm, vendor } = getLLMAndVendorOrThrow(llmId); + return await vendor.callChat(llm, messages, maxTokens); +} + +export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number): Promise { + const { llm, vendor } = getLLMAndVendorOrThrow(llmId); + return await vendor.callChatWithFunctions(llm, messages, functions, maxTokens); +} + + +function getLLMAndVendorOrThrow(llmId: string) { const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId); const vendor = findVendorById(llm?._source.vId); if (!llm || !vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`); - - // go for it - return await vendor.callChat(llm, messages, maxTokens); + return { llm, vendor }; } \ No newline at end of file diff --git a/src/modules/llms/llm.types.ts b/src/modules/llms/llm.types.ts index 2a729e485..4c9a8ec1d 100644 --- a/src/modules/llms/llm.types.ts +++ b/src/modules/llms/llm.types.ts @@ -1,12 +1,11 @@ import type React from 'react'; import type { LLMOptionsOpenAI, SourceSetupOpenAI } from './openai/openai.vendor'; -import type { OpenAI } from './openai/openai.types'; +import type { ModelVendorCallChatFn, ModelVendorCallChatWithFunctionsFn } from './llm.client'; import type { SourceSetupLocalAI } from './localai/localai.vendor'; export type DLLMId = string; -// export type DLLMTags = 'stream' | 'chat'; export type DLLMOptions = LLMOptionsOpenAI; //DLLMValuesOpenAI | DLLMVaLocalAIDLLMValues; export type DModelSourceId = string; export type DModelSourceSetup = SourceSetupOpenAI | SourceSetupLocalAI; @@ -60,6 +59,5 @@ export interface ModelVendor { // functions callChat: ModelVendorCallChatFn; -} - -type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise; + callChatWithFunctions: ModelVendorCallChatWithFunctionsFn; +} \ No newline at end of file diff --git a/src/modules/llms/localai/localai.vendor.tsx b/src/modules/llms/localai/localai.vendor.tsx index 58b77e6aa..89af7d057 100644 --- a/src/modules/llms/localai/localai.vendor.tsx +++ b/src/modules/llms/localai/localai.vendor.tsx @@ -17,7 +17,8 @@ export const ModelVendorLocalAI: ModelVendor = { LLMOptionsComponent: () => <>No LocalAI Options, // functions - callChat: () => Promise.reject(new Error('LocalAI is not implemented')), + callChat: () => Promise.reject(new Error('LocalAI chat is not implemented')), + callChatWithFunctions: () => Promise.reject(new Error('LocalAI chatWithFunctions is not implemented')), }; diff --git a/src/modules/llms/openai/openai.client.ts b/src/modules/llms/openai/openai.client.ts index 7f25b179f..10ee3d1de 100644 --- a/src/modules/llms/openai/openai.client.ts +++ b/src/modules/llms/openai/openai.client.ts @@ -1,7 +1,7 @@ import { apiAsync } from '~/modules/trpc/trpc.client'; -import { DLLM } from '../llm.types'; -import { OpenAI } from './openai.types'; +import type { DLLM } from '../llm.types'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../llm.client'; import { normalizeOAISetup, SourceSetupOpenAI } from './openai.vendor'; @@ -10,10 +10,17 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI; export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40; +export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => + callChatOverloaded(llm, messages, null, maxTokens); + +export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => + callChatOverloaded(llm, messages, functions, maxTokens); + + /** - * This function either returns the LLM response, or throws a descriptive error string + * This function either returns the LLM message, or function calls, or throws a descriptive error string */ -export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { +async function callChatOverloaded(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, maxTokens?: number): Promise { // access params (source) const partialSetup = llm._source.setup as Partial; const sourceSetupOpenAI = normalizeOAISetup(partialSetup); @@ -21,14 +28,18 @@ export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.R // model params (llm) const openaiLlmRef = llm.options.llmRef!; const modelTemperature = llm.options.llmTemperature || 0.5; - // const maxTokens = llm.options.llmResponseTokens || 1024; // <- note: this would be for chat answers, not programmatic chat calls try { - return await apiAsync.openai.chatGenerate.mutate({ + return await apiAsync.openai.chatGenerateWithFunctions.mutate({ access: sourceSetupOpenAI, - model: { id: openaiLlmRef, temperature: modelTemperature, ...(maxTokens && { maxTokens }) }, + model: { + id: openaiLlmRef, + temperature: modelTemperature, + ...(maxTokens && { maxTokens }), + }, + functions: functions ?? undefined, history: messages, - }); + }) as TOut; // errorMessage = `issue fetching: ${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`; } catch (error: any) { const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Fetch Error'; diff --git a/src/modules/llms/openai/openai.router.ts b/src/modules/llms/openai/openai.router.ts index a5cf4f06c..f7c3a4660 100644 --- a/src/modules/llms/openai/openai.router.ts +++ b/src/modules/llms/openai/openai.router.ts @@ -10,6 +10,8 @@ import { OpenAI } from './openai.types'; // console.warn('OPENAI_API_KEY has not been provided in this deployment environment. Will need client-supplied keys, which is not recommended.'); +// Input Schemas + const accessSchema = z.object({ oaiKey: z.string().trim(), oaiOrg: z.string().trim(), @@ -29,7 +31,7 @@ const historySchema = z.array(z.object({ content: z.string(), })); -/*const functionsSchema = z.array(z.object({ +const functionsSchema = z.array(z.object({ name: z.string(), description: z.string().optional(), parameters: z.object({ @@ -41,12 +43,29 @@ const historySchema = z.array(z.object({ })), required: z.array(z.string()).optional(), }).optional(), -}));*/ +})); -export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema }); +export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() }); export type ChatGenerateSchema = z.infer; -export const chatModerationSchema = z.object({ access: accessSchema, text: z.string() }); +const chatModerationSchema = z.object({ access: accessSchema, text: z.string() }); + + +// Output Schemas + +const chatGenerateWithFunctionsOutputSchema = z.union([ + z.object({ + role: z.enum(['assistant', 'system', 'user']), + content: z.string(), + finish_reason: z.union([z.enum(['stop', 'length']), z.null()]), + }), + z.object({ + function_name: z.string(), + function_arguments: z.record(z.any()), + }), +]); + + export const openAIRouter = createTRPCRouter({ @@ -54,33 +73,29 @@ export const openAIRouter = createTRPCRouter({ /** * Chat-based message generation */ - chatGenerate: publicProcedure + chatGenerateWithFunctions: publicProcedure .input(chatGenerateSchema) - .mutation(async ({ input }): Promise => { + .output(chatGenerateWithFunctionsOutputSchema) + .mutation(async ({ input }) => { - const { access, model, history } = input; - const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false); - let wireCompletions: OpenAI.Wire.ChatCompletion.Response; + const { access, model, history, functions } = input; + const isFunctionsCall = !!functions && functions.length > 0; - // try { - wireCompletions = await openaiPOST(access, requestBody, '/v1/chat/completions'); - // } catch (error: any) { - // // NOTE: disabled on 2023-06-19: show all errors, 429 is not that common now, and could explain issues - // // don't log 429 errors on the server-side, they are expected - // if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests')) - // console.error('api/openai/chat error:', error); - // throw error; - // } + const wireCompletions = await openaiPOST( + access, + openAIChatCompletionRequest(model, history, isFunctionsCall ? functions : null, false), + '/v1/chat/completions', + ); + // expect a single output if (wireCompletions?.choices?.length !== 1) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` }); + const { message, finish_reason } = wireCompletions.choices[0]; - const singleChoice = wireCompletions.choices[0]; - return { - role: singleChoice.message.role, - content: singleChoice.message.content, - finish_reason: singleChoice.finish_reason, - }; + // check for a function output + return finish_reason === 'function_call' + ? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAI.Wire.ChatCompletion.ResponseFunctionCall) + : parseChatGenerateOutput(message as OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason); }), /** @@ -147,6 +162,7 @@ export const openAIRouter = createTRPCRouter({ type AccessSchema = z.infer; type ModelSchema = z.infer; type HistorySchema = z.infer; +type FunctionsSchema = z.infer; async function openaiGET(access: AccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { const { headers, url } = openAIAccess(access, apiPath); @@ -171,7 +187,11 @@ async function openaiPOST(access: AccessSchema, body: TBody, apiPat : `[Issue] ${response.statusText}`, }); } - return await response.json() as TOut; + try { + return await response.json(); + } catch (error: any) { + throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] ${error?.message || error}` }); + } } export function openAIAccess(access: AccessSchema, apiPath: string): { headers: HeadersInit, url: string } { @@ -203,14 +223,71 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers: }; } -export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request { +export function openAIChatCompletionRequest(model: ModelSchema, history: HistorySchema, functions: FunctionsSchema | null, stream: boolean): OpenAI.Wire.ChatCompletion.Request { return { model: model.id, messages: history, - // ...(functions && { functions: functions, function_call: 'auto', }), + ...(functions && { functions: functions, function_call: 'auto' }), ...(model.temperature && { temperature: model.temperature }), ...(model.maxTokens && { max_tokens: model.maxTokens }), stream, n: 1, }; +} + +function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAI.Wire.ChatCompletion.ResponseFunctionCall) { + // NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment + if (!isFunctionsCall) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Received a function call without a function call request`, + }); + + // parse the function call + const fcMessage = message as any as OpenAI.Wire.ChatCompletion.ResponseFunctionCall; + if (fcMessage.content !== null) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Expected a function call, got a message`, + }); + + // got a function call, so parse it + const fc = fcMessage.function_call; + if (!fc || !fc.name || !fc.arguments) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Issue with the function call, missing name or arguments`, + }); + + // decode the function call + const fcName = fc.name; + let fcArgs: object; + try { + fcArgs = JSON.parse(fc.arguments); + } catch (error: any) { + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Issue with the function call, arguments are not valid JSON`, + }); + } + + return { + function_name: fcName, + function_arguments: fcArgs, + }; +} + +function parseChatGenerateOutput(message: OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) { + // validate the message + if (message.content === null) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Expected a message, got a null message`, + }); + + return { + role: message.role, + content: message.content, + finish_reason: finish_reason, + }; } \ No newline at end of file diff --git a/src/modules/llms/openai/openai.types.ts b/src/modules/llms/openai/openai.types.ts index 153872edc..1692759b6 100644 --- a/src/modules/llms/openai/openai.types.ts +++ b/src/modules/llms/openai/openai.types.ts @@ -5,12 +5,6 @@ export namespace OpenAI { export namespace Chat { - export interface Response { - role: 'assistant' | 'system' | 'user'; - content: string; - finish_reason: 'stop' | 'length' | null; - } - /** * The client will be sent a stream of words. As an extra (an totally optional) 'data channel' we send a * string JSON object with the few initial variables. We hope in the future to adopt a better @@ -23,7 +17,12 @@ export namespace OpenAI { } - /// OpenAI API types - https://platform.openai.com/docs/api-reference/ + /** + * OpenAI API types - https://platform.openai.com/docs/api-reference/ + * + * Notes: + * - [FN0613]: function calling capability - only 2023-06-13 and later Chat models + */ export namespace Wire { export namespace ChatCompletion { @@ -37,11 +36,11 @@ export namespace OpenAI { max_tokens?: number; stream: boolean; n: number; - // only 2023-06-13 and later Chat models - // functions?: RequestFunction[], - // function_call?: 'auto' | 'none' | { - // name: string; - // }, + // [FN0613] + functions?: RequestFunctionDef[], + function_call?: 'auto' | 'none' | { + name: string; + }, } export interface RequestMessage { @@ -50,7 +49,7 @@ export namespace OpenAI { //name?: string; // when role: 'function' } - /*export interface RequestFunction { + export interface RequestFunctionDef { // [FN0613] name: string; description?: string; parameters?: { @@ -64,7 +63,7 @@ export namespace OpenAI { } required?: string[]; }; - }*/ + } export interface Response { @@ -74,8 +73,8 @@ export namespace OpenAI { model: string; // can differ from the ask, e.g. 'gpt-4-0314' choices: { index: number; - message: ResponseMessage; // | ResponseFunctionCall; - finish_reason: 'stop' | 'length' | null; // | 'function_call' + message: ResponseMessage | ResponseFunctionCall; // [FN0613] + finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613] }[]; usage: { prompt_tokens: number; @@ -84,19 +83,19 @@ export namespace OpenAI { }; } - interface ResponseMessage { + export interface ResponseMessage { role: 'assistant'; content: string; } - /*interface ResponseFunctionCall { + export interface ResponseFunctionCall { // [FN0613] role: 'assistant'; content: null; function_call: { // if content is null and finish_reason is 'function_call' name: string; arguments: string; // a JSON object, to deserialize }; - }*/ + } export interface ResponseStreamingChunk { id: string; diff --git a/src/modules/llms/openai/openai.vendor.ts b/src/modules/llms/openai/openai.vendor.ts index 55f0131cc..fe4109b5f 100644 --- a/src/modules/llms/openai/openai.vendor.ts +++ b/src/modules/llms/openai/openai.vendor.ts @@ -2,7 +2,7 @@ import { ModelVendor } from '../llm.types'; import { OpenAIIcon } from './OpenAIIcon'; import { OpenAILLMOptions } from './OpenAILLMOptions'; import { OpenAISourceSetup } from './OpenAISourceSetup'; -import { callChat } from './openai.client'; +import { callChat, callChatWithFunctions } from './openai.client'; export const ModelVendorOpenAI: ModelVendor = { @@ -19,6 +19,7 @@ export const ModelVendorOpenAI: ModelVendor = { // functions callChat: callChat, + callChatWithFunctions: callChatWithFunctions, };