mirror of
https://github.com/cunnymessiah/keychecker.git
synced 2026-05-10 18:39:04 -07:00
async aws checker (run with -awsasync)
will eventually replace the old one if people don't find any issue with it
This commit is contained in:
+220
@@ -0,0 +1,220 @@
|
|||||||
|
import APIKey
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
aws_regions = [
|
||||||
|
"us-east-1",
|
||||||
|
"us-west-2",
|
||||||
|
"ap-southeast-2",
|
||||||
|
"ap-south-1",
|
||||||
|
"eu-west-3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def check_aws(key: APIKey, session):
|
||||||
|
key.username = await get_username(key, session)
|
||||||
|
if key.username is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await test_invoke_perms(key, session, "anthropic.claude-3-sonnet-20240229-v1:0"):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not key.bedrock_enabled and not key.useless:
|
||||||
|
await test_invoke_perms(key, session, "anthropic.claude-v2")
|
||||||
|
|
||||||
|
policies = await get_key_policies(key, session)
|
||||||
|
if not key.useless and key.bedrock_enabled:
|
||||||
|
await check_logging(key, session)
|
||||||
|
await retrieve_activated_models(key, session)
|
||||||
|
for region, models in key.models.items():
|
||||||
|
key.models[region] = list(set(models))
|
||||||
|
elif key.useless and policies is not None:
|
||||||
|
key.useless_reasons.append('Key policies lack Admin or User Creation perms')
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def sign_request(key: APIKey, region, method, url, headers, data, service):
|
||||||
|
line = key.api_key.split(":")
|
||||||
|
access_key = line[0]
|
||||||
|
secret = line[1]
|
||||||
|
boto3_session = boto3.Session(
|
||||||
|
aws_access_key_id=access_key,
|
||||||
|
aws_secret_access_key=secret,
|
||||||
|
region_name=region
|
||||||
|
)
|
||||||
|
credentials = boto3_session.get_credentials()
|
||||||
|
signer = SigV4Auth(credentials, service, region)
|
||||||
|
request = AWSRequest(method=method, url=url, headers=headers, data=data)
|
||||||
|
signer.add_auth(request)
|
||||||
|
return request.headers, request.data
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invoke_perms(key: APIKey, session, model):
|
||||||
|
async def check_region(region):
|
||||||
|
host = f'bedrock-runtime.{region}.amazonaws.com'
|
||||||
|
url = f'https://{host}/model/{model}/invoke'
|
||||||
|
|
||||||
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'bedrock')
|
||||||
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
||||||
|
resp = await response.json()
|
||||||
|
if response.status == 403:
|
||||||
|
if resp['message'] and 'The request signature we calculated does not match the signature you provided' in resp['message'] or 'The security token included in the request is invalid' in resp['message']:
|
||||||
|
return False
|
||||||
|
elif response.status == 400 or response.status == 404:
|
||||||
|
if resp['message'] and 'Malformed input request' in resp['message']:
|
||||||
|
if key.region == "":
|
||||||
|
key.region = region
|
||||||
|
else:
|
||||||
|
key.alt_regions.append(region)
|
||||||
|
key.bedrock_enabled = True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
key.useless = False
|
||||||
|
return True
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'content-type': 'application/json',
|
||||||
|
'accept': '*/*',
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"prompt": "\n\nHuman:\n\nAssistant:",
|
||||||
|
"max_tokens_to_sample": -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(check_region(region)) for region in aws_regions]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
return any(results)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_username(key: APIKey, session):
|
||||||
|
region = 'us-east-1'
|
||||||
|
url = 'https://sts.amazonaws.com/'
|
||||||
|
headers = {
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
}
|
||||||
|
data = 'Action=GetCallerIdentity&Version=2011-06-15'
|
||||||
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'sts')
|
||||||
|
|
||||||
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
||||||
|
resp = await response.text()
|
||||||
|
if 'ErrorResponse' in resp:
|
||||||
|
return
|
||||||
|
|
||||||
|
root = ET.fromstring(resp)
|
||||||
|
namespace = {'ns': 'https://sts.amazonaws.com/doc/2011-06-15/'}
|
||||||
|
|
||||||
|
arn = root.find('.//ns:Arn', namespaces=namespace).text
|
||||||
|
if not arn:
|
||||||
|
return "default"
|
||||||
|
username = arn.split('/')[-1]
|
||||||
|
if "iam::" in username:
|
||||||
|
return "default"
|
||||||
|
return username
|
||||||
|
|
||||||
|
|
||||||
|
async def get_key_policies(key: APIKey, session):
|
||||||
|
url = f'https://iam.amazonaws.com/?Action=ListAttachedUserPolicies&UserName={key.username}&Version=2010-05-08'
|
||||||
|
signed_headers, signed_data = await sign_request(key, 'us-east-1', 'GET', url, {}, None, 'iam')
|
||||||
|
async with session.get(url, headers=signed_headers) as response:
|
||||||
|
resp = await response.text()
|
||||||
|
root = ET.fromstring(resp)
|
||||||
|
namespace = {'iam': 'https://iam.amazonaws.com/doc/2010-05-08/'}
|
||||||
|
|
||||||
|
attached_policies = root.findall('.//iam:AttachedPolicies/iam:member', namespaces=namespace)
|
||||||
|
policy_names = []
|
||||||
|
if not attached_policies:
|
||||||
|
if not key.bedrock_enabled:
|
||||||
|
key.useless = True
|
||||||
|
key.useless_reasons.append('Failed Policy Fetch')
|
||||||
|
return
|
||||||
|
|
||||||
|
for policy in attached_policies:
|
||||||
|
policy_name = policy.find('iam:PolicyName', namespaces=namespace).text
|
||||||
|
policy_names.append(policy_name)
|
||||||
|
|
||||||
|
if policy_name == 'AdministratorAccess':
|
||||||
|
key.admin_priv = True
|
||||||
|
if 'AWSCompromisedKeyQuarantine' in policy_name:
|
||||||
|
if not key.bedrock_enabled:
|
||||||
|
key.useless = True
|
||||||
|
key.useless_reasons.append('Quarantined Key')
|
||||||
|
return policy_names
|
||||||
|
|
||||||
|
|
||||||
|
async def check_logging(key: APIKey, session):
|
||||||
|
region = key.region
|
||||||
|
host = f'bedrock.{region}.amazonaws.com'
|
||||||
|
url = f'https://{host}/logging/modelinvocations'
|
||||||
|
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
|
||||||
|
async with session.get(url, headers=signed_headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
logging_config = await response.json()
|
||||||
|
if 'loggingConfig' in logging_config and logging_config['loggingConfig'] is not None and 'textDataDeliveryEnabled' in logging_config['loggingConfig']:
|
||||||
|
key.logged = logging_config['loggingConfig']['textDataDeliveryEnabled']
|
||||||
|
else:
|
||||||
|
key.logged = False
|
||||||
|
else:
|
||||||
|
key.logged = False
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_activated_models(key: APIKey, session):
|
||||||
|
tasks = []
|
||||||
|
for region in aws_regions:
|
||||||
|
if region not in key.models:
|
||||||
|
key.models[region] = []
|
||||||
|
task = handle_region(key, session, region)
|
||||||
|
tasks.append(task)
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_region(key: APIKey, session, region):
|
||||||
|
listed_models = await retrieve_models(key, session, region)
|
||||||
|
tasks = [invoke_model(key, session, region, model) for model in listed_models]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
async def invoke_model(key: APIKey, session, region, model):
|
||||||
|
model_id, model_name = model
|
||||||
|
data = {
|
||||||
|
"prompt": "\n\nHuman:\n\nAssistant:",
|
||||||
|
"max_tokens_to_sample": -1,
|
||||||
|
}
|
||||||
|
host = f'bedrock-runtime.{region}.amazonaws.com'
|
||||||
|
url = f'https://{host}/model/{model_id}/invoke'
|
||||||
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, {'content-type': 'application/json', 'accept': '*/*'}, data, 'bedrock')
|
||||||
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
||||||
|
if response.status == 400:
|
||||||
|
resp = await response.json()
|
||||||
|
if resp['message'] and 'Malformed input request' in resp['message']:
|
||||||
|
key.models[region].append(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def retrieve_models(key: APIKey, session, region):
|
||||||
|
host = f'bedrock.{region}.amazonaws.com'
|
||||||
|
url = f'https://{host}/foundation-models'
|
||||||
|
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
|
||||||
|
async with session.get(url, headers=signed_headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
response = await response.json()
|
||||||
|
models = response["modelSummaries"]
|
||||||
|
|
||||||
|
model_providers = ["Meta", "Anthropic", "Mistral AI"]
|
||||||
|
model_info = []
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
provider_name = model["providerName"]
|
||||||
|
model_id = model["modelId"]
|
||||||
|
model_name = model["modelName"]
|
||||||
|
|
||||||
|
if provider_name in model_providers or (provider_name == "Cohere" and ("Command R+" in model_name or "Command R" in model_name)):
|
||||||
|
parts = model_id.split(":")
|
||||||
|
if len(parts) <= 2:
|
||||||
|
model_info.append((model_id, model_name))
|
||||||
|
|
||||||
|
return model_info
|
||||||
|
else:
|
||||||
|
return []
|
||||||
@@ -19,6 +19,7 @@ import argparse
|
|||||||
import os.path
|
import os.path
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import AWSAsync
|
||||||
|
|
||||||
api_keys = set()
|
api_keys = set()
|
||||||
|
|
||||||
@@ -30,6 +31,7 @@ def parse_args():
|
|||||||
parser.add_argument('-file', '--file', action='store', dest='file', help='read slop from a provided filename')
|
parser.add_argument('-file', '--file', action='store', dest='file', help='read slop from a provided filename')
|
||||||
parser.add_argument('-verbose', '--verbose', action='store_true', help='watch as your slop is checked real time')
|
parser.add_argument('-verbose', '--verbose', action='store_true', help='watch as your slop is checked real time')
|
||||||
parser.add_argument('-awsmodels', '--awsmodels', action='store_true', help='output activated aws models for a key (warning: slow)')
|
parser.add_argument('-awsmodels', '--awsmodels', action='store_true', help='output activated aws models for a key (warning: slow)')
|
||||||
|
parser.add_argument('-awsasync', '--awsasync', action='store_true', help='use the AWS REST API for checking keys instead of boto3 (way faster but not as well tested, -awsmodels autoapplies here due to the speedup)')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -148,6 +150,16 @@ def validate_aws(key: APIKey):
|
|||||||
api_keys.add(key)
|
api_keys.add(key)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_aws_async(key: APIKey, sem):
|
||||||
|
async with sem, aiohttp.ClientSession() as session:
|
||||||
|
IO.conditional_print(f"Checking AWS key asynchronously: {key.api_key}", args.verbose)
|
||||||
|
if await AWSAsync.check_aws(key, session) is None:
|
||||||
|
IO.conditional_print(f"Invalid AWS key: {key.api_key}", args.verbose)
|
||||||
|
return
|
||||||
|
IO.conditional_print(f"AWS key '{key.api_key}' is valid", args.verbose)
|
||||||
|
api_keys.add(key)
|
||||||
|
|
||||||
|
|
||||||
def validate_azure(key: APIKey):
|
def validate_azure(key: APIKey):
|
||||||
IO.conditional_print(f"Checking Azure key: {key.api_key}", args.verbose)
|
IO.conditional_print(f"Checking Azure key: {key.api_key}", args.verbose)
|
||||||
if check_azure(key) is None:
|
if check_azure(key) is None:
|
||||||
@@ -220,7 +232,10 @@ async def validate_keys():
|
|||||||
if not match:
|
if not match:
|
||||||
continue
|
continue
|
||||||
key_obj = APIKey(Provider.AWS, key)
|
key_obj = APIKey(Provider.AWS, key)
|
||||||
futures.append(executor.submit(validate_aws, key_obj))
|
if args.awsasync:
|
||||||
|
tasks.append(validate_aws_async(key_obj, concurrent_connections))
|
||||||
|
else:
|
||||||
|
futures.append(executor.submit(validate_aws, key_obj))
|
||||||
elif ":" in key and "AKIA" not in key:
|
elif ":" in key and "AKIA" not in key:
|
||||||
match = azure_regex.match(key)
|
match = azure_regex.match(key)
|
||||||
if not match:
|
if not match:
|
||||||
|
|||||||
Reference in New Issue
Block a user