diff --git a/src/modules/aix/server/dispatch/dispatch.parsers.ts b/src/modules/aix/server/dispatch/dispatch.parsers.ts index a16a076d6..8a492ddbe 100644 --- a/src/modules/aix/server/dispatch/dispatch.parsers.ts +++ b/src/modules/aix/server/dispatch/dispatch.parsers.ts @@ -2,10 +2,9 @@ import { z } from 'zod'; import { safeErrorString } from '~/server/wire'; -import type { OpenAIWire } from '~/modules/llms/server/openai/openai.wiretypes'; - import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema, AnthropicWireMessageResponse } from './anthropic/anthropic.wiretypes'; import { geminiGeneratedContentResponseSchema, geminiHarmProbabilitySortFunction, GeminiSafetyRatings } from './gemini/gemini.wiretypes'; +import { openaiWire_ChatCompletionChunkResponse_Schema } from './openai/oai.wiretypes'; import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; @@ -282,7 +281,7 @@ export function createDispatchParserOpenAI(): DispatchParser { return function* (eventData: string): Generator { // Throws on malformed event data - const json: OpenAIWire.ChatCompletion.ResponseStreamingChunk = JSON.parse(eventData); + const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(eventData)); // -> Model if (!hasBegun && json.model) { diff --git a/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts b/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts index f2366dd76..5def8e4b2 100644 --- a/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts +++ b/src/modules/aix/server/dispatch/openai/oai.wiretypes.ts @@ -131,7 +131,7 @@ const openaiWire_ToolChoice_Schema = z.union([ /// API: Content Generation - Request -export type OpenaiWire_Message = z.infer; +export type OpenaiWire_ChatCompletionRequest = z.infer; export const openaiWire_chatCompletionRequest_Schema = z.object({ // basic input model: z.string(), @@ -203,6 +203,18 @@ const openaiWire_Usage_Schema = z.object({ total_tokens: z.number(), }); + +const openaiWire_UndocumentedError_Schema = z.object({ + // (undocumented) first experienced on 2023-06-19 on streaming APIs + message: z.string().optional(), + type: z.string().optional(), + param: z.string().nullable().optional(), + code: z.string().nullable().optional(), +}); + +const openaiWire_UndocumentedWarning_Schema = z.string(); + + const openaiWire_ChatCompletionChoice_Schema = z.object({ index: z.number(), @@ -252,7 +264,7 @@ const openaiWire_ChatCompletionChunkChoice_Schema = z.object({ export type OpenaiWire_ChatCompletionChunkResponse = z.infer; export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({ - object: z.literal('chat.completion.chunk'), + object: z.enum(['chat.completion.chunk', '' /* [Azure] bad response */]), id: z.string(), /** @@ -267,4 +279,8 @@ export const openaiWire_ChatCompletionChunkResponse_Schema = z.object({ created: z.number(), // The Unix timestamp (in seconds) of when the chat completion was created. system_fingerprint: z.string().optional(), // The backend configuration that the model runs with. // service_tier: z.unknown().optional(), + + // undocumented streaming messages + error: openaiWire_UndocumentedError_Schema.optional(), + warning: openaiWire_UndocumentedWarning_Schema.optional(), }); diff --git a/src/modules/llms/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts index 64763f47d..ed66f345a 100644 --- a/src/modules/llms/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -10,6 +10,8 @@ import { T2iCreateImageOutput, t2iCreateImagesOutputSchema } from '~/modules/t2i import { Brand } from '~/common/app.config'; import { fixupHost } from '~/common/util/urlUtils'; +import type { OpenaiWire_ChatCompletionRequest } from '~/modules/aix/server/dispatch/openai/oai.wiretypes'; + import { OpenAIWire, WireOpenAICreateImageOutput, wireOpenAICreateImageOutputSchema, WireOpenAICreateImageRequest } from './openai.wiretypes'; import { azureModelToModelDescription, deepseekModelToModelDescription, groqModelSortFn, groqModelToModelDescription, lmStudioModelToModelDescription, localAIModelToModelDescription, mistralModelsSort, mistralModelToModelDescription, oobaboogaModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openRouterModelFamilySortFn, openRouterModelToModelDescription, perplexityAIModelDescriptions, perplexityAIModelSort, togetherAIModelsToModelDescriptions } from './models.data'; import { llmsChatGenerateWithFunctionsOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types'; @@ -283,7 +285,7 @@ export const llmOpenAIRouter = createTRPCRouter({ const isFunctionsCall = !!functions && functions.length > 0; const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false); - const wireCompletions = await openaiPOSTOrThrow( + const wireCompletions = await openaiPOSTOrThrow( access, model.id, completionsBody, '/v1/chat/completions', ); @@ -627,7 +629,7 @@ export function openAIAccess(access: OpenAIAccessSchema, modelRefId: string | nu } -export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: OpenAIModelSchema, history: OpenAIHistorySchema, functions: OpenAIFunctionsSchema | null, forceFunctionName: string | null, n: number, stream: boolean): OpenAIWire.ChatCompletion.Request { +export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: OpenAIModelSchema, history: OpenAIHistorySchema, functions: OpenAIFunctionsSchema | null, forceFunctionName: string | null, n: number, stream: boolean): OpenaiWire_ChatCompletionRequest { // Hotfixes to comply with API restrictions const hotfixAlternateUARoles = dialect === 'perplexity';