Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)
This commit is contained in:
@@ -177,10 +177,6 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public anyUnchecked() {
|
||||
return this.keys.some((k) => k.lastChecked === 0);
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
@@ -202,10 +198,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
const timeUntilFirstReady = Math.min(
|
||||
...activeKeys.map((k) => k.rateLimitedUntil - now)
|
||||
);
|
||||
return timeUntilFirstReady;
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -216,7 +209,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.warn({ key: keyHash }, "Key rate limited");
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AwsBedrockModelFamily } from "../../models";
|
||||
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
export const AWS_BEDROCK_SUPPORTED_MODELS = [
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-instant-v1",
|
||||
] as const;
|
||||
export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number];
|
||||
|
||||
type AwsBedrockKeyUsage = {
|
||||
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
readonly service: "aws";
|
||||
readonly modelFamilies: AwsBedrockModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 300;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
readonly service = "aws";
|
||||
|
||||
private keys: AwsBedrockKey[] = [];
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.awsCredentials?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"AWS_CREDENTIALS is not set. AWS Bedrock API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AwsBedrockKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["aws-claude"],
|
||||
isTrial: false,
|
||||
isDisabled: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `aws-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
["aws-claudeTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
|
||||
}
|
||||
|
||||
public init() {}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: AwsBedrockModel) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error("No AWS Bedrock keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
selectedKey.rateLimitedAt = now;
|
||||
// Intended to throttle the queue processor as otherwise it will just
|
||||
// flood the API with requests and we want to wait a sec to see if we're
|
||||
// going to get a rate limit error on this key.
|
||||
selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AwsBedrockKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AwsBedrockKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key["aws-claudeTokens"] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod(_model: AwsBedrockModel) {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {}
|
||||
}
|
||||
@@ -4,17 +4,25 @@ import {
|
||||
AnthropicModel,
|
||||
} from "./anthropic/provider";
|
||||
import { GOOGLE_PALM_SUPPORTED_MODELS, GooglePalmModel } from "./palm/provider";
|
||||
import { AWS_BEDROCK_SUPPORTED_MODELS, AwsBedrockModel } from "./aws/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
/** The request and response format used by a model's API. */
|
||||
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
|
||||
export type Model = OpenAIModel | AnthropicModel | GooglePalmModel;
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| AwsBedrockModel;
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
readonly key: string;
|
||||
/** The service that this key is for. */
|
||||
service: APIFormat;
|
||||
service: LLMService;
|
||||
/** Whether this is a free trial key. These are prioritized over paid keys if they can fulfill the request. */
|
||||
isTrial: boolean;
|
||||
/** The model families that this key has access to. */
|
||||
@@ -44,14 +52,13 @@ for service-agnostic functionality.
|
||||
*/
|
||||
|
||||
export interface KeyProvider<T extends Key = Key> {
|
||||
readonly service: APIFormat;
|
||||
readonly service: LLMService;
|
||||
init(): void;
|
||||
get(model: Model): T;
|
||||
list(): Omit<T, "key">[];
|
||||
disable(key: T): void;
|
||||
update(hash: string, update: Partial<T>): void;
|
||||
available(): number;
|
||||
anyUnchecked(): boolean;
|
||||
incrementUsage(hash: string, model: string, tokens: number): void;
|
||||
getLockoutPeriod(model: Model): number;
|
||||
markRateLimited(hash: string): void;
|
||||
@@ -68,7 +75,9 @@ export {
|
||||
OPENAI_SUPPORTED_MODELS,
|
||||
ANTHROPIC_SUPPORTED_MODELS,
|
||||
GOOGLE_PALM_SUPPORTED_MODELS,
|
||||
AWS_BEDROCK_SUPPORTED_MODELS,
|
||||
};
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
@@ -4,16 +4,17 @@ import os from "os";
|
||||
import schedule from "node-schedule";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { Key, Model, KeyProvider, APIFormat } from "./index";
|
||||
import { Key, Model, KeyProvider, LLMService } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GooglePalmKeyProvider } from "./palm/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
|
||||
export class KeyPool {
|
||||
private keyProviders: KeyProvider[] = [];
|
||||
private recheckJobs: Partial<Record<APIFormat, schedule.Job | null>> = {
|
||||
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
|
||||
openai: null,
|
||||
};
|
||||
|
||||
@@ -21,6 +22,7 @@ export class KeyPool {
|
||||
this.keyProviders.push(new OpenAIKeyProvider());
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
}
|
||||
|
||||
public init() {
|
||||
@@ -28,7 +30,7 @@ export class KeyPool {
|
||||
const availableKeys = this.available("all");
|
||||
if (availableKeys === 0) {
|
||||
throw new Error(
|
||||
"No keys loaded. Ensure OPENAI_KEY, ANTHROPIC_KEY, or GOOGLE_PALM_KEY are set."
|
||||
"No keys loaded. Ensure that at least one key is configured."
|
||||
);
|
||||
}
|
||||
this.scheduleRecheck();
|
||||
@@ -43,6 +45,11 @@ export class KeyPool {
|
||||
return this.keyProviders.flatMap((provider) => provider.list());
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks a key as disabled for a specific reason. `revoked` should be used
|
||||
* to indicate a key that can never be used again, while `quota` should be
|
||||
* used to indicate a key that is still valid but has exceeded its quota.
|
||||
*/
|
||||
public disable(key: Key, reason: "quota" | "revoked"): void {
|
||||
const service = this.getKeyProvider(key.service);
|
||||
service.disable(key);
|
||||
@@ -59,17 +66,14 @@ export class KeyPool {
|
||||
service.update(key.hash, props);
|
||||
}
|
||||
|
||||
public available(service: APIFormat | "all" = "all"): number {
|
||||
public available(model: Model | "all" = "all"): number {
|
||||
return this.keyProviders.reduce((sum, provider) => {
|
||||
const includeProvider = service === "all" || service === provider.service;
|
||||
const includeProvider =
|
||||
model === "all" || this.getService(model) === provider.service;
|
||||
return sum + (includeProvider ? provider.available() : 0);
|
||||
}, 0);
|
||||
}
|
||||
|
||||
public anyUnchecked(): boolean {
|
||||
return this.keyProviders.some((provider) => provider.anyUnchecked());
|
||||
}
|
||||
|
||||
public incrementUsage(key: Key, model: string, tokens: number): void {
|
||||
const provider = this.getKeyProvider(key.service);
|
||||
provider.incrementUsage(key.hash, model, tokens);
|
||||
@@ -92,7 +96,7 @@ export class KeyPool {
|
||||
}
|
||||
}
|
||||
|
||||
public recheck(service: APIFormat): void {
|
||||
public recheck(service: LLMService): void {
|
||||
if (!config.checkKeys) {
|
||||
logger.info("Skipping key recheck because key checking is disabled");
|
||||
return;
|
||||
@@ -102,7 +106,7 @@ export class KeyPool {
|
||||
provider.recheck();
|
||||
}
|
||||
|
||||
private getService(model: Model): APIFormat {
|
||||
private getService(model: Model): LLMService {
|
||||
if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
return "openai";
|
||||
@@ -112,16 +116,15 @@ export class KeyPool {
|
||||
} else if (model.includes("bison")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-palm";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
||||
private getKeyProvider(service: APIFormat): KeyProvider {
|
||||
// The "openai-text" service is a special case handled by OpenAIKeyProvider.
|
||||
if (service === "openai-text") {
|
||||
service = "openai";
|
||||
}
|
||||
|
||||
private getKeyProvider(service: LLMService): KeyProvider {
|
||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,10 +33,10 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class OpenAIKeyChecker {
|
||||
private readonly keys: OpenAIKey[];
|
||||
private log = logger.child({ module: "key-checker", service: "openai" });
|
||||
private timeout?: NodeJS.Timeout;
|
||||
private cloneKey: CloneFn;
|
||||
private updateKey: UpdateFn;
|
||||
private log = logger.child({ module: "key-checker", service: "openai" });
|
||||
private timeout?: NodeJS.Timeout;
|
||||
private lastCheck = 0;
|
||||
|
||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||
@@ -248,10 +248,10 @@ export class OpenAIKeyChecker {
|
||||
} else if (status === 429) {
|
||||
switch (data.error.type) {
|
||||
case "insufficient_quota":
|
||||
case "access_terminated":
|
||||
case "billing_not_active":
|
||||
const isOverQuota = data.error.type === "insufficient_quota";
|
||||
const isRevoked = !isOverQuota;
|
||||
case "access_terminated":
|
||||
const isRevoked = data.error.type === "access_terminated";
|
||||
const isOverQuota = !isRevoked;
|
||||
const modelFamilies: OpenAIModelFamily[] = isRevoked
|
||||
? ["turbo"]
|
||||
: key.modelFamilies;
|
||||
@@ -392,10 +392,9 @@ export class OpenAIKeyChecker {
|
||||
}
|
||||
|
||||
static getHeaders(key: OpenAIKey) {
|
||||
const headers = {
|
||||
return {
|
||||
Authorization: `Bearer ${key.key}`,
|
||||
...(key.organizationId && { "OpenAI-Organization": key.organizationId }),
|
||||
};
|
||||
return headers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ round-robin access to keys. Keys are stored in the OPENAI_KEY environment
|
||||
variable as a comma-separated list of keys. */
|
||||
import crypto from "crypto";
|
||||
import http from "http";
|
||||
import { KeyProvider, Key, Model } from "../index";
|
||||
import { Key, KeyProvider, Model } from "../index";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { OpenAIKeyChecker } from "./checker";
|
||||
import { OpenAIModelFamily, getOpenAIModelFamily } from "../../models";
|
||||
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||
|
||||
export type OpenAIModel =
|
||||
| "gpt-3.5-turbo"
|
||||
@@ -276,10 +276,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public anyUnchecked() {
|
||||
return !!config.checkKeys && this.keys.some((key) => !key.lastChecked);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a model, returns the period until a key will be available to service
|
||||
* the request, or returns 0 if a key is ready immediately.
|
||||
@@ -318,7 +314,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
const timeUntilFirstReady = Math.min(
|
||||
return Math.min(
|
||||
...activeKeys.map((key) => {
|
||||
const resetTime = Math.max(
|
||||
key.rateLimitRequestsReset,
|
||||
@@ -327,11 +323,10 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
return key.rateLimitedAt + resetTime - now;
|
||||
})
|
||||
);
|
||||
return timeUntilFirstReady;
|
||||
}
|
||||
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.warn({ key: keyHash }, "Key rate limited");
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
key.rateLimitedAt = Date.now();
|
||||
}
|
||||
|
||||
@@ -146,10 +146,6 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public anyUnchecked() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
@@ -171,10 +167,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
const timeUntilFirstReady = Math.min(
|
||||
...activeKeys.map((k) => k.rateLimitedUntil - now)
|
||||
);
|
||||
return timeUntilFirstReady;
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -185,7 +178,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.warn({ key: keyHash }, "Key rate limited");
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
|
||||
+15
-2
@@ -3,14 +3,23 @@ import { logger } from "../logger";
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily;
|
||||
| GooglePalmModelFamily
|
||||
| AwsBedrockModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
) => arr)(["turbo", "gpt4", "gpt4-32k", "claude", "bison"] as const);
|
||||
) => arr)([
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||
@@ -41,6 +50,10 @@ export function getGooglePalmModelFamily(model: string): ModelFamily {
|
||||
return "bison";
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../config";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
@@ -13,12 +12,8 @@ import {
|
||||
import { APIFormat } from "../key-management";
|
||||
|
||||
export async function init() {
|
||||
if (config.anthropicKey) {
|
||||
initClaude();
|
||||
}
|
||||
if (config.openaiKey || config.googlePalmKey) {
|
||||
initOpenAi();
|
||||
}
|
||||
initClaude();
|
||||
initOpenAi();
|
||||
}
|
||||
|
||||
/** Tagged union via `service` field of the different types of requests that can
|
||||
|
||||
@@ -8,6 +8,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
|
||||
"gpt4-32k": z.number().optional().default(0),
|
||||
claude: z.number().optional().default(0),
|
||||
bison: z.number().optional().default(0),
|
||||
"aws-claude": z.number().optional().default(0),
|
||||
});
|
||||
|
||||
export const UserSchema = z
|
||||
|
||||
@@ -11,7 +11,7 @@ import admin from "firebase-admin";
|
||||
import schedule from "node-schedule";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import { ModelFamily } from "../models";
|
||||
import { MODEL_FAMILIES, ModelFamily } from "../models";
|
||||
import { logger } from "../../logger";
|
||||
import { User, UserTokenCounts, UserUpdate } from "./schema";
|
||||
|
||||
@@ -23,6 +23,7 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
"gpt4-32k": 0,
|
||||
claude: 0,
|
||||
bison: 0,
|
||||
"aws-claude": 0,
|
||||
};
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
@@ -131,12 +132,14 @@ export function upsertUser(user: UserUpdate) {
|
||||
|
||||
// TODO: Write firebase migration to backfill new fields
|
||||
if (updates.tokenCounts) {
|
||||
updates.tokenCounts["gpt4-32k"] ??= 0;
|
||||
updates.tokenCounts["bison"] ??= 0;
|
||||
for (const family of MODEL_FAMILIES) {
|
||||
updates.tokenCounts[family] ??= 0;
|
||||
}
|
||||
}
|
||||
if (updates.tokenLimits) {
|
||||
updates.tokenLimits["gpt4-32k"] ??= 0;
|
||||
updates.tokenLimits["bison"] ??= 0;
|
||||
for (const family of MODEL_FAMILIES) {
|
||||
updates.tokenLimits[family] ??= 0;
|
||||
}
|
||||
}
|
||||
|
||||
users.set(user.token, Object.assign(existing, updates));
|
||||
@@ -360,9 +363,12 @@ function getModelFamilyForQuotaUsage(model: string): ModelFamily {
|
||||
if (model.includes("bison")) {
|
||||
return "bison";
|
||||
}
|
||||
if (model.includes("claude")) {
|
||||
if (model.startsWith("claude")) {
|
||||
return "claude";
|
||||
}
|
||||
if(model.startsWith("anthropic.claude")) {
|
||||
return "aws-claude";
|
||||
}
|
||||
throw new Error(`Unknown quota model family for model ${model}`);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user