From ec493ee91bc96b9b22a8beef606772a52b538b2a Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 15 Oct 2024 13:49:03 -0700 Subject: [PATCH] Wizard: Models --- src/common/stores/llms/llms.hooks.ts | 6 +- src/common/stores/llms/store-llms.ts | 9 + src/modules/llms/models-modal/ModelsModal.tsx | 148 ++++++++--- .../models-modal/ModelsServiceSelector.tsx | 21 +- .../llms/models-modal/ModelsWizard.tsx | 234 ++++++++++++++++++ src/modules/llms/vendors/vendors.registry.ts | 3 + 6 files changed, 373 insertions(+), 48 deletions(-) create mode 100644 src/modules/llms/models-modal/ModelsWizard.tsx diff --git a/src/common/stores/llms/llms.hooks.ts b/src/common/stores/llms/llms.hooks.ts index 5a30595b2..eafff742a 100644 --- a/src/common/stores/llms/llms.hooks.ts +++ b/src/common/stores/llms/llms.hooks.ts @@ -57,5 +57,9 @@ export function useHasLLMs(): boolean { } export function useModelsServices() { - return useModelsStore(state => state.sources); + return useModelsStore(useShallow(state => ({ + modelsServices: state.sources, + confServiceId: state.confServiceId, + setConfServiceId: state.setConfServiceId, + }))); } \ No newline at end of file diff --git a/src/common/stores/llms/store-llms.ts b/src/common/stores/llms/store-llms.ts index 9215af637..87f2f603e 100644 --- a/src/common/stores/llms/store-llms.ts +++ b/src/common/stores/llms/store-llms.ts @@ -25,6 +25,8 @@ interface LlmsState { chatLLMId: DLLMId | null; fastLLMId: DLLMId | null; + confServiceId: DModelsServiceId | null; + } interface LlmsActions { @@ -43,6 +45,8 @@ interface LlmsActions { setChatLLMId: (id: DLLMId | null) => void; setFastLLMId: (id: DLLMId | null) => void; + setConfServiceId: (id: DModelsServiceId | null) => void; + // special setOpenRouterKey: (key: string) => void; @@ -59,6 +63,7 @@ export const useModelsStore = create()(persist( chatLLMId: null, fastLLMId: null, + confServiceId: null, // actions @@ -166,6 +171,7 @@ export const useModelsStore = create()(persist( } : s, ), + confServiceId: state.confServiceId ?? service.id, }; }), @@ -188,6 +194,9 @@ export const useModelsStore = create()(persist( ), })), + setConfServiceId: (id: DModelsServiceId | null) => + set({ confServiceId: id }), + setOpenRouterKey: (key: string) => set(state => { const firstOpenRouterService = state.sources.find(s => s.vId === 'openrouter'); diff --git a/src/modules/llms/models-modal/ModelsModal.tsx b/src/modules/llms/models-modal/ModelsModal.tsx index 0d42cd7b6..00f73fe29 100644 --- a/src/modules/llms/models-modal/ModelsModal.tsx +++ b/src/modules/llms/models-modal/ModelsModal.tsx @@ -1,21 +1,27 @@ import * as React from 'react'; -import { Box, Divider } from '@mui/joy'; +import { Box, Button, Divider } from '@mui/joy'; -import type { DModelsService, DModelsServiceId } from '~/common/stores/llms/modelsservice.types'; +import type { DModelsService } from '~/common/stores/llms/modelsservice.types'; import { GoodModal } from '~/common/components/modals/GoodModal'; import { llmsStoreState } from '~/common/stores/llms/store-llms'; import { optimaActions, optimaOpenModels, useOptimaModelsModalsState } from '~/common/layout/optima/useOptima'; import { runWhenIdle } from '~/common/util/pwaUtils'; -import { useLLMsCount, useModelsServices } from '~/common/stores/llms/llms.hooks'; +import { useIsMobile } from '~/common/components/useMatchMedia'; +import { useHasLLMs, useModelsServices } from '~/common/stores/llms/llms.hooks'; import { LLMOptionsModal } from './LLMOptionsModal'; import { ModelsList } from './ModelsList'; import { ModelsServiceSelector } from './ModelsServiceSelector'; +import { ModelsWizard } from './ModelsWizard'; import { createModelsServiceForDefaultVendor } from '../vendors/vendor.helpers'; import { findModelVendor } from '../vendors/vendors.registry'; +// configuration +const MODELS_WIZARD_ENABLE_INITIALLY = true; + + function VendorServiceSetup(props: { service: DModelsService }) { const vendor = findModelVendor(props.service.vId); if (!vendor) @@ -24,35 +30,40 @@ function VendorServiceSetup(props: { service: DModelsService }) { } -export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { +/** + * Note: the reason for this component separation from the parent state, is delayed state intitialization. + */ +function ModelsConfiguratorModal(props: { + modelsServices: DModelsService[], + confServiceId: string | null, + setConfServiceId: (serviceId: string | null) => void, + allowAutoTrigger: boolean, +}) { - // local state - const [_selectedServiceId, setSelectedServiceId] = React.useState(null); + const { modelsServices, confServiceId, setConfServiceId } = props; + + // state // const [showAllServices, setShowAllServices] = React.useState(false); + const [showWizard, setShowWizard] = React.useState(MODELS_WIZARD_ENABLE_INITIALLY && !modelsServices.length); const showAllServices = false; // external state - const { showModels, showModelOptions } = useOptimaModelsModalsState(); - const modelsServices = useModelsServices(); - const llmCount = useLLMsCount(); + const isMobile = useIsMobile(); + const hasLLMs = useHasLLMs(); - // auto-select the first service - note: we could use a useEffect() here, but this is more efficient - // also note that state-persistence is unneeded - const selectedServiceId = _selectedServiceId ?? modelsServices[modelsServices.length - 1]?.id ?? null; - const activeService = modelsServices.find(s => s.id === selectedServiceId); + // active service with fallback to the last added service + const activeServiceId = confServiceId + ?? modelsServices[modelsServices.length - 1]?.id + ?? null; - // const multiService = modelsServices.length > 1; + const activeService = modelsServices.find(s => s.id === activeServiceId); + + const isMultiServices = modelsServices.length > 1; - // Auto-open this dialog - anytime no service is selected - const autoOpenTrigger = !selectedServiceId && !props.suspendAutoModelsSetup; - React.useEffect(() => { - if (autoOpenTrigger) - return runWhenIdle(() => optimaOpenModels(), 2000); - }, [autoOpenTrigger]); // Auto-add the default service - at boot, when no service is present - const autoAddTrigger = showModels && !props.suspendAutoModelsSetup; + const autoAddTrigger = !showWizard && props.allowAutoTrigger; React.useEffect(() => { // Note: we use the immediate version to not react to deletions const { addService, sources: modelsServices } = llmsStoreState(); @@ -61,43 +72,67 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { }, [autoAddTrigger]); - return <> + // handlers + const handleShowAdvanced = React.useCallback(() => { + setShowWizard(false); + }, []); - {/* Services Setup */} - {showModels && Configure AI Models} + const handleShowWizard = React.useCallback(() => { + setShowWizard(true); + }, []); + + + // start button + const startButton = React.useMemo(() => { + if (showWizard) + return ; + // return ; + if (!isMultiServices) + return ; + return undefined; + // if (isMultiServices) { + // return ( + // setShowAllServices(all => !all)} + // /> + // ); + // } + }, [handleShowAdvanced, handleShowWizard, isMobile, isMultiServices, showWizard]); + + return ( + {showWizard ? 'Welcome ยท Setup' : 'Configure'} AI Models} open onClose={optimaActions().closeModels} darkBottomClose - animateEnter={llmCount === 0} + closeText={showWizard ? 'Done' : undefined} + animateEnter={!hasLLMs} unfilterBackdrop - // startButton={ - // multiService ? setShowAllServices(all => !all)} - // /> : undefined - // } + startButton={startButton} sx={{ // forces some shrinkage of the contents (ModelsList) overflow: 'auto', }} > - + {!showWizard && } - {!!activeService && } + - {!!activeService && ( + {showWizard && } + + {!showWizard && !!activeService && ( )} - {!!llmCount && } + {!showWizard && hasLLMs && } - {!!llmCount && ( + {!showWizard && hasLLMs && ( Dialog) overflow: 'auto' @@ -119,9 +154,40 @@ export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { /> )} - + - } + + ); +} + + +export function ModelsModal(props: { suspendAutoModelsSetup?: boolean }) { + + // external state + const { showModels, showModelOptions } = useOptimaModelsModalsState(); + const { modelsServices, confServiceId, setConfServiceId } = useModelsServices(); + + + // [effect] Auto-open the configurator - anytime no service is selected + const hasNoServices = !modelsServices.length; + const autoOpenTrigger = hasNoServices && !props.suspendAutoModelsSetup; + React.useEffect(() => { + if (autoOpenTrigger) + return runWhenIdle(() => optimaOpenModels(), 2000); + }, [autoOpenTrigger]); + + + return <> + + {/* Services Setup */} + {showModels && ( + + )} {/* per-LLM options */} {!!showModelOptions && ( diff --git a/src/modules/llms/models-modal/ModelsServiceSelector.tsx b/src/modules/llms/models-modal/ModelsServiceSelector.tsx index 55604d457..54cd16263 100644 --- a/src/modules/llms/models-modal/ModelsServiceSelector.tsx +++ b/src/modules/llms/models-modal/ModelsServiceSelector.tsx @@ -10,12 +10,12 @@ import { ConfirmationModal } from '~/common/components/modals/ConfirmationModal' import { llmsStoreActions, llmsStoreState } from '~/common/stores/llms/store-llms'; import { themeZIndexOverMobileDrawer } from '~/common/app.theme'; import { useIsMobile } from '~/common/components/useMatchMedia'; -import { useModelsServices } from '~/common/stores/llms/llms.hooks'; import { useOverlayComponents } from '~/common/layout/overlays/useOverlayComponents'; import type { IModelVendor } from '../vendors/IModelVendor'; import { createModelsServiceForVendor, vendorHasBackendCap } from '../vendors/vendor.helpers'; import { findAllModelVendors, findModelVendor, ModelVendorId } from '../vendors/vendors.registry'; +// import { MODELS_WIZARD_OPTION_ID } from '~/modules/llms/models-modal/ModelsModal'; /*function locationIcon(vendor?: IModelVendor | null) { @@ -35,7 +35,9 @@ function vendorIcon(vendor: IModelVendor | null, greenMark: boolean) { export function ModelsServiceSelector(props: { - selectedServiceId: DModelsServiceId | null, setSelectedServiceId: (serviceId: DModelsServiceId | null) => void, + modelsServices: DModelsService[], + selectedServiceId: DModelsServiceId | null, + setSelectedServiceId: (serviceId: DModelsServiceId | null) => void, }) { // state @@ -44,7 +46,6 @@ export function ModelsServiceSelector(props: { // external state const isMobile = useIsMobile(); - const modelsServices = useModelsServices(); const handleShowVendors = (event: React.MouseEvent) => setVendorsMenuAnchor(event.currentTarget); @@ -53,7 +54,7 @@ export function ModelsServiceSelector(props: { // handlers - const { setSelectedServiceId } = props; + const { modelsServices, setSelectedServiceId } = props; const handleAddServiceForVendor = React.useCallback((vendorId: ModelVendorId) => { closeVendorsMenu(); @@ -67,7 +68,15 @@ export function ModelsServiceSelector(props: { const enableDeleteButton = !!props.selectedServiceId && modelsServices.length > 1; - const handleDeleteService = React.useCallback(async (serviceId: DModelsServiceId) => { + const handleDeleteService = React.useCallback(async (serviceId: DModelsServiceId, skipConfirmation: boolean) => { + // [shift] to delete without confirmation + if (skipConfirmation) { + // select the next service + setSelectedServiceId(modelsServices.find(s => s.id !== serviceId)?.id ?? null); + // remove the service + llmsStoreActions().removeService(serviceId); + return; + } showPromisedOverlay('llms-service-remove', {}, ({ onResolve, onUserReject }) => onResolve(true)} @@ -222,7 +231,7 @@ export function ModelsServiceSelector(props: { props.selectedServiceId && handleDeleteService(props.selectedServiceId)} + onClick={(event) => props.selectedServiceId && handleDeleteService(props.selectedServiceId, event.shiftKey)} > diff --git a/src/modules/llms/models-modal/ModelsWizard.tsx b/src/modules/llms/models-modal/ModelsWizard.tsx new file mode 100644 index 000000000..31409e831 --- /dev/null +++ b/src/modules/llms/models-modal/ModelsWizard.tsx @@ -0,0 +1,234 @@ +import * as React from 'react'; +import { useShallow } from 'zustand/react/shallow'; + +import { Avatar, Badge, Box, Button, CircularProgress, Input, Sheet, Typography } from '@mui/joy'; + +import { TooltipOutlined } from '~/common/components/TooltipOutlined'; +import { llmsStoreState, useModelsStore } from '~/common/stores/llms/store-llms'; +import { useShallowStabilizer } from '~/common/util/hooks/useShallowObject'; + +import type { IModelVendor } from '../vendors/IModelVendor'; +import { ModelVendorAnthropic } from '../vendors/anthropic/anthropic.vendor'; +import { ModelVendorGemini } from '../vendors/gemini/gemini.vendor'; +import { ModelVendorOpenAI } from '../vendors/openai/openai.vendor'; +import { createModelsServiceForVendor } from '../vendors/vendor.helpers'; +import { llmsUpdateModelsForServiceOrThrow } from '../llm.client'; + + +// configuration +const WizardVendors = [ + { vendor: ModelVendorOpenAI, apiKeyField: 'oaiKey' }, + { vendor: ModelVendorAnthropic, apiKeyField: 'anthropicKey' }, + { vendor: ModelVendorGemini, apiKeyField: 'geminiKey' }, + // { vendor: ModelVendorOpenRouter, apiKeyField: 'oaiKey' }, +] as const; + + +const wizardContainerSx = { + margin: 'calc(-1 * var(--Card-padding, 1rem))', + padding: 'var(--Card-padding)', + // background: 'linear-gradient(135deg, var(--joy-palette-primary-500), var(--joy-palette-primary-700))', + background: 'linear-gradient(135deg, var(--joy-palette-background-level1), var(--joy-palette-background-level1))', + display: 'grid', + gap: 'var(--Card-padding)', +}; + + +function WizardProviderSetup(props: { + apiKeyField: string, + isFirst: boolean, + vendor: IModelVendor, Record>, +}) { + + const { id: vendorId, name: vendorName, Icon: VendorIcon } = props.vendor; + + // state + const [localKey, setLocalKey] = React.useState(null); + const [isLoading, setIsLoading] = React.useState(false); + const [updateError, setUpdateError] = React.useState(null); + + // external state + const stabilizeTransportAccess = useShallowStabilizer>(); + const { serviceAPIKey, serviceLLMsCount } = useModelsStore(useShallow(({ llms, sources }) => { + + // find the service | null + const vendorService = sources.find(s => s.vId === vendorId) ?? null; + + // (safe) service-derived properties + const serviceLLMsCount = !vendorService ? null : llms.filter(llm => llm.sId === vendorService.id).length; + const serviceAccess = stabilizeTransportAccess(props.vendor.getTransportAccess(vendorService?.setup)); + const serviceAPIKey = !serviceAccess ? null : serviceAccess[props.apiKeyField] ?? null; + + return { + serviceAPIKey, + serviceLLMsCount, + }; + })); + + // [effect] initialize the local key + React.useEffect(() => { + if (localKey === null) + setLocalKey(serviceAPIKey || ''); + }, [localKey, serviceAPIKey]); + + + // handlers + + const handleTextChanged = React.useCallback((e: React.ChangeEvent) => { + setLocalKey((e.target as HTMLInputElement).value); + }, []); + + const handleSetServiceKey = React.useCallback(async () => { + + // create the service if missing + const { sources: llmsServices, addService, updateServiceSettings, setLLMs } = llmsStoreState(); + let vendorService = llmsServices.find(s => s.vId === vendorId); + if (!vendorService) { + vendorService = createModelsServiceForVendor(vendorId, llmsServices); + addService(vendorService); + } + const vendorServiceId = vendorService.id; + + // set the key + const newKey = localKey?.trim() ?? ''; + updateServiceSettings(vendorServiceId, { [props.apiKeyField]: newKey }); + + // if the key is empty, remove the models + if (!newKey) { + setUpdateError(null); + setLLMs([], vendorServiceId, true, false); + return; + } + + // update the models + setUpdateError(null); + setIsLoading(true); + try { + await llmsUpdateModelsForServiceOrThrow(vendorService.id, true); + } catch (error: any) { + let errorText = error.message || 'An error occurred'; + if (errorText.includes('Incorrect API key')) + errorText = '[OpenAI issue] Unauthorized: Incorrect API key.'; + setUpdateError(errorText); + setLLMs([], vendorServiceId, true, false); + } + setIsLoading(false); + + }, [localKey, props.apiKeyField, vendorId]); + + + // memoed components + + const endButtons = React.useMemo(() => ((localKey || '') === (serviceAPIKey || '')) ? null : ( + + {/**/} + {/* */} + {/* */} + {/* */} + {/**/} + {/**/} + + {/**/} + + ), [handleSetServiceKey, localKey, serviceAPIKey]); + + + return ( + + + + + {/* Left Icon */} + + + + {isLoading ? : } + + + + + {/* Main key inputs */} + + + {/* Line 1 */} + {/*{!!props.serviceLabel && (*/} + {/* */} + {/* /!**!/*/} + {/* {props.serviceLabel}*/} + {/* */} + {/*)}*/} + + {/* Line 2 */} + } + endDecorator={endButtons} + /> + + + + + + {/*{isLoading && Loading your models...}*/} + {/*{!isLoading && !updateError && !!llmsCount && (*/} + {/* {llmsCount} models added.*/} + {/*)}*/} + {!isLoading && !updateError && !serviceLLMsCount && !!serviceAPIKey && ( + No models found. + )} + {!!updateError && {updateError}} + + + ); +} + + +export function ModelsWizard(props: { + isMobile: boolean, + onSkip?: () => void, + onSwitchToAdvanced?: () => void, +}) { + return ( + + + + {/**/} + {/* Quick Start*/} + {/**/} + + Enter API keys to connect Big-AGI to your AI providers.{' '} + {/*{!props.isMobile && <>Switch to Advanced for more options.}*/} + + + + {WizardVendors.map(({ vendor, apiKeyField }, index) => ( + + ))} + + + {/*{!props.isMobile && <>Switch to Advanced to choose between {getModelVendorsCount()} services.}{' '}*/} + {!props.isMobile && <>Switch to Advanced for more services.}{' '} + Or skip for now and do it later. + + + + ); +} \ No newline at end of file diff --git a/src/modules/llms/vendors/vendors.registry.ts b/src/modules/llms/vendors/vendors.registry.ts index 3cb5afe92..f4b0c611b 100644 --- a/src/modules/llms/vendors/vendors.registry.ts +++ b/src/modules/llms/vendors/vendors.registry.ts @@ -54,6 +54,9 @@ const MODEL_VENDOR_REGISTRY: Record = { xai: ModelVendorXAI, } as Record; +export function getModelVendorsCount(): number { + return Object.keys(MODEL_VENDOR_REGISTRY).length; +} export function findAllModelVendors(): IModelVendor[] { const modelVendors = Object.values(MODEL_VENDOR_REGISTRY);