Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d5059534a |
+3
-13
@@ -40,21 +40,15 @@ NODE_ENV=production
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus
|
||||
# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small
|
||||
# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude
|
||||
# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k
|
||||
# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
||||
# 'azure-dall-e' to the list of allowed model families.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
|
||||
# Which services can be used to process prompts containing images via multimodal
|
||||
# models. The following services are recognized:
|
||||
# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai
|
||||
# openai | anthropic | aws | azure | google-ai | mistral-ai
|
||||
# Do not enable this feature unless all users are trusted, as you will be liable
|
||||
# for any user-submitted images containing illegal content.
|
||||
# By default, no image services are allowed and image prompts are rejected.
|
||||
@@ -124,7 +118,6 @@ NODE_ENV=production
|
||||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
# TOKEN_QUOTA_GCP_CLAUDE=0
|
||||
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
||||
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||
@@ -149,15 +142,12 @@ NODE_ENV=production
|
||||
|
||||
# You can add multiple API keys by separating them with a comma.
|
||||
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
|
||||
# For GCP credentials, separate the project ID, client email, region, and private key with a colon.
|
||||
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||
GCP_CREDENTIALS=project-id:client-email:region:private-key
|
||||
|
||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||
# PROXY_KEY=your-secret-key
|
||||
|
||||
@@ -7,8 +7,9 @@ Reverse proxy server for various LLM APIs.
|
||||
- [Features](#features)
|
||||
- [Usage Instructions](#usage-instructions)
|
||||
- [Self-hosting](#self-hosting)
|
||||
- [Huggingface (outdated, not advised)](#huggingface-outdated-not-advised)
|
||||
- [Render (outdated, not advised)](#render-outdated-not-advised)
|
||||
- [Alternatives](#alternatives)
|
||||
- [Huggingface (outdated, not advised)](#huggingface-outdated-not-advised)
|
||||
- [Render (outdated, not advised)](#render-outdated-not-advised)
|
||||
- [Local Development](#local-development)
|
||||
|
||||
## What is this?
|
||||
@@ -19,7 +20,6 @@ This project allows you to run a reverse proxy server for various LLM APIs.
|
||||
- [x] [OpenAI](https://openai.com/)
|
||||
- [x] [Anthropic](https://www.anthropic.com/)
|
||||
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
|
||||
- [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/)
|
||||
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
|
||||
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
|
||||
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
|
||||
@@ -41,6 +41,9 @@ If you'd like to run your own instance of this server, you'll need to deploy it
|
||||
|
||||
**Ensure you set the `TRUSTED_PROXIES` environment variable according to your deployment.** Refer to [.env.example](./.env.example) and [config.ts](./src/config.ts) for more information.
|
||||
|
||||
### Alternatives
|
||||
Fiz and Sekrit are working on some alternative ways to deploy this conveniently. While I'm not involved in this effort beyond providing technical advice regarding my code, I'll link to their work here for convenience: [Sekrit's rentry](https://rentry.org/sekrit)
|
||||
|
||||
### Huggingface (outdated, not advised)
|
||||
[See here for instructions on how to deploy to a Huggingface Space.](./docs/deploy-huggingface.md)
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# Configuring the proxy for Vertex AI (GCP)
|
||||
|
||||
The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs.
|
||||
|
||||
- [Setting keys](#setting-keys)
|
||||
- [Setup Vertex AI](#setup-vertex-ai)
|
||||
- [Supported model IDs](#supported-model-ids)
|
||||
|
||||
## Setting keys
|
||||
|
||||
Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys.
|
||||
|
||||
Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----
|
||||
```
|
||||
|
||||
## Setup Vertex AI
|
||||
1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days)
|
||||
2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API.
|
||||
3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models.
|
||||
4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator".
|
||||
5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically.
|
||||
6. The required credential is in the JSON file you just downloaded.
|
||||
|
||||
## Supported model IDs
|
||||
Users can send these model IDs to the proxy to invoke the corresponding models.
|
||||
- **Claude**
|
||||
- `claude-3-haiku@20240307`
|
||||
- `claude-3-sonnet@20240229`
|
||||
- `claude-3-opus@20240229`
|
||||
- `claude-3-5-sonnet@20240620`
|
||||
+1
-1
@@ -129,7 +129,7 @@ also significantly reduce hash rates on mobile devices.
|
||||
- Intel Core i9-13900K (Chrome, in VM limited to 4 cores): 12.2 - 13.0 H/s
|
||||
- iPad Pro (M2) (Safari, 6 workers): 8.0 - 10 H/s
|
||||
- Thermal throttles early. 8 cores is normal concurrency, but unstable.
|
||||
- iPhone 15 Pro Max (Safari): 4.0 - 4.6 H/s
|
||||
- iPhone 13 Pro (Safari): 4.0 - 4.6 H/s
|
||||
- Samsung Galaxy S10e (Chrome): 3.6 - 3.8 H/s
|
||||
- This is a 2019 phone almost matching an iPhone five years newer because of
|
||||
bad Safari performance.
|
||||
|
||||
Generated
+141
-226
@@ -11,14 +11,14 @@
|
||||
"dependencies": {
|
||||
"@anthropic-ai/tokenizer": "^0.0.4",
|
||||
"@aws-crypto/sha256-js": "^5.2.0",
|
||||
"@huggingface/jinja": "^0.3.0",
|
||||
"@node-rs/argon2": "^1.8.3",
|
||||
"@smithy/eventstream-codec": "^2.1.3",
|
||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||
"@smithy/protocol-http": "^3.2.1",
|
||||
"@smithy/signature-v4": "^2.1.3",
|
||||
"@smithy/types": "^2.10.1",
|
||||
"@smithy/util-utf8": "^2.1.1",
|
||||
"axios": "^1.7.4",
|
||||
"axios": "^1.3.5",
|
||||
"better-sqlite3": "^10.0.0",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
@@ -29,7 +29,7 @@
|
||||
"ejs": "^3.1.10",
|
||||
"express": "^4.18.2",
|
||||
"express-session": "^1.17.3",
|
||||
"firebase-admin": "^12.3.1",
|
||||
"firebase-admin": "^12.1.0",
|
||||
"glob": "^10.3.12",
|
||||
"googleapis": "^122.0.0",
|
||||
"http-proxy-middleware": "^3.0.0-beta.1",
|
||||
@@ -39,7 +39,7 @@
|
||||
"node-schedule": "^2.1.1",
|
||||
"pino": "^8.11.0",
|
||||
"pino-http": "^8.3.3",
|
||||
"sanitize-html": "^2.13.0",
|
||||
"sanitize-html": "2.12.1",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"source-map-support": "^0.5.21",
|
||||
@@ -51,7 +51,6 @@
|
||||
"zod-error": "^1.5.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@smithy/types": "^3.3.0",
|
||||
"@types/better-sqlite3": "^7.6.10",
|
||||
"@types/cookie-parser": "^1.4.3",
|
||||
"@types/cors": "^2.8.13",
|
||||
@@ -152,17 +151,6 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@aws-sdk/types/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@aws-sdk/util-utf8-browser": {
|
||||
"version": "3.259.0",
|
||||
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
|
||||
@@ -554,9 +542,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@fastify/busboy": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-3.0.0.tgz",
|
||||
"integrity": "sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w=="
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-2.1.1.tgz",
|
||||
"integrity": "sha512-vBZP4NlzfOlerQTnba4aqZoMhE/a9HY7HRqoOPaETQcSQuWEIyZMHGfVu6w9wGtGK5fED5qRs2DteVCjOH60sA==",
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/@firebase/app-check-interop-types": {
|
||||
"version": "0.3.1",
|
||||
@@ -635,14 +626,14 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@google-cloud/firestore": {
|
||||
"version": "7.9.0",
|
||||
"resolved": "https://registry.npmjs.org/@google-cloud/firestore/-/firestore-7.9.0.tgz",
|
||||
"integrity": "sha512-c4ALHT3G08rV7Zwv8Z2KG63gZh66iKdhCBeDfCpIkLrjX6EAjTD/szMdj14M+FnQuClZLFfW5bAgoOjfNmLtJg==",
|
||||
"version": "7.6.0",
|
||||
"resolved": "https://registry.npmjs.org/@google-cloud/firestore/-/firestore-7.6.0.tgz",
|
||||
"integrity": "sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"fast-deep-equal": "^3.1.1",
|
||||
"functional-red-black-tree": "^1.0.1",
|
||||
"google-gax": "^4.3.3",
|
||||
"google-gax": "^4.3.1",
|
||||
"protobufjs": "^7.2.6"
|
||||
},
|
||||
"engines": {
|
||||
@@ -848,12 +839,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@grpc/grpc-js": {
|
||||
"version": "1.11.1",
|
||||
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.11.1.tgz",
|
||||
"integrity": "sha512-gyt/WayZrVPH2w/UTLansS7F9Nwld472JxxaETamrM8HNlsa+jSLNyKAZmhxI2Me4c3mQHFiS1wWHDY1g1Kthw==",
|
||||
"version": "1.10.6",
|
||||
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.10.6.tgz",
|
||||
"integrity": "sha512-xP58G7wDQ4TCmN/cMUHh00DS7SRDv/+lC+xFLrTkMIN8h55X5NhZMLYbvy7dSELP15qlI6hPhNCRWVMtZMwqLA==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@grpc/proto-loader": "^0.7.13",
|
||||
"@grpc/proto-loader": "^0.7.10",
|
||||
"@js-sdsl/ordered-map": "^4.4.2"
|
||||
},
|
||||
"engines": {
|
||||
@@ -861,14 +852,14 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@grpc/proto-loader": {
|
||||
"version": "0.7.13",
|
||||
"resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.7.13.tgz",
|
||||
"integrity": "sha512-AiXO/bfe9bmxBjxxtYxFAXGZvMaN5s8kO+jBHAJCON8rJoB5YS/D6X7ZNc6XQkuHNmyl4CYaMI1fJ/Gn27RGGw==",
|
||||
"version": "0.7.12",
|
||||
"resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.7.12.tgz",
|
||||
"integrity": "sha512-DCVwMxqYzpUCiDMl7hQ384FqP4T3DbNpXU8pt681l3UWCip1WUiD5JrkImUwCB9a7f2cq4CUTmi5r/xIMRPY1Q==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"lodash.camelcase": "^4.3.0",
|
||||
"long": "^5.0.0",
|
||||
"protobufjs": "^7.2.5",
|
||||
"protobufjs": "^7.2.4",
|
||||
"yargs": "^17.7.2"
|
||||
},
|
||||
"bin": {
|
||||
@@ -878,14 +869,6 @@
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/@huggingface/jinja": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
|
||||
"integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/cliui": {
|
||||
"version": "8.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
|
||||
@@ -1339,17 +1322,6 @@
|
||||
"tslib": "^2.5.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/eventstream-serde-node": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
|
||||
@@ -1363,17 +1335,6 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/eventstream-serde-universal": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
|
||||
@@ -1387,17 +1348,6 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/is-array-buffer": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
|
||||
@@ -1421,17 +1371,6 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/protocol-http/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/signature-v4": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
|
||||
@@ -1450,29 +1389,17 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/signature-v4/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"node_modules/@smithy/types": {
|
||||
"version": "2.10.1",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz",
|
||||
"integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
"tslib": "^2.5.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/types": {
|
||||
"version": "3.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz",
|
||||
"integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/util-buffer-from": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
|
||||
@@ -1508,17 +1435,6 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/util-middleware/node_modules/@smithy/types": {
|
||||
"version": "2.12.0",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
|
||||
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
|
||||
"dependencies": {
|
||||
"tslib": "^2.6.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@smithy/util-uri-escape": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
|
||||
@@ -1673,9 +1589,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@types/jsonwebtoken": {
|
||||
"version": "9.0.6",
|
||||
"resolved": "https://registry.npmjs.org/@types/jsonwebtoken/-/jsonwebtoken-9.0.6.tgz",
|
||||
"integrity": "sha512-/5hndP5dCjloafCXns6SZyESp3Ldq7YjH3zwzwczYnjxIT0Fqzk5ROSYVGfFyczIue7IUEj8hkvLbPoLQ18vQw==",
|
||||
"version": "9.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@types/jsonwebtoken/-/jsonwebtoken-9.0.2.tgz",
|
||||
"integrity": "sha512-drE6uz7QBKq1fYqqoFKTDRdFCPHd5TCub75BM+D+cMx7NU9hUz7SESLfC2fSCXVFMO5Yj8sOWHuGqPgjc+fz0Q==",
|
||||
"dependencies": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
@@ -1974,11 +1890,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.7.4",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz",
|
||||
"integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==",
|
||||
"version": "1.6.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
|
||||
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
"follow-redirects": "^1.15.0",
|
||||
"form-data": "^4.0.0",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
@@ -2126,11 +2042,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/braces": {
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
|
||||
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
|
||||
"version": "3.0.2",
|
||||
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
|
||||
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
|
||||
"dependencies": {
|
||||
"fill-range": "^7.1.1"
|
||||
"fill-range": "^7.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
@@ -3080,14 +2996,24 @@
|
||||
"resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz",
|
||||
"integrity": "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g=="
|
||||
},
|
||||
"node_modules/farmhash-modern": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/farmhash-modern/-/farmhash-modern-1.1.0.tgz",
|
||||
"integrity": "sha512-6ypT4XfgqJk/F3Yuv4SX26I3doUjt0GTG4a+JgWxXQpxXzTBq8fPUeGHfcYMMDPHJHm3yPOSjaeBwBGAHWXCdA==",
|
||||
"node_modules/farmhash": {
|
||||
"version": "3.3.1",
|
||||
"resolved": "https://registry.npmjs.org/farmhash/-/farmhash-3.3.1.tgz",
|
||||
"integrity": "sha512-XUizHanzlr/v7suBr/o85HSakOoWh6HKXZjFYl5C2+Gj0f0rkw+XTUZzrd9odDsgI9G5tRUcF4wSbKaX04T0DQ==",
|
||||
"hasInstallScript": true,
|
||||
"dependencies": {
|
||||
"node-addon-api": "^5.1.0",
|
||||
"prebuild-install": "^7.1.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/farmhash/node_modules/node-addon-api": {
|
||||
"version": "5.1.0",
|
||||
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-5.1.0.tgz",
|
||||
"integrity": "sha512-eh0GgfEkpnoWDq+VY8OyvYhFEzBk6jIYbRKdIlyTiAXIVJ8PyBaKb0rp7oDtoddbdoHWhq8wwr+XZ81F1rpNdA=="
|
||||
},
|
||||
"node_modules/fast-copy": {
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/fast-copy/-/fast-copy-3.0.1.tgz",
|
||||
@@ -3125,9 +3051,9 @@
|
||||
"integrity": "sha512-VhXlQgj9ioXCqGstD37E/HBeqEGV/qOD/kmbVG8h5xKBYvM1L3lR1Zn4555cQ8GkYbJa8aJSipLPndE1k6zK2w=="
|
||||
},
|
||||
"node_modules/fast-xml-parser": {
|
||||
"version": "4.4.1",
|
||||
"resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.4.1.tgz",
|
||||
"integrity": "sha512-xkjOecfnKGkSsOwtZ5Pz7Us/T6mrbPQrq0nh+aCO5V9nk5NLWmasAHumTKjiPJPWANe+kAZ84Jc8ooJkzZ88Sw==",
|
||||
"version": "4.3.6",
|
||||
"resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.3.6.tgz",
|
||||
"integrity": "sha512-M2SovcRxD4+vC493Uc2GZVcZaj66CCJhWurC4viynVSTvrpErCShNcDz1lAho6n9REQKvL/ll4A4/fw6Y9z8nw==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
@@ -3190,9 +3116,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/fill-range": {
|
||||
"version": "7.1.1",
|
||||
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
|
||||
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
|
||||
"version": "7.0.1",
|
||||
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
|
||||
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
|
||||
"dependencies": {
|
||||
"to-regex-range": "^5.0.1"
|
||||
},
|
||||
@@ -3218,46 +3144,35 @@
|
||||
}
|
||||
},
|
||||
"node_modules/firebase-admin": {
|
||||
"version": "12.3.1",
|
||||
"resolved": "https://registry.npmjs.org/firebase-admin/-/firebase-admin-12.3.1.tgz",
|
||||
"integrity": "sha512-vEr3s3esl8nPIA9r/feDT4nzIXCfov1CyyCSpMQWp6x63Q104qke0MEGZlrHUZVROtl8FLus6niP/M9I1s4VBA==",
|
||||
"version": "12.1.0",
|
||||
"resolved": "https://registry.npmjs.org/firebase-admin/-/firebase-admin-12.1.0.tgz",
|
||||
"integrity": "sha512-bU7uPKMmIXAihWxntpY/Ma9zucn5y3ec+HQPqFQ/zcEfP9Avk9E/6D8u+yT/VwKHNZyg7yDVWOoJi73TIdR4Ww==",
|
||||
"dependencies": {
|
||||
"@fastify/busboy": "^3.0.0",
|
||||
"@fastify/busboy": "^2.1.0",
|
||||
"@firebase/database-compat": "^1.0.2",
|
||||
"@firebase/database-types": "^1.0.0",
|
||||
"@types/node": "^22.0.1",
|
||||
"farmhash-modern": "^1.1.0",
|
||||
"@types/node": "^20.10.3",
|
||||
"farmhash": "^3.3.0",
|
||||
"jsonwebtoken": "^9.0.0",
|
||||
"jwks-rsa": "^3.1.0",
|
||||
"jwks-rsa": "^3.0.1",
|
||||
"long": "^5.2.3",
|
||||
"node-forge": "^1.3.1",
|
||||
"uuid": "^10.0.0"
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@google-cloud/firestore": "^7.7.0",
|
||||
"@google-cloud/firestore": "^7.1.0",
|
||||
"@google-cloud/storage": "^7.7.0"
|
||||
}
|
||||
},
|
||||
"node_modules/firebase-admin/node_modules/@types/node": {
|
||||
"version": "22.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.2.0.tgz",
|
||||
"integrity": "sha512-bm6EG6/pCpkxDf/0gDNDdtDILMOHgaQBVOJGdwsqClnxA3xL6jtMv76rLBc006RVMWbmaf0xbmom4Z/5o2nRkQ==",
|
||||
"version": "20.12.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.12.7.tgz",
|
||||
"integrity": "sha512-wq0cICSkRLVaf3UGLMGItu/PtdY7oaXaI/RVU+xliKVOtRna3PRY57ZDfztpDL0n11vfymMUnXv8QwYCO7L1wg==",
|
||||
"dependencies": {
|
||||
"undici-types": "~6.13.0"
|
||||
}
|
||||
},
|
||||
"node_modules/firebase-admin/node_modules/uuid": {
|
||||
"version": "10.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz",
|
||||
"integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==",
|
||||
"funding": [
|
||||
"https://github.com/sponsors/broofa",
|
||||
"https://github.com/sponsors/ctavan"
|
||||
],
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
"undici-types": "~5.26.4"
|
||||
}
|
||||
},
|
||||
"node_modules/follow-redirects": {
|
||||
@@ -3485,21 +3400,21 @@
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax": {
|
||||
"version": "4.3.9",
|
||||
"resolved": "https://registry.npmjs.org/google-gax/-/google-gax-4.3.9.tgz",
|
||||
"integrity": "sha512-tcjQr7sXVGMdlvcG25wSv98ap1dtF4Z6mcV0rztGIddOcezw4YMb/uTXg72JPrLep+kXcVjaJjg6oo3KLf4itQ==",
|
||||
"version": "4.3.2",
|
||||
"resolved": "https://registry.npmjs.org/google-gax/-/google-gax-4.3.2.tgz",
|
||||
"integrity": "sha512-2mw7qgei2LPdtGrmd1zvxQviOcduTnsvAWYzCxhOWXK4IQKmQztHnDQwD0ApB690fBQJemFKSU7DnceAy3RLzw==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"@grpc/grpc-js": "^1.10.9",
|
||||
"@grpc/proto-loader": "^0.7.13",
|
||||
"@grpc/grpc-js": "~1.10.0",
|
||||
"@grpc/proto-loader": "^0.7.0",
|
||||
"@types/long": "^4.0.0",
|
||||
"abort-controller": "^3.0.0",
|
||||
"duplexify": "^4.0.0",
|
||||
"google-auth-library": "^9.3.0",
|
||||
"node-fetch": "^2.7.0",
|
||||
"node-fetch": "^2.6.1",
|
||||
"object-hash": "^3.0.0",
|
||||
"proto3-json-serializer": "^2.0.2",
|
||||
"protobufjs": "^7.3.2",
|
||||
"proto3-json-serializer": "^2.0.0",
|
||||
"protobufjs": "7.2.6",
|
||||
"retry-request": "^7.0.0",
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
@@ -3520,9 +3435,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/debug": {
|
||||
"version": "4.3.6",
|
||||
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.6.tgz",
|
||||
"integrity": "sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==",
|
||||
"version": "4.3.4",
|
||||
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
|
||||
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"ms": "2.1.2"
|
||||
@@ -3537,34 +3452,21 @@
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/gaxios": {
|
||||
"version": "6.7.0",
|
||||
"resolved": "https://registry.npmjs.org/gaxios/-/gaxios-6.7.0.tgz",
|
||||
"integrity": "sha512-DSrkyMTfAnAm4ks9Go20QGOcXEyW/NmZhvTYBU2rb4afBB393WIMQPWPEDMl/k8xqiNN9HYq2zao3oWXsdl2Tg==",
|
||||
"version": "6.5.0",
|
||||
"resolved": "https://registry.npmjs.org/gaxios/-/gaxios-6.5.0.tgz",
|
||||
"integrity": "sha512-R9QGdv8j4/dlNoQbX3hSaK/S0rkMijqjVvW3YM06CoBdbU/VdKd159j4hePpng0KuE6Lh6JJ7UdmVGJZFcAG1w==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"extend": "^3.0.2",
|
||||
"https-proxy-agent": "^7.0.1",
|
||||
"is-stream": "^2.0.0",
|
||||
"node-fetch": "^2.6.9",
|
||||
"uuid": "^10.0.0"
|
||||
"uuid": "^9.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/gaxios/node_modules/uuid": {
|
||||
"version": "10.0.0",
|
||||
"resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz",
|
||||
"integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==",
|
||||
"funding": [
|
||||
"https://github.com/sponsors/broofa",
|
||||
"https://github.com/sponsors/ctavan"
|
||||
],
|
||||
"optional": true,
|
||||
"bin": {
|
||||
"uuid": "dist/bin/uuid"
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/gcp-metadata": {
|
||||
"version": "6.1.0",
|
||||
"resolved": "https://registry.npmjs.org/gcp-metadata/-/gcp-metadata-6.1.0.tgz",
|
||||
@@ -3579,9 +3481,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/google-auth-library": {
|
||||
"version": "9.13.0",
|
||||
"resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-9.13.0.tgz",
|
||||
"integrity": "sha512-p9Y03Uzp/Igcs36zAaB0XTSwZ8Y0/tpYiz5KIde5By+H9DCVUSYtDWZu6aFXsWTqENMb8BD/pDT3hR8NVrPkfA==",
|
||||
"version": "9.8.0",
|
||||
"resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-9.8.0.tgz",
|
||||
"integrity": "sha512-TJJXFzMlVGRlIH27gYZ6XXyPf5Y3OItsKFfefsDAafNNywYRTkei83nEO29IrYj8GtdHWU78YnW+YZdaZaXIJA==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"base64-js": "^1.3.0",
|
||||
@@ -3609,9 +3511,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/google-gax/node_modules/https-proxy-agent": {
|
||||
"version": "7.0.5",
|
||||
"resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.5.tgz",
|
||||
"integrity": "sha512-1e4Wqeblerz+tMKPIq2EMGiiWW1dIjZOksyHWSUm1rmuvw/how9hBHZ38lAGj5ID4Ik6EdkOw7NmWPy6LAwalw==",
|
||||
"version": "7.0.4",
|
||||
"resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.4.tgz",
|
||||
"integrity": "sha512-wlwpilI7YdjSkWaQ/7omYBMTliDcmCN8OLihO6I9B86g06lMyAoqgoDpV0XqoaPOKj+0DIdAvnsWfyAAhmimcg==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"agent-base": "^7.0.2",
|
||||
@@ -4151,9 +4053,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/jose": {
|
||||
"version": "4.15.9",
|
||||
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz",
|
||||
"integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==",
|
||||
"version": "4.15.5",
|
||||
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.5.tgz",
|
||||
"integrity": "sha512-jc7BFxgKPKi94uOvEmzlSWFFe2+vASyXaKUpdQKatWAESU2MWjDfFf0fdfc83CDKcA5QecabZeNLyfhe3yKNkg==",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/panva"
|
||||
}
|
||||
@@ -4225,25 +4127,25 @@
|
||||
}
|
||||
},
|
||||
"node_modules/jwks-rsa": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/jwks-rsa/-/jwks-rsa-3.1.0.tgz",
|
||||
"integrity": "sha512-v7nqlfezb9YfHHzYII3ef2a2j1XnGeSE/bK3WfumaYCqONAIstJbrEGapz4kadScZzEt7zYCN7bucj8C0Mv/Rg==",
|
||||
"version": "3.0.1",
|
||||
"resolved": "https://registry.npmjs.org/jwks-rsa/-/jwks-rsa-3.0.1.tgz",
|
||||
"integrity": "sha512-UUOZ0CVReK1QVU3rbi9bC7N5/le8ziUj0A2ef1Q0M7OPD2KvjEYizptqIxGIo6fSLYDkqBrazILS18tYuRc8gw==",
|
||||
"dependencies": {
|
||||
"@types/express": "^4.17.17",
|
||||
"@types/jsonwebtoken": "^9.0.2",
|
||||
"@types/express": "^4.17.14",
|
||||
"@types/jsonwebtoken": "^9.0.0",
|
||||
"debug": "^4.3.4",
|
||||
"jose": "^4.14.6",
|
||||
"jose": "^4.10.4",
|
||||
"limiter": "^1.1.5",
|
||||
"lru-memoizer": "^2.2.0"
|
||||
"lru-memoizer": "^2.1.4"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14"
|
||||
}
|
||||
},
|
||||
"node_modules/jwks-rsa/node_modules/debug": {
|
||||
"version": "4.3.6",
|
||||
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.6.tgz",
|
||||
"integrity": "sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==",
|
||||
"version": "4.3.4",
|
||||
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
|
||||
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
|
||||
"dependencies": {
|
||||
"ms": "2.1.2"
|
||||
},
|
||||
@@ -4294,8 +4196,7 @@
|
||||
"node_modules/long": {
|
||||
"version": "5.2.3",
|
||||
"resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
|
||||
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==",
|
||||
"optional": true
|
||||
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
|
||||
},
|
||||
"node_modules/long-timeout": {
|
||||
"version": "0.1.1",
|
||||
@@ -4314,14 +4215,28 @@
|
||||
}
|
||||
},
|
||||
"node_modules/lru-memoizer": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/lru-memoizer/-/lru-memoizer-2.3.0.tgz",
|
||||
"integrity": "sha512-GXn7gyHAMhO13WSKrIiNfztwxodVsP8IoZ3XfrJV4yH2x0/OeTO/FIaAHTY5YekdGgW94njfuKmyyt1E0mR6Ug==",
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/lru-memoizer/-/lru-memoizer-2.2.0.tgz",
|
||||
"integrity": "sha512-QfOZ6jNkxCcM/BkIPnFsqDhtrazLRsghi9mBwFAzol5GCvj4EkFT899Za3+QwikCg5sRX8JstioBDwOxEyzaNw==",
|
||||
"dependencies": {
|
||||
"lodash.clonedeep": "^4.5.0",
|
||||
"lru-cache": "6.0.0"
|
||||
"lru-cache": "~4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/lru-memoizer/node_modules/lru-cache": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-4.0.2.tgz",
|
||||
"integrity": "sha512-uQw9OqphAGiZhkuPlpFGmdTU2tEuhxTourM/19qGJrxBPHAr/f8BT1a0i/lOclESnGatdJG/UCkP9kZB/Lh1iw==",
|
||||
"dependencies": {
|
||||
"pseudomap": "^1.0.1",
|
||||
"yallist": "^2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/lru-memoizer/node_modules/yallist": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-2.1.2.tgz",
|
||||
"integrity": "sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A=="
|
||||
},
|
||||
"node_modules/luxon": {
|
||||
"version": "3.4.2",
|
||||
"resolved": "https://registry.npmjs.org/luxon/-/luxon-3.4.2.tgz",
|
||||
@@ -4580,9 +4495,9 @@
|
||||
"integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
|
||||
},
|
||||
"node_modules/node-fetch": {
|
||||
"version": "2.7.0",
|
||||
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz",
|
||||
"integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==",
|
||||
"version": "2.6.9",
|
||||
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz",
|
||||
"integrity": "sha512-DJm/CJkZkRjKKj4Zi4BsKVZh3ValV5IR5s7LVZnW+6YMh0W1BfNA8XSs6DLMGYlId5F3KnA70uu2qepcR08Qqg==",
|
||||
"dependencies": {
|
||||
"whatwg-url": "^5.0.0"
|
||||
},
|
||||
@@ -5066,9 +4981,9 @@
|
||||
"integrity": "sha512-/1WZ8+VQjR6avWOgHeEPd7SDQmFQ1B5mC1eRXsCm5TarlNmx/wCsa5GEaxGm05BORRtyG/Ex/3xq3TuRvq57qg=="
|
||||
},
|
||||
"node_modules/proto3-json-serializer": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/proto3-json-serializer/-/proto3-json-serializer-2.0.2.tgz",
|
||||
"integrity": "sha512-SAzp/O4Yh02jGdRc+uIrGoe87dkN/XtwxfZ4ZyafJHymd79ozp5VG5nyZ7ygqPM5+cpLDjjGnYFUkngonyDPOQ==",
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/proto3-json-serializer/-/proto3-json-serializer-2.0.1.tgz",
|
||||
"integrity": "sha512-8awBvjO+FwkMd6gNoGFZyqkHZXCFd54CIYTb6De7dPaufGJ2XNW+QUNqbMr8MaAocMdb+KpsD4rxEOaTBDCffA==",
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
"protobufjs": "^7.2.5"
|
||||
@@ -5078,9 +4993,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/protobufjs": {
|
||||
"version": "7.3.2",
|
||||
"resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.3.2.tgz",
|
||||
"integrity": "sha512-RXyHaACeqXeqAKGLDl68rQKbmObRsTIn4TYVUUug1KfS47YWCo5MacGITEryugIgZqORCvJWEk4l449POg5Txg==",
|
||||
"version": "7.2.6",
|
||||
"resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz",
|
||||
"integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==",
|
||||
"hasInstallScript": true,
|
||||
"optional": true,
|
||||
"dependencies": {
|
||||
@@ -5334,9 +5249,9 @@
|
||||
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
|
||||
},
|
||||
"node_modules/sanitize-html": {
|
||||
"version": "2.13.0",
|
||||
"resolved": "https://registry.npmjs.org/sanitize-html/-/sanitize-html-2.13.0.tgz",
|
||||
"integrity": "sha512-Xff91Z+4Mz5QiNSLdLWwjgBDm5b1RU6xBT0+12rapjiaR7SwfRdjw8f+6Rir2MXKLrDicRFHdb51hGOAxmsUIA==",
|
||||
"version": "2.12.1",
|
||||
"resolved": "https://registry.npmjs.org/sanitize-html/-/sanitize-html-2.12.1.tgz",
|
||||
"integrity": "sha512-Plh+JAn0UVDpBRP/xEjsk+xDCoOvMBwQUf/K+/cBAVuTbtX8bj2VB7S1sL1dssVpykqp0/KPSesHrqXtokVBpA==",
|
||||
"dependencies": {
|
||||
"deepmerge": "^4.2.2",
|
||||
"escape-string-regexp": "^4.0.0",
|
||||
@@ -6012,9 +5927,9 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "6.13.0",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.13.0.tgz",
|
||||
"integrity": "sha512-xtFJHudx8S2DSoujjMd1WeWvn7KKWFRESZTMeL1RptAYERu29D6jphMjjY+vn96jvN3kVPDNxU/E13VTaXj6jg=="
|
||||
"version": "5.26.5",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
|
||||
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA=="
|
||||
},
|
||||
"node_modules/unpipe": {
|
||||
"version": "1.0.0",
|
||||
|
||||
+5
-7
@@ -20,14 +20,14 @@
|
||||
"dependencies": {
|
||||
"@anthropic-ai/tokenizer": "^0.0.4",
|
||||
"@aws-crypto/sha256-js": "^5.2.0",
|
||||
"@huggingface/jinja": "^0.3.0",
|
||||
"@node-rs/argon2": "^1.8.3",
|
||||
"@smithy/eventstream-codec": "^2.1.3",
|
||||
"@smithy/eventstream-serde-node": "^2.1.3",
|
||||
"@smithy/protocol-http": "^3.2.1",
|
||||
"@smithy/signature-v4": "^2.1.3",
|
||||
"@smithy/types": "^2.10.1",
|
||||
"@smithy/util-utf8": "^2.1.1",
|
||||
"axios": "^1.7.4",
|
||||
"axios": "^1.3.5",
|
||||
"better-sqlite3": "^10.0.0",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
@@ -38,7 +38,7 @@
|
||||
"ejs": "^3.1.10",
|
||||
"express": "^4.18.2",
|
||||
"express-session": "^1.17.3",
|
||||
"firebase-admin": "^12.3.1",
|
||||
"firebase-admin": "^12.1.0",
|
||||
"glob": "^10.3.12",
|
||||
"googleapis": "^122.0.0",
|
||||
"http-proxy-middleware": "^3.0.0-beta.1",
|
||||
@@ -48,7 +48,7 @@
|
||||
"node-schedule": "^2.1.1",
|
||||
"pino": "^8.11.0",
|
||||
"pino-http": "^8.3.3",
|
||||
"sanitize-html": "^2.13.0",
|
||||
"sanitize-html": "2.12.1",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"source-map-support": "^0.5.21",
|
||||
@@ -60,7 +60,6 @@
|
||||
"zod-error": "^1.5.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@smithy/types": "^3.3.0",
|
||||
"@types/better-sqlite3": "^7.6.10",
|
||||
"@types/cookie-parser": "^1.4.3",
|
||||
"@types/cors": "^2.8.13",
|
||||
@@ -84,8 +83,7 @@
|
||||
"typescript": "^5.4.2"
|
||||
},
|
||||
"overrides": {
|
||||
"braces": "^3.0.3",
|
||||
"fast-xml-parser": "^4.4.1",
|
||||
"postcss": "^8.4.31",
|
||||
"follow-redirects": "^1.15.4"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,39 +230,6 @@ Content-Type: application/json
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / GCP Claude -- Native Completion
|
||||
POST {{proxy-host}}/proxy/gcp/claude/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v2",
|
||||
"max_tokens_to_sample": 10,
|
||||
"temperature": 0,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/gcp/claude/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 50,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Azure OpenAI -- Native Chat Completions
|
||||
POST {{proxy-host}}/proxy/azure/openai/chat/completions
|
||||
|
||||
@@ -51,8 +51,6 @@ function getRandomModelFamily() {
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
// uses the aws sdk to sign a request, then uses axios to send it to the bedrock REST API manually
|
||||
import axios from "axios";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
|
||||
const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID!;
|
||||
const AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY!;
|
||||
|
||||
// Copied from amazon bedrock docs
|
||||
|
||||
// List models
|
||||
// ListFoundationModels
|
||||
// Service: Amazon Bedrock
|
||||
// List of Bedrock foundation models that you can use. For more information, see Foundation models in the
|
||||
// Bedrock User Guide.
|
||||
// Request Syntax
|
||||
// GET /foundation-models?
|
||||
// byCustomizationType=byCustomizationType&byInferenceType=byInferenceType&byOutputModality=byOutputModality&byProvider=byProvider
|
||||
// HTTP/1.1
|
||||
// URI Request Parameters
|
||||
// The request uses the following URI parameters.
|
||||
// byCustomizationType (p. 38)
|
||||
// List by customization type.
|
||||
// Valid Values: FINE_TUNING
|
||||
// byInferenceType (p. 38)
|
||||
// List by inference type.
|
||||
// Valid Values: ON_DEMAND | PROVISIONED
|
||||
// byOutputModality (p. 38)
|
||||
// List by output modality type.
|
||||
// Valid Values: TEXT | IMAGE | EMBEDDING
|
||||
// byProvider (p. 38)
|
||||
// A Bedrock model provider.
|
||||
// Pattern: ^[a-z0-9-]{1,63}$
|
||||
// Request Body
|
||||
// The request does not have a request body
|
||||
|
||||
// Run inference on a text model
|
||||
// Send an invoke request to run inference on a Titan Text G1 - Express model. We set the accept
|
||||
// parameter to accept any content type in the response.
|
||||
// POST https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke
|
||||
// -H accept: */*
|
||||
// -H content-type: application/json
|
||||
// Payload
|
||||
// {"inputText": "Hello world"}
|
||||
// Example response
|
||||
// Response for the above request.
|
||||
// -H content-type: application/json
|
||||
// Payload
|
||||
// <the model response>
|
||||
|
||||
const AMZ_REGION = "us-east-1";
|
||||
const AMZ_HOST = "invoke-bedrock.us-east-1.amazonaws.com";
|
||||
|
||||
async function listModels() {
|
||||
const httpRequest = new HttpRequest({
|
||||
method: "GET",
|
||||
protocol: "https:",
|
||||
hostname: AMZ_HOST,
|
||||
path: "/foundation-models",
|
||||
headers: { ["Host"]: AMZ_HOST },
|
||||
});
|
||||
|
||||
const signedRequest = await signRequest(httpRequest);
|
||||
const response = await axios.get(
|
||||
`https://${signedRequest.hostname}${signedRequest.path}`,
|
||||
{ headers: signedRequest.headers }
|
||||
);
|
||||
console.log(response.data);
|
||||
}
|
||||
|
||||
async function invokeModel() {
|
||||
const model = "anthropic.claude-v1";
|
||||
const httpRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: AMZ_HOST,
|
||||
path: `/model/${model}/invoke`,
|
||||
headers: {
|
||||
["Host"]: AMZ_HOST,
|
||||
["accept"]: "*/*",
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
temperature: 0.5,
|
||||
prompt: "\n\nHuman:Hello world\n\nAssistant:",
|
||||
max_tokens_to_sample: 10,
|
||||
}),
|
||||
});
|
||||
console.log("httpRequest", httpRequest);
|
||||
|
||||
const signedRequest = await signRequest(httpRequest);
|
||||
const response = await axios.post(
|
||||
`https://${signedRequest.hostname}${signedRequest.path}`,
|
||||
signedRequest.body,
|
||||
{ headers: signedRequest.headers }
|
||||
);
|
||||
console.log(response.status);
|
||||
console.log(response.headers);
|
||||
console.log(response.data);
|
||||
console.log("full url", response.request.res.responseUrl);
|
||||
}
|
||||
|
||||
async function signRequest(request: HttpRequest) {
|
||||
const signer = new SignatureV4({
|
||||
sha256: Sha256,
|
||||
credentials: {
|
||||
accessKeyId: AWS_ACCESS_KEY_ID,
|
||||
secretAccessKey: AWS_SECRET_ACCESS_KEY,
|
||||
},
|
||||
region: AMZ_REGION,
|
||||
service: "bedrock",
|
||||
});
|
||||
return await signer.sign(request, { signingDate: new Date() });
|
||||
}
|
||||
|
||||
// listModels();
|
||||
// invokeModel();
|
||||
+13
-14
@@ -17,7 +17,7 @@ import {
|
||||
} from "../../shared/users/schema";
|
||||
import { getLastNImages } from "../../shared/file-storage/image-history";
|
||||
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
|
||||
import { invalidatePowChallenges } from "../../user/web/pow-captcha";
|
||||
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
|
||||
|
||||
const router = Router();
|
||||
|
||||
@@ -268,13 +268,7 @@ router.post("/maintenance", (req, res) => {
|
||||
let flash = { type: "", message: "" };
|
||||
switch (action) {
|
||||
case "recheck": {
|
||||
const checkable: LLMService[] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"aws",
|
||||
"gcp",
|
||||
"azure",
|
||||
];
|
||||
const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"];
|
||||
checkable.forEach((s) => keyPool.recheck(s));
|
||||
const keyCount = keyPool
|
||||
.list()
|
||||
@@ -323,7 +317,7 @@ router.post("/maintenance", (req, res) => {
|
||||
user.disabledReason = "Admin forced expiration.";
|
||||
userStore.upsertUser(user);
|
||||
});
|
||||
invalidatePowChallenges();
|
||||
invalidatePowHmacKey();
|
||||
flash.type = "success";
|
||||
flash.message = `${temps.length} temporary users marked for expiration.`;
|
||||
break;
|
||||
@@ -348,15 +342,20 @@ router.post("/maintenance", (req, res) => {
|
||||
throw new HttpError(400, "Invalid difficulty" + selected);
|
||||
}
|
||||
config.powDifficultyLevel = selected;
|
||||
invalidatePowChallenges();
|
||||
break;
|
||||
}
|
||||
case "generateTempIpReport": {
|
||||
const tempUsers = userStore
|
||||
.getUsers()
|
||||
.filter((u) => u.type === "temporary");
|
||||
const ipv4RangeMap = new Map<string, Set<string>>();
|
||||
const ipv6RangeMap = new Map<string, Set<string>>();
|
||||
const ipv4RangeMap: Map<string, Set<string>> = new Map<
|
||||
string,
|
||||
Set<string>
|
||||
>();
|
||||
const ipv6RangeMap: Map<string, Set<string>> = new Map<
|
||||
string,
|
||||
Set<string>
|
||||
>();
|
||||
|
||||
tempUsers.forEach((u) => {
|
||||
u.ip.forEach((ip) => {
|
||||
@@ -366,14 +365,14 @@ router.post("/maintenance", (req, res) => {
|
||||
const subnet =
|
||||
parsed.toNormalizedString().split(".").slice(0, 3).join(".") +
|
||||
".0/24";
|
||||
const userSet = ipv4RangeMap.get(subnet) || new Set();
|
||||
const userSet = ipv4RangeMap.get(subnet) || new Set<string>();
|
||||
userSet.add(u.token);
|
||||
ipv4RangeMap.set(subnet, userSet);
|
||||
} else if (parsed.kind() === "ipv6") {
|
||||
const subnet =
|
||||
parsed.toNormalizedString().split(":").slice(0, 4).join(":") +
|
||||
"::/48";
|
||||
const userSet = ipv6RangeMap.get(subnet) || new Set();
|
||||
const userSet = ipv6RangeMap.get(subnet) || new Set<string>();
|
||||
userSet.add(u.token);
|
||||
ipv6RangeMap.set(subnet, userSet);
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@
|
||||
<legend>Bulk Quota Management</legend>
|
||||
<p>
|
||||
<button id="refresh-quotas" type="button" onclick="submitForm('resetQuotas')">Refresh All Quotas</button>
|
||||
Immediately refreshes all users' quotas by the configured amounts.
|
||||
Resets all users' quotas to the values set in the <code>TOKEN_QUOTA_*</code> environment variables.
|
||||
</p>
|
||||
<p>
|
||||
<button id="clear-token-counts" type="button" onclick="submitForm('resetCounts')">Clear All Token Counts</button>
|
||||
|
||||
@@ -101,10 +101,6 @@
|
||||
<% ["nickname", "type", "disabledAt", "disabledReason", "maxIps", "adminNote"].forEach(function (key) { %>
|
||||
<input type="hidden" name="<%- key %>" value="<%- user[key] %>" />
|
||||
<% }); %>
|
||||
<!-- tokenRefresh_ keys are dynamically generated -->
|
||||
<% Object.entries(quota).forEach(([family]) => { %>
|
||||
<input type="hidden" name="tokenRefresh_<%- family %>" value="<%- user.tokenRefresh[family] || quota[family] %>" />
|
||||
<% }); %>
|
||||
</form>
|
||||
|
||||
<h3>Quota Information</h3>
|
||||
@@ -115,7 +111,7 @@
|
||||
<button type="submit" class="btn btn-primary">Refresh Quotas for User</button>
|
||||
</form>
|
||||
<% } %>
|
||||
<%- include("partials/shared_quota-info", { quota, user, showRefreshEdit: true }) %>
|
||||
<%- include("partials/shared_quota-info", { quota, user }) %>
|
||||
|
||||
<p><a href="/admin/manage/list-users">Back to User List</a></p>
|
||||
|
||||
@@ -126,25 +122,18 @@
|
||||
const token = a.dataset.token;
|
||||
const field = a.dataset.field;
|
||||
const existingValue = document.querySelector(`#current-values input[name=${field}]`).value;
|
||||
|
||||
let value = prompt(`Enter new value for '${field}':`, existingValue);
|
||||
let value = prompt(`Enter new value for '${field}'':`, existingValue);
|
||||
if (value !== null) {
|
||||
if (value === "") {
|
||||
value = null;
|
||||
}
|
||||
|
||||
const payload = { _csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content") };
|
||||
if (field.startsWith("tokenRefresh_")) {
|
||||
const family = field.slice("tokenRefresh_".length);
|
||||
payload.tokenRefresh = { [family]: Number(value) };
|
||||
} else {
|
||||
payload[field] = value;
|
||||
}
|
||||
|
||||
fetch(`/admin/manage/edit-user/${token}`, {
|
||||
method: "POST",
|
||||
credentials: "same-origin",
|
||||
body: JSON.stringify(payload),
|
||||
body: JSON.stringify({
|
||||
[field]: value,
|
||||
_csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content"),
|
||||
}),
|
||||
headers: { "Content-Type": "application/json", Accept: "application/json" },
|
||||
})
|
||||
.then((res) => Promise.all([res.ok, res.json()]))
|
||||
@@ -152,7 +141,9 @@
|
||||
const url = new URL(window.location.href);
|
||||
const params = new URLSearchParams();
|
||||
if (!ok) {
|
||||
alert(`Failed to edit user: ${json.message}`);
|
||||
params.set("flash", `error: ${json.error.message}`);
|
||||
} else {
|
||||
params.set("flash", `success: User's ${field} updated.`);
|
||||
}
|
||||
url.search = params.toString();
|
||||
window.location.assign(url);
|
||||
|
||||
+27
-25
@@ -45,13 +45,6 @@ type Config = {
|
||||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||
*/
|
||||
awsCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of GCP credentials. Each credential item should be a
|
||||
* colon-delimited list of access key, secret key, and GCP region.
|
||||
*
|
||||
* @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----`
|
||||
*/
|
||||
gcpCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||
@@ -356,7 +349,7 @@ type Config = {
|
||||
*
|
||||
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
||||
* separated list. Available services are:
|
||||
* openai,anthropic,google-ai,mistral-ai,aws,gcp,azure
|
||||
* openai,anthropic,google-ai,mistral-ai,aws,azure
|
||||
*/
|
||||
allowedVisionServices: LLMService[];
|
||||
/**
|
||||
@@ -390,7 +383,6 @@ export const config: Config = {
|
||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||
@@ -415,23 +407,40 @@ export const config: Config = {
|
||||
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
|
||||
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
|
||||
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
|
||||
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 32768),
|
||||
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
|
||||
maxContextTokensAnthropic: getEnvWithDefault(
|
||||
"MAX_CONTEXT_TOKENS_ANTHROPIC",
|
||||
32768
|
||||
0
|
||||
),
|
||||
maxOutputTokensOpenAI: getEnvWithDefault(
|
||||
["MAX_OUTPUT_TOKENS_OPENAI", "MAX_OUTPUT_TOKENS"],
|
||||
1024
|
||||
400
|
||||
),
|
||||
maxOutputTokensAnthropic: getEnvWithDefault(
|
||||
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
|
||||
1024
|
||||
),
|
||||
allowedModelFamilies: getEnvWithDefault(
|
||||
"ALLOWED_MODEL_FAMILIES",
|
||||
getDefaultModelFamilies()
|
||||
400
|
||||
),
|
||||
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"gpt4o",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-pro",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4o",
|
||||
]),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
@@ -500,7 +509,6 @@ function generateSigningKey() {
|
||||
config.googleAIKey,
|
||||
config.mistralAIKey,
|
||||
config.awsCredentials,
|
||||
config.gcpCredentials,
|
||||
config.azureCredentials,
|
||||
];
|
||||
if (secrets.filter((s) => s).length === 0) {
|
||||
@@ -519,7 +527,7 @@ function generateSigningKey() {
|
||||
}
|
||||
|
||||
const signingKey = generateSigningKey();
|
||||
export const SECRET_SIGNING_KEY = signingKey;
|
||||
export const COOKIE_SECRET = signingKey;
|
||||
|
||||
export async function assertConfigIsValid() {
|
||||
if (process.env.MODEL_RATE_LIMIT !== undefined) {
|
||||
@@ -638,7 +646,6 @@ export const OMITTED_KEYS = [
|
||||
"googleAIKey",
|
||||
"mistralAIKey",
|
||||
"awsCredentials",
|
||||
"gcpCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
"adminKey",
|
||||
@@ -729,7 +736,6 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_AI_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"GCP_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
) {
|
||||
@@ -780,7 +786,3 @@ function parseCsv(val: string): string[] {
|
||||
const matches = val.match(regex) || [];
|
||||
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
|
||||
}
|
||||
|
||||
function getDefaultModelFamilies(): ModelFamily[] {
|
||||
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
|
||||
}
|
||||
|
||||
+3
-11
@@ -12,7 +12,7 @@ import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
|
||||
|
||||
const INFO_PAGE_TTL = 2000;
|
||||
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
turbo: "GPT-4o Mini / 3.5 Turbo",
|
||||
turbo: "GPT-3.5 Turbo",
|
||||
gpt4: "GPT-4",
|
||||
"gpt4-32k": "GPT-4 32k",
|
||||
"gpt4-turbo": "GPT-4 Turbo",
|
||||
@@ -20,21 +20,13 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"dall-e": "DALL-E",
|
||||
claude: "Claude (Sonnet)",
|
||||
"claude-opus": "Claude (Opus)",
|
||||
"gemini-flash": "Gemini Flash",
|
||||
"gemini-pro": "Gemini Pro",
|
||||
"gemini-ultra": "Gemini Ultra",
|
||||
"mistral-tiny": "Mistral 7B",
|
||||
"mistral-small": "Mistral Nemo",
|
||||
"mistral-small": "Mixtral Small", // Originally 8x7B, but that now refers to the older open-weight version. Mixtral Small is a newer closed-weight update to the 8x7B model.
|
||||
"mistral-medium": "Mistral Medium",
|
||||
"mistral-large": "Mistral Large",
|
||||
"aws-claude": "AWS Claude (Sonnet)",
|
||||
"aws-claude-opus": "AWS Claude (Opus)",
|
||||
"aws-mistral-tiny": "AWS Mistral 7B",
|
||||
"aws-mistral-small": "AWS Mistral Nemo",
|
||||
"aws-mistral-medium": "AWS Mistral Medium",
|
||||
"aws-mistral-large": "AWS Mistral Large",
|
||||
"gcp-claude": "GCP Claude (Sonnet)",
|
||||
"gcp-claude-opus": "GCP Claude (Opus)",
|
||||
"azure-turbo": "Azure GPT-3.5 Turbo",
|
||||
"azure-gpt4": "Azure GPT-4",
|
||||
"azure-gpt4-32k": "Azure GPT-4 32k",
|
||||
@@ -45,7 +37,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
|
||||
const converter = new showdown.Converter();
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
: "";
|
||||
let infoPageHtml: string | undefined;
|
||||
let infoPageLastUpdated = 0;
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
|
||||
export function addV1(req: Request, res: Response, next: NextFunction) {
|
||||
// Clients don't consistently use the /v1 prefix so we'll add it for them.
|
||||
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
}
|
||||
+68
-32
@@ -46,7 +46,6 @@ const getModelsResponse = () => {
|
||||
"claude-3-haiku-20240307",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-sonnet-20240229",
|
||||
"claude-3-5-sonnet-20240620",
|
||||
];
|
||||
|
||||
const models = claudeVariants.map((id) => ({
|
||||
@@ -70,7 +69,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async (
|
||||
const anthropicResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
@@ -129,7 +128,7 @@ export function transformAnthropicChatResponseToAnthropicText(
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
export function transformAnthropicTextResponseToOpenAI(
|
||||
function transformAnthropicTextResponseToOpenAI(
|
||||
anthropicBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
@@ -179,28 +178,6 @@ export function transformAnthropicChatResponseToOpenAI(
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
|
||||
* model, reassigns it to Claude 3 Sonnet.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
if (!model.startsWith("gpt-")) return;
|
||||
req.body.model = "claude-3-sonnet-20240229";
|
||||
}
|
||||
|
||||
/**
|
||||
* If client requests more than 4096 output tokens the request must have a
|
||||
* particular version header.
|
||||
* https://docs.anthropic.com/en/release-notes/api#july-15th-2024
|
||||
*/
|
||||
function setAnthropicBetaHeader(req: Request) {
|
||||
const { max_tokens_to_sample } = req.body;
|
||||
if (max_tokens_to_sample > 4096) {
|
||||
req.headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15";
|
||||
}
|
||||
}
|
||||
|
||||
const anthropicProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.anthropic.com",
|
||||
@@ -211,7 +188,7 @@ const anthropicProxy = createQueueMiddleware({
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([anthropicBlockingResponseHandler]),
|
||||
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
// Abusing pathFilter to rewrite the paths dynamically.
|
||||
@@ -235,11 +212,6 @@ const anthropicProxy = createQueueMiddleware({
|
||||
}),
|
||||
});
|
||||
|
||||
const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "anthropic" },
|
||||
{ afterTransform: [setAnthropicBetaHeader] }
|
||||
);
|
||||
|
||||
const nativeTextPreprocessor = createPreprocessorMiddleware({
|
||||
inApi: "anthropic-text",
|
||||
outApi: "anthropic-text",
|
||||
@@ -295,7 +267,11 @@ anthropicRouter.get("/v1/models", handleModelRequest);
|
||||
anthropicRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
nativeAnthropicChatPreprocessor,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "anthropic-chat",
|
||||
outApi: "anthropic-chat",
|
||||
service: "anthropic",
|
||||
}),
|
||||
anthropicProxy
|
||||
);
|
||||
// Anthropic text completion endpoint. Translates to Anthropic chat completion
|
||||
@@ -315,5 +291,65 @@ anthropicRouter.post(
|
||||
preprocessOpenAICompatRequest,
|
||||
anthropicProxy
|
||||
);
|
||||
// Temporarily force Anthropic Text to Anthropic Chat for frontends which do not
|
||||
// yet support the new model. Forces claude-3. Will be removed once common
|
||||
// frontends have been updated.
|
||||
anthropicRouter.post(
|
||||
"/v1/:type(sonnet|opus)/:action(complete|messages)",
|
||||
ipLimiter,
|
||||
handleAnthropicTextCompatRequest,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "anthropic-text",
|
||||
outApi: "anthropic-chat",
|
||||
service: "anthropic",
|
||||
}),
|
||||
anthropicProxy
|
||||
);
|
||||
|
||||
function handleAnthropicTextCompatRequest(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: any
|
||||
) {
|
||||
const type = req.params.type;
|
||||
const action = req.params.action;
|
||||
const alreadyInChatFormat = Boolean(req.body.messages);
|
||||
const compatModel = `claude-3-${type}-20240229`;
|
||||
req.log.info(
|
||||
{ type, inputModel: req.body.model, compatModel, alreadyInChatFormat },
|
||||
"Handling Anthropic compatibility request"
|
||||
);
|
||||
|
||||
if (action === "messages" || alreadyInChatFormat) {
|
||||
return sendErrorToClient({
|
||||
req,
|
||||
res,
|
||||
options: {
|
||||
title: "Unnecessary usage of compatibility endpoint",
|
||||
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/anthropic\` proxy endpoint instead.`,
|
||||
format: "unknown",
|
||||
statusCode: 400,
|
||||
reqId: req.id,
|
||||
obj: {
|
||||
requested_endpoint: "/anthropic/" + type,
|
||||
correct_endpoint: "/anthropic",
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
req.body.model = compatModel;
|
||||
next();
|
||||
}
|
||||
|
||||
/**
|
||||
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
|
||||
* model, reassigns it to Claude 3 Sonnet.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
if (!model.startsWith("gpt-")) return;
|
||||
req.body.model = "claude-3-sonnet-20240229";
|
||||
}
|
||||
|
||||
export const anthropic = anthropicRouter;
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import {
|
||||
transformAnthropicChatResponseToAnthropicText,
|
||||
transformAnthropicChatResponseToOpenAI,
|
||||
} from "./anthropic";
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const awsResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-text":
|
||||
req.log.info("Transforming Anthropic Text back to OpenAI format");
|
||||
newBody = transformAwsTextResponseToOpenAI(body, req);
|
||||
break;
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
case "anthropic-text<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to Text format");
|
||||
newBody = transformAnthropicChatResponseToAnthropicText(body);
|
||||
break;
|
||||
}
|
||||
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAwsTextResponseToOpenAI(
|
||||
awsBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "aws-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: awsBody.completion?.trim(),
|
||||
},
|
||||
finish_reason: awsBody.stop_reason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const awsClaudeProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const nativeTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes text completion prompts to aws anthropic-chat if they need translation
|
||||
* (claude-3 based models do not support the old text completion endpoint).
|
||||
*/
|
||||
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
nativeTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
oaiToAwsChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
oaiToAwsTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const awsClaudeRouter = Router();
|
||||
// Native(ish) Anthropic text completion endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/complete",
|
||||
ipLimiter,
|
||||
preprocessAwsTextRequest,
|
||||
awsClaudeProxy
|
||||
);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsClaudeProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-AWS Anthropic compatibility endpoint.
|
||||
awsClaudeRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
awsClaudeProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending AWS model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that AWS doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*
|
||||
* If client sends AWS model ID it will be used verbatim. Otherwise, various
|
||||
* strategies are used to try to map a non-AWS model name to AWS model ID.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an AWS model, use it as-is
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Anthropic model names can look like:
|
||||
// - claude-v1
|
||||
// - claude-2.1
|
||||
// - claude-3-5-sonnet-20240620-v1:0
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, fallback to Claude v2 as it is most likely to be
|
||||
// available on AWS.
|
||||
if (!match) {
|
||||
req.body.model = `anthropic.claude-v2:1`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||
|
||||
if (instant) {
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
const ver = minor ? `${major}.${minor}` : major;
|
||||
switch (ver) {
|
||||
case "1":
|
||||
case "1.0":
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
case "2":
|
||||
case "2.0":
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
case "3":
|
||||
case "3.0":
|
||||
if (name.includes("opus")) {
|
||||
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
} else if (name.includes("haiku")) {
|
||||
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
} else {
|
||||
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
}
|
||||
return;
|
||||
case "3.5":
|
||||
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to Claude 2.1
|
||||
req.body.model = `anthropic.claude-v2:1`;
|
||||
return;
|
||||
}
|
||||
|
||||
export const awsClaude = awsClaudeRouter;
|
||||
@@ -1,110 +0,0 @@
|
||||
import { Request } from "express";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import {
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
signAwsRequest,
|
||||
} from "./middleware/request";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { logger } from "../logger";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import { Router } from "express";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
|
||||
|
||||
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") {
|
||||
newBody = transformMistralTextToMistralChat(body);
|
||||
}
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
const awsMistralProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an AWS model, use it as-is
|
||||
if (model.startsWith("mistral.")) {
|
||||
return;
|
||||
}
|
||||
// Mistral 7B Instruct
|
||||
else if (model.includes("7b")) {
|
||||
req.body.model = "mistral.mistral-7b-instruct-v0:2";
|
||||
}
|
||||
// Mistral 8x7B Instruct
|
||||
else if (model.includes("8x7b")) {
|
||||
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
|
||||
}
|
||||
// Mistral Large (Feb 2024)
|
||||
else if (model.includes("large-2402")) {
|
||||
req.body.model = "mistral.mistral-large-2402-v1:0";
|
||||
}
|
||||
// Mistral Large 2 (July 2024)
|
||||
else if (model.includes("large")) {
|
||||
req.body.model = "mistral.mistral-large-2407-v1:0";
|
||||
}
|
||||
// Mistral Small (Feb 2024)
|
||||
else if (model.includes("small")) {
|
||||
req.body.model = "mistral.mistral-small-2402-v1:0";
|
||||
} else {
|
||||
throw new Error(
|
||||
`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
|
||||
{
|
||||
beforeTransform: [detectMistralInputApi],
|
||||
afterTransform: [maybeReassignModel],
|
||||
}
|
||||
);
|
||||
|
||||
const awsMistralRouter = Router();
|
||||
awsMistralRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
nativeMistralChatPreprocessor,
|
||||
awsMistralProxy
|
||||
);
|
||||
|
||||
export const awsMistral = awsMistralRouter;
|
||||
+318
-58
@@ -1,75 +1,335 @@
|
||||
/* Shared code between AWS Claude and AWS Mistral endpoints. */
|
||||
|
||||
import { Request, Response, Router } from "express";
|
||||
import { Request, RequestHandler, Response, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { addV1 } from "./add-v1";
|
||||
import { awsClaude } from "./aws-claude";
|
||||
import { awsMistral } from "./aws-mistral";
|
||||
import { AwsBedrockKey, keyPool } from "../shared/key-management";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get(["/:vendor?/v1/models", "/:vendor?/models"], handleModelsRequest);
|
||||
awsRouter.use("/claude", addV1, awsClaude);
|
||||
awsRouter.use("/mistral", addV1, awsMistral);
|
||||
const LATEST_AWS_V2_MINOR_VERSION = "1";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const MODELS_CACHE_TTL = 10000;
|
||||
let modelsCache: Record<string, any> = {};
|
||||
let modelsCacheTime: Record<string, number> = {};
|
||||
function handleModelsRequest(req: Request, res: Response) {
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
|
||||
const vendor = req.params.vendor?.length
|
||||
? req.params.vendor === "claude"
|
||||
? "anthropic"
|
||||
: req.params.vendor
|
||||
: "all";
|
||||
|
||||
const cacheTime = modelsCacheTime[vendor] || 0;
|
||||
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
|
||||
return res.json(modelsCache[vendor]);
|
||||
}
|
||||
|
||||
const availableModelIds = new Set<string>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "aws") continue;
|
||||
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
|
||||
}
|
||||
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
|
||||
const models = [
|
||||
const variants = [
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
"anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
"mistral.mistral-small-2402-v1:0",
|
||||
]
|
||||
.filter((id) => availableModelIds.has(id))
|
||||
.map((id) => {
|
||||
const vendor = id.match(/^(.*)\./)?.[1];
|
||||
return {
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: vendor,
|
||||
permission: [],
|
||||
root: vendor,
|
||||
parent: null,
|
||||
};
|
||||
});
|
||||
];
|
||||
|
||||
modelsCache[vendor] = {
|
||||
object: "list",
|
||||
data: models.filter((m) => vendor === "all" || m.root === vendor),
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
permission: [],
|
||||
root: "claude",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const awsResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-text":
|
||||
req.log.info("Transforming Anthropic Text back to OpenAI format");
|
||||
newBody = transformAwsTextResponseToOpenAI(body, req);
|
||||
break;
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
case "anthropic-text<-anthropic-chat":
|
||||
req.log.info("Transforming AWS Anthropic Chat back to Text format");
|
||||
newBody = transformAnthropicChatResponseToAnthropicText(body);
|
||||
break;
|
||||
}
|
||||
|
||||
// AWS does not always confirm the model in the response, so we have to add it
|
||||
if (!newBody.model && req.body.model) {
|
||||
newBody.model = req.body.model;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAwsTextResponseToOpenAI(
|
||||
awsBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "aws-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: awsBody.completion?.trim(),
|
||||
},
|
||||
finish_reason: awsBody.stop_reason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
modelsCacheTime[vendor] = new Date().getTime();
|
||||
}
|
||||
|
||||
return res.json(modelsCache[vendor]);
|
||||
const awsProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const nativeTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const textToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes text completion prompts to aws anthropic-chat if they need translation
|
||||
* (claude-3 based models do not support the old text completion endpoint).
|
||||
*/
|
||||
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
textToChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
nativeTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
if (req.body.model?.includes("claude-3")) {
|
||||
oaiToAwsChatPreprocessor(req, res, next);
|
||||
} else {
|
||||
oaiToAwsTextPreprocessor(req, res, next);
|
||||
}
|
||||
};
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get("/v1/models", handleModelRequest);
|
||||
// Native(ish) Anthropic text completion endpoint.
|
||||
awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
awsRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
// Temporary force-Claude3 endpoint
|
||||
awsRouter.post(
|
||||
"/v1/sonnet/:action(complete|messages)",
|
||||
ipLimiter,
|
||||
handleCompatibilityRequest,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "anthropic-text",
|
||||
outApi: "anthropic-chat",
|
||||
service: "aws",
|
||||
}),
|
||||
awsProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-AWS Anthropic compatibility endpoint.
|
||||
awsRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
awsProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending AWS model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that AWS doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If client already specified an AWS Claude model ID, use it
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?(-sonnet-?|-opus-?|-haiku-?)(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, return the latest v2 model
|
||||
if (!match) {
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
const instant = match[2];
|
||||
const major = match[4];
|
||||
const minor = match[6];
|
||||
|
||||
if (instant) {
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
// There's only one v1 model
|
||||
if (major === "1") {
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to map Anthropic API v2 models to AWS v2 models
|
||||
if (major === "2") {
|
||||
if (minor === "0") {
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
}
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
// AWS currently only supports one v3 model.
|
||||
const variant = match[8]; // sonnet, opus, or haiku
|
||||
const variantVersion = match[9];
|
||||
if (major === "3") {
|
||||
if (variant.includes("opus")) {
|
||||
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
|
||||
} else if (variant.includes("haiku")) {
|
||||
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
|
||||
} else {
|
||||
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to latest v2 model
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
export function handleCompatibilityRequest(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: any
|
||||
) {
|
||||
const action = req.params.action;
|
||||
const alreadyInChatFormat = Boolean(req.body.messages);
|
||||
const compatModel = "anthropic.claude-3-sonnet-20240229-v1:0";
|
||||
req.log.info(
|
||||
{ inputModel: req.body.model, compatModel, alreadyInChatFormat },
|
||||
"Handling AWS compatibility request"
|
||||
);
|
||||
|
||||
if (action === "messages" || alreadyInChatFormat) {
|
||||
return sendErrorToClient({
|
||||
req,
|
||||
res,
|
||||
options: {
|
||||
title: "Unnecessary usage of compatibility endpoint",
|
||||
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`,
|
||||
format: "unknown",
|
||||
statusCode: 400,
|
||||
reqId: req.id,
|
||||
obj: {
|
||||
requested_endpoint: "/aws/claude/sonnet",
|
||||
correct_endpoint: "/aws/claude",
|
||||
},
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
req.body.model = compatModel;
|
||||
next();
|
||||
}
|
||||
|
||||
export const aws = awsRouter;
|
||||
|
||||
@@ -12,7 +12,6 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
|
||||
// pass the _proxy_ key in this header too, instead of providing it as a
|
||||
// Bearer token in the Authorization header. So we need to check both.
|
||||
// Prefer the Authorization header if both are present.
|
||||
// Google AI uses a key querystring parameter.
|
||||
|
||||
if (req.headers.authorization) {
|
||||
const token = req.headers.authorization?.slice("Bearer ".length);
|
||||
@@ -25,12 +24,6 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
|
||||
delete req.headers["x-api-key"];
|
||||
return token;
|
||||
}
|
||||
|
||||
if (req.query.key) {
|
||||
const token = req.query.key?.toString();
|
||||
delete req.query.key;
|
||||
return token;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createPreprocessorMiddleware,
|
||||
signGcpRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
|
||||
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.gcpCredentials) return { object: "list", data: [] };
|
||||
|
||||
// https://docs.anthropic.com/en/docs/about-claude/models
|
||||
const variants = [
|
||||
"claude-3-haiku@20240307",
|
||||
"claude-3-sonnet@20240229",
|
||||
"claude-3-opus@20240229",
|
||||
"claude-3-5-sonnet@20240620",
|
||||
];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "anthropic",
|
||||
permission: [],
|
||||
root: "claude",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const gcpResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
switch (`${req.inboundApi}<-${req.outboundApi}`) {
|
||||
case "openai<-anthropic-chat":
|
||||
req.log.info("Transforming Anthropic Chat back to OpenAI format");
|
||||
newBody = transformAnthropicChatResponseToOpenAI(body);
|
||||
break;
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
const gcpProxy = createQueueMiddleware({
|
||||
beforeProxy: signGcpRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const oaiToChatPreprocessor = createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic-chat", service: "gcp" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
);
|
||||
|
||||
/**
|
||||
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
|
||||
* or the new Claude chat completion endpoint, based on the requested model.
|
||||
*/
|
||||
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
|
||||
oaiToChatPreprocessor(req, res, next);
|
||||
};
|
||||
|
||||
const gcpRouter = Router();
|
||||
gcpRouter.get("/v1/models", handleModelRequest);
|
||||
// Native Anthropic chat completion endpoint.
|
||||
gcpRouter.post(
|
||||
"/v1/messages",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
gcpProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-GCP Anthropic compatibility endpoint.
|
||||
gcpRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
preprocessOpenAICompatRequest,
|
||||
gcpProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* Tries to deal with:
|
||||
* - frontends sending GCP model names even when they want to use the OpenAI-
|
||||
* compatible endpoint
|
||||
* - frontends sending Anthropic model names that GCP doesn't recognize
|
||||
* - frontends sending OpenAI model names because they expect the proxy to
|
||||
* translate them
|
||||
*
|
||||
* If client sends GCP model ID it will be used verbatim. Otherwise, various
|
||||
* strategies are used to try to map a non-GCP model name to GCP model ID.
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If it looks like an GCP model, use it as-is
|
||||
// if (model.includes("anthropic.claude")) {
|
||||
if (model.startsWith("claude-") && model.includes("@")) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Anthropic model names can look like:
|
||||
// - claude-v1
|
||||
// - claude-2.1
|
||||
// - claude-3-5-sonnet-20240620-v1:0
|
||||
const pattern =
|
||||
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, fallback to Claude3 Sonnet as it is most likely to be
|
||||
// available on GCP.
|
||||
if (!match) {
|
||||
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
|
||||
|
||||
const ver = minor ? `${major}.${minor}` : major;
|
||||
switch (ver) {
|
||||
case "3":
|
||||
case "3.0":
|
||||
if (name.includes("opus")) {
|
||||
req.body.model = "claude-3-opus@20240229";
|
||||
} else if (name.includes("haiku")) {
|
||||
req.body.model = "claude-3-haiku@20240307";
|
||||
} else {
|
||||
req.body.model = "claude-3-sonnet@20240229";
|
||||
}
|
||||
return;
|
||||
case "3.5":
|
||||
req.body.model = "claude-3-5-sonnet@20240620";
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to Claude3 Sonnet
|
||||
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
export const gcp = gcpRouter;
|
||||
+8
-80
@@ -16,7 +16,6 @@ import {
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
|
||||
import { GoogleAIKey, keyPool } from "../shared/key-management";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
@@ -31,19 +30,9 @@ const getModelsResponse = () => {
|
||||
|
||||
if (!config.googleAIKey) return { object: "list", data: [] };
|
||||
|
||||
const keys = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
|
||||
if (keys.length === 0) {
|
||||
modelsCache = { object: "list", data: [] };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
return modelsCache;
|
||||
}
|
||||
const googleAIVariants = ["gemini-pro", "gemini-1.0-pro", "gemini-1.5-pro"];
|
||||
|
||||
const modelIds = Array.from(
|
||||
new Set(keys.map((k) => k.modelIds).flat())
|
||||
).filter((id) => id.startsWith("models/gemini"));
|
||||
const models = modelIds.map((id) => ({
|
||||
const models = googleAIVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
@@ -120,17 +109,7 @@ const googleAIProxy = createQueueMiddleware({
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
// Prevent logging of the API key by HPM
|
||||
logger: logger.child(
|
||||
{},
|
||||
{
|
||||
redact: {
|
||||
paths: ["*"],
|
||||
censor: (v) =>
|
||||
typeof v === "string" ? v.replace(/key=\S+/g, "key=xxxxxxx") : v,
|
||||
},
|
||||
}
|
||||
),
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
|
||||
@@ -141,67 +120,16 @@ const googleAIProxy = createQueueMiddleware({
|
||||
|
||||
const googleAIRouter = Router();
|
||||
googleAIRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
// Native Google AI chat completion endpoint
|
||||
googleAIRouter.post(
|
||||
"/v1beta/models/:modelId:(generateContent|streamGenerateContent)",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{
|
||||
inApi: "google-ai",
|
||||
outApi: "google-ai",
|
||||
service: "google-ai",
|
||||
},
|
||||
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-Google AI compatibility endpoint.
|
||||
googleAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai",
|
||||
outApi: "google-ai",
|
||||
service: "google-ai",
|
||||
}),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
function setStreamFlag(req: Request) {
|
||||
const isStreaming = req.url.includes("streamGenerateContent");
|
||||
if (isStreaming) {
|
||||
req.body.stream = true;
|
||||
req.isStreaming = true;
|
||||
} else {
|
||||
req.body.stream = false;
|
||||
req.isStreaming = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces requests for non-Google AI models with gemini-pro-1.5-latest.
|
||||
* Also strips models/ from the beginning of the model IDs.
|
||||
**/
|
||||
function maybeReassignModel(req: Request) {
|
||||
// Ensure model is on body as a lot of middleware will expect it.
|
||||
const model = req.body.model || req.url.split("/").pop()?.split(":").shift();
|
||||
if (!model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
req.body.model = model;
|
||||
|
||||
const requested = model;
|
||||
if (requested.startsWith("models/")) {
|
||||
req.body.model = requested.slice("models/".length);
|
||||
}
|
||||
|
||||
if (requested.includes("gemini")) {
|
||||
return;
|
||||
}
|
||||
|
||||
req.log.info({ requested }, "Reassigning model to gemini-pro-1.5-latest");
|
||||
req.body.model = "gemini-pro-1.5-latest";
|
||||
}
|
||||
|
||||
export const googleAI = googleAIRouter;
|
||||
|
||||
@@ -16,7 +16,6 @@ const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
|
||||
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
|
||||
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
|
||||
const GOOGLE_AI_COMPLETION_ENDPOINT = "/v1beta/models";
|
||||
|
||||
export function isTextGenerationRequest(req: Request) {
|
||||
return (
|
||||
@@ -28,7 +27,6 @@ export function isTextGenerationRequest(req: Request) {
|
||||
ANTHROPIC_MESSAGES_ENDPOINT,
|
||||
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
|
||||
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
|
||||
GOOGLE_AI_COMPLETION_ENDPOINT,
|
||||
].some((endpoint) => req.path.startsWith(endpoint))
|
||||
);
|
||||
}
|
||||
@@ -223,12 +221,9 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
// Few possible values:
|
||||
// - choices[0].message.content
|
||||
// - choices[0].message with no content if model is invoking a tool
|
||||
return body.choices?.[0]?.message?.content || "";
|
||||
case "mistral-text":
|
||||
return body.outputs?.[0]?.text || "";
|
||||
// Can be null if the model wants to invoke tools rather than return a
|
||||
// completion.
|
||||
return body.choices[0].message.content || "";
|
||||
case "openai-text":
|
||||
return body.choices[0].text;
|
||||
case "anthropic-chat":
|
||||
@@ -265,22 +260,22 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
}
|
||||
}
|
||||
|
||||
export function getModelFromBody(req: Request, resBody: Record<string, any>) {
|
||||
export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return resBody.model;
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
return body.model;
|
||||
case "openai-image":
|
||||
case "google-ai":
|
||||
// These formats don't have a model in the response body.
|
||||
return req.body.model;
|
||||
case "anthropic-chat":
|
||||
case "anthropic-text":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return resBody.model || req.body.model;
|
||||
return body.model || req.body.model;
|
||||
case "google-ai":
|
||||
// Google doesn't confirm the model in the response.
|
||||
return req.body.model;
|
||||
default:
|
||||
assertNever(format);
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
|
||||
export { languageFilter } from "./preprocessors/language-filter";
|
||||
export { setApiFormat } from "./preprocessors/set-api-format";
|
||||
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
||||
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
|
||||
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
||||
export { validateContextSize } from "./preprocessors/validate-context-size";
|
||||
export { validateVision } from "./preprocessors/validate-vision";
|
||||
|
||||
@@ -38,10 +38,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
// translation now reassigns the model earlier in the request pipeline.
|
||||
case "anthropic-text":
|
||||
case "anthropic-chat":
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
case "google-ai":
|
||||
assignedKey = keyPool.get(body.model, service);
|
||||
assignedKey = keyPool.get("claude-v1", service, needsMultimodal);
|
||||
break;
|
||||
case "openai-text":
|
||||
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
|
||||
@@ -50,8 +47,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
assignedKey = keyPool.get("dall-e-3", service);
|
||||
break;
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
throw new Error(
|
||||
`Outbound API ${outboundApi} is not supported for ${inboundApi}`
|
||||
`add-key should not be called for outbound API ${outboundApi}`
|
||||
);
|
||||
default:
|
||||
assertNever(outboundApi);
|
||||
@@ -84,7 +83,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
case "gcp":
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this service.");
|
||||
default:
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { config } from "../../../../config";
|
||||
import { ForbiddenError } from "../../../../shared/errors";
|
||||
import { getModelFamilyForRequest } from "../../../../shared/models";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* Ensures the selected model family is enabled by the proxy configuration.
|
||||
*/
|
||||
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req) => {
|
||||
**/
|
||||
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
|
||||
const family = getModelFamilyForRequest(req);
|
||||
if (!config.allowedModelFamilies.includes(family)) {
|
||||
throw new ForbiddenError(
|
||||
`Model family '${family}' is not enabled on this proxy`
|
||||
);
|
||||
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
|
||||
* For AWS/Azure/Google requests, the body is signed earlier in the request
|
||||
* pipeline, before the proxy middleware. This function just assigns the path
|
||||
* and headers to the proxy request.
|
||||
*/
|
||||
|
||||
@@ -84,9 +84,9 @@ async function executePreprocessors(
|
||||
} catch (error) {
|
||||
if (error.constructor.name === "ZodError") {
|
||||
const msg = error?.issues
|
||||
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
|
||||
?.map((issue: ZodIssue) => issue.message)
|
||||
.join("; ");
|
||||
req.log.warn({ issues: msg }, "Prompt validation failed.");
|
||||
req.log.info(msg, "Prompt validation failed.");
|
||||
} else {
|
||||
req.log.error(error, "Error while executing request preprocessor");
|
||||
}
|
||||
@@ -143,7 +143,7 @@ const handleTestMessage: RequestHandler = (req, res) => {
|
||||
};
|
||||
|
||||
function isTestMessage(body: any) {
|
||||
const { messages, prompt, contents } = body;
|
||||
const { messages, prompt } = body;
|
||||
|
||||
if (messages) {
|
||||
return (
|
||||
@@ -151,11 +151,6 @@ function isTestMessage(body: any) {
|
||||
messages[0].role === "user" &&
|
||||
messages[0].content === "Hi"
|
||||
);
|
||||
} else if (contents) {
|
||||
return (
|
||||
contents.length === 1 &&
|
||||
contents[0].parts[0]?.text === "Hi"
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
prompt?.trim() === "Human: Hi\n\nAssistant:" ||
|
||||
|
||||
@@ -2,38 +2,39 @@ import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addGoogleAIKey: RequestPreprocessor = (req) => {
|
||||
const inboundValid =
|
||||
req.inboundApi === "openai" || req.inboundApi === "google-ai";
|
||||
const outboundValid = req.outboundApi === "google-ai";
|
||||
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
|
||||
const serviceValid = req.service === "google-ai";
|
||||
if (!inboundValid || !outboundValid || !serviceValid) {
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addGoogleAIKey called on invalid request");
|
||||
}
|
||||
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model;
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
req.key = keyPool.get(model, "google-ai");
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model, stream: req.isStreaming },
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Google AI API key to request"
|
||||
);
|
||||
|
||||
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
|
||||
const payload = { ...req.body, stream: undefined, model: undefined };
|
||||
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
delete req.body.stream;
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: "generativelanguage.googleapis.com",
|
||||
path: `/v1beta/models/${model}:${
|
||||
req.isStreaming ? "streamGenerateContent" : "generateContent"
|
||||
}?key=${req.key.key}`,
|
||||
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
|
||||
headers: {
|
||||
["host"]: `generativelanguage.googleapis.com`,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
AnthropicChatMessage,
|
||||
GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
@@ -30,13 +31,10 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
}
|
||||
case "anthropic-chat": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
let system = req.body.system ?? "";
|
||||
if (Array.isArray(system)) {
|
||||
system = system
|
||||
.map((m: { type: string; text: string }) => m.text)
|
||||
.join("\n");
|
||||
}
|
||||
const prompt = { system, messages: req.body.messages };
|
||||
const prompt = {
|
||||
system: req.body.system ?? "",
|
||||
messages: req.body.messages,
|
||||
};
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
@@ -52,11 +50,9 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "mistral-ai":
|
||||
case "mistral-text": {
|
||||
case "mistral-ai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: string | MistralAIChatMessage[] =
|
||||
req.body.messages ?? req.body.prompt;
|
||||
const prompt: MistralAIChatMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ function getPromptFromRequest(req: Request) {
|
||||
switch (service) {
|
||||
case "anthropic-chat":
|
||||
return flattenAnthropicMessages(body.messages);
|
||||
case "anthropic-text":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return body.messages
|
||||
@@ -70,10 +72,8 @@ function getPromptFromRequest(req: Request) {
|
||||
return `${msg.role}: ${text}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
case "anthropic-text":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
case "mistral-text":
|
||||
return body.prompt;
|
||||
case "google-ai":
|
||||
return body.prompt.text;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import express, { Request } from "express";
|
||||
import express from "express";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
@@ -6,12 +6,8 @@ import {
|
||||
AnthropicV1TextSchema,
|
||||
AnthropicV1MessagesSchema,
|
||||
} from "../../../../shared/api-schemas";
|
||||
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import {
|
||||
AWSMistralV1ChatCompletionsSchema,
|
||||
AWSMistralV1TextCompletionsSchema,
|
||||
} from "../../../../shared/api-schemas/mistral-ai";
|
||||
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
|
||||
@@ -33,33 +29,56 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
req.body.prompt = preamble + req.body.prompt;
|
||||
}
|
||||
|
||||
// AWS uses mostly the same parameters as Anthropic, with a few removed params
|
||||
// and much stricter validation on unused parameters. Rather than treating it
|
||||
// as a separate schema we will use the anthropic ones and strip the unused
|
||||
// parameters.
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
let strippedParams: Record<string, unknown>;
|
||||
if (req.outboundApi === "anthropic-chat") {
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "bedrock-2023-05-31";
|
||||
} else {
|
||||
strippedParams = AnthropicV1TextSchema.pick({
|
||||
prompt: true,
|
||||
max_tokens_to_sample: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
}
|
||||
|
||||
const credential = getCredentialParts(req);
|
||||
const host = AMZ_HOST.replace("%REGION%", credential.region);
|
||||
|
||||
// AWS only uses 2023-06-01 and does not actually check this header, but we
|
||||
// set it so that the stream adapter always selects the correct transformer.
|
||||
req.headers["anthropic-version"] = "2023-06-01";
|
||||
|
||||
// If our key has an inference profile compatible with the requested model,
|
||||
// we want to use the inference profile instead of the model ID when calling
|
||||
// InvokeModel as that will give us higher rate limits.
|
||||
const profile =
|
||||
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) =>
|
||||
p.includes(model)
|
||||
) || model;
|
||||
|
||||
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
|
||||
// with the headers generated by the SDK.
|
||||
const newRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: host,
|
||||
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`,
|
||||
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
|
||||
headers: {
|
||||
["Host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(applyAwsStrictValidation(req)),
|
||||
body: JSON.stringify(strippedParams),
|
||||
});
|
||||
|
||||
if (stream) {
|
||||
@@ -70,13 +89,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
|
||||
const { key, body, inboundApi, outboundApi } = req;
|
||||
req.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
model: body.model,
|
||||
inferenceProfile: profile,
|
||||
inboundApi,
|
||||
outboundApi,
|
||||
},
|
||||
{ key: key.hash, model: body.model, inboundApi, outboundApi },
|
||||
"Assigned AWS credentials to request"
|
||||
);
|
||||
|
||||
@@ -115,50 +128,3 @@ async function sign(request: HttpRequest, credential: Credential) {
|
||||
|
||||
return signer.sign(request);
|
||||
}
|
||||
|
||||
function applyAwsStrictValidation(req: Request): unknown {
|
||||
// AWS uses vendor API formats but imposes additional (more strict) validation
|
||||
// rules, namely that extraneous parameters are not allowed. We will validate
|
||||
// using the vendor's zod schema but apply `.strip` to ensure that any
|
||||
// extraneous parameters are removed.
|
||||
let strippedParams: Record<string, unknown> = {};
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic-text":
|
||||
strippedParams = AnthropicV1TextSchema.pick({
|
||||
prompt: true,
|
||||
max_tokens_to_sample: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
break;
|
||||
case "anthropic-chat":
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
tools: true,
|
||||
tool_choice: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "bedrock-2023-05-31";
|
||||
break;
|
||||
case "mistral-ai":
|
||||
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
|
||||
break;
|
||||
case "mistral-text":
|
||||
strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body);
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unexpected outbound API for AWS.");
|
||||
}
|
||||
return strippedParams;
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
import express from "express";
|
||||
import crypto from "crypto";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
|
||||
|
||||
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||
|
||||
export const signGcpRequest: RequestPreprocessor = async (req) => {
|
||||
const serviceValid = req.service === "gcp";
|
||||
if (!serviceValid) {
|
||||
throw new Error("addVertexAIKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const { model, stream } = req.body;
|
||||
req.key = keyPool.get(model, "gcp");
|
||||
|
||||
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request");
|
||||
|
||||
req.isStreaming = String(stream) === "true";
|
||||
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
let strippedParams: Record<string, unknown>;
|
||||
strippedParams = AnthropicV1MessagesSchema.pick({
|
||||
messages: true,
|
||||
system: true,
|
||||
max_tokens: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
tools: true,
|
||||
tool_choice: true,
|
||||
stream: true,
|
||||
})
|
||||
.strip()
|
||||
.parse(req.body);
|
||||
strippedParams.anthropic_version = "vertex-2023-10-16";
|
||||
|
||||
const [accessToken, credential] = await getAccessToken(req);
|
||||
|
||||
const host = GCP_HOST.replace("%REGION%", credential.region);
|
||||
// GCP doesn't use the anthropic-version header, but we set it to ensure the
|
||||
// stream adapter selects the correct transformer.
|
||||
req.headers["anthropic-version"] = "2023-06-01";
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: host,
|
||||
path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`,
|
||||
headers: {
|
||||
["host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
["authorization"]: `Bearer ${accessToken}`,
|
||||
},
|
||||
body: JSON.stringify(strippedParams),
|
||||
};
|
||||
};
|
||||
|
||||
async function getAccessToken(
|
||||
req: express.Request
|
||||
): Promise<[string, Credential]> {
|
||||
// TODO: access token caching to reduce latency
|
||||
const credential = getCredentialParts(req);
|
||||
const signedJWT = await createSignedJWT(
|
||||
credential.clientEmail,
|
||||
credential.privateKey
|
||||
);
|
||||
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
|
||||
if (accessToken === null) {
|
||||
req.log.warn(
|
||||
{ key: req.key!.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
);
|
||||
throw new Error("The access token is invalid.");
|
||||
}
|
||||
return [accessToken, credential];
|
||||
}
|
||||
|
||||
async function createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
str2ab(atob(pkey)),
|
||||
{
|
||||
name: "RSASSA-PKCS1-v1_5",
|
||||
hash: { name: "SHA-256" },
|
||||
},
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = {
|
||||
alg: "RS256",
|
||||
typ: "JWT",
|
||||
};
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
|
||||
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
async function exchangeJwtForAccessToken(
|
||||
signedJwt: string
|
||||
): Promise<[string | null, string]> {
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signedJwt,
|
||||
};
|
||||
|
||||
const r = await fetch(authUrl, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: Object.entries(params)
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join("&"),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (r.access_token) {
|
||||
return [r.access_token, ""];
|
||||
}
|
||||
|
||||
return [null, JSON.stringify(r)];
|
||||
}
|
||||
|
||||
function str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
type Credential = {
|
||||
projectId: string;
|
||||
clientEmail: string;
|
||||
region: string;
|
||||
privateKey: string;
|
||||
};
|
||||
|
||||
function getCredentialParts(req: express.Request): Credential {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] =
|
||||
req.key!.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
req.log.error(
|
||||
{ key: req.key!.hash },
|
||||
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
|
||||
);
|
||||
throw new Error("The key assigned to this request is invalid.");
|
||||
}
|
||||
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
import { Request } from "express";
|
||||
import {
|
||||
API_REQUEST_VALIDATORS,
|
||||
API_REQUEST_TRANSFORMERS,
|
||||
@@ -13,33 +12,29 @@ import { RequestPreprocessor } from "../index";
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
const sameService = req.inboundApi === req.outboundApi;
|
||||
const alreadyTransformed = req.retryCount > 0;
|
||||
const notTransformable =
|
||||
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
|
||||
|
||||
if (alreadyTransformed) {
|
||||
return;
|
||||
} else if (notTransformable) {
|
||||
// This is probably an indication of a bug in the proxy.
|
||||
const { inboundApi, outboundApi, method, path } = req;
|
||||
req.log.warn(
|
||||
{ inboundApi, outboundApi, method, path },
|
||||
"`transformOutboundPayload` called on a non-transformable request."
|
||||
if (alreadyTransformed || notTransformable) return;
|
||||
|
||||
// TODO: this should be an APIFormatTransformer
|
||||
if (req.inboundApi === "mistral-ai") {
|
||||
const messages = req.body.messages;
|
||||
req.body.messages = fixMistralPrompt(messages);
|
||||
req.log.info(
|
||||
{ old: messages.length, new: req.body.messages.length },
|
||||
"Fixed Mistral prompt"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
applyMistralPromptFixes(req);
|
||||
|
||||
// Native prompts are those which were already provided by the client in the
|
||||
// target API format. We don't need to transform them.
|
||||
const isNativePrompt = req.inboundApi === req.outboundApi;
|
||||
if (isNativePrompt) {
|
||||
if (sameService) {
|
||||
const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body);
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body: req.body },
|
||||
"Native prompt request validation failed."
|
||||
"Request validation failed"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
@@ -47,12 +42,11 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prompt requires translation from one API format to another.
|
||||
const transformation = `${req.inboundApi}->${req.outboundApi}` as const;
|
||||
const transFn = API_REQUEST_TRANSFORMERS[transformation];
|
||||
|
||||
if (transFn) {
|
||||
req.log.info({ transformation }, "Transforming request...");
|
||||
req.log.info({ transformation }, "Transforming request");
|
||||
req.body = await transFn(req);
|
||||
return;
|
||||
}
|
||||
@@ -61,36 +55,3 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
`${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.`
|
||||
);
|
||||
};
|
||||
|
||||
// handles weird cases that don't fit into our abstractions
|
||||
function applyMistralPromptFixes(req: Request): void {
|
||||
if (req.inboundApi === "mistral-ai") {
|
||||
// Mistral Chat is very similar to OpenAI but not identical and many clients
|
||||
// don't properly handle the differences. We will try to validate the
|
||||
// mistral prompt and try to fix it if it fails. It will be re-validated
|
||||
// after this function returns.
|
||||
const result = API_REQUEST_VALIDATORS["mistral-ai"].parse(req.body);
|
||||
req.body.messages = fixMistralPrompt(result.messages);
|
||||
req.log.info(
|
||||
{ n: req.body.messages.length, prev: result.messages.length },
|
||||
"Applied Mistral chat prompt fixes."
|
||||
);
|
||||
|
||||
// If the prompt relies on `prefix: true` for the last message, we need to
|
||||
// convert it to a text completions request because AWS Mistral support for
|
||||
// this feature is broken.
|
||||
// On Mistral La Plateforme, we can't do this because they don't expose
|
||||
// a text completions endpoint.
|
||||
const { messages } = req.body;
|
||||
const lastMessage = messages && messages[messages.length - 1];
|
||||
if (lastMessage?.role === "assistant" && req.service === "aws") {
|
||||
// enable prefix if client forgot, otherwise the template will insert an
|
||||
// eos token which is very unlikely to be what the client wants.
|
||||
lastMessage.prefix = true;
|
||||
req.outboundApi = "mistral-text";
|
||||
req.log.info(
|
||||
"Native Mistral chat prompt relies on assistant message prefix. Converting to text completions request."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,8 @@ import { RequestPreprocessor } from "../index";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
// todo: make configurable
|
||||
const GOOGLE_AI_MAX_CONTEXT = 1024000;
|
||||
const MISTRAL_AI_MAX_CONTENT = 131072;
|
||||
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
||||
const MISTRAL_AI_MAX_CONTENT = 32768;
|
||||
|
||||
/**
|
||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||
@@ -38,7 +37,6 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
proxyMax = MISTRAL_AI_MAX_CONTENT;
|
||||
break;
|
||||
case "openai-image":
|
||||
@@ -58,8 +56,6 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 16384;
|
||||
} else if (model.match(/^gpt-4o/)) {
|
||||
modelMax = 128000;
|
||||
} else if (model.match(/^chatgpt-4o/)) {
|
||||
modelMax = 128000;
|
||||
} else if (model.match(/gpt-4-turbo(-\d{4}-\d{2}-\d{2})?$/)) {
|
||||
modelMax = 131072;
|
||||
} else if (model.match(/gpt-4-turbo(-preview)?$/)) {
|
||||
@@ -84,19 +80,17 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^claude-3/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^gemini-/)) {
|
||||
modelMax = 1024000;
|
||||
} else if (model.match(/^gemini-\d{3}$/)) {
|
||||
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
|
||||
modelMax = MISTRAL_AI_MAX_CONTENT;
|
||||
} else if (model.match(/^anthropic\.claude-3/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^anthropic\.claude/)) {
|
||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/tral/)) {
|
||||
// catches mistral, mixtral, codestral, mathstral, etc. mistral models have
|
||||
// no name convention and wildly different context windows so this is a
|
||||
// catch-all
|
||||
modelMax = MISTRAL_AI_MAX_CONTENT;
|
||||
} else {
|
||||
req.log.warn({ model }, "Unknown model, using 200k token limit.");
|
||||
modelMax = 200000;
|
||||
|
||||
@@ -28,7 +28,6 @@ export const validateVision: RequestPreprocessor = async (req) => {
|
||||
case "anthropic-text":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
case "openai-image":
|
||||
case "openai-text":
|
||||
return;
|
||||
|
||||
@@ -65,7 +65,7 @@ type ErrorGeneratorOptions = {
|
||||
format: APIFormat | "unknown";
|
||||
title: string;
|
||||
message: string;
|
||||
obj?: Record<string, any>;
|
||||
obj?: object;
|
||||
reqId: string | number | object;
|
||||
model?: string;
|
||||
statusCode?: number;
|
||||
@@ -95,23 +95,6 @@ export function tryInferFormat(body: any): APIFormat | "unknown" {
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
// avoid leaking upstream hostname on dns resolution error
|
||||
function redactHostname(options: ErrorGeneratorOptions): ErrorGeneratorOptions {
|
||||
if (!options.message.includes("getaddrinfo")) return options;
|
||||
|
||||
const redacted = { ...options };
|
||||
redacted.message = "Could not resolve hostname";
|
||||
|
||||
if (typeof redacted.obj?.error === "object") {
|
||||
redacted.obj = {
|
||||
...redacted.obj,
|
||||
error: { message: "Could not resolve hostname" },
|
||||
};
|
||||
}
|
||||
|
||||
return redacted;
|
||||
}
|
||||
|
||||
export function sendErrorToClient({
|
||||
options,
|
||||
req,
|
||||
@@ -121,26 +104,27 @@ export function sendErrorToClient({
|
||||
req: express.Request;
|
||||
res: express.Response;
|
||||
}) {
|
||||
const redactedOpts = redactHostname(options);
|
||||
const { format: inputFormat } = redactedOpts;
|
||||
const { format: inputFormat } = options;
|
||||
|
||||
// This is an error thrown before we know the format of the request, so we
|
||||
// can't send a response in the format the client expects.
|
||||
const format =
|
||||
inputFormat === "unknown" ? tryInferFormat(req.body) : inputFormat;
|
||||
if (format === "unknown") {
|
||||
return res.status(redactedOpts.statusCode || 400).json({
|
||||
error: redactedOpts.message,
|
||||
details: redactedOpts.obj,
|
||||
return res.status(options.statusCode || 400).json({
|
||||
error: options.message,
|
||||
details: options.obj,
|
||||
});
|
||||
}
|
||||
|
||||
const completion = buildSpoofedCompletion({ ...redactedOpts, format });
|
||||
const event = buildSpoofedSSE({ ...redactedOpts, format });
|
||||
const completion = buildSpoofedCompletion({ ...options, format });
|
||||
const event = buildSpoofedSSE({ ...options, format });
|
||||
const isStreaming =
|
||||
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.setHeader("x-oai-proxy-error", redactedOpts.title);
|
||||
res.setHeader("x-oai-proxy-error-status", redactedOpts.statusCode || 500);
|
||||
res.setHeader("x-oai-proxy-error", options.title);
|
||||
res.setHeader("x-oai-proxy-error-status", options.statusCode || 500);
|
||||
}
|
||||
|
||||
if (isStreaming) {
|
||||
@@ -189,11 +173,6 @@ export function buildSpoofedCompletion({
|
||||
},
|
||||
],
|
||||
};
|
||||
case "mistral-text":
|
||||
return {
|
||||
outputs: [{ text: content, stop_reason: title }],
|
||||
model,
|
||||
}
|
||||
case "openai-text":
|
||||
return {
|
||||
id: "error-" + id,
|
||||
@@ -225,7 +204,13 @@ export function buildSpoofedCompletion({
|
||||
stop_sequence: null,
|
||||
};
|
||||
case "google-ai":
|
||||
// TODO: Native Google AI non-streaming responses are not supported, this
|
||||
// is an untested guess at what the response should look like.
|
||||
return {
|
||||
id: "error-" + id,
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model,
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: content }], role: "model" },
|
||||
@@ -272,11 +257,6 @@ export function buildSpoofedSSE({
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: title }],
|
||||
};
|
||||
break;
|
||||
case "mistral-text":
|
||||
event = {
|
||||
outputs: [{ text: content, stop_reason: title }],
|
||||
};
|
||||
break;
|
||||
case "openai-text":
|
||||
event = {
|
||||
id: "cmpl-" + id,
|
||||
@@ -306,10 +286,7 @@ export function buildSpoofedSSE({
|
||||
};
|
||||
break;
|
||||
case "google-ai":
|
||||
// TODO: google ai supports two streaming transports, SSE and JSON.
|
||||
// we currently only support SSE.
|
||||
// return JSON.stringify({
|
||||
event = {
|
||||
return JSON.stringify({
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: content }], role: "model" },
|
||||
@@ -319,8 +296,7 @@ export function buildSpoofedSSE({
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
};
|
||||
break;
|
||||
});
|
||||
case "openai-image":
|
||||
return JSON.stringify(obj);
|
||||
default:
|
||||
|
||||
@@ -22,19 +22,18 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
/**
|
||||
* `handleStreamedResponse` consumes a streamed response from the upstream API,
|
||||
* decodes chunk-by-chunk into a stream of events, transforms those events into
|
||||
* the client's requested format, and forwards the result to the client.
|
||||
*
|
||||
* `handleStreamedResponse` consumes and transforms a streamed response from the
|
||||
* upstream service, forwarding events to the client in their requested format.
|
||||
* After the entire stream has been consumed, it resolves with the full response
|
||||
* body so that subsequent middleware in the chain can process it as if it were
|
||||
* a non-streaming response (to count output tokens, track usage, etc).
|
||||
* a non-streaming response.
|
||||
*
|
||||
* In the event of an error, the request's streaming flag is unset and the
|
||||
* request is bounced back to the non-streaming response handler. If the error
|
||||
* is retryable, that handler will re-enqueue the request and also reset the
|
||||
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
|
||||
* places, so it's hard to keep track of.
|
||||
* In the event of an error, the request's streaming flag is unset and the non-
|
||||
* streaming response handler is called instead.
|
||||
*
|
||||
* If the error is retryable, that handler will re-enqueue the request and also
|
||||
* reset the streaming flag. Unfortunately the streaming flag is set and unset
|
||||
* in multiple places, so it's hard to keep track of.
|
||||
*/
|
||||
export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
proxyRes,
|
||||
@@ -71,21 +70,13 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
logger: req.log,
|
||||
};
|
||||
|
||||
// While the request is streaming, aggregator collects all events so that we
|
||||
// can compile them into a single response object and publish that to the
|
||||
// remaining middleware. Because we have an OpenAI transformer for every
|
||||
// supported format, EventAggregator always consumes OpenAI events so that we
|
||||
// only have to write one aggregator (OpenAI input) for each output format.
|
||||
const aggregator = new EventAggregator(req);
|
||||
|
||||
// Decoder reads from the raw response buffer and produces a stream of
|
||||
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
|
||||
// streaming JSON, etc).
|
||||
// Decoder turns the raw response stream into a stream of events in some
|
||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
|
||||
// Adapter consumes the decoded events and produces server-sent events so we
|
||||
// have a standard event format for the client and to translate between API
|
||||
// message formats.
|
||||
// Adapter transforms the decoded events into server-sent events.
|
||||
const adapter = new SSEStreamAdapter(streamOptions);
|
||||
// Aggregator compiles all events into a single response object.
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
// Transformer converts server-sent events from one vendor's API message
|
||||
// format to another.
|
||||
const transformer = new SSEMessageTransformer({
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
/* This file is fucking horrendous, sorry */
|
||||
// TODO: extract all per-service error response handling into its own modules
|
||||
import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import { config } from "../../../config";
|
||||
@@ -186,13 +185,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
throw new HttpError(statusCode, parseError.message);
|
||||
}
|
||||
|
||||
const service = req.key!.service;
|
||||
if (service === "gcp") {
|
||||
if (Array.isArray(errorPayload)) {
|
||||
errorPayload = errorPayload[0];
|
||||
}
|
||||
}
|
||||
|
||||
const errorType =
|
||||
errorPayload.error?.code ||
|
||||
errorPayload.error?.type ||
|
||||
@@ -202,24 +194,21 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
|
||||
`Received error response from upstream. (${proxyRes.statusMessage})`
|
||||
);
|
||||
|
||||
|
||||
// TODO: split upstream error handling into separate modules for each service,
|
||||
// this is out of control.
|
||||
|
||||
const service = req.key!.service;
|
||||
if (service === "aws") {
|
||||
// Try to standardize the error format for AWS
|
||||
errorPayload.error = { message: errorPayload.message, type: errorType };
|
||||
delete errorPayload.message;
|
||||
} else if (service === "gcp") {
|
||||
// Try to standardize the error format for GCP
|
||||
if (errorPayload.error?.code) { // GCP Error
|
||||
errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code };
|
||||
}
|
||||
}
|
||||
|
||||
if (statusCode === 400) {
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
@@ -236,11 +225,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
break;
|
||||
case "anthropic":
|
||||
case "aws":
|
||||
case "gcp":
|
||||
await handleAnthropicAwsBadRequestError(req, errorPayload);
|
||||
break;
|
||||
case "google-ai":
|
||||
await handleGoogleAIBadRequestError(req, errorPayload);
|
||||
await handleAnthropicBadRequestError(req, errorPayload);
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
@@ -262,9 +247,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
);
|
||||
keyPool.update(req.key!, { allowsMultimodality: false });
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError(
|
||||
"Claude request re-enqueued because key does not support multimodality."
|
||||
);
|
||||
throw new RetryableError("Claude request re-enqueued because key does not support multimodality.");
|
||||
} else {
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
|
||||
@@ -292,12 +275,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
default:
|
||||
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
|
||||
}
|
||||
return;
|
||||
case "mistral-ai":
|
||||
case "gcp":
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
|
||||
return;
|
||||
}
|
||||
} else if (statusCode === 429) {
|
||||
switch (service) {
|
||||
@@ -310,9 +287,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "aws":
|
||||
await handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "gcp":
|
||||
await handleGcpRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "azure":
|
||||
case "mistral-ai":
|
||||
await handleAzureRateLimitError(req, errorPayload);
|
||||
@@ -349,9 +323,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "gcp":
|
||||
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||
break;
|
||||
@@ -376,7 +347,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
throw new HttpError(statusCode, errorPayload.error?.message);
|
||||
};
|
||||
|
||||
async function handleAnthropicAwsBadRequestError(
|
||||
async function handleAnthropicBadRequestError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
@@ -411,13 +382,11 @@ async function handleAnthropicAwsBadRequestError(
|
||||
return;
|
||||
}
|
||||
|
||||
const isDisabled =
|
||||
error?.message?.match(/organization has been disabled/i) ||
|
||||
error?.message?.match(/^operation not allowed/i);
|
||||
const isDisabled = error?.message?.match(/organization has been disabled/i);
|
||||
if (isDisabled) {
|
||||
req.log.warn(
|
||||
{ key: req.key?.hash, message: error?.message },
|
||||
"Anthropic/AWS key has been disabled."
|
||||
"Anthropic key has been disabled."
|
||||
);
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned key has been disabled. (${error?.message})`;
|
||||
@@ -458,19 +427,6 @@ async function handleAwsRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
async function handleGcpRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") {
|
||||
keyPool.markRateLimited(req.key!);
|
||||
await reenqueueRequest(req);
|
||||
throw new RetryableError("GCP rate-limited request re-enqueued.");
|
||||
} else {
|
||||
errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`;
|
||||
}
|
||||
}
|
||||
|
||||
async function handleOpenAIRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
@@ -556,7 +512,7 @@ async function handleOpenAIRateLimitError(
|
||||
// keyPool.markRateLimited(req.key!);
|
||||
// break;
|
||||
default:
|
||||
errorPayload.proxy_note = `This is likely a temporary error with the API. Try again in a few seconds.`;
|
||||
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
|
||||
break;
|
||||
}
|
||||
return errorPayload;
|
||||
@@ -578,42 +534,6 @@ async function handleAzureRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
//{"error":{"code":400,"message":"API Key not found. Please pass a valid API key.","status":"INVALID_ARGUMENT","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"API_KEY_INVALID","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com"}}]}}
|
||||
//{"error":{"code":400,"message":"Gemini API free tier is not available in your country. Please enable billing on your project in Google AI Studio.","status":"FAILED_PRECONDITION"}}
|
||||
async function handleGoogleAIBadRequestError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
const error = errorPayload.error || {};
|
||||
const { message, status, details } = error;
|
||||
|
||||
if (status === "INVALID_ARGUMENT") {
|
||||
const reason = details?.[0]?.reason;
|
||||
if (reason === "API_KEY_INVALID") {
|
||||
req.log.warn(
|
||||
{ key: req.key?.hash, status, reason, msg: error.message },
|
||||
"Received `API_KEY_INVALID` error from Google AI. Check the configured API key."
|
||||
);
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned API key is invalid.`;
|
||||
}
|
||||
} else if (status === "FAILED_PRECONDITION") {
|
||||
if (message.match(/please enable billing/i)) {
|
||||
req.log.warn(
|
||||
{ key: req.key?.hash, status, msg: error.message },
|
||||
"Cannot use key due to billing restrictions."
|
||||
);
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned API key cannot be used.`;
|
||||
}
|
||||
} else {
|
||||
req.log.warn(
|
||||
{ key: req.key?.hash, status, msg: error.message },
|
||||
"Received unexpected 400 error from Google AI."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
|
||||
async function handleGoogleAIRateLimitError(
|
||||
req: Request,
|
||||
|
||||
@@ -11,8 +11,7 @@ import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import {
|
||||
AnthropicChatMessage,
|
||||
flattenAnthropicMessages,
|
||||
GoogleAIChatMessage,
|
||||
flattenAnthropicMessages, GoogleAIChatMessage,
|
||||
MistralAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../../shared/api-schemas";
|
||||
@@ -75,16 +74,8 @@ const getPromptForRequest = (
|
||||
case "mistral-ai":
|
||||
return req.body.messages;
|
||||
case "anthropic-chat":
|
||||
let system = req.body.system;
|
||||
if (Array.isArray(system)) {
|
||||
system = system
|
||||
.map((m: { type: string; text: string }) => m.text)
|
||||
.join("\n");
|
||||
}
|
||||
return { system, messages: req.body.messages };
|
||||
return { system: req.body.system, messages: req.body.messages };
|
||||
case "openai-text":
|
||||
case "anthropic-text":
|
||||
case "mistral-text":
|
||||
return req.body.prompt;
|
||||
case "openai-image":
|
||||
return {
|
||||
@@ -94,6 +85,8 @@ const getPromptForRequest = (
|
||||
quality: req.body.quality,
|
||||
revisedPrompt: responseBody.data[0].revised_prompt,
|
||||
};
|
||||
case "anthropic-text":
|
||||
return req.body.prompt;
|
||||
case "google-ai":
|
||||
return { contents: req.body.contents };
|
||||
default:
|
||||
@@ -120,7 +113,9 @@ const flattenMessages = (
|
||||
if (isGoogleAIChatPrompt(val)) {
|
||||
return val.contents
|
||||
.map(({ parts, role }) => {
|
||||
const text = parts.map((p) => p.text).join("\n");
|
||||
const text = parts
|
||||
.map((p) => p.text)
|
||||
.join("\n");
|
||||
return `${role}: ${text}`;
|
||||
})
|
||||
.join("\n");
|
||||
@@ -148,7 +143,11 @@ const flattenMessages = (
|
||||
function isGoogleAIChatPrompt(
|
||||
val: unknown
|
||||
): val is { contents: GoogleAIChatMessage[] } {
|
||||
return typeof val === "object" && val !== null && "contents" in val;
|
||||
return (
|
||||
typeof val === "object" &&
|
||||
val !== null &&
|
||||
"contents" in val
|
||||
);
|
||||
}
|
||||
|
||||
function isAnthropicChatPrompt(
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type MistralChatCompletionResponse = {
|
||||
choices: {
|
||||
index: number;
|
||||
message: { role: string; content: string };
|
||||
finish_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized Mistral chat completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForMistralChat(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): MistralChatCompletionResponse {
|
||||
let merged: MistralChatCompletionResponse = {
|
||||
choices: [
|
||||
{ index: 0, message: { role: "", content: "" }, finish_reason: "" },
|
||||
],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
acc.choices[0].message.role = event.choices[0].delta.role ?? "assistant";
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.choices[0].finish_reason = event.choices[0].finish_reason ?? "";
|
||||
if (event.choices[0].delta.content) {
|
||||
acc.choices[0].message.content += event.choices[0].delta.content;
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
import { OpenAIChatCompletionStreamEvent } from "../index";
|
||||
|
||||
export type MistralTextCompletionResponse = {
|
||||
outputs: {
|
||||
text: string;
|
||||
stop_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Given a list of OpenAI chat completion events, compiles them into a single
|
||||
* finalized Mistral text completion response so that non-streaming middleware
|
||||
* can operate on it as if it were a blocking response.
|
||||
*/
|
||||
export function mergeEventsForMistralText(
|
||||
events: OpenAIChatCompletionStreamEvent[]
|
||||
): MistralTextCompletionResponse {
|
||||
let merged: MistralTextCompletionResponse = {
|
||||
outputs: [{ text: "", stop_reason: "" }],
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
// The first event will only contain role assignment and response metadata
|
||||
if (i === 0) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
acc.outputs[0].text += event.choices[0].delta.content ?? "";
|
||||
acc.outputs[0].stop_reason = event.choices[0].finish_reason ?? "";
|
||||
|
||||
return acc;
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
|
||||
if (eventType === "chunk") {
|
||||
result = input[eventType];
|
||||
} else {
|
||||
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
|
||||
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
|
||||
result = { [eventType]: input[eventType] } as any;
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import express from "express";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import {
|
||||
@@ -7,13 +6,8 @@ import {
|
||||
mergeEventsForAnthropicText,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
mergeEventsForMistralChat,
|
||||
mergeEventsForMistralText,
|
||||
AnthropicV2StreamEvent,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
mistralAIToOpenAI,
|
||||
MistralAIStreamEvent,
|
||||
MistralChatCompletionEvent,
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
@@ -21,70 +15,45 @@ import {
|
||||
* compiles them into a single finalized response for downstream middleware.
|
||||
*/
|
||||
export class EventAggregator {
|
||||
private readonly model: string;
|
||||
private readonly requestFormat: APIFormat;
|
||||
private readonly responseFormat: APIFormat;
|
||||
private readonly format: APIFormat;
|
||||
private readonly events: OpenAIChatCompletionStreamEvent[];
|
||||
|
||||
constructor({ body, inboundApi, outboundApi }: express.Request) {
|
||||
constructor({ format }: { format: APIFormat }) {
|
||||
this.events = [];
|
||||
this.requestFormat = inboundApi;
|
||||
this.responseFormat = outboundApi;
|
||||
this.model = body.model;
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
addEvent(
|
||||
event:
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralAIStreamEvent
|
||||
) {
|
||||
addEvent(event: OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent) {
|
||||
if (eventIsOpenAIEvent(event)) {
|
||||
this.events.push(event);
|
||||
} else {
|
||||
// horrible special case. previously all transformers' target format was
|
||||
// openai, so the event aggregator could conveniently assume all incoming
|
||||
// events were in openai format.
|
||||
// now we have added some transformers that convert between non-openai
|
||||
// formats, so aggregator needs to know how to collapse for more than
|
||||
// just openai.
|
||||
// because writing aggregation logic for every possible output format is
|
||||
// annoying, we will just transform any non-openai output events to openai
|
||||
// format (even if the client did not request openai at all) so that we
|
||||
// still only need to write aggregators for openai SSEs.
|
||||
let openAIEvent: OpenAIChatCompletionStreamEvent | undefined;
|
||||
switch (this.requestFormat) {
|
||||
case "anthropic-text":
|
||||
assertIsAnthropicV2Event(event);
|
||||
openAIEvent = anthropicV2ToOpenAI({
|
||||
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
|
||||
lastPosition: -1,
|
||||
index: 0,
|
||||
fallbackId: event.log_id || "fallback-" + Date.now(),
|
||||
fallbackModel: event.model || this.model || "fallback-claude-3",
|
||||
})?.event;
|
||||
break;
|
||||
case "mistral-ai":
|
||||
assertIsMistralChatEvent(event);
|
||||
openAIEvent = mistralAIToOpenAI({
|
||||
data: `data: ${JSON.stringify(event)}\n\n`,
|
||||
lastPosition: -1,
|
||||
index: 0,
|
||||
fallbackId: "fallback-" + Date.now(),
|
||||
fallbackModel: this.model || "fallback-mistral",
|
||||
})?.event;
|
||||
break;
|
||||
}
|
||||
if (openAIEvent) {
|
||||
this.events.push(openAIEvent);
|
||||
// now we have added anthropic-chat-to-text, so aggregator needs to know
|
||||
// how to collapse events from two formats.
|
||||
// because that is annoying, we will simply transform anthropic events to
|
||||
// openai (even if the client didn't ask for openai) so we don't have to
|
||||
// write aggregation logic for anthropic chat (which is also a troublesome
|
||||
// stateful format).
|
||||
const openAIEvent = anthropicV2ToOpenAI({
|
||||
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
|
||||
lastPosition: -1,
|
||||
index: 0,
|
||||
fallbackId: event.log_id || "event-aggregator-fallback",
|
||||
fallbackModel: event.model || "claude-3-fallback",
|
||||
});
|
||||
if (openAIEvent.event) {
|
||||
this.events.push(openAIEvent.event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getFinalResponse() {
|
||||
switch (this.responseFormat) {
|
||||
switch (this.format) {
|
||||
case "openai":
|
||||
case "google-ai": // TODO: this is probably wrong now that we support native Google Makersuite prompts
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
@@ -92,16 +61,10 @@ export class EventAggregator {
|
||||
return mergeEventsForAnthropicText(this.events);
|
||||
case "anthropic-chat":
|
||||
return mergeEventsForAnthropicChat(this.events);
|
||||
case "mistral-ai":
|
||||
return mergeEventsForMistralChat(this.events);
|
||||
case "mistral-text":
|
||||
return mergeEventsForMistralText(this.events);
|
||||
case "openai-image":
|
||||
throw new Error(
|
||||
`SSE aggregation not supported for ${this.responseFormat}`
|
||||
);
|
||||
throw new Error(`SSE aggregation not supported for ${this.format}`);
|
||||
default:
|
||||
assertNever(this.responseFormat);
|
||||
assertNever(this.format);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,17 +78,3 @@ function eventIsOpenAIEvent(
|
||||
): event is OpenAIChatCompletionStreamEvent {
|
||||
return event?.object === "chat.completion.chunk";
|
||||
}
|
||||
|
||||
function assertIsAnthropicV2Event(event: any): asserts event is AnthropicV2StreamEvent {
|
||||
if (!event?.completion) {
|
||||
throw new Error(`Bad event for Anthropic V2 SSE aggregation`);
|
||||
}
|
||||
}
|
||||
|
||||
function assertIsMistralChatEvent(
|
||||
event: any
|
||||
): asserts event is MistralChatCompletionEvent {
|
||||
if (!event?.choices) {
|
||||
throw new Error(`Bad event for Mistral SSE aggregation`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,25 +7,6 @@ export type SSEResponseTransformArgs<S = Record<string, any>> = {
|
||||
state?: S;
|
||||
};
|
||||
|
||||
export type MistralChatCompletionEvent = {
|
||||
choices: {
|
||||
index: number;
|
||||
message: { role: string; content: string };
|
||||
stop_reason: string | null;
|
||||
}[];
|
||||
};
|
||||
export type MistralTextCompletionEvent = {
|
||||
outputs: { text: string; stop_reason: string | null }[];
|
||||
};
|
||||
export type MistralAIStreamEvent = {
|
||||
"amazon-bedrock-invocationMetrics"?: {
|
||||
inputTokenCount: number;
|
||||
outputTokenCount: number;
|
||||
invocationLatency: number;
|
||||
firstByteLatency: number;
|
||||
};
|
||||
} & (MistralChatCompletionEvent | MistralTextCompletionEvent);
|
||||
|
||||
export type AnthropicV2StreamEvent = {
|
||||
log_id?: string;
|
||||
model?: string;
|
||||
@@ -60,12 +41,8 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||
export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2";
|
||||
export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai";
|
||||
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
|
||||
export { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai";
|
||||
export { mistralTextToMistralChat } from "./transformers/mistral-text-to-mistral-chat";
|
||||
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||
export { mergeEventsForAnthropicText } from "./aggregators/anthropic-text";
|
||||
export { mergeEventsForAnthropicChat } from "./aggregators/anthropic-chat";
|
||||
export { mergeEventsForMistralChat } from "./aggregators/mistral-chat";
|
||||
export { mergeEventsForMistralText } from "./aggregators/mistral-text";
|
||||
|
||||
@@ -11,11 +11,8 @@ import {
|
||||
googleAIToOpenAI,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
openAITextToOpenAIChat,
|
||||
mistralAIToOpenAI,
|
||||
mistralTextToMistralChat,
|
||||
passthroughToOpenAI,
|
||||
StreamingCompletionTransformer,
|
||||
MistralChatCompletionEvent,
|
||||
} from "./index";
|
||||
|
||||
type SSEMessageTransformerOptions = TransformOptions & {
|
||||
@@ -38,9 +35,7 @@ export class SSEMessageTransformer extends Transform {
|
||||
private readonly inputFormat: APIFormat;
|
||||
private readonly transformFn: StreamingCompletionTransformer<
|
||||
// TODO: Refactor transformers to not assume only OpenAI events as output
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralChatCompletionEvent
|
||||
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
|
||||
>;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
@@ -126,17 +121,16 @@ function eventIsOpenAIEvent(
|
||||
function getTransformer(
|
||||
responseApi: APIFormat,
|
||||
version?: string,
|
||||
// In most cases, we are transforming back to OpenAI. Some responses can be
|
||||
// translated between two non-OpenAI formats, eg Anthropic Chat -> Anthropic
|
||||
// Text, or Mistral Text -> Mistral Chat.
|
||||
// There's only one case where we're not transforming back to OpenAI, which is
|
||||
// Anthropic Chat response -> Anthropic Text request. This parameter is only
|
||||
// used for that case.
|
||||
requestApi: APIFormat = "openai"
|
||||
): StreamingCompletionTransformer<
|
||||
| OpenAIChatCompletionStreamEvent
|
||||
| AnthropicV2StreamEvent
|
||||
| MistralChatCompletionEvent
|
||||
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
|
||||
> {
|
||||
switch (responseApi) {
|
||||
case "openai":
|
||||
case "mistral-ai":
|
||||
return passthroughToOpenAI;
|
||||
case "openai-text":
|
||||
return openAITextToOpenAIChat;
|
||||
@@ -146,16 +140,10 @@ function getTransformer(
|
||||
: anthropicV2ToOpenAI;
|
||||
case "anthropic-chat":
|
||||
return requestApi === "anthropic-text"
|
||||
? anthropicChatToAnthropicV2 // User's legacy text prompt was converted to chat, and response must be converted back to text
|
||||
? anthropicChatToAnthropicV2
|
||||
: anthropicChatToOpenAI;
|
||||
case "google-ai":
|
||||
return googleAIToOpenAI;
|
||||
case "mistral-ai":
|
||||
return mistralAIToOpenAI;
|
||||
case "mistral-text":
|
||||
return requestApi === "mistral-ai"
|
||||
? mistralTextToMistralChat // User's chat request was converted to text, and response must be converted back to chat
|
||||
: mistralAIToOpenAI;
|
||||
case "openai-image":
|
||||
throw new Error(`SSE transformation not supported for ${responseApi}`);
|
||||
default:
|
||||
|
||||
@@ -55,10 +55,8 @@ export class SSEStreamAdapter extends Transform {
|
||||
|
||||
if ("completion" in eventObj) {
|
||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||
} else if (eventObj.type) {
|
||||
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
|
||||
} else {
|
||||
return `data: ${event}`;
|
||||
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
|
||||
}
|
||||
}
|
||||
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
|
||||
@@ -118,7 +116,7 @@ export class SSEStreamAdapter extends Transform {
|
||||
try {
|
||||
const hasParts = candidates[0].content?.parts?.length > 0;
|
||||
if (hasParts) {
|
||||
return `data: ${JSON.stringify(data.value ?? data)}`;
|
||||
return `data: ${JSON.stringify(data.value ?? data)}\n`;
|
||||
} else {
|
||||
this.log.error({ event: data }, "Received bad Google AI event");
|
||||
return `data: ${buildSpoofedSSE({
|
||||
|
||||
@@ -34,7 +34,7 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
index: params.index,
|
||||
delta: { content: deltaEvent.delta.text },
|
||||
finish_reason: null,
|
||||
},
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
import { logger } from "../../../../../logger";
|
||||
import { MistralAIStreamEvent, SSEResponseTransformArgs } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "mistral-ai-to-openai",
|
||||
});
|
||||
|
||||
export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
if ("choices" in completionEvent) {
|
||||
const newChatEvent = {
|
||||
id: params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: completionEvent.choices[0].index,
|
||||
delta: { content: completionEvent.choices[0].message.content },
|
||||
finish_reason: completionEvent.choices[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: newChatEvent };
|
||||
} else if ("outputs" in completionEvent) {
|
||||
const newTextEvent = {
|
||||
id: params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: completionEvent.outputs[0].text },
|
||||
finish_reason: completionEvent.outputs[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: newTextEvent };
|
||||
}
|
||||
|
||||
// should never happen
|
||||
return { position: -1 };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (
|
||||
(Array.isArray(parsed.choices) &&
|
||||
parsed.choices[0].message !== undefined) ||
|
||||
(Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined)
|
||||
) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
import {
|
||||
MistralChatCompletionEvent,
|
||||
MistralTextCompletionEvent,
|
||||
StreamingCompletionTransformer,
|
||||
} from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "mistral-text-to-mistral-chat",
|
||||
});
|
||||
|
||||
/**
|
||||
* Transforms an incoming Mistral Text SSE to an equivalent Mistral Chat SSE.
|
||||
* This is generally used when a client sends a Mistral Chat prompt, but we
|
||||
* convert it to Mistral Text before sending it to the API to work around
|
||||
* some bugs in Mistral/AWS prompt templating. In these cases we need to convert
|
||||
* the response back to Mistral Chat.
|
||||
*/
|
||||
export const mistralTextToMistralChat: StreamingCompletionTransformer<
|
||||
MistralChatCompletionEvent
|
||||
> = (params) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const textCompletion = asTextCompletion(rawEvent);
|
||||
if (!textCompletion) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const chatEvent: MistralChatCompletionEvent = {
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: { role: "assistant", content: textCompletion.outputs[0].text },
|
||||
stop_reason: textCompletion.outputs[0].stop_reason,
|
||||
},
|
||||
],
|
||||
};
|
||||
return { position: -1, event: chatEvent };
|
||||
};
|
||||
|
||||
function asTextCompletion(
|
||||
event: ServerSentEvent
|
||||
): MistralTextCompletionEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data);
|
||||
if (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid data event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
+19
-79
@@ -1,4 +1,4 @@
|
||||
import express, { Request, RequestHandler, Router } from "express";
|
||||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
@@ -21,48 +21,28 @@ import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { BadRequestError } from "../shared/errors";
|
||||
|
||||
// Mistral can't settle on a single naming scheme and deprecates models within
|
||||
// months of releasing them so this list is hard to keep up to date. 2024-07-28
|
||||
// https://docs.mistral.ai/platform/endpoints
|
||||
export const KNOWN_MISTRAL_AI_MODELS = [
|
||||
/*
|
||||
Mistral Nemo
|
||||
"A 12B model built with the partnership with Nvidia. It is easy to use and a
|
||||
drop-in replacement in any system using Mistral 7B that it supersedes."
|
||||
*/
|
||||
"open-mistral-nemo",
|
||||
"open-mistral-nemo-2407",
|
||||
/*
|
||||
Mistral Large
|
||||
"Our flagship model with state-of-the-art reasoning, knowledge, and coding
|
||||
capabilities."
|
||||
*/
|
||||
"mistral-large-latest",
|
||||
"mistral-large-2407",
|
||||
"mistral-large-2402", // deprecated
|
||||
/*
|
||||
Codestral
|
||||
"A cutting-edge generative model that has been specifically designed and
|
||||
optimized for code generation tasks, including fill-in-the-middle and code
|
||||
completion."
|
||||
note: this uses a separate bidi completion endpoint that is not implemented
|
||||
*/
|
||||
"codestral-latest",
|
||||
"codestral-2405",
|
||||
/* So-called "Research Models" */
|
||||
// Mistral 7b (open weight, legacy)
|
||||
"open-mistral-7b",
|
||||
"mistral-tiny-2312",
|
||||
// Mixtral 8x7b (open weight, legacy)
|
||||
"open-mixtral-8x7b",
|
||||
"open-mistral-8x22b",
|
||||
"open-codestral-mamba",
|
||||
/* Deprecated production models */
|
||||
"mistral-small-2312",
|
||||
// Mixtral Small (newer 8x7b, closed weight)
|
||||
"mistral-small-latest",
|
||||
"mistral-small-2402",
|
||||
// Mistral Medium
|
||||
"mistral-medium-latest",
|
||||
"mistral-medium-2312",
|
||||
// Mistral Large
|
||||
"mistral-large-latest",
|
||||
"mistral-large-2402",
|
||||
// Deprecated identifiers (2024-05-01)
|
||||
"mistral-tiny",
|
||||
"mistral-tiny-2312",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
@@ -109,24 +89,9 @@ const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
if (req.inboundApi === "mistral-text" && req.outboundApi === "mistral-ai") {
|
||||
newBody = transformMistralTextToMistralChat(body);
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
res.status(200).json({ ...body, proxy: body.proxy });
|
||||
};
|
||||
|
||||
export function transformMistralTextToMistralChat(textBody: any) {
|
||||
return {
|
||||
...textBody,
|
||||
choices: [
|
||||
{ message: { content: textBody.outputs[0].text, role: "assistant" } },
|
||||
],
|
||||
outputs: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
const mistralAIProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.mistral.ai",
|
||||
@@ -149,37 +114,12 @@ mistralAIRouter.get("/v1/models", handleModelRequest);
|
||||
mistralAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{
|
||||
inApi: "mistral-ai",
|
||||
outApi: "mistral-ai",
|
||||
service: "mistral-ai",
|
||||
},
|
||||
{ beforeTransform: [detectMistralInputApi] }
|
||||
),
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "mistral-ai",
|
||||
outApi: "mistral-ai",
|
||||
service: "mistral-ai",
|
||||
}),
|
||||
mistralAIProxy
|
||||
);
|
||||
|
||||
/**
|
||||
* We can't determine if a request is Mistral text or chat just from the path
|
||||
* because they both use the same endpoint. We need to check the request body
|
||||
* for either `messages` or `prompt`.
|
||||
* @param req
|
||||
*/
|
||||
export function detectMistralInputApi(req: Request) {
|
||||
const { messages, prompt } = req.body;
|
||||
if (messages) {
|
||||
req.inboundApi = "mistral-ai";
|
||||
req.outboundApi = "mistral-ai";
|
||||
} else if (prompt && req.service === "mistral-ai") {
|
||||
// Mistral La Plateforme doesn't expose a text completions endpoint.
|
||||
throw new BadRequestError(
|
||||
"Mistral (via La Plateforme API) does not support text completions. This format is only supported on Mistral via the AWS API."
|
||||
);
|
||||
} else if (prompt && req.service === "aws") {
|
||||
req.inboundApi = "mistral-text";
|
||||
req.outboundApi = "mistral-text";
|
||||
}
|
||||
}
|
||||
|
||||
export const mistralAI = mistralAIRouter;
|
||||
|
||||
+10
-26
@@ -28,44 +28,28 @@ import {
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
export const KNOWN_OPENAI_MODELS = [
|
||||
// GPT4o
|
||||
"gpt-4o",
|
||||
"gpt-4o-2024-05-13",
|
||||
"gpt-4o-2024-08-06",
|
||||
// GPT4o Mini
|
||||
"gpt-4o-mini",
|
||||
"gpt-4o-mini-2024-07-18",
|
||||
// GPT4o (ChatGPT)
|
||||
"chatgpt-4o-latest",
|
||||
// GPT4 Turbo (superceded by GPT4o)
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo", // alias for latest gpt4-turbo stable
|
||||
"gpt-4-turbo-2024-04-09", // gpt4-turbo stable, with vision
|
||||
"gpt-4-turbo-preview", // alias for latest turbo preview
|
||||
"gpt-4-0125-preview", // gpt4-turbo preview 2
|
||||
"gpt-4-1106-preview", // gpt4-turbo preview 1
|
||||
// Launch GPT4
|
||||
"gpt-4-vision-preview", // gpt4-turbo preview 1 with vision
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // legacy
|
||||
// GPT3.5 Turbo (superceded by GPT4o Mini)
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314", // EOL 2024-06-13
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125", // latest turbo
|
||||
"gpt-3.5-turbo-1106", // older turbo
|
||||
// Text Completion
|
||||
"gpt-3.5-turbo-0301", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-instruct-0914",
|
||||
// Embeddings
|
||||
"text-embedding-ada-002",
|
||||
// Known deprecated models
|
||||
"gpt-4-32k", // alias for 0613
|
||||
"gpt-4-32k-0314", // EOL 2025-06-06
|
||||
"gpt-4-32k-0613", // EOL 2025-06-06
|
||||
"gpt-4-vision-preview", // EOL 2024-12-06
|
||||
"gpt-4-1106-vision-preview", // EOL 2024-12-06
|
||||
"gpt-3.5-turbo-0613", // EOL 2024-09-13
|
||||
"gpt-3.5-turbo-0301", // not on the website anymore, maybe unavailable
|
||||
"gpt-3.5-turbo-16k", // alias for 0613
|
||||
"gpt-3.5-turbo-16k-0613", // EOL 2024-09-13
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
|
||||
+50
-14
@@ -22,7 +22,7 @@ import {
|
||||
} from "../shared/models";
|
||||
import { initializeSseStream } from "../shared/streaming";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps } from "./rate-limit";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
@@ -30,12 +30,10 @@ import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = parseInt(
|
||||
process.env.USER_CONCURRENCY_LIMIT ?? "1"
|
||||
);
|
||||
/** Maximum number of queue slots for Agnai.chat requests. */
|
||||
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
|
||||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = 1;
|
||||
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||
const MAX_HEARTBEAT_SIZE =
|
||||
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||
@@ -60,20 +58,39 @@ const QUEUE_JOIN_TIMEOUT = 5000;
|
||||
function getIdentifier(req: Request) {
|
||||
if (req.user) return req.user.token;
|
||||
if (req.risuToken) return req.risuToken;
|
||||
// if (isFromSharedIp(req)) return "shared-ip";
|
||||
if (isFromSharedIp(req)) return "shared-ip";
|
||||
return req.ip;
|
||||
}
|
||||
|
||||
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
||||
getIdentifier(queued) === getIdentifier(incoming);
|
||||
|
||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
|
||||
|
||||
async function enqueue(req: Request) {
|
||||
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
||||
let isGuest = req.user?.token === undefined;
|
||||
|
||||
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
|
||||
throw new TooManyRequestsError(
|
||||
"Your IP or user token already has another request in the queue."
|
||||
);
|
||||
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
|
||||
// based rate limiting but can only occupy a certain number of slots in the
|
||||
// queue. Authenticated users always get a single spot in the queue.
|
||||
const isSharedIp = isFromSharedIp(req);
|
||||
const maxConcurrentQueuedRequests =
|
||||
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
|
||||
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
|
||||
if (isSharedIp) {
|
||||
// Re-enqueued requests are not counted towards the limit since they
|
||||
// already made it through the queue once.
|
||||
if (req.retryCount === 0) {
|
||||
throw new TooManyRequestsError(
|
||||
"Too many agnai.chat requests are already queued"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
throw new TooManyRequestsError(
|
||||
"Your IP or user token already has another request in the queue."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// shitty hack to remove hpm's event listeners on retried requests
|
||||
@@ -129,7 +146,19 @@ export async function reenqueueRequest(req: Request) {
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue.filter((req) => getModelFamilyForRequest(req) === partition);
|
||||
return queue
|
||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||
.sort((a, b) => {
|
||||
// Certain requests are exempted from IP-based rate limiting because they
|
||||
// come from a shared IP address. To prevent these requests from starving
|
||||
// out other requests during periods of high traffic, we sort them to the
|
||||
// end of the queue.
|
||||
const aIsExempted = isFromSharedIp(a);
|
||||
const bIsExempted = isFromSharedIp(b);
|
||||
if (aIsExempted && !bIsExempted) return 1;
|
||||
if (!aIsExempted && bIsExempted) return -1;
|
||||
return 0;
|
||||
});
|
||||
}
|
||||
|
||||
export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
@@ -232,6 +261,7 @@ let waitTimes: {
|
||||
partition: ModelFamily;
|
||||
start: number;
|
||||
end: number;
|
||||
isDeprioritized: boolean;
|
||||
}[] = [];
|
||||
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
@@ -240,6 +270,7 @@ export function trackWaitTime(req: Request) {
|
||||
partition: getModelFamilyForRequest(req),
|
||||
start: req.startTime!,
|
||||
end: req.queueOutTime ?? Date.now(),
|
||||
isDeprioritized: isFromSharedIp(req),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -265,7 +296,8 @@ function calculateWaitTime(partition: ModelFamily) {
|
||||
.filter((wait) => {
|
||||
const isSamePartition = wait.partition === partition;
|
||||
const isRecent = now - wait.end < 300 * 1000;
|
||||
return isSamePartition && isRecent;
|
||||
const isNormalPriority = !wait.isDeprioritized;
|
||||
return isSamePartition && isRecent && isNormalPriority;
|
||||
})
|
||||
.map((wait) => wait.end - wait.start);
|
||||
const recentAverage = recentWaits.length
|
||||
@@ -279,7 +311,11 @@ function calculateWaitTime(partition: ModelFamily) {
|
||||
);
|
||||
|
||||
const currentWaits = queue
|
||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||
.filter((req) => {
|
||||
const isSamePartition = getModelFamilyForRequest(req) === partition;
|
||||
const isNormalPriority = !isFromSharedIp(req);
|
||||
return isSamePartition && isNormalPriority;
|
||||
})
|
||||
.map((req) => now - req.startTime!);
|
||||
const longestCurrentWait = Math.max(...currentWaits, 0);
|
||||
|
||||
|
||||
+32
-15
@@ -1,6 +1,14 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { config } from "../config";
|
||||
|
||||
export const SHARED_IP_ADDRESSES = new Set([
|
||||
// Agnai.chat
|
||||
"157.230.249.32", // old
|
||||
"157.245.148.56",
|
||||
"174.138.29.50",
|
||||
"209.97.162.44",
|
||||
]);
|
||||
|
||||
const ONE_MINUTE_MS = 60 * 1000;
|
||||
|
||||
type Timestamp = number;
|
||||
@@ -12,10 +20,7 @@ const exemptedRequests: Timestamp[] = [];
|
||||
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
|
||||
attempt > now - ONE_MINUTE_MS;
|
||||
|
||||
/**
|
||||
* Returns duration in seconds to wait before retrying for Retry-After header.
|
||||
*/
|
||||
const getRetryAfter = (ip: string, type: "text" | "image") => {
|
||||
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
|
||||
const now = Date.now();
|
||||
const attempts = lastAttempts.get(ip) || [];
|
||||
const validAttempts = attempts.filter(isRecentAttempt(now));
|
||||
@@ -24,7 +29,7 @@ const getRetryAfter = (ip: string, type: "text" | "image") => {
|
||||
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
|
||||
|
||||
if (validAttempts.length >= limit) {
|
||||
return (validAttempts[0] - now + ONE_MINUTE_MS) / 1000;
|
||||
return validAttempts[0] - now + ONE_MINUTE_MS;
|
||||
} else {
|
||||
lastAttempts.set(ip, [...validAttempts, now]);
|
||||
return 0;
|
||||
@@ -91,11 +96,22 @@ export const ipLimiter = async (
|
||||
if (!textLimit && !imageLimit) return next();
|
||||
if (req.user?.type === "special") return next();
|
||||
|
||||
const path = req.baseUrl + req.path;
|
||||
const type =
|
||||
path.includes("openai-image") || path.includes("images/generations")
|
||||
? "image"
|
||||
: "text";
|
||||
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
|
||||
// by many users. Instead, the request queue will limit the number of such
|
||||
// requests that may wait in the queue at a time, and sorts them to the end to
|
||||
// let individual users go first.
|
||||
if (SHARED_IP_ADDRESSES.has(req.ip)) {
|
||||
exemptedRequests.push(Date.now());
|
||||
req.log.info(
|
||||
{ ip: req.ip, recentExemptions: exemptedRequests.length },
|
||||
"Exempting Agnai request from rate limiting."
|
||||
);
|
||||
return next();
|
||||
}
|
||||
|
||||
const type = (req.baseUrl + req.path).includes("openai-image")
|
||||
? "image"
|
||||
: "text";
|
||||
const limit = type === "image" ? imageLimit : textLimit;
|
||||
|
||||
// If user is authenticated, key rate limiting by their token. Otherwise, key
|
||||
@@ -107,14 +123,15 @@ export const ipLimiter = async (
|
||||
res.set("X-RateLimit-Remaining", remaining.toString());
|
||||
res.set("X-RateLimit-Reset", reset.toString());
|
||||
|
||||
const retryAfterTime = getRetryAfter(rateLimitKey, type);
|
||||
if (retryAfterTime > 0) {
|
||||
const waitSec = Math.ceil(retryAfterTime).toString();
|
||||
res.set("Retry-After", waitSec);
|
||||
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
|
||||
if (tryAgainInMs > 0) {
|
||||
res.set("Retry-After", tryAgainInMs.toString());
|
||||
res.status(429).json({
|
||||
error: {
|
||||
type: "proxy_rate_limited",
|
||||
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${waitSec} seconds.`,
|
||||
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
|
||||
tryAgainInMs / 1000
|
||||
)} seconds.`,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
|
||||
+19
-25
@@ -1,55 +1,42 @@
|
||||
import express from "express";
|
||||
import { addV1 } from "./add-v1";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
import { checkRisuToken } from "./check-risu-token";
|
||||
import express, { Request, Response, NextFunction } from "express";
|
||||
import { gatekeeper } from "./gatekeeper";
|
||||
import { gcp } from "./gcp";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { checkRisuToken } from "./check-risu-token";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { mistralAI } from "./mistral-ai";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
import { sendErrorToClient } from "./middleware/response/error-generator";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
|
||||
// Remove `expect: 100-continue` header from requests due to incompatibility
|
||||
// with node-http-proxy.
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
if (req.headers.expect) {
|
||||
// node-http-proxy does not like it when clients send `expect: 100-continue`
|
||||
// and will stall. none of the upstream APIs use this header anyway.
|
||||
delete req.headers.expect;
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
// Apply body parsers.
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "100mb" }),
|
||||
express.urlencoded({ extended: true, limit: "100mb" })
|
||||
);
|
||||
|
||||
// Apply auth/rate limits.
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
|
||||
// Initialize request queue metadata.
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
req.startTime = Date.now();
|
||||
req.retryCount = 0;
|
||||
next();
|
||||
});
|
||||
|
||||
// Proxy endpoints.
|
||||
proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/mistral-ai", addV1, mistralAI);
|
||||
proxyRouter.use("/aws", aws);
|
||||
proxyRouter.use("/gcp/claude", addV1, gcp);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
@@ -59,8 +46,7 @@ proxyRouter.get("*", (req, res, next) => {
|
||||
next();
|
||||
}
|
||||
});
|
||||
|
||||
// Send a fake client error if user specifies an invalid proxy endpoint.
|
||||
// Handle 404s.
|
||||
proxyRouter.use((req, res) => {
|
||||
sendErrorToClient({
|
||||
req,
|
||||
@@ -81,3 +67,11 @@ proxyRouter.use((req, res) => {
|
||||
});
|
||||
|
||||
export { proxyRouter as proxyRouter };
|
||||
|
||||
function addV1(req: Request, res: Response, next: NextFunction) {
|
||||
// Clients don't consistently use the /v1 prefix so we'll add it for them.
|
||||
if (!req.path.startsWith("/v1/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
||||
@@ -49,7 +49,6 @@ app.use(
|
||||
// Don't log the prompt text on transform errors
|
||||
"body.messages",
|
||||
"body.prompt",
|
||||
"body.contents",
|
||||
],
|
||||
censor: "********",
|
||||
},
|
||||
@@ -88,15 +87,6 @@ app.use(blacklist);
|
||||
app.use(checkOrigin);
|
||||
|
||||
app.use("/admin", adminRouter);
|
||||
app.use((req, _, next) => {
|
||||
// For whatever reason SillyTavern just ignores the path a user provides
|
||||
// when using Google AI with reverse proxy. We'll fix it here.
|
||||
if (req.path.startsWith("/v1beta/models/")) {
|
||||
req.url = `${config.proxyEndpointRoute}/google-ai${req.url}`;
|
||||
return next();
|
||||
}
|
||||
next();
|
||||
});
|
||||
app.use(config.proxyEndpointRoute, proxyRouter);
|
||||
app.use("/user", userRouter);
|
||||
if (config.staticServiceInfo) {
|
||||
|
||||
+127
-118
@@ -2,7 +2,8 @@ import { config, listConfig } from "./config";
|
||||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
GcpKey,
|
||||
AzureOpenAIKey,
|
||||
GoogleAIKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
@@ -10,7 +11,6 @@ import {
|
||||
AnthropicModelFamily,
|
||||
assertIsKnownModelFamily,
|
||||
AwsBedrockModelFamily,
|
||||
GcpModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
GoogleAIModelFamily,
|
||||
LLM_SERVICES,
|
||||
@@ -24,16 +24,22 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
|
||||
|
||||
const CACHE_TTL = 2000;
|
||||
|
||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||
k.service === "google-ai";
|
||||
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
|
||||
k.service === "mistral-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
|
||||
|
||||
/** Stats aggregated across all keys for a given service. */
|
||||
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
|
||||
@@ -45,15 +51,10 @@ type ModelAggregates = {
|
||||
overQuota?: number;
|
||||
pozzed?: number;
|
||||
awsLogged?: number;
|
||||
// needed to disambugiate aws-claude family's variants
|
||||
awsClaude2?: number;
|
||||
awsSonnet3?: number;
|
||||
awsSonnet3_5?: number;
|
||||
awsHaiku: number;
|
||||
gcpSonnet?: number;
|
||||
gcpSonnet35?: number;
|
||||
gcpHaiku?: number;
|
||||
awsSonnet?: number;
|
||||
awsHaiku?: number;
|
||||
queued: number;
|
||||
queueTime: string;
|
||||
tokens: number;
|
||||
};
|
||||
/** All possible combinations of model family and aggregate type. */
|
||||
@@ -85,10 +86,8 @@ type AnthropicInfo = BaseFamilyInfo & {
|
||||
};
|
||||
type AwsInfo = BaseFamilyInfo & {
|
||||
privacy?: string;
|
||||
enabledVariants?: string;
|
||||
};
|
||||
type GcpInfo = BaseFamilyInfo & {
|
||||
enabledVariants?: string;
|
||||
sonnetKeys?: number;
|
||||
haikuKeys?: number;
|
||||
};
|
||||
|
||||
// prettier-ignore
|
||||
@@ -96,11 +95,12 @@ export type ServiceInfo = {
|
||||
uptime: number;
|
||||
endpoints: {
|
||||
openai?: string;
|
||||
openai2?: string;
|
||||
anthropic?: string;
|
||||
"anthropic-claude-3"?: string;
|
||||
"google-ai"?: string;
|
||||
"mistral-ai"?: string;
|
||||
"aws"?: string;
|
||||
gcp?: string;
|
||||
aws?: string;
|
||||
azure?: string;
|
||||
"openai-image"?: string;
|
||||
"azure-image"?: string;
|
||||
@@ -114,7 +114,6 @@ export type ServiceInfo = {
|
||||
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
|
||||
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
|
||||
& { [f in AwsBedrockModelFamily]?: AwsInfo }
|
||||
& { [f in GcpModelFamily]?: GcpInfo }
|
||||
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
|
||||
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
|
||||
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
|
||||
@@ -137,6 +136,7 @@ export type ServiceInfo = {
|
||||
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
openai: {
|
||||
openai: `%BASE%/openai`,
|
||||
openai2: `%BASE%/openai/turbo-instruct`,
|
||||
"openai-image": `%BASE%/openai-image`,
|
||||
},
|
||||
anthropic: {
|
||||
@@ -149,11 +149,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
"mistral-ai": `%BASE%/mistral-ai`,
|
||||
},
|
||||
aws: {
|
||||
"aws-claude": `%BASE%/aws/claude`,
|
||||
"aws-mistral": `%BASE%/aws/mistral`,
|
||||
},
|
||||
gcp: {
|
||||
gcp: `%BASE%/gcp/claude`,
|
||||
aws: `%BASE%/aws/claude`,
|
||||
},
|
||||
azure: {
|
||||
azure: `%BASE%/azure/openai`,
|
||||
@@ -161,7 +157,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
|
||||
},
|
||||
};
|
||||
|
||||
const familyStats = new Map<ModelAggregateKey, number>();
|
||||
const modelStats = new Map<ModelAggregateKey, number>();
|
||||
const serviceStats = new Map<keyof AllStats, number>();
|
||||
|
||||
let cachedInfo: ServiceInfo | undefined;
|
||||
@@ -178,7 +174,7 @@ export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
|
||||
.concat("turbo")
|
||||
);
|
||||
|
||||
familyStats.clear();
|
||||
modelStats.clear();
|
||||
serviceStats.clear();
|
||||
keys.forEach(addKeyToAggregates);
|
||||
|
||||
@@ -297,102 +293,131 @@ function increment<T extends keyof AllStats | ModelAggregateKey>(
|
||||
) {
|
||||
map.set(key, (map.get(key) || 0) + delta);
|
||||
}
|
||||
const addToService = increment.bind(null, serviceStats);
|
||||
const addToFamily = increment.bind(null, familyStats);
|
||||
|
||||
function addKeyToAggregates(k: KeyPoolKey) {
|
||||
addToService("proompts", k.promptCount);
|
||||
addToService("openai__keys", k.service === "openai" ? 1 : 0);
|
||||
addToService("anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
addToService("google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
addToService("mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
|
||||
addToService("aws__keys", k.service === "aws" ? 1 : 0);
|
||||
addToService("gcp__keys", k.service === "gcp" ? 1 : 0);
|
||||
addToService("azure__keys", k.service === "azure" ? 1 : 0);
|
||||
increment(serviceStats, "proompts", k.promptCount);
|
||||
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
|
||||
increment(
|
||||
serviceStats,
|
||||
"mistral-ai__keys",
|
||||
k.service === "mistral-ai" ? 1 : 0
|
||||
);
|
||||
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
|
||||
const incrementGenericFamilyStats = (f: ModelFamily) => {
|
||||
const tokens = (k as any)[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
addToFamily(`${f}__tokens`, tokens);
|
||||
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
|
||||
};
|
||||
|
||||
switch (k.service) {
|
||||
case "openai":
|
||||
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
|
||||
addToService("openai__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
|
||||
k.modelFamilies.forEach((f) => {
|
||||
incrementGenericFamilyStats(f);
|
||||
addToFamily(`${f}__trial`, k.isTrial ? 1 : 0);
|
||||
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic":
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
addToService("anthropic__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
|
||||
k.modelFamilies.forEach((f) => {
|
||||
incrementGenericFamilyStats(f);
|
||||
addToFamily(`${f}__trial`, k.tier === "free" ? 1 : 0);
|
||||
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
addToFamily(`${f}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
increment(
|
||||
serviceStats,
|
||||
"openai__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
increment(modelStats, `${f}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
});
|
||||
increment(
|
||||
serviceStats,
|
||||
"anthropic__uncheckedKeys",
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-ai": {
|
||||
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
|
||||
const family = "gemini-pro";
|
||||
sumTokens += k["gemini-proTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
||||
break;
|
||||
}
|
||||
case "mistral-ai": {
|
||||
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
if (!k.isDisabled) {
|
||||
// Don't add revoked keys to available AWS variants
|
||||
k.modelIds.forEach((id) => {
|
||||
if (id.includes("claude-3-sonnet")) {
|
||||
addToFamily(`aws-claude__awsSonnet3`, 1);
|
||||
} else if (id.includes("claude-3-5-sonnet")) {
|
||||
addToFamily(`aws-claude__awsSonnet3_5`, 1);
|
||||
} else if (id.includes("claude-3-haiku")) {
|
||||
addToFamily(`aws-claude__awsHaiku`, 1);
|
||||
} else if (id.includes("claude-v2")) {
|
||||
addToFamily(`aws-claude__awsClaude2`, 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
});
|
||||
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
|
||||
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
// logging status is unknown.
|
||||
const countAsLogged =
|
||||
k.lastChecked && !k.isDisabled && k.awsLoggingStatus === "enabled";
|
||||
addToFamily(`aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
|
||||
break;
|
||||
}
|
||||
case "gcp":
|
||||
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
// TODO: add modelIds to GcpKey
|
||||
break;
|
||||
// These services don't have any additional stats to track.
|
||||
case "azure":
|
||||
case "google-ai":
|
||||
case "mistral-ai":
|
||||
k.modelFamilies.forEach(incrementGenericFamilyStats);
|
||||
break;
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
||||
addToService("tokens", sumTokens);
|
||||
addToService("tokenCost", sumCost);
|
||||
increment(serviceStats, "tokens", sumTokens);
|
||||
increment(serviceStats, "tokenCost", sumCost);
|
||||
}
|
||||
|
||||
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const tokens = familyStats.get(`${family}__tokens`) || 0;
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
|
||||
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
|
||||
activeKeys: familyStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: familyStats.get(`${family}__revoked`) || 0,
|
||||
activeKeys: modelStats.get(`${family}__active`) || 0,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
// Add service-specific stats to the info object.
|
||||
@@ -400,8 +425,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
const service = MODEL_FAMILY_SERVICE[family];
|
||||
switch (service) {
|
||||
case "openai":
|
||||
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
|
||||
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
|
||||
|
||||
// Delete trial/revoked keys for non-turbo families.
|
||||
// Trials are turbo 99% of the time, and if a key is invalid we don't
|
||||
@@ -412,25 +437,15 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
|
||||
info.prefilledKeys = familyStats.get(`${family}__pozzed`) || 0;
|
||||
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
|
||||
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
|
||||
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
|
||||
break;
|
||||
case "aws":
|
||||
if (family === "aws-claude") {
|
||||
const logged = familyStats.get(`${family}__awsLogged`) || 0;
|
||||
const variants = new Set<string>();
|
||||
if (familyStats.get(`${family}__awsClaude2`) || 0)
|
||||
variants.add("claude2");
|
||||
if (familyStats.get(`${family}__awsSonnet3`) || 0)
|
||||
variants.add("sonnet3");
|
||||
if (familyStats.get(`${family}__awsSonnet3_5`) || 0)
|
||||
variants.add("sonnet3.5");
|
||||
if (familyStats.get(`${family}__awsHaiku`) || 0)
|
||||
variants.add("haiku");
|
||||
info.enabledVariants = variants.size
|
||||
? `${Array.from(variants).join(",")}`
|
||||
: undefined;
|
||||
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
|
||||
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
|
||||
const logged = modelStats.get(`${family}__awsLogged`) || 0;
|
||||
if (logged > 0) {
|
||||
info.privacy = config.allowAwsLogging
|
||||
? `AWS logging verification inactive. Prompts could be logged.`
|
||||
@@ -438,12 +453,6 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
|
||||
}
|
||||
}
|
||||
break;
|
||||
case "gcp":
|
||||
if (family === "gcp-claude") {
|
||||
// TODO: implement
|
||||
info.enabledVariants = "not implemented";
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,12 +19,7 @@ const AnthropicV1BaseSchema = z
|
||||
top_k: z.coerce.number().optional(),
|
||||
top_p: z.coerce.number().optional(),
|
||||
metadata: z.object({ user_id: z.string().optional() }).optional(),
|
||||
tools: z.array(z.any()).optional(),
|
||||
tool_choice: z.any().optional(),
|
||||
})
|
||||
.omit(
|
||||
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true }
|
||||
)
|
||||
.strip();
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
|
||||
@@ -49,18 +44,6 @@ const AnthropicV1MessageMultimodalContentSchema = z.array(
|
||||
data: z.string(),
|
||||
}),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal("tool_use"),
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
input: z.object({}).passthrough(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal("tool_result"),
|
||||
tool_use_id: z.string(),
|
||||
is_error: z.boolean().optional(),
|
||||
content: z.union([z.string(), z.object({}).passthrough()]).optional(),
|
||||
}),
|
||||
])
|
||||
);
|
||||
|
||||
@@ -80,12 +63,7 @@ export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
|
||||
.number()
|
||||
.int()
|
||||
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
|
||||
system: z
|
||||
.union([
|
||||
z.string(),
|
||||
z.array(z.object({ type: z.literal("text"), text: z.string() })),
|
||||
])
|
||||
.optional(),
|
||||
system: z.string().optional(),
|
||||
})
|
||||
);
|
||||
export type AnthropicChatMessage = z.infer<
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
import { z } from "zod";
|
||||
import {
|
||||
OPENAI_OUTPUT_MAX,
|
||||
OpenAIV1ChatCompletionSchema,
|
||||
flattenOpenAIMessageContent,
|
||||
} from "./openai";
|
||||
import { APIFormatTransformer } from ".";
|
||||
|
||||
// https://docs.cohere.com/reference/chat
|
||||
export const CohereV1ChatSchema = z
|
||||
.object({
|
||||
message: z.string(),
|
||||
model: z.string().default("command-r-plus"),
|
||||
stream: z.boolean().default(false).optional(),
|
||||
preamble: z.string().optional(),
|
||||
chat_history: z
|
||||
.array(
|
||||
// Either a message from a chat participant, or a past tool call
|
||||
z.union([
|
||||
z.object({
|
||||
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
|
||||
message: z.string(),
|
||||
tool_calls: z
|
||||
.array(z.object({ name: z.string(), parameters: z.any() }))
|
||||
.optional(),
|
||||
}),
|
||||
z.object({
|
||||
role: z.enum(["TOOL"]),
|
||||
tool_results: z.array(
|
||||
z.object({
|
||||
call: z.object({ name: z.string(), parameters: z.any() }),
|
||||
outputs: z.array(z.any()),
|
||||
})
|
||||
),
|
||||
}),
|
||||
])
|
||||
)
|
||||
.optional(),
|
||||
// Don't allow conversation_id as it causes calls to be stateful and we don't
|
||||
// offer guarantees about which key a user's request will be routed to.
|
||||
conversation_id: z.literal(undefined).optional(),
|
||||
prompt_truncation: z
|
||||
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
|
||||
.optional(),
|
||||
/*
|
||||
Supporting RAG is complex because documents can be arbitrary size and have
|
||||
to have embeddings generated, which incurs a cost that is not trivial to
|
||||
estimate. We don't support it for now.
|
||||
connectors: z
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
user_access_token: z.string().optional(),
|
||||
continue_on_failure: z.boolean().default(false).optional(),
|
||||
options: z.any().optional(),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
search_queries_only: z.boolean().default(false).optional(),
|
||||
documents: z
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string().optional(),
|
||||
title: z.string().optional(),
|
||||
text: z.string(),
|
||||
_excludes: z.array(z.string()).optional(),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
citation_quality: z.enum(["accurate", "fast"]).optional(),
|
||||
*/
|
||||
temperature: z.number().default(0.3).optional(),
|
||||
max_tokens: z
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
max_input_tokens: z.number().int().optional(),
|
||||
k: z.number().int().min(0).max(500).default(0).optional(),
|
||||
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
|
||||
seed: z.number().int().optional(),
|
||||
stop_sequences: z.array(z.string()).max(5).optional(),
|
||||
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||
presence_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||
tools: z
|
||||
.array(
|
||||
z.object({
|
||||
name: z.string(),
|
||||
description: z.string(),
|
||||
parameter_definitions: z.record(
|
||||
z.object({
|
||||
description: z.string().optional(),
|
||||
type: z.string(),
|
||||
required: z.boolean().optional().default(false),
|
||||
})
|
||||
),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
tool_results: z
|
||||
.array(
|
||||
z.object({
|
||||
call: z.object({
|
||||
name: z.string(),
|
||||
parameters: z.record(z.any()),
|
||||
}),
|
||||
outputs: z.array(z.record(z.any())),
|
||||
})
|
||||
)
|
||||
.optional(),
|
||||
// We always force single step to avoid stateful calls or expensive multi-step
|
||||
// generations when tools are involved.
|
||||
force_single_step: z.literal(true).default(true).optional(),
|
||||
})
|
||||
.strip();
|
||||
export type CohereChatMessage = NonNullable<
|
||||
z.infer<typeof CohereV1ChatSchema>["chat_history"]
|
||||
>[number];
|
||||
|
||||
export function flattenCohereMessageContent(
|
||||
message: CohereChatMessage
|
||||
): string {
|
||||
return message.role === "TOOL"
|
||||
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
|
||||
: message.message;
|
||||
}
|
||||
|
||||
export const transformOpenAIToCohere: APIFormatTransformer<
|
||||
typeof CohereV1ChatSchema
|
||||
> = async (req) => {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
...body,
|
||||
model: "gpt-3.5-turbo",
|
||||
});
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Cohere request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
// Final OAI message becomes the `message` field in Cohere
|
||||
const message = messages[messages.length - 1];
|
||||
// If the first message has system role, use it as preamble.
|
||||
const hasSystemPreamble = messages[0]?.role === "system";
|
||||
const preamble = hasSystemPreamble
|
||||
? flattenOpenAIMessageContent(messages[0].content)
|
||||
: undefined;
|
||||
|
||||
const chatHistory = messages.slice(0, -1).map((m) => {
|
||||
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
|
||||
m.role === "assistant"
|
||||
? "CHATBOT"
|
||||
: m.role === "system"
|
||||
? "SYSTEM"
|
||||
: "USER";
|
||||
const content = flattenOpenAIMessageContent(m.content);
|
||||
const message = m.name ? `${m.name}: ${content}` : content;
|
||||
return { role, message };
|
||||
});
|
||||
|
||||
return {
|
||||
model: rest.model,
|
||||
preamble,
|
||||
chat_history: chatHistory,
|
||||
message: flattenOpenAIMessageContent(message.content),
|
||||
stop_sequences:
|
||||
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
|
||||
max_tokens: rest.max_tokens,
|
||||
temperature: rest.temperature,
|
||||
p: rest.top_p,
|
||||
frequency_penalty: rest.frequency_penalty,
|
||||
presence_penalty: rest.presence_penalty,
|
||||
seed: rest.seed,
|
||||
stream: rest.stream,
|
||||
};
|
||||
};
|
||||
@@ -5,20 +5,19 @@ import {
|
||||
} from "./openai";
|
||||
import { APIFormatTransformer } from "./index";
|
||||
|
||||
const GoogleAIV1ContentSchema = z.object({
|
||||
parts: z.array(z.object({ text: z.string() })), // TODO: add other media types
|
||||
role: z.enum(["user", "model"]).optional(),
|
||||
});
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
|
||||
export const GoogleAIV1GenerateContentSchema = z
|
||||
.object({
|
||||
model: z.string().max(100), //actually specified in path but we need it for the router
|
||||
stream: z.boolean().optional().default(false), // also used for router
|
||||
contents: z.array(GoogleAIV1ContentSchema),
|
||||
contents: z.array(
|
||||
z.object({
|
||||
parts: z.array(z.object({ text: z.string() })),
|
||||
role: z.enum(["user", "model"]),
|
||||
})
|
||||
),
|
||||
tools: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: z.array(z.object({})).optional(),
|
||||
systemInstruction: GoogleAIV1ContentSchema.optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
generationConfig: z.object({
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
@@ -26,12 +25,12 @@ export const GoogleAIV1GenerateContentSchema = z
|
||||
.int()
|
||||
.optional()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v, 4096)), // TODO: Add config
|
||||
.transform((v) => Math.min(v, 1024)), // TODO: Add config
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().optional(),
|
||||
topK: z.number().optional(),
|
||||
stopSequences: z.array(z.string().max(500)).max(5).optional(),
|
||||
}).default({}),
|
||||
}),
|
||||
})
|
||||
.strip();
|
||||
export type GoogleAIChatMessage = z.infer<
|
||||
@@ -104,7 +103,7 @@ export const transformOpenAIToGoogleAI: APIFormatTransformer<
|
||||
stops = [...new Set(stops)].slice(0, 5);
|
||||
|
||||
return {
|
||||
model: req.body.model,
|
||||
model: "gemini-pro",
|
||||
stream: rest.stream,
|
||||
contents,
|
||||
tools: [],
|
||||
|
||||
@@ -21,11 +21,8 @@ import {
|
||||
GoogleAIV1GenerateContentSchema,
|
||||
transformOpenAIToGoogleAI,
|
||||
} from "./google-ai";
|
||||
import {
|
||||
MistralAIV1ChatCompletionsSchema,
|
||||
MistralAIV1TextCompletionsSchema,
|
||||
transformMistralChatToText,
|
||||
} from "./mistral-ai";
|
||||
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
|
||||
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
|
||||
|
||||
export { OpenAIChatMessage } from "./openai";
|
||||
export {
|
||||
@@ -37,15 +34,29 @@ export {
|
||||
export { GoogleAIChatMessage } from "./google-ai";
|
||||
export { MistralAIChatMessage } from "./mistral-ai";
|
||||
|
||||
/** Represents a pair of API formats that can be transformed between. */
|
||||
type APIPair = `${APIFormat}->${APIFormat}`;
|
||||
/** Represents a map of API format pairs to transformer functions. */
|
||||
type TransformerMap = {
|
||||
[key in APIPair]?: APIFormatTransformer<any>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Represents a transformer function that takes a Request and returns a Promise
|
||||
* resolving to a value of the specified Zod schema type.
|
||||
*
|
||||
* @template Z The Zod schema type to transform the request into (from api-schemas).
|
||||
* @param req The incoming Request to transform.
|
||||
* @returns A Promise resolving to the transformed request body.
|
||||
*/
|
||||
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
|
||||
req: Request
|
||||
) => Promise<z.infer<Z>>;
|
||||
|
||||
/**
|
||||
* Specifies possible translations between API formats and the corresponding
|
||||
* transformer functions to apply them.
|
||||
*/
|
||||
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
|
||||
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
|
||||
@@ -53,9 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||
"openai->openai-text": transformOpenAIToOpenAIText,
|
||||
"openai->openai-image": transformOpenAIToOpenAIImage,
|
||||
"openai->google-ai": transformOpenAIToGoogleAI,
|
||||
"mistral-ai->mistral-text": transformMistralChatToText,
|
||||
"openai->cohere-chat": transformOpenAIToCohere,
|
||||
};
|
||||
|
||||
/**
|
||||
* Specifies the schema for each API format to validate incoming requests.
|
||||
*/
|
||||
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
"anthropic-chat": AnthropicV1MessagesSchema,
|
||||
"anthropic-text": AnthropicV1TextSchema,
|
||||
@@ -64,5 +78,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||
"mistral-text": MistralAIV1TextCompletionsSchema,
|
||||
"cohere-chat": CohereV1ChatSchema,
|
||||
};
|
||||
|
||||
@@ -1,34 +1,15 @@
|
||||
import { z } from "zod";
|
||||
import { OPENAI_OUTPUT_MAX } from "./openai";
|
||||
import { Template } from "@huggingface/jinja";
|
||||
import { APIFormatTransformer } from "./index";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
const MistralChatMessageSchema = z.object({
|
||||
role: z.enum(["system", "user", "assistant", "tool"]), // TODO: implement tools
|
||||
content: z.string(),
|
||||
prefix: z.boolean().optional(),
|
||||
});
|
||||
|
||||
const MistralMessagesSchema = z.array(MistralChatMessageSchema).refine(
|
||||
(input) => {
|
||||
const prefixIdx = input.findIndex((msg) => Boolean(msg.prefix));
|
||||
if (prefixIdx === -1) return true; // no prefix messages
|
||||
const lastIdx = input.length - 1;
|
||||
const lastMsg = input[lastIdx];
|
||||
return prefixIdx === lastIdx && lastMsg.role === "assistant";
|
||||
},
|
||||
{
|
||||
message:
|
||||
"`prefix` can only be set to `true` on the last message, and only for an assistant message.",
|
||||
}
|
||||
);
|
||||
|
||||
// https://docs.mistral.ai/api#operation/createChatCompletion
|
||||
const BaseMistralAIV1CompletionsSchema = z.object({
|
||||
export const MistralAIV1ChatCompletionsSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: MistralMessagesSchema.optional(),
|
||||
prompt: z.string().optional(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
})
|
||||
),
|
||||
temperature: z.number().optional().default(0.7),
|
||||
top_p: z.number().optional().default(1),
|
||||
max_tokens: z.coerce
|
||||
@@ -37,50 +18,12 @@ const BaseMistralAIV1CompletionsSchema = z.object({
|
||||
.nullish()
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
stream: z.boolean().optional().default(false),
|
||||
// Mistral docs say that `stop` can be a string or array but AWS Mistral
|
||||
// blows up if a string is passed. We must convert it to an array.
|
||||
stop: z
|
||||
.union([z.string(), z.array(z.string())])
|
||||
.optional()
|
||||
.default([])
|
||||
.transform((v) => (Array.isArray(v) ? v : [v])),
|
||||
random_seed: z.number().int().min(0).optional(),
|
||||
response_format: z
|
||||
.object({ type: z.enum(["text", "json_object"]) })
|
||||
.optional(),
|
||||
safe_prompt: z.boolean().optional().default(false),
|
||||
random_seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export const MistralAIV1ChatCompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.and(
|
||||
z.object({ messages: MistralMessagesSchema })
|
||||
);
|
||||
export const MistralAIV1TextCompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
|
||||
|
||||
/*
|
||||
Slightly more strict version that only allows a subset of the parameters. AWS
|
||||
Mistral helpfully returns no details if unsupported parameters are passed so
|
||||
this list comes from trial and error as of 2024-08-12.
|
||||
*/
|
||||
const BaseAWSMistralAIV1CompletionsSchema =
|
||||
BaseMistralAIV1CompletionsSchema.pick({
|
||||
temperature: true,
|
||||
top_p: true,
|
||||
max_tokens: true,
|
||||
stop: true,
|
||||
random_seed: true,
|
||||
// response_format: true,
|
||||
// safe_prompt: true,
|
||||
}).strip();
|
||||
export const AWSMistralV1ChatCompletionsSchema =
|
||||
BaseAWSMistralAIV1CompletionsSchema.and(
|
||||
z.object({ messages: MistralMessagesSchema })
|
||||
);
|
||||
export const AWSMistralV1TextCompletionsSchema =
|
||||
BaseAWSMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
|
||||
|
||||
export type MistralAIChatMessage = z.infer<typeof MistralChatMessageSchema>;
|
||||
export type MistralAIChatMessage = z.infer<
|
||||
typeof MistralAIV1ChatCompletionsSchema
|
||||
>["messages"][0];
|
||||
|
||||
export function fixMistralPrompt(
|
||||
messages: MistralAIChatMessage[]
|
||||
@@ -88,11 +31,12 @@ export function fixMistralPrompt(
|
||||
// Mistral uses OpenAI format but has some additional requirements:
|
||||
// - Only one system message per request, and it must be the first message if
|
||||
// present.
|
||||
// - Final message must be a user message, unless it has `prefix: true`.
|
||||
// - Final message must be a user message.
|
||||
// - Cannot have multiple messages from the same role in a row.
|
||||
// While frontends should be able to handle this, we can fix it here in the
|
||||
// meantime.
|
||||
const fixed = messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
|
||||
|
||||
return messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
|
||||
if (acc.length === 0) {
|
||||
acc.push(msg);
|
||||
return acc;
|
||||
@@ -113,54 +57,4 @@ export function fixMistralPrompt(
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
// If the last message is an assistant message, mark it as a prefix. An
|
||||
// assistant message at the end of the conversation without `prefix: true`
|
||||
// results in an error.
|
||||
if (fixed[fixed.length - 1].role === "assistant") {
|
||||
fixed[fixed.length - 1].prefix = true;
|
||||
}
|
||||
return fixed;
|
||||
}
|
||||
|
||||
let jinjaTemplate: Template;
|
||||
let renderTemplate: (messages: MistralAIChatMessage[]) => string;
|
||||
function renderMistralPrompt(messages: MistralAIChatMessage[]) {
|
||||
if (!jinjaTemplate) {
|
||||
logger.warn("Lazy loading mistral chat template...");
|
||||
const { chatTemplate, bosToken, eosToken } =
|
||||
require("./templates/mistral-template").MISTRAL_TEMPLATE;
|
||||
jinjaTemplate = new Template(chatTemplate);
|
||||
renderTemplate = (messages) =>
|
||||
jinjaTemplate.render({
|
||||
messages,
|
||||
bos_token: bosToken,
|
||||
eos_token: eosToken,
|
||||
});
|
||||
}
|
||||
|
||||
return renderTemplate(messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to convert a Mistral chat completions request to a text completions,
|
||||
* using the official prompt template published by Mistral.
|
||||
*/
|
||||
export const transformMistralChatToText: APIFormatTransformer<
|
||||
typeof MistralAIV1TextCompletionsSchema
|
||||
> = async (req) => {
|
||||
const { body } = req;
|
||||
const result = MistralAIV1ChatCompletionsSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid Mistral chat completions request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = renderMistralPrompt(messages);
|
||||
|
||||
return { ...rest, prompt, messages: undefined };
|
||||
};
|
||||
|
||||
@@ -52,7 +52,7 @@ export const OpenAIV1ChatCompletionSchema = z
|
||||
.number()
|
||||
.int()
|
||||
.nullish()
|
||||
.default(Math.min(OPENAI_OUTPUT_MAX, 16384))
|
||||
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
|
||||
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||
frequency_penalty: z.number().optional().default(0),
|
||||
presence_penalty: z.number().optional().default(0),
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
export const MISTRAL_TEMPLATE = {
|
||||
bosToken: "<s>",
|
||||
eosToken: "</s>",
|
||||
chatTemplate: `"{%- if messages[0]["role"] == "system" %}
|
||||
{%- set system_message = messages[0]["content"] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
|
||||
|
||||
{%- for message in loop_messages %}
|
||||
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if message["role"] == "user" %}
|
||||
{%- if loop.last and system_message is defined %}
|
||||
{{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}
|
||||
{%- else %}
|
||||
{{- "[INST] " + message["content"] + "[/INST]" }}
|
||||
{%- endif %}
|
||||
{%- elif message["role"] == "assistant" %}
|
||||
{%- if loop.last and message.prefix is defined and message.prefix %}
|
||||
{{- " " + message["content"] }}
|
||||
{%- else %}
|
||||
{{- " " + message["content"] + eos_token}}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
|
||||
{%- endif %}
|
||||
{%- endfor %}`,
|
||||
};
|
||||
@@ -1,18 +0,0 @@
|
||||
/** Module for generating and verifying HMAC signatures. */
|
||||
|
||||
import crypto from "crypto";
|
||||
import { SECRET_SIGNING_KEY } from "../config";
|
||||
|
||||
/**
|
||||
* Generates a HMAC signature for the given message. Optionally salts the
|
||||
* key with a provided string.
|
||||
*/
|
||||
export function signMessage(msg: any, salt: string = ""): string {
|
||||
const hmac = crypto.createHmac("sha256", SECRET_SIGNING_KEY + salt);
|
||||
if (typeof msg === "object") {
|
||||
hmac.update(JSON.stringify(msg));
|
||||
} else {
|
||||
hmac.update(msg);
|
||||
}
|
||||
return hmac.digest("hex");
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
import { doubleCsrf } from "csrf-csrf";
|
||||
import express from "express";
|
||||
import { config, SECRET_SIGNING_KEY } from "../config";
|
||||
import { config, COOKIE_SECRET } from "../config";
|
||||
|
||||
const { generateToken, doubleCsrfProtection } = doubleCsrf({
|
||||
getSecret: () => SECRET_SIGNING_KEY,
|
||||
getSecret: () => COOKIE_SECRET,
|
||||
cookieName: "csrf",
|
||||
cookieOptions: {
|
||||
sameSite: "strict",
|
||||
|
||||
@@ -7,9 +7,8 @@ import * as userStore from "./users/user-store";
|
||||
export const injectLocals: RequestHandler = (req, res, next) => {
|
||||
// config-related locals
|
||||
const quota = config.tokenQuota;
|
||||
const sumOfQuotas = Object.values(quota).reduce((a, b) => a + b, 0);
|
||||
|
||||
res.locals.quotasEnabled = sumOfQuotas > 0;
|
||||
res.locals.quotasEnabled =
|
||||
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
|
||||
res.locals.quota = quota;
|
||||
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
|
||||
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory";
|
||||
|
||||
@@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const oneMinute = 10 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import crypto from "crypto";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
|
||||
@@ -23,6 +23,10 @@ type AnthropicKeyUsage = {
|
||||
export interface AnthropicKey extends Key, AnthropicKeyUsage {
|
||||
readonly service: "anthropic";
|
||||
readonly modelFamilies: AnthropicModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/**
|
||||
* Whether this key requires a special preamble. For unclear reasons, some
|
||||
* Anthropic keys will throw an error if the prompt does not begin with a
|
||||
@@ -213,7 +217,22 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,31 +1,13 @@
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
|
||||
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
|
||||
import { URL } from "url";
|
||||
import { config } from "../../../config";
|
||||
import { getAwsBedrockModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
||||
import { AwsBedrockModelFamily } from "../../models";
|
||||
import { config } from "../../../config";
|
||||
|
||||
type ParentModelId = string;
|
||||
type AliasModelId = string;
|
||||
type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]];
|
||||
|
||||
const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
|
||||
["anthropic.claude-v2", "anthropic.claude-v2:1"],
|
||||
["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
["anthropic.claude-3-haiku-20240307-v1:0"],
|
||||
["anthropic.claude-3-opus-20240229-v1:0"],
|
||||
["anthropic.claude-3-5-sonnet-20240620-v1:0"],
|
||||
["mistral.mistral-7b-instruct-v0:2"],
|
||||
["mistral.mixtral-8x7b-instruct-v0:1"],
|
||||
["mistral.mistral-large-2402-v1:0"],
|
||||
["mistral.mistral-large-2407-v1:0"],
|
||||
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
|
||||
];
|
||||
|
||||
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||
const AMZ_HOST =
|
||||
@@ -33,8 +15,6 @@ const AMZ_HOST =
|
||||
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
|
||||
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
|
||||
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
|
||||
const GET_LIST_INFERENCE_PROFILES_URL = (region: string) =>
|
||||
`https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`;
|
||||
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
|
||||
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
|
||||
const TEST_MESSAGES = [
|
||||
@@ -44,22 +24,6 @@ const TEST_MESSAGES = [
|
||||
|
||||
type AwsError = { error: {} };
|
||||
|
||||
type GetInferenceProfilesResponse = {
|
||||
inferenceProfileSummaries: {
|
||||
inferenceProfileId: string;
|
||||
inferenceProfileName: string;
|
||||
inferenceProfileArn: string;
|
||||
description?: string;
|
||||
createdAt?: string;
|
||||
updatedAt?: string;
|
||||
status: "ACTIVE" | unknown;
|
||||
type: "SYSTEM_DEFINED" | unknown;
|
||||
models: {
|
||||
modelArn?: string;
|
||||
}[];
|
||||
}[];
|
||||
};
|
||||
|
||||
type GetLoggingConfigResponse = {
|
||||
loggingConfig: null | {
|
||||
cloudWatchConfig: null | unknown;
|
||||
@@ -78,67 +42,54 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
service: "aws",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
checks = [
|
||||
this.invokeModel("anthropic.claude-v2", key),
|
||||
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
|
||||
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
|
||||
];
|
||||
}
|
||||
checks.unshift(this.checkLoggingConfiguration(key));
|
||||
|
||||
const [_logging, claudeV2, sonnet, haiku, opus] = await Promise.all(checks);
|
||||
|
||||
if (isInitialCheck) {
|
||||
try {
|
||||
await this.checkInferenceProfiles(key);
|
||||
} catch (e) {
|
||||
const asError = e as AxiosError<AwsError>;
|
||||
const data = asError.response?.data;
|
||||
const families: AwsBedrockModelFamily[] = [];
|
||||
if (claudeV2 || sonnet || haiku) families.push("aws-claude");
|
||||
if (opus) families.push("aws-claude-opus");
|
||||
|
||||
if (families.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: e.message, data },
|
||||
"Cannot list inference profiles.\n\
|
||||
Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\
|
||||
Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\
|
||||
See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html."
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true });
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
sonnetEnabled: sonnet,
|
||||
haikuEnabled: haiku,
|
||||
modelFamilies: families,
|
||||
});
|
||||
}
|
||||
|
||||
// Perform checks for all parent model IDs
|
||||
const results = await Promise.all(
|
||||
KNOWN_MODEL_IDS.filter(([model]) =>
|
||||
// Skip checks for models that are disabled anyway
|
||||
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
|
||||
).map(async ([model, ...aliases]) => ({
|
||||
models: [model, ...aliases],
|
||||
success: await this.invokeModel(model, key),
|
||||
}))
|
||||
);
|
||||
|
||||
// Filter out models that are disabled
|
||||
const modelIds = results
|
||||
.filter(({ success }) => success)
|
||||
.flatMap(({ models }) => models);
|
||||
|
||||
if (modelIds.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true });
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
modelIds,
|
||||
modelFamilies: Array.from(
|
||||
new Set(modelIds.map(getAwsBedrockModelFamily))
|
||||
),
|
||||
});
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
logged: key.awsLoggingStatus,
|
||||
sonnet,
|
||||
haiku,
|
||||
families: key.modelFamilies,
|
||||
models: key.modelIds,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
@@ -209,52 +160,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(
|
||||
model: string,
|
||||
key: AwsBedrockKey
|
||||
): Promise<boolean> {
|
||||
if (model.includes("claude")) {
|
||||
// If inference profiles are available, try testing model with them.
|
||||
// If they are not available or the invocation fails with the inference
|
||||
// profile, fall back to regular model ID.
|
||||
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
const continent = region.split("-")[0];
|
||||
const profile = key.inferenceProfileIds.find(
|
||||
(id) => `${continent}.${model}` === id
|
||||
);
|
||||
|
||||
if (profile) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, profile },
|
||||
"Testing model via inference profile."
|
||||
);
|
||||
let result: boolean;
|
||||
try {
|
||||
result = await this.testClaudeModel(key, profile);
|
||||
} catch (e) {
|
||||
this.log.error(
|
||||
{ key: key.hash, model, profile, error: e.message },
|
||||
"Error testing model with inference profile; trying model ID directly."
|
||||
);
|
||||
result = false;
|
||||
}
|
||||
|
||||
// If the profile worked, we'll return success. Caller will add the
|
||||
// model (not the profile) to the list of enabled models, but the
|
||||
// profile will be used when the key is used for inference.
|
||||
if (result) return true;
|
||||
}
|
||||
return this.testClaudeModel(key, model);
|
||||
} else if (model.includes("mistral")) {
|
||||
return this.testMistralModel(key, model);
|
||||
}
|
||||
throw new Error("AwsKeyChecker#invokeModel: no implementation for model");
|
||||
}
|
||||
|
||||
private async testClaudeModel(
|
||||
key: AwsBedrockKey,
|
||||
model: string
|
||||
): Promise<boolean> {
|
||||
private async invokeModel(model: string, key: AwsBedrockKey) {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
// This is not a valid invocation payload, but a 400 response indicates that
|
||||
// the principal at least has permission to invoke the model.
|
||||
@@ -269,7 +175,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
|
||||
method: "POST",
|
||||
url: POST_INVOKE_MODEL_URL(creds.region, model),
|
||||
data: payload,
|
||||
validateStatus: (status) => [400, 403, 404].includes(status),
|
||||
validateStatus: (status) => status === 400 || status === 403,
|
||||
};
|
||||
config.headers = new AxiosHeaders({
|
||||
"content-type": "application/json",
|
||||
@@ -281,26 +187,11 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
|
||||
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
|
||||
const errorMessage = data?.message;
|
||||
|
||||
// This message indicates the key is valid but this particular model is not
|
||||
// accessible. Other 403s may indicate the key is not usable.
|
||||
// We only allow one type of 403 error, and we only allow it for one model.
|
||||
if (
|
||||
status === 403 &&
|
||||
errorMessage?.match(/access to the model with the specified model ID/)
|
||||
) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status, headers },
|
||||
"Model is not available (principal does not have access)."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
// ResourceNotFound typically indicates that the tested model cannot be used
|
||||
// on the configured region for this set of credentials.
|
||||
if (status === 404) {
|
||||
this.log.debug(
|
||||
{ region: creds.region, model, key: key.hash },
|
||||
"Model is not available (not supported in this AWS region)."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -309,91 +200,23 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
|
||||
const correctErrorType = errorType === "ValidationException";
|
||||
const correctErrorMessage = errorMessage?.match(/max_tokens/);
|
||||
if (!correctErrorType || !correctErrorMessage) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"Model is not available (request rejected)."
|
||||
);
|
||||
return false;
|
||||
// throw new AxiosError(
|
||||
// `Unexpected error when invoking model ${model}: ${errorMessage}`,
|
||||
// "AWS_ERROR",
|
||||
// response.config,
|
||||
// response.request,
|
||||
// response
|
||||
// );
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"Model is available."
|
||||
"AWS InvokeModel test successful."
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
private async testMistralModel(
|
||||
key: AwsBedrockKey,
|
||||
model: string
|
||||
): Promise<boolean> {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
|
||||
const payload = {
|
||||
max_tokens: -1,
|
||||
prompt: "<s>[INST] What is your favourite condiment? [/INST]</s>",
|
||||
};
|
||||
const config: AxiosRequestConfig = {
|
||||
method: "POST",
|
||||
url: POST_INVOKE_MODEL_URL(creds.region, model),
|
||||
data: payload,
|
||||
validateStatus: (status) => [400, 403, 404].includes(status),
|
||||
headers: {
|
||||
"content-type": "application/json",
|
||||
accept: "*/*",
|
||||
},
|
||||
};
|
||||
await AwsKeyChecker.signRequestForAws(config, key);
|
||||
const response = await axios.request(config);
|
||||
const { data, status, headers } = response;
|
||||
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
|
||||
const errorMessage = data?.message;
|
||||
|
||||
if (status === 403 || status === 404) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"Model is not available (no access or unsupported region)."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
const isBadRequest = status === 400;
|
||||
const isValidationError = errorMessage?.match(/validation error/i);
|
||||
if (isBadRequest && !isValidationError) {
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status, headers },
|
||||
"Model is not available (request rejected)."
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, model, errorType, data, status },
|
||||
"Model is available."
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
private async checkInferenceProfiles(key: AwsBedrockKey) {
|
||||
const creds = AwsKeyChecker.getCredentialsFromKey(key);
|
||||
const req: AxiosRequestConfig = {
|
||||
method: "GET",
|
||||
url: GET_LIST_INFERENCE_PROFILES_URL(creds.region),
|
||||
headers: { accept: "application/json" },
|
||||
};
|
||||
await AwsKeyChecker.signRequestForAws(req, key);
|
||||
const { data } = await axios.request<GetInferenceProfilesResponse>(req);
|
||||
const { inferenceProfileSummaries } = data;
|
||||
const profileIds = inferenceProfileSummaries.map(
|
||||
(p) => p.inferenceProfileId
|
||||
);
|
||||
this.log.debug(
|
||||
{ key: key.hash, profileIds, region: creds.region },
|
||||
"Inference profiles found."
|
||||
);
|
||||
this.updateKey(key.hash, { inferenceProfileIds: profileIds });
|
||||
}
|
||||
|
||||
private async checkLoggingConfiguration(key: AwsBedrockKey) {
|
||||
if (config.allowAwsLogging) {
|
||||
// Don't check logging status if we're allowing it to reduce API calls.
|
||||
@@ -462,8 +285,7 @@ See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-
|
||||
method,
|
||||
protocol: "https:",
|
||||
hostname: url.hostname,
|
||||
path: url.pathname,
|
||||
query: Object.fromEntries(url.searchParams),
|
||||
path: url.pathname + url.search,
|
||||
headers: { Host: url.hostname, ...plainHeaders },
|
||||
});
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { AwsKeyChecker } from "./checker";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
|
||||
type AwsBedrockKeyUsage = {
|
||||
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
|
||||
@@ -14,6 +13,10 @@ type AwsBedrockKeyUsage = {
|
||||
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
readonly service: "aws";
|
||||
readonly modelFamilies: AwsBedrockModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/**
|
||||
* The confirmed logging status of this key. This is "unknown" until we
|
||||
* receive a response from the AWS API. Keys which are logged, or not
|
||||
@@ -21,8 +24,8 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
* set.
|
||||
*/
|
||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||
modelIds: string[];
|
||||
inferenceProfileIds: string[];
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -72,14 +75,10 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
|
||||
inferenceProfileIds: [],
|
||||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
["aws-claudeTokens"]: 0,
|
||||
["aws-claude-opusTokens"]: 0,
|
||||
["aws-mistral-tinyTokens"]: 0,
|
||||
["aws-mistral-smallTokens"]: 0,
|
||||
["aws-mistral-mediumTokens"]: 0,
|
||||
["aws-mistral-largeTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
@@ -98,61 +97,51 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
let neededVariantId = model;
|
||||
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
|
||||
// Generally all AWS model IDs are supersets of the original vendor IDs.
|
||||
// Claude 2 is the only model that breaks this convention; Anthropic calls
|
||||
// it claude-2 but AWS calls it claude-v2.
|
||||
if (model.includes("claude-2")) neededVariantId = "claude-v2";
|
||||
const neededFamily = getAwsBedrockModelFamily(model);
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
// Select keys which
|
||||
const isNotLogged = k.awsLoggingStatus !== "enabled";
|
||||
const neededFamily = getAwsBedrockModelFamily(model);
|
||||
const needsSonnet =
|
||||
model.includes("sonnet") && neededFamily === "aws-claude";
|
||||
const needsHaiku =
|
||||
model.includes("haiku") && neededFamily === "aws-claude";
|
||||
return (
|
||||
// are enabled
|
||||
!k.isDisabled &&
|
||||
// are not logged, unless policy allows it
|
||||
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
|
||||
// have access to the model family we need
|
||||
k.modelFamilies.includes(neededFamily) &&
|
||||
// have access to the specific variant we need
|
||||
k.modelIds.some((m) => m.includes(neededVariantId))
|
||||
(isNotLogged || config.allowAwsLogging) &&
|
||||
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
|
||||
(k.haikuEnabled || !needsHaiku) &&
|
||||
k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
});
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
requestedModel: model,
|
||||
selectedVariant: neededVariantId,
|
||||
selectedFamily: neededFamily,
|
||||
totalKeys: this.keys.length,
|
||||
availableKeys: availableKeys.length,
|
||||
},
|
||||
"Selecting AWS key"
|
||||
);
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
throw new PaymentRequiredError(
|
||||
`No AWS Bedrock keys available for model ${model}`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Comparator for prioritizing keys on inference profile compatibility.
|
||||
* Requests made via inference profiles have higher rate limits so we want
|
||||
* to use keys with compatible inference profiles first.
|
||||
*/
|
||||
const hasInferenceProfile = (
|
||||
a: AwsBedrockKey,
|
||||
b: AwsBedrockKey
|
||||
) => {
|
||||
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
|
||||
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
|
||||
return aMatch - bMatch;
|
||||
};
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -180,7 +169,22 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
public getLockoutPeriod() {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import {
|
||||
AzureOpenAIModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
} from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
@@ -17,6 +14,10 @@ type AzureOpenAIKeyUsage = {
|
||||
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||
readonly service: "azure";
|
||||
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
contentFiltering: boolean;
|
||||
}
|
||||
|
||||
@@ -104,8 +105,30 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
);
|
||||
}
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -133,7 +156,26 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
// TODO: all of this shit is duplicate code
|
||||
|
||||
public getLockoutPeriod(family: AzureOpenAIModelFamily) {
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(family)
|
||||
);
|
||||
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -1,294 +0,0 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import crypto from "crypto";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { GcpKey, GcpKeyProvider } from "./provider";
|
||||
import { GcpModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
|
||||
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
|
||||
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
|
||||
`https://${GCP_HOST.replace(
|
||||
"%REGION%",
|
||||
region
|
||||
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
|
||||
const TEST_MESSAGES = [
|
||||
{ role: "user", content: "Hi!" },
|
||||
{ role: "assistant", content: "Hello!" },
|
||||
];
|
||||
|
||||
type UpdateFn = typeof GcpKeyProvider.prototype.update;
|
||||
|
||||
export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
|
||||
constructor(keys: GcpKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "gcp",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: GcpKey) {
|
||||
let checks: Promise<boolean>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
checks = [
|
||||
this.invokeModel("claude-3-haiku@20240307", key, true),
|
||||
this.invokeModel("claude-3-sonnet@20240229", key, true),
|
||||
this.invokeModel("claude-3-opus@20240229", key, true),
|
||||
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
|
||||
];
|
||||
|
||||
const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
|
||||
"GCP model initial tests complete."
|
||||
);
|
||||
|
||||
const families: GcpModelFamily[] = [];
|
||||
if (sonnet || sonnet35 || haiku) families.push("gcp-claude");
|
||||
if (opus) families.push("gcp-claude-opus");
|
||||
|
||||
if (families.length === 0) {
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
"Key does not have access to any models; disabling."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true });
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {
|
||||
sonnetEnabled: sonnet,
|
||||
haikuEnabled: haiku,
|
||||
sonnet35Enabled: sonnet35,
|
||||
modelFamilies: families,
|
||||
});
|
||||
} else {
|
||||
if (key.haikuEnabled) {
|
||||
await this.invokeModel("claude-3-haiku@20240307", key, false);
|
||||
} else if (key.sonnetEnabled) {
|
||||
await this.invokeModel("claude-3-sonnet@20240229", key, false);
|
||||
} else if (key.sonnet35Enabled) {
|
||||
await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
|
||||
} else {
|
||||
await this.invokeModel("claude-3-opus@20240229", key, false);
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
this.log.debug({ key: key.hash }, "GCP key check complete.");
|
||||
}
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
families: key.modelFamilies,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: GcpKey, error: AxiosError) {
|
||||
if (error.response && GcpKeyChecker.errorIsGcpError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if (status === 400 || status === 401 || status === 403) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
} else if (status === 429) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data },
|
||||
"Key is rate limited. Rechecking in a minute."
|
||||
);
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
} else {
|
||||
this.log.error(
|
||||
{ key: key.hash, status, error: data },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
return;
|
||||
}
|
||||
const { response, cause } = error;
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, cause, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempt to invoke the given model with the given key. Returns true if the
|
||||
* key has access to the model, false if it does not. Throws an error if the
|
||||
* key is disabled.
|
||||
*/
|
||||
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
|
||||
const creds = GcpKeyChecker.getCredentialsFromKey(key);
|
||||
const signedJWT = await GcpKeyChecker.createSignedJWT(
|
||||
creds.clientEmail,
|
||||
creds.privateKey
|
||||
);
|
||||
const [accessToken, jwtError] =
|
||||
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
|
||||
if (accessToken === null) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, jwtError },
|
||||
"Unable to get the access token"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
const payload = {
|
||||
max_tokens: 1,
|
||||
messages: TEST_MESSAGES,
|
||||
anthropic_version: "vertex-2023-10-16",
|
||||
};
|
||||
const { data, status } = await axios.post(
|
||||
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
|
||||
payload,
|
||||
{
|
||||
headers: GcpKeyChecker.getRequestHeaders(accessToken),
|
||||
validateStatus: initial
|
||||
? () => true
|
||||
: (status: number) => status >= 200 && status < 300,
|
||||
}
|
||||
);
|
||||
this.log.debug({ key: key.hash, data }, "Response from GCP");
|
||||
|
||||
if (initial) {
|
||||
return (
|
||||
(status >= 200 && status < 300) || status === 429 || status === 529
|
||||
);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static errorIsGcpError(error: AxiosError): error is AxiosError {
|
||||
const data = error.response?.data as any;
|
||||
if (Array.isArray(data)) {
|
||||
return data.length > 0 && data[0]?.error?.message;
|
||||
} else {
|
||||
return data?.error?.message;
|
||||
}
|
||||
}
|
||||
|
||||
static async createSignedJWT(email: string, pkey: string): Promise<string> {
|
||||
let cryptoKey = await crypto.subtle.importKey(
|
||||
"pkcs8",
|
||||
GcpKeyChecker.str2ab(atob(pkey)),
|
||||
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
|
||||
false,
|
||||
["sign"]
|
||||
);
|
||||
|
||||
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const issued = Math.floor(Date.now() / 1000);
|
||||
const expires = issued + 600;
|
||||
|
||||
const header = { alg: "RS256", typ: "JWT" };
|
||||
|
||||
const payload = {
|
||||
iss: email,
|
||||
aud: authUrl,
|
||||
iat: issued,
|
||||
exp: expires,
|
||||
scope: "https://www.googleapis.com/auth/cloud-platform",
|
||||
};
|
||||
|
||||
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
|
||||
JSON.stringify(header)
|
||||
);
|
||||
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
|
||||
JSON.stringify(payload)
|
||||
);
|
||||
|
||||
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
|
||||
|
||||
const signature = await crypto.subtle.sign(
|
||||
"RSASSA-PKCS1-v1_5",
|
||||
cryptoKey,
|
||||
GcpKeyChecker.str2ab(unsignedToken)
|
||||
);
|
||||
|
||||
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
|
||||
return `${unsignedToken}.${encodedSignature}`;
|
||||
}
|
||||
|
||||
static async exchangeJwtForAccessToken(
|
||||
signed_jwt: string
|
||||
): Promise<[string | null, string]> {
|
||||
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
|
||||
const params = {
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
assertion: signed_jwt,
|
||||
};
|
||||
|
||||
const r = await fetch(auth_url, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: Object.entries(params)
|
||||
.map(([k, v]) => `${k}=${v}`)
|
||||
.join("&"),
|
||||
}).then((res) => res.json());
|
||||
|
||||
if (r.access_token) {
|
||||
return [r.access_token, ""];
|
||||
}
|
||||
|
||||
return [null, JSON.stringify(r)];
|
||||
}
|
||||
|
||||
static str2ab(str: string): ArrayBuffer {
|
||||
const buffer = new ArrayBuffer(str.length);
|
||||
const bufferView = new Uint8Array(buffer);
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
bufferView[i] = str.charCodeAt(i);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
|
||||
let base64: string;
|
||||
if (typeof data === "string") {
|
||||
base64 = btoa(
|
||||
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
|
||||
String.fromCharCode(parseInt("0x" + p1, 16))
|
||||
)
|
||||
);
|
||||
} else {
|
||||
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
|
||||
}
|
||||
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
|
||||
}
|
||||
|
||||
static getRequestHeaders(accessToken: string) {
|
||||
return {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: GcpKey) {
|
||||
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
|
||||
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
|
||||
throw new Error("Invalid GCP key");
|
||||
}
|
||||
const privateKey = rawPrivateKey
|
||||
.replace(
|
||||
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
|
||||
""
|
||||
)
|
||||
.trim();
|
||||
|
||||
return { projectId, clientEmail, region, privateKey };
|
||||
}
|
||||
}
|
||||
@@ -1,202 +0,0 @@
|
||||
import crypto from "crypto";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { GcpModelFamily, getGcpModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { GcpKeyChecker } from "./checker";
|
||||
|
||||
type GcpKeyUsage = {
|
||||
[K in GcpModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface GcpKey extends Key, GcpKeyUsage {
|
||||
readonly service: "gcp";
|
||||
readonly modelFamilies: GcpModelFamily[];
|
||||
sonnetEnabled: boolean;
|
||||
haikuEnabled: boolean;
|
||||
sonnet35Enabled: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class GcpKeyProvider implements KeyProvider<GcpKey> {
|
||||
readonly service = "gcp";
|
||||
|
||||
private keys: GcpKey[] = [];
|
||||
private checker?: GcpKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.gcpCredentials?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"GCP_CREDENTIALS is not set. GCP API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: GcpKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["gcp-claude"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `gcp-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
sonnetEnabled: true,
|
||||
haikuEnabled: false,
|
||||
sonnet35Enabled: false,
|
||||
["gcp-claudeTokens"]: 0,
|
||||
["gcp-claude-opusTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new GcpKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
const neededFamily = getGcpModelFamily(model);
|
||||
|
||||
// this is a horrible mess
|
||||
// each of these should be separate model families, but adding model
|
||||
// families is not low enough friction for the rate at which gcp claude
|
||||
// model variants are added.
|
||||
const needsSonnet35 =
|
||||
model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude";
|
||||
const needsSonnet =
|
||||
!needsSonnet35 &&
|
||||
model.includes("sonnet") &&
|
||||
neededFamily === "gcp-claude";
|
||||
const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude";
|
||||
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
return (
|
||||
!k.isDisabled &&
|
||||
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not
|
||||
(k.haikuEnabled || !needsHaiku) &&
|
||||
(k.sonnet35Enabled || !needsSonnet35) &&
|
||||
k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
});
|
||||
|
||||
this.log.debug(
|
||||
{
|
||||
model,
|
||||
neededFamily,
|
||||
needsSonnet,
|
||||
needsHaiku,
|
||||
needsSonnet35,
|
||||
availableKeys: availableKeys.length,
|
||||
totalKeys: this.keys.length,
|
||||
},
|
||||
"Selecting GCP key"
|
||||
);
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
throw new PaymentRequiredError(
|
||||
`No GCP keys available for model ${model}`
|
||||
);
|
||||
}
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: GcpKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<GcpKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {
|
||||
this.keys.forEach(({ hash }) =>
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false })
|
||||
);
|
||||
this.checker?.scheduleNextCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { GoogleAIKey, GoogleAIKeyProvider } from "./provider";
|
||||
import { getGoogleAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 60 * 1000; // 3 hours
|
||||
const LIST_MODELS_URL =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models";
|
||||
|
||||
type ListModelsResponse = {
|
||||
models: {
|
||||
name: string;
|
||||
baseModelId: string;
|
||||
version: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
inputTokenLimit: number;
|
||||
outputTokenLimit: number;
|
||||
supportedGenerationMethods: string[];
|
||||
temperature: number;
|
||||
maxTemperature: number;
|
||||
topP: number;
|
||||
topK: number;
|
||||
}[];
|
||||
nextPageToken: string;
|
||||
};
|
||||
|
||||
type UpdateFn = typeof GoogleAIKeyProvider.prototype.update;
|
||||
|
||||
export class GoogleAIKeyChecker extends KeyCheckerBase<GoogleAIKey> {
|
||||
constructor(keys: GoogleAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "google-ai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: GoogleAIKey) {
|
||||
const provisionedModels = await this.getProvisionedModels(key);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, ids: key.modelIds.length },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
key: GoogleAIKey
|
||||
): Promise<GoogleAIModelFamily[]> {
|
||||
const { data } = await axios.get<ListModelsResponse>(
|
||||
`${LIST_MODELS_URL}?pageSize=1000&key=${key.key}`
|
||||
);
|
||||
const models = data.models;
|
||||
|
||||
const ids = new Set<string>();
|
||||
const families = new Set<GoogleAIModelFamily>();
|
||||
models.forEach(({ name }) => {
|
||||
families.add(getGoogleAIModelFamily(name));
|
||||
ids.add(name);
|
||||
});
|
||||
|
||||
const familiesArray = Array.from(families);
|
||||
this.updateKey(key.hash, {
|
||||
modelFamilies: familiesArray,
|
||||
modelIds: Array.from(ids),
|
||||
});
|
||||
|
||||
return familiesArray;
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: GoogleAIKey, error: AxiosError): void {
|
||||
if (error.response && GoogleAIKeyChecker.errorIsGoogleAIError(error)) {
|
||||
const httpStatus = error.response.status;
|
||||
const { code, message, status, details } = error.response.data.error;
|
||||
|
||||
switch (httpStatus) {
|
||||
case 400:
|
||||
const reason = details?.[0]?.reason;
|
||||
if (status === "INVALID_ARGUMENT" && reason === "API_KEY_INVALID") {
|
||||
this.log.warn(
|
||||
{ key: key.hash, reason, details },
|
||||
"Key check returned API_KEY_INVALID error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
} else if (
|
||||
status === "FAILED_PRECONDITION" &&
|
||||
message.match(/please enable billing/i)
|
||||
) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, message, details },
|
||||
"Key check returned billing disabled error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case 401:
|
||||
case 403:
|
||||
this.log.warn(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Key check returned Forbidden/Unauthorized error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
case 429:
|
||||
this.log.warn(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Key is rate limited. Rechecking key in 1 minute."
|
||||
);
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.error(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
|
||||
this.log.error(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 10 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
return this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
static errorIsGoogleAIError(
|
||||
error: AxiosError
|
||||
): error is AxiosError<GoogleAIError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.code || data?.error?.status;
|
||||
}
|
||||
}
|
||||
|
||||
type GoogleAIError = {
|
||||
error: {
|
||||
code: string;
|
||||
message: string;
|
||||
status: string;
|
||||
details: any[];
|
||||
};
|
||||
};
|
||||
@@ -1,15 +1,13 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { GoogleAIKeyChecker } from "./checker";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
import { HttpError, PaymentRequiredError } from "../../errors";
|
||||
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by
|
||||
// Google but Vertex is the GCP product for enterprise, while Google API is a
|
||||
// development/hobbyist product. They use completely different APIs and keys.
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by Google
|
||||
// but Vertex is the GCP product for enterprise. while Google AI is the
|
||||
// consumer-ish product. The API is different, and keys are not compatible.
|
||||
// https://ai.google.dev/docs/migrate_to_cloud
|
||||
|
||||
export type GoogleAIKeyUpdate = Omit<
|
||||
@@ -29,8 +27,10 @@ type GoogleAIKeyUsage = {
|
||||
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
|
||||
readonly service: "google-ai";
|
||||
readonly modelFamilies: GoogleAIModelFamily[];
|
||||
/** All detected model IDs on this key. */
|
||||
modelIds: string[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,7 +49,6 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
readonly service = "google-ai";
|
||||
|
||||
private keys: GoogleAIKey[] = [];
|
||||
private checker?: GoogleAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
@@ -79,40 +78,49 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"gemini-flashTokens": 0,
|
||||
"gemini-proTokens": 0,
|
||||
"gemini-ultraTokens": 0,
|
||||
modelIds: [],
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new GoogleAIKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
public init() {}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: string) {
|
||||
const neededFamily = getGoogleAIModelFamily(model);
|
||||
const availableKeys = this.keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
public get(_model: string) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new PaymentRequiredError("No Google AI keys available");
|
||||
}
|
||||
|
||||
const keysByPriority = prioritizeKeys(availableKeys);
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -133,14 +141,29 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
|
||||
key["gemini-proTokens"] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -10,7 +10,7 @@ export type APIFormat =
|
||||
| "anthropic-text" // Legacy flat string prompt format
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "mistral-text"
|
||||
| "cohere-chat";
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
@@ -31,10 +31,6 @@ export interface Key {
|
||||
lastChecked: number;
|
||||
/** Hash of the key, for logging and to find the key in the pool. */
|
||||
hash: string;
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -63,32 +59,9 @@ export interface KeyProvider<T extends Key = Key> {
|
||||
recheck(): void;
|
||||
}
|
||||
|
||||
export function createGenericGetLockoutPeriod<T extends Key>(
|
||||
getKeys: () => T[]
|
||||
) {
|
||||
return function (this: unknown, family?: ModelFamily): number {
|
||||
const keys = getKeys();
|
||||
const activeKeys = keys.filter(
|
||||
(k) => !k.isDisabled && (!family || k.modelFamilies.includes(family))
|
||||
);
|
||||
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
};
|
||||
}
|
||||
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { GcpKey } from "./gcp/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { MistralAIKey } from "./mistral-ai/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
||||
@@ -7,7 +7,6 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
|
||||
service: string;
|
||||
keyCheckPeriod: number;
|
||||
minCheckInterval: number;
|
||||
keyCheckBatchSize?: number;
|
||||
recurringChecksEnabled?: boolean;
|
||||
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
};
|
||||
@@ -23,8 +22,6 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
* than this.
|
||||
*/
|
||||
protected readonly keyCheckPeriod: number;
|
||||
/** Maximum number of keys to check simultaneously. */
|
||||
protected readonly keyCheckBatchSize: number;
|
||||
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
protected readonly keys: TKey[] = [];
|
||||
protected log: pino.Logger;
|
||||
@@ -36,7 +33,6 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
this.keyCheckPeriod = opts.keyCheckPeriod;
|
||||
this.minCheckInterval = opts.minCheckInterval;
|
||||
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
|
||||
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
|
||||
this.updateKey = opts.updateKey;
|
||||
this.service = opts.service;
|
||||
this.log = logger.child({ module: "key-checker", service: opts.service });
|
||||
@@ -82,7 +78,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||
|
||||
if (numUnchecked > 0) {
|
||||
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
|
||||
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
||||
|
||||
this.timeout = setTimeout(async () => {
|
||||
try {
|
||||
@@ -118,8 +114,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
);
|
||||
|
||||
// Don't check any individual key too often.
|
||||
// Don't check anything at all more frequently than some minimum interval
|
||||
// even if keys still need to be checked.
|
||||
// Don't check anything at all at a rate faster than once per 3 seconds.
|
||||
const nextCheck = Math.max(
|
||||
oldestKey.lastChecked + this.keyCheckPeriod,
|
||||
this.lastCheck + this.minCheckInterval
|
||||
|
||||
@@ -10,7 +10,6 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { GcpKeyProvider } from "./gcp/provider";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
import { MistralAIKeyProvider } from "./mistral-ai/provider";
|
||||
|
||||
@@ -28,7 +27,6 @@ export class KeyPool {
|
||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||
this.keyProviders.push(new MistralAIKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new GcpKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
|
||||
@@ -130,11 +128,7 @@ export class KeyPool {
|
||||
return "openai";
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
if (!model.includes('@')) {
|
||||
return "anthropic";
|
||||
} else {
|
||||
return "gcp";
|
||||
}
|
||||
return "anthropic";
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-ai";
|
||||
|
||||
@@ -69,9 +69,9 @@ export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
|
||||
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
|
||||
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
|
||||
const { status, data } = error.response;
|
||||
if ([401, 403].includes(status)) {
|
||||
if (status === 401) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: data, status },
|
||||
{ key: key.hash, error: data },
|
||||
"Key is invalid or revoked. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, {
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { HttpError } from "../../errors";
|
||||
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
|
||||
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
|
||||
import { prioritizeKeys } from "../prioritize-keys";
|
||||
import { MistralAIKeyChecker } from "./checker";
|
||||
import { HttpError } from "../../errors";
|
||||
|
||||
type MistralAIKeyUsage = {
|
||||
[K in MistralAIModelFamily as `${K}Tokens`]: number;
|
||||
@@ -14,6 +13,10 @@ type MistralAIKeyUsage = {
|
||||
export interface MistralAIKey extends Key, MistralAIKeyUsage {
|
||||
readonly service: "mistral-ai";
|
||||
readonly modelFamilies: MistralAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -95,8 +98,30 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
throw new HttpError(402, "No Mistral AI keys available");
|
||||
}
|
||||
|
||||
const selectedKey = prioritizeKeys(availableKeys)[0];
|
||||
selectedKey.lastUsed = Date.now();
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
@@ -125,7 +150,22 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
|
||||
key[`${family}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
|
||||
public getLockoutPeriod() {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
|
||||
@@ -26,6 +26,8 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
|
||||
isTrial: boolean;
|
||||
/** Set when key check returns a non-transient 429. */
|
||||
isOverQuota: boolean;
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/**
|
||||
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
|
||||
* number.
|
||||
@@ -109,7 +111,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
rateLimitRequestsReset: 0,
|
||||
rateLimitTokensReset: 0,
|
||||
turboTokens: 0,
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import { Key } from "./index";
|
||||
|
||||
/**
|
||||
* Given a list of keys, returns a new list of keys sorted from highest to
|
||||
* lowest priority. Keys are prioritized in the following order:
|
||||
*
|
||||
* 1. Keys which are not rate limited
|
||||
* a. If all keys were rate limited recently, select the least-recently
|
||||
* rate limited key.
|
||||
* b. Otherwise, select the first key.
|
||||
* 2. Keys which have not been used in the longest time
|
||||
* 3. Keys according to the custom comparator, if provided
|
||||
* @param keys The list of keys to sort
|
||||
* @param customComparator A custom comparator function to use for sorting
|
||||
*/
|
||||
export function prioritizeKeys<T extends Key>(
|
||||
keys: T[],
|
||||
customComparator?: (a: T, b: T) => number
|
||||
) {
|
||||
const now = Date.now();
|
||||
|
||||
return keys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil;
|
||||
const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
if (customComparator) {
|
||||
const result = customComparator(a, b);
|
||||
if (result !== 0) return result;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
}
|
||||
+64
-97
@@ -1,11 +1,12 @@
|
||||
// Don't import any other project files here as this is one of the first modules
|
||||
// loaded and it will cause circular imports.
|
||||
|
||||
import pino from "pino";
|
||||
import type { Request } from "express";
|
||||
|
||||
/**
|
||||
* The service that a model is hosted on. Distinct from `APIFormat` because some
|
||||
* services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure).
|
||||
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
|
||||
*/
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
@@ -13,8 +14,8 @@ export type LLMService =
|
||||
| "google-ai"
|
||||
| "mistral-ai"
|
||||
| "aws"
|
||||
| "gcp"
|
||||
| "azure";
|
||||
| "azure"
|
||||
| "cohere";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
@@ -24,27 +25,23 @@ export type OpenAIModelFamily =
|
||||
| "gpt4o"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude" | "claude-opus";
|
||||
export type GoogleAIModelFamily =
|
||||
| "gemini-flash"
|
||||
| "gemini-pro"
|
||||
| "gemini-ultra";
|
||||
export type GoogleAIModelFamily = "gemini-pro";
|
||||
export type MistralAIModelFamily =
|
||||
// mistral changes their model classes frequently so these no longer
|
||||
// correspond to specific models. consider them rough pricing tiers.
|
||||
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
|
||||
export type AwsBedrockModelFamily = `aws-${
|
||||
| AnthropicModelFamily
|
||||
| MistralAIModelFamily}`;
|
||||
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
|
||||
| "mistral-tiny"
|
||||
| "mistral-small"
|
||||
| "mistral-medium"
|
||||
| "mistral-large";
|
||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||
export type CohereModelFamily = "command-r" | "command-r-plus";
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GoogleAIModelFamily
|
||||
| MistralAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| GcpModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
| AzureOpenAIModelFamily
|
||||
| CohereModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
@@ -57,27 +54,21 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"dall-e",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-flash",
|
||||
"gemini-pro",
|
||||
"gemini-ultra",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
"mistral-large",
|
||||
"aws-claude",
|
||||
"aws-claude-opus",
|
||||
"aws-mistral-tiny",
|
||||
"aws-mistral-small",
|
||||
"aws-mistral-medium",
|
||||
"aws-mistral-large",
|
||||
"gcp-claude",
|
||||
"gcp-claude-opus",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4o",
|
||||
"azure-dall-e",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
] as const);
|
||||
|
||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
@@ -88,50 +79,12 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||
"google-ai",
|
||||
"mistral-ai",
|
||||
"aws",
|
||||
"gcp",
|
||||
"azure",
|
||||
"cohere",
|
||||
] as const);
|
||||
|
||||
export const MODEL_FAMILY_SERVICE: {
|
||||
[f in ModelFamily]: LLMService;
|
||||
} = {
|
||||
turbo: "openai",
|
||||
gpt4: "openai",
|
||||
"gpt4-turbo": "openai",
|
||||
"gpt4-32k": "openai",
|
||||
gpt4o: "openai",
|
||||
"dall-e": "openai",
|
||||
claude: "anthropic",
|
||||
"claude-opus": "anthropic",
|
||||
"aws-claude": "aws",
|
||||
"aws-claude-opus": "aws",
|
||||
"aws-mistral-tiny": "aws",
|
||||
"aws-mistral-small": "aws",
|
||||
"aws-mistral-medium": "aws",
|
||||
"aws-mistral-large": "aws",
|
||||
"gcp-claude": "gcp",
|
||||
"gcp-claude-opus": "gcp",
|
||||
"azure-turbo": "azure",
|
||||
"azure-gpt4": "azure",
|
||||
"azure-gpt4-32k": "azure",
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"azure-gpt4o": "azure",
|
||||
"azure-dall-e": "azure",
|
||||
"gemini-flash": "google-ai",
|
||||
"gemini-pro": "google-ai",
|
||||
"gemini-ultra": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
"mistral-large": "mistral-ai",
|
||||
};
|
||||
|
||||
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4o(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4o",
|
||||
"^chatgpt-4o": "gpt4o",
|
||||
"^gpt-4o-mini(-\\d{4}-\\d{2}-\\d{2})?$": "turbo", // closest match
|
||||
"^gpt-4o": "gpt4o",
|
||||
"^gpt-4-turbo(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4-turbo",
|
||||
"^gpt-4-turbo(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4-(0125|1106)(-preview)?$": "gpt4-turbo",
|
||||
@@ -145,6 +98,38 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^dall-e-\\d{1}$": "dall-e",
|
||||
};
|
||||
|
||||
export const MODEL_FAMILY_SERVICE: {
|
||||
[f in ModelFamily]: LLMService;
|
||||
} = {
|
||||
turbo: "openai",
|
||||
gpt4: "openai",
|
||||
"gpt4-turbo": "openai",
|
||||
"gpt4-32k": "openai",
|
||||
"gpt4o": "openai",
|
||||
"dall-e": "openai",
|
||||
claude: "anthropic",
|
||||
"claude-opus": "anthropic",
|
||||
"aws-claude": "aws",
|
||||
"aws-claude-opus": "aws",
|
||||
"azure-turbo": "azure",
|
||||
"azure-gpt4": "azure",
|
||||
"azure-gpt4-32k": "azure",
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"azure-gpt4o": "azure",
|
||||
"azure-dall-e": "azure",
|
||||
"gemini-pro": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
"mistral-large": "mistral-ai",
|
||||
"command-r": "cohere",
|
||||
"command-r-plus": "cohere",
|
||||
};
|
||||
|
||||
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
||||
|
||||
pino({ level: "debug" }).child({ module: "startup" });
|
||||
|
||||
export function getOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: OpenAIModelFamily = "gpt4"
|
||||
@@ -160,12 +145,8 @@ export function getClaudeModelFamily(model: string): AnthropicModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function getGoogleAIModelFamily(model: string): GoogleAIModelFamily {
|
||||
return model.includes("ultra")
|
||||
? "gemini-ultra"
|
||||
: model.includes("flash")
|
||||
? "gemini-flash"
|
||||
: "gemini-pro";
|
||||
export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
||||
return "gemini-pro";
|
||||
}
|
||||
|
||||
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
@@ -178,34 +159,16 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
return prunedModel as MistralAIModelFamily;
|
||||
case "open-mistral-7b":
|
||||
return "mistral-tiny";
|
||||
case "open-mistral-nemo":
|
||||
case "open-mixtral-8x7b":
|
||||
case "codestral":
|
||||
case "open-codestral-mamba":
|
||||
return "mistral-small";
|
||||
case "open-mixtral-8x22b":
|
||||
return "mistral-medium";
|
||||
default:
|
||||
return "mistral-small";
|
||||
return "mistral-tiny";
|
||||
}
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
|
||||
// remove vendor and version from AWS model ids
|
||||
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
|
||||
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
|
||||
|
||||
if (["claude", "anthropic"].some((x) => model.includes(x))) {
|
||||
return `aws-${getClaudeModelFamily(deAwsified)}`;
|
||||
} else if (model.includes("tral")) {
|
||||
return `aws-${getMistralAIModelFamily(deAwsified)}`;
|
||||
}
|
||||
return `aws-claude`;
|
||||
}
|
||||
|
||||
export function getGcpModelFamily(model: string): GcpModelFamily {
|
||||
if (model.includes("opus")) return "gcp-claude-opus";
|
||||
return "gcp-claude";
|
||||
if (model.includes("opus")) return "aws-claude-opus";
|
||||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function getAzureOpenAIModelFamily(
|
||||
@@ -226,6 +189,11 @@ export function getAzureOpenAIModelFamily(
|
||||
return defaultFamily;
|
||||
}
|
||||
|
||||
export function getCohereModelFamily(model: string): CohereModelFamily {
|
||||
if (model.includes("plus")) return "command-r-plus";
|
||||
return "command-r";
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
@@ -242,13 +210,10 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
let modelFamily: ModelFamily;
|
||||
|
||||
// Weird special case for AWS/GCP/Azure because they serve models with
|
||||
// different API formats, so the outbound API alone is not sufficient to
|
||||
// determine the partition.
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") {
|
||||
modelFamily = getAwsBedrockModelFamily(model);
|
||||
} else if (req.service === "gcp") {
|
||||
modelFamily = getGcpModelFamily(model);
|
||||
} else if (req.service === "azure") {
|
||||
modelFamily = getAzureOpenAIModelFamily(model);
|
||||
} else {
|
||||
@@ -266,9 +231,11 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
modelFamily = getGoogleAIModelFamily(model);
|
||||
break;
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
modelFamily = getMistralAIModelFamily(model);
|
||||
break;
|
||||
case "cohere-chat":
|
||||
modelFamily = getCohereModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
|
||||
@@ -30,12 +30,10 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
cost = 0.00001;
|
||||
break;
|
||||
case "aws-claude":
|
||||
case "gcp-claude":
|
||||
case "claude":
|
||||
cost = 0.000008;
|
||||
break;
|
||||
case "aws-claude-opus":
|
||||
case "gcp-claude-opus":
|
||||
case "claude-opus":
|
||||
cost = 0.000015;
|
||||
break;
|
||||
|
||||
@@ -67,9 +67,6 @@ async function getTokenCountForMessages({
|
||||
case "image":
|
||||
numTokens += await getImageTokenCount(part.source.data);
|
||||
break;
|
||||
case "tool_use":
|
||||
case "tool_result":
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unsupported Anthropic content type.`);
|
||||
}
|
||||
|
||||
@@ -47,9 +47,9 @@ type GoogleAIChatTokenCountRequest = {
|
||||
};
|
||||
|
||||
type MistralAIChatTokenCountRequest = {
|
||||
prompt: string | MistralAIChatMessage[];
|
||||
prompt: MistralAIChatMessage[];
|
||||
completion?: never;
|
||||
service: "mistral-ai" | "mistral-text";
|
||||
service: "mistral-ai";
|
||||
};
|
||||
|
||||
type FlatPromptTokenCountRequest = {
|
||||
@@ -128,7 +128,6 @@ export async function countTokens({
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
return {
|
||||
...getMistralAITokenCount(prompt ?? completion),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
|
||||
@@ -37,8 +37,6 @@ export const UserSchema = z
|
||||
tokenCounts: tokenCountsSchema,
|
||||
/** Maximum number of tokens the user can consume, by model family. */
|
||||
tokenLimits: tokenCountsSchema,
|
||||
/** User-specific token refresh amount, by model family. */
|
||||
tokenRefresh: tokenCountsSchema,
|
||||
/** Time at which the user was created. */
|
||||
createdAt: z.number(),
|
||||
/** Time at which the user last connected. */
|
||||
|
||||
@@ -13,7 +13,6 @@ import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAwsBedrockModelFamily,
|
||||
getGcpModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGoogleAIModelFamily,
|
||||
@@ -71,7 +70,6 @@ export function createUser(createOptions?: {
|
||||
type?: User["type"];
|
||||
expiresAt?: number;
|
||||
tokenLimits?: User["tokenLimits"];
|
||||
tokenRefresh?: User["tokenRefresh"];
|
||||
}) {
|
||||
const token = uuid();
|
||||
const newUser: User = {
|
||||
@@ -81,7 +79,6 @@ export function createUser(createOptions?: {
|
||||
promptCount: 0,
|
||||
tokenCounts: { ...INITIAL_TOKENS },
|
||||
tokenLimits: createOptions?.tokenLimits ?? { ...config.tokenQuota },
|
||||
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS },
|
||||
createdAt: Date.now(),
|
||||
meta: {},
|
||||
};
|
||||
@@ -126,7 +123,6 @@ export function upsertUser(user: UserUpdate) {
|
||||
promptCount: 0,
|
||||
tokenCounts: { ...INITIAL_TOKENS },
|
||||
tokenLimits: { ...config.tokenQuota },
|
||||
tokenRefresh: { ...INITIAL_TOKENS },
|
||||
createdAt: Date.now(),
|
||||
meta: {},
|
||||
};
|
||||
@@ -143,6 +139,7 @@ export function upsertUser(user: UserUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Write firebase migration to backfill new fields
|
||||
if (updates.tokenCounts) {
|
||||
for (const family of MODEL_FAMILIES) {
|
||||
updates.tokenCounts[family] ??= 0;
|
||||
@@ -153,16 +150,6 @@ export function upsertUser(user: UserUpdate) {
|
||||
updates.tokenLimits[family] ??= 0;
|
||||
}
|
||||
}
|
||||
// tokenRefresh is a special case where we want to merge the existing and
|
||||
// updated values for each model family, ignoring falsy values.
|
||||
if (updates.tokenRefresh) {
|
||||
const merged = { ...existing.tokenRefresh };
|
||||
for (const family of MODEL_FAMILIES) {
|
||||
merged[family] =
|
||||
updates.tokenRefresh[family] || existing.tokenRefresh[family];
|
||||
}
|
||||
updates.tokenRefresh = merged;
|
||||
}
|
||||
|
||||
users.set(user.token, Object.assign(existing, updates));
|
||||
usersToFlush.add(user.token);
|
||||
@@ -258,29 +245,19 @@ export function hasAvailableQuota({
|
||||
return tokensConsumed < tokenLimit;
|
||||
}
|
||||
|
||||
/**
|
||||
* For the given user, sets token limits for each model family to the sum of the
|
||||
* current count and the refresh amount, up to the default limit. If a quota is
|
||||
* not specified for a model family, it is not touched.
|
||||
*/
|
||||
export function refreshQuota(token: string) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
const { tokenQuota } = config;
|
||||
const { tokenCounts, tokenLimits, tokenRefresh } = user;
|
||||
|
||||
// Get default quotas for each model family.
|
||||
const defaultQuotas = Object.entries(tokenQuota) as [ModelFamily, number][];
|
||||
// If any user-specific refresh quotas are present, override default quotas.
|
||||
const userQuotas = defaultQuotas.map(
|
||||
([f, q]) => [f, (tokenRefresh[f] ?? 0) || q] as const /* narrow to tuple */
|
||||
);
|
||||
|
||||
userQuotas
|
||||
// Ignore families with no global or user-specific refresh quota.
|
||||
.filter(([, q]) => q > 0)
|
||||
// Increase family token limit by the family's refresh amount.
|
||||
.forEach(([f, q]) => (tokenLimits[f] = (tokenCounts[f] ?? 0) + q));
|
||||
const { tokenCounts, tokenLimits } = user;
|
||||
const quotas = Object.entries(config.tokenQuota) as [ModelFamily, number][];
|
||||
quotas
|
||||
// If a quota is not configured, don't touch any existing limits a user may
|
||||
// already have been assigned manually.
|
||||
.filter(([, quota]) => quota > 0)
|
||||
.forEach(
|
||||
([model, quota]) =>
|
||||
(tokenLimits[model] = (tokenCounts[model] ?? 0) + quota)
|
||||
);
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
@@ -330,7 +307,7 @@ function cleanupExpiredTokens() {
|
||||
user.meta.refreshable = config.captchaMode !== "none";
|
||||
disabled++;
|
||||
}
|
||||
const purgeTimeout = config.powTokenPurgeHours * 60 * 60 * 1000;
|
||||
const purgeTimeout = config.powTokenPurgeHours * 60 * 60 * 1000;
|
||||
if (user.disabledAt && user.disabledAt + purgeTimeout < now) {
|
||||
users.delete(user.token);
|
||||
usersToFlush.add(user.token);
|
||||
@@ -418,7 +395,6 @@ function getModelFamilyForQuotaUsage(
|
||||
// differentiate between Azure and OpenAI variants of the same model.
|
||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
|
||||
if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model);
|
||||
|
||||
switch (api) {
|
||||
case "openai":
|
||||
@@ -431,7 +407,6 @@ function getModelFamilyForQuotaUsage(
|
||||
case "google-ai":
|
||||
return getGoogleAIModelFamily(model);
|
||||
case "mistral-ai":
|
||||
case "mistral-text":
|
||||
return getMistralAIModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
.pagination li a {
|
||||
display: block;
|
||||
padding: 0.5em 1em;
|
||||
border-bottom: none;
|
||||
border-bottom: none;
|
||||
text-decoration: none;
|
||||
}
|
||||
.pagination li.active a {
|
||||
@@ -71,24 +71,20 @@
|
||||
td.actions:hover {
|
||||
background-color: #e0e6f6;
|
||||
}
|
||||
tr > td,
|
||||
tr > th {
|
||||
border-right: 1px solid #dedede;
|
||||
}
|
||||
|
||||
@media (max-width: 800px) {
|
||||
body {
|
||||
padding: 0.5em;
|
||||
}
|
||||
|
||||
table.full-width {
|
||||
width: 100%;
|
||||
position: static;
|
||||
left: auto;
|
||||
right: auto;
|
||||
margin-left: 0;
|
||||
margin-right: 0;
|
||||
}
|
||||
table.full-width {
|
||||
width: 100%;
|
||||
position: static;
|
||||
left: auto;
|
||||
right: auto;
|
||||
margin-left: 0;
|
||||
margin-right: 0;
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
@@ -99,13 +95,6 @@
|
||||
th.active {
|
||||
background-color: #446;
|
||||
}
|
||||
td.actions:hover {
|
||||
background-color: #446;
|
||||
}
|
||||
tr > td,
|
||||
tr > th {
|
||||
border-right: 1px solid #444;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
<p>
|
||||
Next refresh: <time><%- nextQuotaRefresh %></time>
|
||||
</p>
|
||||
<p>Next refresh: <time><%- nextQuotaRefresh %></time></p>
|
||||
<table class="striped">
|
||||
<thead>
|
||||
<tr>
|
||||
@@ -11,7 +9,7 @@
|
||||
<% } %>
|
||||
<th scope="col">Limit</th>
|
||||
<th scope="col">Remaining</th>
|
||||
<th scope="col" colspan="<%= showRefreshEdit ? 2 : 1 %>">Refresh Amount</th>
|
||||
<th scope="col">Refresh Amount</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
@@ -21,7 +19,7 @@
|
||||
<td><%- prettyTokens(user.tokenCounts[key]) %></td>
|
||||
<% if (showTokenCosts) { %>
|
||||
<td>$<%- tokenCost(key, user.tokenCounts[key]).toFixed(2) %></td>
|
||||
<% } %>
|
||||
<% } %>
|
||||
<% if (!user.tokenLimits[key]) { %>
|
||||
<td colspan="2" style="text-align: center">unlimited</td>
|
||||
<% } else { %>
|
||||
@@ -31,20 +29,7 @@
|
||||
<% if (user.type === "temporary") { %>
|
||||
<td>N/A</td>
|
||||
<% } else { %>
|
||||
<td><%- prettyTokens(user.tokenRefresh[key] || quota[key]) %></td>
|
||||
<% } %>
|
||||
<% if (showRefreshEdit) { %>
|
||||
<td class="actions">
|
||||
<a
|
||||
title="Edit"
|
||||
id="edit-refresh"
|
||||
href="#"
|
||||
data-field="tokenRefresh_<%= key %>"
|
||||
data-token="<%= user.token %>"
|
||||
data-modelFamily="<%= key %>"
|
||||
>✏️</a
|
||||
>
|
||||
</td>
|
||||
<td><%- prettyTokens(quota[key]) %></td>
|
||||
<% } %>
|
||||
</tr>
|
||||
<% }) %>
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import cookieParser from "cookie-parser";
|
||||
import expressSession from "express-session";
|
||||
import MemoryStore from "memorystore";
|
||||
import { config, SECRET_SIGNING_KEY } from "../config";
|
||||
import { config, COOKIE_SECRET } from "../config";
|
||||
|
||||
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
|
||||
|
||||
const cookieParserMiddleware = cookieParser(SECRET_SIGNING_KEY);
|
||||
const cookieParserMiddleware = cookieParser(COOKIE_SECRET);
|
||||
|
||||
const sessionMiddleware = expressSession({
|
||||
secret: SECRET_SIGNING_KEY,
|
||||
secret: COOKIE_SECRET,
|
||||
resave: false,
|
||||
saveUninitialized: false,
|
||||
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
|
||||
|
||||
@@ -2,7 +2,6 @@ import crypto from "crypto";
|
||||
import express from "express";
|
||||
import argon2 from "@node-rs/argon2";
|
||||
import { z } from "zod";
|
||||
import { signMessage } from "../../shared/hmac-signing";
|
||||
import {
|
||||
authenticate,
|
||||
createUser,
|
||||
@@ -14,13 +13,15 @@ import { config } from "../../config";
|
||||
/** Lockout time after verification in milliseconds */
|
||||
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
|
||||
|
||||
let powKeySalt = crypto.randomBytes(32).toString("hex");
|
||||
/** HMAC key for signing challenges; regenerated on startup */
|
||||
let hmacSecret = crypto.randomBytes(32).toString("hex");
|
||||
|
||||
/**
|
||||
* Invalidates any outstanding unsolved challenges.
|
||||
* Regenerate the HMAC key used for signing challenges. Calling this function
|
||||
* will invalidate all existing challenges.
|
||||
*/
|
||||
export function invalidatePowChallenges() {
|
||||
powKeySalt = crypto.randomBytes(32).toString("hex");
|
||||
export function invalidatePowHmacKey() {
|
||||
hmacSecret = crypto.randomBytes(32).toString("hex");
|
||||
}
|
||||
|
||||
const argon2Params = {
|
||||
@@ -140,6 +141,16 @@ function generateChallenge(clientIp?: string, token?: string): Challenge {
|
||||
};
|
||||
}
|
||||
|
||||
function signMessage(msg: any): string {
|
||||
const hmac = crypto.createHmac("sha256", hmacSecret);
|
||||
if (typeof msg === "object") {
|
||||
hmac.update(JSON.stringify(msg));
|
||||
} else {
|
||||
hmac.update(msg);
|
||||
}
|
||||
return hmac.digest("hex");
|
||||
}
|
||||
|
||||
async function verifySolution(
|
||||
challenge: Challenge,
|
||||
solution: string,
|
||||
@@ -202,7 +213,7 @@ router.post("/challenge", (req, res) => {
|
||||
}
|
||||
const { action, refreshToken, proxyKey } = data.data;
|
||||
if (config.proxyKey && proxyKey !== config.proxyKey) {
|
||||
res.status(401).json({ error: "Invalid proxy password" });
|
||||
res.status(400).json({ error: "Invalid proxy password" });
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -214,11 +225,11 @@ router.post("/challenge", (req, res) => {
|
||||
return;
|
||||
}
|
||||
const challenge = generateChallenge(req.ip, refreshToken);
|
||||
const signature = signMessage(challenge, powKeySalt);
|
||||
const signature = signMessage(challenge);
|
||||
res.json({ challenge, signature });
|
||||
} else {
|
||||
const challenge = generateChallenge(req.ip);
|
||||
const signature = signMessage(challenge, powKeySalt);
|
||||
const signature = signMessage(challenge);
|
||||
res.json({ challenge, signature });
|
||||
}
|
||||
});
|
||||
@@ -242,7 +253,7 @@ router.post("/verify", async (req, res) => {
|
||||
}
|
||||
|
||||
const { challenge, signature, solution } = result.data;
|
||||
if (signMessage(challenge, powKeySalt) !== signature) {
|
||||
if (signMessage(challenge) !== signature) {
|
||||
res.status(400).json({
|
||||
error:
|
||||
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",
|
||||
|
||||
@@ -303,10 +303,6 @@
|
||||
_csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content"),
|
||||
};
|
||||
|
||||
if (localStorage.getItem("captcha-proxy-key")) {
|
||||
body.proxyKey = localStorage.getItem("captcha-proxy-key");
|
||||
}
|
||||
|
||||
fetch("/user/captcha/verify", {
|
||||
method: "POST",
|
||||
credentials: "same-origin",
|
||||
|
||||
@@ -64,7 +64,7 @@
|
||||
</table>
|
||||
|
||||
<h3>Quota Information</h3>
|
||||
<%- include("partials/shared_quota-info", { quota, user, showRefreshEdit: false }) %>
|
||||
<%- include("partials/shared_quota-info", { quota, user }) %>
|
||||
|
||||
<form id="edit-nickname-form" style="display: none" action="/user/edit-nickname" method="post">
|
||||
<input type="hidden" name="_csrf" value="<%= csrfToken %>" />
|
||||
|
||||
@@ -61,11 +61,7 @@
|
||||
const refreshToken = token && action === "refresh" ? JSON.parse(token).token : undefined;
|
||||
const keyInput = document.getElementById("proxy-key");
|
||||
const proxyKey = (keyInput && keyInput.value) || undefined;
|
||||
if (!proxyKey?.length) {
|
||||
localStorage.removeItem("captcha-proxy-key");
|
||||
} else {
|
||||
localStorage.setItem("captcha-proxy-key", proxyKey);
|
||||
}
|
||||
localStorage.setItem("captcha-proxy-key", proxyKey);
|
||||
|
||||
fetch("/user/captcha/challenge", {
|
||||
method: "POST",
|
||||
|
||||
Reference in New Issue
Block a user