mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
Models: upgrade data structure to v2
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import type { ModelDescriptionSchema } from './server/llm.server.types';
|
||||
import type { OpenAIWire } from './server/openai/openai.wiretypes';
|
||||
import type { StreamingClientUpdate } from './vendors/unifiedStreamingClient';
|
||||
import { DLLM, DLLMId, DModelSource, DModelSourceId, useModelsStore } from './store-llms';
|
||||
import { DLLM, DLLMId, DModelSource, DModelSourceId, LLM_IF_OAI_Chat, useModelsStore } from './store-llms';
|
||||
import { FALLBACK_LLM_TEMPERATURE } from './vendors/openai/openai.vendor';
|
||||
import { findAccessForSourceOrThrow, findVendorForLlmOrThrow } from './vendors/vendors.registry';
|
||||
|
||||
@@ -62,16 +62,26 @@ function modelDescriptionToDLLMOpenAIOptions<TSourceSetup, TLLMOptions>(model: M
|
||||
return {
|
||||
id: `${source.id}-${model.id}`,
|
||||
|
||||
// editable properties
|
||||
label: model.label,
|
||||
created: model.created || 0,
|
||||
updated: model.updated || 0,
|
||||
description: model.description,
|
||||
tags: [], // ['stream', 'chat'],
|
||||
hidden: !!model.hidden,
|
||||
// isEdited: false, // NOTE: this is set by the store on user edits
|
||||
|
||||
// hard properties
|
||||
contextTokens,
|
||||
maxOutputTokens,
|
||||
hidden: !!model.hidden,
|
||||
trainingDataCutoff: model.trainingDataCutoff,
|
||||
interfaces: model.interfaces?.length ? model.interfaces : [LLM_IF_OAI_Chat],
|
||||
// inputTypes: ...
|
||||
benchmark: model.benchmark,
|
||||
pricing: model.pricing,
|
||||
|
||||
isFree: model.pricing?.chatIn === 0 && model.pricing?.chatOut === 0,
|
||||
// derived properties
|
||||
tmpIsFree: model.pricing?.chatIn === 0 && model.pricing?.chatOut === 0,
|
||||
tmpIsVision: model.interfaces?.includes(LLM_IF_OAI_Chat) === true,
|
||||
|
||||
sId: source.id,
|
||||
_source: source,
|
||||
|
||||
@@ -126,7 +126,7 @@ export function LLMOptionsModal(props: { id: DLLMId, onClose: () => void }) {
|
||||
<Typography level='body-md'>
|
||||
{llm.id}
|
||||
</Typography>
|
||||
{llm.isFree && <Typography level='body-xs'>
|
||||
{!!llm.tmpIsFree && <Typography level='body-xs'>
|
||||
🎁 Free model - note: refresh models to check for updates in pricing
|
||||
</Typography>}
|
||||
{!!llm.description && <Typography level='body-xs'>
|
||||
|
||||
@@ -61,9 +61,11 @@ export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, al
|
||||
description: descriptionLong,
|
||||
contextWindow: contextWindow,
|
||||
maxCompletionTokens: outputTokenLimit,
|
||||
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
|
||||
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
|
||||
// trainingDataCutoff: '...',
|
||||
interfaces,
|
||||
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
|
||||
// benchmarks: ...
|
||||
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
|
||||
hidden,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ const benchmarkSchema = z.object({
|
||||
// reqPerMinute: z.number().optional(),
|
||||
// });
|
||||
|
||||
const interfaceSchema = z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision, LLM_IF_OAI_Json]);
|
||||
|
||||
// NOTE: update the `fromManualMapping` function if you add new fields
|
||||
const modelDescriptionSchema = z.object({
|
||||
id: z.string(),
|
||||
label: z.string(),
|
||||
@@ -28,10 +31,11 @@ const modelDescriptionSchema = z.object({
|
||||
maxCompletionTokens: z.number().optional(),
|
||||
// rateLimits: rateLimitsSchema.optional(),
|
||||
trainingDataCutoff: z.string().optional(),
|
||||
interfaces: z.array(z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision, LLM_IF_OAI_Json])),
|
||||
pricing: pricingSchema.optional(),
|
||||
interfaces: z.array(interfaceSchema),
|
||||
benchmark: benchmarkSchema.optional(),
|
||||
pricing: pricingSchema.optional(),
|
||||
hidden: z.boolean().optional(),
|
||||
// TODO: add inputTypes/Kinds..
|
||||
});
|
||||
|
||||
// this is also used by the Client
|
||||
|
||||
@@ -607,11 +607,11 @@ export function openRouterModelToModelDescription(wireModel: object): ModelDescr
|
||||
chatIn: parseFloat(model.pricing.prompt) * 1000,
|
||||
chatOut: parseFloat(model.pricing.completion),
|
||||
};
|
||||
const isFree = pricing.chatIn === 0 && pricing.chatOut === 0;
|
||||
const seemsFree = pricing.chatIn === 0 && pricing.chatOut === 0;
|
||||
|
||||
// openrouter provides the fields we need as part of the model object
|
||||
let label = model.name || model.id.replace('/', ' · ');
|
||||
if (isFree)
|
||||
if (seemsFree)
|
||||
label += ' · 🎁'; // Free? Discounted?
|
||||
|
||||
// hidden: hide by default older models or models not in known families
|
||||
@@ -626,8 +626,10 @@ export function openRouterModelToModelDescription(wireModel: object): ModelDescr
|
||||
description: model.description,
|
||||
contextWindow: model.context_length || 4096,
|
||||
maxCompletionTokens: model.top_provider.max_completion_tokens || undefined,
|
||||
pricing,
|
||||
// trainingDataCutoff: ...
|
||||
interfaces: [LLM_IF_OAI_Chat],
|
||||
// benchmark: ...
|
||||
pricing,
|
||||
hidden,
|
||||
});
|
||||
}
|
||||
@@ -905,8 +907,10 @@ function fromManualMapping(mappings: ManualMappings, id: string, created?: numbe
|
||||
description: known.description,
|
||||
contextWindow: known.contextWindow,
|
||||
...(!!known.maxCompletionTokens && { maxCompletionTokens: known.maxCompletionTokens }),
|
||||
...(!!known.pricing && { pricing: known.pricing }),
|
||||
...(!!known.trainingDataCutoff && { trainingDataCutoff: known.trainingDataCutoff }),
|
||||
interfaces: known.interfaces,
|
||||
...(!!known.benchmark && { benchmark: known.benchmark }),
|
||||
...(!!known.pricing && { pricing: known.pricing }),
|
||||
...(!!known.hidden && { hidden: known.hidden }),
|
||||
};
|
||||
}
|
||||
+103
-57
@@ -1,5 +1,4 @@
|
||||
import { create } from 'zustand';
|
||||
import { shallow } from 'zustand/shallow';
|
||||
import { persist } from 'zustand/middleware';
|
||||
|
||||
import type { ModelVendorId } from './vendors/vendors.registry';
|
||||
@@ -11,18 +10,36 @@ import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vend
|
||||
*/
|
||||
export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
|
||||
id: DLLMId;
|
||||
|
||||
// editable properties (kept on update, if isEdited)
|
||||
label: string;
|
||||
created: number | 0;
|
||||
updated?: number | 0;
|
||||
description: string;
|
||||
tags: string[]; // UNUSED for now
|
||||
// modelcaps: DModelCapability[];
|
||||
hidden: boolean; // hidden from UI selectors
|
||||
isEdited?: boolean; // user has edited the soft properties
|
||||
|
||||
// hard properties (overwritten on update)
|
||||
contextTokens: number | null; // null: must assume it's unknown
|
||||
maxOutputTokens: number | null; // null: must assume it's unknown
|
||||
hidden: boolean; // hidden from Chat model UI selectors
|
||||
trainingDataCutoff?: string; // [v2] 'Apr 2029'
|
||||
interfaces: DModelInterfaceV1[]; // [v2] if set, meaning this is the known and comprehensive set of interfaces
|
||||
// inputTypes: { // [v2] the supported input formats
|
||||
// [key in DModelPartKind]?: {
|
||||
// // maxItemsPerInput?: number;
|
||||
// // maxFileSize?: number; // in bytes
|
||||
// // maxDurationPerInput?: number; // in seconds, for audio and video
|
||||
// // maxPagesPerInput?: number; // for PDF
|
||||
// // encodings?: ('base64' | 'utf-8')[];
|
||||
// mimeTypes?: string[];
|
||||
// }
|
||||
// };
|
||||
benchmark?: { cbaElo?: number, cbaMmlu?: number }; // [v2] benchmark values
|
||||
pricing?: { chatIn?: number, chatOut?: number }; // [v2] cost per million tokens
|
||||
|
||||
// temporary special flags - not graduated yet
|
||||
isFree: boolean; // model is free to use
|
||||
// derived properties
|
||||
tmpIsFree?: boolean; // model is free to use [temporary, for now], this is a derived property from the pricing
|
||||
tmpIsVision?: boolean; // model can take image inputs
|
||||
|
||||
// llm -> source
|
||||
sId: DModelSourceId;
|
||||
@@ -34,6 +51,25 @@ export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
|
||||
|
||||
export type DLLMId = string;
|
||||
|
||||
// export type DModelPartKind = 'text' | 'image' | 'audio' | 'video' | 'pdf';
|
||||
|
||||
export type DModelInterfaceV1 =
|
||||
// do not change anything below! those will be persisted in data
|
||||
| 'oai-chat'
|
||||
| 'oai-chat-json'
|
||||
| 'oai-chat-vision'
|
||||
| 'oai-chat-fn'
|
||||
| 'oai-complete'
|
||||
// only append below this line
|
||||
;
|
||||
|
||||
// Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future
|
||||
export const LLM_IF_OAI_Chat: DModelInterfaceV1 = 'oai-chat';
|
||||
export const LLM_IF_OAI_Json: DModelInterfaceV1 = 'oai-chat-json';
|
||||
export const LLM_IF_OAI_Vision: DModelInterfaceV1 = 'oai-chat-vision';
|
||||
export const LLM_IF_OAI_Fn: DModelInterfaceV1 = 'oai-chat-fn';
|
||||
export const LLM_IF_OAI_Complete: DModelInterfaceV1 = 'oai-complete';
|
||||
|
||||
// export type DModelCapability =
|
||||
// | 'input-text'
|
||||
// | 'input-image-data'
|
||||
@@ -44,13 +80,7 @@ export type DLLMId = string;
|
||||
// | 'if-chat'
|
||||
// | 'if-fast-chat'
|
||||
// ;
|
||||
|
||||
// Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future
|
||||
export const LLM_IF_OAI_Chat = 'oai-chat';
|
||||
export const LLM_IF_OAI_Json = 'oai-chat-json';
|
||||
export const LLM_IF_OAI_Vision = 'oai-chat-vision';
|
||||
export const LLM_IF_OAI_Fn = 'oai-chat-fn';
|
||||
export const LLM_IF_OAI_Complete = 'oai-complete';
|
||||
// modelcaps: DModelCapability[];
|
||||
|
||||
|
||||
/**
|
||||
@@ -231,8 +261,9 @@ export const useModelsStore = create<LlmsStore>()(
|
||||
|
||||
/* versioning:
|
||||
* 1: adds maxOutputTokens (default to half of contextTokens)
|
||||
* 2: large changes on all LLMs, and reset chat/fast/func LLMs
|
||||
*/
|
||||
version: 1,
|
||||
version: 2,
|
||||
migrate: (state: any, fromVersion: number): LlmsStore => {
|
||||
|
||||
// 0 -> 1: add 'maxOutputTokens' where missing
|
||||
@@ -241,6 +272,19 @@ export const useModelsStore = create<LlmsStore>()(
|
||||
if (llm.maxOutputTokens === undefined)
|
||||
llm.maxOutputTokens = llm.contextTokens ? Math.round(llm.contextTokens / 2) : null;
|
||||
|
||||
// 1 -> 2: large changes
|
||||
if (state && fromVersion < 2) {
|
||||
for (const llm of state.llms) {
|
||||
delete llm['tags'];
|
||||
llm.interfaces = [LLM_IF_OAI_Chat];
|
||||
// llm.inputTypes = { 'text': {} };
|
||||
}
|
||||
// const autoPickModels = updateSelectedIds(state.llms, null, null, null);
|
||||
state.chatLLMId = null; // autoPickModels.chatLLMId;
|
||||
state.fastLLMId = null; // ...
|
||||
state.funcLLMId = null; // ...
|
||||
}
|
||||
|
||||
return state;
|
||||
},
|
||||
|
||||
@@ -285,64 +329,66 @@ export function findSourceOrThrow<TSourceSetup>(sourceId: DModelSourceId) {
|
||||
}
|
||||
|
||||
|
||||
const modelsKnowledgeMap: { contains: string[], cutoff: string }[] = [
|
||||
{ contains: ['4-0125', '4-turbo'], cutoff: '2023-12' },
|
||||
{ contains: ['4-1106', '4-vision'], cutoff: '2023-04' },
|
||||
{ contains: ['4-0613', '4-0314', '4-32k', '3.5-turbo'], cutoff: '2021-09' },
|
||||
] as const;
|
||||
function groupLlmsByVendor(llms: DLLM[]): { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[] {
|
||||
// group all LLMs by vendor
|
||||
const grouped = llms.reduce((acc, llm) => {
|
||||
if (llm.hidden) return acc;
|
||||
const vendorId = llm._source.vId;
|
||||
const vendor = acc.find(v => v.vendorId === vendorId);
|
||||
if (!vendor) acc.push({ vendorId, llmsByElo: [{ id: llm.id, cbaElo: llm.benchmark?.cbaElo }] });
|
||||
else vendor.llmsByElo.push({ id: llm.id, cbaElo: llm.benchmark?.cbaElo });
|
||||
return acc;
|
||||
}, [] as { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[]);
|
||||
|
||||
export function getKnowledgeMapCutoff(llmId?: DLLMId): string | null {
|
||||
if (llmId)
|
||||
for (const { contains, cutoff } of modelsKnowledgeMap)
|
||||
if (contains.some(c => llmId.includes(c)))
|
||||
return cutoff;
|
||||
return null;
|
||||
// sort each vendor's LLMs by elo, decreasing
|
||||
for (const vendor of grouped)
|
||||
vendor.llmsByElo.sort((a, b) => (b.cbaElo ?? -1) - (a.cbaElo ?? -1));
|
||||
|
||||
// sort all vendors by their highest elo, decreasing
|
||||
grouped.sort((a, b) => (b.llmsByElo[0].cbaElo ?? -1) - (a.llmsByElo[0].cbaElo ?? -1));
|
||||
return grouped;
|
||||
}
|
||||
|
||||
const defaultChatSuffixPreference = ['gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4-0613', 'gpt-4', 'gpt-4-32k', 'gpt-3.5-turbo'];
|
||||
const defaultFastSuffixPreference = ['gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo'];
|
||||
const defaultFuncSuffixPreference = ['gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-4-0613'];
|
||||
|
||||
function updateSelectedIds(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null): Partial<ModelsData> {
|
||||
if (chatLlmId && !allLlms.find(llm => llm.id === chatLlmId)) chatLlmId = null;
|
||||
if (!chatLlmId) chatLlmId = findLlmIdBySuffix(allLlms, defaultChatSuffixPreference, true);
|
||||
export function updateSelectedIds(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null) {
|
||||
|
||||
if (fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) fastLlmId = null;
|
||||
if (!fastLlmId) fastLlmId = findLlmIdBySuffix(allLlms, defaultFastSuffixPreference, true);
|
||||
// the output of groupLlmsByVendor
|
||||
let grouped: ReturnType<typeof groupLlmsByVendor> | null = null;
|
||||
|
||||
if (funcLlmId && !allLlms.find(llm => llm.id === funcLlmId)) funcLlmId = null;
|
||||
if (!funcLlmId) funcLlmId = findLlmIdBySuffix(allLlms, defaultFuncSuffixPreference, false);
|
||||
function cachedGrouped() {
|
||||
if (!grouped) grouped = groupLlmsByVendor(allLlms);
|
||||
return grouped;
|
||||
}
|
||||
|
||||
// the best llm
|
||||
if (!chatLlmId || !allLlms.find(llm => llm.id === chatLlmId)) {
|
||||
const vendors = cachedGrouped();
|
||||
chatLlmId = vendors.length ? vendors[0].llmsByElo[0].id : null;
|
||||
}
|
||||
|
||||
// a fast llm (bottom elo of the top vendor ~~ not really a proxy, but not sure which heuristic to use here)
|
||||
if (!fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) {
|
||||
const vendors = cachedGrouped();
|
||||
fastLlmId = vendors.length
|
||||
? vendors[0].llmsByElo.findLast(llm => llm.cbaElo)?.id // last with ELO
|
||||
?? vendors[0].llmsByElo[vendors[0].llmsByElo.length - 1].id ?? null // last
|
||||
: null;
|
||||
}
|
||||
|
||||
// a func llm (
|
||||
if (!funcLlmId || !allLlms.find(llm => llm.id === funcLlmId))
|
||||
funcLlmId = chatLlmId;
|
||||
|
||||
return { chatLLMId: chatLlmId, fastLLMId: fastLlmId, funcLLMId: funcLlmId };
|
||||
}
|
||||
|
||||
function findLlmIdBySuffix(llms: DLLM[], suffixes: string[], fallbackToFirst: boolean): DLLMId | null {
|
||||
if (!llms?.length) return null;
|
||||
for (const suffix of suffixes)
|
||||
for (const llm of llms)
|
||||
if (llm.id.endsWith(suffix))
|
||||
return llm.id;
|
||||
if (!fallbackToFirst) return null;
|
||||
|
||||
// otherwise return first that's not hidden
|
||||
for (const llm of llms)
|
||||
if (!llm.hidden)
|
||||
return llm.id;
|
||||
|
||||
// otherwise return first id
|
||||
return llms[0].id;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Current 'Chat' LLM, or null
|
||||
*/
|
||||
export function useChatLLM() {
|
||||
return useModelsStore(state => {
|
||||
const { chatLLMId } = state;
|
||||
const chatLLM = chatLLMId ? state.llms.find(llm => llm.id === chatLLMId) ?? null : null;
|
||||
return { chatLLM };
|
||||
}, shallow);
|
||||
const chatLLM = useModelsStore(state => state.chatLLMId ? state.llms.find(llm => llm.id === state.chatLLMId) ?? null : null);
|
||||
return { chatLLM };
|
||||
}
|
||||
|
||||
export function getLLMsDebugInfo() {
|
||||
|
||||
@@ -61,7 +61,7 @@ export const ModelVendorOpenRouter: IModelVendor<SourceSetupOpenRouter, OpenAIAc
|
||||
getRateLimitDelay: (llm) => {
|
||||
const now = Date.now();
|
||||
const elapsed = now - nextGenerationTs;
|
||||
const wait = llm.isFree
|
||||
const wait = llm.tmpIsFree
|
||||
? 5000 + 100 /* 5 seconds for free call, plus some safety margin */
|
||||
: 100;
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { DLLMId, getKnowledgeMapCutoff } from '~/modules/llms/store-llms';
|
||||
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
|
||||
|
||||
import { browserLangOrUS } from '~/common/util/pwaUtils';
|
||||
|
||||
@@ -78,7 +78,13 @@ export function bareBonesPromptMixer(_template: string, assistantLlmId: DLLMId |
|
||||
mixed = mixed.replace('{{ToolBrowser0}}', 'Web browsing capabilities: Disabled');
|
||||
|
||||
// {{Cutoff}} or remove the line
|
||||
const varCutoff = getKnowledgeMapCutoff(assistantLlmId);
|
||||
let varCutoff: string | undefined;
|
||||
try {
|
||||
if (assistantLlmId)
|
||||
varCutoff = findLLMOrThrow(assistantLlmId).trainingDataCutoff;
|
||||
} catch (e) {
|
||||
// ignore...
|
||||
}
|
||||
if (varCutoff)
|
||||
mixed = mixed.replaceAll('{{Cutoff}}', varCutoff);
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user