From 266bc976ea2d69411b80da8b837086d44ea582d3 Mon Sep 17 00:00:00 2001 From: based Date: Mon, 11 Dec 2023 03:44:49 +1000 Subject: [PATCH] Added Google Vertex AI support --- APIKey.py | 5 +++++ README.md | 2 +- VertexAI.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 37 ++++++++++++++++++++++++++++++------- requirements.txt | Bin 398 -> 1422 bytes 5 files changed, 83 insertions(+), 8 deletions(-) create mode 100644 VertexAI.py diff --git a/APIKey.py b/APIKey.py index 7cd99b1..6113b16 100644 --- a/APIKey.py +++ b/APIKey.py @@ -36,6 +36,10 @@ class APIKey: self.deployments = [] self.unfiltered = False + elif provider == Provider.VERTEXAI: + self.project_id = "" + pass + class Provider(Enum): OPENAI = 1, @@ -44,3 +48,4 @@ class Provider(Enum): PALM = 4 AWS = 5 AZURE = 6 + VERTEXAI = 7 diff --git a/README.md b/README.md index 7fd5534..109d177 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ Currently supports and validates keys for the services below, and checks for the - Google PaLM - AWS - (Admin status, auto-fetch the region, logging status, username, bedrock status) - Azure - (Auto-fetch all deployments, auto-fetch best deployment/model, filter status) +- Google Cloud Vertex AI - (Requires a key file since oauth tokens expire hourly. Good luck scraping for those.) -May add support for Google Vertex in the future. # Usage: `pip install -r requirements.txt` diff --git a/VertexAI.py b/VertexAI.py new file mode 100644 index 0000000..c5e07c3 --- /dev/null +++ b/VertexAI.py @@ -0,0 +1,47 @@ +import APIKey +import json +import vertexai +from google.cloud import aiplatform +from google.oauth2 import service_account +import google.api_core.exceptions +from vertexai.language_models import TextGenerationModel + + +location = 'us-central1' # location doesn't matter unlike azure/aws + + +def check_vertexai(key: APIKey): + try: + credentials = service_account.Credentials.from_service_account_file(key.api_key) + with open(key.api_key, 'r') as file: + data = json.load(file) + if data.get('type') != 'service_account': + return + + project_id = data.get('project_id') + if not project_id: + return + key.project_id = project_id + + aiplatform.init(credentials=credentials, location=location, project=key.project_id) + test_model_response(key, credentials) + + except google.api_core.exceptions.InvalidArgument: + key.api_key = f'"{key.api_key}"' + return True # if we get to the stage where google yells at us for a bad parameter, 99% sure the key works. + except Exception as e: + return + + +def test_model_response(key: APIKey, credentials): + vertexai.init(project=key.project_id, location=location, credentials=credentials) + model = TextGenerationModel.from_pretrained("text-bison@002") + model.predict("bweh", **{"temperature": 0.1, "max_output_tokens": 0}) + + +def pretty_print_vertexai_keys(keys): + print('-' * 90) + print(f'Validated {len(keys)} Google Vertex AI keys:') + for key in keys: + print(f'{key.api_key} | {key.project_id}') + print(f'\n--- Total Valid Google Vertex AI Keys: {len(keys)} ---\n') diff --git a/main.py b/main.py index adcd9d0..38b1702 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ from AI21 import check_ai21, pretty_print_ai21_keys from Palm import check_palm, pretty_print_palm_keys from AWS import check_aws, pretty_print_aws_keys from Azure import check_azure, pretty_print_azure_keys +from VertexAI import check_vertexai, pretty_print_vertexai_keys from APIKey import APIKey, Provider from concurrent.futures import ThreadPoolExecutor, as_completed @@ -13,11 +14,12 @@ import sys from datetime import datetime import re import argparse +import os.path api_keys = set() -print("Enter API keys (OpenAI/Anthropic/AI21/PaLM/AWS/Azure) one per line. Press Enter on a blank line to start validation") -print("Expected format for AWS keys is accesskey:secret, for Azure keys it's resourcegroup:apikey") +print('Enter API keys (OpenAI/Anthropic/AI21/PaLM/AWS/Azure) one per line. Press Enter on a blank line to start validation') +print('Expected format for AWS keys is accesskey:secret, for Azure keys it\'s resourcegroup:apikey. For Vertex AI keys the absolute path to the secrets key file is expected in quotes. "/path/to/secrets.json"') inputted_keys = set() while True: @@ -86,12 +88,19 @@ def validate_azure(key: APIKey): api_keys.add(key) +def validate_vertexai(key: APIKey): + if check_vertexai(key) is None: + return + api_keys.add(key) + + oai_regex = re.compile('(sk-[A-Za-z0-9]{20}T3BlbkFJ[A-Za-z0-9]{20})') anthropic_regex = re.compile(r'sk-ant-api03-[A-Za-z0-9\-_]{93}AA') ai21_regex = re.compile('[A-Za-z0-9]{32}') palm_regex = re.compile(r'AIzaSy[A-Za-z0-9\-_]{33}') aws_regex = re.compile(r'^(AKIA[0-9A-Z]{16}):([A-Za-z0-9+/]{40})$') azure_regex = re.compile(r'^(.+):([a-z0-9]{32})$') +# vertex_regex = re.compile(r'^(.+):(ya29.[A-Za-z0-9\-_]{469})$') regex for the oauth tokens, useless since they expire hourly executor = ThreadPoolExecutor(max_workers=100) @@ -99,7 +108,13 @@ executor = ThreadPoolExecutor(max_workers=100) def validate_keys(): futures = [] for key in inputted_keys: - if "ant-api03" in key: + if '"' in key[:1]: + key = key.strip('"') + if not os.path.isfile(key): + continue + key_obj = APIKey(Provider.VERTEXAI, key) + futures.append(executor.submit(validate_vertexai, key_obj)) + elif "ant-api03" in key: match = anthropic_regex.match(key) if not match: continue @@ -141,14 +156,16 @@ def validate_keys(): futures.clear() -def get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys): +def get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys, valid_vertexai_keys): valid_oai_keys_set = set([key.api_key for key in valid_oai_keys]) valid_anthropic_keys_set = set([key.api_key for key in valid_anthropic_keys]) valid_ai21_keys_set = set([key.api_key for key in valid_ai21_keys]) valid_palm_keys_set = set([key.api_key for key in valid_palm_keys]) valid_aws_keys_set = set([key.api_key for key in valid_aws_keys]) valid_azure_keys_set = set([key.api_key for key in valid_azure_keys]) - invalid_keys = inputted_keys - valid_oai_keys_set - valid_anthropic_keys_set - valid_ai21_keys_set - valid_palm_keys_set - valid_aws_keys_set - valid_azure_keys_set + valid_vertexai_keys_set = set([key.api_key for key in valid_vertexai_keys]) + + invalid_keys = inputted_keys - valid_oai_keys_set - valid_anthropic_keys_set - valid_ai21_keys_set - valid_palm_keys_set - valid_aws_keys_set - valid_azure_keys_set - valid_vertexai_keys_set if len(invalid_keys) < 1: return print('\nInvalid Keys:') @@ -166,6 +183,8 @@ def output_keys(): valid_palm_keys = [] valid_aws_keys = [] valid_azure_keys = [] + valid_vertexai_keys = [] + for key in api_keys: if key.provider == Provider.OPENAI: valid_oai_keys.append(key) @@ -179,6 +198,8 @@ def output_keys(): valid_aws_keys.append(key) elif key.provider == Provider.AZURE: valid_azure_keys.append(key) + elif key.provider == Provider.VERTEXAI: + valid_vertexai_keys.append(key) if should_write: output_filename = "key_snapshots.txt" sys.stdout = Logger(output_filename) @@ -188,7 +209,7 @@ def output_keys(): print(f"Key snapshot from {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print("#" * 90) print(f'\n--- Checked {len(inputted_keys)} keys | {len(inputted_keys) - len(api_keys)} were invalid ---') - get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys) + get_invalid_keys(valid_oai_keys, valid_anthropic_keys, valid_ai21_keys, valid_palm_keys, valid_aws_keys, valid_azure_keys, valid_vertexai_keys) print() if valid_oai_keys: pretty_print_oai_keys(valid_oai_keys) @@ -202,8 +223,10 @@ def output_keys(): pretty_print_aws_keys(valid_aws_keys) if valid_azure_keys: pretty_print_azure_keys(valid_azure_keys) + if valid_vertexai_keys: + pretty_print_vertexai_keys(valid_vertexai_keys) else: - # ai21 keys aren't supported in proxies so no point outputting them, filtered azure keys should be excluded. + # ai21 and vertex keys aren't supported in proxies so no point outputting them, filtered azure keys should be excluded. print("OPENAI_KEY=" + ','.join(key.api_key for key in valid_oai_keys)) print("ANTHROPIC_KEY=" + ','.join(key.api_key for key in valid_anthropic_keys)) print("AWS_CREDENTIALS=" + ','.join(f"{key.api_key}:{key.region}" for key in valid_aws_keys)) diff --git a/requirements.txt b/requirements.txt index 2878a70e7717658ecc48c1d66e8216a1f2230868..9bcdaa9bd5382db07fc5edaaecdaa99fb4fd11ff 100644 GIT binary patch literal 1422 zcmai!-EP`Y5QO)-QXj=pj1$^i^kIx4;MT^}KUI^5Py5ZT4~8^pgs>22c4l^VkN>=E zth14CVc+=Vc4I4h;CXAG?USX`3Pf+aGo}EQFqM^7*@08>bk^A5xO)(V^NFeAya!R+ zBVTE)&{#R*i=$SqD5etR;Cz`KUAg7IhPr2GQFSmU4|`YMgV|XI_5?OVb8#)dxye&> z_PS!i+hO*l41^3v*T@Qu&L-QsPK{E7-bh!{YMntT7N)M=(z;~W;B#+3;Zr(KUaih@ z@9(4P*ZtjjrP6-iv!*zB2ffcE!*fZ>B{lMpwjlRi_{hQFGwnCk!COI6!bot563 zNjRO{-)5d{i^naQm0qoxuHQ^KHei@{poOEO$Bb$9{L-Grj%reC+$@vo!Bilx6O>}h zzrRRDUEWf0J)@|rVNMCI9hDMX&NH+}_(rCF=|Y@=uIW(|5tFvz&Tqzd9B6hYuZ>-m z|L~%zA7N}ck3EXJtbAgRRae;$o_i8J@{D;ccXUZdNi|pg;--eb1-Ak7=seoJxHY-E zm5UO$$BV=?>3>hPfgK%BKI2G4J;_(ndm^m^D9v!|{qMl*ZPLtL+}Kxlpd7S0ai^4Y zFlxpH>WNctm`GU8AKdc*Vx(I4VW6*cALM4a&4a(8PMkJ#(y>1&=K)v0-Idhxiw&LN QZur!^SBtZ6vERG?U(OrV&;S4c delta 25 hcmeC