mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
MP: fix Beam, Rays and Fusions
This commit is contained in:
@@ -22,7 +22,7 @@ function initTestConversation(): DConversation {
|
||||
}
|
||||
|
||||
function initTestBeamStore(messages: DMessage[], beamStore: BeamStoreApi = createBeamVanillaStore()): BeamStoreApi {
|
||||
beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (text) => alert(text));
|
||||
beamStore.getState().open(messages, useModelsStore.getState().chatLLMId, (content) => alert(content));
|
||||
return beamStore;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import { createBeamVanillaStore } from '~/modules/beam/store-beam-vanilla';
|
||||
|
||||
import { ChatActions, getConversationSystemPurposeId, useChatStore } from '~/common/stores/chat/store-chats';
|
||||
import { DConversationId } from '~/common/stores/chat/chat.conversation';
|
||||
import { createDMessage, createTextPart, DMessage, fixmeThisReplacesAllParts, pendDMessage } from '~/common/stores/chat/chat.message';
|
||||
import { createDMessage, createTextPart, DContentParts, DMessage, pendDMessage } from '~/common/stores/chat/chat.message';
|
||||
|
||||
import { EphemeralHandler, EphemeralsStore } from './EphemeralsStore';
|
||||
import { createChatOverlayVanillaStore } from './store-chat-overlay-vanilla';
|
||||
@@ -118,17 +118,17 @@ export class ConversationHandler {
|
||||
beamInvoke(viewHistory: Readonly<DMessage[]>, importMessages: DMessage[], destReplaceMessageId: DMessage['id'] | null): void {
|
||||
const { open: beamOpen, importRays: beamImportRays, terminateKeepingSettings } = this.beamStore.getState();
|
||||
|
||||
// TODO: we shall get a Message here, rather than a string - it's limiting
|
||||
const onBeamSuccess = (beamText: string, llmId: DLLMId) => {
|
||||
const onBeamSuccess = (content: DContentParts, llmId: DLLMId) => {
|
||||
// set output when going back to the chat
|
||||
if (destReplaceMessageId) {
|
||||
// replace a single message in the conversation history
|
||||
this.messageEdit(destReplaceMessageId, { content: fixmeThisReplacesAllParts(beamText), originLLM: llmId }, true); // [chat] replace assistant:Beam text
|
||||
this.messageEdit(destReplaceMessageId, { content, originLLM: llmId, pendingIncomplete: undefined, pendingPlaceholderText: undefined }, true); // [chat] replace assistant:Beam contentParts
|
||||
} else {
|
||||
// replace (may truncate) the conversation history and append a message
|
||||
const newMessage = createDMessage('assistant', beamText); // [chat] append Beam text
|
||||
const newMessage = createDMessage('assistant', content); // [chat] append Beam contentParts
|
||||
newMessage.originLLM = llmId;
|
||||
newMessage.purposeId = getConversationSystemPurposeId(this.conversationId) ?? undefined;
|
||||
// TODO: put the other rays in the metadata?! (reqby @Techfren)
|
||||
this.messagesReplace([...viewHistory, newMessage]);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ export interface DMessage {
|
||||
id: DMessageId; // unique message ID
|
||||
|
||||
role: DMessageRole;
|
||||
content: DContentPart[]; // multi-part content (sent: mix of text/images/etc., received: usually one part)
|
||||
content: DContentParts; // multi-part content (sent: mix of text/images/etc., received: usually one part)
|
||||
userAttachments: DAttachmentPart[]; // higher-level multi-part to be sent (transformed to multipart before sending)
|
||||
|
||||
// pending state (not stored)
|
||||
@@ -33,6 +33,8 @@ export interface DMessage {
|
||||
updated: number | null; // updated timestamp - null means incomplete - TODO: disambiguate vs pendingIncomplete
|
||||
}
|
||||
|
||||
export type DContentParts = DContentPart[];
|
||||
|
||||
export type DMessageId = string;
|
||||
|
||||
export type DMessageRole = 'user' | 'assistant' | 'system';
|
||||
@@ -52,7 +54,7 @@ type DContentRef =
|
||||
|
||||
// Content Part
|
||||
|
||||
export type DContentPart =
|
||||
type DContentPart =
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'image'; mimeType: string; source: DContentRef }
|
||||
// | { type: 'audio'; mimeType: string; source: DContentRef }
|
||||
@@ -87,7 +89,7 @@ export type DMessageUserFlag =
|
||||
|
||||
// helpers - creation
|
||||
|
||||
export function createDMessage(role: DMessageRole, content?: string | DContentPart[]): DMessage {
|
||||
export function createDMessage(role: DMessageRole, content?: string | DContentParts): DMessage {
|
||||
|
||||
// ensure content is an array
|
||||
if (content === undefined)
|
||||
@@ -202,6 +204,15 @@ export function convertDMessage_V3_V4(message: DMessage) {
|
||||
|
||||
// helpers - text
|
||||
|
||||
export function reduceContentToText(content: DContentParts): string {
|
||||
const partTextSeparator = '\n\n';
|
||||
return content.map(part => {
|
||||
if (part.type === 'text')
|
||||
return part.text;
|
||||
return '';
|
||||
}).join(partTextSeparator);
|
||||
}
|
||||
|
||||
export function singleTextOrThrow(message: DMessage): string {
|
||||
if (message.content.length !== 1)
|
||||
throw new Error('Expected single content');
|
||||
@@ -210,7 +221,7 @@ export function singleTextOrThrow(message: DMessage): string {
|
||||
return message.content[0].text;
|
||||
}
|
||||
|
||||
export function singleTextOrThrow2(content?: DContentPart[]): string {
|
||||
export function singleTextOrThrow2(content?: DContentParts): string {
|
||||
if (!content || content.length !== 1)
|
||||
throw new Error('Expected single content');
|
||||
if (content[0].type !== 'text')
|
||||
@@ -219,7 +230,7 @@ export function singleTextOrThrow2(content?: DContentPart[]): string {
|
||||
}
|
||||
|
||||
// zustand-like deep replace
|
||||
export function contentPartsReplaceText(message: DMessage, newText: string): DContentPart[] {
|
||||
export function contentPartsReplaceText(message: DMessage, newText: string): DContentParts {
|
||||
const lastTextPart = message.content.findLast(part => part.type === 'text');
|
||||
if (!lastTextPart)
|
||||
return [...message.content, createTextPart(newText)];
|
||||
@@ -230,7 +241,7 @@ export function contentPartsReplaceText(message: DMessage, newText: string): DCo
|
||||
);
|
||||
}
|
||||
|
||||
export function fixmeThisReplacesAllParts(text: string): DContentPart[] {
|
||||
export function fixmeThisReplacesAllParts(text: string): DContentParts {
|
||||
return [createTextPart(text)];
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip';
|
||||
import { InlineError } from '~/common/components/InlineError';
|
||||
import { animationEnterBelow } from '~/common/util/animUtils';
|
||||
import { copyToClipboard } from '~/common/util/clipboardUtils';
|
||||
import { reduceContentToText } from '~/common/stores/chat/chat.message';
|
||||
|
||||
import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard';
|
||||
import { BeamStoreApi, useBeamStore } from '../store-beam.hooks';
|
||||
@@ -68,16 +69,16 @@ export function Fusion(props: {
|
||||
const handleFusionCopy = React.useCallback(() => {
|
||||
const { fusions } = props.beamStore.getState();
|
||||
const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId);
|
||||
if (fusion?.outputDMessage?.text)
|
||||
copyToClipboard(fusion.outputDMessage.text, 'Merge');
|
||||
if (fusion?.outputDMessage?.content.length)
|
||||
copyToClipboard(reduceContentToText(fusion.outputDMessage.content), 'Merge');
|
||||
}, [props.beamStore, props.fusionId]);
|
||||
|
||||
const handleFusionUse = React.useCallback(() => {
|
||||
// get snapshot values, so we don't have to react to the hook
|
||||
const { fusions, onSuccessCallback } = props.beamStore.getState();
|
||||
const fusion = fusions.find(fusion => fusion.fusionId === props.fusionId);
|
||||
if (fusion?.outputDMessage?.text && onSuccessCallback)
|
||||
onSuccessCallback(fusion.outputDMessage.text, fusion.llmId || '');
|
||||
if (fusion?.outputDMessage?.content.length && onSuccessCallback)
|
||||
onSuccessCallback(fusion.outputDMessage.content, fusion.llmId || '');
|
||||
}, [props.beamStore, props.fusionId]);
|
||||
|
||||
|
||||
@@ -137,7 +138,7 @@ export function Fusion(props: {
|
||||
{!!fusion?.fusingInstructionComponent && fusion.fusingInstructionComponent}
|
||||
|
||||
{/* Output Message */}
|
||||
{(!!fusion?.outputDMessage?.text || fusion?.stage === 'fusing') && (
|
||||
{(!!fusion?.outputDMessage?.content.length || fusion?.stage === 'fusing') && (
|
||||
<Box sx={beamCardMessageWrapperSx}>
|
||||
{!!fusion.outputDMessage && (
|
||||
<ChatMessageMemo
|
||||
|
||||
@@ -7,7 +7,6 @@ import type { DLLMId } from '~/modules/llms/store-llms';
|
||||
import type { DMessage } from '~/common/stores/chat/chat.message';
|
||||
|
||||
import { CUSTOM_FACTORY_ID, FFactoryId, findFusionFactory, FUSION_FACTORIES, FUSION_FACTORY_DEFAULT } from './instructions/beam.gather.factories';
|
||||
import { GATHER_PLACEHOLDER } from '../beam.config';
|
||||
import { RootStoreSlice } from '../store-beam-vanilla';
|
||||
import { ScatterStoreSlice } from '../scatter/beam.scatter';
|
||||
import { gatherStartFusion, gatherStopFusion, Instruction } from './instructions/beam.gather.execution';
|
||||
@@ -84,8 +83,7 @@ export function fusionIsStopped(fusion: BFusion | null): boolean {
|
||||
}
|
||||
|
||||
export function fusionIsUsableOutput(fusion: BFusion | null): boolean {
|
||||
const message = fusion?.outputDMessage ?? null;
|
||||
return !!message && !!message.updated && !!message.text && message.text !== GATHER_PLACEHOLDER;
|
||||
return !!fusion?.outputDMessage?.content.length;
|
||||
}
|
||||
|
||||
export function fusionIsError(fusion: BFusion | null): boolean {
|
||||
@@ -261,7 +259,7 @@ export const createGatherSlice: StateCreator<RootStoreSlice & ScatterStoreSlice
|
||||
// start the fusion
|
||||
const { inputHistory, rays, _fusionUpdate } = _get();
|
||||
const chatMessages = inputHistory ? [...inputHistory] : [];
|
||||
const rayMessages = rays.map(ray => ray.message).filter(message => !!message.text.trim());
|
||||
const rayMessages = rays.map(ray => ray.message).filter(message => !!message.content.length);
|
||||
const onUpdate = (update: FusionUpdateOrFn) => _fusionUpdate(fusion.fusionId, update);
|
||||
gatherStartFusion(fusion, chatMessages, rayMessages, onUpdate);
|
||||
},
|
||||
|
||||
@@ -18,6 +18,7 @@ import { GoodTooltip } from '~/common/components/GoodTooltip';
|
||||
import { InlineError } from '~/common/components/InlineError';
|
||||
import { animationEnterBelow } from '~/common/util/animUtils';
|
||||
import { copyToClipboard } from '~/common/util/clipboardUtils';
|
||||
import { reduceContentToText } from '~/common/stores/chat/chat.message';
|
||||
import { useLLMSelect } from '~/common/components/forms/useLLMSelect';
|
||||
|
||||
import { BeamCard, beamCardClasses, beamCardMessageScrollingSx, beamCardMessageSx, beamCardMessageWrapperSx } from '../BeamCard';
|
||||
@@ -148,16 +149,16 @@ export function BeamRay(props: {
|
||||
const handleRayCopy = React.useCallback(() => {
|
||||
const { rays } = props.beamStore.getState();
|
||||
const ray = rays.find(ray => ray.rayId === props.rayId);
|
||||
if (ray?.message?.text)
|
||||
copyToClipboard(ray.message.text, 'Beam');
|
||||
if (ray?.message.content.length)
|
||||
copyToClipboard(reduceContentToText(ray.message.content), 'Response');
|
||||
}, [props.beamStore, props.rayId]);
|
||||
|
||||
const handleRayUse = React.useCallback(() => {
|
||||
// get snapshot values, so we don't have to react to the hook
|
||||
const { rays, onSuccessCallback } = props.beamStore.getState();
|
||||
const ray = rays.find(ray => ray.rayId === props.rayId);
|
||||
if (ray?.message?.text && onSuccessCallback)
|
||||
onSuccessCallback(ray.message.text, llmId || '');
|
||||
if (ray && ray.message.content.length && onSuccessCallback)
|
||||
onSuccessCallback(ray.message.content, llmId || '');
|
||||
}, [llmId, props.beamStore, props.rayId]);
|
||||
|
||||
const handleRayRemove = React.useCallback(() => {
|
||||
@@ -200,7 +201,7 @@ export function BeamRay(props: {
|
||||
{!!ray?.scatterIssue && <InlineError error={ray.scatterIssue} />}
|
||||
|
||||
{/* Ray Message */}
|
||||
{(!!ray?.message?.text || ray?.status === 'scattering') && (
|
||||
{(!!ray?.message?.content.length || ray?.status === 'scattering') && (
|
||||
<Box sx={beamCardMessageWrapperSx}>
|
||||
{!!ray.message && (
|
||||
<ChatMessageMemo
|
||||
|
||||
@@ -2,7 +2,7 @@ import { createStore, StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { DLLMId, getDiverseTopLlmIds } from '~/modules/llms/store-llms';
|
||||
|
||||
import { DMessage } from '~/common/stores/chat/chat.message';
|
||||
import { DContentParts, DMessage } from '~/common/stores/chat/chat.message';
|
||||
|
||||
import { BeamConfigSnapshot, useModuleBeamStore } from './store-module-beam';
|
||||
import { SCATTER_RAY_DEF } from './beam.config';
|
||||
@@ -26,7 +26,7 @@ export const createBeamVanillaStore = () => createStore<BeamStore>()((...a) => (
|
||||
|
||||
/// Common Store Slice ///
|
||||
|
||||
type BeamSuccessCallback = (text: string, llmId: DLLMId) => void;
|
||||
type BeamSuccessCallback = (content: DContentParts, llmId: DLLMId) => void;
|
||||
|
||||
interface RootStateSlice {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user