mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
Beam: Gather: enable customization
This commit is contained in:
@@ -29,7 +29,7 @@ export function BeamView(props: {
|
||||
const {
|
||||
/* root */ editInputHistoryMessage,
|
||||
/* scatter */ setRayCount, startScatteringAll, stopScatteringAll,
|
||||
/* gather */ setFusionIndex, setFusionLlmId, startFusion, stopFusion,
|
||||
/* gather */ setFusionIndex, setFusionLlmId, fusionCustomize, fusionStart, fusionStop,
|
||||
} = props.beamStore.getState();
|
||||
const {
|
||||
/* root */ inputHistory, inputIssues, inputReady,
|
||||
@@ -140,8 +140,9 @@ export function BeamView(props: {
|
||||
isMobile={props.isMobile}
|
||||
fusionIndex={fusionIndex}
|
||||
setFusionIndex={setFusionIndex}
|
||||
onStartFusion={startFusion}
|
||||
onStopFusion={stopFusion}
|
||||
onFusionCustomize={fusionCustomize}
|
||||
onFusionStart={fusionStart}
|
||||
onFusionStop={fusionStop}
|
||||
/>
|
||||
|
||||
{/* Fusion Output */}
|
||||
|
||||
@@ -6,8 +6,10 @@ import { Box, Typography } from '@mui/joy';
|
||||
import { ChatMessageMemo } from '../../../apps/chat/components/message/ChatMessage';
|
||||
|
||||
import { createDMessage } from '~/common/state/store-chats';
|
||||
import { BeamStoreApi, useBeamStore } from '~/common/beam/store-beam.hooks';
|
||||
import { GATHER_DEBUG_NONCUSTOM } from '~/common/beam/beam.config';
|
||||
|
||||
import type { TInstruction } from './beam.gather';
|
||||
import { BeamStoreApi, useBeamStore } from '../store-beam.hooks';
|
||||
import { GATHER_DEBUG_NONCUSTOM } from '../beam.config';
|
||||
|
||||
|
||||
const gatherConfigWrapperSx: SxProps = {
|
||||
@@ -34,6 +36,79 @@ const configChatInstructionSx: SxProps = {
|
||||
flex: 1,
|
||||
};
|
||||
|
||||
function InstructionWrapper(props: { children: React.ReactNode }) {
|
||||
return (
|
||||
<Box sx={{ display: 'flex', flexDirection: 'column', mx: 'var(--Pad)' }}>
|
||||
{props.children}
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
function ReadOnlyInstruction(props: { instruction: TInstruction, isMobile: boolean }) {
|
||||
const { instruction } = props;
|
||||
|
||||
// render 'chat-generate'
|
||||
if (instruction.type === 'chat-generate') {
|
||||
return (
|
||||
<InstructionWrapper>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center' }}>
|
||||
<Typography level='body-xs'>
|
||||
System Prompt:
|
||||
</Typography>
|
||||
<ChatMessageMemo
|
||||
message={createDMessage('assistant', instruction.systemPrompt)}
|
||||
fitScreen={props.isMobile}
|
||||
showAvatar={false}
|
||||
adjustContentScaling={-1}
|
||||
sx={configChatInstructionSx}
|
||||
/>
|
||||
</Box>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center' }}>
|
||||
<Typography level='body-xs'>
|
||||
User Prompt:
|
||||
</Typography>
|
||||
<ChatMessageMemo
|
||||
message={createDMessage('assistant', instruction.userPrompt)}
|
||||
fitScreen={props.isMobile}
|
||||
showAvatar={false}
|
||||
adjustContentScaling={-1}
|
||||
sx={configChatInstructionSx}
|
||||
/>
|
||||
</Box>
|
||||
</InstructionWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
// render 'user-input-checklist'
|
||||
if (instruction.type === 'user-input-checklist') {
|
||||
return (
|
||||
<InstructionWrapper>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center' }}>
|
||||
<Typography level='body-xs'>
|
||||
Checklist:
|
||||
</Typography>
|
||||
<ChatMessageMemo
|
||||
message={createDMessage('assistant', '#### Test\n- [ ] test\n- [ ] ...')}
|
||||
fitScreen={props.isMobile}
|
||||
showAvatar={false}
|
||||
adjustContentScaling={-1}
|
||||
sx={configChatInstructionSx}
|
||||
/>
|
||||
</Box>
|
||||
</InstructionWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<InstructionWrapper>
|
||||
<Typography level='body-xs'>
|
||||
Unknown Instruction
|
||||
</Typography>
|
||||
</InstructionWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
export function BeamGatherConfig(props: {
|
||||
beamStore: BeamStoreApi
|
||||
@@ -41,7 +116,7 @@ export function BeamGatherConfig(props: {
|
||||
}) {
|
||||
|
||||
// state
|
||||
const [viewInstructionIndex, setViewInstructionIndex] = React.useState(0);
|
||||
// const [viewInstructionIndex, setViewInstructionIndex] = React.useState(0);
|
||||
|
||||
// external state
|
||||
const fusion = useBeamStore(props.beamStore, store => {
|
||||
@@ -49,64 +124,23 @@ export function BeamGatherConfig(props: {
|
||||
return (fusion?.isEditable || GATHER_DEBUG_NONCUSTOM) ? fusion : null;
|
||||
});
|
||||
|
||||
|
||||
// [effect] sync the fusion program index to the viewInstructionIndex
|
||||
React.useEffect(() => {
|
||||
fusion && setViewInstructionIndex(fusion.currentInstructionIndex);
|
||||
}, [fusion]);
|
||||
|
||||
// // [effect] sync the fusion program index to the viewInstructionIndex
|
||||
// React.useEffect(() => {
|
||||
// fusion && setViewInstructionIndex(fusion.currentInstructionIndex);
|
||||
// }, [fusion]);
|
||||
|
||||
// derived state
|
||||
const instruction = React.useMemo(() => {
|
||||
return fusion?.instructions[viewInstructionIndex] ?? null;
|
||||
}, [fusion, viewInstructionIndex]);
|
||||
// const instructions = React.useMemo(() => {
|
||||
// return fusion?.instructions ?? null;
|
||||
// }, [fusion]);
|
||||
|
||||
const instructions = fusion?.instructions ?? null;
|
||||
|
||||
// render instruction
|
||||
const instructionComponent = React.useMemo(() => {
|
||||
if (instruction && instruction.type === 'chat-generate') {
|
||||
return <>
|
||||
{instruction.systemPrompt && (
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', mx: 'var(--Pad)' }}>
|
||||
<Typography level='body-xs'>
|
||||
System Prompt:
|
||||
</Typography>
|
||||
<ChatMessageMemo
|
||||
message={createDMessage('assistant', instruction.systemPrompt)}
|
||||
fitScreen={props.isMobile}
|
||||
showAvatar={false}
|
||||
adjustContentScaling={-1}
|
||||
sx={configChatInstructionSx}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{instruction.userPrompt && (
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', ml: 'var(--Pad)' }}>
|
||||
<Typography level='body-xs'>
|
||||
User Prompt:
|
||||
</Typography>
|
||||
<ChatMessageMemo
|
||||
message={createDMessage('assistant', instruction.userPrompt)}
|
||||
fitScreen={props.isMobile}
|
||||
showAvatar={false}
|
||||
adjustContentScaling={-1}
|
||||
sx={configChatInstructionSx}
|
||||
/>
|
||||
|
||||
</Box>
|
||||
)}
|
||||
</>;
|
||||
}
|
||||
return null;
|
||||
}, [instruction, props.isMobile]);
|
||||
|
||||
|
||||
if (!instructionComponent)
|
||||
return null;
|
||||
|
||||
return (
|
||||
return !!instructions?.length ? (
|
||||
<Box sx={gatherConfigWrapperSx}>
|
||||
{instructionComponent}
|
||||
{instructions.map((instruction, stepIndex) =>
|
||||
<ReadOnlyInstruction key={'step-' + stepIndex} instruction={instruction} isMobile={props.isMobile} />,
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
) : null;
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
import * as React from 'react';
|
||||
|
||||
import type { SxProps } from '@mui/joy/styles/types';
|
||||
import { Box, Button, ButtonGroup, FormControl, SvgIconProps, Typography } from '@mui/joy';
|
||||
import { Box, Button, ButtonGroup, FormControl, IconButton, SvgIconProps, Tooltip, Typography } from '@mui/joy';
|
||||
import AutoAwesomeIcon from '@mui/icons-material/AutoAwesome';
|
||||
import AutoAwesomeOutlinedIcon from '@mui/icons-material/AutoAwesomeOutlined';
|
||||
import EditRoundedIcon from '@mui/icons-material/EditRounded';
|
||||
import MergeRoundedIcon from '@mui/icons-material/MergeRounded';
|
||||
import StopRoundedIcon from '@mui/icons-material/StopRounded';
|
||||
|
||||
@@ -15,7 +16,7 @@ import { useScrollToBottom } from '~/common/scroll-to-bottom/useScrollToBottom';
|
||||
import { GATHER_COLOR } from '../beam.config';
|
||||
import { beamPaneSx } from '../BeamCard';
|
||||
|
||||
import { FUSION_PROGRAMS } from './beam.gather';
|
||||
import { FUSION_FACTORIES } from './beam.gather';
|
||||
|
||||
|
||||
const mobileBeamGatherPane: SxProps = {
|
||||
@@ -47,21 +48,25 @@ export function BeamGatherPane(props: {
|
||||
gatherLlmIcon?: React.FunctionComponent<SvgIconProps>,
|
||||
fusionIndex: number | null,
|
||||
setFusionIndex: (index: number | null) => void
|
||||
onStartFusion: () => void,
|
||||
onStopFusion: () => void,
|
||||
onFusionCustomize: (index: number) => void,
|
||||
onFusionStart: () => void,
|
||||
onFusionStop: () => void,
|
||||
}) {
|
||||
|
||||
// external state
|
||||
const { setStickToBottom } = useScrollToBottom();
|
||||
|
||||
// derived state
|
||||
const { gatherCount, gatherEnabled, gatherBusy, setFusionIndex } = props;
|
||||
const { gatherCount, gatherEnabled, gatherBusy, setFusionIndex, onFusionCustomize } = props;
|
||||
|
||||
const handleFusionActivate = React.useCallback((idx: number, shiftPressed: boolean) => {
|
||||
setStickToBottom(true);
|
||||
setFusionIndex((idx !== props.fusionIndex || !shiftPressed) ? idx : null);
|
||||
}, [props.fusionIndex, setFusionIndex, setStickToBottom]);
|
||||
|
||||
const handleFusionCustomize = React.useCallback(() => {
|
||||
props.fusionIndex !== null && onFusionCustomize(props.fusionIndex);
|
||||
}, [onFusionCustomize, props.fusionIndex]);
|
||||
|
||||
const Icon = props.gatherLlmIcon || (gatherBusy ? AutoAwesomeIcon : AutoAwesomeOutlinedIcon);
|
||||
|
||||
@@ -85,10 +90,13 @@ export function BeamGatherPane(props: {
|
||||
|
||||
{/* Method */}
|
||||
<FormControl sx={{ my: '-0.25rem' }}>
|
||||
<FormLabelStart title={<><AutoAwesomeOutlinedIcon sx={{ fontSize: 'md', mr: 0.5 }} />Method</>} sx={{ mb: '0.25rem' /* orig: 6px */ }} />
|
||||
<Box sx={{ display: 'flex', flexAlign: 'center', gap: 1 }}>
|
||||
<FormLabelStart
|
||||
title={<><AutoAwesomeOutlinedIcon sx={{ fontSize: 'md', mr: 0.5 }} />Method</>}
|
||||
sx={{ mb: '0.25rem' /* orig: 6px */ }}
|
||||
/>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>
|
||||
<ButtonGroup variant='outlined'>
|
||||
{FUSION_PROGRAMS.map((fusion, idx) => {
|
||||
{FUSION_FACTORIES.map((fusion, idx) => {
|
||||
const isActive = idx === props.fusionIndex;
|
||||
return (
|
||||
<Button
|
||||
@@ -108,6 +116,13 @@ export function BeamGatherPane(props: {
|
||||
);
|
||||
})}
|
||||
</ButtonGroup>
|
||||
{(props.fusionIndex !== null) && (
|
||||
<Tooltip disableInteractive title='Customize This Merge'>
|
||||
<IconButton size='sm' color='success' disabled={props.gatherBusy || props.fusionIndex === 2} onClick={handleFusionCustomize}>
|
||||
<EditRoundedIcon />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
)}
|
||||
</Box>
|
||||
</FormControl>
|
||||
|
||||
@@ -123,7 +138,7 @@ export function BeamGatherPane(props: {
|
||||
variant='solid' color={GATHER_COLOR}
|
||||
disabled={!gatherEnabled || gatherBusy} loading={gatherBusy}
|
||||
endDecorator={<MergeRoundedIcon />}
|
||||
onClick={props.onStartFusion}
|
||||
onClick={props.onFusionStart}
|
||||
sx={{ minWidth: 120 }}
|
||||
>
|
||||
Merge
|
||||
@@ -133,7 +148,7 @@ export function BeamGatherPane(props: {
|
||||
// key='gather-stop'
|
||||
variant='solid' color='danger'
|
||||
endDecorator={<StopRoundedIcon />}
|
||||
onClick={props.onStopFusion}
|
||||
onClick={props.onFusionStop}
|
||||
sx={{ minWidth: 120 }}
|
||||
>
|
||||
Stop
|
||||
|
||||
@@ -9,45 +9,53 @@ import { GATHER_PLACEHOLDER } from '../beam.config';
|
||||
|
||||
// Choose, Improve, Fuse, Manual
|
||||
|
||||
export const FUSION_PROGRAMS: { label: string, factory: () => BFusion }[] = [
|
||||
const commonInitialization = (isEditable: boolean): Pick<BFusion,
|
||||
'isEditable' | 'currentInstructionIndex' | 'llmId' | 'status' | 'outputMessage'
|
||||
> => ({
|
||||
isEditable,
|
||||
currentInstructionIndex: 0,
|
||||
llmId: null,
|
||||
status: 'idle',
|
||||
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER),
|
||||
});
|
||||
|
||||
export const FUSION_FACTORIES: { label: string, factory: () => BFusion }[] = [
|
||||
{
|
||||
label: 'Guided', factory: () => ({
|
||||
label: 'Guided',
|
||||
factory: () => ({
|
||||
instructions: [{
|
||||
type: 'chat-generate',
|
||||
systemPrompt: 'You are',
|
||||
userPrompt: 'Perform this',
|
||||
systemPrompt: 'You arfe',
|
||||
userPrompt: 'Perform thiws',
|
||||
outputType: 'fin',
|
||||
}, {
|
||||
type: 'user-input-checklist',
|
||||
}],
|
||||
currentInstructionIndex: 0,
|
||||
isEditable: false,
|
||||
llmId: null,
|
||||
status: 'idle',
|
||||
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER),
|
||||
...commonInitialization(false),
|
||||
}),
|
||||
},
|
||||
{
|
||||
label: 'Fuse', factory: () => ({
|
||||
label: 'Fuse',
|
||||
factory: () => ({
|
||||
instructions: [{
|
||||
type: 'chat-generate',
|
||||
systemPrompt: 'You are an editor',
|
||||
userPrompt: 'Best of all',
|
||||
outputType: 'fin',
|
||||
}],
|
||||
currentInstructionIndex: 0,
|
||||
isEditable: false,
|
||||
llmId: null,
|
||||
status: 'idle',
|
||||
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER + '2'),
|
||||
...commonInitialization(false),
|
||||
}),
|
||||
},
|
||||
{
|
||||
label: 'Custom', factory: () => ({
|
||||
instructions: [],
|
||||
currentInstructionIndex: 0,
|
||||
isEditable: true,
|
||||
llmId: null,
|
||||
status: 'idle',
|
||||
outputMessage: createDMessage('assistant', GATHER_PLACEHOLDER + '3'),
|
||||
label: 'Custom',
|
||||
factory: () => ({
|
||||
instructions: [{
|
||||
type: 'chat-generate',
|
||||
systemPrompt: 'You are a custom editor',
|
||||
userPrompt: 'Best of all',
|
||||
outputType: 'fin',
|
||||
}],
|
||||
...commonInitialization(true),
|
||||
}),
|
||||
},
|
||||
];
|
||||
@@ -73,7 +81,7 @@ export function fusionGatherStop(fusion: BFusion): BFusion {
|
||||
|
||||
/// Gather Store Slice ///
|
||||
|
||||
type TInstruction = {
|
||||
export type TInstruction = {
|
||||
type: 'chat-generate',
|
||||
systemPrompt: string;
|
||||
userPrompt: string;
|
||||
@@ -84,14 +92,14 @@ type TInstruction = {
|
||||
|
||||
export interface BFusion {
|
||||
// set at creation, adjusted later if this is a custom fusion (and only when idle)
|
||||
isEditable: boolean; // only true on a single custom fusion
|
||||
instructions: TInstruction[];
|
||||
currentInstructionIndex: number;
|
||||
isEditable: boolean;
|
||||
|
||||
// set at lifecycle
|
||||
// set at start
|
||||
llmId: DLLMId | null;
|
||||
|
||||
// variable
|
||||
currentInstructionIndex: number; // points to the next instruction to execute
|
||||
status: 'idle' | 'fusing' | 'success' | 'stopped' | 'error';
|
||||
outputMessage: DMessage;
|
||||
issue?: string;
|
||||
@@ -115,7 +123,7 @@ export const reInitGatherStateSlice = (prevFusions: BFusion[]): GatherStateSlice
|
||||
|
||||
return {
|
||||
// recreate all fusions (no recycle)
|
||||
fusions: FUSION_PROGRAMS.map(spec => spec.factory()),
|
||||
fusions: FUSION_FACTORIES.map(spec => spec.factory()),
|
||||
fusionIndex: null,
|
||||
fusionLlmId: null,
|
||||
isGathering: false,
|
||||
@@ -126,8 +134,12 @@ export interface GatherStoreSlice extends GatherStateSlice {
|
||||
|
||||
setFusionIndex: (index: number | null) => void;
|
||||
setFusionLlmId: (llmId: DLLMId | null) => void;
|
||||
startFusion: () => void;
|
||||
stopFusion: () => void;
|
||||
|
||||
fusionCustomize: (sourceIndex: number) => void;
|
||||
fusionStart: () => void;
|
||||
fusionStop: () => void;
|
||||
|
||||
_fusionUpdate: (fusionIndex: number, update: Partial<BFusion> | ((fusion: BFusion) => (Partial<BFusion> | null))) => void;
|
||||
|
||||
}
|
||||
|
||||
@@ -147,12 +159,37 @@ export const createGatherSlice: StateCreator<GatherStoreSlice, [], [], GatherSto
|
||||
fusionLlmId: llmId,
|
||||
}),
|
||||
|
||||
startFusion: () => {
|
||||
fusionCustomize: (sourceIndex: number) => {
|
||||
const { fusions, setFusionIndex, _fusionUpdate } = _get();
|
||||
const editableFusionIndex = fusions.findIndex(fusion => fusion.isEditable);
|
||||
const fusionFactory = FUSION_FACTORIES[sourceIndex];
|
||||
if (editableFusionIndex === -1 || editableFusionIndex === sourceIndex || !fusionFactory)
|
||||
return;
|
||||
_fusionUpdate(editableFusionIndex, customFusion => {
|
||||
// Terminate current custom fusion, if any
|
||||
fusionGatherStop(customFusion);
|
||||
return {
|
||||
...fusionFactory.factory(),
|
||||
isEditable: true,
|
||||
};
|
||||
});
|
||||
setFusionIndex(editableFusionIndex);
|
||||
},
|
||||
|
||||
fusionStart: () => {
|
||||
console.log('startGatheringCurrent');
|
||||
},
|
||||
|
||||
stopFusion: () => {
|
||||
fusionStop: () => {
|
||||
console.log('stopGatheringCurrent');
|
||||
},
|
||||
|
||||
_fusionUpdate: (fusionIndex: number, update: Partial<BFusion> | ((fusion: BFusion) => (Partial<BFusion> | null))) =>
|
||||
_set(state => ({
|
||||
fusions: state.fusions.map((fusion, index) => (index === fusionIndex)
|
||||
? { ...fusion, ...(typeof update === 'function' ? update(fusion) : update) }
|
||||
: fusion,
|
||||
),
|
||||
})),
|
||||
|
||||
});
|
||||
|
||||
@@ -264,12 +264,13 @@ export const createScatterSlice: StateCreator<RootStoreSlice & GatherStoreSlice
|
||||
scatterLlmId: llmId,
|
||||
}),
|
||||
|
||||
_rayUpdate: (rayId: BRayId, update: Partial<BRay> | ((ray: BRay) => Partial<BRay>)) => _set(state => ({
|
||||
rays: state.rays.map(ray => (ray.rayId === rayId)
|
||||
? { ...ray, ...(typeof update === 'function' ? update(ray) : update) }
|
||||
: ray,
|
||||
),
|
||||
})),
|
||||
_rayUpdate: (rayId: BRayId, update: Partial<BRay> | ((ray: BRay) => Partial<BRay>)) =>
|
||||
_set(state => ({
|
||||
rays: state.rays.map(ray => (ray.rayId === rayId)
|
||||
? { ...ray, ...(typeof update === 'function' ? update(ray) : update) }
|
||||
: ray,
|
||||
),
|
||||
})),
|
||||
|
||||
|
||||
syncRaysStateToBeam: () => {
|
||||
|
||||
Reference in New Issue
Block a user