AIX: Client API first port

This commit is contained in:
Enrico Ros
2024-07-10 03:30:18 -07:00
parent cc7242dfd3
commit 83b1e0ffba
4 changed files with 118 additions and 67 deletions
+89 -16
View File
@@ -1,20 +1,98 @@
import type { DLLMId } from '~/modules/llms/store-llms';
import type { VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';
import { AixChatContentGenerateRequest, AixChatMessage, AixChatMessageModel, AixChatMessageUser } from '~/modules/aix/client/aix.client.api';
import type { IntakeContextChatStream } from '~/modules/aix/server/intake/schemas.intake.api';
import { aixStreamingChatGenerate, StreamingClientUpdate } from '~/modules/aix/client/aix.client';
import { autoConversationTitle } from '~/modules/aifn/autotitle/autoTitle';
import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions';
import { PersonaChatMessageSpeak } from './persona/PersonaChatMessageSpeak';
import type { DConversationId } from '~/common/stores/chat/chat.conversation';
import { ConversationsManager } from '~/common/chats/ConversationsManager';
import { DMessage, messageFragmentsReplaceLastContentText, messageSingleTextOrThrow } from '~/common/stores/chat/chat.message';
import { DMessage, messageFragmentsReplaceLastContentText } from '~/common/stores/chat/chat.message';
import { getUXLabsHighPerformance } from '~/common/state/store-ux-labs';
import { isContentFragment, isTextPart } from '~/common/stores/chat/chat.fragments';
import { isContentFragment, isContentOrAttachmentFragment, isTextPart } from '~/common/stores/chat/chat.fragments';
import { PersonaChatMessageSpeak } from './persona/PersonaChatMessageSpeak';
import { getChatAutoAI } from '../store-app-chat';
import { getInstantAppChatPanesCount } from '../components/panes/usePanesManager';
async function historyToChatGenerateRequest(history: Readonly<DMessage[]>): Promise<AixChatContentGenerateRequest> {
// reduce history
return history.reduce((acc, m, index) => {
// extract system
if (index === 0 && m.role === 'system') {
// create parts if not exist
if (!acc.systemMessage) {
acc.systemMessage = {
parts: [],
};
}
for (const systemFragment of m.fragments) {
if (isContentFragment(systemFragment) && isTextPart(systemFragment.part)) {
acc.systemMessage.parts.push(systemFragment.part);
} else {
console.warn('historyToChatGenerateRequest: unexpected system fragment', systemFragment);
}
}
return acc;
}
// map the other parts
let aixChatMessage: AixChatMessage | undefined = undefined;
if (m.role === 'assistant') {
aixChatMessage = m.fragments.reduce((mMsg, srcFragment) => {
if (!isContentOrAttachmentFragment(srcFragment))
return mMsg;
switch (srcFragment.part.pt) {
case 'text':
case 'tool_call':
mMsg.parts.push(srcFragment.part);
break;
default:
console.warn('historyToChatGenerateRequest: unexpected model fragment part type', srcFragment.part);
break;
}
return mMsg;
}, { role: 'model', parts: [] } as AixChatMessageModel);
} else if (m.role === 'user') {
aixChatMessage = m.fragments.reduce((mMsg, srcFragment) => {
if (!isContentOrAttachmentFragment(srcFragment))
return mMsg;
switch (srcFragment.part.pt) {
case 'text':
mMsg.parts.push(srcFragment.part);
break;
case 'image_ref':
console.log('DEV: historyToChatGenerateRequest: image_ref', srcFragment.part);
// const imageDataRef = srcFragment.part.dataRef;
// if (imageDataRef.reftype === 'dblob' && imageDataRef.dblobAssetId) {
// const imageAsset = await getImageAsset(imageDataRef.dblobAssetId);
// }
//
//
//
// mMsg.parts.push({ pt: 'inline_image',mimeType });
break;
case 'doc':
mMsg.parts.push(srcFragment.part);
break;
default:
console.warn('historyToChatGenerateRequest: unexpected user fragment part type', srcFragment.part);
}
return mMsg;
}, { role: 'user', parts: [] } as AixChatMessageUser);
} else {
console.warn('historyToChatGenerateRequest: unexpected message role', m.role);
}
if (aixChatMessage)
acc.chat.push(aixChatMessage);
return acc;
}, { chat: [] } as AixChatContentGenerateRequest);
}
/**
* The main "chat" function.
*/
@@ -45,17 +123,12 @@ export async function runPersonaOnConversationHead(
const abortController = new AbortController();
cHandler.setAbortController(abortController);
// stream the assistant's messages directly to the state store
let instructions: VChatMessageIn[];
try {
instructions = history.map((m): VChatMessageIn => ({ role: m.role, content: messageSingleTextOrThrow(m) /* BIG FIXME */ }));
} catch (error) {
console.error('runAssistantUpdatingState: error:', error, history);
throw error;
}
const aixChatContentGenerateRequest = await historyToChatGenerateRequest(history);
const messageStatus = await llmGenerateContentStream(
assistantLlmId,
instructions,
aixChatContentGenerateRequest,
'conversation',
conversationId,
parallelViewCount,
@@ -103,9 +176,9 @@ type StreamMessageUpdate = Pick<DMessage, 'fragments' | 'originLLM' | 'pendingIn
export async function llmGenerateContentStream(
llmId: DLLMId,
messagesHistory: VChatMessageIn[],
contextName: VChatStreamContextName,
contextRef: VChatContextRef,
chatGenerate: AixChatContentGenerateRequest,
intakeContextName: IntakeContextChatStream['name'],
intakeContextRef: string,
parallelViewCount: number, // 0: disable, 1: default throttle (12Hz), 2+ reduce frequency with the square root
abortSignal: AbortSignal,
onMessageUpdated: (incrementalMessage: Partial<StreamMessageUpdate>, messageComplete: boolean) => void,
@@ -122,7 +195,7 @@ export async function llmGenerateContentStream(
try {
await aixStreamingChatGenerate(llmId, messagesHistory, contextName, contextRef, null, null, abortSignal,
await aixStreamingChatGenerate(llmId, chatGenerate, intakeContextName, intakeContextRef, abortSignal,
(update: StreamingClientUpdate, done: boolean) => {
// grow the incremental message
+2 -2
View File
@@ -44,12 +44,12 @@ export type AixChatMessage =
| AixChatMessageModel
| AixChatMessageTool;
interface AixChatMessageUser {
export interface AixChatMessageUser {
role: 'user',
parts: (DMessageTextPart | AixInlineImagePart | DMessageDocPart | AixMetaReplyToPart)[];
}
interface AixChatMessageModel {
export interface AixChatMessageModel {
role: 'model',
parts: (DMessageTextPart | DMessageToolCallPart)[];
}
+26 -48
View File
@@ -4,10 +4,9 @@ import { findVendorForLlmOrThrow } from '~/modules/llms/vendors/vendors.registry
import { apiStream } from '~/common/util/trpc.client';
import type { VChatContextRef, VChatFunctionIn, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';
import type { IntakeAccess, IntakeContextChatStream, IntakeModel } from '../server/intake/schemas.intake.api';
import type { AixAccess, AixHistory, AixModel, AixStreamGenerateContext } from '../server/intake/aix.intake.types';
import type { AixToolPolicy, AixTools } from '../server/intake/aix.tool.types';
import type { AixChatContentGenerateRequest } from './aix.client.api';
export type StreamingClientUpdate = Partial<{
@@ -19,11 +18,9 @@ export type StreamingClientUpdate = Partial<{
export async function aixStreamingChatGenerate<TSourceSetup = unknown, TAccess extends ChatStreamingInputSchema['access'] = ChatStreamingInputSchema['access']>(
llmId: DLLMId,
history: VChatMessageIn[],
contextName: VChatStreamContextName,
contextRef: VChatContextRef,
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
chatGenerate: AixChatContentGenerateRequest,
intakeContextName: IntakeContextChatStream['name'],
intakeContextRef: string,
abortSignal: AbortSignal,
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
): Promise<void> {
@@ -32,11 +29,12 @@ export async function aixStreamingChatGenerate<TSourceSetup = unknown, TAccess e
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess>(llmId);
// FIXME: relax the forced cast
const llmOptions = llm.options;
// const llmOptions = llm.options;
const intakeModel = intakeModelFromLLMOptions(llm.options, llmId);
// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access'];
const intakeAccess = vendor.getTransportAccess(partialSourceSetup);
// get any vendor-specific rate limit delay
const delay = vendor.getRateLimitDelay?.(llm, partialSourceSetup) ?? 0;
@@ -44,29 +42,24 @@ export async function aixStreamingChatGenerate<TSourceSetup = unknown, TAccess e
await new Promise(resolve => setTimeout(resolve, delay));
// [OpenAI-only] check for harmful content with the free 'moderation' API, if the user requests so
// if (access.dialect === 'openai' && access.moderationCheck) {
// const moderationUpdate = await _openAIModerationCheck(access, messages.at(-1) ?? null);
// if (intakeAccess.dialect === 'openai' && intakeAccess.moderationCheck) {
// const moderationUpdate = await _openAIModerationCheck(intakeAccess, messages.at(-1) ?? null);
// if (moderationUpdate)
// return onUpdate({ textSoFar: moderationUpdate, typing: false }, true);
// }
// execute via the vendor
return await aixStreamGenerateUnified(
access,
aixModelFromLLMOptions(llm.options, llmId),
history,
undefined,
undefined,
aixStreamGenerateContext(contextName, contextRef),
abortSignal,
onUpdate,
);
// return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, contextName, contextRef, functions, forceFunctionName, abortSignal, onUpdate);
// return await vendor.streamingChatGenerateOrThrow(intakeAccess, llmId, llmOptions, messages, contextName, contextRef, functions, forceFunctionName, abortSignal, onUpdate);
const intakeContext = intakeContextChatStream(intakeContextName, intakeContextRef);
return await _aixStreamGenerateUnified(intakeAccess, intakeModel, chatGenerate, intakeContext, abortSignal, onUpdate);
}
function intakeContextChatStream(name: IntakeContextChatStream['name'], ref: string): IntakeContextChatStream {
return { method: 'chat-stream', name, ref };
}
function aixModelFromLLMOptions(llmOptions: Record<string, any>, debugLlmId: string): AixModel {
function intakeModelFromLLMOptions(llmOptions: Record<string, any>, debugLlmId: string): IntakeModel {
// model params (llm)
const { llmRef, llmTemperature, llmResponseTokens } = llmOptions || {};
if (!llmRef || llmTemperature === undefined)
@@ -79,14 +72,6 @@ function aixModelFromLLMOptions(llmOptions: Record<string, any>, debugLlmId: str
};
}
function aixStreamGenerateContext(contextName: AixStreamGenerateContext['name'], contextRef: AixStreamGenerateContext['ref']): AixStreamGenerateContext {
return {
method: 'chat-stream',
name: contextName,
ref: contextRef,
};
}
/**
* Client side chat generation, with streaming. This decodes the (text) streaming response from
@@ -97,27 +82,21 @@ function aixStreamGenerateContext(contextName: AixStreamGenerateContext['name'],
*
* NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received
*/
export async function aixStreamGenerateUnified<TSourceSetup = unknown>(
async function _aixStreamGenerateUnified(
// input
access: AixAccess,
model: AixModel,
history: AixHistory,
tools: AixTools | undefined,
toolPolicy: AixToolPolicy | undefined,
context: AixStreamGenerateContext,
access: IntakeAccess,
model: IntakeModel,
chatGenerate: AixChatContentGenerateRequest,
context: IntakeContextChatStream,
// others
abortSignal: AbortSignal,
onUpdate: (update: StreamingClientUpdate, done: boolean) => void,
): Promise<void> {
const x = await apiStream.aix.chatGenerateContentStream.mutate({
access,
model,
history,
tools,
toolPolicy,
context,
}, { signal: abortSignal });
const x = await apiStream.aix.chatGenerateContentStream.mutate(
{ access, model, chatGenerate, context },
{ signal: abortSignal },
);
let incrementalText = '';
@@ -150,5 +129,4 @@ export async function aixStreamGenerateUnified<TSourceSetup = unknown>(
console.log('HERE', abortSignal.aborted ? 'client-initiated ABORTED' : '');
onUpdate({ typing: false }, true);
}
+1 -1
View File
@@ -57,7 +57,7 @@ async function _addDBImageAsset(imageAsset: DBlobImageAsset, contextId: DBlobDBC
// return await getDBAssetsByType<DBlobImageAsset>(DBlobAssetType.IMAGE);
// }
async function getImageAsset(id: DBlobAssetId) {
export async function getImageAsset(id: DBlobAssetId) {
return await getDBAsset<DBlobImageAsset>(id);
}