diff --git a/pages/api/openai/stream-chat.ts b/pages/api/openai/stream-chat.ts index a8392e57d..99263172f 100644 --- a/pages/api/openai/stream-chat.ts +++ b/pages/api/openai/stream-chat.ts @@ -1,7 +1,7 @@ import { NextRequest, NextResponse } from 'next/server'; import { createParser } from 'eventsource-parser'; -import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAICompletionRequest } from '~/modules/llms/openai/openai.router'; +import { ChatGenerateSchema, chatGenerateSchema, openAIAccess, openAIChatCompletionRequest } from '~/modules/llms/openai/openai.router'; import { OpenAI } from '~/modules/llms/openai/openai.types'; @@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C // prepare request objects const { headers, url } = openAIAccess(access, '/v1/chat/completions'); - const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true); + const body: OpenAI.Wire.ChatCompletion.Request = openAIChatCompletionRequest(model, history, null, true); // perform the request upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal }); diff --git a/src/modules/llms/llm.client.ts b/src/modules/llms/llm.client.ts index a9496f0f9..14ee086a6 100644 --- a/src/modules/llms/llm.client.ts +++ b/src/modules/llms/llm.client.ts @@ -1,17 +1,49 @@ -import { DLLMId } from '~/modules/llms/llm.types'; -import { findVendorById } from '~/modules/llms/vendor.registry'; -import { useModelsStore } from '~/modules/llms/store-llms'; - +import { DLLM, DLLMId } from './llm.types'; import { OpenAI } from './openai/openai.types'; +import { findVendorById } from './vendor.registry'; +import { useModelsStore } from './store-llms'; -export async function callChatGenerate(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { +export type ModelVendorCallChatFn = (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => Promise; +export type ModelVendorCallChatWithFunctionsFn = (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => Promise; - // get the vendor +export interface VChatMessageIn { + role: 'assistant' | 'system' | 'user'; // | 'function'; + content: string; + //name?: string; // when role: 'function' +} + +export type VChatFunctionIn = OpenAI.Wire.ChatCompletion.RequestFunctionDef; + +export interface VChatMessageOut { + role: 'assistant' | 'system' | 'user'; + content: string; + finish_reason: 'stop' | 'length' | null; +} + +export interface VChatFunctionCallOut { + function_name: string; + function_arguments: object | null; +} + +export type VChatMessageOrFunctionCallOut = VChatMessageOut | VChatFunctionCallOut; + + + +export async function callChatGenerate(llmId: DLLMId, messages: VChatMessageIn[], maxTokens?: number): Promise { + const { llm, vendor } = getLLMAndVendorOrThrow(llmId); + return await vendor.callChat(llm, messages, maxTokens); +} + +export async function callChatGenerateWithFunctions(llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number): Promise { + const { llm, vendor } = getLLMAndVendorOrThrow(llmId); + return await vendor.callChatWithFunctions(llm, messages, functions, maxTokens); +} + + +function getLLMAndVendorOrThrow(llmId: string) { const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId); const vendor = findVendorById(llm?._source.vId); if (!llm || !vendor) throw new Error(`callChat: Vendor not found for LLM ${llmId}`); - - // go for it - return await vendor.callChat(llm, messages, maxTokens); + return { llm, vendor }; } \ No newline at end of file diff --git a/src/modules/llms/llm.types.ts b/src/modules/llms/llm.types.ts index 2a729e485..4c9a8ec1d 100644 --- a/src/modules/llms/llm.types.ts +++ b/src/modules/llms/llm.types.ts @@ -1,12 +1,11 @@ import type React from 'react'; import type { LLMOptionsOpenAI, SourceSetupOpenAI } from './openai/openai.vendor'; -import type { OpenAI } from './openai/openai.types'; +import type { ModelVendorCallChatFn, ModelVendorCallChatWithFunctionsFn } from './llm.client'; import type { SourceSetupLocalAI } from './localai/localai.vendor'; export type DLLMId = string; -// export type DLLMTags = 'stream' | 'chat'; export type DLLMOptions = LLMOptionsOpenAI; //DLLMValuesOpenAI | DLLMVaLocalAIDLLMValues; export type DModelSourceId = string; export type DModelSourceSetup = SourceSetupOpenAI | SourceSetupLocalAI; @@ -60,6 +59,5 @@ export interface ModelVendor { // functions callChat: ModelVendorCallChatFn; -} - -type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise; + callChatWithFunctions: ModelVendorCallChatWithFunctionsFn; +} \ No newline at end of file diff --git a/src/modules/llms/localai/localai.vendor.tsx b/src/modules/llms/localai/localai.vendor.tsx index 58b77e6aa..89af7d057 100644 --- a/src/modules/llms/localai/localai.vendor.tsx +++ b/src/modules/llms/localai/localai.vendor.tsx @@ -17,7 +17,8 @@ export const ModelVendorLocalAI: ModelVendor = { LLMOptionsComponent: () => <>No LocalAI Options, // functions - callChat: () => Promise.reject(new Error('LocalAI is not implemented')), + callChat: () => Promise.reject(new Error('LocalAI chat is not implemented')), + callChatWithFunctions: () => Promise.reject(new Error('LocalAI chatWithFunctions is not implemented')), }; diff --git a/src/modules/llms/openai/openai.client.ts b/src/modules/llms/openai/openai.client.ts index 7f25b179f..10ee3d1de 100644 --- a/src/modules/llms/openai/openai.client.ts +++ b/src/modules/llms/openai/openai.client.ts @@ -1,7 +1,7 @@ import { apiAsync } from '~/modules/trpc/trpc.client'; -import { DLLM } from '../llm.types'; -import { OpenAI } from './openai.types'; +import type { DLLM } from '../llm.types'; +import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../llm.client'; import { normalizeOAISetup, SourceSetupOpenAI } from './openai.vendor'; @@ -10,10 +10,17 @@ export const hasServerKeyOpenAI = !!process.env.HAS_SERVER_KEY_OPENAI; export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.startsWith('sk-') && apiKey.length > 40; +export const callChat = async (llm: DLLM, messages: VChatMessageIn[], maxTokens?: number) => + callChatOverloaded(llm, messages, null, maxTokens); + +export const callChatWithFunctions = async (llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[], maxTokens?: number) => + callChatOverloaded(llm, messages, functions, maxTokens); + + /** - * This function either returns the LLM response, or throws a descriptive error string + * This function either returns the LLM message, or function calls, or throws a descriptive error string */ -export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { +async function callChatOverloaded(llm: DLLM, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, maxTokens?: number): Promise { // access params (source) const partialSetup = llm._source.setup as Partial; const sourceSetupOpenAI = normalizeOAISetup(partialSetup); @@ -21,14 +28,18 @@ export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.R // model params (llm) const openaiLlmRef = llm.options.llmRef!; const modelTemperature = llm.options.llmTemperature || 0.5; - // const maxTokens = llm.options.llmResponseTokens || 1024; // <- note: this would be for chat answers, not programmatic chat calls try { - return await apiAsync.openai.chatGenerate.mutate({ + return await apiAsync.openai.chatGenerateWithFunctions.mutate({ access: sourceSetupOpenAI, - model: { id: openaiLlmRef, temperature: modelTemperature, ...(maxTokens && { maxTokens }) }, + model: { + id: openaiLlmRef, + temperature: modelTemperature, + ...(maxTokens && { maxTokens }), + }, + functions: functions ?? undefined, history: messages, - }); + }) as TOut; // errorMessage = `issue fetching: ${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`; } catch (error: any) { const errorMessage = error?.message || error?.toString() || 'OpenAI Chat Fetch Error'; diff --git a/src/modules/llms/openai/openai.router.ts b/src/modules/llms/openai/openai.router.ts index a5cf4f06c..f7c3a4660 100644 --- a/src/modules/llms/openai/openai.router.ts +++ b/src/modules/llms/openai/openai.router.ts @@ -10,6 +10,8 @@ import { OpenAI } from './openai.types'; // console.warn('OPENAI_API_KEY has not been provided in this deployment environment. Will need client-supplied keys, which is not recommended.'); +// Input Schemas + const accessSchema = z.object({ oaiKey: z.string().trim(), oaiOrg: z.string().trim(), @@ -29,7 +31,7 @@ const historySchema = z.array(z.object({ content: z.string(), })); -/*const functionsSchema = z.array(z.object({ +const functionsSchema = z.array(z.object({ name: z.string(), description: z.string().optional(), parameters: z.object({ @@ -41,12 +43,29 @@ const historySchema = z.array(z.object({ })), required: z.array(z.string()).optional(), }).optional(), -}));*/ +})); -export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema }); +export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema, functions: functionsSchema.optional() }); export type ChatGenerateSchema = z.infer; -export const chatModerationSchema = z.object({ access: accessSchema, text: z.string() }); +const chatModerationSchema = z.object({ access: accessSchema, text: z.string() }); + + +// Output Schemas + +const chatGenerateWithFunctionsOutputSchema = z.union([ + z.object({ + role: z.enum(['assistant', 'system', 'user']), + content: z.string(), + finish_reason: z.union([z.enum(['stop', 'length']), z.null()]), + }), + z.object({ + function_name: z.string(), + function_arguments: z.record(z.any()), + }), +]); + + export const openAIRouter = createTRPCRouter({ @@ -54,33 +73,29 @@ export const openAIRouter = createTRPCRouter({ /** * Chat-based message generation */ - chatGenerate: publicProcedure + chatGenerateWithFunctions: publicProcedure .input(chatGenerateSchema) - .mutation(async ({ input }): Promise => { + .output(chatGenerateWithFunctionsOutputSchema) + .mutation(async ({ input }) => { - const { access, model, history } = input; - const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false); - let wireCompletions: OpenAI.Wire.ChatCompletion.Response; + const { access, model, history, functions } = input; + const isFunctionsCall = !!functions && functions.length > 0; - // try { - wireCompletions = await openaiPOST(access, requestBody, '/v1/chat/completions'); - // } catch (error: any) { - // // NOTE: disabled on 2023-06-19: show all errors, 429 is not that common now, and could explain issues - // // don't log 429 errors on the server-side, they are expected - // if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests')) - // console.error('api/openai/chat error:', error); - // throw error; - // } + const wireCompletions = await openaiPOST( + access, + openAIChatCompletionRequest(model, history, isFunctionsCall ? functions : null, false), + '/v1/chat/completions', + ); + // expect a single output if (wireCompletions?.choices?.length !== 1) throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` }); + const { message, finish_reason } = wireCompletions.choices[0]; - const singleChoice = wireCompletions.choices[0]; - return { - role: singleChoice.message.role, - content: singleChoice.message.content, - finish_reason: singleChoice.finish_reason, - }; + // check for a function output + return finish_reason === 'function_call' + ? parseChatGenerateFCOutput(isFunctionsCall, message as OpenAI.Wire.ChatCompletion.ResponseFunctionCall) + : parseChatGenerateOutput(message as OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason); }), /** @@ -147,6 +162,7 @@ export const openAIRouter = createTRPCRouter({ type AccessSchema = z.infer; type ModelSchema = z.infer; type HistorySchema = z.infer; +type FunctionsSchema = z.infer; async function openaiGET(access: AccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { const { headers, url } = openAIAccess(access, apiPath); @@ -171,7 +187,11 @@ async function openaiPOST(access: AccessSchema, body: TBody, apiPat : `[Issue] ${response.statusText}`, }); } - return await response.json() as TOut; + try { + return await response.json(); + } catch (error: any) { + throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] ${error?.message || error}` }); + } } export function openAIAccess(access: AccessSchema, apiPath: string): { headers: HeadersInit, url: string } { @@ -203,14 +223,71 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers: }; } -export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request { +export function openAIChatCompletionRequest(model: ModelSchema, history: HistorySchema, functions: FunctionsSchema | null, stream: boolean): OpenAI.Wire.ChatCompletion.Request { return { model: model.id, messages: history, - // ...(functions && { functions: functions, function_call: 'auto', }), + ...(functions && { functions: functions, function_call: 'auto' }), ...(model.temperature && { temperature: model.temperature }), ...(model.maxTokens && { max_tokens: model.maxTokens }), stream, n: 1, }; +} + +function parseChatGenerateFCOutput(isFunctionsCall: boolean, message: OpenAI.Wire.ChatCompletion.ResponseFunctionCall) { + // NOTE: Defensive: we run extensive validation because the API is not well tested and documented at the moment + if (!isFunctionsCall) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Received a function call without a function call request`, + }); + + // parse the function call + const fcMessage = message as any as OpenAI.Wire.ChatCompletion.ResponseFunctionCall; + if (fcMessage.content !== null) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Expected a function call, got a message`, + }); + + // got a function call, so parse it + const fc = fcMessage.function_call; + if (!fc || !fc.name || !fc.arguments) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Issue with the function call, missing name or arguments`, + }); + + // decode the function call + const fcName = fc.name; + let fcArgs: object; + try { + fcArgs = JSON.parse(fc.arguments); + } catch (error: any) { + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Issue with the function call, arguments are not valid JSON`, + }); + } + + return { + function_name: fcName, + function_arguments: fcArgs, + }; +} + +function parseChatGenerateOutput(message: OpenAI.Wire.ChatCompletion.ResponseMessage, finish_reason: 'stop' | 'length' | null) { + // validate the message + if (message.content === null) + throw new TRPCError({ + code: 'INTERNAL_SERVER_ERROR', + message: `[OpenAI Issue] Expected a message, got a null message`, + }); + + return { + role: message.role, + content: message.content, + finish_reason: finish_reason, + }; } \ No newline at end of file diff --git a/src/modules/llms/openai/openai.types.ts b/src/modules/llms/openai/openai.types.ts index 153872edc..1692759b6 100644 --- a/src/modules/llms/openai/openai.types.ts +++ b/src/modules/llms/openai/openai.types.ts @@ -5,12 +5,6 @@ export namespace OpenAI { export namespace Chat { - export interface Response { - role: 'assistant' | 'system' | 'user'; - content: string; - finish_reason: 'stop' | 'length' | null; - } - /** * The client will be sent a stream of words. As an extra (an totally optional) 'data channel' we send a * string JSON object with the few initial variables. We hope in the future to adopt a better @@ -23,7 +17,12 @@ export namespace OpenAI { } - /// OpenAI API types - https://platform.openai.com/docs/api-reference/ + /** + * OpenAI API types - https://platform.openai.com/docs/api-reference/ + * + * Notes: + * - [FN0613]: function calling capability - only 2023-06-13 and later Chat models + */ export namespace Wire { export namespace ChatCompletion { @@ -37,11 +36,11 @@ export namespace OpenAI { max_tokens?: number; stream: boolean; n: number; - // only 2023-06-13 and later Chat models - // functions?: RequestFunction[], - // function_call?: 'auto' | 'none' | { - // name: string; - // }, + // [FN0613] + functions?: RequestFunctionDef[], + function_call?: 'auto' | 'none' | { + name: string; + }, } export interface RequestMessage { @@ -50,7 +49,7 @@ export namespace OpenAI { //name?: string; // when role: 'function' } - /*export interface RequestFunction { + export interface RequestFunctionDef { // [FN0613] name: string; description?: string; parameters?: { @@ -64,7 +63,7 @@ export namespace OpenAI { } required?: string[]; }; - }*/ + } export interface Response { @@ -74,8 +73,8 @@ export namespace OpenAI { model: string; // can differ from the ask, e.g. 'gpt-4-0314' choices: { index: number; - message: ResponseMessage; // | ResponseFunctionCall; - finish_reason: 'stop' | 'length' | null; // | 'function_call' + message: ResponseMessage | ResponseFunctionCall; // [FN0613] + finish_reason: 'stop' | 'length' | null | 'function_call'; // [FN0613] }[]; usage: { prompt_tokens: number; @@ -84,19 +83,19 @@ export namespace OpenAI { }; } - interface ResponseMessage { + export interface ResponseMessage { role: 'assistant'; content: string; } - /*interface ResponseFunctionCall { + export interface ResponseFunctionCall { // [FN0613] role: 'assistant'; content: null; function_call: { // if content is null and finish_reason is 'function_call' name: string; arguments: string; // a JSON object, to deserialize }; - }*/ + } export interface ResponseStreamingChunk { id: string; diff --git a/src/modules/llms/openai/openai.vendor.ts b/src/modules/llms/openai/openai.vendor.ts index 55f0131cc..fe4109b5f 100644 --- a/src/modules/llms/openai/openai.vendor.ts +++ b/src/modules/llms/openai/openai.vendor.ts @@ -2,7 +2,7 @@ import { ModelVendor } from '../llm.types'; import { OpenAIIcon } from './OpenAIIcon'; import { OpenAILLMOptions } from './OpenAILLMOptions'; import { OpenAISourceSetup } from './OpenAISourceSetup'; -import { callChat } from './openai.client'; +import { callChat, callChatWithFunctions } from './openai.client'; export const ModelVendorOpenAI: ModelVendor = { @@ -19,6 +19,7 @@ export const ModelVendorOpenAI: ModelVendor = { // functions callChat: callChat, + callChatWithFunctions: callChatWithFunctions, };