AIX: First image.

This commit is contained in:
Enrico Ros
2024-07-10 03:53:33 -07:00
parent 83b1e0ffba
commit 3ee3c312ef
3 changed files with 166 additions and 57 deletions
+26 -12
View File
@@ -15,11 +15,15 @@ import { isContentFragment, isContentOrAttachmentFragment, isTextPart } from '~/
import { PersonaChatMessageSpeak } from './persona/PersonaChatMessageSpeak';
import { getChatAutoAI } from '../store-app-chat';
import { getInstantAppChatPanesCount } from '../components/panes/usePanesManager';
import { getImageAsset } from '~/modules/dblobs/dblobs.images';
// FIXME: complete and optimize. This translates our 'message at rest' data structure into the Aix Request structure
// for chat generate
async function historyToChatGenerateRequest(history: Readonly<DMessage[]>): Promise<AixChatContentGenerateRequest> {
// reduce history
return history.reduce((acc, m, index) => {
return await history.reduce(async (accPromise, m, index) => {
const acc = await accPromise;
// extract system
if (index === 0 && m.role === 'system') {
@@ -57,7 +61,8 @@ async function historyToChatGenerateRequest(history: Readonly<DMessage[]>): Prom
return mMsg;
}, { role: 'model', parts: [] } as AixChatMessageModel);
} else if (m.role === 'user') {
aixChatMessage = m.fragments.reduce((mMsg, srcFragment) => {
aixChatMessage = await m.fragments.reduce(async (mMsgPromise, srcFragment) => {
const mMsg = await mMsgPromise;
if (!isContentOrAttachmentFragment(srcFragment))
return mMsg;
switch (srcFragment.part.pt) {
@@ -66,14 +71,23 @@ async function historyToChatGenerateRequest(history: Readonly<DMessage[]>): Prom
break;
case 'image_ref':
console.log('DEV: historyToChatGenerateRequest: image_ref', srcFragment.part);
// const imageDataRef = srcFragment.part.dataRef;
// if (imageDataRef.reftype === 'dblob' && imageDataRef.dblobAssetId) {
// const imageAsset = await getImageAsset(imageDataRef.dblobAssetId);
// }
//
//
//
// mMsg.parts.push({ pt: 'inline_image',mimeType });
const imageDataRef = srcFragment.part.dataRef;
if (imageDataRef.reftype === 'dblob' && imageDataRef.dblobAssetId) {
const imageAsset = await getImageAsset(imageDataRef.dblobAssetId);
if (imageAsset) {
mMsg.parts.push({
pt: 'inline_image',
mimeType: imageDataRef.mimeType || imageAsset.data.mimeType || 'image/png' as any,
base64: imageAsset.data.base64,
});
} else {
console.warn('historyToChatGenerateRequest: image_ref: missing image asset', imageDataRef);
throw new Error('Missing image asset');
}
} else {
console.warn('historyToChatGenerateRequest: image_ref: unexpected data ref', imageDataRef);
throw new Error('Unexpected data ref');
}
break;
case 'doc':
mMsg.parts.push(srcFragment.part);
@@ -82,14 +96,14 @@ async function historyToChatGenerateRequest(history: Readonly<DMessage[]>): Prom
console.warn('historyToChatGenerateRequest: unexpected user fragment part type', srcFragment.part);
}
return mMsg;
}, { role: 'user', parts: [] } as AixChatMessageUser);
}, Promise.resolve({ role: 'user', parts: [] } as AixChatMessageUser));
} else {
console.warn('historyToChatGenerateRequest: unexpected message role', m.role);
}
if (aixChatMessage)
acc.chat.push(aixChatMessage);
return acc;
}, { chat: [] } as AixChatContentGenerateRequest);
}, Promise.resolve({ chat: [] } as AixChatContentGenerateRequest));
}
@@ -0,0 +1,92 @@
import { OpenAIModelSchema } from '~/modules/llms/server/openai/openai.router';
import { AnthropicWireMessageCreate, anthropicWireMessageCreateSchema } from '~/modules/aix/server/dispatch/anthropic/anthropic.wiretypes';
import type { IntakeChatGenerateRequest } from '../../intake/schemas.intake.api';
const DEFAULT_MAX_TOKENS = 1024;
export function NEWanthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, chatGenerate: IntakeChatGenerateRequest, stream: boolean): AnthropicWireMessageCreate {
// Extract system message
const systemMessage = chatGenerate.systemMessage?.parts.find(part => part.pt === 'text')?.text;
// Transform the chat messages into Anthropic's format
const messages: AnthropicWireMessageCreate['messages'] = chatGenerate.chat.reduce((acc, message) => {
const anthropicRole = message.role === 'model' ? 'assistant' : 'user';
const content = message.parts.map(part => {
switch (part.pt) {
case 'text':
return { type: 'text' as const, text: part.text };
case 'inline_image':
return {
type: 'image' as const,
source: {
type: 'base64',
media_type: part.mimeType,
data: part.base64,
},
};
case 'tool_call':
case 'tool_response':
// These might need special handling depending on Anthropic's API
console.warn('Tool calls and results are not directly supported in this conversion');
return null;
default:
console.warn(`Unsupported part type: ${(part as any).pt}`);
return null;
}
}).filter(Boolean);
if (content.length > 0) {
acc.push({ role: anthropicRole, content: content as any /*FIXME*/ });
}
return acc;
}, [] as AnthropicWireMessageCreate['messages']);
// Ensure the first message is from the user
if (messages.length === 0 || messages[0].role !== 'user') {
messages.unshift({ role: 'user', content: [{ type: 'text', text: systemMessage || '' }] });
}
// Construct the request payload
const payload: AnthropicWireMessageCreate = {
model: model.id,
messages,
max_tokens: model.maxTokens || DEFAULT_MAX_TOKENS,
stream,
...(model.temperature !== undefined && { temperature: model.temperature }),
...(systemMessage && { system: [{ type: 'text', text: systemMessage }] }),
};
// // Handle tools and tool policy
// if (chatGenerate.tools && chatGenerate.tools.length > 0) {
// payload.tools = chatGenerate.tools.map(tool => ({
// name: tool.name,
// description: tool.description,
// input_schema: {
// type: 'object',
// properties: tool.parameters.properties,
// required: tool.parameters.required,
// },
// }));
//
// if (chatGenerate.toolsPolicy) {
// switch (chatGenerate.toolsPolicy.type) {
// case 'auto':
// payload.tool_choice = { type: 'auto' };
// break;
// case 'any':
// payload.tool_choice = { type: 'any' };
// break;
// case 'force':
// payload.tool_choice = { type: 'tool', name: chatGenerate.toolsPolicy.name };
// break;
// }
// }
// }
// Validate the payload against the schema to ensure correctness
const validated = anthropicWireMessageCreateSchema.safeParse(payload);
if (!validated.success)
throw new Error(`Invalid message sequence for Anthropic models: ${validated.error.errors?.[0]?.message || validated.error}`);
return validated.data;
}
@@ -1,5 +1,5 @@
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaChatCompletionPayload } from '~/modules/llms/server/ollama/ollama.router';
import { anthropicAccess, anthropicMessagesPayloadOrThrow } from '~/modules/llms/server/anthropic/anthropic.router';
import { anthropicAccess } from '~/modules/llms/server/anthropic/anthropic.router';
import { geminiAccess, geminiGenerateContentTextPayload } from '~/modules/llms/server/gemini/gemini.router';
import { openAIAccess, openAIChatCompletionPayload, OpenAIHistorySchema } from '~/modules/llms/server/openai/openai.router';
@@ -8,6 +8,7 @@ import type { IntakeAccess, IntakeChatGenerateRequest, IntakeModel } from '../in
import { createDispatchDemuxer } from './dispatch.demuxers';
import { createDispatchParserAnthropicMessages, createDispatchParserGemini, createDispatchParserOllama, createDispatchParserOpenAI, DispatchParser } from './dispatch.parsers';
import { geminiModelsStreamGenerateContentPath } from './gemini/gemini.wiretypes';
import { NEWanthropicMessagesPayloadOrThrow } from '~/modules/aix/server/dispatch/anthropic/anthropic.adapter';
export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGenerate: IntakeChatGenerateRequest): {
@@ -19,52 +20,54 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
// temporarily re-cast back to history
const _hist: OpenAIHistorySchema = [];
chatGenerate.systemMessage?.parts.forEach(systemPart => {
_hist.push({ role: 'system', content: systemPart.text });
});
chatGenerate.chat.forEach(({ role, parts }) => {
switch (role) {
if (access.dialect !== 'anthropic') {
chatGenerate.systemMessage?.parts.forEach(systemPart => {
_hist.push({ role: 'system', content: systemPart.text });
});
chatGenerate.chat.forEach(({ role, parts }) => {
switch (role) {
case 'user':
parts.forEach(userPart => {
switch (userPart.pt) {
case 'text':
_hist.push({ role: 'user', content: userPart.text });
break;
case 'inline_image':
throw new Error('Inline images are not supported');
case 'doc':
_hist.push({ role: 'user', content: userPart.data.text });
break;
case 'meta_reply_to':
throw new Error('Meta reply to is not supported');
}
});
break;
case 'user':
parts.forEach(userPart => {
switch (userPart.pt) {
case 'text':
_hist.push({ role: 'user', content: userPart.text });
break;
case 'inline_image':
throw new Error('Inline images are not supported');
case 'doc':
_hist.push({ role: 'user', content: userPart.data.text });
break;
case 'meta_reply_to':
throw new Error('Meta reply to is not supported');
}
});
break;
case 'model':
parts.forEach(modelPart => {
switch (modelPart.pt) {
case 'text':
_hist.push({ role: 'assistant', content: modelPart.text });
break;
case 'tool_call':
throw new Error('Tool calls are not supported');
}
});
break;
case 'model':
parts.forEach(modelPart => {
switch (modelPart.pt) {
case 'text':
_hist.push({ role: 'assistant', content: modelPart.text });
break;
case 'tool_call':
throw new Error('Tool calls are not supported');
}
});
break;
case 'tool':
parts.forEach(toolPart => {
switch (toolPart.pt) {
case 'tool_response':
throw new Error('Tool responses are not supported');
}
});
break;
}
});
console.log('converted chatGenerate to history', _hist.length, '<- items');
case 'tool':
parts.forEach(toolPart => {
switch (toolPart.pt) {
case 'tool_response':
throw new Error('Tool responses are not supported');
}
});
break;
}
});
console.log('converted chatGenerate to history', _hist.length, '<- items');
}
switch (access.dialect) {
@@ -72,7 +75,7 @@ export function createDispatch(access: IntakeAccess, model: IntakeModel, chatGen
return {
request: {
...anthropicAccess(access, '/v1/messages'),
body: anthropicMessagesPayloadOrThrow(model, _hist, true),
body: NEWanthropicMessagesPayloadOrThrow(model, chatGenerate, true),
},
demuxer: createDispatchDemuxer('sse'),
parser: createDispatchParserAnthropicMessages(),