Files
OAI-Proxy/src/shared/key-management/cohere/provider.ts
T

168 lines
5.0 KiB
TypeScript

import { Key, KeyProvider, createGenericGetLockoutPeriod } from "..";
import { CohereKeyChecker } from "./checker";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { CohereModelFamily, ModelFamily } from "../../models"; // Added ModelFamily
// CohereKeyUsage is removed, tokenUsage from base Key interface will be used.
export interface CohereKey extends Key {
readonly service: "cohere";
readonly modelFamilies: CohereModelFamily[];
isOverQuota: boolean;
}
export class CohereKeyProvider implements KeyProvider<CohereKey> {
readonly service = "cohere";
private keys: CohereKey[] = [];
private checker?: CohereKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.cohereKey?.trim();
if (!keyConfig) {
return;
}
const keys = keyConfig.split(",").map((k) => k.trim());
for (const key of keys) {
if (!key) continue;
this.keys.push({
key,
service: this.service,
modelFamilies: ["cohere"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
lastChecked: 0,
hash: this.hashKey(key),
rateLimitedAt: 0,
rateLimitedUntil: 0,
tokenUsage: {}, // Initialize new tokenUsage field
isOverQuota: false,
});
}
}
private hashKey(key: string): string {
return require("crypto").createHash("sha256").update(key).digest("hex");
}
public init() {
if (this.keys.length === 0) return;
if (!config.checkKeys) {
this.log.warn(
"Key checking is disabled. Keys will not be verified."
);
return;
}
this.checker = new CohereKeyChecker(this.update.bind(this));
for (const key of this.keys) {
void this.checker.checkKey(key);
}
}
public get(model: string): CohereKey {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Cohere keys available");
}
const key = availableKeys[Math.floor(Math.random() * availableKeys.length)];
key.lastUsed = Date.now();
this.throttle(key.hash);
return { ...key };
}
public list(): Omit<CohereKey, "key">[] {
return this.keys.map(({ key, ...rest }) => rest);
}
public disable(key: CohereKey): void {
const found = this.keys.find((k) => k.hash === key.hash);
if (found) {
found.isDisabled = true;
}
}
public update(hash: string, update: Partial<CohereKey>): void {
const key = this.keys.find((k) => k.hash === hash);
if (key) {
Object.assign(key, update);
}
}
public available(): number {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(keyHash: string, modelFamily: CohereModelFamily, usage: { input: number; output: number }) {
const key = this.keys.find((k) => k.hash === keyHash);
if (!key) return;
key.promptCount++;
if (!key.tokenUsage) {
key.tokenUsage = {};
}
// Cohere only has one model family "cohere"
if (!key.tokenUsage[modelFamily]) {
key.tokenUsage[modelFamily] = { input: 0, output: 0 };
}
const currentFamilyUsage = key.tokenUsage[modelFamily]!;
currentFamilyUsage.input += usage.input;
currentFamilyUsage.output += usage.output;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
private static readonly RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
private static readonly KEY_REUSE_DELAY = 500;
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + CohereKeyProvider.RATE_LIMIT_LOCKOUT;
}
public recheck(): void {
if (!this.checker || !config.checkKeys) return;
for (const key of this.keys) {
this.update(key.hash, {
isOverQuota: false,
isDisabled: false,
lastChecked: 0
});
void this.checker.checkKey(key);
}
}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + CohereKeyProvider.KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}