MP: fix Beam, Rays and Fusions

This commit is contained in:
Enrico Ros
2024-05-17 00:07:50 -07:00
parent 3aea29bcb5
commit 3cef39da17
7 changed files with 39 additions and 28 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ function initTestConversation(): DConversation {
}
function initTestBeamStore(messages: DMessage[], beamStore: BeamStoreApi = createBeamVanillaStore()): BeamStoreApi {
beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (text) => alert(text));
beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (content) => alert(content));
return beamStore;
}
+5 -5
View File
@@ -7,7 +7,7 @@ import { createBeamVanillaStore } from '~/modules/beam/store-beam-vanilla';
import { ChatActions, getConversationSystemPurposeId, useChatStore } from '~/common/stores/chat/store-chats';
import { DConversationId } from '~/common/stores/chat/chat.conversation';
import { createDMessage, createTextPart, DMessage, fixmeThisReplacesAllParts, pendDMessage } from '~/common/stores/chat/chat.message';
import { createDMessage, createTextPart, DContentParts, DMessage, pendDMessage } from '~/common/stores/chat/chat.message';
import { EphemeralHandler, EphemeralsStore } from './EphemeralsStore';
import { createChatOverlayVanillaStore } from './store-chat-overlay-vanilla';
@@ -118,17 +118,17 @@ export class ConversationHandler {
beamInvoke(viewHistory: Readonly<DMessage[]>, importMessages: DMessage[], destReplaceMessageId: DMessage['id'] | null): void {
const { open: beamOpen, importRays: beamImportRays, terminateKeepingSettings } = this.beamStore.getState();
// TODO: we shall get a Message here, rather than a string - it's limiting
const onBeamSuccess = (beamText: string, llmId: DLLMId) => {
const onBeamSuccess = (content: DContentParts, llmId: DLLMId) => {
// set output when going back to the chat
if (destReplaceMessageId) {
// replace a single message in the conversation history
this.messageEdit(destReplaceMessageId, { content: fixmeThisReplacesAllParts(beamText), originLLM: llmId }, true); // [chat] replace assistant:Beam text
this.messageEdit(destReplaceMessageId, { content, originLLM: llmId, pendingIncomplete: undefined, pendingPlaceholderText: undefined }, true); // [chat] replace assistant:Beam contentParts
} else {
// replace (may truncate) the conversation history and append a message
const newMessage = createDMessage('assistant', beamText); // [chat] append Beam text
const newMessage = createDMessage('assistant', content); // [chat] append Beam contentParts
newMessage.originLLM = llmId;
newMessage.purposeId = getConversationSystemPurposeId(this.conversationId) ?? undefined;
// TODO: put the other rays in the metadata?! (reqby @Techfren)
this.messagesReplace([...viewHistory, newMessage]);
}
+17 -6
View File
@@ -9,7 +9,7 @@ export interface DMessage {
id: DMessageId; // unique message ID
role: DMessageRole;
content: DContentPart[]; // multi-part content (sent: mix of text/images/etc., received: usually one part)
content: DContentParts; // multi-part content (sent: mix of text/images/etc., received: usually one part)
userAttachments: DAttachmentPart[]; // higher-level multi-part to be sent (transformed to multipart before sending)
// pending state (not stored)
@@ -33,6 +33,8 @@ export interface DMessage {
updated: number | null; // updated timestamp - null means incomplete - TODO: disambiguate vs pendingIncomplete
}
export type DContentParts = DContentPart[];
export type DMessageId = string;
export type DMessageRole = 'user' | 'assistant' | 'system';
@@ -52,7 +54,7 @@ type DContentRef =
// Content Part
export type DContentPart =
type DContentPart =
| { type: 'text'; text: string }
| { type: 'image'; mimeType: string; source: DContentRef }
// | { type: 'audio'; mimeType: string; source: DContentRef }
@@ -87,7 +89,7 @@ export type DMessageUserFlag =
// helpers - creation
export function createDMessage(role: DMessageRole, content?: string | DContentPart[]): DMessage {
export function createDMessage(role: DMessageRole, content?: string | DContentParts): DMessage {
// ensure content is an array
if (content === undefined)
@@ -202,6 +204,15 @@ export function convertDMessage_V3_V4(message: DMessage) {
// helpers - text
export function reduceContentToText(content: DContentParts): string {
const partTextSeparator = '\n\n';
return content.map(part => {
if (part.type === 'text')
return part.text;
return '';
}).join(partTextSeparator);
}
export function singleTextOrThrow(message: DMessage): string {
if (message.content.length !== 1)
throw new Error('Expected single content');
@@ -210,7 +221,7 @@ export function singleTextOrThrow(message: DMessage): string {
return message.content[0].text;
}
export function singleTextOrThrow2(content?: DContentPart[]): string {
export function singleTextOrThrow2(content?: DContentParts): string {
if (!content || content.length !== 1)
throw new Error('Expected single content');
if (content[0].type !== 'text')
@@ -219,7 +230,7 @@ export function singleTextOrThrow2(content?: DContentPart[]): string {
}
// zustand-like deep replace
export function contentPartsReplaceText(message: DMessage, newText: string): DContentPart[] {
export function contentPartsReplaceText(message: DMessage, newText: string): DContentParts {
const lastTextPart = message.content.findLast(part => part.type === 'text');
if (!lastTextPart)
return [...message.content, createTextPart(newText)];
@@ -230,7 +241,7 @@ export function contentPartsReplaceText(message: DMessage, newText: string): DCo
);
}
export function fixmeThisReplacesAllParts(text: string): DContentPart[] {
export function fixmeThisReplacesAllParts(text: string): DContentParts {
return [createTextPart(text)];
}
+6 -5
View File
@@ -13,6 +13,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip';
import { InlineError } from '~/common/components/InlineError';
import { animationEnterBelow } from '~/common/util/animUtils';
import { copyToClipboard } from '~/common/util/clipboardUtils';
import { reduceContentToText } from '~/common/stores/chat/chat.message';
import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard';
import { BeamStoreApi, useBeamStore } from '../store-beam.hooks';
@@ -68,16 +69,16 @@ export function Fusion(props: {
const handleFusionCopy = React.useCallback(() => {
const { fusions } = props.beamStore.getState();
const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId);
if (fusion?.outputDMessage?.text)
copyToClipboard(fusion.outputDMessage.text, 'Merge');
if (fusion?.outputDMessage?.content.length)
copyToClipboard(reduceContentToText(fusion.outputDMessage.content), 'Merge');
}, [props.beamStore, props.fusionId]);
const handleFusionUse = React.useCallback(() => {
// get snapshot values, so we don't have to react to the hook
const { fusions, onSuccessCallback } = props.beamStore.getState();
const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId);
if (fusion?.outputDMessage?.text && onSuccessCallback)
onSuccessCallback(fusion.outputDMessage.text, fusion.llmId || '');
if (fusion?.outputDMessage?.content.length && onSuccessCallback)
onSuccessCallback(fusion.outputDMessage.content, fusion.llmId || '');
}, [props.beamStore, props.fusionId]);
@@ -137,7 +138,7 @@ export function Fusion(props: {
{!!fusion?.fusingInstructionComponent && fusion.fusingInstructionComponent}
{/* Output Message */}
{(!!fusion?.outputDMessage?.text || fusion?.stage === 'fusing') && (
{(!!fusion?.outputDMessage?.content.length || fusion?.stage === 'fusing') && (
<Box sx={beamCardMessageWrapperSx}>
{!!fusion.outputDMessage && (
<ChatMessageMemo
+2 -4
View File
@@ -7,7 +7,6 @@ import type { DLLMId } from '~/modules/llms/store-llms';
import type { DMessage } from '~/common/stores/chat/chat.message';
import { CUSTOM_FACTORY_ID, FFactoryId, findFusionFactory, FUSION_FACTORIES, FUSION_FACTORY_DEFAULT } from './instructions/beam.gather.factories';
import { GATHER_PLACEHOLDER } from '../beam.config';
import { RootStoreSlice } from '../store-beam-vanilla';
import { ScatterStoreSlice } from '../scatter/beam.scatter';
import { gatherStartFusion, gatherStopFusion, Instruction } from './instructions/beam.gather.execution';
@@ -84,8 +83,7 @@ export function fusionIsStopped(fusion: BFusion | null): boolean {
}
export function fusionIsUsableOutput(fusion: BFusion | null): boolean {
const message = fusion?.outputDMessage ?? null;
return !!message && !!message.updated && !!message.text && message.text !== GATHER_PLACEHOLDER;
return !!fusion?.outputDMessage?.content.length;
}
export function fusionIsError(fusion: BFusion | null): boolean {
@@ -261,7 +259,7 @@ export const createGatherSlice: StateCreator<RootStoreSlice & ScatterStoreSlice
// start the fusion
const { inputHistory, rays, _fusionUpdate } = _get();
const chatMessages = inputHistory ? [...inputHistory] : [];
const rayMessages = rays.map(ray => ray.message).filter(message => !!message.text.trim());
const rayMessages = rays.map(ray => ray.message).filter(message => !!message.content.length);
const onUpdate = (update: FusionUpdateOrFn) => _fusionUpdate(fusion.fusionId, update);
gatherStartFusion(fusion, chatMessages, rayMessages, onUpdate);
},
+6 -5
View File
@@ -18,6 +18,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip';
import { InlineError } from '~/common/components/InlineError';
import { animationEnterBelow } from '~/common/util/animUtils';
import { copyToClipboard } from '~/common/util/clipboardUtils';
import { reduceContentToText } from '~/common/stores/chat/chat.message';
import { useLLMSelect } from '~/common/components/forms/useLLMSelect';
import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard';
@@ -148,16 +149,16 @@ export function BeamRay(props: {
const handleRayCopy = React.useCallback(() => {
const { rays } = props.beamStore.getState();
const ray = rays.find(ray => ray.rayId === props.rayId);
if (ray?.message?.text)
copyToClipboard(ray.message.text, 'Beam');
if (ray?.message.content.length)
copyToClipboard(reduceContentToText(ray.message.content), 'Response');
}, [props.beamStore, props.rayId]);
const handleRayUse = React.useCallback(() => {
// get snapshot values, so we don't have to react to the hook
const { rays, onSuccessCallback } = props.beamStore.getState();
const ray = rays.find(ray => ray.rayId === props.rayId);
if (ray?.message?.text && onSuccessCallback)
onSuccessCallback(ray.message.text, llmId || '');
if (ray && ray.message.content.length && onSuccessCallback)
onSuccessCallback(ray.message.content, llmId || '');
}, [llmId, props.beamStore, props.rayId]);
const handleRayRemove = React.useCallback(() => {
@@ -200,7 +201,7 @@ export function BeamRay(props: {
{!!ray?.scatterIssue && <InlineError error={ray.scatterIssue} />}
{/* Ray Message */}
{(!!ray?.message?.text || ray?.status === 'scattering') && (
{(!!ray?.message?.content.length || ray?.status === 'scattering') && (
<Box sx={beamCardMessageWrapperSx}>
{!!ray.message && (
<ChatMessageMemo
+2 -2
View File
@@ -2,7 +2,7 @@ import { createStore, StateCreator } from 'zustand/vanilla';
import { DLLMId, getDiverseTopLlmIds } from '~/modules/llms/store-llms';
import { DMessage } from '~/common/stores/chat/chat.message';
import { DContentParts, DMessage } from '~/common/stores/chat/chat.message';
import { BeamConfigSnapshot, useModuleBeamStore } from './store-module-beam';
import { SCATTER_RAY_DEF } from './beam.config';
@@ -26,7 +26,7 @@ export const createBeamVanillaStore = () => createStore<BeamStore>()((...a) => (
/// Common Store Slice ///
type BeamSuccessCallback = (text: string, llmId: DLLMId) => void;
type BeamSuccessCallback = (content: DContentParts, llmId: DLLMId) => void;
interface RootStateSlice {