improves reliability of inference profile detection for AWS keychecker

This commit is contained in:
nai-degen
2024-09-07 17:36:29 -05:00
parent 96fe974ad0
commit ac92a19946
2 changed files with 37 additions and 1 deletions
+32
View File
@@ -25,6 +25,7 @@ const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
];
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const AMZ_HOST =
@@ -77,6 +78,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
updateKey,
});
}
@@ -212,6 +214,36 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
key: AwsBedrockKey
): Promise<boolean> {
if (model.includes("claude")) {
// If inference profiles are available, try testing model with them.
// If they are not available or the invocation fails with the inference
// profile, fall back to regular model ID.
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
const continent = region.split("-")[0];
const profile = key.inferenceProfileIds.find(
(id) => `${continent}.${model}` === id
);
if (profile) {
this.log.debug(
{ key: key.hash, model, profile },
"Testing model via inference profile."
);
let result: boolean;
try {
result = await this.testClaudeModel(key, profile);
} catch (e) {
this.log.error(
{ key: key.hash, model, profile, error: e.message },
"Error testing model with inference profile; trying model ID directly."
);
result = false;
}
// If the profile worked, we'll return success. Caller will add the
// model (not the profile) to the list of enabled models, but the
// profile will be used when the key is used for inference.
if (result) return true;
}
return this.testClaudeModel(key, model);
} else if (model.includes("mistral")) {
return this.testMistralModel(key, model);
@@ -7,6 +7,7 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
keyCheckBatchSize?: number;
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
@@ -22,6 +23,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly keyCheckPeriod: number;
/** Maximum number of keys to check simultaneously. */
protected readonly keyCheckBatchSize: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
@@ -33,6 +36,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
this.keyCheckPeriod = opts.keyCheckPeriod;
this.minCheckInterval = opts.minCheckInterval;
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
this.updateKey = opts.updateKey;
this.service = opts.service;
this.log = logger.child({ module: "key-checker", service: opts.service });
@@ -78,7 +82,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, 12);
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
this.timeout = setTimeout(async () => {
try {