From a097b32d5c1d2cf135e44edc73fa61933f1b8397 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Wed, 3 Jul 2024 01:31:13 -0700 Subject: [PATCH] AIX: types migration --- src/apps/chat/editors/chat-persona.ts | 83 +++++---- src/modules/aix/client/aix.client.ts | 198 +++++++++++++++++++++ src/modules/aix/shared/aix.shared.chat.ts | 17 +- src/modules/aix/shared/aix.shared.tools.ts | 4 + src/modules/aix/shared/aix.shared.types.ts | 54 ++++++ src/modules/llms/store-llms.ts | 2 +- 6 files changed, 311 insertions(+), 47 deletions(-) create mode 100644 src/modules/aix/client/aix.client.ts create mode 100644 src/modules/aix/shared/aix.shared.types.ts diff --git a/src/apps/chat/editors/chat-persona.ts b/src/apps/chat/editors/chat-persona.ts index aedc65dd8..a86b2b7c6 100644 --- a/src/apps/chat/editors/chat-persona.ts +++ b/src/apps/chat/editors/chat-persona.ts @@ -1,8 +1,11 @@ import type { DLLMId } from '~/modules/llms/store-llms'; + import { VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client'; +import { aixStreamingChatGenerate, StreamingClientUpdate } from '~/modules/aix/client/aix.client'; + import { ConversationsManager } from '~/common/chats/ConversationsManager'; -import { DMessage, messageSingleTextOrThrow } from '~/common/stores/chat/chat.message'; +import { DMessage, messageFragmentsReplaceLastContentText, messageSingleTextOrThrow } from '~/common/stores/chat/chat.message'; import { getUXLabsHighPerformance } from '~/common/state/store-ux-labs'; import { getInstantAppChatPanesCount } from '../components/panes/usePanesManager'; @@ -114,43 +117,47 @@ export async function streamPersonaMessage( console.log('PERSONA HERE'); - // try { - // await llmStreamingChatGenerate(llmId, messagesHistory, contextName, contextRef, null, null, abortSignal, (update: StreamingClientUpdate) => { - // const textSoFar = update.textSoFar; - // - // // grow the incremental message - // if (textSoFar) incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, textSoFar); - // if (update.originLLM) incrementalAnswer.originLLM = update.originLLM; - // if (update.typing !== undefined) - // incrementalAnswer.pendingIncomplete = update.typing ? true : undefined; - // - // // Update the data store, with optional max-frequency throttling (e.g. OpenAI is downsamped 50 -> 12Hz) - // // This can be toggled from the settings - // throttledEditMessage(incrementalAnswer); - // - // // 📢 TTS: first-line - // // if (textSoFar && autoSpeak === 'firstLine' && !spokenLine) { - // // let cutPoint = textSoFar.lastIndexOf('\n'); - // // if (cutPoint < 0) - // // cutPoint = textSoFar.lastIndexOf('. '); - // // if (cutPoint > 100 && cutPoint < 400) { - // // spokenLine = true; - // // const firstParagraph = textSoFar.substring(0, cutPoint); - // // // fire/forget: we don't want to stall this loop - // // void speakText(firstParagraph); - // // } - // // } - // }); - // } catch (error: any) { - // if (error?.name !== 'AbortError') { - // console.error('Fetch request error:', error); - // const errorText = ` [Issue: ${error.message || (typeof error === 'string' ? error : 'Chat stopped.')}]`; - // incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, errorText, true); - // returnStatus.outcome = 'errored'; - // returnStatus.errorMessage = error.message; - // } else - // returnStatus.outcome = 'aborted'; - // } + try { + const onUpdate = (update: StreamingClientUpdate, done: boolean) => { + console.log('PERSONA UPDATE', update, done); + const textSoFar = update.textSoFar; + + // grow the incremental message + if (textSoFar) incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, textSoFar); + if (update.originLLM) incrementalAnswer.originLLM = update.originLLM; + if (update.typing !== undefined) + incrementalAnswer.pendingIncomplete = update.typing ? true : undefined; + + // Update the data store, with optional max-frequency throttling (e.g. OpenAI is downsamped 50 -> 12Hz) + // This can be toggled from the settings + throttledEditMessage(incrementalAnswer); + + // 📢 TTS: first-line + // if (textSoFar && autoSpeak === 'firstLine' && !spokenLine) { + // let cutPoint = textSoFar.lastIndexOf('\n'); + // if (cutPoint < 0) + // cutPoint = textSoFar.lastIndexOf('. '); + // if (cutPoint > 100 && cutPoint < 400) { + // spokenLine = true; + // const firstParagraph = textSoFar.substring(0, cutPoint); + // // fire/forget: we don't want to stall this loop + // void speakText(firstParagraph); + // } + // } + }; + + await aixStreamingChatGenerate(llmId, messagesHistory, contextName, contextRef, null, null, abortSignal, onUpdate); + + } catch (error: any) { + if (error?.name !== 'AbortError') { + console.error('Fetch request error:', error); + const errorText = ` [Issue: ${error.message || (typeof error === 'string' ? error : 'Chat stopped.')}]`; + incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, errorText, true); + returnStatus.outcome = 'errored'; + returnStatus.errorMessage = error.message; + } else + returnStatus.outcome = 'aborted'; + } // Ensure the last content is flushed out, and mark as complete onMessageUpdated({ ...incrementalAnswer, pendingIncomplete: undefined }, true); diff --git a/src/modules/aix/client/aix.client.ts b/src/modules/aix/client/aix.client.ts new file mode 100644 index 000000000..b57f60993 --- /dev/null +++ b/src/modules/aix/client/aix.client.ts @@ -0,0 +1,198 @@ +import type { ChatStreamingInputSchema } from '~/modules/llms/server/llm.server.streaming'; +import { DLLMId } from '~/modules/llms/store-llms'; +import { findVendorForLlmOrThrow } from '~/modules/llms/vendors/vendors.registry'; +import { VChatContextRef, VChatFunctionIn, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client'; + + +import { frontendSideFetch } from '~/common/util/clientFetchers'; +import { AixGenerateContentInput } from '~/modules/aix/shared/aix.shared.chat'; +import { AixAccess, AixHistory, AixModel, AixStreamGenerateContext } from '~/modules/aix/shared/aix.shared.types'; +import { AixToolPolicy, AixTools } from '~/modules/aix/shared/aix.shared.tools'; + + +export type StreamingClientUpdate = Partial<{ + textSoFar: string; + typing: boolean; + originLLM: string; +}>; + + +export async function aixStreamingChatGenerate( + llmId: DLLMId, + history: VChatMessageIn[], + contextName: VChatStreamContextName, + contextRef: VChatContextRef, + functions: VChatFunctionIn[] | null, + forceFunctionName: string | null, + abortSignal: AbortSignal, + onUpdate: (update: StreamingClientUpdate, done: boolean) => void, +): Promise { + + // id to DLLM and vendor + const { llm, vendor } = findVendorForLlmOrThrow(llmId); + + // FIXME: relax the forced cast + const llmOptions = llm.options; + + // get the access + const partialSourceSetup = llm._source.setup; + const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access']; + + // get any vendor-specific rate limit delay + const delay = vendor.getRateLimitDelay?.(llm, partialSourceSetup) ?? 0; + if (delay > 0) + await new Promise(resolve => setTimeout(resolve, delay)); + + // [OpenAI-only] check for harmful content with the free 'moderation' API, if the user requests so + // if (access.dialect === 'openai' && access.moderationCheck) { + // const moderationUpdate = await _openAIModerationCheck(access, messages.at(-1) ?? null); + // if (moderationUpdate) + // return onUpdate({ textSoFar: moderationUpdate, typing: false }, true); + // } + + + // execute via the vendor + return await aixStreamGenerateDirect( + access, + aixModelFromLLMOptions(llm.options, llmId), + history, + undefined, + undefined, + aixStreamGenerateContext(contextName, contextRef), + abortSignal, + onUpdate, + ); + // return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, contextName, contextRef, functions, forceFunctionName, abortSignal, onUpdate); +} + + +function aixModelFromLLMOptions(llmOptions: Record, debugLlmId: string): AixModel { + // model params (llm) + const { llmRef, llmTemperature, llmResponseTokens } = llmOptions || {}; + if (!llmRef || llmTemperature === undefined) + throw new Error(`Error in configuration for model ${debugLlmId}: ${JSON.stringify(llmOptions)}`); + + return { + id: llmRef, + temperature: llmTemperature, + ...(llmResponseTokens ? { maxTokens: llmResponseTokens } : {}), + }; +} + +function aixStreamGenerateContext(contextName: VChatStreamContextName, contextRef: VChatContextRef): AixStreamGenerateContext { + return { + method: 'chat-stream', + name: contextName, + ref: contextRef, + }; +} + + +/** + * Client side chat generation, with streaming. This decodes the (text) streaming response from + * our server streaming endpoint (plain text, not EventSource), and signals updates via a callback. + * + * Vendor-specific implementation is on our server backend (API) code. This function tries to be + * as generic as possible. + * + * NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received + */ +export async function aixStreamGenerateDirect( + // input + access: AixAccess, + model: AixModel, + history: AixHistory, + tools: AixTools | undefined, + toolPolicy: AixToolPolicy | undefined, + context: AixStreamGenerateContext, + // others + abortSignal: AbortSignal, + onUpdate: (update: StreamingClientUpdate, done: boolean) => void, +): Promise { + + // assemble the input object + const aixGenerateContentInput: AixGenerateContentInput = { + access, + model, + history, + // tools: undefined, + // toolPolicy: undefined, + context, + }; + + // connect to the server-side streaming endpoint + const timeFetch = performance.now(); + const streamResponse = await frontendSideFetch('/api/llms/stream', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(aixGenerateContentInput), + signal: abortSignal, + }); + + // connection error to our backend + if (!streamResponse.ok || !streamResponse.body) { + const errorMessage = streamResponse.body ? await streamResponse.text().catch(() => 'No content from server') : 'No response from server'; + onUpdate({ textSoFar: errorMessage, typing: false }, true); + return; + } + + const responseReader = streamResponse.body.getReader(); + + let incrementalText = ''; + let parsedPreambleStart = false; + let parsedPreableModel = false; + + // loop forever until the read is done, or the abort controller is triggered + const textDecoder = new TextDecoder('utf-8'); + while (true) { + + // read until done + const { value, done } = await responseReader.read(); + if (done) { + if (value?.length) + console.log('aixStreamGenerateDirect: unexpected value in the last packet:', value?.length); + break; + } + + incrementalText += textDecoder.decode(value, { stream: true }); + + // we have two packets with a serialized flat json object at the start; this is side data, before the text flow starts + // while ((!parsedPreambleStart || !parsedPreableModel) && incrementalText.startsWith('{')) { + // + // // extract a complete JSON object, if present + // const endOfJson = incrementalText.indexOf('}'); + // if (endOfJson === -1) break; + // const jsonString = incrementalText.substring(0, endOfJson + 1); + // incrementalText = incrementalText.substring(endOfJson + 1); + // + // // first packet: preamble to let the Vercel edge function go over time + // if (!parsedPreambleStart) { + // parsedPreambleStart = true; + // try { + // const parsed: ChatStreamingPreambleStartSchema = JSON.parse(jsonString); + // if (parsed.type !== 'start') + // console.log('unifiedStreamingClient: unexpected preamble type:', parsed?.type, 'time:', performance.now() - timeFetch); + // } catch (e) { + // // error parsing JSON, ignore + // console.log('unifiedStreamingClient: error parsing start JSON:', e); + // } + // continue; + // } + // + // // second packet: the model name + // if (!parsedPreableModel) { + // parsedPreableModel = true; + // try { + // const parsed: ChatStreamingPreambleModelSchema = JSON.parse(jsonString); + // onUpdate({ originLLM: parsed.model }, false); + // } catch (e) { + // // error parsing JSON, ignore + // console.log('unifiedStreamingClient: error parsing model JSON:', e); + // } + // } + // } + + if (incrementalText) + onUpdate({ textSoFar: incrementalText }, false); + } +} diff --git a/src/modules/aix/shared/aix.shared.chat.ts b/src/modules/aix/shared/aix.shared.chat.ts index 269c56b5b..869c1b021 100644 --- a/src/modules/aix/shared/aix.shared.chat.ts +++ b/src/modules/aix/shared/aix.shared.chat.ts @@ -1,18 +1,19 @@ import { z } from 'zod'; +import { aixAccessSchema, aixHistorySchema, aixModelSchema, aixStreamingContextSchema } from '~/modules/aix/shared/aix.shared.types'; import { aixToolsPolicySchema, aixToolsSchema } from './aix.shared.tools'; /// GENERATE INPUT Schema /// -export const aixChatGenerateInputSchema = z.object({ - // access: openAIAccessSchema, - // model: openAIModelSchema, - // history: openAIHistorySchema, +export type AixGenerateContentInput = z.infer; + +export const aixGenerateContentInputSchema = z.object({ + access: aixAccessSchema, + model: aixModelSchema, + history: aixHistorySchema, tools: aixToolsSchema.optional(), toolPolicy: aixToolsPolicySchema.optional(), - // context: llmsGenerateContextSchema, - // stream? -> implicit via the function name + context: aixStreamingContextSchema, + // stream? -> discriminated via the rpc function name }); - - diff --git a/src/modules/aix/shared/aix.shared.tools.ts b/src/modules/aix/shared/aix.shared.tools.ts index 30cd09f13..5cb1458d6 100644 --- a/src/modules/aix/shared/aix.shared.tools.ts +++ b/src/modules/aix/shared/aix.shared.tools.ts @@ -137,6 +137,8 @@ export const aixToolsSchema = z.array(z.discriminatedUnion('type', [ aixToolPreprocessorSchema, ])); +export type AixTools = z.infer; + /** * Policy for tools that the model can use: * - any: must use one tool at least @@ -150,3 +152,5 @@ export const aixToolsPolicySchema = z.discriminatedUnion('type', [ z.object({ type: z.literal('function'), function: z.object({ name: z.string() }) }), z.object({ type: z.literal('none') }), ]); + +export type AixToolPolicy = z.infer; diff --git a/src/modules/aix/shared/aix.shared.types.ts b/src/modules/aix/shared/aix.shared.types.ts new file mode 100644 index 000000000..3dfbcea9b --- /dev/null +++ b/src/modules/aix/shared/aix.shared.types.ts @@ -0,0 +1,54 @@ +import { z } from 'zod'; + +import { anthropicAccessSchema } from '~/modules/llms/server/anthropic/anthropic.router'; +import { geminiAccessSchema } from '~/modules/llms/server/gemini/gemini.router'; +import { ollamaAccessSchema } from '~/modules/llms/server/ollama/ollama.router'; +import { openAIAccessSchema } from '~/modules/llms/server/openai/openai.router'; + + +// AIX Access Schema // + +export type AixAccess = z.infer; + +export const aixAccessSchema = z.discriminatedUnion( + 'dialect', + [ + anthropicAccessSchema, + geminiAccessSchema, + ollamaAccessSchema, + openAIAccessSchema, + ], +); + + +// AIX Context Schema // + +export type AixStreamGenerateContext = z.infer; + +export const aixStreamingContextSchema = z.object({ + method: z.literal('chat-stream'), + name: z.enum(['conversation', 'ai-diagram', 'ai-flattener', 'call', 'beam-scatter', 'beam-gather', 'persona-extract']), + ref: z.string(), +}); + + +// AIX History Schema // + +export type AixHistory = z.infer; + +export const aixHistorySchema = z.array(z.object({ + role: z.enum(['assistant', 'system', 'user'/*, 'function'*/]), + content: z.string(), +})); + + +// AIX Model Schema // + +export type AixModel = z.infer; + +// FIXME: have a more flexible schema +export const aixModelSchema = z.object({ + id: z.string(), + temperature: z.number().min(0).max(2).optional(), + maxTokens: z.number().min(1).max(1000000).optional(), +}); diff --git a/src/modules/llms/store-llms.ts b/src/modules/llms/store-llms.ts index 3afd389be..82cab97dd 100644 --- a/src/modules/llms/store-llms.ts +++ b/src/modules/llms/store-llms.ts @@ -8,7 +8,7 @@ import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vend /** * Large Language Model - description and configuration (data object, stored) */ -export interface DLLM { +export interface DLLM> { id: DLLMId; // editable properties (kept on update, if isEdited)