AIX: types migration

This commit is contained in:
Enrico Ros
2024-07-03 01:31:13 -07:00
parent 0a88a9cee6
commit a097b32d5c
6 changed files with 311 additions and 47 deletions
+45 -38
View File
@@ -1,8 +1,11 @@
import type { DLLMId } from '~/modules/llms/store-llms';
import { VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';
import { aixStreamingChatGenerate, StreamingClientUpdate } from '~/modules/aix/client/aix.client';
import { ConversationsManager } from '~/common/chats/ConversationsManager';
import { DMessage, messageSingleTextOrThrow } from '~/common/stores/chat/chat.message';
import { DMessage, messageFragmentsReplaceLastContentText, messageSingleTextOrThrow } from '~/common/stores/chat/chat.message';
import { getUXLabsHighPerformance } from '~/common/state/store-ux-labs';
import { getInstantAppChatPanesCount } from '../components/panes/usePanesManager';
@@ -114,43 +117,47 @@ export async function streamPersonaMessage(
console.log('PERSONA HERE');
// try {
// await llmStreamingChatGenerate(llmId, messagesHistory, contextName, contextRef, null, null, abortSignal, (update: StreamingClientUpdate) => {
// const textSoFar = update.textSoFar;
//
// // grow the incremental message
// if (textSoFar) incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, textSoFar);
// if (update.originLLM) incrementalAnswer.originLLM = update.originLLM;
// if (update.typing !== undefined)
// incrementalAnswer.pendingIncomplete = update.typing ? true : undefined;
//
// // Update the data store, with optional max-frequency throttling (e.g. OpenAI is downsamped 50 -> 12Hz)
// // This can be toggled from the settings
// throttledEditMessage(incrementalAnswer);
//
// // 📢 TTS: first-line
// // if (textSoFar && autoSpeak === 'firstLine' && !spokenLine) {
// // let cutPoint = textSoFar.lastIndexOf('\n');
// // if (cutPoint < 0)
// // cutPoint = textSoFar.lastIndexOf('. ');
// // if (cutPoint > 100 && cutPoint < 400) {
// // spokenLine = true;
// // const firstParagraph = textSoFar.substring(0, cutPoint);
// // // fire/forget: we don't want to stall this loop
// // void speakText(firstParagraph);
// // }
// // }
// });
// } catch (error: any) {
// if (error?.name !== 'AbortError') {
// console.error('Fetch request error:', error);
// const errorText = ` [Issue: ${error.message || (typeof error === 'string' ? error : 'Chat stopped.')}]`;
// incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, errorText, true);
// returnStatus.outcome = 'errored';
// returnStatus.errorMessage = error.message;
// } else
// returnStatus.outcome = 'aborted';
// }
try {
const onUpdate = (update: StreamingClientUpdate, done: boolean) => {
console.log('PERSONA UPDATE', update, done);
const textSoFar = update.textSoFar;
// grow the incremental message
if (textSoFar) incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, textSoFar);
if (update.originLLM) incrementalAnswer.originLLM = update.originLLM;
if (update.typing !== undefined)
incrementalAnswer.pendingIncomplete = update.typing ? true : undefined;
// Update the data store, with optional max-frequency throttling (e.g. OpenAI is downsamped 50 -> 12Hz)
// This can be toggled from the settings
throttledEditMessage(incrementalAnswer);
// 📢 TTS: first-line
// if (textSoFar && autoSpeak === 'firstLine' && !spokenLine) {
// let cutPoint = textSoFar.lastIndexOf('\n');
// if (cutPoint < 0)
// cutPoint = textSoFar.lastIndexOf('. ');
// if (cutPoint > 100 && cutPoint < 400) {
// spokenLine = true;
// const firstParagraph = textSoFar.substring(0, cutPoint);
// // fire/forget: we don't want to stall this loop
// void speakText(firstParagraph);
// }
// }
};
await aixStreamingChatGenerate(llmId, messagesHistory, contextName, contextRef, null, null, abortSignal, onUpdate);
} catch (error: any) {
if (error?.name !== 'AbortError') {
console.error('Fetch request error:', error);
const errorText = ` [Issue: ${error.message || (typeof error === 'string' ? error : 'Chat stopped.')}]`;
incrementalAnswer.fragments = messageFragmentsReplaceLastContentText(incrementalAnswer.fragments, errorText, true);
returnStatus.outcome = 'errored';
returnStatus.errorMessage = error.message;
} else
returnStatus.outcome = 'aborted';
}
// Ensure the last content is flushed out, and mark as complete
onMessageUpdated({ ...incrementalAnswer, pendingIncomplete: undefined }, true);
+198
View File
@@ -0,0 +1,198 @@
import type { ChatStreamingInputSchema } from '~/modules/llms/server/llm.server.streaming';
import { DLLMId } from '~/modules/llms/store-llms';
import { findVendorForLlmOrThrow } from '~/modules/llms/vendors/vendors.registry';
import { VChatContextRef, VChatFunctionIn, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';
import { frontendSideFetch } from '~/common/util/clientFetchers';
import { AixGenerateContentInput } from '~/modules/aix/shared/aix.shared.chat';
import { AixAccess, AixHistory, AixModel, AixStreamGenerateContext } from '~/modules/aix/shared/aix.shared.types';
import { AixToolPolicy, AixTools } from '~/modules/aix/shared/aix.shared.tools';
export type StreamingClientUpdate = Partial<{
textSoFar: string;
typing: boolean;
originLLM: string;
}>;
export async function aixStreamingChatGenerate<TSourceSetup = unknown, TAccess extends ChatStreamingInputSchema['access'] = ChatStreamingInputSchema['access']>(
llmId: DLLMId,
history: VChatMessageIn[],
contextName: VChatStreamContextName,
contextRef: VChatContextRef,
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
abortSignal: AbortSignal,
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
): Promise<void> {
// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess>(llmId);
// FIXME: relax the forced cast
const llmOptions = llm.options;
// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access'];
// get any vendor-specific rate limit delay
const delay = vendor.getRateLimitDelay?.(llm, partialSourceSetup) ?? 0;
if (delay > 0)
await new Promise(resolve => setTimeout(resolve, delay));
// [OpenAI-only] check for harmful content with the free 'moderation' API, if the user requests so
// if (access.dialect === 'openai' && access.moderationCheck) {
// const moderationUpdate = await _openAIModerationCheck(access, messages.at(-1) ?? null);
// if (moderationUpdate)
// return onUpdate({ textSoFar: moderationUpdate, typing: false }, true);
// }
// execute via the vendor
return await aixStreamGenerateDirect(
access,
aixModelFromLLMOptions(llm.options, llmId),
history,
undefined,
undefined,
aixStreamGenerateContext(contextName, contextRef),
abortSignal,
onUpdate,
);
// return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, contextName, contextRef, functions, forceFunctionName, abortSignal, onUpdate);
}
function aixModelFromLLMOptions(llmOptions: Record<string, any>, debugLlmId: string): AixModel {
// model params (llm)
const { llmRef, llmTemperature, llmResponseTokens } = llmOptions || {};
if (!llmRef || llmTemperature === undefined)
throw new Error(`Error in configuration for model ${debugLlmId}: ${JSON.stringify(llmOptions)}`);
return {
id: llmRef,
temperature: llmTemperature,
...(llmResponseTokens ? { maxTokens: llmResponseTokens } : {}),
};
}
function aixStreamGenerateContext(contextName: VChatStreamContextName, contextRef: VChatContextRef): AixStreamGenerateContext {
return {
method: 'chat-stream',
name: contextName,
ref: contextRef,
};
}
/**
* Client side chat generation, with streaming. This decodes the (text) streaming response from
* our server streaming endpoint (plain text, not EventSource), and signals updates via a callback.
*
* Vendor-specific implementation is on our server backend (API) code. This function tries to be
* as generic as possible.
*
* NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received
*/
export async function aixStreamGenerateDirect<TSourceSetup = unknown>(
// input
access: AixAccess,
model: AixModel,
history: AixHistory,
tools: AixTools | undefined,
toolPolicy: AixToolPolicy | undefined,
context: AixStreamGenerateContext,
// others
abortSignal: AbortSignal,
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
): Promise<void> {
// assemble the input object
const aixGenerateContentInput: AixGenerateContentInput = {
access,
model,
history,
// tools: undefined,
// toolPolicy: undefined,
context,
};
// connect to the server-side streaming endpoint
const timeFetch = performance.now();
const streamResponse = await frontendSideFetch('/api/llms/stream', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(aixGenerateContentInput),
signal: abortSignal,
});
// connection error to our backend
if (!streamResponse.ok || !streamResponse.body) {
const errorMessage = streamResponse.body ? await streamResponse.text().catch(() => 'No content from server') : 'No response from server';
onUpdate({ textSoFar: errorMessage, typing: false }, true);
return;
}
const responseReader = streamResponse.body.getReader();
let incrementalText = '';
let parsedPreambleStart = false;
let parsedPreableModel = false;
// loop forever until the read is done, or the abort controller is triggered
const textDecoder = new TextDecoder('utf-8');
while (true) {
// read until done
const { value, done } = await responseReader.read();
if (done) {
if (value?.length)
console.log('aixStreamGenerateDirect: unexpected value in the last packet:', value?.length);
break;
}
incrementalText += textDecoder.decode(value, { stream: true });
// we have two packets with a serialized flat json object at the start; this is side data, before the text flow starts
// while ((!parsedPreambleStart || !parsedPreableModel) && incrementalText.startsWith('{')) {
//
// // extract a complete JSON object, if present
// const endOfJson = incrementalText.indexOf('}');
// if (endOfJson === -1) break;
// const jsonString = incrementalText.substring(0, endOfJson + 1);
// incrementalText = incrementalText.substring(endOfJson + 1);
//
// // first packet: preamble to let the Vercel edge function go over time
// if (!parsedPreambleStart) {
// parsedPreambleStart = true;
// try {
// const parsed: ChatStreamingPreambleStartSchema = JSON.parse(jsonString);
// if (parsed.type !== 'start')
// console.log('unifiedStreamingClient: unexpected preamble type:', parsed?.type, 'time:', performance.now() - timeFetch);
// } catch (e) {
// // error parsing JSON, ignore
// console.log('unifiedStreamingClient: error parsing start JSON:', e);
// }
// continue;
// }
//
// // second packet: the model name
// if (!parsedPreableModel) {
// parsedPreableModel = true;
// try {
// const parsed: ChatStreamingPreambleModelSchema = JSON.parse(jsonString);
// onUpdate({ originLLM: parsed.model }, false);
// } catch (e) {
// // error parsing JSON, ignore
// console.log('unifiedStreamingClient: error parsing model JSON:', e);
// }
// }
// }
if (incrementalText)
onUpdate({ textSoFar: incrementalText }, false);
}
}
+9 -8
View File
@@ -1,18 +1,19 @@
import { z } from 'zod';
import { aixAccessSchema, aixHistorySchema, aixModelSchema, aixStreamingContextSchema } from '~/modules/aix/shared/aix.shared.types';
import { aixToolsPolicySchema, aixToolsSchema } from './aix.shared.tools';
/// GENERATE INPUT Schema ///
export const aixChatGenerateInputSchema = z.object({
// access: openAIAccessSchema,
// model: openAIModelSchema,
// history: openAIHistorySchema,
export type AixGenerateContentInput = z.infer<typeof aixGenerateContentInputSchema>;
export const aixGenerateContentInputSchema = z.object({
access: aixAccessSchema,
model: aixModelSchema,
history: aixHistorySchema,
tools: aixToolsSchema.optional(),
toolPolicy: aixToolsPolicySchema.optional(),
// context: llmsGenerateContextSchema,
// stream? -> implicit via the function name
context: aixStreamingContextSchema,
// stream? -> discriminated via the rpc function name
});
@@ -137,6 +137,8 @@ export const aixToolsSchema = z.array(z.discriminatedUnion('type', [
aixToolPreprocessorSchema,
]));
export type AixTools = z.infer<typeof aixToolsSchema>;
/**
* Policy for tools that the model can use:
* - any: must use one tool at least
@@ -150,3 +152,5 @@ export const aixToolsPolicySchema = z.discriminatedUnion('type', [
z.object({ type: z.literal('function'), function: z.object({ name: z.string() }) }),
z.object({ type: z.literal('none') }),
]);
export type AixToolPolicy = z.infer<typeof aixToolsPolicySchema>;
@@ -0,0 +1,54 @@
import { z } from 'zod';
import { anthropicAccessSchema } from '~/modules/llms/server/anthropic/anthropic.router';
import { geminiAccessSchema } from '~/modules/llms/server/gemini/gemini.router';
import { ollamaAccessSchema } from '~/modules/llms/server/ollama/ollama.router';
import { openAIAccessSchema } from '~/modules/llms/server/openai/openai.router';
// AIX Access Schema //
export type AixAccess = z.infer<typeof aixAccessSchema>;
export const aixAccessSchema = z.discriminatedUnion(
'dialect',
[
anthropicAccessSchema,
geminiAccessSchema,
ollamaAccessSchema,
openAIAccessSchema,
],
);
// AIX Context Schema //
export type AixStreamGenerateContext = z.infer<typeof aixStreamingContextSchema>;
export const aixStreamingContextSchema = z.object({
method: z.literal('chat-stream'),
name: z.enum(['conversation', 'ai-diagram', 'ai-flattener', 'call', 'beam-scatter', 'beam-gather', 'persona-extract']),
ref: z.string(),
});
// AIX History Schema //
export type AixHistory = z.infer<typeof aixHistorySchema>;
export const aixHistorySchema = z.array(z.object({
role: z.enum(['assistant', 'system', 'user'/*, 'function'*/]),
content: z.string(),
}));
// AIX Model Schema //
export type AixModel = z.infer<typeof aixModelSchema>;
// FIXME: have a more flexible schema
export const aixModelSchema = z.object({
id: z.string(),
temperature: z.number().min(0).max(2).optional(),
maxTokens: z.number().min(1).max(1000000).optional(),
});
+1 -1
View File
@@ -8,7 +8,7 @@ import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vend
/**
* Large Language Model - description and configuration (data object, stored)
*/
export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
export interface DLLM<TSourceSetup = unknown, TLLMOptions = Record<string, any>> {
id: DLLMId;
// editable properties (kept on update, if isEdited)