diff --git a/pages/api/openai/stream-chat.ts b/pages/api/openai/stream-chat.ts index bfb7b658c..5d4d265f0 100644 --- a/pages/api/openai/stream-chat.ts +++ b/pages/api/openai/stream-chat.ts @@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C // prepare request objects const { headers, url } = openAIAccess(access, '/v1/chat/completions'); - const body: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, true); + const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true); // perform the request upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal }); @@ -65,7 +65,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C } try { - const json: OpenAI.Wire.Chat.CompletionResponseChunked = JSON.parse(event.data); + const json: OpenAI.Wire.ChatCompletion.ResponseStreamingChunk = JSON.parse(event.data); // ignore any 'role' delta update if (json.choices[0].delta?.role && !json.choices[0].delta?.content) diff --git a/src/modules/aifn/react/react.ts b/src/modules/aifn/react/react.ts index 93d798297..9de27a49f 100644 --- a/src/modules/aifn/react/react.ts +++ b/src/modules/aifn/react/react.ts @@ -20,7 +20,7 @@ const actionRe = /^Action: (\w+): (.*)$/; * - loop() is a function that will update the state (in place) */ interface State { - messages: OpenAI.Wire.Chat.Message[]; + messages: OpenAI.Wire.ChatCompletion.RequestMessage[]; nextPrompt: string; lastObservation: string; result: string | undefined; diff --git a/src/modules/llms/llm.client.ts b/src/modules/llms/llm.client.ts index 79e5db831..9d6ab60a3 100644 --- a/src/modules/llms/llm.client.ts +++ b/src/modules/llms/llm.client.ts @@ -5,7 +5,7 @@ import { useModelsStore } from '~/modules/llms/store-llms'; import { OpenAI } from './openai/openai.types'; -export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise { +export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { // get the vendor const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId); diff --git a/src/modules/llms/llm.types.ts b/src/modules/llms/llm.types.ts index b36afc4b2..2a729e485 100644 --- a/src/modules/llms/llm.types.ts +++ b/src/modules/llms/llm.types.ts @@ -62,4 +62,4 @@ export interface ModelVendor { callChat: ModelVendorCallChatFn; } -type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number) => Promise; +type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise; diff --git a/src/modules/llms/openai/openai.client.ts b/src/modules/llms/openai/openai.client.ts index 766084434..4330d245a 100644 --- a/src/modules/llms/openai/openai.client.ts +++ b/src/modules/llms/openai/openai.client.ts @@ -13,7 +13,7 @@ export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.start /** * This function either returns the LLM response, or throws a descriptive error string */ -export async function callChat(llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise { +export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise { // access params (source) const partialSetup = llm._source.setup as Partial; const sourceSetupOpenAI = normalizeOAISetup(partialSetup); diff --git a/src/modules/llms/openai/openai.router.ts b/src/modules/llms/openai/openai.router.ts index a22ff2387..49c07c67f 100644 --- a/src/modules/llms/openai/openai.router.ts +++ b/src/modules/llms/openai/openai.router.ts @@ -23,10 +23,24 @@ const modelSchema = z.object({ }); const historySchema = z.array(z.object({ - role: z.enum(['assistant', 'system', 'user']), + role: z.enum(['assistant', 'system', 'user'/*, 'function'*/]), content: z.string(), })); +/*const functionsSchema = z.array(z.object({ + name: z.string(), + description: z.string().optional(), + parameters: z.object({ + type: z.literal('object'), + properties: z.record(z.object({ + type: z.enum(['string', 'number', 'integer', 'boolean']), + description: z.string().optional(), + enum: z.array(z.string()).optional(), + })), + required: z.array(z.string()).optional(), + }).optional(), +}));*/ + export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema }); export type ChatGenerateSchema = z.infer; @@ -41,11 +55,11 @@ export const openAIRouter = createTRPCRouter({ .mutation(async ({ input }): Promise => { const { access, model, history } = input; - const requestBody: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, false); - let wireCompletions: OpenAI.Wire.Chat.CompletionResponse; + const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false); + let wireCompletions: OpenAI.Wire.ChatCompletion.Response; try { - wireCompletions = await openaiPOST(access, requestBody, '/v1/chat/completions'); + wireCompletions = await openaiPOST(access, requestBody, '/v1/chat/completions'); } catch (error: any) { // don't log 429 errors, they are expected if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 ยท Too Many Requests')) @@ -147,10 +161,11 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers: }; } -export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.Chat.CompletionRequest { +export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request { return { model: model.id, messages: history, + // ...(functions && { functions: functions, function_call: 'auto', }), ...(model.temperature && { temperature: model.temperature }), ...(model.maxTokens && { max_tokens: model.maxTokens }), stream, diff --git a/src/modules/llms/openai/openai.types.ts b/src/modules/llms/openai/openai.types.ts index 4f15695e6..a0883f293 100644 --- a/src/modules/llms/openai/openai.types.ts +++ b/src/modules/llms/openai/openai.types.ts @@ -2,7 +2,7 @@ export namespace OpenAI { /// Client (Browser) -> Server (Next.js) export namespace API { - + export namespace Chat { export interface Response { @@ -25,15 +25,11 @@ export namespace OpenAI { /// This is the upstream API, for Server (Next.js) -> Upstream Server export namespace Wire { - export namespace Chat { - export interface Message { - role: 'assistant' | 'system' | 'user'; - content: string; - } + export namespace ChatCompletion { - export interface CompletionRequest { + export interface Request { model: string; - messages: Message[]; + messages: RequestMessage[]; temperature?: number; top_p?: number; frequency_penalty?: number; @@ -41,17 +37,45 @@ export namespace OpenAI { max_tokens?: number; stream: boolean; n: number; + // only 2023-06-13 and later Chat models + // functions?: RequestFunction[], + // function_call?: 'auto' | 'none' | { + // name: string; + // }, } - export interface CompletionResponse { + export interface RequestMessage { + role: ('assistant' | 'system' | 'user'); // | 'function'; + content: string; + //name?: string; // when role: 'function' + } + + /*export interface RequestFunction { + name: string; + description?: string; + parameters?: { + type: 'object'; + properties: { + [key: string]: { + type: 'string' | 'number' | 'integer' | 'boolean'; + description?: string; + enum?: string[]; + } + } + required?: string[]; + }; + }*/ + + + export interface Response { id: string; object: 'chat.completion'; created: number; // unix timestamp in seconds model: string; // can differ from the ask, e.g. 'gpt-4-0314' choices: { index: number; - message: Message; - finish_reason: 'stop' | 'length' | null; + message: ResponseMessage; + finish_reason: ('stop' | 'length' | null); // | 'function_call'; }[]; usage: { prompt_tokens: number; @@ -60,19 +84,29 @@ export namespace OpenAI { }; } - export interface CompletionResponseChunked { + export interface ResponseMessage { + role: 'assistant' | 'system' | 'user'; + content: string; // | null; // null for function_calls + // function_call?: { // if content is null and finish_reason is 'function_call' + // name: string; + // arguments: string; // a JSON object, to deserialize + // }; + } + + export interface ResponseStreamingChunk { id: string; object: 'chat.completion.chunk'; created: number; model: string; choices: { index: number; - delta: Partial; + delta: Partial; finish_reason: 'stop' | 'length' | null; }[]; } } + export namespace Models { export interface ModelDescription { id: string;