From 5ff13d4fe3f02ceab9315b50bb8e934dae0478e4 Mon Sep 17 00:00:00 2001 From: based Date: Sun, 12 May 2024 04:36:36 +1000 Subject: [PATCH] async aws checker (run with -awsasync) will eventually replace the old one if people don't find any issue with it --- AWSAsync.py | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++ main.py | 17 +++- 2 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 AWSAsync.py diff --git a/AWSAsync.py b/AWSAsync.py new file mode 100644 index 0000000..ad53a0e --- /dev/null +++ b/AWSAsync.py @@ -0,0 +1,220 @@ +import APIKey +import boto3 +from botocore.auth import SigV4Auth +from botocore.awsrequest import AWSRequest +import xml.etree.ElementTree as ET +import asyncio + + +aws_regions = [ + "us-east-1", + "us-west-2", + "ap-southeast-2", + "ap-south-1", + "eu-west-3", +] + + +async def check_aws(key: APIKey, session): + key.username = await get_username(key, session) + if key.username is None: + return + + if not await test_invoke_perms(key, session, "anthropic.claude-3-sonnet-20240229-v1:0"): + return + + if not key.bedrock_enabled and not key.useless: + await test_invoke_perms(key, session, "anthropic.claude-v2") + + policies = await get_key_policies(key, session) + if not key.useless and key.bedrock_enabled: + await check_logging(key, session) + await retrieve_activated_models(key, session) + for region, models in key.models.items(): + key.models[region] = list(set(models)) + elif key.useless and policies is not None: + key.useless_reasons.append('Key policies lack Admin or User Creation perms') + return True + + +async def sign_request(key: APIKey, region, method, url, headers, data, service): + line = key.api_key.split(":") + access_key = line[0] + secret = line[1] + boto3_session = boto3.Session( + aws_access_key_id=access_key, + aws_secret_access_key=secret, + region_name=region + ) + credentials = boto3_session.get_credentials() + signer = SigV4Auth(credentials, service, region) + request = AWSRequest(method=method, url=url, headers=headers, data=data) + signer.add_auth(request) + return request.headers, request.data + + +async def test_invoke_perms(key: APIKey, session, model): + async def check_region(region): + host = f'bedrock-runtime.{region}.amazonaws.com' + url = f'https://{host}/model/{model}/invoke' + + signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'bedrock') + async with session.post(url, headers=signed_headers, data=signed_data) as response: + resp = await response.json() + if response.status == 403: + if resp['message'] and 'The request signature we calculated does not match the signature you provided' in resp['message'] or 'The security token included in the request is invalid' in resp['message']: + return False + elif response.status == 400 or response.status == 404: + if resp['message'] and 'Malformed input request' in resp['message']: + if key.region == "": + key.region = region + else: + key.alt_regions.append(region) + key.bedrock_enabled = True + else: + return False + key.useless = False + return True + + headers = { + 'content-type': 'application/json', + 'accept': '*/*', + } + data = { + "prompt": "\n\nHuman:\n\nAssistant:", + "max_tokens_to_sample": -1, + } + + tasks = [asyncio.create_task(check_region(region)) for region in aws_regions] + results = await asyncio.gather(*tasks) + return any(results) + + +async def get_username(key: APIKey, session): + region = 'us-east-1' + url = 'https://sts.amazonaws.com/' + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + } + data = 'Action=GetCallerIdentity&Version=2011-06-15' + signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'sts') + + async with session.post(url, headers=signed_headers, data=signed_data) as response: + resp = await response.text() + if 'ErrorResponse' in resp: + return + + root = ET.fromstring(resp) + namespace = {'ns': 'https://sts.amazonaws.com/doc/2011-06-15/'} + + arn = root.find('.//ns:Arn', namespaces=namespace).text + if not arn: + return "default" + username = arn.split('/')[-1] + if "iam::" in username: + return "default" + return username + + +async def get_key_policies(key: APIKey, session): + url = f'https://iam.amazonaws.com/?Action=ListAttachedUserPolicies&UserName={key.username}&Version=2010-05-08' + signed_headers, signed_data = await sign_request(key, 'us-east-1', 'GET', url, {}, None, 'iam') + async with session.get(url, headers=signed_headers) as response: + resp = await response.text() + root = ET.fromstring(resp) + namespace = {'iam': 'https://iam.amazonaws.com/doc/2010-05-08/'} + + attached_policies = root.findall('.//iam:AttachedPolicies/iam:member', namespaces=namespace) + policy_names = [] + if not attached_policies: + if not key.bedrock_enabled: + key.useless = True + key.useless_reasons.append('Failed Policy Fetch') + return + + for policy in attached_policies: + policy_name = policy.find('iam:PolicyName', namespaces=namespace).text + policy_names.append(policy_name) + + if policy_name == 'AdministratorAccess': + key.admin_priv = True + if 'AWSCompromisedKeyQuarantine' in policy_name: + if not key.bedrock_enabled: + key.useless = True + key.useless_reasons.append('Quarantined Key') + return policy_names + + +async def check_logging(key: APIKey, session): + region = key.region + host = f'bedrock.{region}.amazonaws.com' + url = f'https://{host}/logging/modelinvocations' + signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock') + async with session.get(url, headers=signed_headers) as response: + if response.status == 200: + logging_config = await response.json() + if 'loggingConfig' in logging_config and logging_config['loggingConfig'] is not None and 'textDataDeliveryEnabled' in logging_config['loggingConfig']: + key.logged = logging_config['loggingConfig']['textDataDeliveryEnabled'] + else: + key.logged = False + else: + key.logged = False + + +async def retrieve_activated_models(key: APIKey, session): + tasks = [] + for region in aws_regions: + if region not in key.models: + key.models[region] = [] + task = handle_region(key, session, region) + tasks.append(task) + await asyncio.gather(*tasks) + + +async def handle_region(key: APIKey, session, region): + listed_models = await retrieve_models(key, session, region) + tasks = [invoke_model(key, session, region, model) for model in listed_models] + await asyncio.gather(*tasks) + + +async def invoke_model(key: APIKey, session, region, model): + model_id, model_name = model + data = { + "prompt": "\n\nHuman:\n\nAssistant:", + "max_tokens_to_sample": -1, + } + host = f'bedrock-runtime.{region}.amazonaws.com' + url = f'https://{host}/model/{model_id}/invoke' + signed_headers, signed_data = await sign_request(key, region, 'POST', url, {'content-type': 'application/json', 'accept': '*/*'}, data, 'bedrock') + async with session.post(url, headers=signed_headers, data=signed_data) as response: + if response.status == 400: + resp = await response.json() + if resp['message'] and 'Malformed input request' in resp['message']: + key.models[region].append(model_name) + + +async def retrieve_models(key: APIKey, session, region): + host = f'bedrock.{region}.amazonaws.com' + url = f'https://{host}/foundation-models' + signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock') + async with session.get(url, headers=signed_headers) as response: + if response.status == 200: + response = await response.json() + models = response["modelSummaries"] + + model_providers = ["Meta", "Anthropic", "Mistral AI"] + model_info = [] + + for model in models: + provider_name = model["providerName"] + model_id = model["modelId"] + model_name = model["modelName"] + + if provider_name in model_providers or (provider_name == "Cohere" and ("Command R+" in model_name or "Command R" in model_name)): + parts = model_id.split(":") + if len(parts) <= 2: + model_info.append((model_id, model_name)) + + return model_info + else: + return [] diff --git a/main.py b/main.py index 0e1d9c3..23e0c02 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,7 @@ import argparse import os.path import asyncio import aiohttp +import AWSAsync api_keys = set() @@ -30,6 +31,7 @@ def parse_args(): parser.add_argument('-file', '--file', action='store', dest='file', help='read slop from a provided filename') parser.add_argument('-verbose', '--verbose', action='store_true', help='watch as your slop is checked real time') parser.add_argument('-awsmodels', '--awsmodels', action='store_true', help='output activated aws models for a key (warning: slow)') + parser.add_argument('-awsasync', '--awsasync', action='store_true', help='use the AWS REST API for checking keys instead of boto3 (way faster but not as well tested, -awsmodels autoapplies here due to the speedup)') return parser.parse_args() @@ -148,6 +150,16 @@ def validate_aws(key: APIKey): api_keys.add(key) +async def validate_aws_async(key: APIKey, sem): + async with sem, aiohttp.ClientSession() as session: + IO.conditional_print(f"Checking AWS key asynchronously: {key.api_key}", args.verbose) + if await AWSAsync.check_aws(key, session) is None: + IO.conditional_print(f"Invalid AWS key: {key.api_key}", args.verbose) + return + IO.conditional_print(f"AWS key '{key.api_key}' is valid", args.verbose) + api_keys.add(key) + + def validate_azure(key: APIKey): IO.conditional_print(f"Checking Azure key: {key.api_key}", args.verbose) if check_azure(key) is None: @@ -220,7 +232,10 @@ async def validate_keys(): if not match: continue key_obj = APIKey(Provider.AWS, key) - futures.append(executor.submit(validate_aws, key_obj)) + if args.awsasync: + tasks.append(validate_aws_async(key_obj, concurrent_connections)) + else: + futures.append(executor.submit(validate_aws, key_obj)) elif ":" in key and "AKIA" not in key: match = azure_regex.match(key) if not match: