mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
Improve Fast LLM auto-select (py price)
This commit is contained in:
@@ -3,7 +3,7 @@ import * as React from 'react';
|
||||
import type { SxProps } from '@mui/joy/styles/types';
|
||||
import { Box, ColorPaletteProp, Tooltip } from '@mui/joy';
|
||||
|
||||
import { DChatGeneratePricing, getLlmPriceForTokens } from '~/common/stores/llms/llms.pricing';
|
||||
import { DChatGeneratePricing, getLlmCostForTokens } from '~/common/stores/llms/llms.pricing';
|
||||
import { adjustContentScaling, themeScalingMap } from '~/common/app.theme';
|
||||
import { formatModelsCost } from '~/common/util/costUtils';
|
||||
import { useUIContentScaling } from '~/common/state/store-ui';
|
||||
@@ -43,8 +43,8 @@ export function tokenCountsMathAndMessage(tokenLimit: number | 0, directTokens:
|
||||
|
||||
// add the price, if available
|
||||
if (chatPricing) {
|
||||
const inputPrice = getLlmPriceForTokens(usedInputTokens, usedInputTokens, chatPricing.input);
|
||||
const outputPrice = getLlmPriceForTokens(usedInputTokens, responseMaxTokens || 0, chatPricing.output);
|
||||
const inputPrice = getLlmCostForTokens(usedInputTokens, usedInputTokens, chatPricing.input);
|
||||
const outputPrice = getLlmCostForTokens(usedInputTokens, responseMaxTokens || 0, chatPricing.output);
|
||||
|
||||
costMin = inputPrice;
|
||||
const costOutMax = outputPrice;
|
||||
|
||||
@@ -9,21 +9,26 @@ export type DModelPricing = {
|
||||
chat?: DChatGeneratePricing,
|
||||
}
|
||||
|
||||
// NOTE: (!) keep this in sync with ChatGeneratePricing_schema (modules/llms/server/llm.server.types.ts)
|
||||
export type DChatGeneratePricing = {
|
||||
// unit: 'USD_Mtok',
|
||||
input?: DTieredPrice;
|
||||
output?: DTieredPrice;
|
||||
input?: DTieredPricing;
|
||||
output?: DTieredPricing;
|
||||
cache?: {
|
||||
cType: 'ant-bp';
|
||||
read: DTieredPrice;
|
||||
write: DTieredPrice;
|
||||
read: DTieredPricing;
|
||||
write: DTieredPricing;
|
||||
duration: number; // seconds
|
||||
} | {
|
||||
cType: 'oai-apc';
|
||||
read: DTieredPricing;
|
||||
// write: DTieredPricing; // Not needed, as it's automatic
|
||||
};
|
||||
// NOT in AixWire_API_ListModels.PriceChatGenerate_schema
|
||||
// NOT in AixWire_API_ListModels.ChatGeneratePricing_schema
|
||||
_isFree?: boolean; // precomputed, so we avoid recalculating it
|
||||
}
|
||||
|
||||
type DTieredPrice = DPricePerMToken | DPriceUpTo[];
|
||||
type DTieredPricing = DPricePerMToken | DPriceUpTo[];
|
||||
|
||||
type DPriceUpTo = {
|
||||
upTo: number | null,
|
||||
@@ -35,46 +40,43 @@ type DPricePerMToken = number | 'free';
|
||||
|
||||
/// detect Free Pricing
|
||||
|
||||
export function isModelPriceFree(priceChatGenerate: DChatGeneratePricing): boolean {
|
||||
if (!priceChatGenerate) return true;
|
||||
return _isPriceFree(priceChatGenerate.input) && _isPriceFree(priceChatGenerate.output);
|
||||
export function isModelPricingFree(pricingChatGenerate: DChatGeneratePricing): boolean {
|
||||
if (!pricingChatGenerate) return true;
|
||||
return _isPricingFree(pricingChatGenerate.input) && _isPricingFree(pricingChatGenerate.output);
|
||||
}
|
||||
|
||||
function _isPriceFree(price: DTieredPrice | undefined): boolean {
|
||||
if (price === 'free') return true;
|
||||
if (price === undefined) return false;
|
||||
if (typeof price === 'number') return price === 0;
|
||||
return price.every(tier => _isPricePerMTokenFree(tier.price));
|
||||
}
|
||||
|
||||
function _isPricePerMTokenFree(price: DPricePerMToken): boolean {
|
||||
return price === 'free' || price === 0;
|
||||
function _isPricingFree(pricing: DTieredPricing | undefined): boolean {
|
||||
if (pricing === 'free') return true;
|
||||
if (pricing === undefined) return false;
|
||||
if (typeof pricing === 'number') return pricing === 0;
|
||||
return pricing.every(tier => tier.price === 'free' || tier.price === 0);
|
||||
}
|
||||
|
||||
|
||||
/// Human readable price formatting
|
||||
/// Human readable cost
|
||||
|
||||
export function getLlmPriceForTokens(inputTokens: number, tokens: number, pricing: DTieredPrice | undefined): number | undefined {
|
||||
export function getLlmCostForTokens(inputTokens: number, tokens: number, pricing: DTieredPricing | undefined): number | undefined {
|
||||
if (!pricing) return undefined;
|
||||
if (pricing === 'free') return 0;
|
||||
|
||||
// Cost = tokens * price / 1e6
|
||||
if (typeof pricing === 'number') return tokens * pricing / 1e6;
|
||||
|
||||
// Find the applicable tier based on input tokens
|
||||
const applicableTier = pricing.find(tier => tier.upTo === null || inputTokens <= tier.upTo);
|
||||
|
||||
// This should not happen if the pricing is well-formed
|
||||
if (!applicableTier) {
|
||||
console.log('[DEV] getPriceForTokens: No applicable tier found for input tokens', { inputTokens, pricing });
|
||||
console.log('[DEV] getLlmCostForTokens: No applicable tier found for input tokens', { inputTokens, pricing });
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Apply the price of the found tier to all tokens
|
||||
// Cost = tier pricing * tokens / 1e6 (or free)
|
||||
if (applicableTier.price === 'free') return 0;
|
||||
// Note: apply the pricing of the found tier to all tokens
|
||||
return tokens * applicableTier.price / 1e6;
|
||||
}
|
||||
|
||||
|
||||
// Compatibiltiy layer for pricing V2 -> V3
|
||||
// Compatibility layer for pricing V2 -> V3
|
||||
|
||||
interface Was_DModelPricingV2 {
|
||||
chatIn?: number
|
||||
@@ -92,7 +94,7 @@ export function portModelPricingV2toV3(llm: DLLM): void {
|
||||
V3.input = pretendIsV2.chatIn;
|
||||
if (pretendIsV2.chatOut)
|
||||
V3.output = pretendIsV2.chatOut;
|
||||
V3._isFree = isModelPriceFree(V3);
|
||||
V3._isFree = isModelPricingFree(V3);
|
||||
llm.pricing = { chat: V3 };
|
||||
delete pretendIsV2.chatIn;
|
||||
delete pretendIsV2.chatOut;
|
||||
|
||||
@@ -10,7 +10,7 @@ import type { ModelVendorId } from '~/modules/llms/vendors/vendors.registry';
|
||||
|
||||
import type { DLLM, DLLMId } from './llms.types';
|
||||
import type { DModelsService, DModelsServiceId } from './modelsservice.types';
|
||||
import { portModelPricingV2toV3 } from './llms.pricing';
|
||||
import { getLlmCostForTokens, portModelPricingV2toV3 } from './llms.pricing';
|
||||
|
||||
|
||||
/// ModelsStore - a store for configured LLMs and configured services
|
||||
@@ -331,57 +331,52 @@ export function getLLMsDebugInfo() {
|
||||
|
||||
function _heuristicUpdateSelectedLLMs(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null) {
|
||||
|
||||
// the output of _groupLlmsByVendorRankedByElo
|
||||
let grouped: ReturnType<typeof _groupLlmsByVendorRankedByElo> | null = null;
|
||||
let grouped: GroupedVendorLLMs;
|
||||
|
||||
function cachedGrouped() {
|
||||
if (!grouped) grouped = _groupLlmsByVendorRankedByElo(allLlms);
|
||||
return grouped;
|
||||
}
|
||||
|
||||
// the best llm
|
||||
// default Chat: top vendor by Elo, top model
|
||||
if (!chatLlmId || !allLlms.find(llm => llm.id === chatLlmId)) {
|
||||
const vendors = cachedGrouped();
|
||||
chatLlmId = vendors.length ? vendors[0].llmsByElo[0].id : null;
|
||||
}
|
||||
|
||||
// a fast llm (bottom elo of the top vendor ~~ not really a proxy, but not sure which heuristic to use here)
|
||||
if (!fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) {
|
||||
// default Fast: vendors by Elo, lowest cost (if available)
|
||||
if (!fastLlmId || !allLlms.find(llm => llm.id === fastLlmId)) {
|
||||
const vendors = cachedGrouped();
|
||||
fastLlmId = vendors.length
|
||||
? vendors[0].llmsByElo.findLast(llm => llm.cbaElo)?.id // last with ELO
|
||||
?? vendors[0].llmsByElo[vendors[0].llmsByElo.length - 1].id ?? null // last
|
||||
: null;
|
||||
fastLlmId = _selectFastLlmID(vendors);
|
||||
}
|
||||
|
||||
// a func llm (same as chat for now, hoping the highest grade also has function calling)
|
||||
// a func llm (same as chat for now)
|
||||
if (!funcLlmId || !allLlms.find(llm => llm.id === funcLlmId))
|
||||
funcLlmId = chatLlmId;
|
||||
|
||||
return { chatLLMId: chatLlmId, fastLLMId: fastLlmId, funcLLMId: funcLlmId };
|
||||
}
|
||||
|
||||
function _groupLlmsByVendorRankedByElo(llms: DLLM[]): { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[] {
|
||||
|
||||
type BenchVendorLLMs = { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined, costRank: number | undefined }[] };
|
||||
type GroupedVendorLLMs = BenchVendorLLMs[];
|
||||
|
||||
function _groupLlmsByVendorRankedByElo(llms: DLLM[]): GroupedVendorLLMs {
|
||||
// group all LLMs by vendor
|
||||
const grouped = llms.reduce((acc, llm) => {
|
||||
if (llm.hidden) return acc;
|
||||
const vendor = acc.find(v => v.vendorId === llm.vId);
|
||||
if (!vendor) {
|
||||
acc.push({
|
||||
vendorId: llm.vId,
|
||||
llmsByElo: [{
|
||||
id: llm.id,
|
||||
cbaElo: llm.benchmark?.cbaElo,
|
||||
}],
|
||||
});
|
||||
} else {
|
||||
vendor.llmsByElo.push({
|
||||
id: llm.id,
|
||||
cbaElo: llm.benchmark?.cbaElo,
|
||||
});
|
||||
}
|
||||
const group = acc.find(v => v.vendorId === llm.vId);
|
||||
const eloCostItem = {
|
||||
id: llm.id,
|
||||
cbaElo: llm.benchmark?.cbaElo,
|
||||
costRank: !llm.pricing ? undefined : _getLlmCostBenchmark(llm),
|
||||
};
|
||||
if (!group)
|
||||
acc.push({ vendorId: llm.vId, llmsByElo: [eloCostItem] });
|
||||
else
|
||||
group.llmsByElo.push(eloCostItem);
|
||||
return acc;
|
||||
}, [] as { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[]);
|
||||
}, [] as GroupedVendorLLMs);
|
||||
|
||||
// sort each vendor's LLMs by elo, decreasing
|
||||
for (const vendor of grouped)
|
||||
@@ -391,3 +386,28 @@ function _groupLlmsByVendorRankedByElo(llms: DLLM[]): { vendorId: ModelVendorId,
|
||||
grouped.sort((a, b) => (b.llmsByElo[0].cbaElo ?? -1) - (a.llmsByElo[0].cbaElo ?? -1));
|
||||
return grouped;
|
||||
}
|
||||
|
||||
// Hypothetical cost benchmark for a model, based on total cost of 100k input tokens and 10k output tokens.
|
||||
function _getLlmCostBenchmark(llm: DLLM): number | undefined {
|
||||
if (!llm.pricing?.chat) return undefined;
|
||||
const costIn = getLlmCostForTokens(100000, 100000, llm.pricing.chat.input);
|
||||
const costOut = getLlmCostForTokens(100000, 10000, llm.pricing.chat.output);
|
||||
return (costIn !== undefined && costOut !== undefined) ? costIn + costOut : undefined;
|
||||
}
|
||||
|
||||
// Selects the 'fast' llm
|
||||
function _selectFastLlmID(vendors: GroupedVendorLLMs) {
|
||||
if (!vendors.length) return null;
|
||||
for (const vendor of vendors) {
|
||||
const lowestCostLlm = vendor.llmsByElo.reduce((acc, llm) => {
|
||||
if (!acc)
|
||||
return llm;
|
||||
if (!llm.costRank || !acc.costRank)
|
||||
return acc;
|
||||
return llm.costRank < acc.costRank ? llm : acc;
|
||||
}, null as BenchVendorLLMs['llmsByElo'][number] | null);
|
||||
if (lowestCostLlm)
|
||||
return lowestCostLlm.id;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { DChatGeneratePricing, getLlmPriceForTokens, isModelPriceFree } from '~/common/stores/llms/llms.pricing';
|
||||
import { DChatGeneratePricing, getLlmCostForTokens, isModelPricingFree } from '~/common/stores/llms/llms.pricing';
|
||||
|
||||
/**
|
||||
* This is a stored type - IMPORTANT: do not break.
|
||||
@@ -105,7 +105,7 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
|
||||
return { $code: 'no-pricing' };
|
||||
|
||||
// pricing: bail if free
|
||||
if (isModelPriceFree(pricing))
|
||||
if (isModelPricingFree(pricing))
|
||||
return { $code: 'free' };
|
||||
|
||||
|
||||
@@ -113,8 +113,8 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
|
||||
const isPartialMessage = metrics.TsR === 'pending' || metrics.TsR === 'aborted';
|
||||
|
||||
// Calculate costs
|
||||
const $in = getLlmPriceForTokens(inputTokens, inputTokens, pricing.input);
|
||||
const $out = getLlmPriceForTokens(inputTokens, outputTokens, pricing.output);
|
||||
const $in = getLlmCostForTokens(inputTokens, inputTokens, pricing.input);
|
||||
const $out = getLlmCostForTokens(inputTokens, outputTokens, pricing.output);
|
||||
if ($in === undefined || $out === undefined)
|
||||
return { $code: 'partial-price' };
|
||||
|
||||
@@ -128,13 +128,13 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
|
||||
throw new Error('Tiered pricing with cache is not supported');
|
||||
|
||||
const inputNoCache = inputTokens + cacheReadTokens + cacheWriteTokens;
|
||||
const $cacheRead = getLlmPriceForTokens(inputNoCache, cacheReadTokens, pricing.cache?.read);
|
||||
const $cacheWrite = getLlmPriceForTokens(inputNoCache, cacheWriteTokens, pricing.cache?.write);
|
||||
const $cacheRead = getLlmCostForTokens(inputNoCache, cacheReadTokens, pricing.cache?.read);
|
||||
const $cacheWrite = getLlmCostForTokens(inputNoCache, cacheWriteTokens, pricing.cache?.write);
|
||||
if ($cacheRead === undefined || $cacheWrite === undefined)
|
||||
return { $code: 'partial-price' };
|
||||
|
||||
// compute the advantage from caching
|
||||
const $inNoCache = getLlmPriceForTokens(inputNoCache, inputNoCache, pricing.input)!;
|
||||
const $inNoCache = getLlmCostForTokens(inputNoCache, inputNoCache, pricing.input)!;
|
||||
return {
|
||||
$c: Math.round(($in + $out + $cacheRead + $cacheWrite) * USD_TO_CENTS * 10000) / 10000,
|
||||
$cdCache: Math.round(($inNoCache - $in - $cacheRead - $cacheWrite) * USD_TO_CENTS * 10000) / 10000,
|
||||
|
||||
@@ -7,7 +7,7 @@ import type { OpenAIWire_Tools } from '~/modules/aix/server/dispatch/wiretypes/o
|
||||
import type { DModelsService, DModelsServiceId } from '~/common/stores/llms/modelsservice.types';
|
||||
import { DLLM, DLLMId, LLM_IF_OAI_Chat } from '~/common/stores/llms/llms.types';
|
||||
import { llmsStoreActions } from '~/common/stores/llms/store-llms';
|
||||
import { isModelPriceFree } from '~/common/stores/llms/llms.pricing';
|
||||
import { isModelPricingFree } from '~/common/stores/llms/llms.pricing';
|
||||
|
||||
import type { ModelDescriptionSchema } from './server/llm.server.types';
|
||||
import { DOpenAILLMOptions, FALLBACK_LLM_TEMPERATURE } from './vendors/openai/openai.vendor';
|
||||
@@ -116,7 +116,7 @@ function _createDLLMFromModelDescription(d: ModelDescriptionSchema, service: DMo
|
||||
chat: {
|
||||
...d.chatPrice,
|
||||
// compute the free status
|
||||
_isFree: isModelPriceFree(d.chatPrice),
|
||||
_isFree: isModelPricingFree(d.chatPrice),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -48,27 +48,28 @@ const PriceUpTo_schema = z.object({
|
||||
price: PricePerMToken_schema,
|
||||
});
|
||||
|
||||
const TieredPrice_schema = z.union([
|
||||
const TieredPricing_schema = z.union([
|
||||
PricePerMToken_schema,
|
||||
z.array(PriceUpTo_schema),
|
||||
]);
|
||||
|
||||
// NOTE: (!) keep this in sync with DChatGeneratePricing (llms.pricing.ts)
|
||||
const ChatGeneratePricing_schema = z.object({
|
||||
input: TieredPrice_schema.optional(),
|
||||
output: TieredPrice_schema.optional(),
|
||||
input: TieredPricing_schema.optional(),
|
||||
output: TieredPricing_schema.optional(),
|
||||
// Future: Perplexity has a cost per request, consider this for future additions
|
||||
// perRequest: z.number().optional(), // New field for fixed per-request pricing
|
||||
cache: z.discriminatedUnion('cType', [
|
||||
z.object({
|
||||
cType: z.literal('ant-bp'), // [Anthropic] Breakpoint-based caching
|
||||
read: TieredPrice_schema,
|
||||
write: TieredPrice_schema,
|
||||
read: TieredPricing_schema,
|
||||
write: TieredPricing_schema,
|
||||
duration: z.number(),
|
||||
}),
|
||||
z.object({
|
||||
cType: z.literal('oai-apc'), // [OpenAI] Automatic Prompt Caching
|
||||
read: TieredPrice_schema,
|
||||
// write: TieredPrice_schema, // Not needed, as it's automatic
|
||||
read: TieredPricing_schema,
|
||||
// write: TieredPricing_schema, // Not needed, as it's automatic
|
||||
}),
|
||||
]).optional(),
|
||||
// Not for the server-side, computed on the client only
|
||||
|
||||
@@ -327,7 +327,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
|
||||
trainingDataCutoff: 'Sep 2021',
|
||||
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
|
||||
chatPrice: { input: 0.5, output: 1.5 },
|
||||
// benchmark: { cbaElo: 1106 }, // disabled so that it won't be picked up as 'fast' model
|
||||
benchmark: { cbaElo: 1106 },
|
||||
},
|
||||
{
|
||||
idPrefix: 'gpt-3.5-turbo',
|
||||
@@ -341,7 +341,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
|
||||
trainingDataCutoff: 'Sep 2021',
|
||||
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
|
||||
chatPrice: { input: 0.5, output: 1.5 },
|
||||
// benchmark: { cbaElo: 1106 }, // disabled so that it won't be picked up as 'fast' model
|
||||
benchmark: { cbaElo: 1106 },
|
||||
},
|
||||
{
|
||||
idPrefix: 'gpt-3.5-turbo-1106',
|
||||
@@ -352,7 +352,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
|
||||
trainingDataCutoff: 'Sep 2021',
|
||||
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
|
||||
chatPrice: { input: 1, output: 2 },
|
||||
// benchmark: { cbaElo: 1072 }, // disabled so that it won't be picked up as 'fast' model
|
||||
benchmark: { cbaElo: 1072 },
|
||||
hidden: true,
|
||||
},
|
||||
|
||||
|
||||
Reference in New Issue
Block a user