diff --git a/src/modules/llms/server/openai/models/models.data.ts b/src/modules/llms/server/openai/models/models.data.ts index 4006e2eaa..3d1760c71 100644 --- a/src/modules/llms/server/openai/models/models.data.ts +++ b/src/modules/llms/server/openai/models/models.data.ts @@ -1,4 +1,5 @@ -import { LLM_IF_OAI_Chat } from '~/common/stores/llms/llms.types'; +import { LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, LLM_IF_OAI_Reasoning, LLM_IF_OAI_Vision } from '~/common/stores/llms/llms.types'; +import { capitalizeFirstLetter } from '~/common/util/textUtils'; import type { ModelDescriptionSchema } from '../../llm.server.types'; @@ -27,34 +28,61 @@ export function lmStudioModelToModelDescription(modelId: string): ModelDescripti // [LocalAI] -const _knownLocalAIChatModels: ManualMappings = [ - { - idPrefix: 'ggml-gpt4all-j', - label: 'GPT4All-J', - description: 'GPT4All-J on LocalAI', - contextWindow: 2048, - interfaces: [LLM_IF_OAI_Chat], - }, - { - idPrefix: 'luna-ai-llama2', - label: 'Luna AI Llama2 Uncensored', - description: 'Luna AI Llama2 on LocalAI', - contextWindow: 4096, - interfaces: [LLM_IF_OAI_Chat], - }, +const _knownLocalAIChatModels: ManualMappings = []; +const _knownLocalAIPrice = { input: 'free', output: 'free' } as const; +const _hideLocalAIModels = [ + 'jina-reranker-v1-base-en', // vector search + 'stablediffusion', // text-to-image + 'text-embedding-ada-002', // embedding generator + 'tts-1', // text-to-speech + 'whisper-1', // speech-to-text ]; +export function localAIModelSortFn(a: ModelDescriptionSchema, b: ModelDescriptionSchema): number { + // hidden to the bottom + if (a.hidden && !b.hidden) return 1; + if (!a.hidden && b.hidden) return -1; + + // keep the order from the API + return 0; +} + + export function localAIModelToModelDescription(modelId: string): ModelDescriptionSchema { + + // heurisics to extract a label from the model ID + const label = modelId + .replace('.gguf', '') + .replace('ggml-', '') + .replace('.bin', '') + .replaceAll('-', ' ') + .replace(' Q4_K_M', ' (Q4_K_M)') + .replace(' F16', ' (F16)') + .split(' ') + .map(capitalizeFirstLetter) + .join(' '); + + const description = `LocalAI model. File: ${modelId}`; + + // very dull heuristics + const interfaces = [LLM_IF_OAI_Chat, LLM_IF_OAI_Fn]; + if (modelId.includes('vision') || modelId.includes('llava')) + interfaces.push(LLM_IF_OAI_Vision); + if (modelId.includes('r1')) + interfaces.push(LLM_IF_OAI_Reasoning); + return fromManualMapping(_knownLocalAIChatModels, modelId, undefined, undefined, { idPrefix: modelId, - label: modelId - .replace('ggml-', '') - .replace('.bin', '') - .replaceAll('-', ' '), - description: 'Unknown localAI model. Please update `models.data.ts` with this ID', + label, + description, contextWindow: null, // 'not provided' - interfaces: [LLM_IF_OAI_Chat], // assume.. - chatPrice: { input: 'free', output: 'free' }, + interfaces, + // parameterSpecs + // maxCompletionTokens + // trainingDataCutoff + // benchmark + chatPrice: _knownLocalAIPrice, + hidden: _hideLocalAIModels.includes(modelId), }); } diff --git a/src/modules/llms/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts index 925498b58..977312362 100644 --- a/src/modules/llms/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -16,7 +16,7 @@ import { ListModelsResponse_schema, ModelDescriptionSchema } from '../llm.server import { azureModelToModelDescription, openAIModelFilter, openAIModelToModelDescription, openAISortModels } from './models/openai.models'; import { deepseekModelFilter, deepseekModelSort, deepseekModelToModelDescription } from './models/deepseek.models'; import { groqModelFilter, groqModelSortFn, groqModelToModelDescription } from './models/groq.models'; -import { lmStudioModelToModelDescription, localAIModelToModelDescription } from './models/models.data'; +import { lmStudioModelToModelDescription, localAIModelToModelDescription, localAIModelSortFn } from './models/models.data'; import { mistralModelsSort, mistralModelToModelDescription } from './models/mistral.models'; import { openPipeModelDescriptions, openPipeModelSort, openPipeModelToModelDescriptions } from './models/openpipe.models'; import { openRouterModelFamilySortFn, openRouterModelToModelDescription } from './models/openrouter.models'; @@ -177,7 +177,8 @@ export const llmOpenAIRouter = createTRPCRouter({ // [LocalAI]: map id to label case 'localai': models = openAIModels - .map(model => localAIModelToModelDescription(model.id)); + .map(({ id }) => localAIModelToModelDescription(id)) + .sort(localAIModelSortFn); break; case 'mistral':