adds support for non-Anthropic models to AWS key manager
This commit is contained in:
@@ -5,9 +5,21 @@ import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
|
||||
import { URL } from "url";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
||||
import { AwsBedrockModelFamily } from "../../models";
|
||||
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
|
||||
import { config } from "../../../config";
|
||||
|
||||
const KNOWN_MODEL_IDS = [
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
"mistral.mistral-small-2402-v1:0",
|
||||
];
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||
const AMZ_HOST =
|
||||
@@ -47,41 +59,20 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
checks = [
|
||||
this.invokeModel("anthropic.claude-v2", key),
|
||||
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-5-sonnet-20240620-v1:0", key),
|
||||
this.invokeModel("mistral.mistral-7b-instruct-v0:2", key),
|
||||
this.invokeModel("mistral.mixtral-8x7b-instruct-v0:1", key),
|
||||
this.invokeModel("mistral.mistral-large-2402-v1:0", key),
|
||||
this.invokeModel("mistral.mistral-large-2407-v1:0", key),
|
||||
this.invokeModel("mistral.mistral-small-2402-v1:0", key),
|
||||
];
|
||||
}
|
||||
|
||||
checks.unshift(this.checkLoggingConfiguration(key));
|
||||
|
||||
const [_logging, claudeV2, sonnet, haiku, opus, sonnet35] =
|
||||
await Promise.all(checks);
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, _logging, claudeV2, sonnet, haiku, opus, sonnet35 },
|
||||
"AWS model tests complete."
|
||||
);
|
||||
|
||||
if (isInitialCheck) {
|
||||
const families: AwsBedrockModelFamily[] = [];
|
||||
if (claudeV2 || sonnet || sonnet35 || haiku) families.push("aws-claude");
|
||||
if (opus) families.push("aws-claude-opus");
|
||||
const checks = await Promise.all(
|
||||
KNOWN_MODEL_IDS.map(async (model) => {
|
||||
const success = await this.invokeModel(model, key);
|
||||
return { model, success };
|
||||
})
|
||||
);
|
||||
const modelIds = checks
|
||||
.filter(({ success }) => success)
|
||||
.map(({ model }) => model);
|
||||
|
||||
if (families.length === 0) {
|
||||
if (modelIds.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
@@ -90,20 +81,19 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
sonnetEnabled: sonnet,
|
||||
haikuEnabled: haiku,
|
||||
sonnet35Enabled: sonnet35,
|
||||
modelFamilies: families,
|
||||
modelIds,
|
||||
modelFamilies: Array.from(
|
||||
new Set(modelIds.map(getAwsBedrockModelFamily))
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
this.log.info(
|
||||
this.log.debug(
|
||||
{
|
||||
key: key.hash,
|
||||
sonnet,
|
||||
haiku,
|
||||
families: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
families: key.modelFamilies,
|
||||
models: key.modelIds,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
@@ -174,7 +164,10 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(model: string, key: AwsBedrockKey) {
|
||||
private async invokeModel(
|
||||
model: string,
|
||||
key: AwsBedrockKey
|
||||
): Promise<boolean> {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
// This is not a valid invocation payload, but a 400 response indicates that
|
||||
// the principal at least has permission to invoke the model.
|
||||
@@ -208,7 +201,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// ResourceNotFound typically indicates that the tested model cannot be used
|
||||
// on the configured region for this set of credentials.
|
||||
if (status === 404) {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
|
||||
@@ -13,10 +13,6 @@ type AwsBedrockKeyUsage = {
|
||||
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
readonly service: "aws";
|
||||
readonly modelFamilies: AwsBedrockModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/**
|
||||
* The confirmed logging status of this key. This is "unknown" until we
|
||||
* receive a response from the AWS API. Keys which are logged, or not
|
||||
@@ -24,9 +20,11 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
* set.
|
||||
*/
|
||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
// TODO: replace with list of model ids
|
||||
// sonnetEnabled: boolean;
|
||||
// haikuEnabled: boolean;
|
||||
// sonnet35Enabled: boolean;
|
||||
modelIds: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,11 +74,16 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
sonnet35Enabled: false,
|
||||
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
// sonnetEnabled: true,
|
||||
// haikuEnabled: false,
|
||||
// sonnet35Enabled: false,
|
||||
["aws-claudeTokens"]: 0,
|
||||
["aws-claude-opusTokens"]: 0,
|
||||
["aws-mistral-tinyTokens"]: 0,
|
||||
["aws-mistral-smallTokens"]: 0,
|
||||
["aws-mistral-mediumTokens"]: 0,
|
||||
["aws-mistral-largeTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
@@ -99,41 +102,35 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
let neededVariantId = model;
|
||||
// The only AWS model that breaks naming convention is Claude v2. Anthropic
|
||||
// calls this claude-2 but AWS calls it claude-v2.
|
||||
if (model.includes("claude-2")) neededVariantId = "claude-v2";
|
||||
const neededFamily = getAwsBedrockModelFamily(model);
|
||||
|
||||
// this is a horrible mess
|
||||
// each of these should be separate model families, but adding model
|
||||
// families is not low enough friction for the rate at which aws claude
|
||||
// model variants are added.
|
||||
const needsSonnet35 =
|
||||
model.includes("claude-3-5-sonnet") && neededFamily === "aws-claude";
|
||||
const needsSonnet =
|
||||
!needsSonnet35 &&
|
||||
model.includes("sonnet") &&
|
||||
neededFamily === "aws-claude";
|
||||
const needsHaiku = model.includes("haiku") && neededFamily === "aws-claude";
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
const isNotLogged = k.awsLoggingStatus !== "enabled";
|
||||
// Select keys which
|
||||
return (
|
||||
// are enabled
|
||||
!k.isDisabled &&
|
||||
(isNotLogged || config.allowAwsLogging) &&
|
||||
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
|
||||
(k.haikuEnabled || !needsHaiku) &&
|
||||
(k.sonnet35Enabled || !needsSonnet35) &&
|
||||
k.modelFamilies.includes(neededFamily)
|
||||
// are not logged, unless policy allows it
|
||||
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
|
||||
// have access to the model family we need
|
||||
k.modelFamilies.includes(neededFamily) &&
|
||||
// have access to the specific variant we need
|
||||
// note that requests can be made for the AWS ID or original vendor ID;
|
||||
// all vendor IDs are substrings of the AWS ID.
|
||||
k.modelIds.some((m) => m.includes(neededVariantId))
|
||||
);
|
||||
});
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
model,
|
||||
neededFamily,
|
||||
needsSonnet,
|
||||
needsHaiku,
|
||||
needsSonnet35,
|
||||
availableKeys: availableKeys.length,
|
||||
requestedModel: model,
|
||||
selectedVariant: neededVariantId,
|
||||
selectedFamily: neededFamily,
|
||||
totalKeys: this.keys.length,
|
||||
availableKeys: availableKeys.length,
|
||||
},
|
||||
"Selecting AWS key"
|
||||
);
|
||||
@@ -195,22 +192,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -85,8 +85,9 @@ export function createGenericGetLockoutPeriod<T extends Key>(
|
||||
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { GcpKey } from "./gcp/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { MistralAIKey } from "./mistral-ai/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
|
||||
@@ -192,7 +192,7 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
|
||||
// remove vendor and version from AWS model ids
|
||||
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
|
||||
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d)?(:\d+)*$/, "$2");
|
||||
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
|
||||
|
||||
if (["claude", "anthropic"].some((x) => model.includes(x))) {
|
||||
return `aws-${getClaudeModelFamily(deAwsified)}`;
|
||||
|
||||
Reference in New Issue
Block a user