OpenAI Wire: port image generation and moderations

This commit is contained in:
Enrico Ros
2024-07-09 17:10:42 -07:00
parent eecf220bfe
commit 69a58c435b
6 changed files with 131 additions and 169 deletions
@@ -7,6 +7,9 @@ import { z } from 'zod';
// - 2024-07-09: ignoring the advanced model configuration
//
//
// Chat > Create chat completion
//
/// Content parts - Input
@@ -262,7 +265,6 @@ const openaiWire_ChatCompletionChunkChoice_Schema = z.object({
// logprobs: ... // Log probability information for the choice.
});
export type OpenaiWire_ChatCompletionChunkResponse = z.infer<typeof openaiWire_ChatCompletionChunkResponse_Schema>;
export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({
object: z.enum(['chat.completion.chunk', '' /* [Azure] bad response */]),
id: z.string(),
@@ -284,3 +286,114 @@ export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({
error: openaiWire_UndocumentedError_Schema.optional(),
warning: openaiWire_UndocumentedWarning_Schema.optional(),
});
//
// Images > Create Image
// https://platform.openai.com/docs/api-reference/images/create
//
export type OpenaiWire_CreateImageRequest = z.infer<typeof openaiWire_CreateImageRequest_Schema>;
const openaiWire_CreateImageRequest_Schema = z.object({
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3
prompt: z.string().max(4000),
// The model to use for image generation
model: z.enum(['dall-e-2', 'dall-e-3']).optional().default('dall-e-2'),
// The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.
n: z.number().min(1).max(10).nullable().optional(),
// 'hd' creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3
quality: z.enum(['standard', 'hd']).optional(),
// The format in which the generated images are returned
response_format: z.enum(['url', 'b64_json']).optional(), //.default('url'),
// 'dall-e-2': must be one of 256x256, 512x512, or 1024x1024
// 'dall-e-3': must be one of 1024x1024, 1792x1024, or 1024x1792
size: z.enum(['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']).optional().default('1024x1024'),
// only used by 'dall-e-3': 'vivid' (hyper-real and dramatic images) or 'natural'
style: z.enum(['vivid', 'natural']).optional().default('vivid'),
// A unique identifier representing your end-user
user: z.string().optional(),
});
export type OpenaiWire_CreateImageResponse = z.infer<typeof openaiWire_CreateImageResponse_Schema>;
export const openaiWire_CreateImageResponse_Schema = z.object({
created: z.number(),
data: z.array(z.object({
url: z.string().url().optional(),
b64_json: z.string().optional(),
revised_prompt: z.string().optional(),
})),
});
//
// Models > List Models
//
// Model object schema
export type OpenaiWire_Model = z.infer<typeof openaiWire_Model_Schema>;
const openaiWire_Model_Schema = z.object({
id: z.string(),
object: z.literal('model'),
created: z.number().optional(),
// [dialect:OpenAI] 'openai' | 'openai-dev' | 'openai-internal' | 'system'
// [dialect:Oobabooga] 'user'
owned_by: z.string().optional(),
// **Extensions**
// [Openrouter] non-standard - commented because dynamically added by the Openrouter vendor code
// context_length: z.number().optional(),
});
// List models response schema
export type OpenaiWire_ModelList = z.infer<typeof openaiWire_ModelList_Schema>;
const openaiWire_ModelList_Schema = z.object({
object: z.literal('list'),
data: z.array(openaiWire_Model_Schema),
});
//
// Moderations > Create Moderation
//
export const openaiWire_ModerationCategory_Schema = z.enum([
'sexual',
'hate',
'harassment',
'self-harm',
'sexual/minors',
'hate/threatening',
'violence/graphic',
'self-harm/intent',
'self-harm/instructions',
'harassment/threatening',
'violence',
]);
export type OpenaiWire_ModerationRequest = z.infer<typeof openaiWire_ModerationRequest_Schema>;
const openaiWire_ModerationRequest_Schema = z.object({
// input: z.union([z.string(), z.array(z.string())]),
input: z.string(),
model: z.enum(['text-moderation-stable', 'text-moderation-latest']).optional(),
});
const openaiWire_ModerationResult_Schema = z.object({
flagged: z.boolean(),
categories: z.record(openaiWire_ModerationCategory_Schema, z.boolean()),
category_scores: z.record(openaiWire_ModerationCategory_Schema, z.number()),
});
export type OpenaiWire_ModerationResponse = z.infer<typeof openaiWire_ModerationResponse_Schema>;
const openaiWire_ModerationResponse_Schema = z.object({
id: z.string(),
model: z.string(),
results: z.array(openaiWire_ModerationResult_Schema),
});
@@ -18,8 +18,8 @@ import { wireOllamaChunkedOutputSchema } from '../../aix/server/dispatch/ollama/
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router';
// OpenAI server imports
import type { OpenAIWire } from './openai/openai.wiretypes';
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router';
import { openaiWire_ChatCompletionChunkResponse_Schema } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import { llmsStreamingContextSchema } from './llm.server.types';
@@ -455,7 +455,7 @@ function createStreamParserOpenAI(): AIStreamParser {
return (data: string) => {
const json: OpenAIWire.ChatCompletion.ResponseStreamingChunk = JSON.parse(data);
const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(data));
// [OpenAI] an upstream error will be handled gracefully and transmitted as text (throw to transmit as 'error')
if (json.error)
@@ -1,7 +1,8 @@
import type { OpenaiWire_Model } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Complete, LLM_IF_OAI_Fn, LLM_IF_OAI_Json, LLM_IF_OAI_Vision } from '../../store-llms';
import type { ModelDescriptionSchema } from '../llm.server.types';
import type { OpenAIWire } from './openai.wiretypes';
import { wireGroqModelsListOutputSchema } from './groq.wiretypes';
import { wireMistralModelsListOutputSchema } from './mistral.wiretypes';
import { wireOpenrouterModelsListOutputSchema } from './openrouter.wiretypes';
@@ -306,11 +307,11 @@ const openAIModelsDenyList: string[] = [
'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-0301', 'gpt-3.5-turbo-16k',
];
export function openAIModelFilter(model: OpenAIWire.Models.ModelDescription) {
export function openAIModelFilter(model: OpenaiWire_Model) {
return !openAIModelsDenyList.some(deny => model.id.includes(deny));
}
export function openAIModelToModelDescription(modelId: string, modelCreated: number, modelUpdated?: number): ModelDescriptionSchema {
export function openAIModelToModelDescription(modelId: string, modelCreated: number | undefined, modelUpdated?: number): ModelDescriptionSchema {
return fromManualMapping(_knownOpenAIChatModels, modelId, modelCreated, modelUpdated);
}
@@ -663,7 +664,7 @@ const _knownOobaboogaNonChatModels: string[] = [
/* 'gpt-3.5-turbo' // used to be here, but now it's the way to select the activly loaded ooababooga model */
];
export function oobaboogaModelToModelDescription(modelId: string, created: number): ModelDescriptionSchema {
export function oobaboogaModelToModelDescription(modelId: string, created: number | undefined): ModelDescriptionSchema {
let label = modelId.replaceAll(/[_-]/g, ' ').split(' ').map(word => word[0].toUpperCase() + word.slice(1)).join(' ');
if (label.endsWith('.bin'))
label = label.slice(0, -4);
@@ -10,9 +10,9 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i
import { Brand } from '~/common/app.config';
import { fixupHost } from '~/common/util/urlUtils';
import type { OpenaiWire_ChatCompletionRequest } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import { OpenaiWire_ChatCompletionRequest, OpenaiWire_CreateImageRequest, OpenaiWire_CreateImageResponse, openaiWire_CreateImageResponse_Schema, OpenaiWire_ModelList, OpenaiWire_ModerationRequest, OpenaiWire_ModerationResponse } from '~/modules/aix/server/dispatch/openai/oai.wiretypes';
import { OpenAIWire, WireOpenAICreateImageOutput, wireOpenAICreateImageOutputSchema, WireOpenAICreateImageRequest } from './openai.wiretypes';
import type { OpenAIWire } from './openai.wiretypes';
import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data';
import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
import { wilreLocalAIModelsApplyOutputSchema, wireLocalAIModelsAvailableOutputSchema, wireLocalAIModelsListOutputSchema } from './localai.wiretypes';
@@ -155,13 +155,13 @@ export const llmOpenAIRouter = createTRPCRouter({
// [non-Azure]: fetch openAI-style for all but Azure (will be then used in each dialect)
const openAIWireModelsResponse = await openaiGETOrThrow<OpenAIWire.Models.Response>(access, '/v1/models');
const openAIWireModelsResponse = await openaiGETOrThrow<OpenaiWire_ModelList>(access, '/v1/models');
// [Together] missing the .data property
if (access.dialect === 'togetherai')
return { models: togetherAIModelsToModelDescriptions(openAIWireModelsResponse) };
let openAIModels: OpenAIWire.Models.ModelDescription[] = openAIWireModelsResponse.data || [];
let openAIModels = openAIWireModelsResponse.data || [];
// de-duplicate by ids (can happen for local servers.. upstream bugs)
const preCount = openAIModels.length;
@@ -321,7 +321,7 @@ export const llmOpenAIRouter = createTRPCRouter({
throw new TRPCError({ code: 'BAD_REQUEST', message: `[OpenAI Issue] dall-e-3 model does not support more than 1 image` });
// images/generations request body
const requestBody: WireOpenAICreateImageRequest = {
const requestBody: OpenaiWire_CreateImageRequest = {
prompt: config.prompt,
model: config.model,
n: config.count,
@@ -337,7 +337,7 @@ export const llmOpenAIRouter = createTRPCRouter({
delete requestBody.response_format;
// create 1 image (dall-e-3 won't support more than 1, so better transfer the burden to the client)
const wireOpenAICreateImageOutput = await openaiPOSTOrThrow<WireOpenAICreateImageOutput, WireOpenAICreateImageRequest>(
const wireOpenAICreateImageOutput = await openaiPOSTOrThrow<OpenaiWire_CreateImageResponse, OpenaiWire_CreateImageRequest>(
access, null, requestBody, '/v1/images/generations',
);
@@ -350,7 +350,7 @@ export const llmOpenAIRouter = createTRPCRouter({
const { count: _count, responseFormat: _responseFormat, prompt: origPrompt, ...parameters } = config;
// expect a single image and as URL
const generatedImages = wireOpenAICreateImageOutputSchema.parse(wireOpenAICreateImageOutput).data;
const generatedImages = openaiWire_CreateImageResponse_Schema.parse(wireOpenAICreateImageOutput).data;
return generatedImages.map((image): T2iCreateImageOutput => {
if (!('b64_json' in image))
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected a b64_json, got a url` });
@@ -371,10 +371,10 @@ export const llmOpenAIRouter = createTRPCRouter({
/* [OpenAI] check for content policy violations */
moderation: publicProcedure
.input(moderationInputSchema)
.mutation(async ({ input: { access, text } }): Promise<OpenAIWire.Moderation.Response> => {
.mutation(async ({ input: { access, text } }): Promise<OpenaiWire_ModerationResponse> => {
try {
return await openaiPOSTOrThrow<OpenAIWire.Moderation.Response, OpenAIWire.Moderation.Request>(access, null, {
return await openaiPOSTOrThrow<OpenaiWire_ModerationResponse, OpenaiWire_ModerationRequest>(access, null, {
input: text,
model: 'text-moderation-latest',
}, '/v1/moderations');
@@ -1,6 +1,3 @@
import { z } from 'zod';
/**
* OpenAI API types - https://platform.openai.com/docs/api-reference/
*
@@ -8,36 +5,11 @@ import { z } from 'zod';
* - 2023-12-22:
* Below we have the manually typed types for the OpenAI API. Everywhere else we are switching
* to Zod inferred types, and we shall do it here sooner (so we can validate upon parsing too).
*
* - [FN0613]: function calling capability - only 2023-06-13 and later Chat models
*/
export namespace OpenAIWire {
export namespace ChatCompletion {
export interface Request {
model: string;
messages: RequestMessage[];
temperature?: number;
top_p?: number;
frequency_penalty?: number;
presence_penalty?: number;
max_tokens?: number;
stream: boolean;
n?: number;
// [FN0613]
functions?: RequestFunctionDef[],
function_call?: 'auto' | 'none' | {
name: string;
},
}
export interface RequestMessage {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}
export interface RequestFunctionDef { // [FN0613]
name: string;
description?: string;
@@ -86,128 +58,5 @@ export namespace OpenAIWire {
};
}
export interface ResponseStreamingChunk {
id: string;
object: 'chat.completion.chunk' | ''; // '' is for some Azure responses
created: number;
model: string;
choices: {
index: number;
delta: Partial<ResponseMessage>;
finish_reason: 'stop' | 'length' | null;
}[];
// undocumented, but can be present, e.g. "This model version is deprecated and a newer version \'gpt-4-0613\' is available. Migrate before..."
warning?: string;
// this could also be an error - first experienced on 2023-06-19 on streaming APIs (undocumented)
error?: {
message: string;
type: 'server_error' | string;
param: string | null;
code: string | null;
};
// [OpenRouter/LocalAI] Extended usage statistics
usage?: {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
}
}
export namespace Models {
export interface ModelDescription {
id: string;
object: 'model';
created: number;
owned_by: 'openai' | 'openai-dev' | 'openai-internal' | 'system' | string; // 'user' for Oobabooga models
// [2023-11-08] Note: the following properties are not present in OpenAI responses any longer
// permission: any[];
// root: string;
// parent: null;
// non-standard properties
//context_length?: number; // Openrouter-only models, non-standard - commented because dynamically added by the Openrouter vendor code
}
export interface Response {
object: string;
data: ModelDescription[];
}
}
export namespace Moderation {
export interface Request {
input: string | string[];
model?: 'text-moderation-stable' | 'text-moderation-latest';
}
export enum ModerationCategory {
// noinspection JSUnusedGlobalSymbols
hate = 'hate',
'hate/threatening' = 'hate/threatening',
'self-harm' = 'self-harm',
sexual = 'sexual',
'sexual/minors' = 'sexual/minors',
violence = 'violence',
'violence/graphic' = 'violence/graphic',
}
export interface Response {
id: string;
model: string;
results: [
{
categories: { [key in ModerationCategory]: boolean };
category_scores: { [key in ModerationCategory]: number };
flagged: boolean;
}
];
}
}
}
// OpenAI text to image generation - https://platform.openai.com/docs/api-reference/images/create
const wireOpenAICreateImageRequestSchema = z.object({
// The maximum length is 1000 characters for dall-e-2 and 4000 characters for dall-e-3
prompt: z.string().max(4000),
// The model to use for image generation
model: z.enum(['dall-e-2', 'dall-e-3']).optional().default('dall-e-2'),
// The number of images to generate. Must be between 1 and 10. For dall-e-3, only n=1 is supported.
n: z.number().min(1).max(10).nullable().optional(),
// 'hd' creates images with finer details and greater consistency across the image. This param is only supported for dall-e-3
quality: z.enum(['standard', 'hd']).optional(),
// The format in which the generated images are returned
response_format: z.enum(['url', 'b64_json']).optional(), //.default('url'),
// 'dall-e-2': must be one of 256x256, 512x512, or 1024x1024
// 'dall-e-3': must be one of 1024x1024, 1792x1024, or 1024x1792
size: z.enum(['256x256', '512x512', '1024x1024', '1792x1024', '1024x1792']).optional().default('1024x1024'),
// only used by 'dall-e-3': 'vivid' (hyper-real and dramatic images) or 'natural'
style: z.enum(['vivid', 'natural']).optional().default('vivid'),
// A unique identifier representing your end-user
user: z.string().optional(),
});
export type WireOpenAICreateImageRequest = z.infer<typeof wireOpenAICreateImageRequestSchema>;
export const wireOpenAICreateImageOutputSchema = z.object({
created: z.number(),
data: z.array(z.object({
b64_json: z.string().optional(),
url: z.string().optional(),
revised_prompt: z.string().optional(),
})),
});
export type WireOpenAICreateImageOutput = z.infer<typeof wireOpenAICreateImageOutputSchema>;
+1 -2
View File
@@ -6,7 +6,6 @@ import type { DLLMId } from '../store-llms';
import type { VChatContextRef, VChatFunctionIn, VChatMessageIn, VChatStreamContextName } from '../llm.client';
import type { OpenAIAccessSchema } from '../server/openai/openai.router';
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
export type StreamingClientUpdate = Partial<{
@@ -149,7 +148,7 @@ async function _openAIModerationCheck(access: OpenAIAccessSchema, lastMessage: V
return null;
try {
const moderationResult: OpenAIWire.Moderation.Response = await apiAsync.llmOpenAI.moderation.mutate({
const moderationResult = await apiAsync.llmOpenAI.moderation.mutate({
access, text: lastMessage.content,
});
const issues = moderationResult.results.reduce((acc, result) => {