Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6f7abf0220 |
@@ -0,0 +1,58 @@
|
|||||||
|
/* Provides a single endpoint for all services. */
|
||||||
|
import { RequestHandler } from "express";
|
||||||
|
import { generateErrorMessage } from "zod-error";
|
||||||
|
import { APIFormat } from "../shared/key-management";
|
||||||
|
import {
|
||||||
|
getServiceForModel,
|
||||||
|
LLMService,
|
||||||
|
MODEL_FAMILIES,
|
||||||
|
MODEL_FAMILY_SERVICE,
|
||||||
|
ModelFamily,
|
||||||
|
} from "../shared/models";
|
||||||
|
import { API_SCHEMA_VALIDATORS } from "../shared/api-schemas";
|
||||||
|
|
||||||
|
const detectApiFormat = (body: any, formats: APIFormat[]): APIFormat => {
|
||||||
|
const errors = [];
|
||||||
|
for (const format of formats) {
|
||||||
|
const result = API_SCHEMA_VALIDATORS[format].safeParse(body);
|
||||||
|
if (result.success) {
|
||||||
|
return format;
|
||||||
|
} else {
|
||||||
|
errors.push(result.error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw new Error(`Couldn't determine the format of your request. Errors: ${errors}`);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tries to infer LLMService and APIFormat using the model name and the presence
|
||||||
|
* of certain fields in the request body.
|
||||||
|
*/
|
||||||
|
const inferService: RequestHandler = (req, res, next) => {
|
||||||
|
const model = req.body.model;
|
||||||
|
if (!model) {
|
||||||
|
throw new Error("No model specified");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Service determines the key provider and is typically determined by the
|
||||||
|
// requested model, though some models are served by multiple services.
|
||||||
|
// API format determines the expected request/response format.
|
||||||
|
let service: LLMService;
|
||||||
|
let inboundApi: APIFormat;
|
||||||
|
let outboundApi: APIFormat;
|
||||||
|
|
||||||
|
if (MODEL_FAMILIES.includes(model)) {
|
||||||
|
service = MODEL_FAMILY_SERVICE[model as ModelFamily];
|
||||||
|
} else {
|
||||||
|
service = getServiceForModel(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Each service has typically one API format.
|
||||||
|
switch (service) {
|
||||||
|
case "openai": {
|
||||||
|
const detected = detectApiFormat(req.body, ["openai", "openai-text", "openai-image"]);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
};
|
||||||
@@ -4,8 +4,13 @@ import os from "os";
|
|||||||
import schedule from "node-schedule";
|
import schedule from "node-schedule";
|
||||||
import { config } from "../../config";
|
import { config } from "../../config";
|
||||||
import { logger } from "../../logger";
|
import { logger } from "../../logger";
|
||||||
import { LLMService, MODEL_FAMILY_SERVICE, ModelFamily } from "../models";
|
import {
|
||||||
import { Key, Model, KeyProvider } from "./index";
|
getServiceForModel,
|
||||||
|
LLMService,
|
||||||
|
MODEL_FAMILY_SERVICE,
|
||||||
|
ModelFamily,
|
||||||
|
} from "../models";
|
||||||
|
import { Key, KeyProvider, Model } from "./index";
|
||||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||||
@@ -42,7 +47,7 @@ export class KeyPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public get(model: Model): Key {
|
public get(model: Model): Key {
|
||||||
const service = this.getServiceForModel(model);
|
const service = getServiceForModel(model);
|
||||||
return this.getKeyProvider(service).get(model);
|
return this.getKeyProvider(service).get(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +77,7 @@ export class KeyPool {
|
|||||||
public available(model: Model | "all" = "all"): number {
|
public available(model: Model | "all" = "all"): number {
|
||||||
return this.keyProviders.reduce((sum, provider) => {
|
return this.keyProviders.reduce((sum, provider) => {
|
||||||
const includeProvider =
|
const includeProvider =
|
||||||
model === "all" || this.getServiceForModel(model) === provider.service;
|
model === "all" || getServiceForModel(model) === provider.service;
|
||||||
return sum + (includeProvider ? provider.available() : 0);
|
return sum + (includeProvider ? provider.available() : 0);
|
||||||
}, 0);
|
}, 0);
|
||||||
}
|
}
|
||||||
@@ -109,33 +114,6 @@ export class KeyPool {
|
|||||||
provider.recheck();
|
provider.recheck();
|
||||||
}
|
}
|
||||||
|
|
||||||
private getServiceForModel(model: Model): LLMService {
|
|
||||||
if (
|
|
||||||
model.startsWith("gpt") ||
|
|
||||||
model.startsWith("text-embedding-ada") ||
|
|
||||||
model.startsWith("dall-e")
|
|
||||||
) {
|
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
|
||||||
return "openai";
|
|
||||||
} else if (model.startsWith("claude-")) {
|
|
||||||
// https://console.anthropic.com/docs/api/reference#parameters
|
|
||||||
return "anthropic";
|
|
||||||
} else if (model.includes("gemini")) {
|
|
||||||
// https://developers.generativeai.google.com/models/language
|
|
||||||
return "google-ai";
|
|
||||||
} else if (model.includes("mistral")) {
|
|
||||||
// https://docs.mistral.ai/platform/endpoints
|
|
||||||
return "mistral-ai";
|
|
||||||
} else if (model.startsWith("anthropic.claude")) {
|
|
||||||
// AWS offers models from a few providers
|
|
||||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
|
||||||
return "aws";
|
|
||||||
} else if (model.startsWith("azure")) {
|
|
||||||
return "azure";
|
|
||||||
}
|
|
||||||
throw new Error(`Unknown service for model '${model}'`);
|
|
||||||
}
|
|
||||||
|
|
||||||
private getKeyProvider(service: LLMService): KeyProvider {
|
private getKeyProvider(service: LLMService): KeyProvider {
|
||||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -205,6 +205,33 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||||||
return (req.modelFamily = modelFamily);
|
return (req.modelFamily = modelFamily);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getServiceForModel(model: string): LLMService {
|
||||||
|
if (
|
||||||
|
model.startsWith("gpt") ||
|
||||||
|
model.startsWith("text-embedding-ada") ||
|
||||||
|
model.startsWith("dall-e")
|
||||||
|
) {
|
||||||
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
|
return "openai";
|
||||||
|
} else if (model.startsWith("claude-")) {
|
||||||
|
// https://console.anthropic.com/docs/api/reference#parameters
|
||||||
|
return "anthropic";
|
||||||
|
} else if (model.includes("gemini")) {
|
||||||
|
// https://developers.generativeai.google.com/models/language
|
||||||
|
return "google-ai";
|
||||||
|
} else if (model.includes("mistral")) {
|
||||||
|
// https://docs.mistral.ai/platform/endpoints
|
||||||
|
return "mistral-ai";
|
||||||
|
} else if (model.startsWith("anthropic.claude")) {
|
||||||
|
// AWS offers models from a few providers
|
||||||
|
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||||
|
return "aws";
|
||||||
|
} else if (model.startsWith("azure")) {
|
||||||
|
return "azure";
|
||||||
|
}
|
||||||
|
throw new Error(`Unknown service for model '${model}'`);
|
||||||
|
}
|
||||||
|
|
||||||
function assertNever(x: never): never {
|
function assertNever(x: never): never {
|
||||||
throw new Error(`Called assertNever with argument ${x}.`);
|
throw new Error(`Called assertNever with argument ${x}.`);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user