Merge GCP Vertex AI implementation from cg-dot/oai-reverse-proxy (khanon/oai-reverse-proxy!72)
This commit is contained in:
+12
-3
@@ -40,15 +40,21 @@ NODE_ENV=production
|
|||||||
|
|
||||||
# Which model types users are allowed to access.
|
# Which model types users are allowed to access.
|
||||||
# The following model families are recognized:
|
# The following model families are recognized:
|
||||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
|
||||||
|
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus
|
||||||
|
# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small
|
||||||
|
# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude
|
||||||
|
# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k
|
||||||
|
# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||||
|
|
||||||
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
||||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
||||||
# 'azure-dall-e' to the list of allowed model families.
|
# 'azure-dall-e' to the list of allowed model families.
|
||||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||||
|
|
||||||
# Which services can be used to process prompts containing images via multimodal
|
# Which services can be used to process prompts containing images via multimodal
|
||||||
# models. The following services are recognized:
|
# models. The following services are recognized:
|
||||||
# openai | anthropic | aws | azure | google-ai | mistral-ai
|
# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai
|
||||||
# Do not enable this feature unless all users are trusted, as you will be liable
|
# Do not enable this feature unless all users are trusted, as you will be liable
|
||||||
# for any user-submitted images containing illegal content.
|
# for any user-submitted images containing illegal content.
|
||||||
# By default, no image services are allowed and image prompts are rejected.
|
# By default, no image services are allowed and image prompts are rejected.
|
||||||
@@ -118,6 +124,7 @@ NODE_ENV=production
|
|||||||
# TOKEN_QUOTA_CLAUDE=0
|
# TOKEN_QUOTA_CLAUDE=0
|
||||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||||
|
# TOKEN_QUOTA_GCP_CLAUDE=0
|
||||||
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
||||||
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
||||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||||
@@ -142,6 +149,7 @@ NODE_ENV=production
|
|||||||
|
|
||||||
# You can add multiple API keys by separating them with a comma.
|
# You can add multiple API keys by separating them with a comma.
|
||||||
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
|
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
|
||||||
|
# For GCP credentials, separate the project ID, client email, region, and private key with a colon.
|
||||||
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||||
@@ -149,6 +157,7 @@ GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
|||||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||||
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||||
|
GCP_CREDENTIALS=project-id:client-email:region:private-key
|
||||||
|
|
||||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||||
# PROXY_KEY=your-secret-key
|
# PROXY_KEY=your-secret-key
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ This project allows you to run a reverse proxy server for various LLM APIs.
|
|||||||
- [x] [OpenAI](https://openai.com/)
|
- [x] [OpenAI](https://openai.com/)
|
||||||
- [x] [Anthropic](https://www.anthropic.com/)
|
- [x] [Anthropic](https://www.anthropic.com/)
|
||||||
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
|
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
|
||||||
|
- [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/)
|
||||||
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
|
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
|
||||||
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
|
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
|
||||||
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
|
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
# Configuring the proxy for Vertex AI (GCP)
|
||||||
|
|
||||||
|
The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs.
|
||||||
|
|
||||||
|
- [Setting keys](#setting-keys)
|
||||||
|
- [Setup Vertex AI](#setup-vertex-ai)
|
||||||
|
- [Supported model IDs](#supported-model-ids)
|
||||||
|
|
||||||
|
## Setting keys
|
||||||
|
|
||||||
|
Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys.
|
||||||
|
|
||||||
|
Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setup Vertex AI
|
||||||
|
1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days)
|
||||||
|
2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API.
|
||||||
|
3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models.
|
||||||
|
4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator".
|
||||||
|
5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically.
|
||||||
|
6. The required credential is in the JSON file you just downloaded.
|
||||||
|
|
||||||
|
## Supported model IDs
|
||||||
|
Users can send these model IDs to the proxy to invoke the corresponding models.
|
||||||
|
- **Claude**
|
||||||
|
- `claude-3-haiku@20240307`
|
||||||
|
- `claude-3-sonnet@20240229`
|
||||||
|
- `claude-3-opus@20240229`
|
||||||
|
- `claude-3-5-sonnet@20240620`
|
||||||
@@ -230,6 +230,39 @@ Content-Type: application/json
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / GCP Claude -- Native Completion
|
||||||
|
POST {{proxy-host}}/proxy/gcp/claude/v1/complete
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
anthropic-version: 2023-01-01
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "claude-v2",
|
||||||
|
"max_tokens_to_sample": 10,
|
||||||
|
"temperature": 0,
|
||||||
|
"stream": true,
|
||||||
|
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||||
|
}
|
||||||
|
|
||||||
|
###
|
||||||
|
# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation
|
||||||
|
POST {{proxy-host}}/proxy/gcp/claude/chat/completions
|
||||||
|
Authorization: Bearer {{proxy-key}}
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": true,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is genshin impact?"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
###
|
###
|
||||||
# @name Proxy / Azure OpenAI -- Native Chat Completions
|
# @name Proxy / Azure OpenAI -- Native Chat Completions
|
||||||
POST {{proxy-host}}/proxy/azure/openai/chat/completions
|
POST {{proxy-host}}/proxy/azure/openai/chat/completions
|
||||||
|
|||||||
@@ -51,6 +51,8 @@ function getRandomModelFamily() {
|
|||||||
"mistral-large",
|
"mistral-large",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
"aws-claude-opus",
|
"aws-claude-opus",
|
||||||
|
"gcp-claude",
|
||||||
|
"gcp-claude-opus",
|
||||||
"azure-turbo",
|
"azure-turbo",
|
||||||
"azure-gpt4",
|
"azure-gpt4",
|
||||||
"azure-gpt4-32k",
|
"azure-gpt4-32k",
|
||||||
|
|||||||
@@ -268,7 +268,7 @@ router.post("/maintenance", (req, res) => {
|
|||||||
let flash = { type: "", message: "" };
|
let flash = { type: "", message: "" };
|
||||||
switch (action) {
|
switch (action) {
|
||||||
case "recheck": {
|
case "recheck": {
|
||||||
const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"];
|
const checkable: LLMService[] = ["openai", "anthropic", "aws", "gcp","azure"];
|
||||||
checkable.forEach((s) => keyPool.recheck(s));
|
checkable.forEach((s) => keyPool.recheck(s));
|
||||||
const keyCount = keyPool
|
const keyCount = keyPool
|
||||||
.list()
|
.list()
|
||||||
|
|||||||
+14
-1
@@ -45,6 +45,13 @@ type Config = {
|
|||||||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||||
*/
|
*/
|
||||||
awsCredentials?: string;
|
awsCredentials?: string;
|
||||||
|
/**
|
||||||
|
* Comma-delimited list of GCP credentials. Each credential item should be a
|
||||||
|
* colon-delimited list of access key, secret key, and GCP region.
|
||||||
|
*
|
||||||
|
* @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----`
|
||||||
|
*/
|
||||||
|
gcpCredentials?: string;
|
||||||
/**
|
/**
|
||||||
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||||
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||||
@@ -349,7 +356,7 @@ type Config = {
|
|||||||
*
|
*
|
||||||
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
||||||
* separated list. Available services are:
|
* separated list. Available services are:
|
||||||
* openai,anthropic,google-ai,mistral-ai,aws,azure
|
* openai,anthropic,google-ai,mistral-ai,aws,gcp,azure
|
||||||
*/
|
*/
|
||||||
allowedVisionServices: LLMService[];
|
allowedVisionServices: LLMService[];
|
||||||
/**
|
/**
|
||||||
@@ -383,6 +390,7 @@ export const config: Config = {
|
|||||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||||
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||||
|
gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""),
|
||||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||||
@@ -437,6 +445,8 @@ export const config: Config = {
|
|||||||
"mistral-large",
|
"mistral-large",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
"aws-claude-opus",
|
"aws-claude-opus",
|
||||||
|
"gcp-claude",
|
||||||
|
"gcp-claude-opus",
|
||||||
"azure-turbo",
|
"azure-turbo",
|
||||||
"azure-gpt4",
|
"azure-gpt4",
|
||||||
"azure-gpt4-32k",
|
"azure-gpt4-32k",
|
||||||
@@ -511,6 +521,7 @@ function generateSigningKey() {
|
|||||||
config.googleAIKey,
|
config.googleAIKey,
|
||||||
config.mistralAIKey,
|
config.mistralAIKey,
|
||||||
config.awsCredentials,
|
config.awsCredentials,
|
||||||
|
config.gcpCredentials,
|
||||||
config.azureCredentials,
|
config.azureCredentials,
|
||||||
];
|
];
|
||||||
if (secrets.filter((s) => s).length === 0) {
|
if (secrets.filter((s) => s).length === 0) {
|
||||||
@@ -648,6 +659,7 @@ export const OMITTED_KEYS = [
|
|||||||
"googleAIKey",
|
"googleAIKey",
|
||||||
"mistralAIKey",
|
"mistralAIKey",
|
||||||
"awsCredentials",
|
"awsCredentials",
|
||||||
|
"gcpCredentials",
|
||||||
"azureCredentials",
|
"azureCredentials",
|
||||||
"proxyKey",
|
"proxyKey",
|
||||||
"adminKey",
|
"adminKey",
|
||||||
@@ -738,6 +750,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
|||||||
"ANTHROPIC_KEY",
|
"ANTHROPIC_KEY",
|
||||||
"GOOGLE_AI_KEY",
|
"GOOGLE_AI_KEY",
|
||||||
"AWS_CREDENTIALS",
|
"AWS_CREDENTIALS",
|
||||||
|
"GCP_CREDENTIALS",
|
||||||
"AZURE_CREDENTIALS",
|
"AZURE_CREDENTIALS",
|
||||||
].includes(String(env))
|
].includes(String(env))
|
||||||
) {
|
) {
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
|||||||
"mistral-large": "Mistral Large",
|
"mistral-large": "Mistral Large",
|
||||||
"aws-claude": "AWS Claude (Sonnet)",
|
"aws-claude": "AWS Claude (Sonnet)",
|
||||||
"aws-claude-opus": "AWS Claude (Opus)",
|
"aws-claude-opus": "AWS Claude (Opus)",
|
||||||
|
"gcp-claude": "GCP Claude (Sonnet)",
|
||||||
|
"gcp-claude-opus": "GCP Claude (Opus)",
|
||||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||||
"azure-gpt4": "Azure GPT-4",
|
"azure-gpt4": "Azure GPT-4",
|
||||||
"azure-gpt4-32k": "Azure GPT-4 32k",
|
"azure-gpt4-32k": "Azure GPT-4 32k",
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ export function transformAnthropicChatResponseToAnthropicText(
|
|||||||
* is only used for non-streaming requests as streaming requests are handled
|
* is only used for non-streaming requests as streaming requests are handled
|
||||||
* on-the-fly.
|
* on-the-fly.
|
||||||
*/
|
*/
|
||||||
function transformAnthropicTextResponseToOpenAI(
|
export function transformAnthropicTextResponseToOpenAI(
|
||||||
anthropicBody: Record<string, any>,
|
anthropicBody: Record<string, any>,
|
||||||
req: Request
|
req: Request
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
|
|||||||
@@ -0,0 +1,196 @@
|
|||||||
|
import { Request, RequestHandler, Response, Router } from "express";
|
||||||
|
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||||
|
import { v4 } from "uuid";
|
||||||
|
import { config } from "../config";
|
||||||
|
import { logger } from "../logger";
|
||||||
|
import { createQueueMiddleware } from "./queue";
|
||||||
|
import { ipLimiter } from "./rate-limit";
|
||||||
|
import { handleProxyError } from "./middleware/common";
|
||||||
|
import {
|
||||||
|
createPreprocessorMiddleware,
|
||||||
|
signGcpRequest,
|
||||||
|
finalizeSignedRequest,
|
||||||
|
createOnProxyReqHandler,
|
||||||
|
} from "./middleware/request";
|
||||||
|
import {
|
||||||
|
ProxyResHandlerWithBody,
|
||||||
|
createOnProxyResHandler,
|
||||||
|
} from "./middleware/response";
|
||||||
|
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||||
|
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||||
|
|
||||||
|
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
|
||||||
|
|
||||||
|
let modelsCache: any = null;
|
||||||
|
let modelsCacheTime = 0;
|
||||||
|
|
||||||
|
const getModelsResponse = () => {
|
||||||
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||||
|
return modelsCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.gcpCredentials) return { object: "list", data: [] };
|
||||||
|
|
||||||
|
// https://docs.anthropic.com/en/docs/about-claude/models
|
||||||
|
const variants = [
|
||||||
|
"claude-3-haiku@20240307",
|
||||||
|
"claude-3-sonnet@20240229",
|
||||||
|
"claude-3-opus@20240229",
|
||||||
|
"claude-3-5-sonnet@20240620",
|
||||||
|
];
|
||||||
|
|
||||||
|
const models = variants.map((id) => ({
|
||||||
|
id,
|
||||||
|
object: "model",
|
||||||
|
created: new Date().getTime(),
|
||||||
|
owned_by: "anthropic",
|
||||||
|
permission: [],
|
||||||
|
root: "claude",
|
||||||
|
parent: null,
|
||||||
|
}));
|
||||||
|
|
||||||
|
modelsCache = { object: "list", data: models };
|
||||||
|
modelsCacheTime = new Date().getTime();
|
||||||
|
|
||||||
|
return modelsCache;
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||||
|
res.status(200).json(getModelsResponse());
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Only used for non-streaming requests. */
|
||||||
|
const gcpResponseHandler: ProxyResHandlerWithBody = async (
|
||||||
|
_proxyRes,
|
||||||
|
req,
|
||||||
|
res,
|
||||||
|
body
|
||||||
|
) => {
|
||||||
|
if (typeof body !== "object") {
|
||||||
|
throw new Error("Expected body to be an object");
|
||||||
|
}
|
||||||
|
|
||||||
|
let newBody = body;
|
||||||
|
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||||
|
case "openai<-anthropic-chat":
|
||||||
|
req.log.info("Transforming Anthropic Chat back to OpenAI format");
|
||||||
|
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||||
|
};
|
||||||
|
|
||||||
|
const gcpProxy = createQueueMiddleware({
|
||||||
|
beforeProxy: signGcpRequest,
|
||||||
|
proxyMiddleware: createProxyMiddleware({
|
||||||
|
target: "bad-target-will-be-rewritten",
|
||||||
|
router: ({ signedRequest }) => {
|
||||||
|
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||||
|
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||||
|
},
|
||||||
|
changeOrigin: true,
|
||||||
|
selfHandleResponse: true,
|
||||||
|
logger,
|
||||||
|
on: {
|
||||||
|
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||||
|
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
|
||||||
|
error: handleProxyError,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const oaiToChatPreprocessor = createPreprocessorMiddleware(
|
||||||
|
{ inApi: "openai", outApi: "anthropic-chat", service: "gcp" },
|
||||||
|
{ afterTransform: [maybeReassignModel] }
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||||
|
* or the new Claude chat completion endpoint, based on the requested model.
|
||||||
|
*/
|
||||||
|
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||||
|
oaiToChatPreprocessor(req, res, next);
|
||||||
|
};
|
||||||
|
|
||||||
|
const gcpRouter = Router();
|
||||||
|
gcpRouter.get("/v1/models", handleModelRequest);
|
||||||
|
// Native Anthropic chat completion endpoint.
|
||||||
|
gcpRouter.post(
|
||||||
|
"/v1/messages",
|
||||||
|
ipLimiter,
|
||||||
|
createPreprocessorMiddleware(
|
||||||
|
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" },
|
||||||
|
{ afterTransform: [maybeReassignModel] }
|
||||||
|
),
|
||||||
|
gcpProxy
|
||||||
|
);
|
||||||
|
|
||||||
|
// OpenAI-to-GCP Anthropic compatibility endpoint.
|
||||||
|
gcpRouter.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
ipLimiter,
|
||||||
|
preprocessOpenAICompatRequest,
|
||||||
|
gcpProxy
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tries to deal with:
|
||||||
|
* - frontends sending GCP model names even when they want to use the OpenAI-
|
||||||
|
* compatible endpoint
|
||||||
|
* - frontends sending Anthropic model names that GCP doesn't recognize
|
||||||
|
* - frontends sending OpenAI model names because they expect the proxy to
|
||||||
|
* translate them
|
||||||
|
*
|
||||||
|
* If client sends GCP model ID it will be used verbatim. Otherwise, various
|
||||||
|
* strategies are used to try to map a non-GCP model name to GCP model ID.
|
||||||
|
*/
|
||||||
|
function maybeReassignModel(req: Request) {
|
||||||
|
const model = req.body.model;
|
||||||
|
|
||||||
|
// If it looks like an GCP model, use it as-is
|
||||||
|
// if (model.includes("anthropic.claude")) {
|
||||||
|
if (model.startsWith("claude-") && model.includes("@")) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anthropic model names can look like:
|
||||||
|
// - claude-v1
|
||||||
|
// - claude-2.1
|
||||||
|
// - claude-3-5-sonnet-20240620-v1:0
|
||||||
|
const pattern =
|
||||||
|
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||||
|
const match = model.match(pattern);
|
||||||
|
|
||||||
|
// If there's no match, fallback to Claude3 Sonnet as it is most likely to be
|
||||||
|
// available on GCP.
|
||||||
|
if (!match) {
|
||||||
|
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||||
|
|
||||||
|
const ver = minor ? `${major}.${minor}` : major;
|
||||||
|
switch (ver) {
|
||||||
|
case "3":
|
||||||
|
case "3.0":
|
||||||
|
if (name.includes("opus")) {
|
||||||
|
req.body.model = "claude-3-opus@20240229";
|
||||||
|
} else if (name.includes("haiku")) {
|
||||||
|
req.body.model = "claude-3-haiku@20240307";
|
||||||
|
} else {
|
||||||
|
req.body.model = "claude-3-sonnet@20240229";
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
case "3.5":
|
||||||
|
req.body.model = "claude-3-5-sonnet@20240620";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to Claude3 Sonnet
|
||||||
|
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const gcp = gcpRouter;
|
||||||
@@ -15,6 +15,7 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
|
|||||||
export { languageFilter } from "./preprocessors/language-filter";
|
export { languageFilter } from "./preprocessors/language-filter";
|
||||||
export { setApiFormat } from "./preprocessors/set-api-format";
|
export { setApiFormat } from "./preprocessors/set-api-format";
|
||||||
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
||||||
|
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
|
||||||
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
||||||
export { validateContextSize } from "./preprocessors/validate-context-size";
|
export { validateContextSize } from "./preprocessors/validate-context-size";
|
||||||
export { validateVision } from "./preprocessors/validate-vision";
|
export { validateVision } from "./preprocessors/validate-vision";
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
|||||||
proxyReq.setHeader("api-key", azureKey);
|
proxyReq.setHeader("api-key", azureKey);
|
||||||
break;
|
break;
|
||||||
case "aws":
|
case "aws":
|
||||||
|
case "gcp":
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
throw new Error("add-key should not be used for this service.");
|
throw new Error("add-key should not be used for this service.");
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import type { HPMRequestCallback } from "../index";
|
import type { HPMRequestCallback } from "../index";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For AWS/Azure/Google requests, the body is signed earlier in the request
|
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
|
||||||
* pipeline, before the proxy middleware. This function just assigns the path
|
* pipeline, before the proxy middleware. This function just assigns the path
|
||||||
* and headers to the proxy request.
|
* and headers to the proxy request.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,201 @@
|
|||||||
|
import express from "express";
|
||||||
|
import crypto from "crypto";
|
||||||
|
import { keyPool } from "../../../../shared/key-management";
|
||||||
|
import { RequestPreprocessor } from "../index";
|
||||||
|
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
|
||||||
|
|
||||||
|
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||||
|
|
||||||
|
export const signGcpRequest: RequestPreprocessor = async (req) => {
|
||||||
|
const serviceValid = req.service === "gcp";
|
||||||
|
if (!serviceValid) {
|
||||||
|
throw new Error("addVertexAIKey called on invalid request");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!req.body?.model) {
|
||||||
|
throw new Error("You must specify a model with your request.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const { model, stream } = req.body;
|
||||||
|
req.key = keyPool.get(model, "gcp");
|
||||||
|
|
||||||
|
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request");
|
||||||
|
|
||||||
|
req.isStreaming = String(stream) === "true";
|
||||||
|
|
||||||
|
// TODO: This should happen in transform-outbound-payload.ts
|
||||||
|
// TODO: Support tools
|
||||||
|
let strippedParams: Record<string, unknown>;
|
||||||
|
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||||
|
messages: true,
|
||||||
|
system: true,
|
||||||
|
max_tokens: true,
|
||||||
|
stop_sequences: true,
|
||||||
|
temperature: true,
|
||||||
|
top_k: true,
|
||||||
|
top_p: true,
|
||||||
|
stream: true,
|
||||||
|
})
|
||||||
|
.strip()
|
||||||
|
.parse(req.body);
|
||||||
|
strippedParams.anthropic_version = "vertex-2023-10-16";
|
||||||
|
|
||||||
|
const [accessToken, credential] = await getAccessToken(req);
|
||||||
|
|
||||||
|
const host = GCP_HOST.replace("%REGION%", credential.region);
|
||||||
|
// GCP doesn't use the anthropic-version header, but we set it to ensure the
|
||||||
|
// stream adapter selects the correct transformer.
|
||||||
|
req.headers["anthropic-version"] = "2023-06-01";
|
||||||
|
|
||||||
|
req.signedRequest = {
|
||||||
|
method: "POST",
|
||||||
|
protocol: "https:",
|
||||||
|
hostname: host,
|
||||||
|
path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`,
|
||||||
|
headers: {
|
||||||
|
["host"]: host,
|
||||||
|
["content-type"]: "application/json",
|
||||||
|
["authorization"]: `Bearer ${accessToken}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify(strippedParams),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
async function getAccessToken(
|
||||||
|
req: express.Request
|
||||||
|
): Promise<[string, Credential]> {
|
||||||
|
// TODO: access token caching to reduce latency
|
||||||
|
const credential = getCredentialParts(req);
|
||||||
|
const signedJWT = await createSignedJWT(
|
||||||
|
credential.clientEmail,
|
||||||
|
credential.privateKey
|
||||||
|
);
|
||||||
|
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
|
||||||
|
if (accessToken === null) {
|
||||||
|
req.log.warn(
|
||||||
|
{ key: req.key!.hash, jwtError },
|
||||||
|
"Unable to get the access token"
|
||||||
|
);
|
||||||
|
throw new Error("The access token is invalid.");
|
||||||
|
}
|
||||||
|
return [accessToken, credential];
|
||||||
|
}
|
||||||
|
|
||||||
|
async function createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||||
|
let cryptoKey = await crypto.subtle.importKey(
|
||||||
|
"pkcs8",
|
||||||
|
str2ab(atob(pkey)),
|
||||||
|
{
|
||||||
|
name: "RSASSA-PKCS1-v1_5",
|
||||||
|
hash: { name: "SHA-256" },
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
["sign"]
|
||||||
|
);
|
||||||
|
|
||||||
|
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||||
|
const issued = Math.floor(Date.now() / 1000);
|
||||||
|
const expires = issued + 600;
|
||||||
|
|
||||||
|
const header = {
|
||||||
|
alg: "RS256",
|
||||||
|
typ: "JWT",
|
||||||
|
};
|
||||||
|
|
||||||
|
const payload = {
|
||||||
|
iss: email,
|
||||||
|
aud: authUrl,
|
||||||
|
iat: issued,
|
||||||
|
exp: expires,
|
||||||
|
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
};
|
||||||
|
|
||||||
|
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
|
||||||
|
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
|
||||||
|
|
||||||
|
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||||
|
|
||||||
|
const signature = await crypto.subtle.sign(
|
||||||
|
"RSASSA-PKCS1-v1_5",
|
||||||
|
cryptoKey,
|
||||||
|
str2ab(unsignedToken)
|
||||||
|
);
|
||||||
|
|
||||||
|
const encodedSignature = urlSafeBase64Encode(signature);
|
||||||
|
return `${unsignedToken}.${encodedSignature}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function exchangeJwtForAccessToken(
|
||||||
|
signed_jwt: string
|
||||||
|
): Promise<[string | null, string]> {
|
||||||
|
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||||
|
const params = {
|
||||||
|
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||||
|
assertion: signed_jwt,
|
||||||
|
};
|
||||||
|
|
||||||
|
const r = await fetch(auth_url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||||
|
body: Object.entries(params)
|
||||||
|
.map(([k, v]) => `${k}=${v}`)
|
||||||
|
.join("&"),
|
||||||
|
}).then((res) => res.json());
|
||||||
|
|
||||||
|
if (r.access_token) {
|
||||||
|
return [r.access_token, ""];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [null, JSON.stringify(r)];
|
||||||
|
}
|
||||||
|
|
||||||
|
function str2ab(str: string): ArrayBuffer {
|
||||||
|
const buffer = new ArrayBuffer(str.length);
|
||||||
|
const bufferView = new Uint8Array(buffer);
|
||||||
|
for (let i = 0; i < str.length; i++) {
|
||||||
|
bufferView[i] = str.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||||
|
let base64: string;
|
||||||
|
if (typeof data === "string") {
|
||||||
|
base64 = btoa(
|
||||||
|
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||||
|
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||||
|
}
|
||||||
|
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
type Credential = {
|
||||||
|
projectId: string;
|
||||||
|
clientEmail: string;
|
||||||
|
region: string;
|
||||||
|
privateKey: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
function getCredentialParts(req: express.Request): Credential {
|
||||||
|
const [projectId, clientEmail, region, rawPrivateKey] =
|
||||||
|
req.key!.key.split(":");
|
||||||
|
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||||
|
req.log.error(
|
||||||
|
{ key: req.key!.hash },
|
||||||
|
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
|
||||||
|
);
|
||||||
|
throw new Error("The key assigned to this request is invalid.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const privateKey = rawPrivateKey
|
||||||
|
.replace(
|
||||||
|
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||||
|
""
|
||||||
|
)
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
return { projectId, clientEmail, region, privateKey };
|
||||||
|
}
|
||||||
@@ -186,6 +186,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
throw new HttpError(statusCode, parseError.message);
|
throw new HttpError(statusCode, parseError.message);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const service = req.key!.service;
|
||||||
|
if (service === "gcp") {
|
||||||
|
if (Array.isArray(errorPayload)) {
|
||||||
|
errorPayload = errorPayload[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const errorType =
|
const errorType =
|
||||||
errorPayload.error?.code ||
|
errorPayload.error?.code ||
|
||||||
errorPayload.error?.type ||
|
errorPayload.error?.type ||
|
||||||
@@ -199,11 +206,15 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
// TODO: split upstream error handling into separate modules for each service,
|
// TODO: split upstream error handling into separate modules for each service,
|
||||||
// this is out of control.
|
// this is out of control.
|
||||||
|
|
||||||
const service = req.key!.service;
|
|
||||||
if (service === "aws") {
|
if (service === "aws") {
|
||||||
// Try to standardize the error format for AWS
|
// Try to standardize the error format for AWS
|
||||||
errorPayload.error = { message: errorPayload.message, type: errorType };
|
errorPayload.error = { message: errorPayload.message, type: errorType };
|
||||||
delete errorPayload.message;
|
delete errorPayload.message;
|
||||||
|
} else if (service === "gcp") {
|
||||||
|
// Try to standardize the error format for GCP
|
||||||
|
if (errorPayload.error?.code) { // GCP Error
|
||||||
|
errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code };
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (statusCode === 400) {
|
if (statusCode === 400) {
|
||||||
@@ -225,6 +236,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
break;
|
break;
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
case "aws":
|
case "aws":
|
||||||
|
case "gcp":
|
||||||
await handleAnthropicAwsBadRequestError(req, errorPayload);
|
await handleAnthropicAwsBadRequestError(req, errorPayload);
|
||||||
break;
|
break;
|
||||||
case "google-ai":
|
case "google-ai":
|
||||||
@@ -280,6 +292,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
default:
|
default:
|
||||||
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
|
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
case "gcp":
|
||||||
|
keyPool.disable(req.key!, "revoked");
|
||||||
|
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
} else if (statusCode === 429) {
|
} else if (statusCode === 429) {
|
||||||
switch (service) {
|
switch (service) {
|
||||||
@@ -292,6 +309,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
case "aws":
|
case "aws":
|
||||||
await handleAwsRateLimitError(req, errorPayload);
|
await handleAwsRateLimitError(req, errorPayload);
|
||||||
break;
|
break;
|
||||||
|
case "gcp":
|
||||||
|
await handleGcpRateLimitError(req, errorPayload);
|
||||||
|
break;
|
||||||
case "azure":
|
case "azure":
|
||||||
case "mistral-ai":
|
case "mistral-ai":
|
||||||
await handleAzureRateLimitError(req, errorPayload);
|
await handleAzureRateLimitError(req, errorPayload);
|
||||||
@@ -328,6 +348,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
|||||||
case "aws":
|
case "aws":
|
||||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||||
break;
|
break;
|
||||||
|
case "gcp":
|
||||||
|
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
|
||||||
|
break;
|
||||||
case "azure":
|
case "azure":
|
||||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||||
break;
|
break;
|
||||||
@@ -434,6 +457,19 @@ async function handleAwsRateLimitError(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleGcpRateLimitError(
|
||||||
|
req: Request,
|
||||||
|
errorPayload: ProxiedErrorPayload
|
||||||
|
) {
|
||||||
|
if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") {
|
||||||
|
keyPool.markRateLimited(req.key!);
|
||||||
|
await reenqueueRequest(req);
|
||||||
|
throw new RetryableError("GCP rate-limited request re-enqueued.");
|
||||||
|
} else {
|
||||||
|
errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function handleOpenAIRateLimitError(
|
async function handleOpenAIRateLimitError(
|
||||||
req: Request,
|
req: Request,
|
||||||
errorPayload: ProxiedErrorPayload
|
errorPayload: ProxiedErrorPayload
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import { anthropic } from "./anthropic";
|
|||||||
import { googleAI } from "./google-ai";
|
import { googleAI } from "./google-ai";
|
||||||
import { mistralAI } from "./mistral-ai";
|
import { mistralAI } from "./mistral-ai";
|
||||||
import { aws } from "./aws";
|
import { aws } from "./aws";
|
||||||
|
import { gcp } from "./gcp";
|
||||||
import { azure } from "./azure";
|
import { azure } from "./azure";
|
||||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||||
|
|
||||||
@@ -36,6 +37,7 @@ proxyRouter.use("/anthropic", addV1, anthropic);
|
|||||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||||
proxyRouter.use("/aws/claude", addV1, aws);
|
proxyRouter.use("/aws/claude", addV1, aws);
|
||||||
|
proxyRouter.use("/gcp/claude", addV1, gcp);
|
||||||
proxyRouter.use("/azure/openai", addV1, azure);
|
proxyRouter.use("/azure/openai", addV1, azure);
|
||||||
// Redirect browser requests to the homepage.
|
// Redirect browser requests to the homepage.
|
||||||
proxyRouter.get("*", (req, res, next) => {
|
proxyRouter.get("*", (req, res, next) => {
|
||||||
|
|||||||
+44
-1
@@ -2,6 +2,7 @@ import { config, listConfig } from "./config";
|
|||||||
import {
|
import {
|
||||||
AnthropicKey,
|
AnthropicKey,
|
||||||
AwsBedrockKey,
|
AwsBedrockKey,
|
||||||
|
GcpKey,
|
||||||
AzureOpenAIKey,
|
AzureOpenAIKey,
|
||||||
GoogleAIKey,
|
GoogleAIKey,
|
||||||
keyPool,
|
keyPool,
|
||||||
@@ -11,6 +12,7 @@ import {
|
|||||||
AnthropicModelFamily,
|
AnthropicModelFamily,
|
||||||
assertIsKnownModelFamily,
|
assertIsKnownModelFamily,
|
||||||
AwsBedrockModelFamily,
|
AwsBedrockModelFamily,
|
||||||
|
GcpModelFamily,
|
||||||
AzureOpenAIModelFamily,
|
AzureOpenAIModelFamily,
|
||||||
GoogleAIModelFamily,
|
GoogleAIModelFamily,
|
||||||
LLM_SERVICES,
|
LLM_SERVICES,
|
||||||
@@ -40,6 +42,7 @@ const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
|||||||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||||
k.service === "mistral-ai";
|
k.service === "mistral-ai";
|
||||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||||
|
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
|
||||||
|
|
||||||
/** Stats aggregated across all keys for a given service. */
|
/** Stats aggregated across all keys for a given service. */
|
||||||
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
|
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
|
||||||
@@ -52,7 +55,11 @@ type ModelAggregates = {
|
|||||||
pozzed?: number;
|
pozzed?: number;
|
||||||
awsLogged?: number;
|
awsLogged?: number;
|
||||||
awsSonnet?: number;
|
awsSonnet?: number;
|
||||||
|
awsSonnet35?: number;
|
||||||
awsHaiku?: number;
|
awsHaiku?: number;
|
||||||
|
gcpSonnet?: number;
|
||||||
|
gcpSonnet35?: number;
|
||||||
|
gcpHaiku?: number;
|
||||||
queued: number;
|
queued: number;
|
||||||
queueTime: string;
|
queueTime: string;
|
||||||
tokens: number;
|
tokens: number;
|
||||||
@@ -87,6 +94,12 @@ type AnthropicInfo = BaseFamilyInfo & {
|
|||||||
type AwsInfo = BaseFamilyInfo & {
|
type AwsInfo = BaseFamilyInfo & {
|
||||||
privacy?: string;
|
privacy?: string;
|
||||||
sonnetKeys?: number;
|
sonnetKeys?: number;
|
||||||
|
sonnet35Keys?: number;
|
||||||
|
haikuKeys?: number;
|
||||||
|
};
|
||||||
|
type GcpInfo = BaseFamilyInfo & {
|
||||||
|
sonnetKeys?: number;
|
||||||
|
sonnet35Keys?: number;
|
||||||
haikuKeys?: number;
|
haikuKeys?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -101,6 +114,7 @@ export type ServiceInfo = {
|
|||||||
"google-ai"?: string;
|
"google-ai"?: string;
|
||||||
"mistral-ai"?: string;
|
"mistral-ai"?: string;
|
||||||
aws?: string;
|
aws?: string;
|
||||||
|
gcp?: string;
|
||||||
azure?: string;
|
azure?: string;
|
||||||
"openai-image"?: string;
|
"openai-image"?: string;
|
||||||
"azure-image"?: string;
|
"azure-image"?: string;
|
||||||
@@ -114,6 +128,7 @@ export type ServiceInfo = {
|
|||||||
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
|
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
|
||||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||||
|
& { [f in GcpModelFamily]?: GcpInfo }
|
||||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||||
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||||
@@ -151,6 +166,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
|||||||
aws: {
|
aws: {
|
||||||
aws: `%BASE%/aws/claude`,
|
aws: `%BASE%/aws/claude`,
|
||||||
},
|
},
|
||||||
|
gcp: {
|
||||||
|
gcp: `%BASE%/gcp/claude`,
|
||||||
|
},
|
||||||
azure: {
|
azure: {
|
||||||
azure: `%BASE%/azure/openai`,
|
azure: `%BASE%/azure/openai`,
|
||||||
"azure-image": `%BASE%/azure/openai`,
|
"azure-image": `%BASE%/azure/openai`,
|
||||||
@@ -305,6 +323,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
k.service === "mistral-ai" ? 1 : 0
|
k.service === "mistral-ai" ? 1 : 0
|
||||||
);
|
);
|
||||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||||
|
increment(serviceStats, "gcp__keys", k.service === "gcp" ? 1 : 0);
|
||||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||||
|
|
||||||
let sumTokens = 0;
|
let sumTokens = 0;
|
||||||
@@ -396,6 +415,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||||
});
|
});
|
||||||
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
|
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||||
|
increment(modelStats, `aws-claude__awsSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||||
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
|
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
|
||||||
|
|
||||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||||
@@ -405,6 +425,21 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case "gcp": {
|
||||||
|
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
|
||||||
|
k.modelFamilies.forEach((f) => {
|
||||||
|
const tokens = k[`${f}Tokens`];
|
||||||
|
sumTokens += tokens;
|
||||||
|
sumCost += getTokenCostUsd(f, tokens);
|
||||||
|
increment(modelStats, `${f}__tokens`, tokens);
|
||||||
|
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||||
|
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||||
|
});
|
||||||
|
increment(modelStats, `gcp-claude__gcpSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||||
|
increment(modelStats, `gcp-claude__gcpSonnet35`, k.sonnet35Enabled ? 1 : 0);
|
||||||
|
increment(modelStats, `gcp-claude__gcpHaiku`, k.haikuEnabled ? 1 : 0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
assertNever(k.service);
|
assertNever(k.service);
|
||||||
}
|
}
|
||||||
@@ -416,7 +451,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
|||||||
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||||
const cost = getTokenCostUsd(family, tokens);
|
const cost = getTokenCostUsd(family, tokens);
|
||||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
|
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
|
||||||
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
||||||
activeKeys: modelStats.get(`${family}__active`) || 0,
|
activeKeys: modelStats.get(`${family}__active`) || 0,
|
||||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||||
@@ -446,6 +481,7 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
|||||||
case "aws":
|
case "aws":
|
||||||
if (family === "aws-claude") {
|
if (family === "aws-claude") {
|
||||||
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
|
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
|
||||||
|
info.sonnet35Keys = modelStats.get(`${family}__awsSonnet35`) || 0;
|
||||||
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
|
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
|
||||||
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
||||||
if (logged > 0) {
|
if (logged > 0) {
|
||||||
@@ -455,6 +491,13 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case "gcp":
|
||||||
|
if (family === "gcp-claude") {
|
||||||
|
info.sonnetKeys = modelStats.get(`${family}__gcpSonnet`) || 0;
|
||||||
|
info.sonnet35Keys = modelStats.get(`${family}__gcpSonnet35`) || 0;
|
||||||
|
info.haikuKeys = modelStats.get(`${family}__gcpHaiku`) || 0;
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
|||||||
{ key: key.hash, error: error.message },
|
{ key: key.hash, error: error.message },
|
||||||
"Network error while checking key; trying this key again in a minute."
|
"Network error while checking key; trying this key again in a minute."
|
||||||
);
|
);
|
||||||
const oneMinute = 10 * 1000;
|
const oneMinute = 60 * 1000;
|
||||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||||
this.updateKey(key.hash, { lastChecked: next });
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,277 @@
|
|||||||
|
import axios, { AxiosError } from "axios";
|
||||||
|
import crypto from "crypto";
|
||||||
|
import { KeyCheckerBase } from "../key-checker-base";
|
||||||
|
import type { GcpKey, GcpKeyProvider } from "./provider";
|
||||||
|
import { GcpModelFamily } from "../../models";
|
||||||
|
|
||||||
|
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||||
|
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||||
|
const GCP_HOST =
|
||||||
|
process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||||
|
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
|
||||||
|
`https://${GCP_HOST.replace("%REGION%", region)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
|
||||||
|
const TEST_MESSAGES = [
|
||||||
|
{ role: "user", content: "Hi!" },
|
||||||
|
{ role: "assistant", content: "Hello!" },
|
||||||
|
];
|
||||||
|
|
||||||
|
type UpdateFn = typeof GcpKeyProvider.prototype.update;
|
||||||
|
|
||||||
|
export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
||||||
|
constructor(keys: GcpKey[], updateKey: UpdateFn) {
|
||||||
|
super(keys, {
|
||||||
|
service: "gcp",
|
||||||
|
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||||
|
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||||
|
updateKey,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
protected async testKeyOrFail(key: GcpKey) {
|
||||||
|
let checks: Promise<boolean>[] = [];
|
||||||
|
const isInitialCheck = !key.lastChecked;
|
||||||
|
if (isInitialCheck) {
|
||||||
|
checks = [
|
||||||
|
this.invokeModel("claude-3-haiku@20240307", key, true),
|
||||||
|
this.invokeModel("claude-3-sonnet@20240229", key, true),
|
||||||
|
this.invokeModel("claude-3-opus@20240229", key, true),
|
||||||
|
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
|
||||||
|
];
|
||||||
|
|
||||||
|
const [sonnet, haiku, opus, sonnet35] =
|
||||||
|
await Promise.all(checks);
|
||||||
|
|
||||||
|
this.log.debug(
|
||||||
|
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
|
||||||
|
"GCP model initial tests complete."
|
||||||
|
);
|
||||||
|
|
||||||
|
const families: GcpModelFamily[] = [];
|
||||||
|
if (sonnet || sonnet35 || haiku) families.push("gcp-claude");
|
||||||
|
if (opus) families.push("gcp-claude-opus");
|
||||||
|
|
||||||
|
if (families.length === 0) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash },
|
||||||
|
"Key does not have access to any models; disabling."
|
||||||
|
);
|
||||||
|
return this.updateKey(key.hash, { isDisabled: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
this.updateKey(key.hash, {
|
||||||
|
sonnetEnabled: sonnet,
|
||||||
|
haikuEnabled: haiku,
|
||||||
|
sonnet35Enabled: sonnet35,
|
||||||
|
modelFamilies: families,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (key.haikuEnabled) {
|
||||||
|
await this.invokeModel("claude-3-haiku@20240307", key, false)
|
||||||
|
} else if (key.sonnetEnabled) {
|
||||||
|
await this.invokeModel("claude-3-sonnet@20240229", key, false)
|
||||||
|
} else if (key.sonnet35Enabled) {
|
||||||
|
await this.invokeModel("claude-3-5-sonnet@20240620", key, false)
|
||||||
|
} else {
|
||||||
|
await this.invokeModel("claude-3-opus@20240229", key, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||||
|
this.log.debug(
|
||||||
|
{ key: key.hash},
|
||||||
|
"GCP key check complete."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.log.info(
|
||||||
|
{
|
||||||
|
key: key.hash,
|
||||||
|
families: key.modelFamilies,
|
||||||
|
},
|
||||||
|
"Checked key."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected handleAxiosError(key: GcpKey, error: AxiosError) {
|
||||||
|
if (error.response && GcpKeyChecker.errorIsGcpError(error)) {
|
||||||
|
const { status, data } = error.response;
|
||||||
|
if (status === 400 || status === 401 || status === 403) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, error: data },
|
||||||
|
"Key is invalid or revoked. Disabling key."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||||
|
} else if (status === 429) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, error: data },
|
||||||
|
"Key is rate limited. Rechecking in a minute."
|
||||||
|
);
|
||||||
|
const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000);
|
||||||
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
|
} else {
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, status, error: data },
|
||||||
|
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||||
|
);
|
||||||
|
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const { response, cause } = error;
|
||||||
|
const { headers, status, data } = response ?? {};
|
||||||
|
this.log.error(
|
||||||
|
{ key: key.hash, status, headers, data, cause, error: error.message },
|
||||||
|
"Network error while checking key; trying this key again in a minute."
|
||||||
|
);
|
||||||
|
const oneMinute = 60 * 1000;
|
||||||
|
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||||
|
this.updateKey(key.hash, { lastChecked: next });
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempt to invoke the given model with the given key. Returns true if the
|
||||||
|
* key has access to the model, false if it does not. Throws an error if the
|
||||||
|
* key is disabled.
|
||||||
|
*/
|
||||||
|
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
||||||
|
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
||||||
|
const signedJWT = await GcpKeyChecker.createSignedJWT(creds.clientEmail, creds.privateKey)
|
||||||
|
const [accessToken, jwtError] = await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT)
|
||||||
|
if (accessToken === null) {
|
||||||
|
this.log.warn(
|
||||||
|
{ key: key.hash, jwtError },
|
||||||
|
"Unable to get the access token"
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const payload = {
|
||||||
|
max_tokens: 1,
|
||||||
|
messages: TEST_MESSAGES,
|
||||||
|
anthropic_version: "vertex-2023-10-16",
|
||||||
|
};
|
||||||
|
const { data, status } = await axios.post(
|
||||||
|
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
||||||
|
payload,
|
||||||
|
{
|
||||||
|
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
||||||
|
validateStatus: initial ? () => true : (status: number) => status >= 200 && status < 300
|
||||||
|
}
|
||||||
|
);
|
||||||
|
this.log.debug({ key: key.hash, data }, "Response from GCP");
|
||||||
|
|
||||||
|
if (initial) {
|
||||||
|
return (status >= 200 && status < 300) || (status === 429 || status === 529);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static errorIsGcpError(error: AxiosError): error is AxiosError {
|
||||||
|
const data = error.response?.data as any;
|
||||||
|
if (Array.isArray(data)) {
|
||||||
|
return data.length > 0 && data[0]?.error?.message;
|
||||||
|
} else {
|
||||||
|
return data?.error?.message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static async createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||||
|
let cryptoKey = await crypto.subtle.importKey(
|
||||||
|
"pkcs8",
|
||||||
|
GcpKeyChecker.str2ab(atob(pkey)),
|
||||||
|
{
|
||||||
|
name: "RSASSA-PKCS1-v1_5",
|
||||||
|
hash: { name: "SHA-256" },
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
["sign"]
|
||||||
|
);
|
||||||
|
|
||||||
|
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||||
|
const issued = Math.floor(Date.now() / 1000);
|
||||||
|
const expires = issued + 600;
|
||||||
|
|
||||||
|
const header = {
|
||||||
|
alg: "RS256",
|
||||||
|
typ: "JWT",
|
||||||
|
};
|
||||||
|
|
||||||
|
const payload = {
|
||||||
|
iss: email,
|
||||||
|
aud: authUrl,
|
||||||
|
iat: issued,
|
||||||
|
exp: expires,
|
||||||
|
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||||
|
};
|
||||||
|
|
||||||
|
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(header));
|
||||||
|
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(JSON.stringify(payload));
|
||||||
|
|
||||||
|
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||||
|
|
||||||
|
const signature = await crypto.subtle.sign(
|
||||||
|
"RSASSA-PKCS1-v1_5",
|
||||||
|
cryptoKey,
|
||||||
|
GcpKeyChecker.str2ab(unsignedToken)
|
||||||
|
);
|
||||||
|
|
||||||
|
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
|
||||||
|
return `${unsignedToken}.${encodedSignature}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
static async exchangeJwtForAccessToken(signed_jwt: string): Promise<[string | null, string]> {
|
||||||
|
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||||
|
const params = {
|
||||||
|
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||||
|
assertion: signed_jwt,
|
||||||
|
};
|
||||||
|
|
||||||
|
const r = await fetch(auth_url, {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||||
|
body: Object.entries(params)
|
||||||
|
.map(([k, v]) => `${k}=${v}`)
|
||||||
|
.join("&"),
|
||||||
|
}).then((res) => res.json());
|
||||||
|
|
||||||
|
if (r.access_token) {
|
||||||
|
return [r.access_token, ""];
|
||||||
|
}
|
||||||
|
|
||||||
|
return [null, JSON.stringify(r)];
|
||||||
|
}
|
||||||
|
|
||||||
|
static str2ab(str: string): ArrayBuffer {
|
||||||
|
const buffer = new ArrayBuffer(str.length);
|
||||||
|
const bufferView = new Uint8Array(buffer);
|
||||||
|
for (let i = 0; i < str.length; i++) {
|
||||||
|
bufferView[i] = str.charCodeAt(i);
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||||
|
let base64: string;
|
||||||
|
if (typeof data === "string") {
|
||||||
|
base64 = btoa(encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) => String.fromCharCode(parseInt("0x" + p1, 16))));
|
||||||
|
} else {
|
||||||
|
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||||
|
}
|
||||||
|
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
static getRequestHeaders(accessToken: string) {
|
||||||
|
return { "Authorization": `Bearer ${accessToken}`, "Content-Type": "application/json" };
|
||||||
|
}
|
||||||
|
|
||||||
|
static getCredentialsFromKey(key: GcpKey) {
|
||||||
|
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
|
||||||
|
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||||
|
throw new Error("Invalid GCP key");
|
||||||
|
}
|
||||||
|
const privateKey = rawPrivateKey
|
||||||
|
.replace(/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g, '')
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
return { projectId, clientEmail, region, privateKey };
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,242 @@
|
|||||||
|
import crypto from "crypto";
|
||||||
|
import { Key, KeyProvider } from "..";
|
||||||
|
import { config } from "../../../config";
|
||||||
|
import { logger } from "../../../logger";
|
||||||
|
import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||||
|
import { GcpKeyChecker } from "./checker";
|
||||||
|
import { PaymentRequiredError } from "../../errors";
|
||||||
|
|
||||||
|
type GcpKeyUsage = {
|
||||||
|
[K in GcpModelFamily as `${K}Tokens`]: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export interface GcpKey extends Key, GcpKeyUsage {
|
||||||
|
readonly service: "gcp";
|
||||||
|
readonly modelFamilies: GcpModelFamily[];
|
||||||
|
/** The time at which this key was last rate limited. */
|
||||||
|
rateLimitedAt: number;
|
||||||
|
/** The time until which this key is rate limited. */
|
||||||
|
rateLimitedUntil: number;
|
||||||
|
sonnetEnabled: boolean;
|
||||||
|
haikuEnabled: boolean;
|
||||||
|
sonnet35Enabled: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||||
|
* while we wait for other concurrent requests to finish.
|
||||||
|
*/
|
||||||
|
const RATE_LIMIT_LOCKOUT = 4000;
|
||||||
|
/**
|
||||||
|
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||||
|
* to be used again. This is to prevent the queue from flooding a key with too
|
||||||
|
* many requests while we wait to learn whether previous ones succeeded.
|
||||||
|
*/
|
||||||
|
const KEY_REUSE_DELAY = 500;
|
||||||
|
|
||||||
|
export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||||
|
readonly service = "gcp";
|
||||||
|
|
||||||
|
private keys: GcpKey[] = [];
|
||||||
|
private checker?: GcpKeyChecker;
|
||||||
|
private log = logger.child({ module: "key-provider", service: this.service });
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
const keyConfig = config.gcpCredentials?.trim();
|
||||||
|
if (!keyConfig) {
|
||||||
|
this.log.warn(
|
||||||
|
"GCP_CREDENTIALS is not set. GCP API will not be available."
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let bareKeys: string[];
|
||||||
|
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||||
|
for (const key of bareKeys) {
|
||||||
|
const newKey: GcpKey = {
|
||||||
|
key,
|
||||||
|
service: this.service,
|
||||||
|
modelFamilies: ["gcp-claude"],
|
||||||
|
isDisabled: false,
|
||||||
|
isRevoked: false,
|
||||||
|
promptCount: 0,
|
||||||
|
lastUsed: 0,
|
||||||
|
rateLimitedAt: 0,
|
||||||
|
rateLimitedUntil: 0,
|
||||||
|
hash: `gcp-${crypto
|
||||||
|
.createHash("sha256")
|
||||||
|
.update(key)
|
||||||
|
.digest("hex")
|
||||||
|
.slice(0, 8)}`,
|
||||||
|
lastChecked: 0,
|
||||||
|
sonnetEnabled: true,
|
||||||
|
haikuEnabled: false,
|
||||||
|
sonnet35Enabled: false,
|
||||||
|
["gcp-claudeTokens"]: 0,
|
||||||
|
["gcp-claude-opusTokens"]: 0,
|
||||||
|
};
|
||||||
|
this.keys.push(newKey);
|
||||||
|
}
|
||||||
|
this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys.");
|
||||||
|
}
|
||||||
|
|
||||||
|
public init() {
|
||||||
|
if (config.checkKeys) {
|
||||||
|
this.checker = new GcpKeyChecker(this.keys, this.update.bind(this));
|
||||||
|
this.checker.start();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public list() {
|
||||||
|
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||||
|
}
|
||||||
|
|
||||||
|
public get(model: string) {
|
||||||
|
const neededFamily = getGcpModelFamily(model);
|
||||||
|
|
||||||
|
// this is a horrible mess
|
||||||
|
// each of these should be separate model families, but adding model
|
||||||
|
// families is not low enough friction for the rate at which gcp claude
|
||||||
|
// model variants are added.
|
||||||
|
const needsSonnet35 =
|
||||||
|
model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude";
|
||||||
|
const needsSonnet =
|
||||||
|
!needsSonnet35 &&
|
||||||
|
model.includes("sonnet") &&
|
||||||
|
neededFamily === "gcp-claude";
|
||||||
|
const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude";
|
||||||
|
|
||||||
|
const availableKeys = this.keys.filter((k) => {
|
||||||
|
return (
|
||||||
|
!k.isDisabled &&
|
||||||
|
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not
|
||||||
|
(k.haikuEnabled || !needsHaiku) &&
|
||||||
|
(k.sonnet35Enabled || !needsSonnet35) &&
|
||||||
|
k.modelFamilies.includes(neededFamily)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
this.log.debug(
|
||||||
|
{
|
||||||
|
model,
|
||||||
|
neededFamily,
|
||||||
|
needsSonnet,
|
||||||
|
needsHaiku,
|
||||||
|
needsSonnet35,
|
||||||
|
availableKeys: availableKeys.length,
|
||||||
|
totalKeys: this.keys.length,
|
||||||
|
},
|
||||||
|
"Selecting GCP key"
|
||||||
|
);
|
||||||
|
|
||||||
|
if (availableKeys.length === 0) {
|
||||||
|
throw new PaymentRequiredError(
|
||||||
|
`No GCP keys available for model ${model}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// (largely copied from the OpenAI provider, without trial key support)
|
||||||
|
// Select a key, from highest priority to lowest priority:
|
||||||
|
// 1. Keys which are not rate limited
|
||||||
|
// a. If all keys were rate limited recently, select the least-recently
|
||||||
|
// rate limited key.
|
||||||
|
// 3. Keys which have not been used in the longest time
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
|
||||||
|
const keysByPriority = availableKeys.sort((a, b) => {
|
||||||
|
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||||
|
|
||||||
|
if (aRateLimited && !bRateLimited) return 1;
|
||||||
|
if (!aRateLimited && bRateLimited) return -1;
|
||||||
|
if (aRateLimited && bRateLimited) {
|
||||||
|
return a.rateLimitedAt - b.rateLimitedAt;
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.lastUsed - b.lastUsed;
|
||||||
|
});
|
||||||
|
|
||||||
|
const selectedKey = keysByPriority[0];
|
||||||
|
selectedKey.lastUsed = now;
|
||||||
|
this.throttle(selectedKey.hash);
|
||||||
|
return { ...selectedKey };
|
||||||
|
}
|
||||||
|
|
||||||
|
public disable(key: GcpKey) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||||
|
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||||
|
keyFromPool.isDisabled = true;
|
||||||
|
this.log.warn({ key: key.hash }, "Key disabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
public update(hash: string, update: Partial<GcpKey>) {
|
||||||
|
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||||
|
}
|
||||||
|
|
||||||
|
public available() {
|
||||||
|
return this.keys.filter((k) => !k.isDisabled).length;
|
||||||
|
}
|
||||||
|
|
||||||
|
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||||
|
const key = this.keys.find((k) => k.hash === hash);
|
||||||
|
if (!key) return;
|
||||||
|
key.promptCount++;
|
||||||
|
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
public getLockoutPeriod() {
|
||||||
|
// TODO: same exact behavior for three providers, should be refactored
|
||||||
|
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||||
|
// Don't lock out if there are no keys available or the queue will stall.
|
||||||
|
// Just let it through so the add-key middleware can throw an error.
|
||||||
|
if (activeKeys.length === 0) return 0;
|
||||||
|
|
||||||
|
const now = Date.now();
|
||||||
|
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||||
|
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||||
|
|
||||||
|
if (anyNotRateLimited) return 0;
|
||||||
|
|
||||||
|
// If all keys are rate-limited, return time until the first key is ready.
|
||||||
|
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This is called when we receive a 429, which means there are already five
|
||||||
|
* concurrent requests running on this key. We don't have any information on
|
||||||
|
* when these requests will resolve, so all we can do is wait a bit and try
|
||||||
|
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||||
|
* retrying in order to give the other requests a chance to finish.
|
||||||
|
*/
|
||||||
|
public markRateLimited(keyHash: string) {
|
||||||
|
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||||
|
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||||
|
const now = Date.now();
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||||
|
}
|
||||||
|
|
||||||
|
public recheck() {
|
||||||
|
this.keys.forEach(({ hash }) =>
|
||||||
|
this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false })
|
||||||
|
);
|
||||||
|
this.checker?.scheduleNextCheck();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||||
|
* prevent it from being immediately assigned to another request before the
|
||||||
|
* current one can be dispatched.
|
||||||
|
**/
|
||||||
|
private throttle(hash: string) {
|
||||||
|
const now = Date.now();
|
||||||
|
const key = this.keys.find((k) => k.hash === hash)!;
|
||||||
|
|
||||||
|
const currentRateLimit = key.rateLimitedUntil;
|
||||||
|
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||||
|
|
||||||
|
key.rateLimitedAt = now;
|
||||||
|
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -63,4 +63,5 @@ export { AnthropicKey } from "./anthropic/provider";
|
|||||||
export { OpenAIKey } from "./openai/provider";
|
export { OpenAIKey } from "./openai/provider";
|
||||||
export { GoogleAIKey } from "././google-ai/provider";
|
export { GoogleAIKey } from "././google-ai/provider";
|
||||||
export { AwsBedrockKey } from "./aws/provider";
|
export { AwsBedrockKey } from "./aws/provider";
|
||||||
|
export { GcpKey } from "./gcp/provider";
|
||||||
export { AzureOpenAIKey } from "./azure/provider";
|
export { AzureOpenAIKey } from "./azure/provider";
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
|||||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||||
|
import { GcpKeyProvider } from "./gcp/provider";
|
||||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ export class KeyPool {
|
|||||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||||
this.keyProviders.push(new MistralAIKeyProvider());
|
this.keyProviders.push(new MistralAIKeyProvider());
|
||||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||||
|
this.keyProviders.push(new GcpKeyProvider());
|
||||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,7 +130,11 @@ export class KeyPool {
|
|||||||
return "openai";
|
return "openai";
|
||||||
} else if (model.startsWith("claude-")) {
|
} else if (model.startsWith("claude-")) {
|
||||||
// https://console.anthropic.com/docs/api/reference#parameters
|
// https://console.anthropic.com/docs/api/reference#parameters
|
||||||
return "anthropic";
|
if (!model.includes('@')) {
|
||||||
|
return "anthropic";
|
||||||
|
} else {
|
||||||
|
return "gcp";
|
||||||
|
}
|
||||||
} else if (model.includes("gemini")) {
|
} else if (model.includes("gemini")) {
|
||||||
// https://developers.generativeai.google.com/models/language
|
// https://developers.generativeai.google.com/models/language
|
||||||
return "google-ai";
|
return "google-ai";
|
||||||
|
|||||||
+17
-2
@@ -5,7 +5,7 @@ import type { Request } from "express";
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
* services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure).
|
||||||
*/
|
*/
|
||||||
export type LLMService =
|
export type LLMService =
|
||||||
| "openai"
|
| "openai"
|
||||||
@@ -13,6 +13,7 @@ export type LLMService =
|
|||||||
| "google-ai"
|
| "google-ai"
|
||||||
| "mistral-ai"
|
| "mistral-ai"
|
||||||
| "aws"
|
| "aws"
|
||||||
|
| "gcp"
|
||||||
| "azure";
|
| "azure";
|
||||||
|
|
||||||
export type OpenAIModelFamily =
|
export type OpenAIModelFamily =
|
||||||
@@ -32,6 +33,7 @@ export type MistralAIModelFamily =
|
|||||||
// correspond to specific models. consider them rough pricing tiers.
|
// correspond to specific models. consider them rough pricing tiers.
|
||||||
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
|
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
|
||||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||||
|
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
|
||||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||||
export type ModelFamily =
|
export type ModelFamily =
|
||||||
| OpenAIModelFamily
|
| OpenAIModelFamily
|
||||||
@@ -39,6 +41,7 @@ export type ModelFamily =
|
|||||||
| GoogleAIModelFamily
|
| GoogleAIModelFamily
|
||||||
| MistralAIModelFamily
|
| MistralAIModelFamily
|
||||||
| AwsBedrockModelFamily
|
| AwsBedrockModelFamily
|
||||||
|
| GcpModelFamily
|
||||||
| AzureOpenAIModelFamily;
|
| AzureOpenAIModelFamily;
|
||||||
|
|
||||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||||
@@ -61,6 +64,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||||||
"mistral-large",
|
"mistral-large",
|
||||||
"aws-claude",
|
"aws-claude",
|
||||||
"aws-claude-opus",
|
"aws-claude-opus",
|
||||||
|
"gcp-claude",
|
||||||
|
"gcp-claude-opus",
|
||||||
"azure-turbo",
|
"azure-turbo",
|
||||||
"azure-gpt4",
|
"azure-gpt4",
|
||||||
"azure-gpt4-32k",
|
"azure-gpt4-32k",
|
||||||
@@ -77,6 +82,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
|||||||
"google-ai",
|
"google-ai",
|
||||||
"mistral-ai",
|
"mistral-ai",
|
||||||
"aws",
|
"aws",
|
||||||
|
"gcp",
|
||||||
"azure",
|
"azure",
|
||||||
] as const);
|
] as const);
|
||||||
|
|
||||||
@@ -93,6 +99,8 @@ export const MODEL_FAMILY_SERVICE: {
|
|||||||
"claude-opus": "anthropic",
|
"claude-opus": "anthropic",
|
||||||
"aws-claude": "aws",
|
"aws-claude": "aws",
|
||||||
"aws-claude-opus": "aws",
|
"aws-claude-opus": "aws",
|
||||||
|
"gcp-claude": "gcp",
|
||||||
|
"gcp-claude-opus": "gcp",
|
||||||
"azure-turbo": "azure",
|
"azure-turbo": "azure",
|
||||||
"azure-gpt4": "azure",
|
"azure-gpt4": "azure",
|
||||||
"azure-gpt4-32k": "azure",
|
"azure-gpt4-32k": "azure",
|
||||||
@@ -176,6 +184,11 @@ export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
|
|||||||
return "aws-claude";
|
return "aws-claude";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getGcpModelFamily(model: string): GcpModelFamily {
|
||||||
|
if (model.includes("opus")) return "gcp-claude-opus";
|
||||||
|
return "gcp-claude";
|
||||||
|
}
|
||||||
|
|
||||||
export function getAzureOpenAIModelFamily(
|
export function getAzureOpenAIModelFamily(
|
||||||
model: string,
|
model: string,
|
||||||
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||||
@@ -210,10 +223,12 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||||
let modelFamily: ModelFamily;
|
let modelFamily: ModelFamily;
|
||||||
|
|
||||||
// Weird special case for AWS/Azure because they serve multiple models from
|
// Weird special case for AWS/GCP/Azure because they serve multiple models from
|
||||||
// different vendors, even if currently only one is supported.
|
// different vendors, even if currently only one is supported.
|
||||||
if (req.service === "aws") {
|
if (req.service === "aws") {
|
||||||
modelFamily = getAwsBedrockModelFamily(model);
|
modelFamily = getAwsBedrockModelFamily(model);
|
||||||
|
} else if (req.service === "gcp") {
|
||||||
|
modelFamily = getGcpModelFamily(model);
|
||||||
} else if (req.service === "azure") {
|
} else if (req.service === "azure") {
|
||||||
modelFamily = getAzureOpenAIModelFamily(model);
|
modelFamily = getAzureOpenAIModelFamily(model);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -30,10 +30,12 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
|||||||
cost = 0.00001;
|
cost = 0.00001;
|
||||||
break;
|
break;
|
||||||
case "aws-claude":
|
case "aws-claude":
|
||||||
|
case "gcp-claude":
|
||||||
case "claude":
|
case "claude":
|
||||||
cost = 0.000008;
|
cost = 0.000008;
|
||||||
break;
|
break;
|
||||||
case "aws-claude-opus":
|
case "aws-claude-opus":
|
||||||
|
case "gcp-claude-opus":
|
||||||
case "claude-opus":
|
case "claude-opus":
|
||||||
cost = 0.000015;
|
cost = 0.000015;
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import { v4 as uuid } from "uuid";
|
|||||||
import { config, getFirebaseApp } from "../../config";
|
import { config, getFirebaseApp } from "../../config";
|
||||||
import {
|
import {
|
||||||
getAwsBedrockModelFamily,
|
getAwsBedrockModelFamily,
|
||||||
|
getGcpModelFamily,
|
||||||
getAzureOpenAIModelFamily,
|
getAzureOpenAIModelFamily,
|
||||||
getClaudeModelFamily,
|
getClaudeModelFamily,
|
||||||
getGoogleAIModelFamily,
|
getGoogleAIModelFamily,
|
||||||
@@ -417,6 +418,7 @@ function getModelFamilyForQuotaUsage(
|
|||||||
// differentiate between Azure and OpenAI variants of the same model.
|
// differentiate between Azure and OpenAI variants of the same model.
|
||||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||||
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
|
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
|
||||||
|
if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model);
|
||||||
|
|
||||||
switch (api) {
|
switch (api) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
|||||||
Reference in New Issue
Block a user