adds support for non-Anthropic models to AWS key manager

This commit is contained in:
nai-degen
2024-08-10 16:04:03 -05:00
parent a2d64e281e
commit 750dbee483
4 changed files with 72 additions and 96 deletions
+35 -42
View File
@@ -5,9 +5,21 @@ import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import { URL } from "url";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
import { AwsBedrockModelFamily } from "../../models";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { config } from "../../../config";
const KNOWN_MODEL_IDS = [
"anthropic.claude-v2",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0",
];
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const AMZ_HOST =
@@ -47,41 +59,20 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
}
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
checks = [
this.invokeModel("anthropic.claude-v2", key),
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key),
this.invokeModel("mistral.mistral-7b-instruct-v0:2", key),
this.invokeModel("mistral.mixtral-8x7b-instruct-v0:1", key),
this.invokeModel("mistral.mistral-large-2402-v1:0", key),
this.invokeModel("mistral.mistral-large-2407-v1:0", key),
this.invokeModel("mistral.mistral-small-2402-v1:0", key),
];
}
checks.unshift(this.checkLoggingConfiguration(key));
const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] =
await Promise.all(checks);
this.log.debug(
{ key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 },
"AWS model tests complete."
);
if (isInitialCheck) {
const families: AwsBedrockModelFamily[] = [];
if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude");
if (opus) families.push("aws-claude-opus");
const checks = await Promise.all(
KNOWN_MODEL_IDS.map(async (model) => {
const success = await this.invokeModel(model, key);
return { model, success };
})
);
const modelIds = checks
.filter(({ success }) => success)
.map(({ model }) => model);
if (families.length === 0) {
if (modelIds.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
@@ -90,20 +81,19 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
}
this.updateKey(key.hash, {
sonnetEnabled: sonnet,
haikuEnabled: haiku,
sonnet35Enabled: sonnet35,
modelFamilies: families,
modelIds,
modelFamilies: Array.from(
new Set(modelIds.map(getAwsBedrockModelFamily))
),
});
}
this.log.info(
this.log.debug(
{
key: key.hash,
sonnet,
haiku,
families: key.modelFamilies,
logged: key.awsLoggingStatus,
families: key.modelFamilies,
models: key.modelIds,
},
"Checked key."
);
@@ -174,7 +164,10 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: AwsBedrockKey) {
private async invokeModel(
model: string,
key: AwsBedrockKey
): Promise<boolean> {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model.
@@ -208,7 +201,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
) {
return false;
}
// ResourceNotFound typically indicates that the tested model cannot be used
// on the configured region for this set of credentials.
if (status === 404) {
+33 -51
View File
@@ -1,5 +1,5 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
@@ -13,10 +13,6 @@ type AwsBedrockKeyUsage = {
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
readonly service: "aws";
readonly modelFamilies: AwsBedrockModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* The confirmed logging status of this key. This is "unknown" until we
* receive a response from the AWS API. Keys which are logged, or not
@@ -24,9 +20,11 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
* set.
*/
awsLoggingStatus: "unknown" | "disabled" | "enabled";
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
// TODO: replace with list of model ids
// sonnetEnabled: boolean;
// haikuEnabled: boolean;
// sonnet35Enabled: boolean;
modelIds: string[];
}
/**
@@ -76,11 +74,16 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
sonnetEnabled: true,
haikuEnabled: false,
sonnet35Enabled: false,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
// sonnetEnabled: true,
// haikuEnabled: false,
// sonnet35Enabled: false,
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
["aws-mistral-smallTokens"]: 0,
["aws-mistral-mediumTokens"]: 0,
["aws-mistral-largeTokens"]: 0,
};
this.keys.push(newKey);
}
@@ -99,41 +102,35 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
}
public get(model: string) {
let neededVariantId = model;
// The only AWS model that breaks naming convention is Claude v2. Anthropic
// calls this claude-2 but AWS calls it claude-v2.
if (model.includes("claude-2")) neededVariantId = "claude-v2";
const neededFamily = getAwsBedrockModelFamily(model);
// this is a horrible mess
// each of these should be separate model families, but adding model
// families is not low enough friction for the rate at which aws claude
// model variants are added.
const needsSonnet35 =
model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude";
const needsSonnet =
!needsSonnet35 &&
model.includes("sonnet") &&
neededFamily === "aws-claude";
const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude";
const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus !== "enabled";
// Select keys which
return (
// are enabled
!k.isDisabled &&
(isNotLogged || config.allowAwsLogging) &&
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
(k.haikuEnabled || !needsHaiku) &&
(k.sonnet35Enabled || !needsSonnet35) &&
k.modelFamilies.includes(neededFamily)
// are not logged, unless policy allows it
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
// have access to the model family we need
k.modelFamilies.includes(neededFamily) &&
// have access to the specific variant we need
// note that requests can be made for the AWS ID or original vendor ID;
// all vendor IDs are substrings of the AWS ID.
k.modelIds.some((m) => m.includes(neededVariantId))
);
});
this.log.debug(
{
model,
neededFamily,
needsSonnet,
needsHaiku,
needsSonnet35,
availableKeys: availableKeys.length,
requestedModel: model,
selectedVariant: neededVariantId,
selectedFamily: neededFamily,
totalKeys: this.keys.length,
availableKeys: availableKeys.length,
},
"Selecting AWS key"
);
@@ -195,22 +192,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+3 -2
View File
@@ -85,8 +85,9 @@ export function createGenericGetLockoutPeriod<T extends Key>(
export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { GcpKey } from "./gcp/provider";
export { AzureOpenAIKey } from "./azure/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { MistralAIKey } from "./mistral-ai/provider";
export { OpenAIKey } from "./openai/provider";
+1 -1
View File
@@ -192,7 +192,7 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
// remove vendor and version from AWS model ids
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d)?(:\d+)*$/, "$2");
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
if (["claude", "anthropic"].some((x) => model.includes(x))) {
return `aws-${getClaudeModelFamily(deAwsified)}`;