From 3cef39da17f8b42a5653ebd2bc8afc49008bc3ac Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Fri, 17 May 2024 00:07:50 -0700 Subject: [PATCH] MP: fix Beam, Rays and Fusions --- src/apps/beam/AppBeam.tsx | 2 +- src/common/chats/ConversationHandler.ts | 10 +++++----- src/common/stores/chat/chat.message.ts | 23 +++++++++++++++++------ src/modules/beam/gather/Fusion.tsx | 11 ++++++----- src/modules/beam/gather/beam.gather.ts | 6 ++---- src/modules/beam/scatter/BeamRay.tsx | 11 ++++++----- src/modules/beam/store-beam-vanilla.ts | 4 ++-- 7 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/apps/beam/AppBeam.tsx b/src/apps/beam/AppBeam.tsx index b5a95bc32..5c3b835db 100644 --- a/src/apps/beam/AppBeam.tsx +++ b/src/apps/beam/AppBeam.tsx @@ -22,7 +22,7 @@ function initTestConversation(): DConversation { } function initTestBeamStore(messages: DMessage[], beamStore: BeamStoreApi = createBeamVanillaStore()): BeamStoreApi { - beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (text) => alert(text)); + beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (content) => alert(content)); return beamStore; } diff --git a/src/common/chats/ConversationHandler.ts b/src/common/chats/ConversationHandler.ts index a49cc0a6e..3b03791d4 100644 --- a/src/common/chats/ConversationHandler.ts +++ b/src/common/chats/ConversationHandler.ts @@ -7,7 +7,7 @@ import { createBeamVanillaStore } from '~/modules/beam/store-beam-vanilla'; import { ChatActions, getConversationSystemPurposeId, useChatStore } from '~/common/stores/chat/store-chats'; import { DConversationId } from '~/common/stores/chat/chat.conversation'; -import { createDMessage, createTextPart, DMessage, fixmeThisReplacesAllParts, pendDMessage } from '~/common/stores/chat/chat.message'; +import { createDMessage, createTextPart, DContentParts, DMessage, pendDMessage } from '~/common/stores/chat/chat.message'; import { EphemeralHandler, EphemeralsStore } from './EphemeralsStore'; import { createChatOverlayVanillaStore } from './store-chat-overlay-vanilla'; @@ -118,17 +118,17 @@ export class ConversationHandler { beamInvoke(viewHistory: Readonly, importMessages: DMessage[], destReplaceMessageId: DMessage['id'] | null): void { const { open: beamOpen, importRays: beamImportRays, terminateKeepingSettings } = this.beamStore.getState(); - // TODO: we shall get a Message here, rather than a string - it's limiting - const onBeamSuccess = (beamText: string, llmId: DLLMId) => { + const onBeamSuccess = (content: DContentParts, llmId: DLLMId) => { // set output when going back to the chat if (destReplaceMessageId) { // replace a single message in the conversation history - this.messageEdit(destReplaceMessageId, { content: fixmeThisReplacesAllParts(beamText), originLLM: llmId }, true); // [chat] replace assistant:Beam text + this.messageEdit(destReplaceMessageId, { content, originLLM: llmId, pendingIncomplete: undefined, pendingPlaceholderText: undefined }, true); // [chat] replace assistant:Beam contentParts } else { // replace (may truncate) the conversation history and append a message - const newMessage = createDMessage('assistant', beamText); // [chat] append Beam text + const newMessage = createDMessage('assistant', content); // [chat] append Beam contentParts newMessage.originLLM = llmId; newMessage.purposeId = getConversationSystemPurposeId(this.conversationId) ?? undefined; + // TODO: put the other rays in the metadata?! (reqby @Techfren) this.messagesReplace([...viewHistory, newMessage]); } diff --git a/src/common/stores/chat/chat.message.ts b/src/common/stores/chat/chat.message.ts index 88ca2ed61..d93eb2e36 100644 --- a/src/common/stores/chat/chat.message.ts +++ b/src/common/stores/chat/chat.message.ts @@ -9,7 +9,7 @@ export interface DMessage { id: DMessageId; // unique message ID role: DMessageRole; - content: DContentPart[]; // multi-part content (sent: mix of text/images/etc., received: usually one part) + content: DContentParts; // multi-part content (sent: mix of text/images/etc., received: usually one part) userAttachments: DAttachmentPart[]; // higher-level multi-part to be sent (transformed to multipart before sending) // pending state (not stored) @@ -33,6 +33,8 @@ export interface DMessage { updated: number | null; // updated timestamp - null means incomplete - TODO: disambiguate vs pendingIncomplete } +export type DContentParts = DContentPart[]; + export type DMessageId = string; export type DMessageRole = 'user' | 'assistant' | 'system'; @@ -52,7 +54,7 @@ type DContentRef = // Content Part -export type DContentPart = +type DContentPart = | { type: 'text'; text: string } | { type: 'image'; mimeType: string; source: DContentRef } // | { type: 'audio'; mimeType: string; source: DContentRef } @@ -87,7 +89,7 @@ export type DMessageUserFlag = // helpers - creation -export function createDMessage(role: DMessageRole, content?: string | DContentPart[]): DMessage { +export function createDMessage(role: DMessageRole, content?: string | DContentParts): DMessage { // ensure content is an array if (content === undefined) @@ -202,6 +204,15 @@ export function convertDMessage_V3_V4(message: DMessage) { // helpers - text +export function reduceContentToText(content: DContentParts): string { + const partTextSeparator = '\n\n'; + return content.map(part => { + if (part.type === 'text') + return part.text; + return ''; + }).join(partTextSeparator); +} + export function singleTextOrThrow(message: DMessage): string { if (message.content.length !== 1) throw new Error('Expected single content'); @@ -210,7 +221,7 @@ export function singleTextOrThrow(message: DMessage): string { return message.content[0].text; } -export function singleTextOrThrow2(content?: DContentPart[]): string { +export function singleTextOrThrow2(content?: DContentParts): string { if (!content || content.length !== 1) throw new Error('Expected single content'); if (content[0].type !== 'text') @@ -219,7 +230,7 @@ export function singleTextOrThrow2(content?: DContentPart[]): string { } // zustand-like deep replace -export function contentPartsReplaceText(message: DMessage, newText: string): DContentPart[] { +export function contentPartsReplaceText(message: DMessage, newText: string): DContentParts { const lastTextPart = message.content.findLast(part => part.type === 'text'); if (!lastTextPart) return [...message.content, createTextPart(newText)]; @@ -230,7 +241,7 @@ export function contentPartsReplaceText(message: DMessage, newText: string): DCo ); } -export function fixmeThisReplacesAllParts(text: string): DContentPart[] { +export function fixmeThisReplacesAllParts(text: string): DContentParts { return [createTextPart(text)]; } diff --git a/src/modules/beam/gather/Fusion.tsx b/src/modules/beam/gather/Fusion.tsx index 57782e83e..8b381176c 100644 --- a/src/modules/beam/gather/Fusion.tsx +++ b/src/modules/beam/gather/Fusion.tsx @@ -13,6 +13,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip'; import { InlineError } from '~/common/components/InlineError'; import { animationEnterBelow } from '~/common/util/animUtils'; import { copyToClipboard } from '~/common/util/clipboardUtils'; +import { reduceContentToText } from '~/common/stores/chat/chat.message'; import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard'; import { BeamStoreApi, useBeamStore } from '../store-beam.hooks'; @@ -68,16 +69,16 @@ export function Fusion(props: { const handleFusionCopy = React.useCallback(() => { const { fusions } = props.beamStore.getState(); const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId); - if (fusion?.outputDMessage?.text) - copyToClipboard(fusion.outputDMessage.text, 'Merge'); + if (fusion?.outputDMessage?.content.length) + copyToClipboard(reduceContentToText(fusion.outputDMessage.content), 'Merge'); }, [props.beamStore, props.fusionId]); const handleFusionUse = React.useCallback(() => { // get snapshot values, so we don't have to react to the hook const { fusions, onSuccessCallback } = props.beamStore.getState(); const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId); - if (fusion?.outputDMessage?.text && onSuccessCallback) - onSuccessCallback(fusion.outputDMessage.text, fusion.llmId || ''); + if (fusion?.outputDMessage?.content.length && onSuccessCallback) + onSuccessCallback(fusion.outputDMessage.content, fusion.llmId || ''); }, [props.beamStore, props.fusionId]); @@ -137,7 +138,7 @@ export function Fusion(props: { {!!fusion?.fusingInstructionComponent && fusion.fusingInstructionComponent} {/* Output Message */} - {(!!fusion?.outputDMessage?.text || fusion?.stage === 'fusing') && ( + {(!!fusion?.outputDMessage?.content.length || fusion?.stage === 'fusing') && ( {!!fusion.outputDMessage && ( ray.message).filter(message => !!message.text.trim()); + const rayMessages = rays.map(ray => ray.message).filter(message => !!message.content.length); const onUpdate = (update: FusionUpdateOrFn) => _fusionUpdate(fusion.fusionId, update); gatherStartFusion(fusion, chatMessages, rayMessages, onUpdate); }, diff --git a/src/modules/beam/scatter/BeamRay.tsx b/src/modules/beam/scatter/BeamRay.tsx index 1df103ab6..f026a27d5 100644 --- a/src/modules/beam/scatter/BeamRay.tsx +++ b/src/modules/beam/scatter/BeamRay.tsx @@ -18,6 +18,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip'; import { InlineError } from '~/common/components/InlineError'; import { animationEnterBelow } from '~/common/util/animUtils'; import { copyToClipboard } from '~/common/util/clipboardUtils'; +import { reduceContentToText } from '~/common/stores/chat/chat.message'; import { useLLMSelect } from '~/common/components/forms/useLLMSelect'; import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard'; @@ -148,16 +149,16 @@ export function BeamRay(props: { const handleRayCopy = React.useCallback(() => { const { rays } = props.beamStore.getState(); const ray = rays.find(ray => ray.rayId === props.rayId); - if (ray?.message?.text) - copyToClipboard(ray.message.text, 'Beam'); + if (ray?.message.content.length) + copyToClipboard(reduceContentToText(ray.message.content), 'Response'); }, [props.beamStore, props.rayId]); const handleRayUse = React.useCallback(() => { // get snapshot values, so we don't have to react to the hook const { rays, onSuccessCallback } = props.beamStore.getState(); const ray = rays.find(ray => ray.rayId === props.rayId); - if (ray?.message?.text && onSuccessCallback) - onSuccessCallback(ray.message.text, llmId || ''); + if (ray && ray.message.content.length && onSuccessCallback) + onSuccessCallback(ray.message.content, llmId || ''); }, [llmId, props.beamStore, props.rayId]); const handleRayRemove = React.useCallback(() => { @@ -200,7 +201,7 @@ export function BeamRay(props: { {!!ray?.scatterIssue && } {/* Ray Message */} - {(!!ray?.message?.text || ray?.status === 'scattering') && ( + {(!!ray?.message?.content.length || ray?.status === 'scattering') && ( {!!ray.message && ( createStore()((...a) => ( /// Common Store Slice /// -type BeamSuccessCallback = (text: string, llmId: DLLMId) => void; +type BeamSuccessCallback = (content: DContentParts, llmId: DLLMId) => void; interface RootStateSlice {