diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 95c6530..5d9f4f2 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -1,12 +1,25 @@ -import { OPENAI_SUPPORTED_MODELS, OpenAIModel } from "./openai/provider"; +import type { ModelFamily } from "../models"; +import { KeyPool } from "./key-pool"; +import { + OPENAI_SUPPORTED_MODELS, + OpenAIKey, + OpenAIModel, +} from "./openai/provider"; import { ANTHROPIC_SUPPORTED_MODELS, + AnthropicKey, AnthropicModel, } from "./anthropic/provider"; -import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider"; -import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider"; -import { KeyPool } from "./key-pool"; -import type { ModelFamily } from "../models"; +import { + GOOGLE_PALM_SUPPORTED_MODELS, + GooglePalmKey, + GooglePalmModel, +} from "./palm/provider"; +import { + AWS_BEDROCK_SUPPORTED_MODELS, + AwsBedrockKey, + AwsBedrockModel, +} from "./aws/provider"; /** The request and response format used by a model's API. */ export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text"; @@ -18,6 +31,11 @@ export type Model = | GooglePalmModel | AwsBedrockModel; +type AllKeys = OpenAIKey | AnthropicKey | GooglePalmKey | AwsBedrockKey; +export type ServiceToKey = { + [K in AllKeys["service"]]: Extract; +}; + export interface Key { /** The API key itself. Never log this, use `hash` instead. */ readonly key: string; @@ -83,4 +101,4 @@ export { GooglePalmKey } from "./palm/provider"; export { AwsBedrockKey } from "./aws/provider"; export { assertSerializedKey } from "./serializers"; export { SerializedKey } from "./serializers"; -export { KeySerializer } from "./serializers"; \ No newline at end of file +export { KeySerializer } from "./serializers"; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index bcb9076..111d803 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -4,7 +4,7 @@ import os from "os"; import schedule from "node-schedule"; import { config } from "../../config"; import { logger } from "../../logger"; -import { Key, Model, KeyProvider, LLMService } from "./index"; +import { Key, KeyProvider, LLMService, Model, ServiceToKey } from "./index"; import { getSerializer } from "./serializers"; import { FirebaseKeyStore, KeyStore, MemoryKeyStore } from "./stores"; import { AnthropicKeyProvider } from "./anthropic/provider"; @@ -153,7 +153,9 @@ export class KeyPool { } } -function createKeyStore(service: LLMService): KeyStore { +function createKeyStore( + service: S +): KeyStore { const serializer = getSerializer(service); switch (config.persistenceProvider) { @@ -165,3 +167,4 @@ function createKeyStore(service: LLMService): KeyStore { throw new Error(`Unknown store type: ${config.persistenceProvider}`); } } + diff --git a/src/shared/key-management/serializers.ts b/src/shared/key-management/serializers.ts index 076c6fb..8132862 100644 --- a/src/shared/key-management/serializers.ts +++ b/src/shared/key-management/serializers.ts @@ -1,5 +1,5 @@ -import { Key, LLMService } from "."; import { assertNever } from "../utils"; +import { Key, LLMService, ServiceToKey } from "./index"; import { OpenAIKeySerializer } from "./openai/serializer"; import { AnthropicKeySerializer } from "./anthropic/serializer"; import { GooglePalmKeySerializer } from "./palm/serializer"; @@ -41,6 +41,9 @@ export function assertSerializedKey(k: any): asserts k is SerializedKey { } } +export function getSerializer( + service: S +): KeySerializer; export function getSerializer(service: LLMService): KeySerializer { switch (service) { case "openai":