Gemini: improve support (incl. interfaces, cost, visibility)

This commit is contained in:
Enrico Ros
2024-05-14 15:15:53 -07:00
parent 2db74867f5
commit 9eb0cc0b62
3 changed files with 129 additions and 18 deletions
+121 -16
View File
@@ -1,16 +1,117 @@
import type { GeminiModelSchema } from './gemini.wiretypes';
import type { ModelDescriptionSchema } from '../llm.server.types';
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Vision } from '../../store-llms';
import { LLM_IF_OAI_Chat, LLM_IF_OAI_Json, LLM_IF_OAI_Vision } from '../../store-llms';
// supported interfaces
const geminiChatInterfaces: GeminiModelSchema['supportedGenerationMethods'] = ['generateContent'];
// unsupported interfaces
const filterUnallowedNames = ['Legacy'];
const filterUnallowedInterfaces: GeminiModelSchema['supportedGenerationMethods'] = ['generateAnswer', 'embedContent', 'embedText'];
const geminiLinkModels = ['models/gemini-pro', 'models/gemini-pro-vision'];
// interfaces mapping
const geminiChatInterfaces: GeminiModelSchema['supportedGenerationMethods'] = ['generateContent'];
const geminiVisionNames = ['-vision'];
/* Manual models details
Gemini Name Mapping example:
- Latest version gemini-1.0-pro-latest <model>-<generation>-<variation>-latest
- Latest stable version gemini-1.0-pro <model>-<generation>-<variation>
- Stable versions gemini-1.0-pro-001 <model>-<generation>-<variation>-<version>
*/
const _knownGeminiModels: ({
id: string,
isLatest?: boolean,
isPreview?: boolean
symLink?: string
} & Pick<ModelDescriptionSchema, 'interfaces' | 'pricing' | 'trainingDataCutoff' | 'hidden'>)[] = [
// Generation 1.5
{
id: 'models/gemini-1.5-flash-latest',
isLatest: true,
isPreview: true,
pricing: {
chatIn: 0.70, // 0.35 up to 128k tokens, 0.70 prompts > 128k tokens
chatOut: 1.05, // 0.53 up to 128k tokens, 1.05 prompts > 128k tokens
},
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Vision, LLM_IF_OAI_Json], // input: audio, images and text
},
{
id: 'models/gemini-1.5-pro-latest',
// NOTE: no 'models/gemini-1.5-pro' (latest stable) as of 2024-05-14
isLatest: true,
pricing: {
chatIn: 7.00, // $3.50 / 1 million tokens (for prompts up to 128K tokens), $7.00 / 1 million tokens (for prompts longer than 128K)
chatOut: 21.00, // $10.50 / 1 million tokens (128K or less), $21.00 / 1 million tokens (128K+)
},
trainingDataCutoff: 'Apr 2024',
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Vision, LLM_IF_OAI_Json], // input: audio, images and text
},
// Generation 1.0
{
id: 'models/gemini-1.0-pro-latest',
isLatest: true,
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat],
},
{
id: 'models/gemini-1.0-pro',
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat],
hidden: true,
},
{
id: 'models/gemini-1.0-pro-001',
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat],
hidden: true,
},
// Generation 1.0 + Vision
{
id: 'models/gemini-1.0-pro-vision-latest',
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Vision], // Text and Images
hidden: true,
},
// Older symlinks
{
id: 'models/gemini-pro',
symLink: 'models/gemini-1.0-pro',
// copied from symlinked
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat],
hidden: true,
},
{
id: 'models/gemini-pro-vision',
// copied from symlinked
symLink: 'models/gemini-1.0-pro-vision',
pricing: {
chatIn: 0.50,
chatOut: 1.50,
},
interfaces: [LLM_IF_OAI_Chat, LLM_IF_OAI_Vision], // Text and Images
hidden: true,
},
];
export function geminiFilterModels(geminiModel: GeminiModelSchema): boolean {
@@ -26,17 +127,20 @@ export function geminiSortModels(a: ModelDescriptionSchema, b: ModelDescriptionS
return b.label.localeCompare(a.label);
}
export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, allModels: GeminiModelSchema[]): ModelDescriptionSchema {
export function geminiModelToModelDescription(geminiModel: GeminiModelSchema): ModelDescriptionSchema {
const { description, displayName, name: modelId, supportedGenerationMethods } = geminiModel;
// find known manual mapping
const knownModel = _knownGeminiModels.find(m => m.id === modelId);
// handle symlinks
const isSymlink = geminiLinkModels.includes(modelId);
const symlinked = isSymlink ? allModels.find(m => m.displayName === displayName && m.name !== modelId) : null;
const label = isSymlink ? `🔗 ${displayName.replace('1.0', '')}${symlinked ? symlinked.name : '?'}` : displayName;
const label = knownModel?.symLink
? `🔗 ${displayName.replace('1.0', '')}${knownModel.symLink}`
: displayName;
// handle hidden models
const hasChatInterfaces = supportedGenerationMethods.some(iface => geminiChatInterfaces.includes(iface));
const hidden = isSymlink || !hasChatInterfaces;
const hidden = knownModel?.hidden || !!knownModel?.symLink || !hasChatInterfaces;
// context window
const { inputTokenLimit, outputTokenLimit } = geminiModel;
@@ -46,11 +150,12 @@ export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, al
const { version, topK, topP, temperature } = geminiModel;
const descriptionLong = description + ` (Version: ${version}, Defaults: temperature=${temperature}, topP=${topP}, topK=${topK}, interfaces=[${supportedGenerationMethods.join(',')}])`;
const interfaces: ModelDescriptionSchema['interfaces'] = [];
if (hasChatInterfaces) {
// use known interfaces, or add chat if this is a generateContent model
const interfaces: ModelDescriptionSchema['interfaces'] = knownModel?.interfaces || [];
if (!interfaces.length && hasChatInterfaces) {
interfaces.push(LLM_IF_OAI_Chat);
if (geminiVisionNames.some(name => modelId.includes(name)))
interfaces.push(LLM_IF_OAI_Vision);
// if (geminiVisionNames.some(name => modelId.includes(name)))
// interfaces.push(LLM_IF_OAI_Vision);
}
return {
@@ -61,11 +166,11 @@ export function geminiModelToModelDescription(geminiModel: GeminiModelSchema, al
description: descriptionLong,
contextWindow: contextWindow,
maxCompletionTokens: outputTokenLimit,
// trainingDataCutoff: '...',
trainingDataCutoff: knownModel?.trainingDataCutoff,
interfaces,
// rateLimits: isGeminiPro ? { reqPerMinute: 60 } : undefined,
// benchmarks: ...
// pricing: isGeminiPro ? { needs per-character and per-image pricing } : undefined,
pricing: knownModel?.pricing, // TODO: needs <>128k, and per-character and per-image pricing
hidden,
};
}
@@ -147,7 +147,7 @@ export const llmGeminiRouter = createTRPCRouter({
// map to our output schema
const models = detailedModels
.filter(geminiFilterModels)
.map(geminiModel => geminiModelToModelDescription(geminiModel, detailedModels))
.map(geminiModel => geminiModelToModelDescription(geminiModel))
.sort(geminiSortModels);
return {
@@ -875,7 +875,13 @@ export function groqModelSortFn(a: ModelDescriptionSchema, b: ModelDescriptionSc
// Helpers
type ManualMapping = ({ idPrefix: string, isLatest?: boolean, isPreview?: boolean, isLegacy?: boolean, symLink?: string } & Omit<ModelDescriptionSchema, 'id' | 'created' | 'updated'>);
type ManualMapping = ({
idPrefix: string,
isLatest?: boolean,
isPreview?: boolean,
isLegacy?: boolean,
symLink?: string
} & Omit<ModelDescriptionSchema, 'id' | 'created' | 'updated'>);
type ManualMappings = ManualMapping[];
function fromManualMapping(mappings: ManualMappings, id: string, created?: number, updated?: number, fallback?: ManualMapping): ModelDescriptionSchema {