Ready up the OpenAI APIs - better typing, responsibilities

This commit is contained in:
Enrico Ros
2023-04-10 21:07:52 -07:00
parent db6ce57dc0
commit cc6eaaed22
6 changed files with 200 additions and 154 deletions
+10 -10
View File
@@ -1,4 +1,4 @@
import { ApiChatInput, ApiChatMessage } from '../pages/api/openai/stream-chat';
import { ApiChatInput } from '../pages/api/openai/chat';
import { DMessage } from '@/lib/store-chats';
@@ -13,17 +13,17 @@ export async function streamAssistantMessageEdits(
abortSignal: AbortSignal,
) {
const chatMessages: ApiChatMessage[] = history.map(({ role, text }) => ({
role: role,
content: text,
}));
const payload: ApiChatInput = {
...(apiKey && { apiKey }),
...(apiHost && { apiHost }),
...(apiOrgId && { apiOrgId }),
api: {
...(apiKey && { apiKey }),
...(apiHost && { apiHost }),
...(apiOrgId && { apiOrgId }),
},
model: chatModelId,
messages: chatMessages,
messages: history.map(({ role, text }) => ({
role: role,
content: text,
})),
temperature: modelTemperature,
max_tokens: modelMaxResponseTokens,
};
+103
View File
@@ -0,0 +1,103 @@
import { NextRequest, NextResponse } from 'next/server';
import { OpenAIAPI } from '@/types/api-openai';
if (!process.env.OPENAI_API_KEY)
console.warn(
'OPENAI_API_KEY has not been provided in this deployment environment. ' +
'Will use the optional keys incoming from the client, which is not recommended.',
);
// helper functions
export async function extractOpenaiChatInputs(req: NextRequest): Promise<ApiChatInput> {
const {
api: userApi = {},
model,
messages,
temperature = 0.5,
max_tokens = 1024,
} = (await req.json()) as Partial<ApiChatInput>;
if (!model || !messages)
throw new Error('Missing required parameters: api, model, messages');
const api: OpenAIAPI.Configuration = {
apiKey: (userApi.apiKey || process.env.OPENAI_API_KEY || '').trim(),
apiHost: (userApi.apiHost || process.env.OPENAI_API_HOST || 'api.openai.com').trim().replaceAll('https://', ''),
apiOrgId: (userApi.apiOrgId || process.env.OPENAI_API_ORG_ID || '').trim(),
};
if (!api.apiKey)
throw new Error('Missing OpenAI API Key. Add it on the client side (Settings icon) or server side (your deployment).');
return { api, model, messages, temperature, max_tokens };
}
const openAIHeaders = (api: OpenAIAPI.Configuration): HeadersInit => ({
'Content-Type': 'application/json',
Authorization: `Bearer ${api.apiKey}`,
...(api.apiOrgId && { 'OpenAI-Organization': api.apiOrgId }),
});
export const chatCompletionPayload = (input: Omit<ApiChatInput, 'api'>, stream: boolean): OpenAIAPI.Chat.CompletionsRequest => ({
model: input.model,
messages: input.messages,
...(input.temperature && { temperature: input.temperature }),
...(input.max_tokens && { max_tokens: input.max_tokens }),
stream,
n: 1,
});
export async function postOpenAI<TBody extends object>(api: OpenAIAPI.Configuration, apiPath: string, body: TBody, signal?: AbortSignal): Promise<Response> {
const response = await fetch(`https://${api.apiHost}${apiPath}`, {
method: 'POST',
headers: openAIHeaders(api),
body: JSON.stringify(body),
signal,
});
if (!response.ok) {
let errorPayload: object | null = null;
try {
errorPayload = await response.json();
} catch (e) {
// ignore
}
throw new Error(`${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`);
}
return response;
}
// I/O types for this endpoint
export interface ApiChatInput {
api: OpenAIAPI.Configuration;
model: string;
messages: OpenAIAPI.Chat.Message[];
temperature?: number;
max_tokens?: number;
}
export interface ApiChatResponse {
message: OpenAIAPI.Chat.Message;
}
export default async function handler(req: NextRequest) {
try {
const { api, ...rest } = await extractOpenaiChatInputs(req);
const response = await postOpenAI(api, '/v1/chat/completions', chatCompletionPayload(rest, false));
const completion: OpenAIAPI.Chat.CompletionsResponse = await response.json();
return new NextResponse(JSON.stringify({
message: completion.choices[0].message,
} as ApiChatResponse));
} catch (error: any) {
console.error('Fetch request failed:', error);
return new NextResponse(`[Issue] ${error}`, { status: 400 });
}
}
// noinspection JSUnusedGlobalSymbols
export const config = {
runtime: 'edge',
};
+1 -18
View File
@@ -1,24 +1,7 @@
import { NextResponse } from 'next/server';
import { OpenAIAPI } from '@/types/api-openai';
// definition for OpenAI wire types
namespace OpenAIAPI.Models {
interface Model {
id: string;
object: 'model';
created: number;
owned_by: 'openai' | 'openai-dev' | 'openai-internal' | 'system' | string;
permission: any[];
root: string;
parent: null;
}
export interface ModelList {
object: string;
data: Model[];
}
}
async function fetchOpenAIModels(apiKey: string, apiHost: string): Promise<OpenAIAPI.Models.ModelList> {
const response = await fetch(`https://${apiHost}/v1/models`, {
+8 -125
View File
@@ -1,86 +1,8 @@
import { NextRequest, NextResponse } from 'next/server';
import { createParser } from 'eventsource-parser';
if (!process.env.OPENAI_API_KEY)
console.warn(
'OPENAI_API_KEY has not been provided in this deployment environment. ' +
'Will use the optional keys incoming from the client, which is not recommended.',
);
// definition for OpenAI wire types
namespace OpenAIAPI.Chat {
export interface CompletionMessage {
role: 'assistant' | 'system' | 'user';
content: string;
}
export interface CompletionsRequest {
model: string;
messages: CompletionMessage[];
temperature?: number;
top_p?: number;
frequency_penalty?: number;
presence_penalty?: number;
max_tokens?: number;
stream: boolean;
n: number;
}
export interface CompletionsResponseChunked {
id: string; // unique id of this chunk
object: 'chat.completion.chunk';
created: number; // unix timestamp in seconds
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
choices: {
delta: Partial<CompletionMessage>;
index: number; // always 0s for n=1
finish_reason: 'stop' | 'length' | null;
}[];
}
}
async function fetchOpenAIChatCompletions(
apiCommon: ApiCommonInputs,
completionRequest: Omit<OpenAIAPI.Chat.CompletionsRequest, 'stream' | 'n'>,
signal: AbortSignal,
): Promise<Response> {
const streamingCompletionRequest: OpenAIAPI.Chat.CompletionsRequest = {
...completionRequest,
stream: true,
n: 1,
};
const response = await fetch(`https://${apiCommon.apiHost}/v1/chat/completions`, {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiCommon.apiKey}`,
...(apiCommon.apiOrgId && { 'OpenAI-Organization': apiCommon.apiOrgId }),
},
method: 'POST',
body: JSON.stringify(streamingCompletionRequest),
signal,
});
if (!response.ok) {
// try to parse the OpenAI error payload (incl. description)
let errorPayload: object | null = null;
try {
errorPayload = await response.json();
} catch (e) {
// ignore
}
throw new Error(`${response.status} · ${response.statusText}${errorPayload ? ' · ' + JSON.stringify(errorPayload) : ''}`);
}
return response;
}
import { ApiChatInput, chatCompletionPayload, extractOpenaiChatInputs, postOpenAI } from './chat';
import { OpenAIAPI } from '@/types/api-openai';
// error function: send them down the stream as text
@@ -90,8 +12,7 @@ const sendErrorAndClose = (controller: ReadableStreamDefaultController, encoder:
};
async function chatStreamRepeater(apiCommon: ApiCommonInputs, payload: Omit<OpenAIAPI.Chat.CompletionsRequest, 'stream' | 'n'>, signal: AbortSignal): Promise<ReadableStream> {
const encoder = new TextEncoder();
async function chatStreamRepeater(input: ApiChatInput, signal: AbortSignal): Promise<ReadableStream> {
// Handle the abort event when the connection is closed by the client
signal.addEventListener('abort', () => {
@@ -99,10 +20,11 @@ async function chatStreamRepeater(apiCommon: ApiCommonInputs, payload: Omit<Open
});
// begin event streaming from the OpenAI API
const encoder = new TextEncoder();
let upstreamResponse: Response;
try {
upstreamResponse = await fetchOpenAIChatCompletions(apiCommon, payload, signal);
upstreamResponse = await postOpenAI(input.api, '/v1/chat/completions', chatCompletionPayload(input, true), signal);
} catch (error: any) {
console.log(error);
const message = '[OpenAI Issue] ' + (error?.message || typeof error === 'string' ? error : JSON.stringify(error)) + (error?.cause ? ' · ' + error.cause : '');
@@ -172,24 +94,6 @@ async function chatStreamRepeater(apiCommon: ApiCommonInputs, payload: Omit<Open
}
// Next.js API route
interface ApiCommonInputs {
apiKey?: string;
apiHost?: string;
apiOrgId?: string;
}
export interface ApiChatInput extends ApiCommonInputs {
model: string;
messages: ApiChatMessage[];
temperature?: number;
max_tokens?: number;
}
export type ApiChatMessage = OpenAIAPI.Chat.CompletionMessage;
/**
* The client will be sent a stream of words. As an extra (an totally optional) 'data channel' we send a
* string JSON object with the few initial variables. We hope in the future to adopt a better
@@ -199,31 +103,11 @@ export interface ApiChatFirstOutput {
model: string;
}
export default async function handler(req: NextRequest): Promise<Response> {
const {
apiKey: userApiKey, apiHost: userApiHost, apiOrgId: userApiOrgId,
model, messages,
temperature = 0.5, max_tokens = 2048,
} = await req.json() as ApiChatInput;
const apiCommon: ApiCommonInputs = {
apiKey: (userApiKey || process.env.OPENAI_API_KEY || '').trim(),
apiHost: (userApiHost || process.env.OPENAI_API_HOST || 'api.openai.com').trim().replaceAll('https://', ''),
apiOrgId: (userApiOrgId || process.env.OPENAI_API_ORG_ID || '').trim(),
};
if (!apiCommon.apiKey)
return new Response('[Issue] missing OpenAI API Key. Add it on the client side (Settings icon) or server side (your deployment).', { status: 400 });
try {
const stream: ReadableStream = await chatStreamRepeater(apiCommon, {
model, messages,
temperature, max_tokens,
}, req.signal);
const apiChatInput = await extractOpenaiChatInputs(req);
const stream: ReadableStream = await chatStreamRepeater(apiChatInput, req.signal);
return new NextResponse(stream);
} catch (error: any) {
if (error.name === 'AbortError') {
console.log('Fetch request aborted in handler');
@@ -233,10 +117,9 @@ export default async function handler(req: NextRequest): Promise<Response> {
return new Response('Connection reset by the client.', { status: 499 }); // Use 499 status code for client closed request
} else {
console.error('Fetch request failed:', error);
return new Response('[Issue] Fetch request failed.', { status: 500 });
return new NextResponse(`[Issue] ${error}`, { status: 400 });
}
}
};
//noinspection JSUnusedGlobalSymbols
+2 -1
View File
@@ -18,7 +18,8 @@
"jsxImportSource": "@emotion/react",
"paths": {
"@/components/*": ["components/*"],
"@/lib/*": ["lib/*"]
"@/lib/*": ["lib/*"],
"@/types/*": ["types/*"]
},
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"],
+76
View File
@@ -0,0 +1,76 @@
export namespace OpenAIAPI {
// not an OpenAI type, but the endpoint configuration to access the API
export interface Configuration {
apiKey?: string;
apiHost?: string;
apiOrgId?: string;
}
// [API] Chat
export namespace Chat {
export interface Message {
role: 'assistant' | 'system' | 'user';
content: string;
}
export interface CompletionsRequest {
model: string;
messages: Message[];
temperature?: number;
top_p?: number;
frequency_penalty?: number;
presence_penalty?: number;
max_tokens?: number;
stream: boolean;
n: number;
}
export interface CompletionsResponse {
id: string;
object: 'chat.completion';
created: number; // unix timestamp in seconds
model: string; // can differ from the ask, e.g. 'gpt-4-0314'
choices: {
index: number;
message: Message;
finish_reason: 'stop' | 'length' | null;
}[];
usage: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
};
}
export interface CompletionsResponseChunked {
id: string;
object: 'chat.completion.chunk';
created: number;
model: string;
choices: {
index: number;
delta: Partial<Message>;
finish_reason: 'stop' | 'length' | null;
}[];
}
}
// [API] Models
export namespace Models {
interface Model {
id: string;
object: 'model';
created: number;
owned_by: 'openai' | 'openai-dev' | 'openai-internal' | 'system' | string;
permission: any[];
root: string;
parent: null;
}
export interface ModelList {
object: string;
data: Model[];
}
}
}