diff --git a/src/modules/aix/client/aix.client.ts b/src/modules/aix/client/aix.client.ts index ff5fd0842..8e4b28791 100644 --- a/src/modules/aix/client/aix.client.ts +++ b/src/modules/aix/client/aix.client.ts @@ -93,17 +93,17 @@ async function _aixStreamGenerateUnified( onUpdate: (update: StreamingClientUpdate, done: boolean) => void, ): Promise { - const x = await apiStream.aix.chatGenerateContentStream.mutate( - { access, model, chatGenerate, context }, + const operation = await apiStream.aix.chatGenerateContent.mutate( + { access, model, chatGenerate, context, streaming: true }, { signal: abortSignal }, ); let incrementalText = ''; try { - for await (const update of x) { - console.log('cs update:', update); - + for await (const update of operation) { + // console.log('cs update:', update); + // TODO: improve this recombination protocol... if ('t' in update) { incrementalText += update.t; onUpdate({ textSoFar: incrementalText, typing: true }, false); diff --git a/src/modules/aix/server/aix.router.ts b/src/modules/aix/server/aix.router.ts index d615fd39c..0113b2a6d 100644 --- a/src/modules/aix/server/aix.router.ts +++ b/src/modules/aix/server/aix.router.ts @@ -15,21 +15,20 @@ export const aixRouter = createTRPCRouter({ * Chat content generation, streaming, multipart. * Architecture: Client <-- (intake) --> Server <-- (dispatch) --> AI Service */ - chatGenerateContentStream: publicProcedure + chatGenerateContent: publicProcedure .input(z.object({ access: intakeAccessSchema, model: intakeModelSchema, chatGenerate: intakeChatGenerateRequestSchema, context: intakeContextChatStreamSchema, + streaming: z.boolean(), })) .mutation(async function* ({ input, ctx }) { - // Using the variable to keep the implementation generic - const streaming = true; // Intake derived state const intakeAbortSignal = ctx.reqSignal; - const { access, model, chatGenerate } = input; + const { access, model, chatGenerate, streaming } = input; const accessDialect = access.dialect; const prettyDialect = serverCapitalizeFirstLetter(accessDialect); @@ -43,7 +42,7 @@ export const aixRouter = createTRPCRouter({ try { dispatch = createDispatch(access, model, chatGenerate, streaming); } catch (error: any) { - yield* intakeHandler.yieldError('dispatch-prepare', `**[Service Creation Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`); + yield* intakeHandler.yieldError('dispatch-prepare', `**[Configuration Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`); return; // exit } @@ -73,11 +72,22 @@ export const aixRouter = createTRPCRouter({ } - // Stream the response to the client + // [ALPHA] [NON-STREAMING] Read the full response and send operations down the intake + if (!streaming) { + try { + const dispatchBody = await dispatchResponse.text(); + const messageAction = dispatch.parser(dispatchBody); + yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect); + } catch (error: any) { + yield* intakeHandler.yieldError('dispatch-read', `**[Service Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service reading error'}`); + } + return; // exit + } + + + // STREAM the response to the client const dispatchReader = (dispatchResponse.body || createEmptyReadableStream()).getReader(); const dispatchDecoder = new TextDecoder('utf-8', { fatal: false /* malformed data -> “ ” (U+FFFD) */ }); - const dispatchDemuxer = dispatch.demuxer.demux; - const dispatchParser = dispatch.parser; // Data pump: AI Service -- (dispatch) --> Server -- (intake) --> Client do { @@ -109,51 +119,29 @@ export const aixRouter = createTRPCRouter({ // Demux the chunk into 0 or more events - for (const demuxedEvent of dispatchDemuxer(dispatchChunk)) { - intakeHandler.onReceivedDispatchEvent(demuxedEvent); + for (const demuxedItem of dispatch.demuxer.demux(dispatchChunk)) { + intakeHandler.onReceivedDispatchEvent(demuxedItem); // ignore events post termination if (intakeHandler.intakeTerminated) { - // warning on, because this is pretty important - console.warn('/api/llms/stream: Received event after termination:', demuxedEvent); + // warning on, because this is important and a sign of a bug + console.warn('/api/llms/stream: Received event after termination:', demuxedItem); break; // inner for {} } // ignore superfluos stream events - if (demuxedEvent.type !== 'event') + if (demuxedItem.type !== 'event') continue; // inner for {} - // [OpenAI] Special: event stream termination, close our transformed stream - if (demuxedEvent.data === '[DONE]') { + // [OpenAI] Special: stream termination marker + if (demuxedItem.data === '[DONE]') { yield* intakeHandler.yieldTermination('event-done'); break; // inner for {}, then outer do } try { - const parsedEvents = dispatchParser(demuxedEvent.data, demuxedEvent.name); - for (const upe of parsedEvents) { - console.log('parsed dispatch:', upe); - // TODO: massively rework this into a good protocol - if (upe.op === 'parser-close') { - yield* intakeHandler.yieldTermination('parser-done'); - break; - } else if (upe.op === 'text') { - yield* intakeHandler.yieldOp({ - t: upe.text, - }); - } else if (upe.op === 'issue') { - yield* intakeHandler.yieldOp({ - t: ` ${upe.symbol} **[${prettyDialect} Issue]:** ${upe.issue}`, - }); - } else if (upe.op === 'set') { - yield* intakeHandler.yieldOp({ - set: upe.value, - }); - } else { - // shall never reach this - console.error('Unexpected stream event:', upe); - } - } + const messageAction = dispatch.parser(demuxedItem.data, demuxedItem.name); + yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect); } catch (error: any) { yield* intakeHandler.yieldError('dispatch-parse', ` **[Service Parsing Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown stream parsing error'}. Please open a support ticket.`); break; // inner for {}, then outer do diff --git a/src/modules/aix/server/dispatch/anthropic/anthropic.adapters.ts b/src/modules/aix/server/dispatch/anthropic/anthropic.adapters.ts index a298329d6..525b91024 100644 --- a/src/modules/aix/server/dispatch/anthropic/anthropic.adapters.ts +++ b/src/modules/aix/server/dispatch/anthropic/anthropic.adapters.ts @@ -12,7 +12,7 @@ const hotFixMapModelImagesToUser = true; const DEFAULT_MAX_TOKENS = 4096; -export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, stream: boolean, conversionWarnings: string[]): AnthropicWire_MessageCreate { +export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, streaming: boolean): AnthropicWire_MessageCreate { // Convert the system message const systemMessage: AnthropicWire_MessageCreate['system'] = chatGenerate.systemMessage?.parts.length @@ -45,7 +45,7 @@ export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: tool_choice: chatGenerate.toolsPolicy && _intakeToAnthropicToolChoice(chatGenerate.toolsPolicy), // metadata: { user_id: ... } // stop_sequences: undefined, - stream: stream, + stream: streaming, temperature: model.temperature !== undefined ? model.temperature : undefined, // top_k: undefined, // top_p: undefined, diff --git a/src/modules/aix/server/dispatch/createDispatch.ts b/src/modules/aix/server/dispatch/createDispatch.ts index 283efc151..489835f24 100644 --- a/src/modules/aix/server/dispatch/createDispatch.ts +++ b/src/modules/aix/server/dispatch/createDispatch.ts @@ -7,8 +7,8 @@ import type { IntakeAccess, IntakeChatGenerateRequest, IntakeModel } from '../in import { intakeToAnthropicMessageCreate } from './anthropic/anthropic.adapters'; -import { createDispatchDemuxer } from './dispatch.demuxers'; -import { createDispatchParserAnthropicMessages, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers'; +import { createDispatchDemuxer, nullDispatchDemuxer } from './dispatch.demuxers'; +import { createDispatchParserAnthropicMessage, createDispatchParserAnthropicNS, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers'; import { geminiModelsGenerateContentPath, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes'; @@ -71,16 +71,15 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen } - const conversionWarnings: string[] = []; switch (access.dialect) { case 'anthropic': return { request: { ...anthropicAccess(access, '/v1/messages'), - body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming, conversionWarnings), + body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming), }, - demuxer: createDispatchDemuxer('sse'), - parser: createDispatchParserAnthropicMessages(), + demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer, + parser: streaming ? createDispatchParserAnthropicMessage() : createDispatchParserAnthropicNS(), }; case 'gemini': @@ -89,7 +88,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen ...geminiAccess(access, model.id, streaming ? geminiModelsStreamGenerateContentPath : geminiModelsGenerateContentPath), body: geminiGenerateContentTextPayload(model, _hist, access.minSafetyLevel, 1), }, - demuxer: createDispatchDemuxer('sse'), + demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer, parser: createDispatchParserGemini(model.id.replace('models/', '')), }; @@ -99,7 +98,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen ...ollamaAccess(access, OLLAMA_PATH_CHAT), body: ollamaChatCompletionPayload(model, _hist, access.ollamaJson, streaming), }, - demuxer: createDispatchDemuxer('json-nl'), + demuxer: streaming ? createDispatchDemuxer('json-nl') : nullDispatchDemuxer, parser: createDispatchParserOllama(), }; @@ -119,7 +118,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen ...openAIAccess(access, model.id, '/v1/chat/completions'), body: openAIChatCompletionPayload(access.dialect, model, _hist, null, null, 1, streaming), }, - demuxer: createDispatchDemuxer('sse'), + demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer, parser: createDispatchParserOpenAI(), }; } diff --git a/src/modules/aix/server/dispatch/dispatch.demuxers.ts b/src/modules/aix/server/dispatch/dispatch.demuxers.ts index 747a0863f..21c1f185c 100644 --- a/src/modules/aix/server/dispatch/dispatch.demuxers.ts +++ b/src/modules/aix/server/dispatch/dispatch.demuxers.ts @@ -28,6 +28,14 @@ export function createDispatchDemuxer(format: DispatchDemuxFormat) { } } +export const nullDispatchDemuxer: DispatchDemuxer = { + demux: () => { + console.warn('Null demuxer called - shall not happen, as it is only created in non-streaming'); + return []; + }, + remaining: () => '', +}; + /** * Creates a parser for an EventSource stream (e.g. OpenAI's format). diff --git a/src/modules/aix/server/dispatch/dispatch.parsers.ts b/src/modules/aix/server/dispatch/dispatch.parsers.ts index e840209ea..903d57d20 100644 --- a/src/modules/aix/server/dispatch/dispatch.parsers.ts +++ b/src/modules/aix/server/dispatch/dispatch.parsers.ts @@ -2,7 +2,7 @@ import { z } from 'zod'; import { safeErrorString } from '~/server/wire'; -import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes'; +import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageResponse_Schema, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes'; import { geminiGeneratedContentResponseSchema, geminiHarmProbabilitySortFunction, GeminiSafetyRatings } from './gemini/gemini.wiretypes'; import { openaiWire_ChatCompletionChunkResponse_Schema } from './openai/oai.wiretypes'; import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes'; @@ -15,7 +15,7 @@ const ISSUE_SYMBOL_RECITATION = '🦜'; const TEXT_SYMBOL_MAX_TOKENS = '🧱'; -type DispatchParsedEvent = { +export type DispatchMessageAction = { op: 'text', text: string; } | { @@ -38,12 +38,12 @@ type DispatchParsedEvent = { }; }; -export type DispatchParser = (eventData: string, eventName?: string) => Generator; +export type DispatchParser = (eventData: string, eventName?: string) => Generator; /// Stream Parsers -export function createDispatchParserAnthropicMessages(): DispatchParser { +export function createDispatchParserAnthropicMessage(): DispatchParser { let responseMessage: AnthropicWire_MessageResponse; let hasErrored = false; let messageStartTime: number | undefined = undefined; @@ -52,7 +52,7 @@ export function createDispatchParserAnthropicMessages(): DispatchParser { // Note: at this stage, the parser only returns the text content as text, which is streamed as text // to the client. It is however building in parallel the responseMessage object, which is not // yet used, but contains token counts, for instance. - return function* (eventData: string, eventName?: string): Generator { + return function* (eventData: string, eventName?: string): Generator { // if we've errored, we should not be receiving more data if (hasErrored) @@ -167,6 +167,59 @@ export function createDispatchParserAnthropicMessages(): DispatchParser { } +export function createDispatchParserAnthropicNS(): DispatchParser { + let messageStartTime: number = Date.now(); + + return function* (fullData: string): Generator { + + // parse with validation (e.g. type: 'message' && role: 'assistant') + const { + model, + content, + stop_reason, + usage, + } = anthropicWire_MessageResponse_Schema.parse(JSON.parse(fullData)); + + // -> Model + if (model) + yield { op: 'set', value: { model } }; + + // -> Content Blocks + for (let i = 0; i < content.length; i++) { + const contentBlock = content[i]; + const isLastBlock = i === content.length - 1; + switch (contentBlock.type) { + case 'text': + const hitMaxTokens = (isLastBlock && stop_reason === 'max_tokens') ? ` ${TEXT_SYMBOL_MAX_TOKENS}` : ''; + yield { op: 'text', text: contentBlock.text + hitMaxTokens }; + break; + case 'tool_use': + yield { op: 'text', text: `TODO: [Tool Use] ${contentBlock.id} ${contentBlock.name} ${JSON.stringify(contentBlock.input)}` }; + break; + default: + throw new Error(`Unexpected content block type: ${(contentBlock as any).type}`); + } + } + + // -> Stats + if (usage) { + const elapsedTimeSeconds = (Date.now() - messageStartTime) / 1000; + const chatOutRate = elapsedTimeSeconds > 0 ? usage.output_tokens / elapsedTimeSeconds : 0; + yield { + op: 'set', value: { + stats: { + chatInTokens: usage.input_tokens, + chatOutTokens: usage.output_tokens, + chatOutRate: Math.round(chatOutRate * 100) / 100, // Round to 2 decimal places + timeInner: elapsedTimeSeconds, + }, + }, + }; + } + }; +} + + function explainGeminiSafetyIssues(safetyRatings?: GeminiSafetyRatings): string { if (!safetyRatings || !safetyRatings.length) return 'no safety ratings provided'; @@ -182,7 +235,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser { let hasBegun = false; // this can throw, it's caught by the caller - return function* (eventData): Generator { + return function* (eventData): Generator { // parse the JSON chunk const wireGenerationChunk = JSON.parse(eventData); @@ -253,7 +306,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser { export function createDispatchParserOllama(): DispatchParser { let hasBegun = false; - return function* (eventData: string): Generator { + return function* (eventData: string): Generator { // parse the JSON chunk let wireJsonChunk: any; @@ -302,7 +355,7 @@ export function createDispatchParserOpenAI(): DispatchParser { let hasWarned = false; // NOTE: could compute rate (tok/s) from the first textful event to the last (to ignore the prefill time) - return function* (eventData: string): Generator { + return function* (eventData: string): Generator { // Throws on malformed event data const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(eventData)); diff --git a/src/modules/aix/server/intake/IntakeHandler.ts b/src/modules/aix/server/intake/IntakeHandler.ts index 832d11bb8..510d7d917 100644 --- a/src/modules/aix/server/intake/IntakeHandler.ts +++ b/src/modules/aix/server/intake/IntakeHandler.ts @@ -1,6 +1,7 @@ import { SERVER_DEBUG_WIRE } from '~/server/wire'; import type { DemuxedEvent } from '../dispatch/dispatch.demuxers'; +import type { DispatchMessageAction } from '../dispatch/dispatch.parsers'; // type IntakeProtoObject = IntakeControlProtoObject | IntakeEventProtoObject; @@ -43,6 +44,32 @@ export class IntakeHandler { yield op; } + * yieldDmaOps(parsedEvents: Generator, prettyDialect: string) { + for (const dma of parsedEvents) { + // console.log('parsed dispatch:', dma); + // TODO: massively rework this into a good protocol + if (dma.op === 'parser-close') { + yield* this.yieldTermination('parser-done'); + break; + } else if (dma.op === 'text') { + yield* this.yieldOp({ + t: dma.text, + }); + } else if (dma.op === 'issue') { + yield* this.yieldOp({ + t: ` ${dma.symbol} **[${prettyDialect} Issue]:** ${dma.issue}`, + }); + } else if (dma.op === 'set') { + yield* this.yieldOp({ + set: dma.value, + }); + } else { + // shall never reach this + console.error('Unexpected stream event:', dma); + } + } + } + * yieldError(errorId: 'dispatch-prepare' | 'dispatch-fetch' | 'dispatch-read' | 'dispatch-parse', errorText: string, forceConsoleMessage?: boolean) { if (SERVER_DEBUG_WIRE || forceConsoleMessage || true) console.error(`[POST] /api/llms/stream: ${this.prettyDialect}: ${errorId}: ${errorText}`); diff --git a/src/modules/aix/server/intake/schemas.intake.tools.ts b/src/modules/aix/server/intake/schemas.intake.tools.ts index a45554dc2..2f8e322e8 100644 --- a/src/modules/aix/server/intake/schemas.intake.tools.ts +++ b/src/modules/aix/server/intake/schemas.intake.tools.ts @@ -54,7 +54,6 @@ const openAPISchemaObjectSchema = z.object({ // an object-only subset of the above, which is the JSON object owner of the parameters const intakeFunctionCallInputSchemaSchema = z.object({ - type: z.literal('object'), properties: z.record(openAPISchemaObjectSchema), required: z.array(z.string()).optional(), });