OpenAI: Improve namespacing, ahead of functions support

This commit is contained in:
Enrico Ros
2023-06-13 21:44:01 -07:00
parent b00bc2e1e2
commit 4b170a09dc
7 changed files with 73 additions and 24 deletions
+2 -2
View File
@@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C
// prepare request objects
const { headers, url } = openAIAccess(access, '/v1/chat/completions');
const body: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, true);
const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true);
// perform the request
upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal });
@@ -65,7 +65,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C
}
try {
const json: OpenAI.Wire.Chat.CompletionResponseChunked = JSON.parse(event.data);
const json: OpenAI.Wire.ChatCompletion.ResponseStreamingChunk = JSON.parse(event.data);
// ignore any 'role' delta update
if (json.choices[0].delta?.role && !json.choices[0].delta?.content)
+1 -1
View File
@@ -20,7 +20,7 @@ const actionRe = /^Action: (\w+): (.*)$/;
* - loop() is a function that will update the state (in place)
*/
interface State {
messages: OpenAI.Wire.Chat.Message[];
messages: OpenAI.Wire.ChatCompletion.RequestMessage[];
nextPrompt: string;
lastObservation: string;
result: string | undefined;
+1 -1
View File
@@ -5,7 +5,7 @@ import { useModelsStore } from '~/modules/llms/store-llms';
import { OpenAI } from './openai/openai.types';
export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
// get the vendor
const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId);
+1 -1
View File
@@ -62,4 +62,4 @@ export interface ModelVendor {
callChat: ModelVendorCallChatFn;
}
type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
+1 -1
View File
@@ -13,7 +13,7 @@ export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.start
/**
* This function either returns the LLM response, or throws a descriptive error string
*/
export async function callChat(llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
// access params (source)
const partialSetup = llm._source.setup as Partial<SourceSetupOpenAI>;
const sourceSetupOpenAI = normalizeOAISetup(partialSetup);
+20 -5
View File
@@ -23,10 +23,24 @@ const modelSchema = z.object({
});
const historySchema = z.array(z.object({
role: z.enum(['assistant', 'system', 'user']),
role: z.enum(['assistant', 'system', 'user'/*, 'function'*/]),
content: z.string(),
}));
/*const functionsSchema = z.array(z.object({
name: z.string(),
description: z.string().optional(),
parameters: z.object({
type: z.literal('object'),
properties: z.record(z.object({
type: z.enum(['string', 'number', 'integer', 'boolean']),
description: z.string().optional(),
enum: z.array(z.string()).optional(),
})),
required: z.array(z.string()).optional(),
}).optional(),
}));*/
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema });
export type ChatGenerateSchema = z.infer<typeof chatGenerateSchema>;
@@ -41,11 +55,11 @@ export const openAIRouter = createTRPCRouter({
.mutation(async ({ input }): Promise<OpenAI.API.Chat.Response> => {
const { access, model, history } = input;
const requestBody: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, false);
let wireCompletions: OpenAI.Wire.Chat.CompletionResponse;
const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false);
let wireCompletions: OpenAI.Wire.ChatCompletion.Response;
try {
wireCompletions = await openaiPOST<OpenAI.Wire.Chat.CompletionRequest, OpenAI.Wire.Chat.CompletionResponse>(access, requestBody, '/v1/chat/completions');
wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(access, requestBody, '/v1/chat/completions');
} catch (error: any) {
// don't log 429 errors, they are expected
if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests'))
@@ -147,10 +161,11 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers:
};
}
export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.Chat.CompletionRequest {
export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
return {
model: model.id,
messages: history,
// ...(functions && { functions: functions, function_call: 'auto', }),
...(model.temperature && { temperature: model.temperature }),
...(model.maxTokens && { max_tokens: model.maxTokens }),
stream,
+47 -13
View File
@@ -2,7 +2,7 @@ export namespace OpenAI {
/// Client (Browser) -> Server (Next.js)
export namespace API {
export namespace Chat {
export interface Response {
@@ -25,15 +25,11 @@ export namespace OpenAI {
/// This is the upstream API, for Server (Next.js) -> Upstream Server
export namespace Wire {
export namespace Chat {
export interface Message {
role: 'assistant' | 'system' | 'user';
content: string;
}
export namespace ChatCompletion {
export interface CompletionRequest {
export interface Request {
model: string;
messages: Message[];
messages: RequestMessage[];
temperature?: number;
top_p?: number;
frequency_penalty?: number;
@@ -41,17 +37,45 @@ export namespace OpenAI {
max_tokens?: number;
stream: boolean;
n: number;
// only 2023-06-13 and later Chat models
// functions?: RequestFunction[],
// function_call?: 'auto' | 'none' | {
// name: string;
// },
}
export interface CompletionResponse {
export interface RequestMessage {
role: ('assistant' | 'system' | 'user'); // | 'function';
content: string;
//name?: string; // when role: 'function'
}
/*export interface RequestFunction {
name: string;
description?: string;
parameters?: {
type: 'object';
properties: {
[key: string]: {
type: 'string' | 'number' | 'integer' | 'boolean';
description?: string;
enum?: string[];
}
}
required?: string[];
};
}*/
export interface Response {
id: string;
object: 'chat.completion';
created: number; // unix timestamp in seconds
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
choices: {
index: number;
message: Message;
finish_reason: 'stop' | 'length' | null;
message: ResponseMessage;
finish_reason: ('stop' | 'length' | null); // | 'function_call';
}[];
usage: {
prompt_tokens: number;
@@ -60,19 +84,29 @@ export namespace OpenAI {
};
}
export interface CompletionResponseChunked {
export interface ResponseMessage {
role: 'assistant' | 'system' | 'user';
content: string; // | null; // null for function_calls
// function_call?: { // if content is null and finish_reason is 'function_call'
// name: string;
// arguments: string; // a JSON object, to deserialize
// };
}
export interface ResponseStreamingChunk {
id: string;
object: 'chat.completion.chunk';
created: number;
model: string;
choices: {
index: number;
delta: Partial<Message>;
delta: Partial<ResponseMessage>;
finish_reason: 'stop' | 'length' | null;
}[];
}
}
export namespace Models {
export interface ModelDescription {
id: string;