implement catching and auto retrying when a service decides to shit on a connection due to rate limits or other things mid keycheck

This commit is contained in:
based
2024-06-03 21:27:42 +10:00
parent c286038efd
commit 2e235ac44c
+27 -8
View File
@@ -76,7 +76,8 @@ async def validate_openai(key: APIKey, sem):
api_keys.update(cloned_keys)
async def validate_anthropic(key: APIKey, retry_count, sem):
async def validate_anthropic(key: APIKey, sem):
retry_count = 20
async with sem, aiohttp.ClientSession() as session:
IO.conditional_print(f"Checking Anthropic key: {key.api_key}", args.verbose)
key_status = await check_anthropic(key, session)
@@ -177,6 +178,24 @@ def validate_vertexai(key: APIKey):
api_keys.add(key)
async def execute_with_retries(func, key, sem, retries):
attempt = 0
while attempt < retries:
try:
return await func(key, sem)
except (aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
attempt += 1
print(f"Attempt {attempt}/{retries} failed for {key.api_key}: {str(e)}")
if attempt < retries:
print(f"Retrying after 5 seconds...")
await asyncio.sleep(5)
else:
print(f"Failed to validate key {key.api_key} after {retries} attempts.")
except Exception as e:
print(f"Unexpected error occurred for key {key.api_key}: {str(e)}")
break
oai_regex = re.compile('(sk-[A-Za-z0-9]{20}T3BlbkFJ[A-Za-z0-9]{20})')
oai_secondary_regex = re.compile('(sk-proj-[A-Za-z0-9]{20}T3BlbkFJ[A-Za-z0-9]{20})')
anthropic_regex = re.compile(r'sk-ant-api03-[A-Za-z0-9\-_]{93}AA')
@@ -207,25 +226,25 @@ async def validate_keys():
if not match:
continue
key_obj = APIKey(Provider.ANTHROPIC, key)
tasks.append(validate_anthropic(key_obj, 20, concurrent_connections))
tasks.append(execute_with_retries(validate_anthropic, key_obj, concurrent_connections, 5))
elif "AIzaSy" in key[:6]:
match = makersuite_regex.match(key)
if not match:
continue
key_obj = APIKey(Provider.MAKERSUITE, key)
tasks.append(validate_makersuite(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_makersuite, key_obj, concurrent_connections, 5))
elif "sk-or-v1-" in key:
match = openrouter_regex.match(key)
if not match:
continue
key_obj = APIKey(Provider.OPENROUTER, key)
tasks.append(validate_openrouter(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_openrouter, key_obj, concurrent_connections, 5))
elif "sk-" in key:
match = oai_secondary_regex.match(key) if "-proj-" in key else oai_regex.match(key)
if not match:
continue
key_obj = APIKey(Provider.OPENAI, key)
tasks.append(validate_openai(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_openai, key_obj, concurrent_connections, 5))
elif ":" and "AKIA" in key:
match = aws_regex.match(key)
if not match:
@@ -234,7 +253,7 @@ async def validate_keys():
if args.awslegacy:
futures.append(executor.submit(validate_aws, key_obj))
else:
tasks.append(validate_aws_async(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_aws_async, key_obj, concurrent_connections, 5))
elif ":" in key and "AKIA" not in key:
match = azure_regex.match(key)
if not match:
@@ -248,10 +267,10 @@ async def validate_keys():
if not match:
continue
key_obj = APIKey(Provider.AI21, key)
tasks.append(validate_ai21_and_mistral(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_ai21_and_mistral, key_obj, concurrent_connections, 5))
else:
key_obj = APIKey(Provider.ELEVENLABS, key)
tasks.append(validate_elevenlabs(key_obj, concurrent_connections))
tasks.append(execute_with_retries(validate_elevenlabs, key_obj, concurrent_connections, 5))
results = await asyncio.gather(*tasks)
for result in results:
if result is not None: