async aws checker (run with -awsasync)

will eventually replace the old one if people don't find any issue with it
This commit is contained in:
based
2024-05-12 04:36:36 +10:00
parent 229d528bf4
commit 5ff13d4fe3
2 changed files with 236 additions and 1 deletions
+220
View File
@@ -0,0 +1,220 @@
import APIKey
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
import xml.etree.ElementTree as ET
import asyncio
aws_regions = [
"us-east-1",
"us-west-2",
"ap-southeast-2",
"ap-south-1",
"eu-west-3",
]
async def check_aws(key: APIKey, session):
key.username = await get_username(key, session)
if key.username is None:
return
if not await test_invoke_perms(key, session, "anthropic.claude-3-sonnet-20240229-v1:0"):
return
if not key.bedrock_enabled and not key.useless:
await test_invoke_perms(key, session, "anthropic.claude-v2")
policies = await get_key_policies(key, session)
if not key.useless and key.bedrock_enabled:
await check_logging(key, session)
await retrieve_activated_models(key, session)
for region, models in key.models.items():
key.models[region] = list(set(models))
elif key.useless and policies is not None:
key.useless_reasons.append('Key policies lack Admin or User Creation perms')
return True
async def sign_request(key: APIKey, region, method, url, headers, data, service):
line = key.api_key.split(":")
access_key = line[0]
secret = line[1]
boto3_session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret,
region_name=region
)
credentials = boto3_session.get_credentials()
signer = SigV4Auth(credentials, service, region)
request = AWSRequest(method=method, url=url, headers=headers, data=data)
signer.add_auth(request)
return request.headers, request.data
async def test_invoke_perms(key: APIKey, session, model):
async def check_region(region):
host = f'bedrock-runtime.{region}.amazonaws.com'
url = f'https://{host}/model/{model}/invoke'
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'bedrock')
async with session.post(url, headers=signed_headers, data=signed_data) as response:
resp = await response.json()
if response.status == 403:
if resp['message'] and 'The request signature we calculated does not match the signature you provided' in resp['message'] or 'The security token included in the request is invalid' in resp['message']:
return False
elif response.status == 400 or response.status == 404:
if resp['message'] and 'Malformed input request' in resp['message']:
if key.region == "":
key.region = region
else:
key.alt_regions.append(region)
key.bedrock_enabled = True
else:
return False
key.useless = False
return True
headers = {
'content-type': 'application/json',
'accept': '*/*',
}
data = {
"prompt": "\n\nHuman:\n\nAssistant:",
"max_tokens_to_sample": -1,
}
tasks = [asyncio.create_task(check_region(region)) for region in aws_regions]
results = await asyncio.gather(*tasks)
return any(results)
async def get_username(key: APIKey, session):
region = 'us-east-1'
url = 'https://sts.amazonaws.com/'
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
}
data = 'Action=GetCallerIdentity&Version=2011-06-15'
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'sts')
async with session.post(url, headers=signed_headers, data=signed_data) as response:
resp = await response.text()
if 'ErrorResponse' in resp:
return
root = ET.fromstring(resp)
namespace = {'ns': 'https://sts.amazonaws.com/doc/2011-06-15/'}
arn = root.find('.//ns:Arn', namespaces=namespace).text
if not arn:
return "default"
username = arn.split('/')[-1]
if "iam::" in username:
return "default"
return username
async def get_key_policies(key: APIKey, session):
url = f'https://iam.amazonaws.com/?Action=ListAttachedUserPolicies&UserName={key.username}&Version=2010-05-08'
signed_headers, signed_data = await sign_request(key, 'us-east-1', 'GET', url, {}, None, 'iam')
async with session.get(url, headers=signed_headers) as response:
resp = await response.text()
root = ET.fromstring(resp)
namespace = {'iam': 'https://iam.amazonaws.com/doc/2010-05-08/'}
attached_policies = root.findall('.//iam:AttachedPolicies/iam:member', namespaces=namespace)
policy_names = []
if not attached_policies:
if not key.bedrock_enabled:
key.useless = True
key.useless_reasons.append('Failed Policy Fetch')
return
for policy in attached_policies:
policy_name = policy.find('iam:PolicyName', namespaces=namespace).text
policy_names.append(policy_name)
if policy_name == 'AdministratorAccess':
key.admin_priv = True
if 'AWSCompromisedKeyQuarantine' in policy_name:
if not key.bedrock_enabled:
key.useless = True
key.useless_reasons.append('Quarantined Key')
return policy_names
async def check_logging(key: APIKey, session):
region = key.region
host = f'bedrock.{region}.amazonaws.com'
url = f'https://{host}/logging/modelinvocations'
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
async with session.get(url, headers=signed_headers) as response:
if response.status == 200:
logging_config = await response.json()
if 'loggingConfig' in logging_config and logging_config['loggingConfig'] is not None and 'textDataDeliveryEnabled' in logging_config['loggingConfig']:
key.logged = logging_config['loggingConfig']['textDataDeliveryEnabled']
else:
key.logged = False
else:
key.logged = False
async def retrieve_activated_models(key: APIKey, session):
tasks = []
for region in aws_regions:
if region not in key.models:
key.models[region] = []
task = handle_region(key, session, region)
tasks.append(task)
await asyncio.gather(*tasks)
async def handle_region(key: APIKey, session, region):
listed_models = await retrieve_models(key, session, region)
tasks = [invoke_model(key, session, region, model) for model in listed_models]
await asyncio.gather(*tasks)
async def invoke_model(key: APIKey, session, region, model):
model_id, model_name = model
data = {
"prompt": "\n\nHuman:\n\nAssistant:",
"max_tokens_to_sample": -1,
}
host = f'bedrock-runtime.{region}.amazonaws.com'
url = f'https://{host}/model/{model_id}/invoke'
signed_headers, signed_data = await sign_request(key, region, 'POST', url, {'content-type': 'application/json', 'accept': '*/*'}, data, 'bedrock')
async with session.post(url, headers=signed_headers, data=signed_data) as response:
if response.status == 400:
resp = await response.json()
if resp['message'] and 'Malformed input request' in resp['message']:
key.models[region].append(model_name)
async def retrieve_models(key: APIKey, session, region):
host = f'bedrock.{region}.amazonaws.com'
url = f'https://{host}/foundation-models'
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
async with session.get(url, headers=signed_headers) as response:
if response.status == 200:
response = await response.json()
models = response["modelSummaries"]
model_providers = ["Meta", "Anthropic", "Mistral AI"]
model_info = []
for model in models:
provider_name = model["providerName"]
model_id = model["modelId"]
model_name = model["modelName"]
if provider_name in model_providers or (provider_name == "Cohere" and ("Command R+" in model_name or "Command R" in model_name)):
parts = model_id.split(":")
if len(parts) <= 2:
model_info.append((model_id, model_name))
return model_info
else:
return []
+16 -1
View File
@@ -19,6 +19,7 @@ import argparse
import os.path
import asyncio
import aiohttp
import AWSAsync
api_keys = set()
@@ -30,6 +31,7 @@ def parse_args():
parser.add_argument('-file', '--file', action='store', dest='file', help='read slop from a provided filename')
parser.add_argument('-verbose', '--verbose', action='store_true', help='watch as your slop is checked real time')
parser.add_argument('-awsmodels', '--awsmodels', action='store_true', help='output activated aws models for a key (warning: slow)')
parser.add_argument('-awsasync', '--awsasync', action='store_true', help='use the AWS REST API for checking keys instead of boto3 (way faster but not as well tested, -awsmodels autoapplies here due to the speedup)')
return parser.parse_args()
@@ -148,6 +150,16 @@ def validate_aws(key: APIKey):
api_keys.add(key)
async def validate_aws_async(key: APIKey, sem):
async with sem, aiohttp.ClientSession() as session:
IO.conditional_print(f"Checking AWS key asynchronously: {key.api_key}", args.verbose)
if await AWSAsync.check_aws(key, session) is None:
IO.conditional_print(f"Invalid AWS key: {key.api_key}", args.verbose)
return
IO.conditional_print(f"AWS key '{key.api_key}' is valid", args.verbose)
api_keys.add(key)
def validate_azure(key: APIKey):
IO.conditional_print(f"Checking Azure key: {key.api_key}", args.verbose)
if check_azure(key) is None:
@@ -220,7 +232,10 @@ async def validate_keys():
if not match:
continue
key_obj = APIKey(Provider.AWS, key)
futures.append(executor.submit(validate_aws, key_obj))
if args.awsasync:
tasks.append(validate_aws_async(key_obj, concurrent_connections))
else:
futures.append(executor.submit(validate_aws, key_obj))
elif ":" in key and "AKIA" not in key:
match = azure_regex.match(key)
if not match: