mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
AIX: dynamic streaming support
This commit is contained in:
@@ -93,17 +93,17 @@ async function _aixStreamGenerateUnified(
|
||||
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
|
||||
): Promise<void> {
|
||||
|
||||
const x = await apiStream.aix.chatGenerateContentStream.mutate(
|
||||
{ access, model, chatGenerate, context },
|
||||
const operation = await apiStream.aix.chatGenerateContent.mutate(
|
||||
{ access, model, chatGenerate, context, streaming: true },
|
||||
{ signal: abortSignal },
|
||||
);
|
||||
|
||||
let incrementalText = '';
|
||||
|
||||
try {
|
||||
for await (const update of x) {
|
||||
console.log('cs update:', update);
|
||||
|
||||
for await (const update of operation) {
|
||||
// console.log('cs update:', update);
|
||||
// TODO: improve this recombination protocol...
|
||||
if ('t' in update) {
|
||||
incrementalText += update.t;
|
||||
onUpdate({ textSoFar: incrementalText, typing: true }, false);
|
||||
|
||||
@@ -15,21 +15,20 @@ export const aixRouter = createTRPCRouter({
|
||||
* Chat content generation, streaming, multipart.
|
||||
* Architecture: Client <-- (intake) --> Server <-- (dispatch) --> AI Service
|
||||
*/
|
||||
chatGenerateContentStream: publicProcedure
|
||||
chatGenerateContent: publicProcedure
|
||||
.input(z.object({
|
||||
access: intakeAccessSchema,
|
||||
model: intakeModelSchema,
|
||||
chatGenerate: intakeChatGenerateRequestSchema,
|
||||
context: intakeContextChatStreamSchema,
|
||||
streaming: z.boolean(),
|
||||
}))
|
||||
.mutation(async function* ({ input, ctx }) {
|
||||
|
||||
// Using the variable to keep the implementation generic
|
||||
const streaming = true;
|
||||
|
||||
// Intake derived state
|
||||
const intakeAbortSignal = ctx.reqSignal;
|
||||
const { access, model, chatGenerate } = input;
|
||||
const { access, model, chatGenerate, streaming } = input;
|
||||
const accessDialect = access.dialect;
|
||||
const prettyDialect = serverCapitalizeFirstLetter(accessDialect);
|
||||
|
||||
@@ -43,7 +42,7 @@ export const aixRouter = createTRPCRouter({
|
||||
try {
|
||||
dispatch = createDispatch(access, model, chatGenerate, streaming);
|
||||
} catch (error: any) {
|
||||
yield* intakeHandler.yieldError('dispatch-prepare', `**[Service Creation Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`);
|
||||
yield* intakeHandler.yieldError('dispatch-prepare', `**[Configuration Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service preparation error'}`);
|
||||
return; // exit
|
||||
}
|
||||
|
||||
@@ -73,11 +72,22 @@ export const aixRouter = createTRPCRouter({
|
||||
}
|
||||
|
||||
|
||||
// Stream the response to the client
|
||||
// [ALPHA] [NON-STREAMING] Read the full response and send operations down the intake
|
||||
if (!streaming) {
|
||||
try {
|
||||
const dispatchBody = await dispatchResponse.text();
|
||||
const messageAction = dispatch.parser(dispatchBody);
|
||||
yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect);
|
||||
} catch (error: any) {
|
||||
yield* intakeHandler.yieldError('dispatch-read', `**[Service Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown service reading error'}`);
|
||||
}
|
||||
return; // exit
|
||||
}
|
||||
|
||||
|
||||
// STREAM the response to the client
|
||||
const dispatchReader = (dispatchResponse.body || createEmptyReadableStream()).getReader();
|
||||
const dispatchDecoder = new TextDecoder('utf-8', { fatal: false /* malformed data -> “ ” (U+FFFD) */ });
|
||||
const dispatchDemuxer = dispatch.demuxer.demux;
|
||||
const dispatchParser = dispatch.parser;
|
||||
|
||||
// Data pump: AI Service -- (dispatch) --> Server -- (intake) --> Client
|
||||
do {
|
||||
@@ -109,51 +119,29 @@ export const aixRouter = createTRPCRouter({
|
||||
|
||||
|
||||
// Demux the chunk into 0 or more events
|
||||
for (const demuxedEvent of dispatchDemuxer(dispatchChunk)) {
|
||||
intakeHandler.onReceivedDispatchEvent(demuxedEvent);
|
||||
for (const demuxedItem of dispatch.demuxer.demux(dispatchChunk)) {
|
||||
intakeHandler.onReceivedDispatchEvent(demuxedItem);
|
||||
|
||||
// ignore events post termination
|
||||
if (intakeHandler.intakeTerminated) {
|
||||
// warning on, because this is pretty important
|
||||
console.warn('/api/llms/stream: Received event after termination:', demuxedEvent);
|
||||
// warning on, because this is important and a sign of a bug
|
||||
console.warn('/api/llms/stream: Received event after termination:', demuxedItem);
|
||||
break; // inner for {}
|
||||
}
|
||||
|
||||
// ignore superfluos stream events
|
||||
if (demuxedEvent.type !== 'event')
|
||||
if (demuxedItem.type !== 'event')
|
||||
continue; // inner for {}
|
||||
|
||||
// [OpenAI] Special: event stream termination, close our transformed stream
|
||||
if (demuxedEvent.data === '[DONE]') {
|
||||
// [OpenAI] Special: stream termination marker
|
||||
if (demuxedItem.data === '[DONE]') {
|
||||
yield* intakeHandler.yieldTermination('event-done');
|
||||
break; // inner for {}, then outer do
|
||||
}
|
||||
|
||||
try {
|
||||
const parsedEvents = dispatchParser(demuxedEvent.data, demuxedEvent.name);
|
||||
for (const upe of parsedEvents) {
|
||||
console.log('parsed dispatch:', upe);
|
||||
// TODO: massively rework this into a good protocol
|
||||
if (upe.op === 'parser-close') {
|
||||
yield* intakeHandler.yieldTermination('parser-done');
|
||||
break;
|
||||
} else if (upe.op === 'text') {
|
||||
yield* intakeHandler.yieldOp({
|
||||
t: upe.text,
|
||||
});
|
||||
} else if (upe.op === 'issue') {
|
||||
yield* intakeHandler.yieldOp({
|
||||
t: ` ${upe.symbol} **[${prettyDialect} Issue]:** ${upe.issue}`,
|
||||
});
|
||||
} else if (upe.op === 'set') {
|
||||
yield* intakeHandler.yieldOp({
|
||||
set: upe.value,
|
||||
});
|
||||
} else {
|
||||
// shall never reach this
|
||||
console.error('Unexpected stream event:', upe);
|
||||
}
|
||||
}
|
||||
const messageAction = dispatch.parser(demuxedItem.data, demuxedItem.name);
|
||||
yield* intakeHandler.yieldDmaOps(messageAction, prettyDialect);
|
||||
} catch (error: any) {
|
||||
yield* intakeHandler.yieldError('dispatch-parse', ` **[Service Parsing Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown stream parsing error'}. Please open a support ticket.`);
|
||||
break; // inner for {}, then outer do
|
||||
|
||||
@@ -12,7 +12,7 @@ const hotFixMapModelImagesToUser = true;
|
||||
const DEFAULT_MAX_TOKENS = 4096;
|
||||
|
||||
|
||||
export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, stream: boolean, conversionWarnings: string[]): AnthropicWire_MessageCreate {
|
||||
export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, streaming: boolean): AnthropicWire_MessageCreate {
|
||||
|
||||
// Convert the system message
|
||||
const systemMessage: AnthropicWire_MessageCreate['system'] = chatGenerate.systemMessage?.parts.length
|
||||
@@ -45,7 +45,7 @@ export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate:
|
||||
tool_choice: chatGenerate.toolsPolicy && _intakeToAnthropicToolChoice(chatGenerate.toolsPolicy),
|
||||
// metadata: { user_id: ... }
|
||||
// stop_sequences: undefined,
|
||||
stream: stream,
|
||||
stream: streaming,
|
||||
temperature: model.temperature !== undefined ? model.temperature : undefined,
|
||||
// top_k: undefined,
|
||||
// top_p: undefined,
|
||||
|
||||
@@ -7,8 +7,8 @@ import type { IntakeAccess, IntakeChatGenerateRequest, IntakeModel } from '../in
|
||||
|
||||
import { intakeToAnthropicMessageCreate } from './anthropic/anthropic.adapters';
|
||||
|
||||
import { createDispatchDemuxer } from './dispatch.demuxers';
|
||||
import { createDispatchParserAnthropicMessages, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers';
|
||||
import { createDispatchDemuxer, nullDispatchDemuxer } from './dispatch.demuxers';
|
||||
import { createDispatchParserAnthropicMessage, createDispatchParserAnthropicNS, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers';
|
||||
import { geminiModelsGenerateContentPath, geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes';
|
||||
|
||||
|
||||
@@ -71,16 +71,15 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
|
||||
}
|
||||
|
||||
|
||||
const conversionWarnings: string[] = [];
|
||||
switch (access.dialect) {
|
||||
case 'anthropic':
|
||||
return {
|
||||
request: {
|
||||
...anthropicAccess(access, '/v1/messages'),
|
||||
body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming, conversionWarnings),
|
||||
body: intakeToAnthropicMessageCreate(model, chatGenerate, streaming),
|
||||
},
|
||||
demuxer: createDispatchDemuxer('sse'),
|
||||
parser: createDispatchParserAnthropicMessages(),
|
||||
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
|
||||
parser: streaming ? createDispatchParserAnthropicMessage() : createDispatchParserAnthropicNS(),
|
||||
};
|
||||
|
||||
case 'gemini':
|
||||
@@ -89,7 +88,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
|
||||
...geminiAccess(access, model.id, streaming ? geminiModelsStreamGenerateContentPath : geminiModelsGenerateContentPath),
|
||||
body: geminiGenerateContentTextPayload(model, _hist, access.minSafetyLevel, 1),
|
||||
},
|
||||
demuxer: createDispatchDemuxer('sse'),
|
||||
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
|
||||
parser: createDispatchParserGemini(model.id.replace('models/', '')),
|
||||
};
|
||||
|
||||
@@ -99,7 +98,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
|
||||
...ollamaAccess(access, OLLAMA_PATH_CHAT),
|
||||
body: ollamaChatCompletionPayload(model, _hist, access.ollamaJson, streaming),
|
||||
},
|
||||
demuxer: createDispatchDemuxer('json-nl'),
|
||||
demuxer: streaming ? createDispatchDemuxer('json-nl') : nullDispatchDemuxer,
|
||||
parser: createDispatchParserOllama(),
|
||||
};
|
||||
|
||||
@@ -119,7 +118,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
|
||||
...openAIAccess(access, model.id, '/v1/chat/completions'),
|
||||
body: openAIChatCompletionPayload(access.dialect, model, _hist, null, null, 1, streaming),
|
||||
},
|
||||
demuxer: createDispatchDemuxer('sse'),
|
||||
demuxer: streaming ? createDispatchDemuxer('sse') : nullDispatchDemuxer,
|
||||
parser: createDispatchParserOpenAI(),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -28,6 +28,14 @@ export function createDispatchDemuxer(format: DispatchDemuxFormat) {
|
||||
}
|
||||
}
|
||||
|
||||
export const nullDispatchDemuxer: DispatchDemuxer = {
|
||||
demux: () => {
|
||||
console.warn('Null demuxer called - shall not happen, as it is only created in non-streaming');
|
||||
return [];
|
||||
},
|
||||
remaining: () => '',
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* Creates a parser for an EventSource stream (e.g. OpenAI's format).
|
||||
|
||||
@@ -2,7 +2,7 @@ import { z } from 'zod';
|
||||
|
||||
import { safeErrorString } from '~/server/wire';
|
||||
|
||||
import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes';
|
||||
import { anthropicWire_ContentBlockDeltaEvent_Schema, anthropicWire_ContentBlockStartEvent_Schema, anthropicWire_ContentBlockStopEvent_Schema, anthropicWire_MessageDeltaEvent_Schema, AnthropicWire_MessageResponse, anthropicWire_MessageResponse_Schema, anthropicWire_MessageStartEvent_Schema, anthropicWire_MessageStopEvent_Schema } from './anthropic/anthropic.wiretypes';
|
||||
import { geminiGeneratedContentResponseSchema, geminiHarmProbabilitySortFunction, GeminiSafetyRatings } from './gemini/gemini.wiretypes';
|
||||
import { openaiWire_ChatCompletionChunkResponse_Schema } from './openai/oai.wiretypes';
|
||||
import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes';
|
||||
@@ -15,7 +15,7 @@ const ISSUE_SYMBOL_RECITATION = '🦜';
|
||||
const TEXT_SYMBOL_MAX_TOKENS = '🧱';
|
||||
|
||||
|
||||
type DispatchParsedEvent = {
|
||||
export type DispatchMessageAction = {
|
||||
op: 'text',
|
||||
text: string;
|
||||
} | {
|
||||
@@ -38,12 +38,12 @@ type DispatchParsedEvent = {
|
||||
};
|
||||
};
|
||||
|
||||
export type DispatchParser = (eventData: string, eventName?: string) => Generator<DispatchParsedEvent>;
|
||||
export type DispatchParser = (eventData: string, eventName?: string) => Generator<DispatchMessageAction>;
|
||||
|
||||
|
||||
/// Stream Parsers
|
||||
|
||||
export function createDispatchParserAnthropicMessages(): DispatchParser {
|
||||
export function createDispatchParserAnthropicMessage(): DispatchParser {
|
||||
let responseMessage: AnthropicWire_MessageResponse;
|
||||
let hasErrored = false;
|
||||
let messageStartTime: number | undefined = undefined;
|
||||
@@ -52,7 +52,7 @@ export function createDispatchParserAnthropicMessages(): DispatchParser {
|
||||
// Note: at this stage, the parser only returns the text content as text, which is streamed as text
|
||||
// to the client. It is however building in parallel the responseMessage object, which is not
|
||||
// yet used, but contains token counts, for instance.
|
||||
return function* (eventData: string, eventName?: string): Generator<DispatchParsedEvent> {
|
||||
return function* (eventData: string, eventName?: string): Generator<DispatchMessageAction> {
|
||||
|
||||
// if we've errored, we should not be receiving more data
|
||||
if (hasErrored)
|
||||
@@ -167,6 +167,59 @@ export function createDispatchParserAnthropicMessages(): DispatchParser {
|
||||
}
|
||||
|
||||
|
||||
export function createDispatchParserAnthropicNS(): DispatchParser {
|
||||
let messageStartTime: number = Date.now();
|
||||
|
||||
return function* (fullData: string): Generator<DispatchMessageAction> {
|
||||
|
||||
// parse with validation (e.g. type: 'message' && role: 'assistant')
|
||||
const {
|
||||
model,
|
||||
content,
|
||||
stop_reason,
|
||||
usage,
|
||||
} = anthropicWire_MessageResponse_Schema.parse(JSON.parse(fullData));
|
||||
|
||||
// -> Model
|
||||
if (model)
|
||||
yield { op: 'set', value: { model } };
|
||||
|
||||
// -> Content Blocks
|
||||
for (let i = 0; i < content.length; i++) {
|
||||
const contentBlock = content[i];
|
||||
const isLastBlock = i === content.length - 1;
|
||||
switch (contentBlock.type) {
|
||||
case 'text':
|
||||
const hitMaxTokens = (isLastBlock && stop_reason === 'max_tokens') ? ` ${TEXT_SYMBOL_MAX_TOKENS}` : '';
|
||||
yield { op: 'text', text: contentBlock.text + hitMaxTokens };
|
||||
break;
|
||||
case 'tool_use':
|
||||
yield { op: 'text', text: `TODO: [Tool Use] ${contentBlock.id} ${contentBlock.name} ${JSON.stringify(contentBlock.input)}` };
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unexpected content block type: ${(contentBlock as any).type}`);
|
||||
}
|
||||
}
|
||||
|
||||
// -> Stats
|
||||
if (usage) {
|
||||
const elapsedTimeSeconds = (Date.now() - messageStartTime) / 1000;
|
||||
const chatOutRate = elapsedTimeSeconds > 0 ? usage.output_tokens / elapsedTimeSeconds : 0;
|
||||
yield {
|
||||
op: 'set', value: {
|
||||
stats: {
|
||||
chatInTokens: usage.input_tokens,
|
||||
chatOutTokens: usage.output_tokens,
|
||||
chatOutRate: Math.round(chatOutRate * 100) / 100, // Round to 2 decimal places
|
||||
timeInner: elapsedTimeSeconds,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
function explainGeminiSafetyIssues(safetyRatings?: GeminiSafetyRatings): string {
|
||||
if (!safetyRatings || !safetyRatings.length)
|
||||
return 'no safety ratings provided';
|
||||
@@ -182,7 +235,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser {
|
||||
let hasBegun = false;
|
||||
|
||||
// this can throw, it's caught by the caller
|
||||
return function* (eventData): Generator<DispatchParsedEvent> {
|
||||
return function* (eventData): Generator<DispatchMessageAction> {
|
||||
|
||||
// parse the JSON chunk
|
||||
const wireGenerationChunk = JSON.parse(eventData);
|
||||
@@ -253,7 +306,7 @@ export function createDispatchParserGemini(modelName: string): DispatchParser {
|
||||
export function createDispatchParserOllama(): DispatchParser {
|
||||
let hasBegun = false;
|
||||
|
||||
return function* (eventData: string): Generator<DispatchParsedEvent> {
|
||||
return function* (eventData: string): Generator<DispatchMessageAction> {
|
||||
|
||||
// parse the JSON chunk
|
||||
let wireJsonChunk: any;
|
||||
@@ -302,7 +355,7 @@ export function createDispatchParserOpenAI(): DispatchParser {
|
||||
let hasWarned = false;
|
||||
// NOTE: could compute rate (tok/s) from the first textful event to the last (to ignore the prefill time)
|
||||
|
||||
return function* (eventData: string): Generator<DispatchParsedEvent> {
|
||||
return function* (eventData: string): Generator<DispatchMessageAction> {
|
||||
|
||||
// Throws on malformed event data
|
||||
const json = openaiWire_ChatCompletionChunkResponse_Schema.parse(JSON.parse(eventData));
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { SERVER_DEBUG_WIRE } from '~/server/wire';
|
||||
|
||||
import type { DemuxedEvent } from '../dispatch/dispatch.demuxers';
|
||||
import type { DispatchMessageAction } from '../dispatch/dispatch.parsers';
|
||||
|
||||
|
||||
// type IntakeProtoObject = IntakeControlProtoObject | IntakeEventProtoObject;
|
||||
@@ -43,6 +44,32 @@ export class IntakeHandler {
|
||||
yield op;
|
||||
}
|
||||
|
||||
* yieldDmaOps(parsedEvents: Generator<DispatchMessageAction>, prettyDialect: string) {
|
||||
for (const dma of parsedEvents) {
|
||||
// console.log('parsed dispatch:', dma);
|
||||
// TODO: massively rework this into a good protocol
|
||||
if (dma.op === 'parser-close') {
|
||||
yield* this.yieldTermination('parser-done');
|
||||
break;
|
||||
} else if (dma.op === 'text') {
|
||||
yield* this.yieldOp({
|
||||
t: dma.text,
|
||||
});
|
||||
} else if (dma.op === 'issue') {
|
||||
yield* this.yieldOp({
|
||||
t: ` ${dma.symbol} **[${prettyDialect} Issue]:** ${dma.issue}`,
|
||||
});
|
||||
} else if (dma.op === 'set') {
|
||||
yield* this.yieldOp({
|
||||
set: dma.value,
|
||||
});
|
||||
} else {
|
||||
// shall never reach this
|
||||
console.error('Unexpected stream event:', dma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
* yieldError(errorId: 'dispatch-prepare' | 'dispatch-fetch' | 'dispatch-read' | 'dispatch-parse', errorText: string, forceConsoleMessage?: boolean) {
|
||||
if (SERVER_DEBUG_WIRE || forceConsoleMessage || true)
|
||||
console.error(`[POST] /api/llms/stream: ${this.prettyDialect}: ${errorId}: ${errorText}`);
|
||||
|
||||
@@ -54,7 +54,6 @@ const openAPISchemaObjectSchema = z.object({
|
||||
|
||||
// an object-only subset of the above, which is the JSON object owner of the parameters
|
||||
const intakeFunctionCallInputSchemaSchema = z.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.record(openAPISchemaObjectSchema),
|
||||
required: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user