AIX: dynamic streaming support

This commit is contained in:
Enrico Ros
2024-07-11 02:31:44 -07:00
parent de139cada0
commit 1db71d9ba7
8 changed files with 138 additions and 64 deletions
+5 -5
View File
@@ -93,17 +93,17 @@ async function _aixStreamGenerateUnified(
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
): Promise<void> {
const x = await apiStream.aix.chatGenerateContentStream.mutate(
{ access, model, chatGenerate, context },
const operation = await apiStream.aix.chatGenerateContent.mutate(
{ access, model, chatGenerate, context, streaming: true },
{ signal: abortSignal },
);
let incrementalText = '';
try {
for await (const update of x) {
console.log('cs update:', update);
for await (const update of operation) {
// console.log('cs update:', update);
// TODO: improve this recombination protocol...
if ('t' in update) {
incrementalText += update.t;
onUpdate({ textSoFar: incrementalText, typing: true }, false);
+27 -39
View File
@@ -15,21 +15,20 @@ export const aixRouter = createTRPCRouter({
* Chat content generation, streaming, multipart.
* Architecture: Client <-- (intake) --> Server <-- (dispatch) --> AI Service
*/
chatGenerateContentStream: publicProcedure
chatGenerateContent: publicProcedure
.input(z.object({
access: intakeAccessSchema,
model: intakeModelSchema,
chatGenerate: intakeChatGenerateRequestSchema,
context: intakeContextChatStreamSchema,
streaming: z.boolean(),
}))
.mutation(async function* ({ input, ctx }) {
// Using the variable to keep the implementation generic
const streaming = true;
// Intake derived state
const intakeAbortSignal = ctx.reqSignal;
const { access, model, chatGenerate } = input;
const { access, model, chatGenerate, streaming } = input;
const accessDialect = access.dialect;
const prettyDialect = serverCapitalizeFirstLetter(accessDialect);
@@ -43,7 +42,7 @@ export const aixRouter = createTRPCRouter({
try {
dispatch = createDispatch(access, model, chatGenerate, streaming);
} catch (error: any) {
yield* intakeHandler.yieldError('dispatch-prepare', `**[Service Creation Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`);
yield* intakeHandler.yieldError('dispatch-prepare', `**[Configuration Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`);
return; // exit
}
@@ -73,11 +72,22 @@ export const aixRouter = createTRPCRouter({
}
// Stream the response to the client
// [ALPHA] [NON-STREAMING] Read the full response and send operations down the intake
if (!streaming) {
try {
const dispatchBody = await dispatchResponse.text();
const messageAction = dispatch.parser(dispatchBody);
yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect);
} catch (error: any) {
yield* intakeHandler.yieldError('dispatch-read', `**[Service Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service reading error'}`);
}
return; // exit
}
// STREAM the response to the client
const dispatchReader = (dispatchResponse.body || createEmptyReadableStream()).getReader();
const dispatchDecoder = new TextDecoder('utf-8', { fatal: false /* malformed data -> “ ” (U+FFFD) */ });
const dispatchDemuxer = dispatch.demuxer.demux;
const dispatchParser = dispatch.parser;
// Data pump: AI Service -- (dispatch) --> Server -- (intake) --> Client
do {
@@ -109,51 +119,29 @@ export const aixRouter = createTRPCRouter({
// Demux the chunk into 0 or more events
for (const demuxedEvent of dispatchDemuxer(dispatchChunk)) {
intakeHandler.onReceivedDispatchEvent(demuxedEvent);
for (const demuxedItem of dispatch.demuxer.demux(dispatchChunk)) {
intakeHandler.onReceivedDispatchEvent(demuxedItem);
// ignore events post termination
if (intakeHandler.intakeTerminated) {
// warning on, because this is pretty important
console.warn('/api/llms/stream: Received event after termination:', demuxedEvent);
// warning on, because this is important and a sign of a bug
console.warn('/api/llms/stream: Received event after termination:', demuxedItem);
break; // inner for {}
}
// ignore superfluos stream events
if (demuxedEvent.type !== 'event')
if (demuxedItem.type !== 'event')
continue; // inner for {}
// [OpenAI] Special: event stream termination, close our transformed stream
if (demuxedEvent.data === '[DONE]') {
// [OpenAI] Special: stream termination marker
if (demuxedItem.data === '[DONE]') {
yield* intakeHandler.yieldTermination('event-done');
break; // inner for {}, then outer do
}
try {
const parsedEvents = dispatchParser(demuxedEvent.data, demuxedEvent.name);
for (const upe of parsedEvents) {
console.log('parsed dispatch:', upe);
// TODO: massively rework this into a good protocol
if (upe.op === 'parser-close') {
yield* intakeHandler.yieldTermination('parser-done');
break;
} else if (upe.op === 'text') {
yield* intakeHandler.yieldOp({
t: upe.text,
});
} else if (upe.op === 'issue') {
yield* intakeHandler.yieldOp({
t: ` ${upe.symbol} **[${prettyDialect} Issue]:** ${upe.issue}`,
});
} else if (upe.op === 'set') {
yield* intakeHandler.yieldOp({
set: upe.value,
});
} else {
// shall never reach this
console.error('Unexpected stream event:', upe);
}
}
const messageAction = dispatch.parser(demuxedItem.data, demuxedItem.name);
yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect);
} catch (error: any) {
yield* intakeHandler.yieldError('dispatch-parse', ` **[Service Parsing Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown stream parsing error'}. Please open a support ticket.`);
break; // inner for {}, then outer do
@@ -12,7 +12,7 @@ const hotFixMapModelImagesToUser = true;
const DEFAULT_MAX_TOKENS = 4096;
export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, stream: boolean, conversionWarnings: string[]): AnthropicWire_MessageCreate {
export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, streaming: boolean): AnthropicWire_MessageCreate {
// Convert the system message
const systemMessage: AnthropicWire_MessageCreate['system'] = chatGenerate.systemMessage?.parts.length
@@ -45,7 +45,7 @@ export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate:
tool_choice: chatGenerate.toolsPolicy && _intakeToAnthropicToolChoice(chatGenerate.toolsPolicy),
// metadata: { user_id: ... }
// stop_sequences: undefined,
stream: stream,
stream: streaming,
temperature: model.temperature !== undefined ? model.temperature : undefined,
// top_k: undefined,
// top_p: undefined,
@@ -7,8 +7,8 @@ import type { IntakeAccess, IntakeChatGenerateRequest, IntakeModel } from '../in
import { intakeToAnthropicMessageCreate } from './anthropic/anthropic.adapters';
import { createDispatchDemuxer } from './dispatch.demuxers';
import { createDispatchParserAnthropicMessages, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers';
import { createDispatchDemuxer, nullDispatchDemuxer } from './dispatch.demuxers';
import { createDispatchParserAnthropicMessage, createDispatchParserAnthropicNS, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers';
import { geminiModelsGenerateContentPath, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes';
@@ -71,16 +71,15 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
}
const conversionWarnings: string[] = [];
switch (access.dialect) {
case 'anthropic':
return {
request: {
...anthropicAccess(access, '/v1/messages'),
body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming, conversionWarnings),
body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming),
},
demuxer: createDispatchDemuxer('sse'),
parser: createDispatchParserAnthropicMessages(),
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
parser: streaming ? createDispatchParserAnthropicMessage() : createDispatchParserAnthropicNS(),
};
case 'gemini':
@@ -89,7 +88,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
...geminiAccess(access, model.id, streaming ? geminiModelsStreamGenerateContentPath : geminiModelsGenerateContentPath),
body: geminiGenerateContentTextPayload(model, _hist, access.minSafetyLevel, 1),
},
demuxer: createDispatchDemuxer('sse'),
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
parser: createDispatchParserGemini(model.id.replace('models/', '')),
};
@@ -99,7 +98,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
...ollamaAccess(access, OLLAMA_PATH_CHAT),
body: ollamaChatCompletionPayload(model, _hist, access.ollamaJson, streaming),
},
demuxer: createDispatchDemuxer('json-nl'),
demuxer: streaming ? createDispatchDemuxer('json-nl') : nullDispatchDemuxer,
parser: createDispatchParserOllama(),
};
@@ -119,7 +118,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
...openAIAccess(access, model.id, '/v1/chat/completions'),
body: openAIChatCompletionPayload(access.dialect, model, _hist, null, null, 1, streaming),
},
demuxer: createDispatchDemuxer('sse'),
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
parser: createDispatchParserOpenAI(),
};
}
@@ -28,6 +28,14 @@ export function createDispatchDemuxer(format: DispatchDemuxFormat) {
}
}
export const nullDispatchDemuxer: DispatchDemuxer = {
demux: () => {
console.warn('Null demuxer called - shall not happen, as it is only created in non-streaming');
return [];
},
remaining: () => '',
};
/**
* Creates a parser for an EventSource stream (e.g. OpenAI's format).
@@ -2,7 +2,7 @@ import { z } from 'zod';
import { safeErrorString } from '~/server/wire';
import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes';
import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageResponse_Schema, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes';
import { geminiGeneratedContentResponseSchema, geminiHarmProbabilitySortFunction, GeminiSafetyRatings } from './gemini/gemini.wiretypes';
import { openaiWire_ChatCompletionChunkResponse_Schema } from './openai/oai.wiretypes';
import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes';
@@ -15,7 +15,7 @@ const ISSUE_SYMBOL_RECITATION = '🦜';
const TEXT_SYMBOL_MAX_TOKENS = '🧱';
type DispatchParsedEvent = {
export type DispatchMessageAction = {
op: 'text',
text: string;
} | {
@@ -38,12 +38,12 @@ type DispatchParsedEvent = {
};
};
export type DispatchParser = (eventData: string, eventName?: string) => Generator<DispatchParsedEvent>;
export type DispatchParser = (eventData: string, eventName?: string) => Generator<DispatchMessageAction>;
/// Stream Parsers
export function createDispatchParserAnthropicMessages(): DispatchParser {
export function createDispatchParserAnthropicMessage(): DispatchParser {
let responseMessage: AnthropicWire_MessageResponse;
let hasErrored = false;
let messageStartTime: number | undefined = undefined;
@@ -52,7 +52,7 @@ export function createDispatchParserAnthropicMessages(): DispatchParser {
// Note: at this stage, the parser only returns the text content as text, which is streamed as text
// to the client. It is however building in parallel the responseMessage object, which is not
// yet used, but contains token counts, for instance.
return function* (eventData: string, eventName?: string): Generator<DispatchParsedEvent> {
return function* (eventData: string, eventName?: string): Generator<DispatchMessageAction> {
// if we've errored, we should not be receiving more data
if (hasErrored)
@@ -167,6 +167,59 @@ export function createDispatchParserAnthropicMessages(): DispatchParser {
}
export function createDispatchParserAnthropicNS(): DispatchParser {
let messageStartTime: number = Date.now();
return function* (fullData: string): Generator<DispatchMessageAction> {
// parse with validation (e.g. type: 'message' && role: 'assistant')
const {
model,
content,
stop_reason,
usage,
} = anthropicWire_MessageResponse_Schema.parse(JSON.parse(fullData));
// -> Model
if (model)
yield { op: 'set', value: { model } };
// -> Content Blocks
for (let i = 0; i < content.length; i++) {
const contentBlock = content[i];
const isLastBlock = i === content.length - 1;
switch (contentBlock.type) {
case 'text':
const hitMaxTokens = (isLastBlock && stop_reason === 'max_tokens') ? ` ${TEXT_SYMBOL_MAX_TOKENS}` : '';
yield { op: 'text', text: contentBlock.text + hitMaxTokens };
break;
case 'tool_use':
yield { op: 'text', text: `TODO: [Tool Use] ${contentBlock.id} ${contentBlock.name} ${JSON.stringify(contentBlock.input)}` };
break;
default:
throw new Error(`Unexpected content block type: ${(contentBlock as any).type}`);
}
}
// -> Stats
if (usage) {
const elapsedTimeSeconds = (Date.now() - messageStartTime) / 1000;
const chatOutRate = elapsedTimeSeconds > 0 ? usage.output_tokens / elapsedTimeSeconds : 0;
yield {
op: 'set', value: {
stats: {
chatInTokens: usage.input_tokens,
chatOutTokens: usage.output_tokens,
chatOutRate: Math.round(chatOutRate * 100) / 100, // Round to 2 decimal places
timeInner: elapsedTimeSeconds,
},
},
};
}
};
}
function explainGeminiSafetyIssues(safetyRatings?: GeminiSafetyRatings): string {
if (!safetyRatings || !safetyRatings.length)
return 'no safety ratings provided';
@@ -182,7 +235,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser {
let hasBegun = false;
// this can throw, it's caught by the caller
return function* (eventData): Generator<DispatchParsedEvent> {
return function* (eventData): Generator<DispatchMessageAction> {
// parse the JSON chunk
const wireGenerationChunk = JSON.parse(eventData);
@@ -253,7 +306,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser {
export function createDispatchParserOllama(): DispatchParser {
let hasBegun = false;
return function* (eventData: string): Generator<DispatchParsedEvent> {
return function* (eventData: string): Generator<DispatchMessageAction> {
// parse the JSON chunk
let wireJsonChunk: any;
@@ -302,7 +355,7 @@ export function createDispatchParserOpenAI(): DispatchParser {
let hasWarned = false;
// NOTE: could compute rate (tok/s) from the first textful event to the last (to ignore the prefill time)
return function* (eventData: string): Generator<DispatchParsedEvent> {
return function* (eventData: string): Generator<DispatchMessageAction> {
// Throws on malformed event data
const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(eventData));
@@ -1,6 +1,7 @@
import { SERVER_DEBUG_WIRE } from '~/server/wire';
import type { DemuxedEvent } from '../dispatch/dispatch.demuxers';
import type { DispatchMessageAction } from '../dispatch/dispatch.parsers';
// type IntakeProtoObject = IntakeControlProtoObject | IntakeEventProtoObject;
@@ -43,6 +44,32 @@ export class IntakeHandler {
yield op;
}
* yieldDmaOps(parsedEvents: Generator<DispatchMessageAction>, prettyDialect: string) {
for (const dma of parsedEvents) {
// console.log('parsed dispatch:', dma);
// TODO: massively rework this into a good protocol
if (dma.op === 'parser-close') {
yield* this.yieldTermination('parser-done');
break;
} else if (dma.op === 'text') {
yield* this.yieldOp({
t: dma.text,
});
} else if (dma.op === 'issue') {
yield* this.yieldOp({
t: ` ${dma.symbol} **[${prettyDialect} Issue]:** ${dma.issue}`,
});
} else if (dma.op === 'set') {
yield* this.yieldOp({
set: dma.value,
});
} else {
// shall never reach this
console.error('Unexpected stream event:', dma);
}
}
}
* yieldError(errorId: 'dispatch-prepare' | 'dispatch-fetch' | 'dispatch-read' | 'dispatch-parse', errorText: string, forceConsoleMessage?: boolean) {
if (SERVER_DEBUG_WIRE || forceConsoleMessage || true)
console.error(`[POST] /api/llms/stream: ${this.prettyDialect}: ${errorId}: ${errorText}`);
@@ -54,7 +54,6 @@ const openAPISchemaObjectSchema = z.object({
// an object-only subset of the above, which is the JSON object owner of the parameters
const intakeFunctionCallInputSchemaSchema = z.object({
type: z.literal('object'),
properties: z.record(openAPISchemaObjectSchema),
required: z.array(z.string()).optional(),
});