Clone keys assigned to multiple organizations (khanon/oai-reverse-proxy!38)

This commit is contained in:
khanon
2023-08-28 21:11:49 +00:00
parent 7c9c3a640c
commit 6833736392
3 changed files with 85 additions and 46 deletions
+44 -43
View File
@@ -13,36 +13,34 @@ const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const POST_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions";
const GET_MODELS_URL = "https://api.openai.com/v1/models";
const GET_SUBSCRIPTION_URL =
"https://api.openai.com/dashboard/billing/subscription";
const GET_ORGANIZATIONS_URL = "https://api.openai.com/v1/organizations";
type GetModelsResponse = {
data: [{ id: string }];
};
type GetSubscriptionResponse = {
plan: { title: string };
has_payment_method: boolean;
soft_limit_usd: number;
hard_limit_usd: number;
system_hard_limit_usd: number;
type GetOrganizationsResponse = {
data: [{ id: string; is_default: boolean }];
};
type OpenAIError = {
error: { type: string; code: string; param: unknown; message: string };
};
type CloneFn = typeof OpenAIKeyProvider.prototype.clone;
type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
export class OpenAIKeyChecker {
private readonly keys: OpenAIKey[];
private log = logger.child({ module: "key-checker", service: "openai" });
private timeout?: NodeJS.Timeout;
private cloneKey: CloneFn;
private updateKey: UpdateFn;
private lastCheck = 0;
constructor(keys: OpenAIKey[], updateKey: UpdateFn) {
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
this.keys = keys;
this.cloneKey = cloneFn;
this.updateKey = updateKey;
}
@@ -131,17 +129,13 @@ export class OpenAIKeyChecker {
try {
// We only need to check for provisioned models on the initial check.
if (isInitialCheck) {
const [/* subscription,*/ provisionedModels, livenessTest] =
await Promise.all([
// this.getSubscription(key),
this.getProvisionedModels(key),
this.testLiveness(key),
]);
const [provisionedModels, livenessTest] = await Promise.all([
this.getProvisionedModels(key),
this.testLiveness(key),
this.maybeCreateOrganizationClones(key),
]);
const updates = {
isGpt4: provisionedModels.gpt4,
// softLimit: subscription.soft_limit_usd,
// hardLimit: subscription.hard_limit_usd,
// systemHardLimit: subscription.system_hard_limit_usd,
isTrial: livenessTest.rateLimit <= 250,
softLimit: 0,
hardLimit: 0,
@@ -150,18 +144,8 @@ export class OpenAIKeyChecker {
this.updateKey(key.hash, updates);
} else {
// Provisioned models don't change, so we don't need to check them again
const [/* subscription, */ _livenessTest] = await Promise.all([
// this.getSubscription(key),
this.testLiveness(key),
]);
const updates = {
// softLimit: subscription.soft_limit_usd,
// hardLimit: subscription.hard_limit_usd,
// systemHardLimit: subscription.system_hard_limit_usd,
softLimit: 0,
hardLimit: 0,
systemHardLimit: 0,
};
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
const updates = { softLimit: 0, hardLimit: 0, systemHardLimit: 0 };
this.updateKey(key.hash, updates);
}
this.log.info({ key: key.hash }, "Key check complete.");
@@ -182,7 +166,7 @@ export class OpenAIKeyChecker {
private async getProvisionedModels(
key: OpenAIKey
): Promise<{ turbo: boolean; gpt4: boolean }> {
const opts = { headers: { Authorization: `Bearer ${key.key}` } };
const opts = { headers: OpenAIKeyChecker.getHeaders(key) };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const turbo = models.some(({ id }) => id.startsWith("gpt-3.5"));
@@ -200,18 +184,27 @@ export class OpenAIKeyChecker {
return { turbo, gpt4 };
}
private async getSubscription(key: OpenAIKey) {
const { data } = await axios.get<GetSubscriptionResponse>(
GET_SUBSCRIPTION_URL,
{ headers: { Authorization: `Bearer ${key.key}` } }
private async maybeCreateOrganizationClones(key: OpenAIKey) {
if (key.organizationId) return; // already cloned
const opts = { headers: { Authorization: `Bearer ${key.key}` } };
const { data } = await axios.get<GetOrganizationsResponse>(
GET_ORGANIZATIONS_URL,
opts
);
// See note above about updating the key's `lastChecked` timestamp.
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
isTrial: !data.has_payment_method,
lastChecked: keyFromPool.lastChecked,
});
return data;
const organizations = data.data;
if (organizations.length <= 1) return undefined;
this.log.info(
{ parent: key.hash, organizations: organizations.map((org) => org.id) },
"Key is associated with multiple organizations; cloning key for each organization."
);
const defaultOrg = organizations.find(({ is_default }) => is_default);
const ids = organizations
.filter(({ is_default }) => !is_default)
.map(({ id }) => id);
this.updateKey(key.hash, { organizationId: defaultOrg?.id });
this.cloneKey(key.hash, ids);
}
private handleAxiosError(key: OpenAIKey, error: AxiosError) {
@@ -318,7 +311,7 @@ export class OpenAIKeyChecker {
POST_CHAT_COMPLETIONS_URL,
payload,
{
headers: { Authorization: `Bearer ${key.key}` },
headers: OpenAIKeyChecker.getHeaders(key),
validateStatus: (status) => status === 400,
}
);
@@ -341,4 +334,12 @@ export class OpenAIKeyChecker {
const data = error.response?.data as any;
return data?.error?.type;
}
static getHeaders(key: OpenAIKey) {
const headers = {
Authorization: `Bearer ${key.key}`,
...(key.organizationId && { "OpenAI-Organization": key.organizationId }),
};
return headers;
}
}
+32 -1
View File
@@ -18,6 +18,12 @@ export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
export interface OpenAIKey extends Key {
readonly service: "openai";
/**
* Some keys are assigned to multiple organizations, each with their own quota
* limits. We clone the key for each organization and track usage/disabled
* status separately.
*/
organizationId?: string;
/** Set when key check returns a 401. */
isRevoked: boolean;
/** Set when key check returns a non-transient 429. */
@@ -107,7 +113,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
public init() {
if (config.checkKeys) {
this.checker = new OpenAIKeyChecker(this.keys, this.update.bind(this));
const cloneFn = this.clone.bind(this);
const updateFn = this.update.bind(this);
this.checker = new OpenAIKeyChecker(this.keys, cloneFn, updateFn);
this.checker.start();
}
}
@@ -191,6 +199,29 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
// this.writeKeyStatus();
}
/** Called by the key checker to create clones of keys for the given orgs. */
public clone(keyHash: string, newOrgIds: string[]) {
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
const clones = newOrgIds.map((orgId) => {
const clone: OpenAIKey = {
...keyFromPool,
organizationId: orgId,
hash: `oai-${crypto
.createHash("sha256")
.update(keyFromPool.key + orgId)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0, // Force re-check in case the org has different models
};
this.log.info(
{ cloneHash: clone.hash, parentHash: keyFromPool.hash, orgId },
"Cloned organization key"
);
return clone;
});
this.keys.push(...clones);
}
/** Disables a key, or does nothing if the key isn't in this pool. */
public disable(key: Key) {
const keyFromPool = this.keys.find((k) => k.key === key.key);
+9 -2
View File
@@ -1,4 +1,4 @@
import { Key, keyPool } from "../../../key-management";
import { Key, OpenAIKey, keyPool } from "../../../key-management";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
@@ -57,9 +57,16 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
"Assigned key to request"
);
// TODO: KeyProvider should assemble all necessary headers
if (assignedKey.service === "anthropic") {
proxyReq.setHeader("X-API-Key", assignedKey.key);
} else {
} else if (assignedKey.service === "openai") {
const key: OpenAIKey = assignedKey as OpenAIKey;
if (key.organizationId) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
}
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
} else {
throw new Error(`Unknown service '${assignedKey.service}'`);
}
};