Wizard: Models

This commit is contained in:
Enrico Ros
2024-10-15 13:49:03 -07:00
parent 2200bb9ee8
commit ec493ee91b
6 changed files with 373 additions and 48 deletions
+5 -1
View File
@@ -57,5 +57,9 @@ export function useHasLLMs(): boolean {
}
export function useModelsServices() {
return useModelsStore(state => state.sources);
return useModelsStore(useShallow(state => ({
modelsServices: state.sources,
confServiceId: state.confServiceId,
setConfServiceId: state.setConfServiceId,
})));
}
+9
View File
@@ -25,6 +25,8 @@ interface LlmsState {
chatLLMId: DLLMId | null;
fastLLMId: DLLMId | null;
confServiceId: DModelsServiceId | null;
}
interface LlmsActions {
@@ -43,6 +45,8 @@ interface LlmsActions {
setChatLLMId: (id: DLLMId | null) => void;
setFastLLMId: (id: DLLMId | null) => void;
setConfServiceId: (id: DModelsServiceId | null) => void;
// special
setOpenRouterKey: (key: string) => void;
@@ -59,6 +63,7 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
chatLLMId: null,
fastLLMId: null,
confServiceId: null,
// actions
@@ -166,6 +171,7 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
}
: s,
),
confServiceId: state.confServiceId ?? service.id,
};
}),
@@ -188,6 +194,9 @@ export const useModelsStore = create<LlmsState & LlmsActions>()(persist(
),
})),
setConfServiceId: (id: DModelsServiceId | null) =>
set({ confServiceId: id }),
setOpenRouterKey: (key: string) =>
set(state => {
const firstOpenRouterService = state.sources.find(s => s.vId === 'openrouter');
+107 -41
View File
@@ -1,21 +1,27 @@
import * as React from 'react';
import { Box, Divider } from '@mui/joy';
import { Box, Button, Divider } from '@mui/joy';
import type { DModelsService, DModelsServiceId } from '~/common/stores/llms/modelsservice.types';
import type { DModelsService } from '~/common/stores/llms/modelsservice.types';
import { GoodModal } from '~/common/components/modals/GoodModal';
import { llmsStoreState } from '~/common/stores/llms/store-llms';
import { optimaActions, optimaOpenModels, useOptimaModelsModalsState } from '~/common/layout/optima/useOptima';
import { runWhenIdle } from '~/common/util/pwaUtils';
import { useLLMsCount, useModelsServices } from '~/common/stores/llms/llms.hooks';
import { useIsMobile } from '~/common/components/useMatchMedia';
import { useHasLLMs, useModelsServices } from '~/common/stores/llms/llms.hooks';
import { LLMOptionsModal } from './LLMOptionsModal';
import { ModelsList } from './ModelsList';
import { ModelsServiceSelector } from './ModelsServiceSelector';
import { ModelsWizard } from './ModelsWizard';
import { createModelsServiceForDefaultVendor } from '../vendors/vendor.helpers';
import { findModelVendor } from '../vendors/vendors.registry';
// configuration
const MODELS_WIZARD_ENABLE_INITIALLY = true;
function VendorServiceSetup(props: { service: DModelsService }) {
const vendor = findModelVendor(props.service.vId);
if (!vendor)
@@ -24,35 +30,40 @@ function VendorServiceSetup(props: { service: DModelsService }) {
}
export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) {
/**
* Note: the reason for this component separation from the parent state, is delayed state intitialization.
*/
function ModelsConfiguratorModal(props: {
modelsServices: DModelsService[],
confServiceId: string | null,
setConfServiceId: (serviceId: string | null) => void,
allowAutoTrigger: boolean,
}) {
// local state
const [_selectedServiceId, setSelectedServiceId] = React.useState<DModelsServiceId | null>(null);
const { modelsServices, confServiceId, setConfServiceId } = props;
// state
// const [showAllServices, setShowAllServices] = React.useState<boolean>(false);
const [showWizard, setShowWizard] = React.useState<boolean>(MODELS_WIZARD_ENABLE_INITIALLY && !modelsServices.length);
const showAllServices = false;
// external state
const { showModels, showModelOptions } = useOptimaModelsModalsState();
const modelsServices = useModelsServices();
const llmCount = useLLMsCount();
const isMobile = useIsMobile();
const hasLLMs = useHasLLMs();
// auto-select the first service - note: we could use a useEffect() here, but this is more efficient
// also note that state-persistence is unneeded
const selectedServiceId = _selectedServiceId ?? modelsServices[modelsServices.length - 1]?.id ?? null;
const activeService = modelsServices.find(s => s.id === selectedServiceId);
// active service with fallback to the last added service
const activeServiceId = confServiceId
?? modelsServices[modelsServices.length - 1]?.id
?? null;
// const multiService = modelsServices.length > 1;
const activeService = modelsServices.find(s => s.id === activeServiceId);
const isMultiServices = modelsServices.length > 1;
// Auto-open this dialog - anytime no service is selected
const autoOpenTrigger = !selectedServiceId && !props.suspendAutoModelsSetup;
React.useEffect(() => {
if (autoOpenTrigger)
return runWhenIdle(() => optimaOpenModels(), 2000);
}, [autoOpenTrigger]);
// Auto-add the default service - at boot, when no service is present
const autoAddTrigger = showModels && !props.suspendAutoModelsSetup;
const autoAddTrigger = !showWizard && props.allowAutoTrigger;
React.useEffect(() => {
// Note: we use the immediate version to not react to deletions
const { addService, sources: modelsServices } = llmsStoreState();
@@ -61,43 +72,67 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) {
}, [autoAddTrigger]);
return <>
// handlers
const handleShowAdvanced = React.useCallback(() => {
setShowWizard(false);
}, []);
{/* Services Setup */}
{showModels && <GoodModal
title={<>Configure <b>AI Models</b></>}
const handleShowWizard = React.useCallback(() => {
setShowWizard(true);
}, []);
// start button
const startButton = React.useMemo(() => {
if (showWizard)
return <Button variant='outlined' color='neutral' onClick={handleShowAdvanced}>{isMobile ? 'Advanced' : 'Switch to Advanced'}</Button>;
// return <Badge size='sm' badgeContent='14 Services' color='neutral' variant='outlined'><Button variant='outlined' color='neutral' onClick={handleShowAdvanced}>{isMobile ? 'Advanced' : 'Switch to Advanced'}</Button></Badge>;
if (!isMultiServices)
return <Button variant='outlined' color='neutral' onClick={handleShowWizard}>{isMobile ? 'Easy Mode' : 'Easy Mode'}</Button>;
return undefined;
// if (isMultiServices) {
// return (
// <Checkbox
// label='All Services'
// sx={{ my: 'auto' }}
// checked={showAllServices} onChange={() => setShowAllServices(all => !all)}
// />
// );
// }
}, [handleShowAdvanced, handleShowWizard, isMobile, isMultiServices, showWizard]);
return (
<GoodModal
title={<>{showWizard ? 'Welcome · Setup' : 'Configure'} <b>AI Models</b></>}
open onClose={optimaActions().closeModels}
darkBottomClose
animateEnter={llmCount === 0}
closeText={showWizard ? 'Done' : undefined}
animateEnter={!hasLLMs}
unfilterBackdrop
// startButton={
// multiService ? <Checkbox
// label='All Services'
// sx={{ my: 'auto' }}
// checked={showAllServices} onChange={() => setShowAllServices(all => !all)}
// /> : undefined
// }
startButton={startButton}
sx={{
// forces some shrinkage of the contents (ModelsList)
overflow: 'auto',
}}
>
<ModelsServiceSelector selectedServiceId={selectedServiceId} setSelectedServiceId={setSelectedServiceId} />
{!showWizard && <ModelsServiceSelector modelsServices={modelsServices} selectedServiceId={activeServiceId} setSelectedServiceId={setConfServiceId} />}
{!!activeService && <Divider />}
<Divider />
{!!activeService && (
{showWizard && <ModelsWizard isMobile={isMobile} onSkip={optimaActions().closeModels} onSwitchToAdvanced={handleShowAdvanced} />}
{!showWizard && !!activeService && (
<Box sx={{ display: 'grid', gap: 'var(--Card-padding)' }}>
<VendorServiceSetup service={activeService} />
</Box>
)}
{!!llmCount && <Divider />}
{!showWizard && hasLLMs && <Divider />}
{!!llmCount && (
{!showWizard && hasLLMs && (
<ModelsList
filterServiceId={showAllServices ? null : selectedServiceId}
filterServiceId={showAllServices ? null : activeServiceId}
onOpenLLMOptions={optimaActions().openModelOptions}
sx={{
// works in tandem with the parent (GoodModal > Dialog) overflow: 'auto'
@@ -119,9 +154,40 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) {
/>
)}
<Divider sx={{ background: 'transparent'}} />
<Divider sx={{ background: 'transparent' }} />
</GoodModal>}
</GoodModal>
);
}
export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) {
// external state
const { showModels, showModelOptions } = useOptimaModelsModalsState();
const { modelsServices, confServiceId, setConfServiceId } = useModelsServices();
// [effect] Auto-open the configurator - anytime no service is selected
const hasNoServices = !modelsServices.length;
const autoOpenTrigger = hasNoServices && !props.suspendAutoModelsSetup;
React.useEffect(() => {
if (autoOpenTrigger)
return runWhenIdle(() => optimaOpenModels(), 2000);
}, [autoOpenTrigger]);
return <>
{/* Services Setup */}
{showModels && (
<ModelsConfiguratorModal
modelsServices={modelsServices}
confServiceId={confServiceId}
setConfServiceId={setConfServiceId}
allowAutoTrigger={!props.suspendAutoModelsSetup}
/>
)}
{/* per-LLM options */}
{!!showModelOptions && (
@@ -10,12 +10,12 @@ import { ConfirmationModal } from '~/common/components/modals/ConfirmationModal'
import { llmsStoreActions, llmsStoreState } from '~/common/stores/llms/store-llms';
import { themeZIndexOverMobileDrawer } from '~/common/app.theme';
import { useIsMobile } from '~/common/components/useMatchMedia';
import { useModelsServices } from '~/common/stores/llms/llms.hooks';
import { useOverlayComponents } from '~/common/layout/overlays/useOverlayComponents';
import type { IModelVendor } from '../vendors/IModelVendor';
import { createModelsServiceForVendor, vendorHasBackendCap } from '../vendors/vendor.helpers';
import { findAllModelVendors, findModelVendor, ModelVendorId } from '../vendors/vendors.registry';
// import { MODELS_WIZARD_OPTION_ID } from '~/modules/llms/models-modal/ModelsModal';
/*function locationIcon(vendor?: IModelVendor | null) {
@@ -35,7 +35,9 @@ function vendorIcon(vendor: IModelVendor | null, greenMark: boolean) {
export function ModelsServiceSelector(props: {
selectedServiceId: DModelsServiceId | null, setSelectedServiceId: (serviceId: DModelsServiceId | null) => void,
modelsServices: DModelsService[],
selectedServiceId: DModelsServiceId | null,
setSelectedServiceId: (serviceId: DModelsServiceId | null) => void,
}) {
// state
@@ -44,7 +46,6 @@ export function ModelsServiceSelector(props: {
// external state
const isMobile = useIsMobile();
const modelsServices = useModelsServices();
const handleShowVendors = (event: React.MouseEvent<HTMLElement>) => setVendorsMenuAnchor(event.currentTarget);
@@ -53,7 +54,7 @@ export function ModelsServiceSelector(props: {
// handlers
const { setSelectedServiceId } = props;
const { modelsServices, setSelectedServiceId } = props;
const handleAddServiceForVendor = React.useCallback((vendorId: ModelVendorId) => {
closeVendorsMenu();
@@ -67,7 +68,15 @@ export function ModelsServiceSelector(props: {
const enableDeleteButton = !!props.selectedServiceId && modelsServices.length > 1;
const handleDeleteService = React.useCallback(async (serviceId: DModelsServiceId) => {
const handleDeleteService = React.useCallback(async (serviceId: DModelsServiceId, skipConfirmation: boolean) => {
// [shift] to delete without confirmation
if (skipConfirmation) {
// select the next service
setSelectedServiceId(modelsServices.find(s => s.id !== serviceId)?.id ?? null);
// remove the service
llmsStoreActions().removeService(serviceId);
return;
}
showPromisedOverlay('llms-service-remove', {}, ({ onResolve, onUserReject }) =>
<ConfirmationModal
open onClose={onUserReject} onPositive={() => onResolve(true)}
@@ -222,7 +231,7 @@ export function ModelsServiceSelector(props: {
<IconButton
variant='plain' color='neutral' disabled={!enableDeleteButton} sx={{ ml: 'auto' }}
onClick={() => props.selectedServiceId && handleDeleteService(props.selectedServiceId)}
onClick={(event) => props.selectedServiceId && handleDeleteService(props.selectedServiceId, event.shiftKey)}
>
<DeleteOutlineIcon />
</IconButton>
@@ -0,0 +1,234 @@
import * as React from 'react';
import { useShallow } from 'zustand/react/shallow';
import { Avatar, Badge, Box, Button, CircularProgress, Input, Sheet, Typography } from '@mui/joy';
import { TooltipOutlined } from '~/common/components/TooltipOutlined';
import { llmsStoreState, useModelsStore } from '~/common/stores/llms/store-llms';
import { useShallowStabilizer } from '~/common/util/hooks/useShallowObject';
import type { IModelVendor } from '../vendors/IModelVendor';
import { ModelVendorAnthropic } from '../vendors/anthropic/anthropic.vendor';
import { ModelVendorGemini } from '../vendors/gemini/gemini.vendor';
import { ModelVendorOpenAI } from '../vendors/openai/openai.vendor';
import { createModelsServiceForVendor } from '../vendors/vendor.helpers';
import { llmsUpdateModelsForServiceOrThrow } from '../llm.client';
// configuration
const WizardVendors = [
{ vendor: ModelVendorOpenAI, apiKeyField: 'oaiKey' },
{ vendor: ModelVendorAnthropic, apiKeyField: 'anthropicKey' },
{ vendor: ModelVendorGemini, apiKeyField: 'geminiKey' },
// { vendor: ModelVendorOpenRouter, apiKeyField: 'oaiKey' },
] as const;
const wizardContainerSx = {
margin: 'calc(-1 * var(--Card-padding, 1rem))',
padding: 'var(--Card-padding)',
// background: 'linear-gradient(135deg, var(--joy-palette-primary-500), var(--joy-palette-primary-700))',
background: 'linear-gradient(135deg, var(--joy-palette-background-level1), var(--joy-palette-background-level1))',
display: 'grid',
gap: 'var(--Card-padding)',
};
function WizardProviderSetup(props: {
apiKeyField: string,
isFirst: boolean,
vendor: IModelVendor<Record<string, any>, Record<string, any>>,
}) {
const { id: vendorId, name: vendorName, Icon: VendorIcon } = props.vendor;
// state
const [localKey, setLocalKey] = React.useState<string | null>(null);
const [isLoading, setIsLoading] = React.useState(false);
const [updateError, setUpdateError] = React.useState<string | null>(null);
// external state
const stabilizeTransportAccess = useShallowStabilizer<Record<string, any>>();
const { serviceAPIKey, serviceLLMsCount } = useModelsStore(useShallow(({ llms, sources }) => {
// find the service | null
const vendorService = sources.find(s => s.vId === vendorId) ?? null;
// (safe) service-derived properties
const serviceLLMsCount = !vendorService ? null : llms.filter(llm => llm.sId === vendorService.id).length;
const serviceAccess = stabilizeTransportAccess(props.vendor.getTransportAccess(vendorService?.setup));
const serviceAPIKey = !serviceAccess ? null : serviceAccess[props.apiKeyField] ?? null;
return {
serviceAPIKey,
serviceLLMsCount,
};
}));
// [effect] initialize the local key
React.useEffect(() => {
if (localKey === null)
setLocalKey(serviceAPIKey || '');
}, [localKey, serviceAPIKey]);
// handlers
const handleTextChanged = React.useCallback((e: React.ChangeEvent) => {
setLocalKey((e.target as HTMLInputElement).value);
}, []);
const handleSetServiceKey = React.useCallback(async () => {
// create the service if missing
const { sources: llmsServices, addService, updateServiceSettings, setLLMs } = llmsStoreState();
let vendorService = llmsServices.find(s => s.vId === vendorId);
if (!vendorService) {
vendorService = createModelsServiceForVendor(vendorId, llmsServices);
addService(vendorService);
}
const vendorServiceId = vendorService.id;
// set the key
const newKey = localKey?.trim() ?? '';
updateServiceSettings(vendorServiceId, { [props.apiKeyField]: newKey });
// if the key is empty, remove the models
if (!newKey) {
setUpdateError(null);
setLLMs([], vendorServiceId, true, false);
return;
}
// update the models
setUpdateError(null);
setIsLoading(true);
try {
await llmsUpdateModelsForServiceOrThrow(vendorService.id, true);
} catch (error: any) {
let errorText = error.message || 'An error occurred';
if (errorText.includes('Incorrect API key'))
errorText = '[OpenAI issue] Unauthorized: Incorrect API key.';
setUpdateError(errorText);
setLLMs([], vendorServiceId, true, false);
}
setIsLoading(false);
}, [localKey, props.apiKeyField, vendorId]);
// memoed components
const endButtons = React.useMemo(() => ((localKey || '') === (serviceAPIKey || '')) ? null : (
<Box sx={{ display: 'flex', gap: 2 }}>
{/*<TooltipOutlined title='Clear Key'>*/}
{/* <IconButton variant='outlined' color='neutral' onClick={handleClear}>*/}
{/* <ClearIcon />*/}
{/* </IconButton>*/}
{/*</TooltipOutlined>*/}
{/*<TooltipOutlined title='Confirm'>*/}
<Button
variant='solid' color='primary'
onClick={handleSetServiceKey}
// endDecorator={<CheckRoundedIcon />}
>
{!serviceAPIKey ? 'Confirm' : !localKey?.trim() ? 'Clear' : 'Update'}
</Button>
{/*</TooltipOutlined>*/}
</Box>
), [handleSetServiceKey, localKey, serviceAPIKey]);
return (
<Box sx={{ display: 'flex', flexDirection: 'column', gap: 1 }}>
<Box sx={{ display: 'flex', alignItems: 'center', gap: 2 }}>
{/* Left Icon */}
<TooltipOutlined title={serviceLLMsCount ? `${serviceLLMsCount} ${vendorName} models available` : `${vendorName} API Key`} placement='top'>
<Badge
size='md' color='primary' variant='solid' badgeInset='12%'
badgeContent={serviceLLMsCount} showZero={false}
slotProps={{ badge: { sx: { boxShadow: 'xs', border: 'none' } } }}
>
<Avatar sx={{ height: '100%', aspectRatio: 1, backgroundColor: 'transparent' }}>
{isLoading ? <CircularProgress color='primary' variant='solid' size='sm' /> : <VendorIcon />}
</Avatar>
</Badge>
</TooltipOutlined>
{/* Main key inputs */}
<Box sx={{ flex: 1, display: 'grid' }}>
{/* Line 1 */}
{/*{!!props.serviceLabel && (*/}
{/* <Box sx={{ display: 'flex', alignItems: 'center', gap: 1 }}>*/}
{/* /!*<props.vendorIcon />*!/*/}
{/* <Box>{props.serviceLabel}</Box>*/}
{/* </Box>*/}
{/*)}*/}
{/* Line 2 */}
<Input
fullWidth
name={`wizard-api-key-${vendorId}`}
autoComplete='off'
variant='outlined'
value={localKey ?? ''}
onChange={handleTextChanged}
placeholder={`${vendorName} API Key`}
type='password'
// error={!isValidKey}
// startDecorator={<props.vendorIcon />}
endDecorator={endButtons}
/>
</Box>
</Box>
{/*{isLoading && <Typography level='body-xs' sx={{ ml: 7, px: 0.5 }}>Loading your models...</Typography>}*/}
{/*{!isLoading && !updateError && !!llmsCount && (*/}
{/* <Typography level='body-xs' sx={{ ml: 7, px: 0.5 }}>{llmsCount} models added.</Typography>*/}
{/*)}*/}
{!isLoading && !updateError && !serviceLLMsCount && !!serviceAPIKey && (
<Typography level='body-xs' color='warning' sx={{ ml: 7, px: 0.5 }}>No models found.</Typography>
)}
{!!updateError && <Typography level='body-xs' color='danger' sx={{ ml: 7, px: 0.5 }}>{updateError}</Typography>}
</Box>
);
}
export function ModelsWizard(props: {
isMobile: boolean,
onSkip?: () => void,
onSwitchToAdvanced?: () => void,
}) {
return (
<Sheet variant='soft' sx={wizardContainerSx}>
<Box sx={{ ml: 7.25, display: 'flex', flexDirection: 'column', gap: 0.25 }}>
{/*<Typography level='title-sm'>*/}
{/* Quick Start*/}
{/*</Typography>*/}
<Typography level='body-sm'>
Enter API keys to connect Big-AGI to your AI providers.{' '}
{/*{!props.isMobile && <>Switch to <Box component='a' onClick={props.onSwitchToAdvanced} sx={{ textDecoration: 'underline', cursor: 'pointer' }}>Advanced</Box> for more options.</>}*/}
</Typography>
</Box>
{WizardVendors.map(({ vendor, apiKeyField }, index) => (
<WizardProviderSetup key={vendor.id} apiKeyField={apiKeyField} isFirst={!index} vendor={vendor} />
))}
<Box sx={{ ml: 7.25, color: 'text.tertiary', fontSize: 'sm' }}>
{/*{!props.isMobile && <>Switch to <Box component='a' onClick={props.onSwitchToAdvanced} sx={{ textDecoration: 'underline', cursor: 'pointer' }}>Advanced</Box> to choose between {getModelVendorsCount()} services.</>}{' '}*/}
{!props.isMobile && <>Switch to <Box component='a' onClick={props.onSwitchToAdvanced} sx={{ textDecoration: 'underline', cursor: 'pointer' }}>Advanced</Box> for more services.</>}{' '}
Or <Box component='a' onClick={props.onSkip} sx={{ textDecoration: 'underline', cursor: 'pointer' }}>skip</Box> for now and do it later.
</Box>
</Sheet>
);
}
+3
View File
@@ -54,6 +54,9 @@ const MODEL_VENDOR_REGISTRY: Record<ModelVendorId, IModelVendor> = {
xai: ModelVendorXAI,
} as Record<string, IModelVendor>;
export function getModelVendorsCount(): number {
return Object.keys(MODEL_VENDOR_REGISTRY).length;
}
export function findAllModelVendors(): IModelVendor[] {
const modelVendors = Object.values(MODEL_VENDOR_REGISTRY);