diff --git a/src/shared/key-management/google-ai/checker.ts b/src/shared/key-management/google-ai/checker.ts index 9f7b609..3f442cf 100644 --- a/src/shared/key-management/google-ai/checker.ts +++ b/src/shared/key-management/google-ai/checker.ts @@ -12,6 +12,9 @@ const LIST_MODELS_URL = "https://generativelanguage.googleapis.com/v1beta/models"; const GENERATE_CONTENT_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?key=%KEY%"; +const PRO_MODEL_ID = "gemini-2.5-pro-preview-05-06"; +const GENERATE_PRO_CONTENT_URL = + `https://generativelanguage.googleapis.com/v1beta/models/${PRO_MODEL_ID}:generateContent?key=%KEY%`; type ListModelsResponse = { models: { @@ -46,12 +49,27 @@ export class GoogleAIKeyChecker extends KeyCheckerBase { protected async testKeyOrFail(key: GoogleAIKey) { const provisionedModels = await this.getProvisionedModels(key); + + // Always test flash model access (existing behaviour) await this.testGenerateContent(key); - const updates = { modelFamilies: provisionedModels }; + // If key claims to support gemini-pro, perform a second layer test with a pro model. + let effectiveFamilies = [...provisionedModels]; + if (effectiveFamilies.includes("gemini-pro")) { + const proAccessible = await this.canAccessModel( + key, + GENERATE_PRO_CONTENT_URL + ); + if (!proAccessible) { + // Remove pro access if invocation fails + effectiveFamilies = effectiveFamilies.filter((f) => f !== "gemini-pro"); + } + } + + const updates = { modelFamilies: effectiveFamilies }; this.updateKey(key.hash, updates); this.log.info( - { key: key.hash, models: key.modelFamilies, ids: key.modelIds.length }, + { key: key.hash, models: effectiveFamilies, ids: key.modelIds?.length }, "Checked key." ); } @@ -94,6 +112,28 @@ export class GoogleAIKeyChecker extends KeyCheckerBase { ); } + private async canAccessModel( + key: GoogleAIKey, + modelGenerateUrlTemplate: string + ): Promise { + const payload = { + contents: [{ parts: { text: "hi" }, role: "user" }], + tools: [], + safetySettings: [], + generationConfig: { maxOutputTokens: 1 }, + }; + try { + await axios.post( + modelGenerateUrlTemplate.replace("%KEY%", key.key), + payload, + { validateStatus: (status) => status === 200 } + ); + return true; + } catch { + return false; + } + } + protected handleAxiosError(key: GoogleAIKey, error: AxiosError): void { if (error.response && GoogleAIKeyChecker.errorIsGoogleAIError(error)) { const httpStatus = error.response.status; @@ -103,12 +143,12 @@ export class GoogleAIKeyChecker extends KeyCheckerBase { case 400: { const keyDeadMsgs = [ /please enable billing/i, - /API key not valid/i, - /API key expired/i, - /pass a valid API/i, + /api key not valid/i, + /api key expired/i, + /pass a valid api/i, ]; const text = JSON.stringify(error.response.data.error); - if (text.match(keyDeadMsgs.join("|"))) { + if (keyDeadMsgs.some((r) => r.test(text))) { this.log.warn( { key: key.hash, error: text }, "Key check returned a non-transient 400 error. Disabling key." @@ -133,7 +173,7 @@ export class GoogleAIKeyChecker extends KeyCheckerBase { /GenerateContentRequestsPerMinutePerProjectPerRegion/i, /"quota_limit_value":"0"/i, ]; - if (keyDeadMsgs.some(r => r.test(text))) { + if (keyDeadMsgs.some((r) => r.test(text))) { this.log.warn( { key: key.hash, error: text }, "Key check returned a non-transient 429 error. Disabling key."