Improve Fast LLM auto-select (py price)

This commit is contained in:
Enrico Ros
2024-10-01 14:19:15 -07:00
parent dcce5a5b1d
commit d09d4455aa
7 changed files with 99 additions and 76 deletions
@@ -3,7 +3,7 @@ import * as React from 'react';
import type { SxProps } from '@mui/joy/styles/types';
import { Box, ColorPaletteProp, Tooltip } from '@mui/joy';
import { DChatGeneratePricing, getLlmPriceForTokens } from '~/common/stores/llms/llms.pricing';
import { DChatGeneratePricing, getLlmCostForTokens } from '~/common/stores/llms/llms.pricing';
import { adjustContentScaling, themeScalingMap } from '~/common/app.theme';
import { formatModelsCost } from '~/common/util/costUtils';
import { useUIContentScaling } from '~/common/state/store-ui';
@@ -43,8 +43,8 @@ export function tokenCountsMathAndMessage(tokenLimit: number | 0, directTokens:
// add the price, if available
if (chatPricing) {
const inputPrice = getLlmPriceForTokens(usedInputTokens, usedInputTokens, chatPricing.input);
const outputPrice = getLlmPriceForTokens(usedInputTokens, responseMaxTokens || 0, chatPricing.output);
const inputPrice = getLlmCostForTokens(usedInputTokens, usedInputTokens, chatPricing.input);
const outputPrice = getLlmCostForTokens(usedInputTokens, responseMaxTokens || 0, chatPricing.output);
costMin = inputPrice;
const costOutMax = outputPrice;
+28 -26
View File
@@ -9,21 +9,26 @@ export type DModelPricing = {
chat?: DChatGeneratePricing,
}
// NOTE: (!) keep this in sync with ChatGeneratePricing_schema (modules/llms/server/llm.server.types.ts)
export type DChatGeneratePricing = {
// unit: 'USD_Mtok',
input?: DTieredPrice;
output?: DTieredPrice;
input?: DTieredPricing;
output?: DTieredPricing;
cache?: {
cType: 'ant-bp';
read: DTieredPrice;
write: DTieredPrice;
read: DTieredPricing;
write: DTieredPricing;
duration: number; // seconds
} | {
cType: 'oai-apc';
read: DTieredPricing;
// write: DTieredPricing; // Not needed, as it's automatic
};
// NOT in AixWire_API_ListModels.PriceChatGenerate_schema
// NOT in AixWire_API_ListModels.ChatGeneratePricing_schema
_isFree?: boolean; // precomputed, so we avoid recalculating it
}
type DTieredPrice = DPricePerMToken | DPriceUpTo[];
type DTieredPricing = DPricePerMToken | DPriceUpTo[];
type DPriceUpTo = {
upTo: number | null,
@@ -35,46 +40,43 @@ type DPricePerMToken = number | 'free';
/// detect Free Pricing
export function isModelPriceFree(priceChatGenerate: DChatGeneratePricing): boolean {
if (!priceChatGenerate) return true;
return _isPriceFree(priceChatGenerate.input) && _isPriceFree(priceChatGenerate.output);
export function isModelPricingFree(pricingChatGenerate: DChatGeneratePricing): boolean {
if (!pricingChatGenerate) return true;
return _isPricingFree(pricingChatGenerate.input) && _isPricingFree(pricingChatGenerate.output);
}
function _isPriceFree(price: DTieredPrice | undefined): boolean {
if (price === 'free') return true;
if (price === undefined) return false;
if (typeof price === 'number') return price === 0;
return price.every(tier => _isPricePerMTokenFree(tier.price));
}
function _isPricePerMTokenFree(price: DPricePerMToken): boolean {
return price === 'free' || price === 0;
function _isPricingFree(pricing: DTieredPricing | undefined): boolean {
if (pricing === 'free') return true;
if (pricing === undefined) return false;
if (typeof pricing === 'number') return pricing === 0;
return pricing.every(tier => tier.price === 'free' || tier.price === 0);
}
/// Human readable price formatting
/// Human readable cost
export function getLlmPriceForTokens(inputTokens: number, tokens: number, pricing: DTieredPrice | undefined): number | undefined {
export function getLlmCostForTokens(inputTokens: number, tokens: number, pricing: DTieredPricing | undefined): number | undefined {
if (!pricing) return undefined;
if (pricing === 'free') return 0;
// Cost = tokens * price / 1e6
if (typeof pricing === 'number') return tokens * pricing / 1e6;
// Find the applicable tier based on input tokens
const applicableTier = pricing.find(tier => tier.upTo === null || inputTokens <= tier.upTo);
// This should not happen if the pricing is well-formed
if (!applicableTier) {
console.log('[DEV] getPriceForTokens: No applicable tier found for input tokens', { inputTokens, pricing });
console.log('[DEV] getLlmCostForTokens: No applicable tier found for input tokens', { inputTokens, pricing });
return undefined;
}
// Apply the price of the found tier to all tokens
// Cost = tier pricing * tokens / 1e6 (or free)
if (applicableTier.price === 'free') return 0;
// Note: apply the pricing of the found tier to all tokens
return tokens * applicableTier.price / 1e6;
}
// Compatibiltiy layer for pricing V2 -> V3
// Compatibility layer for pricing V2 -> V3
interface Was_DModelPricingV2 {
chatIn?: number
@@ -92,7 +94,7 @@ export function portModelPricingV2toV3(llm: DLLM): void {
V3.input = pretendIsV2.chatIn;
if (pretendIsV2.chatOut)
V3.output = pretendIsV2.chatOut;
V3._isFree = isModelPriceFree(V3);
V3._isFree = isModelPricingFree(V3);
llm.pricing = { chat: V3 };
delete pretendIsV2.chatIn;
delete pretendIsV2.chatOut;
+48 -28
View File
@@ -10,7 +10,7 @@ import type { ModelVendorId } from '~/modules/llms/vendors/vendors.registry';
import type { DLLM, DLLMId } from './llms.types';
import type { DModelsService, DModelsServiceId } from './modelsservice.types';
import { portModelPricingV2toV3 } from './llms.pricing';
import { getLlmCostForTokens, portModelPricingV2toV3 } from './llms.pricing';
/// ModelsStore - a store for configured LLMs and configured services
@@ -331,57 +331,52 @@ export function getLLMsDebugInfo() {
function _heuristicUpdateSelectedLLMs(allLlms: DLLM[], chatLlmId: DLLMId | null, fastLlmId: DLLMId | null, funcLlmId: DLLMId | null) {
// the output of _groupLlmsByVendorRankedByElo
let grouped: ReturnType<typeof _groupLlmsByVendorRankedByElo> | null = null;
let grouped: GroupedVendorLLMs;
function cachedGrouped() {
if (!grouped) grouped = _groupLlmsByVendorRankedByElo(allLlms);
return grouped;
}
// the best llm
// default Chat: top vendor by Elo, top model
if (!chatLlmId || !allLlms.find(llm => llm.id === chatLlmId)) {
const vendors = cachedGrouped();
chatLlmId = vendors.length ? vendors[0].llmsByElo[0].id : null;
}
// a fast llm (bottom elo of the top vendor ~~ not really a proxy, but not sure which heuristic to use here)
if (!fastLlmId && !allLlms.find(llm => llm.id === fastLlmId)) {
// default Fast: vendors by Elo, lowest cost (if available)
if (!fastLlmId || !allLlms.find(llm => llm.id === fastLlmId)) {
const vendors = cachedGrouped();
fastLlmId = vendors.length
? vendors[0].llmsByElo.findLast(llm => llm.cbaElo)?.id // last with ELO
?? vendors[0].llmsByElo[vendors[0].llmsByElo.length - 1].id ?? null // last
: null;
fastLlmId = _selectFastLlmID(vendors);
}
// a func llm (same as chat for now, hoping the highest grade also has function calling)
// a func llm (same as chat for now)
if (!funcLlmId || !allLlms.find(llm => llm.id === funcLlmId))
funcLlmId = chatLlmId;
return { chatLLMId: chatLlmId, fastLLMId: fastLlmId, funcLLMId: funcLlmId };
}
function _groupLlmsByVendorRankedByElo(llms: DLLM[]): { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[] {
type BenchVendorLLMs = { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined, costRank: number | undefined }[] };
type GroupedVendorLLMs = BenchVendorLLMs[];
function _groupLlmsByVendorRankedByElo(llms: DLLM[]): GroupedVendorLLMs {
// group all LLMs by vendor
const grouped = llms.reduce((acc, llm) => {
if (llm.hidden) return acc;
const vendor = acc.find(v => v.vendorId === llm.vId);
if (!vendor) {
acc.push({
vendorId: llm.vId,
llmsByElo: [{
id: llm.id,
cbaElo: llm.benchmark?.cbaElo,
}],
});
} else {
vendor.llmsByElo.push({
id: llm.id,
cbaElo: llm.benchmark?.cbaElo,
});
}
const group = acc.find(v => v.vendorId === llm.vId);
const eloCostItem = {
id: llm.id,
cbaElo: llm.benchmark?.cbaElo,
costRank: !llm.pricing ? undefined : _getLlmCostBenchmark(llm),
};
if (!group)
acc.push({ vendorId: llm.vId, llmsByElo: [eloCostItem] });
else
group.llmsByElo.push(eloCostItem);
return acc;
}, [] as { vendorId: ModelVendorId, llmsByElo: { id: DLLMId, cbaElo: number | undefined }[] }[]);
}, [] as GroupedVendorLLMs);
// sort each vendor's LLMs by elo, decreasing
for (const vendor of grouped)
@@ -391,3 +386,28 @@ function _groupLlmsByVendorRankedByElo(llms: DLLM[]): { vendorId: ModelVendorId,
grouped.sort((a, b) => (b.llmsByElo[0].cbaElo ?? -1) - (a.llmsByElo[0].cbaElo ?? -1));
return grouped;
}
// Hypothetical cost benchmark for a model, based on total cost of 100k input tokens and 10k output tokens.
function _getLlmCostBenchmark(llm: DLLM): number | undefined {
if (!llm.pricing?.chat) return undefined;
const costIn = getLlmCostForTokens(100000, 100000, llm.pricing.chat.input);
const costOut = getLlmCostForTokens(100000, 10000, llm.pricing.chat.output);
return (costIn !== undefined && costOut !== undefined) ? costIn + costOut : undefined;
}
// Selects the 'fast' llm
function _selectFastLlmID(vendors: GroupedVendorLLMs) {
if (!vendors.length) return null;
for (const vendor of vendors) {
const lowestCostLlm = vendor.llmsByElo.reduce((acc, llm) => {
if (!acc)
return llm;
if (!llm.costRank || !acc.costRank)
return acc;
return llm.costRank < acc.costRank ? llm : acc;
}, null as BenchVendorLLMs['llmsByElo'][number] | null);
if (lowestCostLlm)
return lowestCostLlm.id;
}
return null;
}
@@ -1,4 +1,4 @@
import { DChatGeneratePricing, getLlmPriceForTokens, isModelPriceFree } from '~/common/stores/llms/llms.pricing';
import { DChatGeneratePricing, getLlmCostForTokens, isModelPricingFree } from '~/common/stores/llms/llms.pricing';
/**
* This is a stored type - IMPORTANT: do not break.
@@ -105,7 +105,7 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
return { $code: 'no-pricing' };
// pricing: bail if free
if (isModelPriceFree(pricing))
if (isModelPricingFree(pricing))
return { $code: 'free' };
@@ -113,8 +113,8 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
const isPartialMessage = metrics.TsR === 'pending' || metrics.TsR === 'aborted';
// Calculate costs
const $in = getLlmPriceForTokens(inputTokens, inputTokens, pricing.input);
const $out = getLlmPriceForTokens(inputTokens, outputTokens, pricing.output);
const $in = getLlmCostForTokens(inputTokens, inputTokens, pricing.input);
const $out = getLlmCostForTokens(inputTokens, outputTokens, pricing.output);
if ($in === undefined || $out === undefined)
return { $code: 'partial-price' };
@@ -128,13 +128,13 @@ export function computeChatGenerationCosts(metrics?: Readonly<DChatGenerateMetri
throw new Error('Tiered pricing with cache is not supported');
const inputNoCache = inputTokens + cacheReadTokens + cacheWriteTokens;
const $cacheRead = getLlmPriceForTokens(inputNoCache, cacheReadTokens, pricing.cache?.read);
const $cacheWrite = getLlmPriceForTokens(inputNoCache, cacheWriteTokens, pricing.cache?.write);
const $cacheRead = getLlmCostForTokens(inputNoCache, cacheReadTokens, pricing.cache?.read);
const $cacheWrite = getLlmCostForTokens(inputNoCache, cacheWriteTokens, pricing.cache?.write);
if ($cacheRead === undefined || $cacheWrite === undefined)
return { $code: 'partial-price' };
// compute the advantage from caching
const $inNoCache = getLlmPriceForTokens(inputNoCache, inputNoCache, pricing.input)!;
const $inNoCache = getLlmCostForTokens(inputNoCache, inputNoCache, pricing.input)!;
return {
$c: Math.round(($in + $out + $cacheRead + $cacheWrite) * USD_TO_CENTS * 10000) / 10000,
$cdCache: Math.round(($inNoCache - $in - $cacheRead - $cacheWrite) * USD_TO_CENTS * 10000) / 10000,
+2 -2
View File
@@ -7,7 +7,7 @@ import type { OpenAIWire_Tools } from '~/modules/aix/server/dispatch/wiretypes/o
import type { DModelsService, DModelsServiceId } from '~/common/stores/llms/modelsservice.types';
import { DLLM, DLLMId, LLM_IF_OAI_Chat } from '~/common/stores/llms/llms.types';
import { llmsStoreActions } from '~/common/stores/llms/store-llms';
import { isModelPriceFree } from '~/common/stores/llms/llms.pricing';
import { isModelPricingFree } from '~/common/stores/llms/llms.pricing';
import type { ModelDescriptionSchema } from './server/llm.server.types';
import { DOpenAILLMOptions, FALLBACK_LLM_TEMPERATURE } from './vendors/openai/openai.vendor';
@@ -116,7 +116,7 @@ function _createDLLMFromModelDescription(d: ModelDescriptionSchema, service: DMo
chat: {
...d.chatPrice,
// compute the free status
_isFree: isModelPriceFree(d.chatPrice),
_isFree: isModelPricingFree(d.chatPrice),
},
};
}
+8 -7
View File
@@ -48,27 +48,28 @@ const PriceUpTo_schema = z.object({
price: PricePerMToken_schema,
});
const TieredPrice_schema = z.union([
const TieredPricing_schema = z.union([
PricePerMToken_schema,
z.array(PriceUpTo_schema),
]);
// NOTE: (!) keep this in sync with DChatGeneratePricing (llms.pricing.ts)
const ChatGeneratePricing_schema = z.object({
input: TieredPrice_schema.optional(),
output: TieredPrice_schema.optional(),
input: TieredPricing_schema.optional(),
output: TieredPricing_schema.optional(),
// Future: Perplexity has a cost per request, consider this for future additions
// perRequest: z.number().optional(), // New field for fixed per-request pricing
cache: z.discriminatedUnion('cType', [
z.object({
cType: z.literal('ant-bp'), // [Anthropic] Breakpoint-based caching
read: TieredPrice_schema,
write: TieredPrice_schema,
read: TieredPricing_schema,
write: TieredPricing_schema,
duration: z.number(),
}),
z.object({
cType: z.literal('oai-apc'), // [OpenAI] Automatic Prompt Caching
read: TieredPrice_schema,
// write: TieredPrice_schema, // Not needed, as it's automatic
read: TieredPricing_schema,
// write: TieredPricing_schema, // Not needed, as it's automatic
}),
]).optional(),
// Not for the server-side, computed on the client only
@@ -327,7 +327,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
trainingDataCutoff: 'Sep 2021',
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
chatPrice: { input: 0.5, output: 1.5 },
// benchmark: { cbaElo: 1106 }, // disabled so that it won't be picked up as 'fast' model
benchmark: { cbaElo: 1106 },
},
{
idPrefix: 'gpt-3.5-turbo',
@@ -341,7 +341,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
trainingDataCutoff: 'Sep 2021',
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
chatPrice: { input: 0.5, output: 1.5 },
// benchmark: { cbaElo: 1106 }, // disabled so that it won't be picked up as 'fast' model
benchmark: { cbaElo: 1106 },
},
{
idPrefix: 'gpt-3.5-turbo-1106',
@@ -352,7 +352,7 @@ export const _knownOpenAIChatModels: ManualMappings = [
trainingDataCutoff: 'Sep 2021',
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn],
chatPrice: { input: 1, output: 2 },
// benchmark: { cbaElo: 1072 }, // disabled so that it won't be picked up as 'fast' model
benchmark: { cbaElo: 1072 },
hidden: true,
},