mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
OpenAI Wire: port image generation and moderations
This commit is contained in:
@@ -7,6 +7,9 @@ import { z } from 'zod';
|
||||
// - 2024-07-09: ignoring the advanced model configuration
|
||||
//
|
||||
|
||||
//
|
||||
// Chat > Create chat completion
|
||||
//
|
||||
|
||||
/// Content parts - Input
|
||||
|
||||
@@ -262,7 +265,6 @@ const openaiWire_ChatCompletionChunkChoice_Schema = z.object({
|
||||
// logprobs: ... // Log probability information for the choice.
|
||||
});
|
||||
|
||||
export type OpenaiWire_ChatCompletionChunkResponse = z.infer<typeof openaiWire_ChatCompletionChunkResponse_Schema>;
|
||||
export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({
|
||||
object: z.enum(['chat.completion.chunk', '' /* [Azure] bad response */]),
|
||||
id: z.string(),
|
||||
@@ -284,3 +286,114 @@ export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({
|
||||
error: openaiWire_UndocumentedError_Schema.optional(),
|
||||
warning: openaiWire_UndocumentedWarning_Schema.optional(),
|
||||
});
|
||||
|
||||
|
||||
//
|
||||
// Images > Create Image
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
//
|
||||
|
||||
export type OpenaiWire_CreateImageRequest = z.infer<typeof openaiWire_CreateImageRequest_Schema>;
|
||||
const openaiWire_CreateImageRequest_Schema = z.object({
|
||||
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3
|
||||
prompt: z.string().max(4000),
|
||||
|
||||
// The model to use for image generation
|
||||
model: z.enum(['dall-e-2', 'dall-e-3']).optional().default('dall-e-2'),
|
||||
|
||||
// The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.
|
||||
n: z.number().min(1).max(10).nullable().optional(),
|
||||
|
||||
// 'hd' creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3
|
||||
quality: z.enum(['standard', 'hd']).optional(),
|
||||
|
||||
// The format in which the generated images are returned
|
||||
response_format: z.enum(['url', 'b64_json']).optional(), //.default('url'),
|
||||
|
||||
// 'dall-e-2': must be one of 256x256, 512x512, or 1024x1024
|
||||
// 'dall-e-3': must be one of 1024x1024, 1792x1024, or 1024x1792
|
||||
size: z.enum(['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']).optional().default('1024x1024'),
|
||||
|
||||
// only used by 'dall-e-3': 'vivid' (hyper-real and dramatic images) or 'natural'
|
||||
style: z.enum(['vivid', 'natural']).optional().default('vivid'),
|
||||
|
||||
// A unique identifier representing your end-user
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
|
||||
export type OpenaiWire_CreateImageResponse = z.infer<typeof openaiWire_CreateImageResponse_Schema>;
|
||||
export const openaiWire_CreateImageResponse_Schema = z.object({
|
||||
created: z.number(),
|
||||
data: z.array(z.object({
|
||||
url: z.string().url().optional(),
|
||||
b64_json: z.string().optional(),
|
||||
revised_prompt: z.string().optional(),
|
||||
})),
|
||||
});
|
||||
|
||||
|
||||
//
|
||||
// Models > List Models
|
||||
//
|
||||
|
||||
// Model object schema
|
||||
export type OpenaiWire_Model = z.infer<typeof openaiWire_Model_Schema>;
|
||||
const openaiWire_Model_Schema = z.object({
|
||||
id: z.string(),
|
||||
object: z.literal('model'),
|
||||
created: z.number().optional(),
|
||||
// [dialect:OpenAI] 'openai' | 'openai-dev' | 'openai-internal' | 'system'
|
||||
// [dialect:Oobabooga] 'user'
|
||||
owned_by: z.string().optional(),
|
||||
|
||||
// **Extensions**
|
||||
// [Openrouter] non-standard - commented because dynamically added by the Openrouter vendor code
|
||||
// context_length: z.number().optional(),
|
||||
});
|
||||
|
||||
// List models response schema
|
||||
export type OpenaiWire_ModelList = z.infer<typeof openaiWire_ModelList_Schema>;
|
||||
const openaiWire_ModelList_Schema = z.object({
|
||||
object: z.literal('list'),
|
||||
data: z.array(openaiWire_Model_Schema),
|
||||
});
|
||||
|
||||
|
||||
//
|
||||
// Moderations > Create Moderation
|
||||
//
|
||||
|
||||
export const openaiWire_ModerationCategory_Schema = z.enum([
|
||||
'sexual',
|
||||
'hate',
|
||||
'harassment',
|
||||
'self-harm',
|
||||
'sexual/minors',
|
||||
'hate/threatening',
|
||||
'violence/graphic',
|
||||
'self-harm/intent',
|
||||
'self-harm/instructions',
|
||||
'harassment/threatening',
|
||||
'violence',
|
||||
]);
|
||||
|
||||
export type OpenaiWire_ModerationRequest = z.infer<typeof openaiWire_ModerationRequest_Schema>;
|
||||
const openaiWire_ModerationRequest_Schema = z.object({
|
||||
// input: z.union([z.string(), z.array(z.string())]),
|
||||
input: z.string(),
|
||||
model: z.enum(['text-moderation-stable', 'text-moderation-latest']).optional(),
|
||||
});
|
||||
|
||||
const openaiWire_ModerationResult_Schema = z.object({
|
||||
flagged: z.boolean(),
|
||||
categories: z.record(openaiWire_ModerationCategory_Schema, z.boolean()),
|
||||
category_scores: z.record(openaiWire_ModerationCategory_Schema, z.number()),
|
||||
});
|
||||
|
||||
export type OpenaiWire_ModerationResponse = z.infer<typeof openaiWire_ModerationResponse_Schema>;
|
||||
const openaiWire_ModerationResponse_Schema = z.object({
|
||||
id: z.string(),
|
||||
model: z.string(),
|
||||
results: z.array(openaiWire_ModerationResult_Schema),
|
||||
});
|
||||
|
||||
@@ -18,8 +18,8 @@ import { wireOllamaChunkedOutputSchema } from '../../aix/server/dispatch/ollama/
|
||||
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router';
|
||||
|
||||
// OpenAI server imports
|
||||
import type { OpenAIWire } from './openai/openai.wiretypes';
|
||||
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router';
|
||||
import { openaiWire_ChatCompletionChunkResponse_Schema } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
|
||||
|
||||
import { llmsStreamingContextSchema } from './llm.server.types';
|
||||
@@ -455,7 +455,7 @@ function createStreamParserOpenAI(): AIStreamParser {
|
||||
|
||||
return (data: string) => {
|
||||
|
||||
const json: OpenAIWire.ChatCompletion.ResponseStreamingChunk = JSON.parse(data);
|
||||
const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(data));
|
||||
|
||||
// [OpenAI] an upstream error will be handled gracefully and transmitted as text (throw to transmit as 'error')
|
||||
if (json.error)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import type { OpenaiWire_Model } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
|
||||
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Json, LLM_IF_OAI_Vision } from '../../store-llms';
|
||||
|
||||
import type { ModelDescriptionSchema } from '../llm.server.types';
|
||||
import type { OpenAIWire } from './openai.wiretypes';
|
||||
import { wireGroqModelsListOutputSchema } from './groq.wiretypes';
|
||||
import { wireMistralModelsListOutputSchema } from './mistral.wiretypes';
|
||||
import { wireOpenrouterModelsListOutputSchema } from './openrouter.wiretypes';
|
||||
@@ -306,11 +307,11 @@ const openAIModelsDenyList: string[] = [
|
||||
'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-0301', 'gpt-3.5-turbo-16k',
|
||||
];
|
||||
|
||||
export function openAIModelFilter(model: OpenAIWire.Models.ModelDescription) {
|
||||
export function openAIModelFilter(model: OpenaiWire_Model) {
|
||||
return !openAIModelsDenyList.some(deny => model.id.includes(deny));
|
||||
}
|
||||
|
||||
export function openAIModelToModelDescription(modelId: string, modelCreated: number, modelUpdated?: number): ModelDescriptionSchema {
|
||||
export function openAIModelToModelDescription(modelId: string, modelCreated: number | undefined, modelUpdated?: number): ModelDescriptionSchema {
|
||||
return fromManualMapping(_knownOpenAIChatModels, modelId, modelCreated, modelUpdated);
|
||||
}
|
||||
|
||||
@@ -663,7 +664,7 @@ const _knownOobaboogaNonChatModels: string[] = [
|
||||
/* 'gpt-3.5-turbo' // used to be here, but now it's the way to select the activly loaded ooababooga model */
|
||||
];
|
||||
|
||||
export function oobaboogaModelToModelDescription(modelId: string, created: number): ModelDescriptionSchema {
|
||||
export function oobaboogaModelToModelDescription(modelId: string, created: number | undefined): ModelDescriptionSchema {
|
||||
let label = modelId.replaceAll(/[_-]/g, ' ').split(' ').map(word => word[0].toUpperCase() + word.slice(1)).join(' ');
|
||||
if (label.endsWith('.bin'))
|
||||
label = label.slice(0, -4);
|
||||
|
||||
@@ -10,9 +10,9 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i
|
||||
import { Brand } from '~/common/app.config';
|
||||
import { fixupHost } from '~/common/util/urlUtils';
|
||||
|
||||
import type { OpenaiWire_ChatCompletionRequest } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
|
||||
|
||||
import { OpenAIWire, WireOpenAICreateImageOutput, wireOpenAICreateImageOutputSchema, WireOpenAICreateImageRequest } from './openai.wiretypes';
|
||||
import type { OpenAIWire } from './openai.wiretypes';
|
||||
import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data';
|
||||
import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
|
||||
import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes';
|
||||
@@ -155,13 +155,13 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
|
||||
|
||||
// [non-Azure]: fetch openAI-style for all but Azure (will be then used in each dialect)
|
||||
const openAIWireModelsResponse = await openaiGETOrThrow<OpenAIWire.Models.Response>(access, '/v1/models');
|
||||
const openAIWireModelsResponse = await openaiGETOrThrow<OpenaiWire_ModelList>(access, '/v1/models');
|
||||
|
||||
// [Together] missing the .data property
|
||||
if (access.dialect === 'togetherai')
|
||||
return { models: togetherAIModelsToModelDescriptions(openAIWireModelsResponse) };
|
||||
|
||||
let openAIModels: OpenAIWire.Models.ModelDescription[] = openAIWireModelsResponse.data || [];
|
||||
let openAIModels = openAIWireModelsResponse.data || [];
|
||||
|
||||
// de-duplicate by ids (can happen for local servers.. upstream bugs)
|
||||
const preCount = openAIModels.length;
|
||||
@@ -321,7 +321,7 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
throw new TRPCError({ code: 'BAD_REQUEST', message: `[OpenAI Issue] dall-e-3 model does not support more than 1 image` });
|
||||
|
||||
// images/generations request body
|
||||
const requestBody: WireOpenAICreateImageRequest = {
|
||||
const requestBody: OpenaiWire_CreateImageRequest = {
|
||||
prompt: config.prompt,
|
||||
model: config.model,
|
||||
n: config.count,
|
||||
@@ -337,7 +337,7 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
delete requestBody.response_format;
|
||||
|
||||
// create 1 image (dall-e-3 won't support more than 1, so better transfer the burden to the client)
|
||||
const wireOpenAICreateImageOutput = await openaiPOSTOrThrow<WireOpenAICreateImageOutput, WireOpenAICreateImageRequest>(
|
||||
const wireOpenAICreateImageOutput = await openaiPOSTOrThrow<OpenaiWire_CreateImageResponse, OpenaiWire_CreateImageRequest>(
|
||||
access, null, requestBody, '/v1/images/generations',
|
||||
);
|
||||
|
||||
@@ -350,7 +350,7 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
const { count: _count, responseFormat: _responseFormat, prompt: origPrompt, ...parameters } = config;
|
||||
|
||||
// expect a single image and as URL
|
||||
const generatedImages = wireOpenAICreateImageOutputSchema.parse(wireOpenAICreateImageOutput).data;
|
||||
const generatedImages = openaiWire_CreateImageResponse_Schema.parse(wireOpenAICreateImageOutput).data;
|
||||
return generatedImages.map((image): T2iCreateImageOutput => {
|
||||
if (!('b64_json' in image))
|
||||
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected a b64_json, got a url` });
|
||||
@@ -371,10 +371,10 @@ export const llmOpenAIRouter = createTRPCRouter({
|
||||
/* [OpenAI] check for content policy violations */
|
||||
moderation: publicProcedure
|
||||
.input(moderationInputSchema)
|
||||
.mutation(async ({ input: { access, text } }): Promise<OpenAIWire.Moderation.Response> => {
|
||||
.mutation(async ({ input: { access, text } }): Promise<OpenaiWire_ModerationResponse> => {
|
||||
try {
|
||||
|
||||
return await openaiPOSTOrThrow<OpenAIWire.Moderation.Response, OpenAIWire.Moderation.Request>(access, null, {
|
||||
return await openaiPOSTOrThrow<OpenaiWire_ModerationResponse, OpenaiWire_ModerationRequest>(access, null, {
|
||||
input: text,
|
||||
model: 'text-moderation-latest',
|
||||
}, '/v1/moderations');
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
|
||||
/**
|
||||
* OpenAI API types - https://platform.openai.com/docs/api-reference/
|
||||
*
|
||||
@@ -8,36 +5,11 @@ import { z } from 'zod';
|
||||
* - 2023-12-22:
|
||||
* Below we have the manually typed types for the OpenAI API. Everywhere else we are switching
|
||||
* to Zod inferred types, and we shall do it here sooner (so we can validate upon parsing too).
|
||||
*
|
||||
* - [FN0613]: function calling capability - only 2023-06-13 and later Chat models
|
||||
*/
|
||||
export namespace OpenAIWire {
|
||||
|
||||
export namespace ChatCompletion {
|
||||
|
||||
export interface Request {
|
||||
model: string;
|
||||
messages: RequestMessage[];
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
frequency_penalty?: number;
|
||||
presence_penalty?: number;
|
||||
max_tokens?: number;
|
||||
stream: boolean;
|
||||
n?: number;
|
||||
// [FN0613]
|
||||
functions?: RequestFunctionDef[],
|
||||
function_call?: 'auto' | 'none' | {
|
||||
name: string;
|
||||
},
|
||||
}
|
||||
|
||||
export interface RequestMessage {
|
||||
role: 'assistant' | 'system' | 'user'; // | 'function';
|
||||
content: string;
|
||||
//name?: string; // when role: 'function'
|
||||
}
|
||||
|
||||
export interface RequestFunctionDef { // [FN0613]
|
||||
name: string;
|
||||
description?: string;
|
||||
@@ -86,128 +58,5 @@ export namespace OpenAIWire {
|
||||
};
|
||||
}
|
||||
|
||||
export interface ResponseStreamingChunk {
|
||||
id: string;
|
||||
object: 'chat.completion.chunk' | ''; // '' is for some Azure responses
|
||||
created: number;
|
||||
model: string;
|
||||
choices: {
|
||||
index: number;
|
||||
delta: Partial<ResponseMessage>;
|
||||
finish_reason: 'stop' | 'length' | null;
|
||||
}[];
|
||||
// undocumented, but can be present, e.g. "This model version is deprecated and a newer version \'gpt-4-0613\' is available. Migrate before..."
|
||||
warning?: string;
|
||||
// this could also be an error - first experienced on 2023-06-19 on streaming APIs (undocumented)
|
||||
error?: {
|
||||
message: string;
|
||||
type: 'server_error' | string;
|
||||
param: string | null;
|
||||
code: string | null;
|
||||
};
|
||||
// [OpenRouter/LocalAI] Extended usage statistics
|
||||
usage?: {
|
||||
prompt_tokens?: number;
|
||||
completion_tokens?: number;
|
||||
total_tokens?: number;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export namespace Models {
|
||||
export interface ModelDescription {
|
||||
id: string;
|
||||
object: 'model';
|
||||
created: number;
|
||||
owned_by: 'openai' | 'openai-dev' | 'openai-internal' | 'system' | string; // 'user' for Oobabooga models
|
||||
// [2023-11-08] Note: the following properties are not present in OpenAI responses any longer
|
||||
// permission: any[];
|
||||
// root: string;
|
||||
// parent: null;
|
||||
|
||||
// non-standard properties
|
||||
//context_length?: number; // Openrouter-only models, non-standard - commented because dynamically added by the Openrouter vendor code
|
||||
}
|
||||
|
||||
export interface Response {
|
||||
object: string;
|
||||
data: ModelDescription[];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export namespace Moderation {
|
||||
export interface Request {
|
||||
input: string | string[];
|
||||
model?: 'text-moderation-stable' | 'text-moderation-latest';
|
||||
}
|
||||
|
||||
export enum ModerationCategory {
|
||||
// noinspection JSUnusedGlobalSymbols
|
||||
hate = 'hate',
|
||||
'hate/threatening' = 'hate/threatening',
|
||||
'self-harm' = 'self-harm',
|
||||
sexual = 'sexual',
|
||||
'sexual/minors' = 'sexual/minors',
|
||||
violence = 'violence',
|
||||
'violence/graphic' = 'violence/graphic',
|
||||
}
|
||||
|
||||
export interface Response {
|
||||
id: string;
|
||||
model: string;
|
||||
results: [
|
||||
{
|
||||
categories: { [key in ModerationCategory]: boolean };
|
||||
category_scores: { [key in ModerationCategory]: number };
|
||||
flagged: boolean;
|
||||
}
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
// OpenAI text to image generation - https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
const wireOpenAICreateImageRequestSchema = z.object({
|
||||
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3
|
||||
prompt: z.string().max(4000),
|
||||
|
||||
// The model to use for image generation
|
||||
model: z.enum(['dall-e-2', 'dall-e-3']).optional().default('dall-e-2'),
|
||||
|
||||
// The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.
|
||||
n: z.number().min(1).max(10).nullable().optional(),
|
||||
|
||||
// 'hd' creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3
|
||||
quality: z.enum(['standard', 'hd']).optional(),
|
||||
|
||||
// The format in which the generated images are returned
|
||||
response_format: z.enum(['url', 'b64_json']).optional(), //.default('url'),
|
||||
|
||||
// 'dall-e-2': must be one of 256x256, 512x512, or 1024x1024
|
||||
// 'dall-e-3': must be one of 1024x1024, 1792x1024, or 1024x1792
|
||||
size: z.enum(['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']).optional().default('1024x1024'),
|
||||
|
||||
// only used by 'dall-e-3': 'vivid' (hyper-real and dramatic images) or 'natural'
|
||||
style: z.enum(['vivid', 'natural']).optional().default('vivid'),
|
||||
|
||||
// A unique identifier representing your end-user
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
export type WireOpenAICreateImageRequest = z.infer<typeof wireOpenAICreateImageRequestSchema>;
|
||||
|
||||
export const wireOpenAICreateImageOutputSchema = z.object({
|
||||
created: z.number(),
|
||||
data: z.array(z.object({
|
||||
b64_json: z.string().optional(),
|
||||
url: z.string().optional(),
|
||||
revised_prompt: z.string().optional(),
|
||||
})),
|
||||
});
|
||||
|
||||
export type WireOpenAICreateImageOutput = z.infer<typeof wireOpenAICreateImageOutputSchema>;
|
||||
+1
-2
@@ -6,7 +6,6 @@ import type { DLLMId } from '../store-llms';
|
||||
import type { VChatContextRef, VChatFunctionIn, VChatMessageIn, VChatStreamContextName } from '../llm.client';
|
||||
|
||||
import type { OpenAIAccessSchema } from '../server/openai/openai.router';
|
||||
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
|
||||
|
||||
|
||||
export type StreamingClientUpdate = Partial<{
|
||||
@@ -149,7 +148,7 @@ async function _openAIModerationCheck(access: OpenAIAccessSchema, lastMessage: V
|
||||
return null;
|
||||
|
||||
try {
|
||||
const moderationResult: OpenAIWire.Moderation.Response = await apiAsync.llmOpenAI.moderation.mutate({
|
||||
const moderationResult = await apiAsync.llmOpenAI.moderation.mutate({
|
||||
access, text: lastMessage.content,
|
||||
});
|
||||
const issues = moderationResult.results.reduce((acc, result) => {
|
||||
|
||||
Reference in New Issue
Block a user