mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-11 14:10:15 -07:00
OpenAI: Improve namespacing, ahead of functions support
This commit is contained in:
@@ -31,7 +31,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C
|
||||
|
||||
// prepare request objects
|
||||
const { headers, url } = openAIAccess(access, '/v1/chat/completions');
|
||||
const body: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, true);
|
||||
const body: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, true);
|
||||
|
||||
// perform the request
|
||||
upstreamResponse = await fetch(url, { headers, method: 'POST', body: JSON.stringify(body), signal });
|
||||
@@ -65,7 +65,7 @@ async function chatStreamRepeater(access: ChatGenerateSchema['access'], model: C
|
||||
}
|
||||
|
||||
try {
|
||||
const json: OpenAI.Wire.Chat.CompletionResponseChunked = JSON.parse(event.data);
|
||||
const json: OpenAI.Wire.ChatCompletion.ResponseStreamingChunk = JSON.parse(event.data);
|
||||
|
||||
// ignore any 'role' delta update
|
||||
if (json.choices[0].delta?.role && !json.choices[0].delta?.content)
|
||||
|
||||
@@ -20,7 +20,7 @@ const actionRe = /^Action: (\w+): (.*)$/;
|
||||
* - loop() is a function that will update the state (in place)
|
||||
*/
|
||||
interface State {
|
||||
messages: OpenAI.Wire.Chat.Message[];
|
||||
messages: OpenAI.Wire.ChatCompletion.RequestMessage[];
|
||||
nextPrompt: string;
|
||||
lastObservation: string;
|
||||
result: string | undefined;
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useModelsStore } from '~/modules/llms/store-llms';
|
||||
import { OpenAI } from './openai/openai.types';
|
||||
|
||||
|
||||
export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
|
||||
export async function callChat(llmId: DLLMId, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
|
||||
|
||||
// get the vendor
|
||||
const llm = useModelsStore.getState().llms.find(llm => llm.id === llmId);
|
||||
|
||||
@@ -62,4 +62,4 @@ export interface ModelVendor {
|
||||
callChat: ModelVendorCallChatFn;
|
||||
}
|
||||
|
||||
type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
|
||||
type ModelVendorCallChatFn = (llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number) => Promise<OpenAI.API.Chat.Response>;
|
||||
|
||||
@@ -13,7 +13,7 @@ export const isValidOpenAIApiKey = (apiKey?: string) => !!apiKey && apiKey.start
|
||||
/**
|
||||
* This function either returns the LLM response, or throws a descriptive error string
|
||||
*/
|
||||
export async function callChat(llm: DLLM, messages: OpenAI.Wire.Chat.Message[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
|
||||
export async function callChat(llm: DLLM, messages: OpenAI.Wire.ChatCompletion.RequestMessage[], maxTokens?: number): Promise<OpenAI.API.Chat.Response> {
|
||||
// access params (source)
|
||||
const partialSetup = llm._source.setup as Partial<SourceSetupOpenAI>;
|
||||
const sourceSetupOpenAI = normalizeOAISetup(partialSetup);
|
||||
|
||||
@@ -23,10 +23,24 @@ const modelSchema = z.object({
|
||||
});
|
||||
|
||||
const historySchema = z.array(z.object({
|
||||
role: z.enum(['assistant', 'system', 'user']),
|
||||
role: z.enum(['assistant', 'system', 'user'/*, 'function'*/]),
|
||||
content: z.string(),
|
||||
}));
|
||||
|
||||
/*const functionsSchema = z.array(z.object({
|
||||
name: z.string(),
|
||||
description: z.string().optional(),
|
||||
parameters: z.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.record(z.object({
|
||||
type: z.enum(['string', 'number', 'integer', 'boolean']),
|
||||
description: z.string().optional(),
|
||||
enum: z.array(z.string()).optional(),
|
||||
})),
|
||||
required: z.array(z.string()).optional(),
|
||||
}).optional(),
|
||||
}));*/
|
||||
|
||||
export const chatGenerateSchema = z.object({ access: accessSchema, model: modelSchema, history: historySchema });
|
||||
export type ChatGenerateSchema = z.infer<typeof chatGenerateSchema>;
|
||||
|
||||
@@ -41,11 +55,11 @@ export const openAIRouter = createTRPCRouter({
|
||||
.mutation(async ({ input }): Promise<OpenAI.API.Chat.Response> => {
|
||||
|
||||
const { access, model, history } = input;
|
||||
const requestBody: OpenAI.Wire.Chat.CompletionRequest = openAICompletionRequest(model, history, false);
|
||||
let wireCompletions: OpenAI.Wire.Chat.CompletionResponse;
|
||||
const requestBody: OpenAI.Wire.ChatCompletion.Request = openAICompletionRequest(model, history, false);
|
||||
let wireCompletions: OpenAI.Wire.ChatCompletion.Response;
|
||||
|
||||
try {
|
||||
wireCompletions = await openaiPOST<OpenAI.Wire.Chat.CompletionRequest, OpenAI.Wire.Chat.CompletionResponse>(access, requestBody, '/v1/chat/completions');
|
||||
wireCompletions = await openaiPOST<OpenAI.Wire.ChatCompletion.Request, OpenAI.Wire.ChatCompletion.Response>(access, requestBody, '/v1/chat/completions');
|
||||
} catch (error: any) {
|
||||
// don't log 429 errors, they are expected
|
||||
if (!error || !(typeof error.startsWith === 'function') || !error.startsWith('Error: 429 · Too Many Requests'))
|
||||
@@ -147,10 +161,11 @@ export function openAIAccess(access: AccessSchema, apiPath: string): { headers:
|
||||
};
|
||||
}
|
||||
|
||||
export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.Chat.CompletionRequest {
|
||||
export function openAICompletionRequest(model: ModelSchema, history: HistorySchema, stream: boolean): OpenAI.Wire.ChatCompletion.Request {
|
||||
return {
|
||||
model: model.id,
|
||||
messages: history,
|
||||
// ...(functions && { functions: functions, function_call: 'auto', }),
|
||||
...(model.temperature && { temperature: model.temperature }),
|
||||
...(model.maxTokens && { max_tokens: model.maxTokens }),
|
||||
stream,
|
||||
|
||||
@@ -2,7 +2,7 @@ export namespace OpenAI {
|
||||
|
||||
/// Client (Browser) -> Server (Next.js)
|
||||
export namespace API {
|
||||
|
||||
|
||||
export namespace Chat {
|
||||
|
||||
export interface Response {
|
||||
@@ -25,15 +25,11 @@ export namespace OpenAI {
|
||||
|
||||
/// This is the upstream API, for Server (Next.js) -> Upstream Server
|
||||
export namespace Wire {
|
||||
export namespace Chat {
|
||||
export interface Message {
|
||||
role: 'assistant' | 'system' | 'user';
|
||||
content: string;
|
||||
}
|
||||
export namespace ChatCompletion {
|
||||
|
||||
export interface CompletionRequest {
|
||||
export interface Request {
|
||||
model: string;
|
||||
messages: Message[];
|
||||
messages: RequestMessage[];
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
frequency_penalty?: number;
|
||||
@@ -41,17 +37,45 @@ export namespace OpenAI {
|
||||
max_tokens?: number;
|
||||
stream: boolean;
|
||||
n: number;
|
||||
// only 2023-06-13 and later Chat models
|
||||
// functions?: RequestFunction[],
|
||||
// function_call?: 'auto' | 'none' | {
|
||||
// name: string;
|
||||
// },
|
||||
}
|
||||
|
||||
export interface CompletionResponse {
|
||||
export interface RequestMessage {
|
||||
role: ('assistant' | 'system' | 'user'); // | 'function';
|
||||
content: string;
|
||||
//name?: string; // when role: 'function'
|
||||
}
|
||||
|
||||
/*export interface RequestFunction {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: {
|
||||
type: 'object';
|
||||
properties: {
|
||||
[key: string]: {
|
||||
type: 'string' | 'number' | 'integer' | 'boolean';
|
||||
description?: string;
|
||||
enum?: string[];
|
||||
}
|
||||
}
|
||||
required?: string[];
|
||||
};
|
||||
}*/
|
||||
|
||||
|
||||
export interface Response {
|
||||
id: string;
|
||||
object: 'chat.completion';
|
||||
created: number; // unix timestamp in seconds
|
||||
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
|
||||
choices: {
|
||||
index: number;
|
||||
message: Message;
|
||||
finish_reason: 'stop' | 'length' | null;
|
||||
message: ResponseMessage;
|
||||
finish_reason: ('stop' | 'length' | null); // | 'function_call';
|
||||
}[];
|
||||
usage: {
|
||||
prompt_tokens: number;
|
||||
@@ -60,19 +84,29 @@ export namespace OpenAI {
|
||||
};
|
||||
}
|
||||
|
||||
export interface CompletionResponseChunked {
|
||||
export interface ResponseMessage {
|
||||
role: 'assistant' | 'system' | 'user';
|
||||
content: string; // | null; // null for function_calls
|
||||
// function_call?: { // if content is null and finish_reason is 'function_call'
|
||||
// name: string;
|
||||
// arguments: string; // a JSON object, to deserialize
|
||||
// };
|
||||
}
|
||||
|
||||
export interface ResponseStreamingChunk {
|
||||
id: string;
|
||||
object: 'chat.completion.chunk';
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
index: number;
|
||||
delta: Partial<Message>;
|
||||
delta: Partial<ResponseMessage>;
|
||||
finish_reason: 'stop' | 'length' | null;
|
||||
}[];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export namespace Models {
|
||||
export interface ModelDescription {
|
||||
id: string;
|
||||
|
||||
Reference in New Issue
Block a user