diff --git a/src/shared/key-management/aws/checker.ts b/src/shared/key-management/aws/checker.ts index 09e209a..6541286 100644 --- a/src/shared/key-management/aws/checker.ts +++ b/src/shared/key-management/aws/checker.ts @@ -25,6 +25,7 @@ const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [ ["mistral.mistral-small-2402-v1:0"], // Seems to return 400 ]; +const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const AMZ_HOST = @@ -77,6 +78,7 @@ export class AwsKeyChecker extends KeyCheckerBase { service: "aws", keyCheckPeriod: KEY_CHECK_PERIOD, minCheckInterval: MIN_CHECK_INTERVAL, + keyCheckBatchSize: KEY_CHECK_BATCH_SIZE, updateKey, }); } @@ -212,6 +214,36 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference- key: AwsBedrockKey ): Promise { if (model.includes("claude")) { + // If inference profiles are available, try testing model with them. + // If they are not available or the invocation fails with the inference + // profile, fall back to regular model ID. + const { region } = AwsKeyChecker.getCredentialsFromKey(key); + const continent = region.split("-")[0]; + const profile = key.inferenceProfileIds.find( + (id) => `${continent}.${model}` === id + ); + + if (profile) { + this.log.debug( + { key: key.hash, model, profile }, + "Testing model via inference profile." + ); + let result: boolean; + try { + result = await this.testClaudeModel(key, profile); + } catch (e) { + this.log.error( + { key: key.hash, model, profile, error: e.message }, + "Error testing model with inference profile; trying model ID directly." + ); + result = false; + } + + // If the profile worked, we'll return success. Caller will add the + // model (not the profile) to the list of enabled models, but the + // profile will be used when the key is used for inference. + if (result) return true; + } return this.testClaudeModel(key, model); } else if (model.includes("mistral")) { return this.testMistralModel(key, model); diff --git a/src/shared/key-management/key-checker-base.ts b/src/shared/key-management/key-checker-base.ts index e740bce..5d388e7 100644 --- a/src/shared/key-management/key-checker-base.ts +++ b/src/shared/key-management/key-checker-base.ts @@ -7,6 +7,7 @@ type KeyCheckerOptions = { service: string; keyCheckPeriod: number; minCheckInterval: number; + keyCheckBatchSize?: number; recurringChecksEnabled?: boolean; updateKey: (hash: string, props: Partial) => void; }; @@ -22,6 +23,8 @@ export abstract class KeyCheckerBase { * than this. */ protected readonly keyCheckPeriod: number; + /** Maximum number of keys to check simultaneously. */ + protected readonly keyCheckBatchSize: number; protected readonly updateKey: (hash: string, props: Partial) => void; protected readonly keys: TKey[] = []; protected log: pino.Logger; @@ -33,6 +36,7 @@ export abstract class KeyCheckerBase { this.keyCheckPeriod = opts.keyCheckPeriod; this.minCheckInterval = opts.minCheckInterval; this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true; + this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12; this.updateKey = opts.updateKey; this.service = opts.service; this.log = logger.child({ module: "key-checker", service: opts.service }); @@ -78,7 +82,7 @@ export abstract class KeyCheckerBase { checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check..."); if (numUnchecked > 0) { - const keycheckBatch = uncheckedKeys.slice(0, 12); + const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize); this.timeout = setTimeout(async () => { try {