mirror of
https://github.com/cunnymessiah/keychecker.git
synced 2026-05-10 18:39:04 -07:00
225 lines
8.5 KiB
Python
225 lines
8.5 KiB
Python
import APIKey
|
|
import boto3
|
|
from botocore.auth import SigV4Auth
|
|
from botocore.awsrequest import AWSRequest
|
|
import xml.etree.ElementTree as ET
|
|
import asyncio
|
|
|
|
|
|
aws_regions = [
|
|
"us-east-1",
|
|
"us-west-2",
|
|
"ap-southeast-2",
|
|
"ap-south-1",
|
|
"eu-west-3",
|
|
]
|
|
|
|
|
|
async def check_aws(key: APIKey, session):
|
|
key.username = await get_username(key, session)
|
|
if key.username is None:
|
|
return
|
|
|
|
if not await test_invoke_perms(key, session, "anthropic.claude-3-sonnet-20240229-v1:0"):
|
|
return
|
|
|
|
if not key.bedrock_enabled and not key.useless:
|
|
await test_invoke_perms(key, session, "anthropic.claude-v2")
|
|
|
|
policies = await get_key_policies(key, session)
|
|
if not key.useless and key.bedrock_enabled:
|
|
await check_logging(key, session)
|
|
await retrieve_activated_models(key, session)
|
|
for region, models in key.models.items():
|
|
key.models[region] = list(set(models))
|
|
elif key.useless and policies is not None:
|
|
key.useless_reasons.append('Key policies lack Admin or User Creation perms')
|
|
return True
|
|
|
|
|
|
async def sign_request(key: APIKey, region, method, url, headers, data, service):
|
|
line = key.api_key.split(":")
|
|
access_key = line[0]
|
|
secret = line[1]
|
|
boto3_session = boto3.Session(
|
|
aws_access_key_id=access_key,
|
|
aws_secret_access_key=secret,
|
|
region_name=region
|
|
)
|
|
credentials = boto3_session.get_credentials()
|
|
signer = SigV4Auth(credentials, service, region)
|
|
request = AWSRequest(method=method, url=url, headers=headers, data=data)
|
|
signer.add_auth(request)
|
|
return request.headers, request.data
|
|
|
|
|
|
async def test_invoke_perms(key: APIKey, session, model):
|
|
async def check_region(region):
|
|
host = f'bedrock-runtime.{region}.amazonaws.com'
|
|
url = f'https://{host}/model/{model}/invoke'
|
|
|
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'bedrock')
|
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
|
resp = await response.json()
|
|
message = resp.get('message')
|
|
if response.status == 403:
|
|
if message and ('The request signature we calculated does not match the signature you provided' in message or 'The security token included in the request is invalid' in message):
|
|
return False
|
|
elif response.status == 400 or response.status == 404:
|
|
if message and 'Malformed input request' in message:
|
|
if key.region == "":
|
|
key.region = region
|
|
else:
|
|
key.alt_regions.append(region)
|
|
key.bedrock_enabled = True
|
|
else:
|
|
return False
|
|
key.useless = False
|
|
return True
|
|
|
|
headers = {
|
|
'content-type': 'application/json',
|
|
'accept': '*/*',
|
|
}
|
|
data = {
|
|
"prompt": "\n\nHuman:\n\nAssistant:",
|
|
"max_tokens_to_sample": -1,
|
|
}
|
|
|
|
tasks = [asyncio.create_task(check_region(region)) for region in aws_regions]
|
|
results = await asyncio.gather(*tasks)
|
|
return any(results)
|
|
|
|
|
|
async def get_username(key: APIKey, session):
|
|
region = 'us-east-1'
|
|
url = 'https://sts.amazonaws.com/'
|
|
headers = {
|
|
'Content-Type': 'application/x-www-form-urlencoded',
|
|
}
|
|
data = 'Action=GetCallerIdentity&Version=2011-06-15'
|
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, headers, data, 'sts')
|
|
|
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
|
resp = await response.text()
|
|
if 'ErrorResponse' in resp:
|
|
return
|
|
|
|
root = ET.fromstring(resp)
|
|
namespace = {'ns': 'https://sts.amazonaws.com/doc/2011-06-15/'}
|
|
|
|
arn = root.find('.//ns:Arn', namespaces=namespace).text
|
|
if not arn:
|
|
return "default"
|
|
username = arn.split('/')[-1]
|
|
if "iam::" in username:
|
|
return "default"
|
|
return username
|
|
|
|
|
|
async def get_key_policies(key: APIKey, session):
|
|
url = f'https://iam.amazonaws.com/?Action=ListAttachedUserPolicies&UserName={key.username}&Version=2010-05-08'
|
|
signed_headers, signed_data = await sign_request(key, 'us-east-1', 'GET', url, {}, None, 'iam')
|
|
async with session.get(url, headers=signed_headers) as response:
|
|
resp = await response.text()
|
|
root = ET.fromstring(resp)
|
|
namespace = {'iam': 'https://iam.amazonaws.com/doc/2010-05-08/'}
|
|
|
|
attached_policies = root.findall('.//iam:AttachedPolicies/iam:member', namespaces=namespace)
|
|
policy_names = []
|
|
if not attached_policies:
|
|
if not key.bedrock_enabled:
|
|
key.useless = True
|
|
key.useless_reasons.append('Failed Policy Fetch')
|
|
return
|
|
|
|
for policy in attached_policies:
|
|
policy_name = policy.find('iam:PolicyName', namespaces=namespace).text
|
|
policy_names.append(policy_name)
|
|
|
|
if policy_name == 'AdministratorAccess':
|
|
key.admin_priv = True
|
|
if 'AWSCompromisedKeyQuarantine' in policy_name:
|
|
if not key.bedrock_enabled:
|
|
key.useless = True
|
|
key.useless_reasons.append('Quarantined Key')
|
|
return policy_names
|
|
|
|
|
|
async def check_logging(key: APIKey, session):
|
|
region = key.region
|
|
host = f'bedrock.{region}.amazonaws.com'
|
|
url = f'https://{host}/logging/modelinvocations'
|
|
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
|
|
async with session.get(url, headers=signed_headers) as response:
|
|
if response.status == 200:
|
|
logging_config = await response.json()
|
|
if 'loggingConfig' in logging_config and logging_config['loggingConfig'] is not None and 'textDataDeliveryEnabled' in logging_config['loggingConfig']:
|
|
key.logged = logging_config['loggingConfig']['textDataDeliveryEnabled']
|
|
else:
|
|
key.logged = False
|
|
else:
|
|
key.logged = False
|
|
|
|
|
|
async def retrieve_activated_models(key: APIKey, session):
|
|
tasks = []
|
|
for region in aws_regions:
|
|
if region not in key.models:
|
|
key.models[region] = []
|
|
task = handle_region(key, session, region)
|
|
tasks.append(task)
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
async def handle_region(key: APIKey, session, region):
|
|
listed_models = await retrieve_models(key, session, region)
|
|
tasks = [invoke_model(key, session, region, model) for model in listed_models]
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
async def invoke_model(key: APIKey, session, region, model):
|
|
model_id, model_name = model
|
|
data = {
|
|
"prompt": "\n\nHuman:\n\nAssistant:",
|
|
"max_tokens_to_sample": -1,
|
|
}
|
|
host = f'bedrock-runtime.{region}.amazonaws.com'
|
|
url = f'https://{host}/model/{model_id}/invoke'
|
|
signed_headers, signed_data = await sign_request(key, region, 'POST', url, {'content-type': 'application/json', 'accept': '*/*'}, data, 'bedrock')
|
|
async with session.post(url, headers=signed_headers, data=signed_data) as response:
|
|
if response.status == 400:
|
|
resp = await response.json()
|
|
if resp['message'] and 'Malformed input request' in resp['message']:
|
|
key.models[region].append(model_name)
|
|
elif resp['message'] and 'Operation not' in resp['message']:
|
|
key.useless = True
|
|
key.useless_reasons.append('Quarantined Key')
|
|
|
|
|
|
async def retrieve_models(key: APIKey, session, region):
|
|
host = f'bedrock.{region}.amazonaws.com'
|
|
url = f'https://{host}/foundation-models'
|
|
signed_headers, signed_data = await sign_request(key, region, 'GET', url, {'accept': 'application/json'}, {}, 'bedrock')
|
|
async with session.get(url, headers=signed_headers) as response:
|
|
if response.status == 200:
|
|
response = await response.json()
|
|
models = response["modelSummaries"]
|
|
|
|
model_providers = ["Meta", "Anthropic", "Mistral AI"]
|
|
model_info = []
|
|
|
|
for model in models:
|
|
provider_name = model["providerName"]
|
|
model_id = model["modelId"]
|
|
model_name = model["modelName"]
|
|
|
|
if provider_name in model_providers or (provider_name == "Cohere" and ("Command R+" in model_name or "Command R" in model_name)):
|
|
parts = model_id.split(":")
|
|
if len(parts) <= 2:
|
|
model_info.append((model_id, model_name))
|
|
|
|
return model_info
|
|
else:
|
|
return []
|