Clone keys assigned to multiple organizations (khanon/oai-reverse-proxy!38)
This commit is contained in:
@@ -13,36 +13,34 @@ const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||
|
||||
const POST_CHAT_COMPLETIONS_URL = "https://api.openai.com/v1/chat/completions";
|
||||
const GET_MODELS_URL = "https://api.openai.com/v1/models";
|
||||
const GET_SUBSCRIPTION_URL =
|
||||
"https://api.openai.com/dashboard/billing/subscription";
|
||||
const GET_ORGANIZATIONS_URL = "https://api.openai.com/v1/organizations";
|
||||
|
||||
type GetModelsResponse = {
|
||||
data: [{ id: string }];
|
||||
};
|
||||
|
||||
type GetSubscriptionResponse = {
|
||||
plan: { title: string };
|
||||
has_payment_method: boolean;
|
||||
soft_limit_usd: number;
|
||||
hard_limit_usd: number;
|
||||
system_hard_limit_usd: number;
|
||||
type GetOrganizationsResponse = {
|
||||
data: [{ id: string; is_default: boolean }];
|
||||
};
|
||||
|
||||
type OpenAIError = {
|
||||
error: { type: string; code: string; param: unknown; message: string };
|
||||
};
|
||||
|
||||
type CloneFn = typeof OpenAIKeyProvider.prototype.clone;
|
||||
type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class OpenAIKeyChecker {
|
||||
private readonly keys: OpenAIKey[];
|
||||
private log = logger.child({ module: "key-checker", service: "openai" });
|
||||
private timeout?: NodeJS.Timeout;
|
||||
private cloneKey: CloneFn;
|
||||
private updateKey: UpdateFn;
|
||||
private lastCheck = 0;
|
||||
|
||||
constructor(keys: OpenAIKey[], updateKey: UpdateFn) {
|
||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||
this.keys = keys;
|
||||
this.cloneKey = cloneFn;
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
@@ -131,17 +129,13 @@ export class OpenAIKeyChecker {
|
||||
try {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
if (isInitialCheck) {
|
||||
const [/* subscription,*/ provisionedModels, livenessTest] =
|
||||
await Promise.all([
|
||||
// this.getSubscription(key),
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
]);
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
isGpt4: provisionedModels.gpt4,
|
||||
// softLimit: subscription.soft_limit_usd,
|
||||
// hardLimit: subscription.hard_limit_usd,
|
||||
// systemHardLimit: subscription.system_hard_limit_usd,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
softLimit: 0,
|
||||
hardLimit: 0,
|
||||
@@ -150,18 +144,8 @@ export class OpenAIKeyChecker {
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// Provisioned models don't change, so we don't need to check them again
|
||||
const [/* subscription, */ _livenessTest] = await Promise.all([
|
||||
// this.getSubscription(key),
|
||||
this.testLiveness(key),
|
||||
]);
|
||||
const updates = {
|
||||
// softLimit: subscription.soft_limit_usd,
|
||||
// hardLimit: subscription.hard_limit_usd,
|
||||
// systemHardLimit: subscription.system_hard_limit_usd,
|
||||
softLimit: 0,
|
||||
hardLimit: 0,
|
||||
systemHardLimit: 0,
|
||||
};
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { softLimit: 0, hardLimit: 0, systemHardLimit: 0 };
|
||||
this.updateKey(key.hash, updates);
|
||||
}
|
||||
this.log.info({ key: key.hash }, "Key check complete.");
|
||||
@@ -182,7 +166,7 @@ export class OpenAIKeyChecker {
|
||||
private async getProvisionedModels(
|
||||
key: OpenAIKey
|
||||
): Promise<{ turbo: boolean; gpt4: boolean }> {
|
||||
const opts = { headers: { Authorization: `Bearer ${key.key}` } };
|
||||
const opts = { headers: OpenAIKeyChecker.getHeaders(key) };
|
||||
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||
const models = data.data;
|
||||
const turbo = models.some(({ id }) => id.startsWith("gpt-3.5"));
|
||||
@@ -200,18 +184,27 @@ export class OpenAIKeyChecker {
|
||||
return { turbo, gpt4 };
|
||||
}
|
||||
|
||||
private async getSubscription(key: OpenAIKey) {
|
||||
const { data } = await axios.get<GetSubscriptionResponse>(
|
||||
GET_SUBSCRIPTION_URL,
|
||||
{ headers: { Authorization: `Bearer ${key.key}` } }
|
||||
private async maybeCreateOrganizationClones(key: OpenAIKey) {
|
||||
if (key.organizationId) return; // already cloned
|
||||
const opts = { headers: { Authorization: `Bearer ${key.key}` } };
|
||||
const { data } = await axios.get<GetOrganizationsResponse>(
|
||||
GET_ORGANIZATIONS_URL,
|
||||
opts
|
||||
);
|
||||
// See note above about updating the key's `lastChecked` timestamp.
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||
this.updateKey(key.hash, {
|
||||
isTrial: !data.has_payment_method,
|
||||
lastChecked: keyFromPool.lastChecked,
|
||||
});
|
||||
return data;
|
||||
const organizations = data.data;
|
||||
if (organizations.length <= 1) return undefined;
|
||||
|
||||
this.log.info(
|
||||
{ parent: key.hash, organizations: organizations.map((org) => org.id) },
|
||||
"Key is associated with multiple organizations; cloning key for each organization."
|
||||
);
|
||||
|
||||
const defaultOrg = organizations.find(({ is_default }) => is_default);
|
||||
const ids = organizations
|
||||
.filter(({ is_default }) => !is_default)
|
||||
.map(({ id }) => id);
|
||||
this.updateKey(key.hash, { organizationId: defaultOrg?.id });
|
||||
this.cloneKey(key.hash, ids);
|
||||
}
|
||||
|
||||
private handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
||||
@@ -318,7 +311,7 @@ export class OpenAIKeyChecker {
|
||||
POST_CHAT_COMPLETIONS_URL,
|
||||
payload,
|
||||
{
|
||||
headers: { Authorization: `Bearer ${key.key}` },
|
||||
headers: OpenAIKeyChecker.getHeaders(key),
|
||||
validateStatus: (status) => status === 400,
|
||||
}
|
||||
);
|
||||
@@ -341,4 +334,12 @@ export class OpenAIKeyChecker {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.type;
|
||||
}
|
||||
|
||||
static getHeaders(key: OpenAIKey) {
|
||||
const headers = {
|
||||
Authorization: `Bearer ${key.key}`,
|
||||
...(key.organizationId && { "OpenAI-Organization": key.organizationId }),
|
||||
};
|
||||
return headers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,12 @@ export const OPENAI_SUPPORTED_MODELS: readonly OpenAIModel[] = [
|
||||
|
||||
export interface OpenAIKey extends Key {
|
||||
readonly service: "openai";
|
||||
/**
|
||||
* Some keys are assigned to multiple organizations, each with their own quota
|
||||
* limits. We clone the key for each organization and track usage/disabled
|
||||
* status separately.
|
||||
*/
|
||||
organizationId?: string;
|
||||
/** Set when key check returns a 401. */
|
||||
isRevoked: boolean;
|
||||
/** Set when key check returns a non-transient 429. */
|
||||
@@ -107,7 +113,9 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new OpenAIKeyChecker(this.keys, this.update.bind(this));
|
||||
const cloneFn = this.clone.bind(this);
|
||||
const updateFn = this.update.bind(this);
|
||||
this.checker = new OpenAIKeyChecker(this.keys, cloneFn, updateFn);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
@@ -191,6 +199,29 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
// this.writeKeyStatus();
|
||||
}
|
||||
|
||||
/** Called by the key checker to create clones of keys for the given orgs. */
|
||||
public clone(keyHash: string, newOrgIds: string[]) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const clones = newOrgIds.map((orgId) => {
|
||||
const clone: OpenAIKey = {
|
||||
...keyFromPool,
|
||||
organizationId: orgId,
|
||||
hash: `oai-${crypto
|
||||
.createHash("sha256")
|
||||
.update(keyFromPool.key + orgId)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0, // Force re-check in case the org has different models
|
||||
};
|
||||
this.log.info(
|
||||
{ cloneHash: clone.hash, parentHash: keyFromPool.hash, orgId },
|
||||
"Cloned organization key"
|
||||
);
|
||||
return clone;
|
||||
});
|
||||
this.keys.push(...clones);
|
||||
}
|
||||
|
||||
/** Disables a key, or does nothing if the key isn't in this pool. */
|
||||
public disable(key: Key) {
|
||||
const keyFromPool = this.keys.find((k) => k.key === key.key);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Key, keyPool } from "../../../key-management";
|
||||
import { Key, OpenAIKey, keyPool } from "../../../key-management";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
@@ -57,9 +57,16 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
"Assigned key to request"
|
||||
);
|
||||
|
||||
// TODO: KeyProvider should assemble all necessary headers
|
||||
if (assignedKey.service === "anthropic") {
|
||||
proxyReq.setHeader("X-API-Key", assignedKey.key);
|
||||
} else {
|
||||
} else if (assignedKey.service === "openai") {
|
||||
const key: OpenAIKey = assignedKey as OpenAIKey;
|
||||
if (key.organizationId) {
|
||||
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
||||
}
|
||||
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
||||
} else {
|
||||
throw new Error(`Unknown service '${assignedKey.service}'`);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user