Beam: Gather: enable customization

This commit is contained in:
Enrico Ros
2024-03-19 18:12:38 -07:00
parent ea109e6c30
commit f171cd4f03
5 changed files with 195 additions and 107 deletions
+4 -3
View File
@@ -29,7 +29,7 @@ export function BeamView(props: {
const {
/* root */ editInputHistoryMessage,
/* scatter */ setRayCount, startScatteringAll, stopScatteringAll,
/* gather */ setFusionIndex, setFusionLlmId, startFusion, stopFusion,
/* gather */ setFusionIndex, setFusionLlmId, fusionCustomize, fusionStart, fusionStop,
} = props.beamStore.getState();
const {
/* root */ inputHistory, inputIssues, inputReady,
@@ -140,8 +140,9 @@ export function BeamView(props: {
isMobile={props.isMobile}
fusionIndex={fusionIndex}
setFusionIndex={setFusionIndex}
onStartFusion={startFusion}
onStopFusion={stopFusion}
onFusionCustomize={fusionCustomize}
onFusionStart={fusionStart}
onFusionStop={fusionStop}
/>
{/* Fusion Output */}
+91 -57
View File
@@ -6,8 +6,10 @@ import { Box, Typography } from '@mui/joy';
import { ChatMessageMemo } from '../../../apps/chat/components/message/ChatMessage';
import { createDMessage } from '~/common/state/store-chats';
import { BeamStoreApi, useBeamStore } from '~/common/beam/store-beam.hooks';
import { GATHER_DEBUG_NONCUSTOM } from '~/common/beam/beam.config';
import type { TInstruction } from './beam.gather';
import { BeamStoreApi, useBeamStore } from '../store-beam.hooks';
import { GATHER_DEBUG_NONCUSTOM } from '../beam.config';
const gatherConfigWrapperSx: SxProps = {
@@ -34,6 +36,79 @@ const configChatInstructionSx: SxProps = {
flex: 1,
};
function InstructionWrapper(props: { children: React.ReactNode }) {
return (
<Box sx={{ display: 'flex', flexDirection: 'column', mx: 'var(--Pad)' }}>
{props.children}
</Box>
);
}
function ReadOnlyInstruction(props: { instruction: TInstruction, isMobile: boolean }) {
const { instruction } = props;
// render 'chat-generate'
if (instruction.type === 'chat-generate') {
return (
<InstructionWrapper>
<Box sx={{ display: 'flex', alignItems: 'center' }}>
<Typography level='body-xs'>
System Prompt:
</Typography>
<ChatMessageMemo
message={createDMessage('assistant', instruction.systemPrompt)}
fitScreen={props.isMobile}
showAvatar={false}
adjustContentScaling={-1}
sx={configChatInstructionSx}
/>
</Box>
<Box sx={{ display: 'flex', alignItems: 'center' }}>
<Typography level='body-xs'>
User Prompt:
</Typography>
<ChatMessageMemo
message={createDMessage('assistant', instruction.userPrompt)}
fitScreen={props.isMobile}
showAvatar={false}
adjustContentScaling={-1}
sx={configChatInstructionSx}
/>
</Box>
</InstructionWrapper>
);
}
// render 'user-input-checklist'
if (instruction.type === 'user-input-checklist') {
return (
<InstructionWrapper>
<Box sx={{ display: 'flex', alignItems: 'center' }}>
<Typography level='body-xs'>
Checklist:
</Typography>
<ChatMessageMemo
message={createDMessage('assistant', '#### Test\n- [ ] test\n- [ ] ...')}
fitScreen={props.isMobile}
showAvatar={false}
adjustContentScaling={-1}
sx={configChatInstructionSx}
/>
</Box>
</InstructionWrapper>
);
}
return (
<InstructionWrapper>
<Typography level='body-xs'>
Unknown Instruction
</Typography>
</InstructionWrapper>
);
}
export function BeamGatherConfig(props: {
beamStore: BeamStoreApi
@@ -41,7 +116,7 @@ export function BeamGatherConfig(props: {
}) {
// state
const [viewInstructionIndex, setViewInstructionIndex] = React.useState(0);
// const [viewInstructionIndex, setViewInstructionIndex] = React.useState(0);
// external state
const fusion = useBeamStore(props.beamStore, store => {
@@ -49,64 +124,23 @@ export function BeamGatherConfig(props: {
return (fusion?.isEditable || GATHER_DEBUG_NONCUSTOM) ? fusion : null;
});
// [effect] sync the fusion program index to the viewInstructionIndex
React.useEffect(() => {
fusion && setViewInstructionIndex(fusion.currentInstructionIndex);
}, [fusion]);
// // [effect] sync the fusion program index to the viewInstructionIndex
// React.useEffect(() => {
// fusion && setViewInstructionIndex(fusion.currentInstructionIndex);
// }, [fusion]);
// derived state
const instruction = React.useMemo(() => {
return fusion?.instructions[viewInstructionIndex] ?? null;
}, [fusion, viewInstructionIndex]);
// const instructions = React.useMemo(() => {
// return fusion?.instructions ?? null;
// }, [fusion]);
const instructions = fusion?.instructions ?? null;
// render instruction
const instructionComponent = React.useMemo(() => {
if (instruction && instruction.type === 'chat-generate') {
return <>
{instruction.systemPrompt && (
<Box sx={{ display: 'flex', alignItems: 'center', mx: 'var(--Pad)' }}>
<Typography level='body-xs'>
System Prompt:
</Typography>
<ChatMessageMemo
message={createDMessage('assistant', instruction.systemPrompt)}
fitScreen={props.isMobile}
showAvatar={false}
adjustContentScaling={-1}
sx={configChatInstructionSx}
/>
</Box>
)}
{instruction.userPrompt && (
<Box sx={{ display: 'flex', alignItems: 'center', ml: 'var(--Pad)' }}>
<Typography level='body-xs'>
User Prompt:
</Typography>
<ChatMessageMemo
message={createDMessage('assistant', instruction.userPrompt)}
fitScreen={props.isMobile}
showAvatar={false}
adjustContentScaling={-1}
sx={configChatInstructionSx}
/>
</Box>
)}
</>;
}
return null;
}, [instruction, props.isMobile]);
if (!instructionComponent)
return null;
return (
return !!instructions?.length ? (
<Box sx={gatherConfigWrapperSx}>
{instructionComponent}
{instructions.map((instruction, stepIndex) =>
<ReadOnlyInstruction key={'step-' + stepIndex} instruction={instruction} isMobile={props.isMobile} />,
)}
</Box>
);
) : null;
}
+25 -10
View File
@@ -1,9 +1,10 @@
import * as React from 'react';
import type { SxProps } from '@mui/joy/styles/types';
import { Box, Button, ButtonGroup, FormControl, SvgIconProps, Typography } from '@mui/joy';
import { Box, Button, ButtonGroup, FormControl, IconButton, SvgIconProps, Tooltip, Typography } from '@mui/joy';
import AutoAwesomeIcon from '@mui/icons-material/AutoAwesome';
import AutoAwesomeOutlinedIcon from '@mui/icons-material/AutoAwesomeOutlined';
import EditRoundedIcon from '@mui/icons-material/EditRounded';
import MergeRoundedIcon from '@mui/icons-material/MergeRounded';
import StopRoundedIcon from '@mui/icons-material/StopRounded';
@@ -15,7 +16,7 @@ import { useScrollToBottom } from '~/common/scroll-to-bottom/useScrollToBottom';
import { GATHER_COLOR } from '../beam.config';
import { beamPaneSx } from '../BeamCard';
import { FUSION_PROGRAMS } from './beam.gather';
import { FUSION_FACTORIES } from './beam.gather';
const mobileBeamGatherPane: SxProps = {
@@ -47,21 +48,25 @@ export function BeamGatherPane(props: {
gatherLlmIcon?: React.FunctionComponent<SvgIconProps>,
fusionIndex: number | null,
setFusionIndex: (index: number | null) => void
onStartFusion: () => void,
onStopFusion: () => void,
onFusionCustomize: (index: number) => void,
onFusionStart: () => void,
onFusionStop: () => void,
}) {
// external state
const { setStickToBottom } = useScrollToBottom();
// derived state
const { gatherCount, gatherEnabled, gatherBusy, setFusionIndex } = props;
const { gatherCount, gatherEnabled, gatherBusy, setFusionIndex, onFusionCustomize } = props;
const handleFusionActivate = React.useCallback((idx: number, shiftPressed: boolean) => {
setStickToBottom(true);
setFusionIndex((idx !== props.fusionIndex || !shiftPressed) ? idx : null);
}, [props.fusionIndex, setFusionIndex, setStickToBottom]);
const handleFusionCustomize = React.useCallback(() => {
props.fusionIndex !== null && onFusionCustomize(props.fusionIndex);
}, [onFusionCustomize, props.fusionIndex]);
const Icon = props.gatherLlmIcon || (gatherBusy ? AutoAwesomeIcon : AutoAwesomeOutlinedIcon);
@@ -85,10 +90,13 @@ export function BeamGatherPane(props: {
{/* Method */}
<FormControl sx={{ my: '-0.25rem' }}>
<FormLabelStart title={<><AutoAwesomeOutlinedIcon sx={{ fontSize: 'md', mr: 0.5 }} />Method</>} sx={{ mb: '0.25rem' /* orig: 6px */ }} />
<Box sx={{ display: 'flex', flexAlign: 'center', gap: 1 }}>
<FormLabelStart
title={<><AutoAwesomeOutlinedIcon sx={{ fontSize: 'md', mr: 0.5 }} />Method</>}
sx={{ mb: '0.25rem' /* orig: 6px */ }}
/>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
<ButtonGroup variant='outlined'>
{FUSION_PROGRAMS.map((fusion, idx) => {
{FUSION_FACTORIES.map((fusion, idx) => {
const isActive = idx === props.fusionIndex;
return (
<Button
@@ -108,6 +116,13 @@ export function BeamGatherPane(props: {
);
})}
</ButtonGroup>
{(props.fusionIndex !== null) && (
<Tooltip disableInteractive title='Customize This Merge'>
<IconButton size='sm' color='success' disabled={props.gatherBusy || props.fusionIndex === 2} onClick={handleFusionCustomize}>
<EditRoundedIcon />
</IconButton>
</Tooltip>
)}
</Box>
</FormControl>
@@ -123,7 +138,7 @@ export function BeamGatherPane(props: {
variant='solid' color={GATHER_COLOR}
disabled={!gatherEnabled || gatherBusy} loading={gatherBusy}
endDecorator={<MergeRoundedIcon />}
onClick={props.onStartFusion}
onClick={props.onFusionStart}
sx={{ minWidth: 120 }}
>
Merge
@@ -133,7 +148,7 @@ export function BeamGatherPane(props: {
// key='gather-stop'
variant='solid' color='danger'
endDecorator={<StopRoundedIcon />}
onClick={props.onStopFusion}
onClick={props.onFusionStop}
sx={{ minWidth: 120 }}
>
Stop
+68 -31
View File
@@ -9,45 +9,53 @@ import { GATHER_PLACEHOLDER } from '../beam.config';
// Choose, Improve, Fuse, Manual
export const FUSION_PROGRAMS: { label: string, factory: () => BFusion }[] = [
const commonInitialization = (isEditable: boolean): Pick<BFusion,
'isEditable' | 'currentInstructionIndex' | 'llmId' | 'status' | 'outputMessage'
> => ({
isEditable,
currentInstructionIndex: 0,
llmId: null,
status: 'idle',
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER),
});
export const FUSION_FACTORIES: { label: string, factory: () => BFusion }[] = [
{
label: 'Guided', factory: () => ({
label: 'Guided',
factory: () => ({
instructions: [{
type: 'chat-generate',
systemPrompt: 'You are',
userPrompt: 'Perform this',
systemPrompt: 'You arfe',
userPrompt: 'Perform thiws',
outputType: 'fin',
}, {
type: 'user-input-checklist',
}],
currentInstructionIndex: 0,
isEditable: false,
llmId: null,
status: 'idle',
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER),
...commonInitialization(false),
}),
},
{
label: 'Fuse', factory: () => ({
label: 'Fuse',
factory: () => ({
instructions: [{
type: 'chat-generate',
systemPrompt: 'You are an editor',
userPrompt: 'Best of all',
outputType: 'fin',
}],
currentInstructionIndex: 0,
isEditable: false,
llmId: null,
status: 'idle',
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER + '2'),
...commonInitialization(false),
}),
},
{
label: 'Custom', factory: () => ({
instructions: [],
currentInstructionIndex: 0,
isEditable: true,
llmId: null,
status: 'idle',
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER + '3'),
label: 'Custom',
factory: () => ({
instructions: [{
type: 'chat-generate',
systemPrompt: 'You are a custom editor',
userPrompt: 'Best of all',
outputType: 'fin',
}],
...commonInitialization(true),
}),
},
];
@@ -73,7 +81,7 @@ export function fusionGatherStop(fusion: BFusion): BFusion {
/// Gather Store Slice ///
type TInstruction = {
export type TInstruction = {
type: 'chat-generate',
systemPrompt: string;
userPrompt: string;
@@ -84,14 +92,14 @@ type TInstruction = {
export interface BFusion {
// set at creation, adjusted later if this is a custom fusion (and only when idle)
isEditable: boolean; // only true on a single custom fusion
instructions: TInstruction[];
currentInstructionIndex: number;
isEditable: boolean;
// set at lifecycle
// set at start
llmId: DLLMId | null;
// variable
currentInstructionIndex: number; // points to the next instruction to execute
status: 'idle' | 'fusing' | 'success' | 'stopped' | 'error';
outputMessage: DMessage;
issue?: string;
@@ -115,7 +123,7 @@ export const reInitGatherStateSlice = (prevFusions: BFusion[]): GatherStateSlice
return {
// recreate all fusions (no recycle)
fusions: FUSION_PROGRAMS.map(spec => spec.factory()),
fusions: FUSION_FACTORIES.map(spec => spec.factory()),
fusionIndex: null,
fusionLlmId: null,
isGathering: false,
@@ -126,8 +134,12 @@ export interface GatherStoreSlice extends GatherStateSlice {
setFusionIndex: (index: number | null) => void;
setFusionLlmId: (llmId: DLLMId | null) => void;
startFusion: () => void;
stopFusion: () => void;
fusionCustomize: (sourceIndex: number) => void;
fusionStart: () => void;
fusionStop: () => void;
_fusionUpdate: (fusionIndex: number, update: Partial<BFusion> | ((fusion: BFusion) => (Partial<BFusion> | null))) => void;
}
@@ -147,12 +159,37 @@ export const createGatherSlice: StateCreator<GatherStoreSlice, [], [], GatherSto
fusionLlmId: llmId,
}),
startFusion: () => {
fusionCustomize: (sourceIndex: number) => {
const { fusions, setFusionIndex, _fusionUpdate } = _get();
const editableFusionIndex = fusions.findIndex(fusion => fusion.isEditable);
const fusionFactory = FUSION_FACTORIES[sourceIndex];
if (editableFusionIndex === -1 || editableFusionIndex === sourceIndex || !fusionFactory)
return;
_fusionUpdate(editableFusionIndex, customFusion => {
// Terminate current custom fusion, if any
fusionGatherStop(customFusion);
return {
...fusionFactory.factory(),
isEditable: true,
};
});
setFusionIndex(editableFusionIndex);
},
fusionStart: () => {
console.log('startGatheringCurrent');
},
stopFusion: () => {
fusionStop: () => {
console.log('stopGatheringCurrent');
},
_fusionUpdate: (fusionIndex: number, update: Partial<BFusion> | ((fusion: BFusion) => (Partial<BFusion> | null))) =>
_set(state => ({
fusions: state.fusions.map((fusion, index) => (index === fusionIndex)
? { ...fusion, ...(typeof update === 'function' ? update(fusion) : update) }
: fusion,
),
})),
});
+7 -6
View File
@@ -264,12 +264,13 @@ export const createScatterSlice: StateCreator<RootStoreSlice & GatherStoreSlice
scatterLlmId: llmId,
}),
_rayUpdate: (rayId: BRayId, update: Partial<BRay> | ((ray: BRay) => Partial<BRay>)) => _set(state => ({
rays: state.rays.map(ray => (ray.rayId === rayId)
? { ...ray, ...(typeof update === 'function' ? update(ray) : update) }
: ray,
),
})),
_rayUpdate: (rayId: BRayId, update: Partial<BRay> | ((ray: BRay) => Partial<BRay>)) =>
_set(state => ({
rays: state.rays.map(ray => (ray.rayId === rayId)
? { ...ray, ...(typeof update === 'function' ? update(ray) : update) }
: ray,
),
})),
syncRaysStateToBeam: () => {