Models: upgrade data structure to v2

This commit is contained in:
Enrico Ros
2024-04-12 05:36:18 -07:00
parent 14041b6012
commit b924d331f9
8 changed files with 145 additions and 73 deletions
+14 -4
View File
@@ -1,7 +1,7 @@
import type { ModelDescriptionSchema } from './server/llm.server.types';
import type { OpenAIWire } from './server/openai/openai.wiretypes';
import type { StreamingClientUpdate } from './vendors/unifiedStreamingClient';
import { DLLM, DLLMId, DModelSource, DModelSourceId, useModelsStore } from './store-llms';
import { DLLM, DLLMId, DModelSource, DModelSourceId, LLM_IF_OAI_Chat, useModelsStore } from './store-llms';
import { FALLBACK_LLM_TEMPERATURE } from './vendors/openai/openai.vendor';
import { findAccessForSourceOrThrow, findVendorForLlmOrThrow } from './vendors/vendors.registry';
@@ -62,16 +62,26 @@ function modelDescriptionToDLLMOpenAIOptions<TSourceSetup, TLLMOptions>(model: M
return {
id: `${source.id}-${model.id}`,
// editable properties
label: model.label,
created: model.created || 0,
updated: model.updated || 0,
description: model.description,
tags: [], // ['stream', 'chat'],
hidden: !!model.hidden,
// isEdited: false, // NOTE: this is set by the store on user edits
// hard properties
contextTokens,
maxOutputTokens,
hidden: !!model.hidden,
trainingDataCutoff: model.trainingDataCutoff,
interfaces: model.interfaces?.length ? model.interfaces : [LLM_IF_OAI_Chat],
// inputTypes: ...
benchmark: model.benchmark,
pricing: model.pricing,
isFree: model.pricing?.chatIn === 0 && model.pricing?.chatOut === 0,
// derived properties
tmpIsFree: model.pricing?.chatIn === 0 && model.pricing?.chatOut === 0,
tmpIsVision: model.interfaces?.includes(LLM_IF_OAI_Chat) === true,
sId: source.id,
_source: source,
@@ -126,7 +126,7 @@ export function LLMOptionsModal(props: { id: DLLMId, onClose: () => void }) {
<Typography level='body-md'>
{llm.id}
</Typography>
{llm.isFree && <Typography level='body-xs'>
{!!llm.tmpIsFree && <Typography level='body-xs'>
🎁 Free model - note: refresh models to check for updates in pricing
</Typography>}
{!!llm.description && <Typography level='body-xs'>
@@ -61,9 +61,11 @@ export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, al
description: descriptionLong,
contextWindow: contextWindow,
maxCompletionTokens: outputTokenLimit,
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
// trainingDataCutoff: '...',
interfaces,
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
// benchmarks: ...
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
hidden,
};
}
+6 -2
View File
@@ -18,6 +18,9 @@ const benchmarkSchema = z.object({
// reqPerMinute: z.number().optional(),
// });
const interfaceSchema = z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision, LLM_IF_OAI_Json]);
// NOTE: update the `fromManualMapping` function if you add new fields
const modelDescriptionSchema = z.object({
id: z.string(),
label: z.string(),
@@ -28,10 +31,11 @@ const modelDescriptionSchema = z.object({
maxCompletionTokens: z.number().optional(),
// rateLimits: rateLimitsSchema.optional(),
trainingDataCutoff: z.string().optional(),
interfaces: z.array(z.enum([LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Complete, LLM_IF_OAI_Vision, LLM_IF_OAI_Json])),
pricing: pricingSchema.optional(),
interfaces: z.array(interfaceSchema),
benchmark: benchmarkSchema.optional(),
pricing: pricingSchema.optional(),
hidden: z.boolean().optional(),
// TODO: add inputTypes/Kinds..
});
// this is also used by the Client
@@ -607,11 +607,11 @@ export function openRouterModelToModelDescription(wireModel: object): ModelDescr
chatIn: parseFloat(model.pricing.prompt) * 1000,
chatOut: parseFloat(model.pricing.completion),
};
const isFree = pricing.chatIn === 0 && pricing.chatOut === 0;
const seemsFree = pricing.chatIn === 0 && pricing.chatOut === 0;
// openrouter provides the fields we need as part of the model object
let label = model.name || model.id.replace('/', ' · ');
if (isFree)
if (seemsFree)
label += ' · 🎁'; // Free? Discounted?
// hidden: hide by default older models or models not in known families
@@ -626,8 +626,10 @@ export function openRouterModelToModelDescription(wireModel: object): ModelDescr
description: model.description,
contextWindow: model.context_length || 4096,
maxCompletionTokens: model.top_provider.max_completion_tokens || undefined,
pricing,
// trainingDataCutoff: ...
interfaces: [LLM_IF_OAI_Chat],
// benchmark: ...
pricing,
hidden,
});
}
@@ -905,8 +907,10 @@ function fromManualMapping(mappings: ManualMappings, id: string, created?: numbe
description: known.description,
contextWindow: known.contextWindow,
...(!!known.maxCompletionTokens && { maxCompletionTokens: known.maxCompletionTokens }),
...(!!known.pricing && { pricing: known.pricing }),
...(!!known.trainingDataCutoff && { trainingDataCutoff: known.trainingDataCutoff }),
interfaces: known.interfaces,
...(!!known.benchmark && { benchmark: known.benchmark }),
...(!!known.pricing && { pricing: known.pricing }),
...(!!known.hidden && { hidden: known.hidden }),
};
}
+103 -57
View File
@@ -1,5 +1,4 @@
import { create } from 'zustand';
import { shallow } from 'zustand/shallow';
import { persist } from 'zustand/middleware';
import type { ModelVendorId } from './vendors/vendors.registry';
@@ -11,18 +10,36 @@ import type { SourceSetupOpenRouter } from './vendors/openrouter/openrouter.vend
*/
export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
id: DLLMId;
// editable properties (kept on update, if isEdited)
label: string;
created: number | 0;
updated?: number | 0;
description: string;
tags: string[]; // UNUSED for now
// modelcaps: DModelCapability[];
hidden: boolean; // hidden from UI selectors
isEdited?: boolean; // user has edited the soft properties
// hard properties (overwritten on update)
contextTokens: number | null; // null: must assume it's unknown
maxOutputTokens: number | null; // null: must assume it's unknown
hidden: boolean; // hidden from Chat model UI selectors
trainingDataCutoff?: string; // [v2] 'Apr 2029'
interfaces: DModelInterfaceV1[]; // [v2] if set, meaning this is the known and comprehensive set of interfaces
// inputTypes: { // [v2] the supported input formats
// [key in DModelPartKind]?: {
// // maxItemsPerInput?: number;
// // maxFileSize?: number; // in bytes
// // maxDurationPerInput?: number; // in seconds, for audio and video
// // maxPagesPerInput?: number; // for PDF
// // encodings?: ('base64' | 'utf-8')[];
// mimeTypes?: string[];
// }
// };
benchmark?: { cbaElo?: number, cbaMmlu?: number }; // [v2] benchmark values
pricing?: { chatIn?: number, chatOut?: number }; // [v2] cost per million tokens
// temporary special flags - not graduated yet
isFree: boolean; // model is free to use
// derived properties
tmpIsFree?: boolean; // model is free to use [temporary, for now], this is a derived property from the pricing
tmpIsVision?: boolean; // model can take image inputs
// llm -> source
sId: DModelSourceId;
@@ -34,6 +51,25 @@ export interface DLLM<TSourceSetup = unknown, TLLMOptions = unknown> {
export type DLLMId = string;
// export type DModelPartKind = 'text' | 'image' | 'audio' | 'video' | 'pdf';
export type DModelInterfaceV1 =
// do not change anything below! those will be persisted in data
| 'oai-chat'
| 'oai-chat-json'
| 'oai-chat-vision'
| 'oai-chat-fn'
| 'oai-complete'
// only append below this line
;
// Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future
export const LLM_IF_OAI_Chat: DModelInterfaceV1 = 'oai-chat';
export const LLM_IF_OAI_Json: DModelInterfaceV1 = 'oai-chat-json';
export const LLM_IF_OAI_Vision: DModelInterfaceV1 = 'oai-chat-vision';
export const LLM_IF_OAI_Fn: DModelInterfaceV1 = 'oai-chat-fn';
export const LLM_IF_OAI_Complete: DModelInterfaceV1 = 'oai-complete';
// export type DModelCapability =
// | 'input-text'
// | 'input-image-data'
@@ -44,13 +80,7 @@ export type DLLMId = string;
// | 'if-chat'
// | 'if-fast-chat'
// ;
// Model interfaces (chat, and function calls) - here as a preview, will be used more broadly in the future
export const LLM_IF_OAI_Chat = 'oai-chat';
export const LLM_IF_OAI_Json = 'oai-chat-json';
export const LLM_IF_OAI_Vision = 'oai-chat-vision';
export const LLM_IF_OAI_Fn = 'oai-chat-fn';
export const LLM_IF_OAI_Complete = 'oai-complete';
// modelcaps: DModelCapability[];
/**
@@ -231,8 +261,9 @@ export const useModelsStore = create<LlmsStore>()(
/* versioning:
* 1: adds maxOutputTokens (default to half of contextTokens)
* 2: large changes on all LLMs, and reset chat/fast/func LLMs
*/
version: 1,
version: 2,
migrate: (state: any, fromVersion: number): LlmsStore => {
// 0 -> 1: add 'maxOutputTokens' where missing
@@ -241,6 +272,19 @@ export const useModelsStore = create<LlmsStore>()(
if (llm.maxOutputTokens === undefined)
llm.maxOutputTokens = llm.contextTokens ? Math.round(llm.contextTokens / 2) : null;
// 1 -> 2: large changes
if (state && fromVersion < 2) {
for (const llm of state.llms) {
delete llm['tags'];
llm.interfaces = [LLM_IF_OAI_Chat];
// llm.inputTypes = { 'text': {} };
}
// const autoPickModels = updateSelectedIds(state.llms, null, null, null);
state.chatLLMId = null; // autoPickModels.chatLLMId;
state.fastLLMId = null; // ...
state.funcLLMId = null; // ...
}
return state;
},
@@ -285,64 +329,66 @@ export function findSourceOrThrow<TSourceSetup>(sourceId: DModelSourceId) {
}
const modelsKnowledgeMap: { contains: string[], cutoff: string }[] = [
{ contains: ['4-0125', '4-turbo'], cutoff: '2023-12' },
{ contains: ['4-1106', '4-vision'], cutoff: '2023-04' },
{ contains: ['4-0613', '4-0314', '4-32k', '3.5-turbo'], cutoff: '2021-09' },
] as const;
function groupLlmsByVendor(llms: DLLM[]): { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[] {
// group all LLMs by vendor
const grouped = llms.reduce((acc, llm) => {
if (llm.hidden) return acc;
const vendorId = llm._source.vId;
const vendor = acc.find(v => v.vendorId === vendorId);
if (!vendor) acc.push({ vendorId, llmsByElo: [{ id: llm.id, cbaElo: llm.benchmark?.cbaElo }] });
else vendor.llmsByElo.push({ id: llm.id, cbaElo: llm.benchmark?.cbaElo });
return acc;
}, [] as { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[]);
export function getKnowledgeMapCutoff(llmId?: DLLMId): string | null {
if (llmId)
for (const { contains, cutoff } of modelsKnowledgeMap)
if (contains.some(c => llmId.includes(c)))
return cutoff;
return null;
// sort each vendor's LLMs by elo, decreasing
for (const vendor of grouped)
vendor.llmsByElo.sort((a, b) => (b.cbaElo ?? -1) - (a.cbaElo ?? -1));
// sort all vendors by their highest elo, decreasing
grouped.sort((a, b) => (b.llmsByElo[0].cbaElo ?? -1) - (a.llmsByElo[0].cbaElo ?? -1));
return grouped;
}
const defaultChatSuffixPreference = ['gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-4-0613', 'gpt-4', 'gpt-4-32k', 'gpt-3.5-turbo'];
const defaultFastSuffixPreference = ['gpt-3.5-turbo-0125', 'gpt-3.5-turbo-1106', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k', 'gpt-3.5-turbo'];
const defaultFuncSuffixPreference = ['gpt-4-0125-preview', 'gpt-4-1106-preview', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0613', 'gpt-4-0613'];
function updateSelectedIds(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null): Partial<ModelsData> {
if (chatLlmId && !allLlms.find(llm => llm.id === chatLlmId)) chatLlmId = null;
if (!chatLlmId) chatLlmId = findLlmIdBySuffix(allLlms, defaultChatSuffixPreference, true);
export function updateSelectedIds(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null) {
if (fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) fastLlmId = null;
if (!fastLlmId) fastLlmId = findLlmIdBySuffix(allLlms, defaultFastSuffixPreference, true);
// the output of groupLlmsByVendor
let grouped: ReturnType<typeof groupLlmsByVendor> | null = null;
if (funcLlmId && !allLlms.find(llm => llm.id === funcLlmId)) funcLlmId = null;
if (!funcLlmId) funcLlmId = findLlmIdBySuffix(allLlms, defaultFuncSuffixPreference, false);
function cachedGrouped() {
if (!grouped) grouped = groupLlmsByVendor(allLlms);
return grouped;
}
// the best llm
if (!chatLlmId || !allLlms.find(llm => llm.id === chatLlmId)) {
const vendors = cachedGrouped();
chatLlmId = vendors.length ? vendors[0].llmsByElo[0].id : null;
}
// a fast llm (bottom elo of the top vendor ~~ not really a proxy, but not sure which heuristic to use here)
if (!fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) {
const vendors = cachedGrouped();
fastLlmId = vendors.length
? vendors[0].llmsByElo.findLast(llm => llm.cbaElo)?.id // last with ELO
?? vendors[0].llmsByElo[vendors[0].llmsByElo.length - 1].id ?? null // last
: null;
}
// a func llm (
if (!funcLlmId || !allLlms.find(llm => llm.id === funcLlmId))
funcLlmId = chatLlmId;
return { chatLLMId: chatLlmId, fastLLMId: fastLlmId, funcLLMId: funcLlmId };
}
function findLlmIdBySuffix(llms: DLLM[], suffixes: string[], fallbackToFirst: boolean): DLLMId | null {
if (!llms?.length) return null;
for (const suffix of suffixes)
for (const llm of llms)
if (llm.id.endsWith(suffix))
return llm.id;
if (!fallbackToFirst) return null;
// otherwise return first that's not hidden
for (const llm of llms)
if (!llm.hidden)
return llm.id;
// otherwise return first id
return llms[0].id;
}
/**
* Current 'Chat' LLM, or null
*/
export function useChatLLM() {
return useModelsStore(state => {
const { chatLLMId } = state;
const chatLLM = chatLLMId ? state.llms.find(llm => llm.id === chatLLMId) ?? null : null;
return { chatLLM };
}, shallow);
const chatLLM = useModelsStore(state => state.chatLLMId ? state.llms.find(llm => llm.id === state.chatLLMId) ?? null : null);
return { chatLLM };
}
export function getLLMsDebugInfo() {
+1 -1
View File
@@ -61,7 +61,7 @@ export const ModelVendorOpenRouter: IModelVendor<SourceSetupOpenRouter, OpenAIAc
getRateLimitDelay: (llm) => {
const now = Date.now();
const elapsed = now - nextGenerationTs;
const wait = llm.isFree
const wait = llm.tmpIsFree
? 5000 + 100 /* 5 seconds for free call, plus some safety margin */
: 100;
+8 -2
View File
@@ -1,4 +1,4 @@
import { DLLMId, getKnowledgeMapCutoff } from '~/modules/llms/store-llms';
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
import { browserLangOrUS } from '~/common/util/pwaUtils';
@@ -78,7 +78,13 @@ export function bareBonesPromptMixer(_template: string, assistantLlmId: DLLMId |
mixed = mixed.replace('{{ToolBrowser0}}', 'Web browsing capabilities: Disabled');
// {{Cutoff}} or remove the line
const varCutoff = getKnowledgeMapCutoff(assistantLlmId);
let varCutoff: string | undefined;
try {
if (assistantLlmId)
varCutoff = findLLMOrThrow(assistantLlmId).trainingDataCutoff;
} catch (e) {
// ignore...
}
if (varCutoff)
mixed = mixed.replaceAll('{{Cutoff}}', varCutoff);
else