Added Google Vertex AI support

This commit is contained in:
based
2023-12-11 03:44:49 +10:00
parent 782864d6cc
commit 266bc976ea
5 changed files with 83 additions and 8 deletions
+5
View File
@@ -36,6 +36,10 @@ class APIKey:
self.deployments = []
self.unfiltered = False
elif provider == Provider.VERTEXAI:
self.project_id = ""
pass
class Provider(Enum):
OPENAI = 1,
@@ -44,3 +48,4 @@ class Provider(Enum):
PALM = 4
AWS = 5
AZURE = 6
VERTEXAI = 7
+1 -1
View File
@@ -9,8 +9,8 @@ Currently supports and validates keys for the services below, and checks for the
- Google PaLM
- AWS - (Admin status, auto-fetch the region, logging status, username, bedrock status)
- Azure - (Auto-fetch all deployments, auto-fetch best deployment/model, filter status)
- Google Cloud Vertex AI - (Requires a key file since oauth tokens expire hourly. Good luck scraping for those.)
May add support for Google Vertex in the future.
# Usage:
`pip install -r requirements.txt`
+47
View File
@@ -0,0 +1,47 @@
import APIKey
import json
import vertexai
from google.cloud import aiplatform
from google.oauth2 import service_account
import google.api_core.exceptions
from vertexai.language_models import TextGenerationModel
location = 'us-central1' # location doesn't matter unlike azure/aws
def check_vertexai(key: APIKey):
try:
credentials = service_account.Credentials.from_service_account_file(key.api_key)
with open(key.api_key, 'r') as file:
data = json.load(file)
if data.get('type') != 'service_account':
return
project_id = data.get('project_id')
if not project_id:
return
key.project_id = project_id
aiplatform.init(credentials=credentials, location=location, project=key.project_id)
test_model_response(key, credentials)
except google.api_core.exceptions.InvalidArgument:
key.api_key = f'"{key.api_key}"'
return True # if we get to the stage where google yells at us for a bad parameter, 99% sure the key works.
except Exception as e:
return
def test_model_response(key: APIKey, credentials):
vertexai.init(project=key.project_id, location=location, credentials=credentials)
model = TextGenerationModel.from_pretrained("text-bison@002")
model.predict("bweh", **{"temperature": 0.1, "max_output_tokens": 0})
def pretty_print_vertexai_keys(keys):
print('-' * 90)
print(f'Validated {len(keys)} Google Vertex AI keys:')
for key in keys:
print(f'{key.api_key} | {key.project_id}')
print(f'\n--- Total Valid Google Vertex AI Keys: {len(keys)} ---\n')
+30 -7
View File
@@ -6,6 +6,7 @@ from AI21 import check_ai21, pretty_print_ai21_keys
from Palm import check_palm, pretty_print_palm_keys
from AWS import check_aws, pretty_print_aws_keys
from Azure import check_azure, pretty_print_azure_keys
from VertexAI import check_vertexai, pretty_print_vertexai_keys
from APIKey import APIKey, Provider
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -13,11 +14,12 @@ import sys
from datetime import datetime
import re
import argparse
import os.path
api_keys = set()
print("Enter API keys (OpenAI/Anthropic/AI21/PaLM/AWS/Azure) one per line. Press Enter on a blank line to start validation")
print("Expected format for AWS keys is accesskey:secret, for Azure keys it's resourcegroup:apikey")
print('Enter API keys (OpenAI/Anthropic/AI21/PaLM/AWS/Azure) one per line. Press Enter on a blank line to start validation')
print('Expected format for AWS keys is accesskey:secret, for Azure keys it\'s resourcegroup:apikey. For Vertex AI keys the absolute path to the secrets key file is expected in quotes. "/path/to/secrets.json"')
inputted_keys = set()
while True:
@@ -86,12 +88,19 @@ def validate_azure(key: APIKey):
api_keys.add(key)
def validate_vertexai(key: APIKey):
if check_vertexai(key) is None:
return
api_keys.add(key)
oai_regex = re.compile('(sk-[A-Za-z0-9]{20}T3BlbkFJ[A-Za-z0-9]{20})')
anthropic_regex = re.compile(r'sk-ant-api03-[A-Za-z0-9\-_]{93}AA')
ai21_regex = re.compile('[A-Za-z0-9]{32}')
palm_regex = re.compile(r'AIzaSy[A-Za-z0-9\-_]{33}')
aws_regex = re.compile(r'^(AKIA[0-9A-Z]{16}):([A-Za-z0-9+/]{40})$')
azure_regex = re.compile(r'^(.+):([a-z0-9]{32})$')
# vertex_regex = re.compile(r'^(.+):(ya29.[A-Za-z0-9\-_]{469})$') regex for the oauth tokens, useless since they expire hourly
executor = ThreadPoolExecutor(max_workers=100)
@@ -99,7 +108,13 @@ executor = ThreadPoolExecutor(max_workers=100)
def validate_keys():
futures = []
for key in inputted_keys:
if "ant-api03" in key:
if '"' in key[:1]:
key = key.strip('"')
if not os.path.isfile(key):
continue
key_obj = APIKey(Provider.VERTEXAI, key)
futures.append(executor.submit(validate_vertexai, key_obj))
elif "ant-api03" in key:
match = anthropic_regex.match(key)
if not match:
continue
@@ -141,14 +156,16 @@ def validate_keys():
futures.clear()
def get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys):
def get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys, valid_vertexai_keys):
valid_oai_keys_set = set([key.api_key for key in valid_oai_keys])
valid_anthropic_keys_set = set([key.api_key for key in valid_anthropic_keys])
valid_ai21_keys_set = set([key.api_key for key in valid_ai21_keys])
valid_palm_keys_set = set([key.api_key for key in valid_palm_keys])
valid_aws_keys_set = set([key.api_key for key in valid_aws_keys])
valid_azure_keys_set = set([key.api_key for key in valid_azure_keys])
invalid_keys = inputted_keys - valid_oai_keys_set - valid_anthropic_keys_set - valid_ai21_keys_set - valid_palm_keys_set - valid_aws_keys_set - valid_azure_keys_set
valid_vertexai_keys_set = set([key.api_key for key in valid_vertexai_keys])
invalid_keys = inputted_keys - valid_oai_keys_set - valid_anthropic_keys_set - valid_ai21_keys_set - valid_palm_keys_set - valid_aws_keys_set - valid_azure_keys_set - valid_vertexai_keys_set
if len(invalid_keys) < 1:
return
print('\nInvalid Keys:')
@@ -166,6 +183,8 @@ def output_keys():
valid_palm_keys = []
valid_aws_keys = []
valid_azure_keys = []
valid_vertexai_keys = []
for key in api_keys:
if key.provider == Provider.OPENAI:
valid_oai_keys.append(key)
@@ -179,6 +198,8 @@ def output_keys():
valid_aws_keys.append(key)
elif key.provider == Provider.AZURE:
valid_azure_keys.append(key)
elif key.provider == Provider.VERTEXAI:
valid_vertexai_keys.append(key)
if should_write:
output_filename = "key_snapshots.txt"
sys.stdout = Logger(output_filename)
@@ -188,7 +209,7 @@ def output_keys():
print(f"Key snapshot from {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("#" * 90)
print(f'\n--- Checked {len(inputted_keys)} keys | {len(inputted_keys) - len(api_keys)} were invalid ---')
get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys)
get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys, valid_vertexai_keys)
print()
if valid_oai_keys:
pretty_print_oai_keys(valid_oai_keys)
@@ -202,8 +223,10 @@ def output_keys():
pretty_print_aws_keys(valid_aws_keys)
if valid_azure_keys:
pretty_print_azure_keys(valid_azure_keys)
if valid_vertexai_keys:
pretty_print_vertexai_keys(valid_vertexai_keys)
else:
# ai21 keys aren't supported in proxies so no point outputting them, filtered azure keys should be excluded.
# ai21 and vertex keys aren't supported in proxies so no point outputting them, filtered azure keys should be excluded.
print("OPENAI_KEY=" + ','.join(key.api_key for key in valid_oai_keys))
print("ANTHROPIC_KEY=" + ','.join(key.api_key for key in valid_anthropic_keys))
print("AWS_CREDENTIALS=" + ','.join(f"{key.api_key}:{key.region}" for key in valid_aws_keys))
BIN
View File
Binary file not shown.