41 Commits

Author SHA1 Message Date
user 9e6fd7c24c Implement tools (function calling) for Claude 2024-09-08 00:04:03 +00:00
nai-degen ac92a19946 improves reliability of inference profile detection for AWS keychecker 2024-09-07 17:36:29 -05:00
khanon 96fe974ad0 Use AWS Inference Profiles for higher rate limits (khanon/oai-reverse-proxy!78) 2024-09-01 22:55:07 +00:00
nai-degen 578615fbd2 fixes typo in new Claude system prompt schema 2024-08-30 10:23:57 -05:00
nai-degen 5dc4050e52 disable periodic GCP key rechecks to workaround keychecker bug 2024-08-29 15:25:37 -05:00
nai-degen cf615ee62c applies prettier to GCP checker 2024-08-29 15:15:56 -05:00
nai-degen ee61f9be2b removes unnecessary log from last commit 2024-08-27 23:58:32 -05:00
nai-degen 0c448cb59d fixes azure dalle using wrong rate limit and out-of-spec Retry-After header 2024-08-27 23:53:28 -05:00
nai-degen 51a9ccceb2 supports alternate claude system prompt format 2024-08-27 23:27:20 -05:00
nai-degen ce490efd7d minor adjustments to HMAC signing 2024-08-22 19:54:02 -05:00
nai-degen 5000e59a61 fix for google makersuite prompt validation/transformation 2024-08-22 14:19:48 -05:00
nai-degen d54acad6ad adds support for sonnet 8192 output tokens on anthropic api 2024-08-15 11:55:13 -05:00
nai-degen 5e1fffe07d adds chatgpt-4o-latest 2024-08-15 11:54:42 -05:00
nai-degen f7fd5f00f2 fixes esponse_format schema for mistral la plateforme 2024-08-14 14:41:47 -05:00
nai-degen 6d323f6ea1 do not transform mistral chat prompts to text when using la plateforme 2024-08-14 12:26:27 -05:00
nai-degen 2959ed3f7f fixes aws keychecker not detecting claude 2.1 2024-08-14 10:49:02 -05:00
nai-degen b58e7cb830 always applies Mistral prompt fixes on messages input 2024-08-14 10:48:55 -05:00
khanon f531272b00 Refactor AWS service code and add AWS Mistral support (khanon/oai-reverse-proxy!75) 2024-08-14 04:40:41 +00:00
nai-degen 6c45c92ea0 updates dependencies 2024-08-12 19:10:15 -05:00
nai-degen b7cd326d2a handles 'invalid subscription' 403 errors from Mistral API 2024-08-07 14:14:53 -05:00
nai-degen 6c9f302fb9 minor gultra fix 2024-08-06 18:46:49 -05:00
nai-degen 9ab1e7d0ce adds new gpt4o id 2024-08-06 13:08:25 -05:00
nai-degen 81f8dc2613 updates README.md 2024-08-05 11:33:16 -05:00
khanon 0c936e97fe Merge GCP Vertex AI implementation from cg-dot/oai-reverse-proxy (khanon/oai-reverse-proxy!72) 2024-08-05 14:27:51 +00:00
nai-degen 29ed07492e fixes info page display for gemini flash/ultra 2024-08-03 22:18:05 -05:00
nai-degen 2f7315379c adds gemini/makersuite keychecker, native endpoint, and streaming fixes 2024-08-03 21:53:32 -05:00
nai-degen e91532f4f7 handle dead makersuite keys triggering 400 error instead of 401/403 2024-08-03 19:09:50 -05:00
nai-degen ca58770458 fixes issue with PROXY_KEY when used together with proof-of-work captcha 2024-07-29 19:41:57 -05:00
nai-degen 9a3cca6b80 adds new mistral models and updates older model lists/context limits 2024-07-28 13:15:03 -05:00
nai-degen 584bb3fbc7 addresses minor issue with quota refresh UI 2024-07-28 11:54:38 -05:00
nai-degen 2aa19e5b09 adds user-specific overrides for daily quota refresh 2024-07-27 14:25:53 -05:00
nai-degen f242777596 fixes token index used as msg idx in anthropic chat-to-openai SSE transformer 2024-07-07 13:33:33 -05:00
nai-degen edc0d094e2 tries to disable quarantined aws keys 2024-06-30 05:08:27 -05:00
nai-degen 994b30dcce adjusts gemini pro model assignment 2024-06-26 13:37:23 -05:00
nai-degen e3d1ab51d1 improves handling of AWS regions with Sonnet 3.5 enabled but Sonnet 3.0 disabled 2024-06-20 12:20:38 -05:00
nai-degen ff38eda066 improves model detection for AWS Sydney region 2024-06-20 12:19:44 -05:00
nai-degen 84b917f726 fixes AWS Sonnet 3.5 key assignment bug 2024-06-20 12:00:11 -05:00
nai-degen 5871025245 fixes AWS keychecker failure caused by Sonnet 3.5 gradual rollout 2024-06-20 11:24:47 -05:00
nai-degen b4fb97ca5c fixes model id typo 2024-06-20 10:42:48 -05:00
nai-degen eb700d3da6 adds untested claude 3.5 model ids and model assignment 2024-06-20 10:34:48 -05:00
nai-degen d706d4c59d adds USER_CONCURRENCY_LIMIT environment variable 2024-06-14 22:52:16 -05:00
96 changed files with 3928 additions and 1684 deletions
+13 -3
View File
@@ -40,15 +40,21 @@ NODE_ENV=production
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus
# | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small
# | mistral-medium | mistral-large | aws-claude | aws-claude-opus | gcp-claude
# | gcp-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k
# | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
# 'azure-dall-e' to the list of allowed model families.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,gcp-claude,gcp-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
# Which services can be used to process prompts containing images via multimodal
# models. The following services are recognized:
# openai | anthropic | aws | azure | google-ai | mistral-ai
# openai | anthropic | aws | gcp | azure | google-ai | mistral-ai
# Do not enable this feature unless all users are trusted, as you will be liable
# for any user-submitted images containing illegal content.
# By default, no image services are allowed and image prompts are rejected.
@@ -118,6 +124,7 @@ NODE_ENV=production
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# TOKEN_QUOTA_GCP_CLAUDE=0
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
@@ -142,12 +149,15 @@ NODE_ENV=production
# You can add multiple API keys by separating them with a comma.
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
# For GCP credentials, separate the project ID, client email, region, and private key with a colon.
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
GCP_CREDENTIALS=project-id:client-email:region:private-key
# With proxy_key gatekeeper, the password users must provide to access the API.
# PROXY_KEY=your-secret-key
+3 -6
View File
@@ -7,9 +7,8 @@ Reverse proxy server for various LLM APIs.
- [Features](#features)
- [Usage Instructions](#usage-instructions)
- [Self-hosting](#self-hosting)
- [Alternatives](#alternatives)
- [Huggingface (outdated, not advised)](#huggingface-outdated-not-advised)
- [Render (outdated, not advised)](#render-outdated-not-advised)
- [Huggingface (outdated, not advised)](#huggingface-outdated-not-advised)
- [Render (outdated, not advised)](#render-outdated-not-advised)
- [Local Development](#local-development)
## What is this?
@@ -20,6 +19,7 @@ This project allows you to run a reverse proxy server for various LLM APIs.
- [x] [OpenAI](https://openai.com/)
- [x] [Anthropic](https://www.anthropic.com/)
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
- [x] [Vertex AI (GCP)](https://cloud.google.com/vertex-ai/)
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
@@ -41,9 +41,6 @@ If you'd like to run your own instance of this server, you'll need to deploy it
**Ensure you set the `TRUSTED_PROXIES` environment variable according to your deployment.** Refer to [.env.example](./.env.example) and [config.ts](./src/config.ts) for more information.
### Alternatives
Fiz and Sekrit are working on some alternative ways to deploy this conveniently. While I'm not involved in this effort beyond providing technical advice regarding my code, I'll link to their work here for convenience: [Sekrit's rentry](https://rentry.org/sekrit)
### Huggingface (outdated, not advised)
[See here for instructions on how to deploy to a Huggingface Space.](./docs/deploy-huggingface.md)
+35
View File
@@ -0,0 +1,35 @@
# Configuring the proxy for Vertex AI (GCP)
The proxy supports GCP models via the `/proxy/gcp/claude` endpoint. There are a few extra steps necessary to use GCP compared to the other supported APIs.
- [Setting keys](#setting-keys)
- [Setup Vertex AI](#setup-vertex-ai)
- [Supported model IDs](#supported-model-ids)
## Setting keys
Use the `GCP_CREDENTIALS` environment variable to set the GCP API keys.
Like other APIs, you can provide multiple keys separated by commas. Each GCP key, however, is a set of credentials including the project id, client email, region and private key. These are separated by a colon (`:`).
For example:
```
GCP_CREDENTIALS=my-first-project:xxx@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,my-first-project2:xxx2@yyy.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----
```
## Setup Vertex AI
1. Go to [https://cloud.google.com/vertex-ai](https://cloud.google.com/vertex-ai) and sign up for a GCP account. ($150 free credits without credit card or $300 free credits with credit card, credits expire in 90 days)
2. Go to [https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com](https://console.cloud.google.com/marketplace/product/google/aiplatform.googleapis.com) to enable Vertex AI API.
3. Go to [https://console.cloud.google.com/vertex-ai](https://console.cloud.google.com/vertex-ai) and navigate to Model Garden to apply for access to the Claude models.
4. Create a [Service Account](https://console.cloud.google.com/projectselector/iam-admin/serviceaccounts/create?walkthrough_id=iam--create-service-account#step_index=1) , and make sure to grant the role of "Vertex AI User" or "Vertex AI Administrator".
5. On the service account page you just created, create a new key and select "JSON". The JSON file will be downloaded automatically.
6. The required credential is in the JSON file you just downloaded.
## Supported model IDs
Users can send these model IDs to the proxy to invoke the corresponding models.
- **Claude**
- `claude-3-haiku@20240307`
- `claude-3-sonnet@20240229`
- `claude-3-opus@20240229`
- `claude-3-5-sonnet@20240620`
+1 -1
View File
@@ -129,7 +129,7 @@ also significantly reduce hash rates on mobile devices.
- Intel Core i9-13900K (Chrome, in VM limited to 4 cores): 12.2 - 13.0 H/s
- iPad Pro (M2) (Safari, 6 workers): 8.0 - 10 H/s
- Thermal throttles early. 8 cores is normal concurrency, but unstable.
- iPhone 13 Pro (Safari): 4.0 - 4.6 H/s
- iPhone 15 Pro Max (Safari): 4.0 - 4.6 H/s
- Samsung Galaxy S10e (Chrome): 3.6 - 3.8 H/s
- This is a 2019 phone almost matching an iPhone five years newer because of
bad Safari performance.
+226 -141
View File
@@ -11,14 +11,14 @@
"dependencies": {
"@anthropic-ai/tokenizer": "^0.0.4",
"@aws-crypto/sha256-js": "^5.2.0",
"@huggingface/jinja": "^0.3.0",
"@node-rs/argon2": "^1.8.3",
"@smithy/eventstream-codec": "^2.1.3",
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
@@ -29,7 +29,7 @@
"ejs": "^3.1.10",
"express": "^4.18.2",
"express-session": "^1.17.3",
"firebase-admin": "^12.1.0",
"firebase-admin": "^12.3.1",
"glob": "^10.3.12",
"googleapis": "^122.0.0",
"http-proxy-middleware": "^3.0.0-beta.1",
@@ -39,7 +39,7 @@
"node-schedule": "^2.1.1",
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "2.12.1",
"sanitize-html": "^2.13.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"source-map-support": "^0.5.21",
@@ -51,6 +51,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
@@ -151,6 +152,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/types/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@aws-sdk/util-utf8-browser": {
"version": "3.259.0",
"resolved": "https://registry.npmjs.org/@aws-sdk/util-utf8-browser/-/util-utf8-browser-3.259.0.tgz",
@@ -542,12 +554,9 @@
}
},
"node_modules/@fastify/busboy": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-2.1.1.tgz",
"integrity": "sha512-vBZP4NlzfOlerQTnba4aqZoMhE/a9HY7HRqoOPaETQcSQuWEIyZMHGfVu6w9wGtGK5fED5qRs2DteVCjOH60sA==",
"engines": {
"node": ">=14"
}
"version": "3.0.0",
"resolved": "https://registry.npmjs.org/@fastify/busboy/-/busboy-3.0.0.tgz",
"integrity": "sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w=="
},
"node_modules/@firebase/app-check-interop-types": {
"version": "0.3.1",
@@ -626,14 +635,14 @@
}
},
"node_modules/@google-cloud/firestore": {
"version": "7.6.0",
"resolved": "https://registry.npmjs.org/@google-cloud/firestore/-/firestore-7.6.0.tgz",
"integrity": "sha512-WUDbaLY8UnPxgwsyIaxj6uxCtSDAaUyvzWJykNH5rZ9i92/SZCsPNNMN0ajrVpAR81hPIL4amXTaMJ40y5L+Yg==",
"version": "7.9.0",
"resolved": "https://registry.npmjs.org/@google-cloud/firestore/-/firestore-7.9.0.tgz",
"integrity": "sha512-c4ALHT3G08rV7Zwv8Z2KG63gZh66iKdhCBeDfCpIkLrjX6EAjTD/szMdj14M+FnQuClZLFfW5bAgoOjfNmLtJg==",
"optional": true,
"dependencies": {
"fast-deep-equal": "^3.1.1",
"functional-red-black-tree": "^1.0.1",
"google-gax": "^4.3.1",
"google-gax": "^4.3.3",
"protobufjs": "^7.2.6"
},
"engines": {
@@ -839,12 +848,12 @@
}
},
"node_modules/@grpc/grpc-js": {
"version": "1.10.6",
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.10.6.tgz",
"integrity": "sha512-xP58G7wDQ4TCmN/cMUHh00DS7SRDv/+lC+xFLrTkMIN8h55X5NhZMLYbvy7dSELP15qlI6hPhNCRWVMtZMwqLA==",
"version": "1.11.1",
"resolved": "https://registry.npmjs.org/@grpc/grpc-js/-/grpc-js-1.11.1.tgz",
"integrity": "sha512-gyt/WayZrVPH2w/UTLansS7F9Nwld472JxxaETamrM8HNlsa+jSLNyKAZmhxI2Me4c3mQHFiS1wWHDY1g1Kthw==",
"optional": true,
"dependencies": {
"@grpc/proto-loader": "^0.7.10",
"@grpc/proto-loader": "^0.7.13",
"@js-sdsl/ordered-map": "^4.4.2"
},
"engines": {
@@ -852,14 +861,14 @@
}
},
"node_modules/@grpc/proto-loader": {
"version": "0.7.12",
"resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.7.12.tgz",
"integrity": "sha512-DCVwMxqYzpUCiDMl7hQ384FqP4T3DbNpXU8pt681l3UWCip1WUiD5JrkImUwCB9a7f2cq4CUTmi5r/xIMRPY1Q==",
"version": "0.7.13",
"resolved": "https://registry.npmjs.org/@grpc/proto-loader/-/proto-loader-0.7.13.tgz",
"integrity": "sha512-AiXO/bfe9bmxBjxxtYxFAXGZvMaN5s8kO+jBHAJCON8rJoB5YS/D6X7ZNc6XQkuHNmyl4CYaMI1fJ/Gn27RGGw==",
"optional": true,
"dependencies": {
"lodash.camelcase": "^4.3.0",
"long": "^5.0.0",
"protobufjs": "^7.2.4",
"protobufjs": "^7.2.5",
"yargs": "^17.7.2"
},
"bin": {
@@ -869,6 +878,14 @@
"node": ">=6"
}
},
"node_modules/@huggingface/jinja": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.3.0.tgz",
"integrity": "sha512-GLJzso0M07ZncFkrJMIXVU4os6GFbPocD4g8fMQPMGJubf48FtGOsUORH2rtFdXPIPelz8SLBMn8ZRmOTwXm9Q==",
"engines": {
"node": ">=18"
}
},
"node_modules/@isaacs/cliui": {
"version": "8.0.2",
"resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz",
@@ -1322,6 +1339,17 @@
"tslib": "^2.5.0"
}
},
"node_modules/@smithy/eventstream-codec/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-node/-/eventstream-serde-node-2.1.3.tgz",
@@ -1335,6 +1363,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-node/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/eventstream-serde-universal/-/eventstream-serde-universal-2.1.3.tgz",
@@ -1348,6 +1387,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/eventstream-serde-universal/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/is-array-buffer": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/is-array-buffer/-/is-array-buffer-2.1.1.tgz",
@@ -1371,6 +1421,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/protocol-http/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/signature-v4": {
"version": "2.1.3",
"resolved": "https://registry.npmjs.org/@smithy/signature-v4/-/signature-v4-2.1.3.tgz",
@@ -1389,17 +1450,29 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/types": {
"version": "2.10.1",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.10.1.tgz",
"integrity": "sha512-hjQO+4ru4cQ58FluQvKKiyMsFg0A6iRpGm2kqdH8fniyNd2WyanoOsYJfMX/IFLuLxEoW6gnRkNZy1y6fUUhtA==",
"node_modules/@smithy/signature-v4/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.5.0"
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/types": {
"version": "3.3.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-3.3.0.tgz",
"integrity": "sha512-IxvBBCTFDHbVoK7zIxqA1ZOdc4QfM5HM7rGleCuHi7L1wnKv5Pn69xXJQ9hgxH60ZVygH9/JG0jRgtUncE3QUA==",
"dev": true,
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=16.0.0"
}
},
"node_modules/@smithy/util-buffer-from": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-buffer-from/-/util-buffer-from-2.1.1.tgz",
@@ -1435,6 +1508,17 @@
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-middleware/node_modules/@smithy/types": {
"version": "2.12.0",
"resolved": "https://registry.npmjs.org/@smithy/types/-/types-2.12.0.tgz",
"integrity": "sha512-QwYgloJ0sVNBeBuBs65cIkTbfzV/Q6ZNPCJ99EICFEdJYG50nGIY/uYXp+TbsdJReIuPr0a0kXmCvren3MbRRw==",
"dependencies": {
"tslib": "^2.6.2"
},
"engines": {
"node": ">=14.0.0"
}
},
"node_modules/@smithy/util-uri-escape": {
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/@smithy/util-uri-escape/-/util-uri-escape-2.1.1.tgz",
@@ -1589,9 +1673,9 @@
}
},
"node_modules/@types/jsonwebtoken": {
"version": "9.0.2",
"resolved": "https://registry.npmjs.org/@types/jsonwebtoken/-/jsonwebtoken-9.0.2.tgz",
"integrity": "sha512-drE6uz7QBKq1fYqqoFKTDRdFCPHd5TCub75BM+D+cMx7NU9hUz7SESLfC2fSCXVFMO5Yj8sOWHuGqPgjc+fz0Q==",
"version": "9.0.6",
"resolved": "https://registry.npmjs.org/@types/jsonwebtoken/-/jsonwebtoken-9.0.6.tgz",
"integrity": "sha512-/5hndP5dCjloafCXns6SZyESp3Ldq7YjH3zwzwczYnjxIT0Fqzk5ROSYVGfFyczIue7IUEj8hkvLbPoLQ18vQw==",
"dependencies": {
"@types/node": "*"
}
@@ -1890,11 +1974,11 @@
}
},
"node_modules/axios": {
"version": "1.6.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
"version": "1.7.4",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.7.4.tgz",
"integrity": "sha512-DukmaFRnY6AzAALSH4J2M3k6PkaC+MfaAGdEERRWcC9q3/TWQwLpHR8ZRLKTdQ3aBDL64EdluRDjJqKw+BPZEw==",
"dependencies": {
"follow-redirects": "^1.15.0",
"follow-redirects": "^1.15.6",
"form-data": "^4.0.0",
"proxy-from-env": "^1.1.0"
}
@@ -2042,11 +2126,11 @@
}
},
"node_modules/braces": {
"version": "3.0.2",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
"integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
"integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
"dependencies": {
"fill-range": "^7.0.1"
"fill-range": "^7.1.1"
},
"engines": {
"node": ">=8"
@@ -2996,24 +3080,14 @@
"resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz",
"integrity": "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g=="
},
"node_modules/farmhash": {
"version": "3.3.1",
"resolved": "https://registry.npmjs.org/farmhash/-/farmhash-3.3.1.tgz",
"integrity": "sha512-XUizHanzlr/v7suBr/o85HSakOoWh6HKXZjFYl5C2+Gj0f0rkw+XTUZzrd9odDsgI9G5tRUcF4wSbKaX04T0DQ==",
"hasInstallScript": true,
"dependencies": {
"node-addon-api": "^5.1.0",
"prebuild-install": "^7.1.2"
},
"node_modules/farmhash-modern": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/farmhash-modern/-/farmhash-modern-1.1.0.tgz",
"integrity": "sha512-6ypT4XfgqJk/F3Yuv4SX26I3doUjt0GTG4a+JgWxXQpxXzTBq8fPUeGHfcYMMDPHJHm3yPOSjaeBwBGAHWXCdA==",
"engines": {
"node": ">=10"
"node": ">=18.0.0"
}
},
"node_modules/farmhash/node_modules/node-addon-api": {
"version": "5.1.0",
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-5.1.0.tgz",
"integrity": "sha512-eh0GgfEkpnoWDq+VY8OyvYhFEzBk6jIYbRKdIlyTiAXIVJ8PyBaKb0rp7oDtoddbdoHWhq8wwr+XZ81F1rpNdA=="
},
"node_modules/fast-copy": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/fast-copy/-/fast-copy-3.0.1.tgz",
@@ -3051,9 +3125,9 @@
"integrity": "sha512-VhXlQgj9ioXCqGstD37E/HBeqEGV/qOD/kmbVG8h5xKBYvM1L3lR1Zn4555cQ8GkYbJa8aJSipLPndE1k6zK2w=="
},
"node_modules/fast-xml-parser": {
"version": "4.3.6",
"resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.3.6.tgz",
"integrity": "sha512-M2SovcRxD4+vC493Uc2GZVcZaj66CCJhWurC4viynVSTvrpErCShNcDz1lAho6n9REQKvL/ll4A4/fw6Y9z8nw==",
"version": "4.4.1",
"resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.4.1.tgz",
"integrity": "sha512-xkjOecfnKGkSsOwtZ5Pz7Us/T6mrbPQrq0nh+aCO5V9nk5NLWmasAHumTKjiPJPWANe+kAZ84Jc8ooJkzZ88Sw==",
"funding": [
{
"type": "github",
@@ -3116,9 +3190,9 @@
}
},
"node_modules/fill-range": {
"version": "7.0.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
"integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
"version": "7.1.1",
"resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
"integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
"dependencies": {
"to-regex-range": "^5.0.1"
},
@@ -3144,35 +3218,46 @@
}
},
"node_modules/firebase-admin": {
"version": "12.1.0",
"resolved": "https://registry.npmjs.org/firebase-admin/-/firebase-admin-12.1.0.tgz",
"integrity": "sha512-bU7uPKMmIXAihWxntpY/Ma9zucn5y3ec+HQPqFQ/zcEfP9Avk9E/6D8u+yT/VwKHNZyg7yDVWOoJi73TIdR4Ww==",
"version": "12.3.1",
"resolved": "https://registry.npmjs.org/firebase-admin/-/firebase-admin-12.3.1.tgz",
"integrity": "sha512-vEr3s3esl8nPIA9r/feDT4nzIXCfov1CyyCSpMQWp6x63Q104qke0MEGZlrHUZVROtl8FLus6niP/M9I1s4VBA==",
"dependencies": {
"@fastify/busboy": "^2.1.0",
"@fastify/busboy": "^3.0.0",
"@firebase/database-compat": "^1.0.2",
"@firebase/database-types": "^1.0.0",
"@types/node": "^20.10.3",
"farmhash": "^3.3.0",
"@types/node": "^22.0.1",
"farmhash-modern": "^1.1.0",
"jsonwebtoken": "^9.0.0",
"jwks-rsa": "^3.0.1",
"long": "^5.2.3",
"jwks-rsa": "^3.1.0",
"node-forge": "^1.3.1",
"uuid": "^9.0.0"
"uuid": "^10.0.0"
},
"engines": {
"node": ">=14"
},
"optionalDependencies": {
"@google-cloud/firestore": "^7.1.0",
"@google-cloud/firestore": "^7.7.0",
"@google-cloud/storage": "^7.7.0"
}
},
"node_modules/firebase-admin/node_modules/@types/node": {
"version": "20.12.7",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.12.7.tgz",
"integrity": "sha512-wq0cICSkRLVaf3UGLMGItu/PtdY7oaXaI/RVU+xliKVOtRna3PRY57ZDfztpDL0n11vfymMUnXv8QwYCO7L1wg==",
"version": "22.2.0",
"resolved": "https://registry.npmjs.org/@types/node/-/node-22.2.0.tgz",
"integrity": "sha512-bm6EG6/pCpkxDf/0gDNDdtDILMOHgaQBVOJGdwsqClnxA3xL6jtMv76rLBc006RVMWbmaf0xbmom4Z/5o2nRkQ==",
"dependencies": {
"undici-types": "~5.26.4"
"undici-types": "~6.13.0"
}
},
"node_modules/firebase-admin/node_modules/uuid": {
"version": "10.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz",
"integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==",
"funding": [
"https://github.com/sponsors/broofa",
"https://github.com/sponsors/ctavan"
],
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/follow-redirects": {
@@ -3400,21 +3485,21 @@
}
},
"node_modules/google-gax": {
"version": "4.3.2",
"resolved": "https://registry.npmjs.org/google-gax/-/google-gax-4.3.2.tgz",
"integrity": "sha512-2mw7qgei2LPdtGrmd1zvxQviOcduTnsvAWYzCxhOWXK4IQKmQztHnDQwD0ApB690fBQJemFKSU7DnceAy3RLzw==",
"version": "4.3.9",
"resolved": "https://registry.npmjs.org/google-gax/-/google-gax-4.3.9.tgz",
"integrity": "sha512-tcjQr7sXVGMdlvcG25wSv98ap1dtF4Z6mcV0rztGIddOcezw4YMb/uTXg72JPrLep+kXcVjaJjg6oo3KLf4itQ==",
"optional": true,
"dependencies": {
"@grpc/grpc-js": "~1.10.0",
"@grpc/proto-loader": "^0.7.0",
"@grpc/grpc-js": "^1.10.9",
"@grpc/proto-loader": "^0.7.13",
"@types/long": "^4.0.0",
"abort-controller": "^3.0.0",
"duplexify": "^4.0.0",
"google-auth-library": "^9.3.0",
"node-fetch": "^2.6.1",
"node-fetch": "^2.7.0",
"object-hash": "^3.0.0",
"proto3-json-serializer": "^2.0.0",
"protobufjs": "7.2.6",
"proto3-json-serializer": "^2.0.2",
"protobufjs": "^7.3.2",
"retry-request": "^7.0.0",
"uuid": "^9.0.1"
},
@@ -3435,9 +3520,9 @@
}
},
"node_modules/google-gax/node_modules/debug": {
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"version": "4.3.6",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.6.tgz",
"integrity": "sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==",
"optional": true,
"dependencies": {
"ms": "2.1.2"
@@ -3452,21 +3537,34 @@
}
},
"node_modules/google-gax/node_modules/gaxios": {
"version": "6.5.0",
"resolved": "https://registry.npmjs.org/gaxios/-/gaxios-6.5.0.tgz",
"integrity": "sha512-R9QGdv8j4/dlNoQbX3hSaK/S0rkMijqjVvW3YM06CoBdbU/VdKd159j4hePpng0KuE6Lh6JJ7UdmVGJZFcAG1w==",
"version": "6.7.0",
"resolved": "https://registry.npmjs.org/gaxios/-/gaxios-6.7.0.tgz",
"integrity": "sha512-DSrkyMTfAnAm4ks9Go20QGOcXEyW/NmZhvTYBU2rb4afBB393WIMQPWPEDMl/k8xqiNN9HYq2zao3oWXsdl2Tg==",
"optional": true,
"dependencies": {
"extend": "^3.0.2",
"https-proxy-agent": "^7.0.1",
"is-stream": "^2.0.0",
"node-fetch": "^2.6.9",
"uuid": "^9.0.1"
"uuid": "^10.0.0"
},
"engines": {
"node": ">=14"
}
},
"node_modules/google-gax/node_modules/gaxios/node_modules/uuid": {
"version": "10.0.0",
"resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz",
"integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==",
"funding": [
"https://github.com/sponsors/broofa",
"https://github.com/sponsors/ctavan"
],
"optional": true,
"bin": {
"uuid": "dist/bin/uuid"
}
},
"node_modules/google-gax/node_modules/gcp-metadata": {
"version": "6.1.0",
"resolved": "https://registry.npmjs.org/gcp-metadata/-/gcp-metadata-6.1.0.tgz",
@@ -3481,9 +3579,9 @@
}
},
"node_modules/google-gax/node_modules/google-auth-library": {
"version": "9.8.0",
"resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-9.8.0.tgz",
"integrity": "sha512-TJJXFzMlVGRlIH27gYZ6XXyPf5Y3OItsKFfefsDAafNNywYRTkei83nEO29IrYj8GtdHWU78YnW+YZdaZaXIJA==",
"version": "9.13.0",
"resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-9.13.0.tgz",
"integrity": "sha512-p9Y03Uzp/Igcs36zAaB0XTSwZ8Y0/tpYiz5KIde5By+H9DCVUSYtDWZu6aFXsWTqENMb8BD/pDT3hR8NVrPkfA==",
"optional": true,
"dependencies": {
"base64-js": "^1.3.0",
@@ -3511,9 +3609,9 @@
}
},
"node_modules/google-gax/node_modules/https-proxy-agent": {
"version": "7.0.4",
"resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.4.tgz",
"integrity": "sha512-wlwpilI7YdjSkWaQ/7omYBMTliDcmCN8OLihO6I9B86g06lMyAoqgoDpV0XqoaPOKj+0DIdAvnsWfyAAhmimcg==",
"version": "7.0.5",
"resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.5.tgz",
"integrity": "sha512-1e4Wqeblerz+tMKPIq2EMGiiWW1dIjZOksyHWSUm1rmuvw/how9hBHZ38lAGj5ID4Ik6EdkOw7NmWPy6LAwalw==",
"optional": true,
"dependencies": {
"agent-base": "^7.0.2",
@@ -4053,9 +4151,9 @@
}
},
"node_modules/jose": {
"version": "4.15.5",
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.5.tgz",
"integrity": "sha512-jc7BFxgKPKi94uOvEmzlSWFFe2+vASyXaKUpdQKatWAESU2MWjDfFf0fdfc83CDKcA5QecabZeNLyfhe3yKNkg==",
"version": "4.15.9",
"resolved": "https://registry.npmjs.org/jose/-/jose-4.15.9.tgz",
"integrity": "sha512-1vUQX+IdDMVPj4k8kOxgUqlcK518yluMuGZwqlr44FS1ppZB/5GWh4rZG89erpOBOJjU/OBsnCVFfapsRz6nEA==",
"funding": {
"url": "https://github.com/sponsors/panva"
}
@@ -4127,25 +4225,25 @@
}
},
"node_modules/jwks-rsa": {
"version": "3.0.1",
"resolved": "https://registry.npmjs.org/jwks-rsa/-/jwks-rsa-3.0.1.tgz",
"integrity": "sha512-UUOZ0CVReK1QVU3rbi9bC7N5/le8ziUj0A2ef1Q0M7OPD2KvjEYizptqIxGIo6fSLYDkqBrazILS18tYuRc8gw==",
"version": "3.1.0",
"resolved": "https://registry.npmjs.org/jwks-rsa/-/jwks-rsa-3.1.0.tgz",
"integrity": "sha512-v7nqlfezb9YfHHzYII3ef2a2j1XnGeSE/bK3WfumaYCqONAIstJbrEGapz4kadScZzEt7zYCN7bucj8C0Mv/Rg==",
"dependencies": {
"@types/express": "^4.17.14",
"@types/jsonwebtoken": "^9.0.0",
"@types/express": "^4.17.17",
"@types/jsonwebtoken": "^9.0.2",
"debug": "^4.3.4",
"jose": "^4.10.4",
"jose": "^4.14.6",
"limiter": "^1.1.5",
"lru-memoizer": "^2.1.4"
"lru-memoizer": "^2.2.0"
},
"engines": {
"node": ">=14"
}
},
"node_modules/jwks-rsa/node_modules/debug": {
"version": "4.3.4",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz",
"integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==",
"version": "4.3.6",
"resolved": "https://registry.npmjs.org/debug/-/debug-4.3.6.tgz",
"integrity": "sha512-O/09Bd4Z1fBrU4VzkhFqVgpPzaGbw6Sm9FEkBT1A/YBXQFGuuSxa1dN2nxgxS34JmKXqYx8CZAwEVoJFImUXIg==",
"dependencies": {
"ms": "2.1.2"
},
@@ -4196,7 +4294,8 @@
"node_modules/long": {
"version": "5.2.3",
"resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz",
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q=="
"integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==",
"optional": true
},
"node_modules/long-timeout": {
"version": "0.1.1",
@@ -4215,28 +4314,14 @@
}
},
"node_modules/lru-memoizer": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/lru-memoizer/-/lru-memoizer-2.2.0.tgz",
"integrity": "sha512-QfOZ6jNkxCcM/BkIPnFsqDhtrazLRsghi9mBwFAzol5GCvj4EkFT899Za3+QwikCg5sRX8JstioBDwOxEyzaNw==",
"version": "2.3.0",
"resolved": "https://registry.npmjs.org/lru-memoizer/-/lru-memoizer-2.3.0.tgz",
"integrity": "sha512-GXn7gyHAMhO13WSKrIiNfztwxodVsP8IoZ3XfrJV4yH2x0/OeTO/FIaAHTY5YekdGgW94njfuKmyyt1E0mR6Ug==",
"dependencies": {
"lodash.clonedeep": "^4.5.0",
"lru-cache": "~4.0.0"
"lru-cache": "6.0.0"
}
},
"node_modules/lru-memoizer/node_modules/lru-cache": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-4.0.2.tgz",
"integrity": "sha512-uQw9OqphAGiZhkuPlpFGmdTU2tEuhxTourM/19qGJrxBPHAr/f8BT1a0i/lOclESnGatdJG/UCkP9kZB/Lh1iw==",
"dependencies": {
"pseudomap": "^1.0.1",
"yallist": "^2.0.0"
}
},
"node_modules/lru-memoizer/node_modules/yallist": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-2.1.2.tgz",
"integrity": "sha512-ncTzHV7NvsQZkYe1DW7cbDLm0YpzHmZF5r/iyP3ZnQtMiJ+pjzisCiMNI+Sj+xQF5pXhSHxSB3uDbsBTzY/c2A=="
},
"node_modules/luxon": {
"version": "3.4.2",
"resolved": "https://registry.npmjs.org/luxon/-/luxon-3.4.2.tgz",
@@ -4495,9 +4580,9 @@
"integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
},
"node_modules/node-fetch": {
"version": "2.6.9",
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz",
"integrity": "sha512-DJm/CJkZkRjKKj4Zi4BsKVZh3ValV5IR5s7LVZnW+6YMh0W1BfNA8XSs6DLMGYlId5F3KnA70uu2qepcR08Qqg==",
"version": "2.7.0",
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz",
"integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==",
"dependencies": {
"whatwg-url": "^5.0.0"
},
@@ -4981,9 +5066,9 @@
"integrity": "sha512-/1WZ8+VQjR6avWOgHeEPd7SDQmFQ1B5mC1eRXsCm5TarlNmx/wCsa5GEaxGm05BORRtyG/Ex/3xq3TuRvq57qg=="
},
"node_modules/proto3-json-serializer": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/proto3-json-serializer/-/proto3-json-serializer-2.0.1.tgz",
"integrity": "sha512-8awBvjO+FwkMd6gNoGFZyqkHZXCFd54CIYTb6De7dPaufGJ2XNW+QUNqbMr8MaAocMdb+KpsD4rxEOaTBDCffA==",
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/proto3-json-serializer/-/proto3-json-serializer-2.0.2.tgz",
"integrity": "sha512-SAzp/O4Yh02jGdRc+uIrGoe87dkN/XtwxfZ4ZyafJHymd79ozp5VG5nyZ7ygqPM5+cpLDjjGnYFUkngonyDPOQ==",
"optional": true,
"dependencies": {
"protobufjs": "^7.2.5"
@@ -4993,9 +5078,9 @@
}
},
"node_modules/protobufjs": {
"version": "7.2.6",
"resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz",
"integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==",
"version": "7.3.2",
"resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.3.2.tgz",
"integrity": "sha512-RXyHaACeqXeqAKGLDl68rQKbmObRsTIn4TYVUUug1KfS47YWCo5MacGITEryugIgZqORCvJWEk4l449POg5Txg==",
"hasInstallScript": true,
"optional": true,
"dependencies": {
@@ -5249,9 +5334,9 @@
"integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg=="
},
"node_modules/sanitize-html": {
"version": "2.12.1",
"resolved": "https://registry.npmjs.org/sanitize-html/-/sanitize-html-2.12.1.tgz",
"integrity": "sha512-Plh+JAn0UVDpBRP/xEjsk+xDCoOvMBwQUf/K+/cBAVuTbtX8bj2VB7S1sL1dssVpykqp0/KPSesHrqXtokVBpA==",
"version": "2.13.0",
"resolved": "https://registry.npmjs.org/sanitize-html/-/sanitize-html-2.13.0.tgz",
"integrity": "sha512-Xff91Z+4Mz5QiNSLdLWwjgBDm5b1RU6xBT0+12rapjiaR7SwfRdjw8f+6Rir2MXKLrDicRFHdb51hGOAxmsUIA==",
"dependencies": {
"deepmerge": "^4.2.2",
"escape-string-regexp": "^4.0.0",
@@ -5927,9 +6012,9 @@
"dev": true
},
"node_modules/undici-types": {
"version": "5.26.5",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz",
"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA=="
"version": "6.13.0",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.13.0.tgz",
"integrity": "sha512-xtFJHudx8S2DSoujjMd1WeWvn7KKWFRESZTMeL1RptAYERu29D6jphMjjY+vn96jvN3kVPDNxU/E13VTaXj6jg=="
},
"node_modules/unpipe": {
"version": "1.0.0",
+7 -5
View File
@@ -20,14 +20,14 @@
"dependencies": {
"@anthropic-ai/tokenizer": "^0.0.4",
"@aws-crypto/sha256-js": "^5.2.0",
"@huggingface/jinja": "^0.3.0",
"@node-rs/argon2": "^1.8.3",
"@smithy/eventstream-codec": "^2.1.3",
"@smithy/eventstream-serde-node": "^2.1.3",
"@smithy/protocol-http": "^3.2.1",
"@smithy/signature-v4": "^2.1.3",
"@smithy/types": "^2.10.1",
"@smithy/util-utf8": "^2.1.1",
"axios": "^1.3.5",
"axios": "^1.7.4",
"better-sqlite3": "^10.0.0",
"check-disk-space": "^3.4.0",
"cookie-parser": "^1.4.6",
@@ -38,7 +38,7 @@
"ejs": "^3.1.10",
"express": "^4.18.2",
"express-session": "^1.17.3",
"firebase-admin": "^12.1.0",
"firebase-admin": "^12.3.1",
"glob": "^10.3.12",
"googleapis": "^122.0.0",
"http-proxy-middleware": "^3.0.0-beta.1",
@@ -48,7 +48,7 @@
"node-schedule": "^2.1.1",
"pino": "^8.11.0",
"pino-http": "^8.3.3",
"sanitize-html": "2.12.1",
"sanitize-html": "^2.13.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"source-map-support": "^0.5.21",
@@ -60,6 +60,7 @@
"zod-error": "^1.5.0"
},
"devDependencies": {
"@smithy/types": "^3.3.0",
"@types/better-sqlite3": "^7.6.10",
"@types/cookie-parser": "^1.4.3",
"@types/cors": "^2.8.13",
@@ -83,7 +84,8 @@
"typescript": "^5.4.2"
},
"overrides": {
"postcss": "^8.4.31",
"braces": "^3.0.3",
"fast-xml-parser": "^4.4.1",
"follow-redirects": "^1.15.4"
}
}
+33
View File
@@ -230,6 +230,39 @@ Content-Type: application/json
]
}
###
# @name Proxy / GCP Claude -- Native Completion
POST {{proxy-host}}/proxy/gcp/claude/v1/complete
Authorization: Bearer {{proxy-key}}
anthropic-version: 2023-01-01
Content-Type: application/json
{
"model": "claude-v2",
"max_tokens_to_sample": 10,
"temperature": 0,
"stream": true,
"prompt": "What is genshin impact\n\n:Assistant:"
}
###
# @name Proxy / GCP Claude -- OpenAI-to-Anthropic API Translation
POST {{proxy-host}}/proxy/gcp/claude/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"max_tokens": 50,
"stream": true,
"messages": [
{
"role": "user",
"content": "What is genshin impact?"
}
]
}
###
# @name Proxy / Azure OpenAI -- Native Chat Completions
POST {{proxy-host}}/proxy/azure/openai/chat/completions
+2
View File
@@ -51,6 +51,8 @@ function getRandomModelFamily() {
"mistral-large",
"aws-claude",
"aws-claude-opus",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
+118
View File
@@ -0,0 +1,118 @@
// uses the aws sdk to sign a request, then uses axios to send it to the bedrock REST API manually
import axios from "axios";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
const AWS_ACCESS_KEY_ID = process.env.AWS_ACCESS_KEY_ID!;
const AWS_SECRET_ACCESS_KEY = process.env.AWS_SECRET_ACCESS_KEY!;
// Copied from amazon bedrock docs
// List models
// ListFoundationModels
// Service: Amazon Bedrock
// List of Bedrock foundation models that you can use. For more information, see Foundation models in the
// Bedrock User Guide.
// Request Syntax
// GET /foundation-models?
// byCustomizationType=byCustomizationType&byInferenceType=byInferenceType&byOutputModality=byOutputModality&byProvider=byProvider
// HTTP/1.1
// URI Request Parameters
// The request uses the following URI parameters.
// byCustomizationType (p. 38)
// List by customization type.
// Valid Values: FINE_TUNING
// byInferenceType (p. 38)
// List by inference type.
// Valid Values: ON_DEMAND | PROVISIONED
// byOutputModality (p. 38)
// List by output modality type.
// Valid Values: TEXT | IMAGE | EMBEDDING
// byProvider (p. 38)
// A Bedrock model provider.
// Pattern: ^[a-z0-9-]{1,63}$
// Request Body
// The request does not have a request body
// Run inference on a text model
// Send an invoke request to run inference on a Titan Text G1 - Express model. We set the accept
// parameter to accept any content type in the response.
// POST https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-text-express-v1/invoke
// -H accept: */*
// -H content-type: application/json
// Payload
// {"inputText": "Hello world"}
// Example response
// Response for the above request.
// -H content-type: application/json
// Payload
// <the model response>
const AMZ_REGION = "us-east-1";
const AMZ_HOST = "invoke-bedrock.us-east-1.amazonaws.com";
async function listModels() {
const httpRequest = new HttpRequest({
method: "GET",
protocol: "https:",
hostname: AMZ_HOST,
path: "/foundation-models",
headers: { ["Host"]: AMZ_HOST },
});
const signedRequest = await signRequest(httpRequest);
const response = await axios.get(
`https://${signedRequest.hostname}${signedRequest.path}`,
{ headers: signedRequest.headers }
);
console.log(response.data);
}
async function invokeModel() {
const model = "anthropic.claude-v1";
const httpRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: AMZ_HOST,
path: `/model/${model}/invoke`,
headers: {
["Host"]: AMZ_HOST,
["accept"]: "*/*",
["content-type"]: "application/json",
},
body: JSON.stringify({
temperature: 0.5,
prompt: "\n\nHuman:Hello world\n\nAssistant:",
max_tokens_to_sample: 10,
}),
});
console.log("httpRequest", httpRequest);
const signedRequest = await signRequest(httpRequest);
const response = await axios.post(
`https://${signedRequest.hostname}${signedRequest.path}`,
signedRequest.body,
{ headers: signedRequest.headers }
);
console.log(response.status);
console.log(response.headers);
console.log(response.data);
console.log("full url", response.request.res.responseUrl);
}
async function signRequest(request: HttpRequest) {
const signer = new SignatureV4({
sha256: Sha256,
credentials: {
accessKeyId: AWS_ACCESS_KEY_ID,
secretAccessKey: AWS_SECRET_ACCESS_KEY,
},
region: AMZ_REGION,
service: "bedrock",
});
return await signer.sign(request, { signingDate: new Date() });
}
// listModels();
// invokeModel();
+14 -13
View File
@@ -17,7 +17,7 @@ import {
} from "../../shared/users/schema";
import { getLastNImages } from "../../shared/file-storage/image-history";
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
import { invalidatePowChallenges } from "../../user/web/pow-captcha";
const router = Router();
@@ -268,7 +268,13 @@ router.post("/maintenance", (req, res) => {
let flash = { type: "", message: "" };
switch (action) {
case "recheck": {
const checkable: LLMService[] = ["openai", "anthropic", "aws", "azure"];
const checkable: LLMService[] = [
"openai",
"anthropic",
"aws",
"gcp",
"azure",
];
checkable.forEach((s) => keyPool.recheck(s));
const keyCount = keyPool
.list()
@@ -317,7 +323,7 @@ router.post("/maintenance", (req, res) => {
user.disabledReason = "Admin forced expiration.";
userStore.upsertUser(user);
});
invalidatePowHmacKey();
invalidatePowChallenges();
flash.type = "success";
flash.message = `${temps.length} temporary users marked for expiration.`;
break;
@@ -342,20 +348,15 @@ router.post("/maintenance", (req, res) => {
throw new HttpError(400, "Invalid difficulty" + selected);
}
config.powDifficultyLevel = selected;
invalidatePowChallenges();
break;
}
case "generateTempIpReport": {
const tempUsers = userStore
.getUsers()
.filter((u) => u.type === "temporary");
const ipv4RangeMap: Map<string, Set<string>> = new Map<
string,
Set<string>
>();
const ipv6RangeMap: Map<string, Set<string>> = new Map<
string,
Set<string>
>();
const ipv4RangeMap = new Map<string, Set<string>>();
const ipv6RangeMap = new Map<string, Set<string>>();
tempUsers.forEach((u) => {
u.ip.forEach((ip) => {
@@ -365,14 +366,14 @@ router.post("/maintenance", (req, res) => {
const subnet =
parsed.toNormalizedString().split(".").slice(0, 3).join(".") +
".0/24";
const userSet = ipv4RangeMap.get(subnet) || new Set<string>();
const userSet = ipv4RangeMap.get(subnet) || new Set();
userSet.add(u.token);
ipv4RangeMap.set(subnet, userSet);
} else if (parsed.kind() === "ipv6") {
const subnet =
parsed.toNormalizedString().split(":").slice(0, 4).join(":") +
"::/48";
const userSet = ipv6RangeMap.get(subnet) || new Set<string>();
const userSet = ipv6RangeMap.get(subnet) || new Set();
userSet.add(u.token);
ipv6RangeMap.set(subnet, userSet);
}
+1 -1
View File
@@ -43,7 +43,7 @@
<legend>Bulk Quota Management</legend>
<p>
<button id="refresh-quotas" type="button" onclick="submitForm('resetQuotas')">Refresh All Quotas</button>
Resets all users' quotas to the values set in the <code>TOKEN_QUOTA_*</code> environment variables.
Immediately refreshes all users' quotas by the configured amounts.
</p>
<p>
<button id="clear-token-counts" type="button" onclick="submitForm('resetCounts')">Clear All Token Counts</button>
+18 -9
View File
@@ -101,6 +101,10 @@
<% ["nickname", "type", "disabledAt", "disabledReason", "maxIps", "adminNote"].forEach(function (key) { %>
<input type="hidden" name="<%- key %>" value="<%- user[key] %>" />
<% }); %>
<!-- tokenRefresh_ keys are dynamically generated -->
<% Object.entries(quota).forEach(([family]) => { %>
<input type="hidden" name="tokenRefresh_<%- family %>" value="<%- user.tokenRefresh[family] || quota[family] %>" />
<% }); %>
</form>
<h3>Quota Information</h3>
@@ -111,7 +115,7 @@
<button type="submit" class="btn btn-primary">Refresh Quotas for User</button>
</form>
<% } %>
<%- include("partials/shared_quota-info", { quota, user }) %>
<%- include("partials/shared_quota-info", { quota, user, showRefreshEdit: true }) %>
<p><a href="/admin/manage/list-users">Back to User List</a></p>
@@ -122,18 +126,25 @@
const token = a.dataset.token;
const field = a.dataset.field;
const existingValue = document.querySelector(`#current-values input[name=${field}]`).value;
let value = prompt(`Enter new value for '${field}'':`, existingValue);
let value = prompt(`Enter new value for '${field}':`, existingValue);
if (value !== null) {
if (value === "") {
value = null;
}
const payload = { _csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content") };
if (field.startsWith("tokenRefresh_")) {
const family = field.slice("tokenRefresh_".length);
payload.tokenRefresh = { [family]: Number(value) };
} else {
payload[field] = value;
}
fetch(`/admin/manage/edit-user/${token}`, {
method: "POST",
credentials: "same-origin",
body: JSON.stringify({
[field]: value,
_csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content"),
}),
body: JSON.stringify(payload),
headers: { "Content-Type": "application/json", Accept: "application/json" },
})
.then((res) => Promise.all([res.ok, res.json()]))
@@ -141,9 +152,7 @@
const url = new URL(window.location.href);
const params = new URLSearchParams();
if (!ok) {
params.set("flash", `error: ${json.error.message}`);
} else {
params.set("flash", `success: User's ${field} updated.`);
alert(`Failed to edit user: ${json.message}`);
}
url.search = params.toString();
window.location.assign(url);
+25 -27
View File
@@ -45,6 +45,13 @@ type Config = {
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
*/
awsCredentials?: string;
/**
* Comma-delimited list of GCP credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and GCP region.
*
* @example `GCP_CREDENTIALS=project1:1@1.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----,project2:2@2.com:us-east5:-----BEGIN PRIVATE KEY-----xxx-----END PRIVATE KEY-----`
*/
gcpCredentials?: string;
/**
* Comma-delimited list of Azure OpenAI credentials. Each credential item
* should be a colon-delimited list of Azure resource name, deployment ID, and
@@ -349,7 +356,7 @@ type Config = {
*
* Defaults to no services, meaning image prompts are disabled. Use a comma-
* separated list. Available services are:
* openai,anthropic,google-ai,mistral-ai,aws,azure
* openai,anthropic,google-ai,mistral-ai,aws,gcp,azure
*/
allowedVisionServices: LLMService[];
/**
@@ -383,6 +390,7 @@ export const config: Config = {
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
@@ -407,40 +415,23 @@ export const config: Config = {
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 32768),
maxContextTokensAnthropic: getEnvWithDefault(
"MAX_CONTEXT_TOKENS_ANTHROPIC",
0
32768
),
maxOutputTokensOpenAI: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_OPENAI", "MAX_OUTPUT_TOKENS"],
400
1024
),
maxOutputTokensAnthropic: getEnvWithDefault(
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
400
1024
),
allowedModelFamilies: getEnvWithDefault(
"ALLOWED_MODEL_FAMILIES",
getDefaultModelFamilies()
),
allowedModelFamilies: getEnvWithDefault("ALLOWED_MODEL_FAMILIES", [
"turbo",
"gpt4",
"gpt4-32k",
"gpt4-turbo",
"gpt4o",
"claude",
"claude-opus",
"gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"azure-gpt4o",
]),
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
rejectMessage: getEnvWithDefault(
"REJECT_MESSAGE",
@@ -509,6 +500,7 @@ function generateSigningKey() {
config.googleAIKey,
config.mistralAIKey,
config.awsCredentials,
config.gcpCredentials,
config.azureCredentials,
];
if (secrets.filter((s) => s).length === 0) {
@@ -527,7 +519,7 @@ function generateSigningKey() {
}
const signingKey = generateSigningKey();
export const COOKIE_SECRET = signingKey;
export const SECRET_SIGNING_KEY = signingKey;
export async function assertConfigIsValid() {
if (process.env.MODEL_RATE_LIMIT !== undefined) {
@@ -646,6 +638,7 @@ export const OMITTED_KEYS = [
"googleAIKey",
"mistralAIKey",
"awsCredentials",
"gcpCredentials",
"azureCredentials",
"proxyKey",
"adminKey",
@@ -736,6 +729,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
"ANTHROPIC_KEY",
"GOOGLE_AI_KEY",
"AWS_CREDENTIALS",
"GCP_CREDENTIALS",
"AZURE_CREDENTIALS",
].includes(String(env))
) {
@@ -786,3 +780,7 @@ function parseCsv(val: string): string[] {
const matches = val.match(regex) || [];
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
}
function getDefaultModelFamilies(): ModelFamily[] {
return MODEL_FAMILIES.filter((f) => !f.includes("dall-e")) as ModelFamily[];
}
+11 -3
View File
@@ -12,7 +12,7 @@ import { checkCsrfToken, injectCsrfToken } from "./shared/inject-csrf";
const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
turbo: "GPT-3.5 Turbo",
turbo: "GPT-4o Mini / 3.5 Turbo",
gpt4: "GPT-4",
"gpt4-32k": "GPT-4 32k",
"gpt4-turbo": "GPT-4 Turbo",
@@ -20,13 +20,21 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"dall-e": "DALL-E",
claude: "Claude (Sonnet)",
"claude-opus": "Claude (Opus)",
"gemini-flash": "Gemini Flash",
"gemini-pro": "Gemini Pro",
"gemini-ultra": "Gemini Ultra",
"mistral-tiny": "Mistral 7B",
"mistral-small": "Mixtral Small", // Originally 8x7B, but that now refers to the older open-weight version. Mixtral Small is a newer closed-weight update to the 8x7B model.
"mistral-small": "Mistral Nemo",
"mistral-medium": "Mistral Medium",
"mistral-large": "Mistral Large",
"aws-claude": "AWS Claude (Sonnet)",
"aws-claude-opus": "AWS Claude (Opus)",
"aws-mistral-tiny": "AWS Mistral 7B",
"aws-mistral-small": "AWS Mistral Nemo",
"aws-mistral-medium": "AWS Mistral Medium",
"aws-mistral-large": "AWS Mistral Large",
"gcp-claude": "GCP Claude (Sonnet)",
"gcp-claude-opus": "GCP Claude (Opus)",
"azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",
@@ -37,7 +45,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
const converter = new showdown.Converter();
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
? `<div id="servergreeting">${fs.readFileSync("greeting.md", "utf8")}</div>`
: "";
let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0;
+9
View File
@@ -0,0 +1,9 @@
import { NextFunction, Request, Response } from "express";
export function addV1(req: Request, res: Response, next: NextFunction) {
// Clients don't consistently use the /v1 prefix so we'll add it for them.
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
req.url = `/v1${req.url}`;
}
next();
}
+32 -68
View File
@@ -46,6 +46,7 @@ const getModelsResponse = () => {
"claude-3-haiku-20240307",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-5-sonnet-20240620",
];
const models = claudeVariants.map((id) => ({
@@ -69,7 +70,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
};
/** Only used for non-streaming requests. */
const anthropicResponseHandler: ProxyResHandlerWithBody = async (
const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
@@ -128,7 +129,7 @@ export function transformAnthropicChatResponseToAnthropicText(
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAnthropicTextResponseToOpenAI(
export function transformAnthropicTextResponseToOpenAI(
anthropicBody: Record<string, any>,
req: Request
): Record<string, any> {
@@ -178,6 +179,28 @@ export function transformAnthropicChatResponseToOpenAI(
};
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
/**
* If client requests more than 4096 output tokens the request must have a
* particular version header.
* https://docs.anthropic.com/en/release-notes/api#july-15th-2024
*/
function setAnthropicBetaHeader(req: Request) {
const { max_tokens_to_sample } = req.body;
if (max_tokens_to_sample > 4096) {
req.headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15";
}
}
const anthropicProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.anthropic.com",
@@ -188,7 +211,7 @@ const anthropicProxy = createQueueMiddleware({
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
}),
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
proxyRes: createOnProxyResHandler([anthropicBlockingResponseHandler]),
error: handleProxyError,
},
// Abusing pathFilter to rewrite the paths dynamically.
@@ -212,6 +235,11 @@ const anthropicProxy = createQueueMiddleware({
}),
});
const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "anthropic" },
{ afterTransform: [setAnthropicBetaHeader] }
);
const nativeTextPreprocessor = createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-text",
@@ -267,11 +295,7 @@ anthropicRouter.get("/v1/models", handleModelRequest);
anthropicRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware({
inApi: "anthropic-chat",
outApi: "anthropic-chat",
service: "anthropic",
}),
nativeAnthropicChatPreprocessor,
anthropicProxy
);
// Anthropic text completion endpoint. Translates to Anthropic chat completion
@@ -291,65 +315,5 @@ anthropicRouter.post(
preprocessOpenAICompatRequest,
anthropicProxy
);
// Temporarily force Anthropic Text to Anthropic Chat for frontends which do not
// yet support the new model. Forces claude-3. Will be removed once common
// frontends have been updated.
anthropicRouter.post(
"/v1/:type(sonnet|opus)/:action(complete|messages)",
ipLimiter,
handleAnthropicTextCompatRequest,
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "anthropic",
}),
anthropicProxy
);
function handleAnthropicTextCompatRequest(
req: Request,
res: Response,
next: any
) {
const type = req.params.type;
const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages);
const compatModel = `claude-3-${type}-20240229`;
req.log.info(
{ type, inputModel: req.body.model, compatModel, alreadyInChatFormat },
"Handling Anthropic compatibility request"
);
if (action === "messages" || alreadyInChatFormat) {
return sendErrorToClient({
req,
res,
options: {
title: "Unnecessary usage of compatibility endpoint",
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/anthropic\` proxy endpoint instead.`,
format: "unknown",
statusCode: 400,
reqId: req.id,
obj: {
requested_endpoint: "/anthropic/" + type,
correct_endpoint: "/anthropic",
},
},
});
}
req.body.model = compatModel;
next();
}
/**
* If a client using the OpenAI compatibility endpoint requests an actual OpenAI
* model, reassigns it to Claude 3 Sonnet.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
req.body.model = "claude-3-sonnet-20240229";
}
export const anthropic = anthropicRouter;
+253
View File
@@ -0,0 +1,253 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import {
transformAnthropicChatResponseToAnthropicText,
transformAnthropicChatResponseToOpenAI,
} from "./anthropic";
/** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-text":
req.log.info("Transforming Anthropic Text back to OpenAI format");
newBody = transformAwsTextResponseToOpenAI(body, req);
break;
case "openai<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
case "anthropic-text<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to Text format");
newBody = transformAnthropicChatResponseToAnthropicText(body);
break;
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "aws-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: awsBody.completion?.trim(),
},
finish_reason: awsBody.stop_reason,
index: 0,
},
],
};
}
const awsClaudeProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
}),
});
const nativeTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const textToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes text completion prompts to aws anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint).
*/
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
textToChatPreprocessor(req, res, next);
} else {
nativeTextPreprocessor(req, res, next);
}
};
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
oaiToAwsChatPreprocessor(req, res, next);
} else {
oaiToAwsTextPreprocessor(req, res, next);
}
};
const awsClaudeRouter = Router();
// Native(ish) Anthropic text completion endpoint.
awsClaudeRouter.post(
"/v1/complete",
ipLimiter,
preprocessAwsTextRequest,
awsClaudeProxy
);
// Native Anthropic chat completion endpoint.
awsClaudeRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
),
awsClaudeProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint.
awsClaudeRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
awsClaudeProxy
);
/**
* Tries to deal with:
* - frontends sending AWS model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
* If client sends AWS model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-AWS model name to AWS model ID.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.includes("anthropic.claude")) {
return;
}
// Anthropic model names can look like:
// - claude-v1
// - claude-2.1
// - claude-3-5-sonnet-20240620-v1:0
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
const match = model.match(pattern);
// If there's no match, fallback to Claude v2 as it is most likely to be
// available on AWS.
if (!match) {
req.body.model = `anthropic.claude-v2:1`;
return;
}
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
if (instant) {
req.body.model = "anthropic.claude-instant-v1";
return;
}
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "1":
case "1.0":
req.body.model = "anthropic.claude-v1";
return;
case "2":
case "2.0":
req.body.model = "anthropic.claude-v2";
return;
case "3":
case "3.0":
if (name.includes("opus")) {
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
} else if (name.includes("haiku")) {
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
} else {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
}
return;
case "3.5":
req.body.model = "anthropic.claude-3-5-sonnet-20240620-v1:0";
return;
}
// Fallback to Claude 2.1
req.body.model = `anthropic.claude-v2:1`;
return;
}
export const awsClaude = awsClaudeRouter;
+110
View File
@@ -0,0 +1,110 @@
import { Request } from "express";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { createQueueMiddleware } from "./queue";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
if (req.inboundApi === "mistral-ai" && req.outboundApi === "mistral-text") {
newBody = transformMistralTextToMistralChat(body);
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
const awsMistralProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
});
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an AWS model, use it as-is
if (model.startsWith("mistral.")) {
return;
}
// Mistral 7B Instruct
else if (model.includes("7b")) {
req.body.model = "mistral.mistral-7b-instruct-v0:2";
}
// Mistral 8x7B Instruct
else if (model.includes("8x7b")) {
req.body.model = "mistral.mixtral-8x7b-instruct-v0:1";
}
// Mistral Large (Feb 2024)
else if (model.includes("large-2402")) {
req.body.model = "mistral.mistral-large-2402-v1:0";
}
// Mistral Large 2 (July 2024)
else if (model.includes("large")) {
req.body.model = "mistral.mistral-large-2407-v1:0";
}
// Mistral Small (Feb 2024)
else if (model.includes("small")) {
req.body.model = "mistral.mistral-small-2402-v1:0";
} else {
throw new Error(
`Can't map '${model}' to a supported AWS model ID; make sure you are requesting a Mistral model supported by Amazon Bedrock`
);
}
}
const nativeMistralChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "mistral-ai", outApi: "mistral-ai", service: "aws" },
{
beforeTransform: [detectMistralInputApi],
afterTransform: [maybeReassignModel],
}
);
const awsMistralRouter = Router();
awsMistralRouter.post(
"/v1/chat/completions",
ipLimiter,
nativeMistralChatPreprocessor,
awsMistralProxy
);
export const awsMistral = awsMistralRouter;
+59 -319
View File
@@ -1,335 +1,75 @@
import { Request, RequestHandler, Response, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
/* Shared code between AWS Claude and AWS Mistral endpoints. */
import { Request, Response, Router } from "express";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { sendErrorToClient } from "./middleware/response/error-generator";
import { addV1 } from "./add-v1";
import { awsClaude } from "./aws-claude";
import { awsMistral } from "./aws-mistral";
import { AwsBedrockKey, keyPool } from "../shared/key-management";
const LATEST_AWS_V2_MINOR_VERSION = "1";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
const awsRouter = Router();
awsRouter.get(["/:vendor?/v1/models", "/:vendor?/models"], handleModelsRequest);
awsRouter.use("/claude", addV1, awsClaude);
awsRouter.use("/mistral", addV1, awsMistral);
const MODELS_CACHE_TTL = 10000;
let modelsCache: Record<string, any> = {};
let modelsCacheTime: Record<string, number> = {};
function handleModelsRequest(req: Request, res: Response) {
if (!config.awsCredentials) return { object: "list", data: [] };
const vendor = req.params.vendor?.length
? req.params.vendor === "claude"
? "anthropic"
: req.params.vendor
: "all";
const cacheTime = modelsCacheTime[vendor] || 0;
if (new Date().getTime() - cacheTime < MODELS_CACHE_TTL) {
return res.json(modelsCache[vendor]);
}
const availableModelIds = new Set<string>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "aws") continue;
(key as AwsBedrockKey).modelIds.forEach((id) => availableModelIds.add(id));
}
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html
const variants = [
const models = [
"anthropic.claude-v2",
"anthropic.claude-v2:1",
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-sonnet-20240229-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
];
const models = variants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const awsResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-text":
req.log.info("Transforming Anthropic Text back to OpenAI format");
newBody = transformAwsTextResponseToOpenAI(body, req);
break;
case "openai<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
case "anthropic-text<-anthropic-chat":
req.log.info("Transforming AWS Anthropic Chat back to Text format");
newBody = transformAnthropicChatResponseToAnthropicText(body);
break;
}
// AWS does not always confirm the model in the response, so we have to add it
if (!newBody.model && req.body.model) {
newBody.model = req.body.model;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "aws-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: awsBody.completion?.trim(),
},
finish_reason: awsBody.stop_reason,
index: 0,
},
],
};
}
const awsProxy = createQueueMiddleware({
beforeProxy: signAwsRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
}),
});
const nativeTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const textToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "anthropic-text", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes text completion prompts to aws anthropic-chat if they need translation
* (claude-3 based models do not support the old text completion endpoint).
*/
const preprocessAwsTextRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
textToChatPreprocessor(req, res, next);
} else {
nativeTextPreprocessor(req, res, next);
}
};
const oaiToAwsTextPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-text", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
const oaiToAwsChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
if (req.body.model?.includes("claude-3")) {
oaiToAwsChatPreprocessor(req, res, next);
} else {
oaiToAwsTextPreprocessor(req, res, next);
}
};
const awsRouter = Router();
awsRouter.get("/v1/models", handleModelRequest);
// Native(ish) Anthropic text completion endpoint.
awsRouter.post("/v1/complete", ipLimiter, preprocessAwsTextRequest, awsProxy);
// Native Anthropic chat completion endpoint.
awsRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "aws" },
{ afterTransform: [maybeReassignModel] }
),
awsProxy
);
// Temporary force-Claude3 endpoint
awsRouter.post(
"/v1/sonnet/:action(complete|messages)",
ipLimiter,
handleCompatibilityRequest,
createPreprocessorMiddleware({
inApi: "anthropic-text",
outApi: "anthropic-chat",
service: "aws",
}),
awsProxy
);
// OpenAI-to-AWS Anthropic compatibility endpoint.
awsRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
awsProxy
);
/**
* Tries to deal with:
* - frontends sending AWS model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that AWS doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If client already specified an AWS Claude model ID, use it
if (model.includes("anthropic.claude")) {
return;
}
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?(-sonnet-?|-opus-?|-haiku-?)(\d*)/i;
const match = model.match(pattern);
// If there's no match, return the latest v2 model
if (!match) {
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
const instant = match[2];
const major = match[4];
const minor = match[6];
if (instant) {
req.body.model = "anthropic.claude-instant-v1";
return;
}
// There's only one v1 model
if (major === "1") {
req.body.model = "anthropic.claude-v1";
return;
}
// Try to map Anthropic API v2 models to AWS v2 models
if (major === "2") {
if (minor === "0") {
req.body.model = "anthropic.claude-v2";
return;
}
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
// AWS currently only supports one v3 model.
const variant = match[8]; // sonnet, opus, or haiku
const variantVersion = match[9];
if (major === "3") {
if (variant.includes("opus")) {
req.body.model = "anthropic.claude-3-opus-20240229-v1:0";
} else if (variant.includes("haiku")) {
req.body.model = "anthropic.claude-3-haiku-20240307-v1:0";
} else {
req.body.model = "anthropic.claude-3-sonnet-20240229-v1:0";
}
return;
}
// Fallback to latest v2 model
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
return;
}
export function handleCompatibilityRequest(
req: Request,
res: Response,
next: any
) {
const action = req.params.action;
const alreadyInChatFormat = Boolean(req.body.messages);
const compatModel = "anthropic.claude-3-sonnet-20240229-v1:0";
req.log.info(
{ inputModel: req.body.model, compatModel, alreadyInChatFormat },
"Handling AWS compatibility request"
);
if (action === "messages" || alreadyInChatFormat) {
return sendErrorToClient({
req,
res,
options: {
title: "Unnecessary usage of compatibility endpoint",
message: `Your client seems to already support the new Claude API format. This endpoint is intended for clients that do not yet support the new format.\nUse the normal \`/aws/claude\` proxy endpoint instead.`,
format: "unknown",
statusCode: 400,
reqId: req.id,
obj: {
requested_endpoint: "/aws/claude/sonnet",
correct_endpoint: "/aws/claude",
},
},
"mistral.mistral-7b-instruct-v0:2",
"mistral.mixtral-8x7b-instruct-v0:1",
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
"mistral.mistral-small-2402-v1:0",
]
.filter((id) => availableModelIds.has(id))
.map((id) => {
const vendor = id.match(/^(.*)\./)?.[1];
return {
id,
object: "model",
created: new Date().getTime(),
owned_by: vendor,
permission: [],
root: vendor,
parent: null,
};
});
}
req.body.model = compatModel;
next();
modelsCache[vendor] = {
object: "list",
data: models.filter((m) => vendor === "all" || m.root === vendor),
};
modelsCacheTime[vendor] = new Date().getTime();
return res.json(modelsCache[vendor]);
}
export const aws = awsRouter;
+7
View File
@@ -12,6 +12,7 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
// pass the _proxy_ key in this header too, instead of providing it as a
// Bearer token in the Authorization header. So we need to check both.
// Prefer the Authorization header if both are present.
// Google AI uses a key querystring parameter.
if (req.headers.authorization) {
const token = req.headers.authorization?.slice("Bearer ".length);
@@ -24,6 +25,12 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
delete req.headers["x-api-key"];
return token;
}
if (req.query.key) {
const token = req.query.key?.toString();
delete req.query.key;
return token;
}
return undefined;
}
+193
View File
@@ -0,0 +1,193 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signGcpRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.gcpCredentials) return { object: "list", data: [] };
// https://docs.anthropic.com/en/docs/about-claude/models
const variants = [
"claude-3-haiku@20240307",
"claude-3-sonnet@20240229",
"claude-3-opus@20240229",
"claude-3-5-sonnet@20240620",
];
const models = variants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "anthropic",
permission: [],
root: "claude",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const gcpResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
switch (`${req.inboundApi}<-${req.outboundApi}`) {
case "openai<-anthropic-chat":
req.log.info("Transforming Anthropic Chat back to OpenAI format");
newBody = transformAnthropicChatResponseToOpenAI(body);
break;
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
const gcpProxy = createQueueMiddleware({
beforeProxy: signGcpRequest,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
return `${signedRequest.protocol}//${signedRequest.hostname}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
error: handleProxyError,
},
}),
});
const oaiToChatPreprocessor = createPreprocessorMiddleware(
{ inApi: "openai", outApi: "anthropic-chat", service: "gcp" },
{ afterTransform: [maybeReassignModel] }
);
/**
* Routes an OpenAI prompt to either the legacy Claude text completion endpoint
* or the new Claude chat completion endpoint, based on the requested model.
*/
const preprocessOpenAICompatRequest: RequestHandler = (req, res, next) => {
oaiToChatPreprocessor(req, res, next);
};
const gcpRouter = Router();
gcpRouter.get("/v1/models", handleModelRequest);
// Native Anthropic chat completion endpoint.
gcpRouter.post(
"/v1/messages",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "anthropic-chat", outApi: "anthropic-chat", service: "gcp" },
{ afterTransform: [maybeReassignModel] }
),
gcpProxy
);
// OpenAI-to-GCP Anthropic compatibility endpoint.
gcpRouter.post(
"/v1/chat/completions",
ipLimiter,
preprocessOpenAICompatRequest,
gcpProxy
);
/**
* Tries to deal with:
* - frontends sending GCP model names even when they want to use the OpenAI-
* compatible endpoint
* - frontends sending Anthropic model names that GCP doesn't recognize
* - frontends sending OpenAI model names because they expect the proxy to
* translate them
*
* If client sends GCP model ID it will be used verbatim. Otherwise, various
* strategies are used to try to map a non-GCP model name to GCP model ID.
*/
function maybeReassignModel(req: Request) {
const model = req.body.model;
// If it looks like an GCP model, use it as-is
// if (model.includes("anthropic.claude")) {
if (model.startsWith("claude-") && model.includes("@")) {
return;
}
// Anthropic model names can look like:
// - claude-v1
// - claude-2.1
// - claude-3-5-sonnet-20240620-v1:0
const pattern =
/^(claude-)?(instant-)?(v)?(\d+)([.-](\d{1}))?(-\d+k)?(-sonnet-|-opus-|-haiku-)?(\d*)/i;
const match = model.match(pattern);
// If there's no match, fallback to Claude3 Sonnet as it is most likely to be
// available on GCP.
if (!match) {
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
return;
}
const [_, _cl, instant, _v, major, _sep, minor, _ctx, name, _rev] = match;
const ver = minor ? `${major}.${minor}` : major;
switch (ver) {
case "3":
case "3.0":
if (name.includes("opus")) {
req.body.model = "claude-3-opus@20240229";
} else if (name.includes("haiku")) {
req.body.model = "claude-3-haiku@20240307";
} else {
req.body.model = "claude-3-sonnet@20240229";
}
return;
case "3.5":
req.body.model = "claude-3-5-sonnet@20240620";
return;
}
// Fallback to Claude3 Sonnet
req.body.model = `claude-3-sonnet@${LATEST_GCP_SONNET_MINOR_VERSION}`;
return;
}
export const gcp = gcpRouter;
+80 -8
View File
@@ -16,6 +16,7 @@ import {
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
import { GoogleAIKey, keyPool } from "../shared/key-management";
let modelsCache: any = null;
let modelsCacheTime = 0;
@@ -30,9 +31,19 @@ const getModelsResponse = () => {
if (!config.googleAIKey) return { object: "list", data: [] };
const googleAIVariants = ["gemini-pro", "gemini-1.0-pro", "gemini-1.5-pro"];
const keys = keyPool
.list()
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
if (keys.length === 0) {
modelsCache = { object: "list", data: [] };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
const models = googleAIVariants.map((id) => ({
const modelIds = Array.from(
new Set(keys.map((k) => k.modelIds).flat())
).filter((id) => id.startsWith("models/gemini"));
const models = modelIds.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
@@ -109,7 +120,17 @@ const googleAIProxy = createQueueMiddleware({
},
changeOrigin: true,
selfHandleResponse: true,
logger,
// Prevent logging of the API key by HPM
logger: logger.child(
{},
{
redact: {
paths: ["*"],
censor: (v) =>
typeof v === "string" ? v.replace(/key=\S+/g, "key=xxxxxxx") : v,
},
}
),
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
@@ -120,16 +141,67 @@ const googleAIProxy = createQueueMiddleware({
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
// Native Google AI chat completion endpoint
googleAIRouter.post(
"/v1beta/models/:modelId:(generateContent|streamGenerateContent)",
ipLimiter,
createPreprocessorMiddleware(
{
inApi: "google-ai",
outApi: "google-ai",
service: "google-ai",
},
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
),
googleAIProxy
);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai",
outApi: "google-ai",
service: "google-ai",
}),
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [maybeReassignModel] }
),
googleAIProxy
);
function setStreamFlag(req: Request) {
const isStreaming = req.url.includes("streamGenerateContent");
if (isStreaming) {
req.body.stream = true;
req.isStreaming = true;
} else {
req.body.stream = false;
req.isStreaming = false;
}
}
/**
* Replaces requests for non-Google AI models with gemini-pro-1.5-latest.
* Also strips models/ from the beginning of the model IDs.
**/
function maybeReassignModel(req: Request) {
// Ensure model is on body as a lot of middleware will expect it.
const model = req.body.model || req.url.split("/").pop()?.split(":").shift();
if (!model) {
throw new Error("You must specify a model with your request.");
}
req.body.model = model;
const requested = model;
if (requested.startsWith("models/")) {
req.body.model = requested.slice("models/".length);
}
if (requested.includes("gemini")) {
return;
}
req.log.info({ requested }, "Reassigning model to gemini-pro-1.5-latest");
req.body.model = "gemini-pro-1.5-latest";
}
export const googleAI = googleAIRouter;
+14 -9
View File
@@ -16,6 +16,7 @@ const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
const ANTHROPIC_MESSAGES_ENDPOINT = "/v1/messages";
const ANTHROPIC_SONNET_COMPAT_ENDPOINT = "/v1/sonnet";
const ANTHROPIC_OPUS_COMPAT_ENDPOINT = "/v1/opus";
const GOOGLE_AI_COMPLETION_ENDPOINT = "/v1beta/models";
export function isTextGenerationRequest(req: Request) {
return (
@@ -27,6 +28,7 @@ export function isTextGenerationRequest(req: Request) {
ANTHROPIC_MESSAGES_ENDPOINT,
ANTHROPIC_SONNET_COMPAT_ENDPOINT,
ANTHROPIC_OPUS_COMPAT_ENDPOINT,
GOOGLE_AI_COMPLETION_ENDPOINT,
].some((endpoint) => req.path.startsWith(endpoint))
);
}
@@ -221,9 +223,12 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
switch (format) {
case "openai":
case "mistral-ai":
// Can be null if the model wants to invoke tools rather than return a
// completion.
return body.choices[0].message.content || "";
// Few possible values:
// - choices[0].message.content
// - choices[0].message with no content if model is invoking a tool
return body.choices?.[0]?.message?.content || "";
case "mistral-text":
return body.outputs?.[0]?.text || "";
case "openai-text":
return body.choices[0].text;
case "anthropic-chat":
@@ -260,22 +265,22 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
}
}
export function getModelFromBody(req: Request, body: Record<string, any>) {
export function getModelFromBody(req: Request, resBody: Record<string, any>) {
const format = req.outboundApi;
switch (format) {
case "openai":
case "openai-text":
return resBody.model;
case "mistral-ai":
return body.model;
case "mistral-text":
case "openai-image":
case "google-ai":
// These formats don't have a model in the response body.
return req.body.model;
case "anthropic-chat":
case "anthropic-text":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;
case "google-ai":
// Google doesn't confirm the model in the response.
return req.body.model;
return resBody.model || req.body.model;
default:
assertNever(format);
}
+1
View File
@@ -15,6 +15,7 @@ export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
export { languageFilter } from "./preprocessors/language-filter";
export { setApiFormat } from "./preprocessors/set-api-format";
export { signAwsRequest } from "./preprocessors/sign-aws-request";
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
export { validateContextSize } from "./preprocessors/validate-context-size";
export { validateVision } from "./preprocessors/validate-vision";
@@ -38,7 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
// translation now reassigns the model earlier in the request pipeline.
case "anthropic-text":
case "anthropic-chat":
assignedKey = keyPool.get("claude-v1", service, needsMultimodal);
case "mistral-ai":
case "mistral-text":
case "google-ai":
assignedKey = keyPool.get(body.model, service);
break;
case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct", service);
@@ -47,10 +50,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
assignedKey = keyPool.get("dall-e-3", service);
break;
case "openai":
case "google-ai":
case "mistral-ai":
throw new Error(
`add-key should not be called for outbound API ${outboundApi}`
`Outbound API ${outboundApi} is not supported for ${inboundApi}`
);
default:
assertNever(outboundApi);
@@ -83,6 +84,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
case "gcp":
case "google-ai":
throw new Error("add-key should not be used for this service.");
default:
@@ -1,14 +1,16 @@
import { HPMRequestCallback } from "../index";
import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models";
import { HPMRequestCallback } from "../index";
/**
* Ensures the selected model family is enabled by the proxy configuration.
**/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
*/
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req) => {
const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) {
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
throw new ForbiddenError(
`Model family '${family}' is not enabled on this proxy`
);
}
};
@@ -1,7 +1,7 @@
import type { HPMRequestCallback } from "../index";
/**
* For AWS/Azure/Google requests, the body is signed earlier in the request
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/
@@ -84,9 +84,9 @@ async function executePreprocessors(
} catch (error) {
if (error.constructor.name === "ZodError") {
const msg = error?.issues
?.map((issue: ZodIssue) => issue.message)
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
.join("; ");
req.log.info(msg, "Prompt validation failed.");
req.log.warn({ issues: msg }, "Prompt validation failed.");
} else {
req.log.error(error, "Error while executing request preprocessor");
}
@@ -143,7 +143,7 @@ const handleTestMessage: RequestHandler = (req, res) => {
};
function isTestMessage(body: any) {
const { messages, prompt } = body;
const { messages, prompt, contents } = body;
if (messages) {
return (
@@ -151,6 +151,11 @@ function isTestMessage(body: any) {
messages[0].role === "user" &&
messages[0].content === "Hi"
);
} else if (contents) {
return (
contents.length === 1 &&
contents[0].parts[0]?.text === "Hi"
);
} else {
return (
prompt?.trim() === "Human: Hi\n\nAssistant:" ||
@@ -2,39 +2,38 @@ import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
const inboundValid =
req.inboundApi === "openai" || req.inboundApi === "google-ai";
const outboundValid = req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!apisValid || !serviceValid) {
if (!inboundValid || !outboundValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model;
req.isStreaming = req.isStreaming || req.body.stream;
req.key = keyPool.get(model, "google-ai");
req.log.info(
{ key: req.key.hash, model },
{ key: req.key.hash, model, stream: req.isStreaming },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
req.isStreaming = req.isStreaming || req.body.stream;
delete req.body.stream;
const payload = { ...req.body, stream: undefined, model: undefined };
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
path: `/v1beta/models/${model}:${
req.isStreaming ? "streamGenerateContent" : "generateContent"
}?key=${req.key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",
},
body: JSON.stringify(req.body),
body: JSON.stringify(payload),
};
};
@@ -2,7 +2,6 @@ import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils";
import {
AnthropicChatMessage,
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
@@ -31,10 +30,13 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
}
case "anthropic-chat": {
req.outputTokens = req.body.max_tokens;
const prompt = {
system: req.body.system ?? "",
messages: req.body.messages,
};
let system = req.body.system ?? "";
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
const prompt = { system, messages: req.body.messages };
result = await countTokens({ req, prompt, service });
break;
}
@@ -50,9 +52,11 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
case "mistral-ai": {
case "mistral-ai":
case "mistral-text": {
req.outputTokens = req.body.max_tokens;
const prompt: MistralAIChatMessage[] = req.body.messages;
const prompt: string | MistralAIChatMessage[] =
req.body.messages ?? req.body.prompt;
result = await countTokens({ req, prompt, service });
break;
}
@@ -56,8 +56,6 @@ function getPromptFromRequest(req: Request) {
switch (service) {
case "anthropic-chat":
return flattenAnthropicMessages(body.messages);
case "anthropic-text":
return body.prompt;
case "openai":
case "mistral-ai":
return body.messages
@@ -72,8 +70,10 @@ function getPromptFromRequest(req: Request) {
return `${msg.role}: ${text}`;
})
.join("\n\n");
case "anthropic-text":
case "openai-text":
case "openai-image":
case "mistral-text":
return body.prompt;
case "google-ai":
return body.prompt.text;
@@ -1,4 +1,4 @@
import express from "express";
import express, { Request } from "express";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
@@ -6,8 +6,12 @@ import {
AnthropicV1TextSchema,
AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas";
import { keyPool } from "../../../../shared/key-management";
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import {
AWSMistralV1ChatCompletionsSchema,
AWSMistralV1TextCompletionsSchema,
} from "../../../../shared/api-schemas/mistral-ai";
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
@@ -29,56 +33,33 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
req.body.prompt = preamble + req.body.prompt;
}
// AWS uses mostly the same parameters as Anthropic, with a few removed params
// and much stricter validation on unused parameters. Rather than treating it
// as a separate schema we will use the anthropic ones and strip the unused
// parameters.
// TODO: This should happen in transform-outbound-payload.ts
let strippedParams: Record<string, unknown>;
if (req.outboundApi === "anthropic-chat") {
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
} else {
strippedParams = AnthropicV1TextSchema.pick({
prompt: true,
max_tokens_to_sample: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
}
const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region);
// AWS only uses 2023-06-01 and does not actually check this header, but we
// set it so that the stream adapter always selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01";
// If our key has an inference profile compatible with the requested model,
// we want to use the inference profile instead of the model ID when calling
// InvokeModel as that will give us higher rate limits.
const profile =
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) =>
p.includes(model)
) || model;
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
// with the headers generated by the SDK.
const newRequest = new HttpRequest({
method: "POST",
protocol: "https:",
hostname: host,
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`,
headers: {
["Host"]: host,
["content-type"]: "application/json",
},
body: JSON.stringify(strippedParams),
body: JSON.stringify(applyAwsStrictValidation(req)),
});
if (stream) {
@@ -89,7 +70,13 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
const { key, body, inboundApi, outboundApi } = req;
req.log.info(
{ key: key.hash, model: body.model, inboundApi, outboundApi },
{
key: key.hash,
model: body.model,
inferenceProfile: profile,
inboundApi,
outboundApi,
},
"Assigned AWS credentials to request"
);
@@ -128,3 +115,50 @@ async function sign(request: HttpRequest, credential: Credential) {
return signer.sign(request);
}
function applyAwsStrictValidation(req: Request): unknown {
// AWS uses vendor API formats but imposes additional (more strict) validation
// rules, namely that extraneous parameters are not allowed. We will validate
// using the vendor's zod schema but apply `.strip` to ensure that any
// extraneous parameters are removed.
let strippedParams: Record<string, unknown> = {};
switch (req.outboundApi) {
case "anthropic-text":
strippedParams = AnthropicV1TextSchema.pick({
prompt: true,
max_tokens_to_sample: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
})
.strip()
.parse(req.body);
break;
case "anthropic-chat":
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "bedrock-2023-05-31";
break;
case "mistral-ai":
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body);
break;
case "mistral-text":
strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body);
break;
default:
throw new Error("Unexpected outbound API for AWS.");
}
return strippedParams;
}
@@ -0,0 +1,202 @@
import express from "express";
import crypto from "crypto";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
export const signGcpRequest: RequestPreprocessor = async (req) => {
const serviceValid = req.service === "gcp";
if (!serviceValid) {
throw new Error("addVertexAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const { model, stream } = req.body;
req.key = keyPool.get(model, "gcp");
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request");
req.isStreaming = String(stream) === "true";
// TODO: This should happen in transform-outbound-payload.ts
let strippedParams: Record<string, unknown>;
strippedParams = AnthropicV1MessagesSchema.pick({
messages: true,
system: true,
max_tokens: true,
stop_sequences: true,
temperature: true,
top_k: true,
top_p: true,
tools: true,
tool_choice: true,
stream: true,
})
.strip()
.parse(req.body);
strippedParams.anthropic_version = "vertex-2023-10-16";
const [accessToken, credential] = await getAccessToken(req);
const host = GCP_HOST.replace("%REGION%", credential.region);
// GCP doesn't use the anthropic-version header, but we set it to ensure the
// stream adapter selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01";
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: host,
path: `/v1/projects/${credential.projectId}/locations/${credential.region}/publishers/anthropic/models/${model}:streamRawPredict`,
headers: {
["host"]: host,
["content-type"]: "application/json",
["authorization"]: `Bearer ${accessToken}`,
},
body: JSON.stringify(strippedParams),
};
};
async function getAccessToken(
req: express.Request
): Promise<[string, Credential]> {
// TODO: access token caching to reduce latency
const credential = getCredentialParts(req);
const signedJWT = await createSignedJWT(
credential.clientEmail,
credential.privateKey
);
const [accessToken, jwtError] = await exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
req.log.warn(
{ key: req.key!.hash, jwtError },
"Unable to get the access token"
);
throw new Error("The access token is invalid.");
}
return [accessToken, credential];
}
async function createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
str2ab(atob(pkey)),
{
name: "RSASSA-PKCS1-v1_5",
hash: { name: "SHA-256" },
},
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = {
alg: "RS256",
typ: "JWT",
};
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = urlSafeBase64Encode(JSON.stringify(header));
const encodedPayload = urlSafeBase64Encode(JSON.stringify(payload));
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
str2ab(unsignedToken)
);
const encodedSignature = urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
async function exchangeJwtForAccessToken(
signedJwt: string
): Promise<[string | null, string]> {
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signedJwt,
};
const r = await fetch(authUrl, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params)
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) {
return [r.access_token, ""];
}
return [null, JSON.stringify(r)];
}
function str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
function urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
type Credential = {
projectId: string;
clientEmail: string;
region: string;
privateKey: string;
};
function getCredentialParts(req: express.Request): Credential {
const [projectId, clientEmail, region, rawPrivateKey] =
req.key!.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
req.log.error(
{ key: req.key!.hash },
"GCP_CREDENTIALS isn't correctly formatted; refer to the docs"
);
throw new Error("The key assigned to this request is invalid.");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}
@@ -1,3 +1,4 @@
import { Request } from "express";
import {
API_REQUEST_VALIDATORS,
API_REQUEST_TRANSFORMERS,
@@ -12,29 +13,33 @@ import { RequestPreprocessor } from "../index";
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
const notTransformable =
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
if (alreadyTransformed || notTransformable) return;
// TODO: this should be an APIFormatTransformer
if (req.inboundApi === "mistral-ai") {
const messages = req.body.messages;
req.body.messages = fixMistralPrompt(messages);
req.log.info(
{ old: messages.length, new: req.body.messages.length },
"Fixed Mistral prompt"
if (alreadyTransformed) {
return;
} else if (notTransformable) {
// This is probably an indication of a bug in the proxy.
const { inboundApi, outboundApi, method, path } = req;
req.log.warn(
{ inboundApi, outboundApi, method, path },
"`transformOutboundPayload` called on a non-transformable request."
);
return;
}
if (sameService) {
applyMistralPromptFixes(req);
// Native prompts are those which were already provided by the client in the
// target API format. We don't need to transform them.
const isNativePrompt = req.inboundApi === req.outboundApi;
if (isNativePrompt) {
const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body: req.body },
"Request validation failed"
"Native prompt request validation failed."
);
throw result.error;
}
@@ -42,11 +47,12 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return;
}
// Prompt requires translation from one API format to another.
const transformation = `${req.inboundApi}->${req.outboundApi}` as const;
const transFn = API_REQUEST_TRANSFORMERS[transformation];
if (transFn) {
req.log.info({ transformation }, "Transforming request");
req.log.info({ transformation }, "Transforming request...");
req.body = await transFn(req);
return;
}
@@ -55,3 +61,36 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
`${transformation} proxying is not supported. Make sure your client is configured to send requests in the correct format and to the correct endpoint.`
);
};
// handles weird cases that don't fit into our abstractions
function applyMistralPromptFixes(req: Request): void {
if (req.inboundApi === "mistral-ai") {
// Mistral Chat is very similar to OpenAI but not identical and many clients
// don't properly handle the differences. We will try to validate the
// mistral prompt and try to fix it if it fails. It will be re-validated
// after this function returns.
const result = API_REQUEST_VALIDATORS["mistral-ai"].parse(req.body);
req.body.messages = fixMistralPrompt(result.messages);
req.log.info(
{ n: req.body.messages.length, prev: result.messages.length },
"Applied Mistral chat prompt fixes."
);
// If the prompt relies on `prefix: true` for the last message, we need to
// convert it to a text completions request because AWS Mistral support for
// this feature is broken.
// On Mistral La Plateforme, we can't do this because they don't expose
// a text completions endpoint.
const { messages } = req.body;
const lastMessage = messages && messages[messages.length - 1];
if (lastMessage?.role === "assistant" && req.service === "aws") {
// enable prefix if client forgot, otherwise the template will insert an
// eos token which is very unlikely to be what the client wants.
lastMessage.prefix = true;
req.outboundApi = "mistral-text";
req.log.info(
"Native Mistral chat prompt relies on assistant message prefix. Converting to text completions request."
);
}
}
}
@@ -6,8 +6,9 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const GOOGLE_AI_MAX_CONTEXT = 32000;
const MISTRAL_AI_MAX_CONTENT = 32768;
// todo: make configurable
const GOOGLE_AI_MAX_CONTEXT = 1024000;
const MISTRAL_AI_MAX_CONTENT = 131072;
/**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -37,6 +38,7 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
proxyMax = GOOGLE_AI_MAX_CONTEXT;
break;
case "mistral-ai":
case "mistral-text":
proxyMax = MISTRAL_AI_MAX_CONTENT;
break;
case "openai-image":
@@ -56,6 +58,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 16384;
} else if (model.match(/^gpt-4o/)) {
modelMax = 128000;
} else if (model.match(/^chatgpt-4o/)) {
modelMax = 128000;
} else if (model.match(/gpt-4-turbo(-\d{4}-\d{2}-\d{2})?$/)) {
modelMax = 131072;
} else if (model.match(/gpt-4-turbo(-preview)?$/)) {
@@ -80,17 +84,19 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 200000;
} else if (model.match(/^claude-3/)) {
modelMax = 200000;
} else if (model.match(/^gemini-\d{3}$/)) {
modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
modelMax = MISTRAL_AI_MAX_CONTENT;
} else if (model.match(/^gemini-/)) {
modelMax = 1024000;
} else if (model.match(/^anthropic\.claude-3/)) {
modelMax = 200000;
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
modelMax = 200000;
} else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000;
} else if (model.match(/tral/)) {
// catches mistral, mixtral, codestral, mathstral, etc. mistral models have
// no name convention and wildly different context windows so this is a
// catch-all
modelMax = MISTRAL_AI_MAX_CONTENT;
} else {
req.log.warn({ model }, "Unknown model, using 200k token limit.");
modelMax = 200000;
@@ -28,6 +28,7 @@ export const validateVision: RequestPreprocessor = async (req) => {
case "anthropic-text":
case "google-ai":
case "mistral-ai":
case "mistral-text":
case "openai-image":
case "openai-text":
return;
@@ -65,7 +65,7 @@ type ErrorGeneratorOptions = {
format: APIFormat | "unknown";
title: string;
message: string;
obj?: object;
obj?: Record<string, any>;
reqId: string | number | object;
model?: string;
statusCode?: number;
@@ -95,6 +95,23 @@ export function tryInferFormat(body: any): APIFormat | "unknown" {
return "unknown";
}
// avoid leaking upstream hostname on dns resolution error
function redactHostname(options: ErrorGeneratorOptions): ErrorGeneratorOptions {
if (!options.message.includes("getaddrinfo")) return options;
const redacted = { ...options };
redacted.message = "Could not resolve hostname";
if (typeof redacted.obj?.error === "object") {
redacted.obj = {
...redacted.obj,
error: { message: "Could not resolve hostname" },
};
}
return redacted;
}
export function sendErrorToClient({
options,
req,
@@ -104,27 +121,26 @@ export function sendErrorToClient({
req: express.Request;
res: express.Response;
}) {
const { format: inputFormat } = options;
const redactedOpts = redactHostname(options);
const { format: inputFormat } = redactedOpts;
// This is an error thrown before we know the format of the request, so we
// can't send a response in the format the client expects.
const format =
inputFormat === "unknown" ? tryInferFormat(req.body) : inputFormat;
if (format === "unknown") {
return res.status(options.statusCode || 400).json({
error: options.message,
details: options.obj,
return res.status(redactedOpts.statusCode || 400).json({
error: redactedOpts.message,
details: redactedOpts.obj,
});
}
const completion = buildSpoofedCompletion({ ...options, format });
const event = buildSpoofedSSE({ ...options, format });
const completion = buildSpoofedCompletion({ ...redactedOpts, format });
const event = buildSpoofedSSE({ ...redactedOpts, format });
const isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
if (!res.headersSent) {
res.setHeader("x-oai-proxy-error", options.title);
res.setHeader("x-oai-proxy-error-status", options.statusCode || 500);
res.setHeader("x-oai-proxy-error", redactedOpts.title);
res.setHeader("x-oai-proxy-error-status", redactedOpts.statusCode || 500);
}
if (isStreaming) {
@@ -173,6 +189,11 @@ export function buildSpoofedCompletion({
},
],
};
case "mistral-text":
return {
outputs: [{ text: content, stop_reason: title }],
model,
}
case "openai-text":
return {
id: "error-" + id,
@@ -204,13 +225,7 @@ export function buildSpoofedCompletion({
stop_sequence: null,
};
case "google-ai":
// TODO: Native Google AI non-streaming responses are not supported, this
// is an untested guess at what the response should look like.
return {
id: "error-" + id,
object: "chat.completion",
created: Date.now(),
model,
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
@@ -257,6 +272,11 @@ export function buildSpoofedSSE({
choices: [{ delta: { content }, index: 0, finish_reason: title }],
};
break;
case "mistral-text":
event = {
outputs: [{ text: content, stop_reason: title }],
};
break;
case "openai-text":
event = {
id: "cmpl-" + id,
@@ -286,7 +306,10 @@ export function buildSpoofedSSE({
};
break;
case "google-ai":
return JSON.stringify({
// TODO: google ai supports two streaming transports, SSE and JSON.
// we currently only support SSE.
// return JSON.stringify({
event = {
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
@@ -296,7 +319,8 @@ export function buildSpoofedSSE({
safetyRatings: [],
},
],
});
};
break;
case "openai-image":
return JSON.stringify(obj);
default:
@@ -22,18 +22,19 @@ import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
const pipelineAsync = promisify(pipeline);
/**
* `handleStreamedResponse` consumes and transforms a streamed response from the
* upstream service, forwarding events to the client in their requested format.
* `handleStreamedResponse` consumes a streamed response from the upstream API,
* decodes chunk-by-chunk into a stream of events, transforms those events into
* the client's requested format, and forwards the result to the client.
*
* After the entire stream has been consumed, it resolves with the full response
* body so that subsequent middleware in the chain can process it as if it were
* a non-streaming response.
* a non-streaming response (to count output tokens, track usage, etc).
*
* In the event of an error, the request's streaming flag is unset and the non-
* streaming response handler is called instead.
*
* If the error is retryable, that handler will re-enqueue the request and also
* reset the streaming flag. Unfortunately the streaming flag is set and unset
* in multiple places, so it's hard to keep track of.
* In the event of an error, the request's streaming flag is unset and the
* request is bounced back to the non-streaming response handler. If the error
* is retryable, that handler will re-enqueue the request and also reset the
* streaming flag. Unfortunately the streaming flag is set and unset in multiple
* places, so it's hard to keep track of.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
@@ -70,13 +71,21 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
logger: req.log,
};
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
// While the request is streaming, aggregator collects all events so that we
// can compile them into a single response object and publish that to the
// remaining middleware. Because we have an OpenAI transformer for every
// supported format, EventAggregator always consumes OpenAI events so that we
// only have to write one aggregator (OpenAI input) for each output format.
const aggregator = new EventAggregator(req);
// Decoder reads from the raw response buffer and produces a stream of
// discrete events in some format (text/event-stream, vnd.amazon.event-stream,
// streaming JSON, etc).
const decoder = getDecoder({ ...streamOptions, input: proxyRes });
// Adapter transforms the decoded events into server-sent events.
// Adapter consumes the decoded events and produces server-sent events so we
// have a standard event format for the client and to translate between API
// message formats.
const adapter = new SSEStreamAdapter(streamOptions);
// Aggregator compiles all events into a single response object.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Transformer converts server-sent events from one vendor's API message
// format to another.
const transformer = new SSEMessageTransformer({
+89 -9
View File
@@ -1,4 +1,5 @@
/* This file is fucking horrendous, sorry */
// TODO: extract all per-service error response handling into its own modules
import { Request, Response } from "express";
import * as http from "http";
import { config } from "../../../config";
@@ -185,6 +186,13 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
throw new HttpError(statusCode, parseError.message);
}
const service = req.key!.service;
if (service === "gcp") {
if (Array.isArray(errorPayload)) {
errorPayload = errorPayload[0];
}
}
const errorType =
errorPayload.error?.code ||
errorPayload.error?.type ||
@@ -194,21 +202,24 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
`Received error response from upstream. (${proxyRes.statusMessage})`
);
// TODO: split upstream error handling into separate modules for each service,
// this is out of control.
const service = req.key!.service;
if (service === "aws") {
// Try to standardize the error format for AWS
errorPayload.error = { message: errorPayload.message, type: errorType };
delete errorPayload.message;
} else if (service === "gcp") {
// Try to standardize the error format for GCP
if (errorPayload.error?.code) { // GCP Error
errorPayload.error = { message: errorPayload.error.message, type: errorPayload.error.status || errorPayload.error.code };
}
}
if (statusCode === 400) {
switch (service) {
case "openai":
case "google-ai":
case "mistral-ai":
case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"];
@@ -225,7 +236,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
break;
case "anthropic":
case "aws":
await handleAnthropicBadRequestError(req, errorPayload);
case "gcp":
await handleAnthropicAwsBadRequestError(req, errorPayload);
break;
case "google-ai":
await handleGoogleAIBadRequestError(req, errorPayload);
break;
default:
assertNever(service);
@@ -247,7 +262,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
);
keyPool.update(req.key!, { allowsMultimodality: false });
await reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued because key does not support multimodality.");
throw new RetryableError(
"Claude request re-enqueued because key does not support multimodality."
);
} else {
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
@@ -275,6 +292,12 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
default:
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
}
return;
case "mistral-ai":
case "gcp":
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`;
return;
}
} else if (statusCode === 429) {
switch (service) {
@@ -287,6 +310,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
await handleAwsRateLimitError(req, errorPayload);
break;
case "gcp":
await handleGcpRateLimitError(req, errorPayload);
break;
case "azure":
case "mistral-ai":
await handleAzureRateLimitError(req, errorPayload);
@@ -323,6 +349,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
case "gcp":
errorPayload.proxy_note = `The requested GCP resource might not exist, or the key might not have access to it.`;
break;
case "azure":
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
break;
@@ -347,7 +376,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
throw new HttpError(statusCode, errorPayload.error?.message);
};
async function handleAnthropicBadRequestError(
async function handleAnthropicAwsBadRequestError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
@@ -382,11 +411,13 @@ async function handleAnthropicBadRequestError(
return;
}
const isDisabled = error?.message?.match(/organization has been disabled/i);
const isDisabled =
error?.message?.match(/organization has been disabled/i) ||
error?.message?.match(/^operation not allowed/i);
if (isDisabled) {
req.log.warn(
{ key: req.key?.hash, message: error?.message },
"Anthropic key has been disabled."
"Anthropic/AWS key has been disabled."
);
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been disabled. (${error?.message})`;
@@ -427,6 +458,19 @@ async function handleAwsRateLimitError(
}
}
async function handleGcpRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
if (errorPayload.error?.type === "RESOURCE_EXHAUSTED") {
keyPool.markRateLimited(req.key!);
await reenqueueRequest(req);
throw new RetryableError("GCP rate-limited request re-enqueued.");
} else {
errorPayload.proxy_note = `Unrecognized 429 Too Many Requests error from GCP.`;
}
}
async function handleOpenAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
@@ -512,7 +556,7 @@ async function handleOpenAIRateLimitError(
// keyPool.markRateLimited(req.key!);
// break;
default:
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
errorPayload.proxy_note = `This is likely a temporary error with the API. Try again in a few seconds.`;
break;
}
return errorPayload;
@@ -534,6 +578,42 @@ async function handleAzureRateLimitError(
}
}
//{"error":{"code":400,"message":"API Key not found. Please pass a valid API key.","status":"INVALID_ARGUMENT","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"API_KEY_INVALID","domain":"googleapis.com","metadata":{"service":"generativelanguage.googleapis.com"}}]}}
//{"error":{"code":400,"message":"Gemini API free tier is not available in your country. Please enable billing on your project in Google AI Studio.","status":"FAILED_PRECONDITION"}}
async function handleGoogleAIBadRequestError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const error = errorPayload.error || {};
const { message, status, details } = error;
if (status === "INVALID_ARGUMENT") {
const reason = details?.[0]?.reason;
if (reason === "API_KEY_INVALID") {
req.log.warn(
{ key: req.key?.hash, status, reason, msg: error.message },
"Received `API_KEY_INVALID` error from Google AI. Check the configured API key."
);
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned API key is invalid.`;
}
} else if (status === "FAILED_PRECONDITION") {
if (message.match(/please enable billing/i)) {
req.log.warn(
{ key: req.key?.hash, status, msg: error.message },
"Cannot use key due to billing restrictions."
);
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned API key cannot be used.`;
}
} else {
req.log.warn(
{ key: req.key?.hash, status, msg: error.message },
"Received unexpected 400 error from Google AI."
);
}
}
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
async function handleGoogleAIRateLimitError(
req: Request,
+13 -12
View File
@@ -11,7 +11,8 @@ import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import {
AnthropicChatMessage,
flattenAnthropicMessages, GoogleAIChatMessage,
flattenAnthropicMessages,
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../shared/api-schemas";
@@ -74,8 +75,16 @@ const getPromptForRequest = (
case "mistral-ai":
return req.body.messages;
case "anthropic-chat":
return { system: req.body.system, messages: req.body.messages };
let system = req.body.system;
if (Array.isArray(system)) {
system = system
.map((m: { type: string; text: string }) => m.text)
.join("\n");
}
return { system, messages: req.body.messages };
case "openai-text":
case "anthropic-text":
case "mistral-text":
return req.body.prompt;
case "openai-image":
return {
@@ -85,8 +94,6 @@ const getPromptForRequest = (
quality: req.body.quality,
revisedPrompt: responseBody.data[0].revised_prompt,
};
case "anthropic-text":
return req.body.prompt;
case "google-ai":
return { contents: req.body.contents };
default:
@@ -113,9 +120,7 @@ const flattenMessages = (
if (isGoogleAIChatPrompt(val)) {
return val.contents
.map(({ parts, role }) => {
const text = parts
.map((p) => p.text)
.join("\n");
const text = parts.map((p) => p.text).join("\n");
return `${role}: ${text}`;
})
.join("\n");
@@ -143,11 +148,7 @@ const flattenMessages = (
function isGoogleAIChatPrompt(
val: unknown
): val is { contents: GoogleAIChatMessage[] } {
return (
typeof val === "object" &&
val !== null &&
"contents" in val
);
return typeof val === "object" && val !== null && "contents" in val;
}
function isAnthropicChatPrompt(
@@ -0,0 +1,39 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type MistralChatCompletionResponse = {
choices: {
index: number;
message: { role: string; content: string };
finish_reason: string | null;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized Mistral chat completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForMistralChat(
events: OpenAIChatCompletionStreamEvent[]
): MistralChatCompletionResponse {
let merged: MistralChatCompletionResponse = {
choices: [
{ index: 0, message: { role: "", content: "" }, finish_reason: "" },
],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
acc.choices[0].message.role = event.choices[0].delta.role ?? "assistant";
return acc;
}
acc.choices[0].finish_reason = event.choices[0].finish_reason ?? "";
if (event.choices[0].delta.content) {
acc.choices[0].message.content += event.choices[0].delta.content;
}
return acc;
}, merged);
return merged;
}
@@ -0,0 +1,33 @@
import { OpenAIChatCompletionStreamEvent } from "../index";
export type MistralTextCompletionResponse = {
outputs: {
text: string;
stop_reason: string | null;
}[];
};
/**
* Given a list of OpenAI chat completion events, compiles them into a single
* finalized Mistral text completion response so that non-streaming middleware
* can operate on it as if it were a blocking response.
*/
export function mergeEventsForMistralText(
events: OpenAIChatCompletionStreamEvent[]
): MistralTextCompletionResponse {
let merged: MistralTextCompletionResponse = {
outputs: [{ text: "", stop_reason: "" }],
};
merged = events.reduce((acc, event, i) => {
// The first event will only contain role assignment and response metadata
if (i === 0) {
return acc;
}
acc.outputs[0].text += event.choices[0].delta.content ?? "";
acc.outputs[0].stop_reason = event.choices[0].finish_reason ?? "";
return acc;
}, merged);
return merged;
}
@@ -24,7 +24,7 @@ export function getAwsEventStreamDecoder(params: {
if (eventType === "chunk") {
result = input[eventType];
} else {
// AWS unmarshaller treats non-chunk (errors and exceptions) oddly.
// AWS unmarshaller treats non-chunk events (errors and exceptions) oddly.
result = { [eventType]: input[eventType] } as any;
}
return result;
@@ -1,3 +1,4 @@
import express from "express";
import { APIFormat } from "../../../../shared/key-management";
import { assertNever } from "../../../../shared/utils";
import {
@@ -6,8 +7,13 @@ import {
mergeEventsForAnthropicText,
mergeEventsForOpenAIChat,
mergeEventsForOpenAIText,
mergeEventsForMistralChat,
mergeEventsForMistralText,
AnthropicV2StreamEvent,
OpenAIChatCompletionStreamEvent,
mistralAIToOpenAI,
MistralAIStreamEvent,
MistralChatCompletionEvent,
} from "./index";
/**
@@ -15,45 +21,70 @@ import {
* compiles them into a single finalized response for downstream middleware.
*/
export class EventAggregator {
private readonly format: APIFormat;
private readonly model: string;
private readonly requestFormat: APIFormat;
private readonly responseFormat: APIFormat;
private readonly events: OpenAIChatCompletionStreamEvent[];
constructor({ format }: { format: APIFormat }) {
constructor({ body, inboundApi, outboundApi }: express.Request) {
this.events = [];
this.format = format;
this.requestFormat = inboundApi;
this.responseFormat = outboundApi;
this.model = body.model;
}
addEvent(event: OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent) {
addEvent(
event:
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralAIStreamEvent
) {
if (eventIsOpenAIEvent(event)) {
this.events.push(event);
} else {
// horrible special case. previously all transformers' target format was
// openai, so the event aggregator could conveniently assume all incoming
// events were in openai format.
// now we have added anthropic-chat-to-text, so aggregator needs to know
// how to collapse events from two formats.
// because that is annoying, we will simply transform anthropic events to
// openai (even if the client didn't ask for openai) so we don't have to
// write aggregation logic for anthropic chat (which is also a troublesome
// stateful format).
const openAIEvent = anthropicV2ToOpenAI({
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: event.log_id || "event-aggregator-fallback",
fallbackModel: event.model || "claude-3-fallback",
});
if (openAIEvent.event) {
this.events.push(openAIEvent.event);
// now we have added some transformers that convert between non-openai
// formats, so aggregator needs to know how to collapse for more than
// just openai.
// because writing aggregation logic for every possible output format is
// annoying, we will just transform any non-openai output events to openai
// format (even if the client did not request openai at all) so that we
// still only need to write aggregators for openai SSEs.
let openAIEvent: OpenAIChatCompletionStreamEvent | undefined;
switch (this.requestFormat) {
case "anthropic-text":
assertIsAnthropicV2Event(event);
openAIEvent = anthropicV2ToOpenAI({
data: `event: completion\ndata: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: event.log_id || "fallback-" + Date.now(),
fallbackModel: event.model || this.model || "fallback-claude-3",
})?.event;
break;
case "mistral-ai":
assertIsMistralChatEvent(event);
openAIEvent = mistralAIToOpenAI({
data: `data: ${JSON.stringify(event)}\n\n`,
lastPosition: -1,
index: 0,
fallbackId: "fallback-" + Date.now(),
fallbackModel: this.model || "fallback-mistral",
})?.event;
break;
}
if (openAIEvent) {
this.events.push(openAIEvent);
}
}
}
getFinalResponse() {
switch (this.format) {
switch (this.responseFormat) {
case "openai":
case "google-ai":
case "mistral-ai":
case "google-ai": // TODO: this is probably wrong now that we support native Google Makersuite prompts
return mergeEventsForOpenAIChat(this.events);
case "openai-text":
return mergeEventsForOpenAIText(this.events);
@@ -61,10 +92,16 @@ export class EventAggregator {
return mergeEventsForAnthropicText(this.events);
case "anthropic-chat":
return mergeEventsForAnthropicChat(this.events);
case "mistral-ai":
return mergeEventsForMistralChat(this.events);
case "mistral-text":
return mergeEventsForMistralText(this.events);
case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`);
throw new Error(
`SSE aggregation not supported for ${this.responseFormat}`
);
default:
assertNever(this.format);
assertNever(this.responseFormat);
}
}
@@ -78,3 +115,17 @@ function eventIsOpenAIEvent(
): event is OpenAIChatCompletionStreamEvent {
return event?.object === "chat.completion.chunk";
}
function assertIsAnthropicV2Event(event: any): asserts event is AnthropicV2StreamEvent {
if (!event?.completion) {
throw new Error(`Bad event for Anthropic V2 SSE aggregation`);
}
}
function assertIsMistralChatEvent(
event: any
): asserts event is MistralChatCompletionEvent {
if (!event?.choices) {
throw new Error(`Bad event for Mistral SSE aggregation`);
}
}
@@ -7,6 +7,25 @@ export type SSEResponseTransformArgs<S = Record<string, any>> = {
state?: S;
};
export type MistralChatCompletionEvent = {
choices: {
index: number;
message: { role: string; content: string };
stop_reason: string | null;
}[];
};
export type MistralTextCompletionEvent = {
outputs: { text: string; stop_reason: string | null }[];
};
export type MistralAIStreamEvent = {
"amazon-bedrock-invocationMetrics"?: {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
} & (MistralChatCompletionEvent | MistralTextCompletionEvent);
export type AnthropicV2StreamEvent = {
log_id?: string;
model?: string;
@@ -41,8 +60,12 @@ export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { anthropicChatToAnthropicV2 } from "./transformers/anthropic-chat-to-anthropic-v2";
export { anthropicChatToOpenAI } from "./transformers/anthropic-chat-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { mistralAIToOpenAI } from "./transformers/mistral-ai-to-openai";
export { mistralTextToMistralChat } from "./transformers/mistral-text-to-mistral-chat";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropicText } from "./aggregators/anthropic-text";
export { mergeEventsForAnthropicChat } from "./aggregators/anthropic-chat";
export { mergeEventsForMistralChat } from "./aggregators/mistral-chat";
export { mergeEventsForMistralText } from "./aggregators/mistral-text";
@@ -11,8 +11,11 @@ import {
googleAIToOpenAI,
OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat,
mistralAIToOpenAI,
mistralTextToMistralChat,
passthroughToOpenAI,
StreamingCompletionTransformer,
MistralChatCompletionEvent,
} from "./index";
type SSEMessageTransformerOptions = TransformOptions & {
@@ -35,7 +38,9 @@ export class SSEMessageTransformer extends Transform {
private readonly inputFormat: APIFormat;
private readonly transformFn: StreamingCompletionTransformer<
// TODO: Refactor transformers to not assume only OpenAI events as output
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralChatCompletionEvent
>;
private readonly log;
private readonly fallbackId: string;
@@ -121,16 +126,17 @@ function eventIsOpenAIEvent(
function getTransformer(
responseApi: APIFormat,
version?: string,
// There's only one case where we're not transforming back to OpenAI, which is
// Anthropic Chat response -> Anthropic Text request. This parameter is only
// used for that case.
// In most cases, we are transforming back to OpenAI. Some responses can be
// translated between two non-OpenAI formats, eg Anthropic Chat -> Anthropic
// Text, or Mistral Text -> Mistral Chat.
requestApi: APIFormat = "openai"
): StreamingCompletionTransformer<
OpenAIChatCompletionStreamEvent | AnthropicV2StreamEvent
| OpenAIChatCompletionStreamEvent
| AnthropicV2StreamEvent
| MistralChatCompletionEvent
> {
switch (responseApi) {
case "openai":
case "mistral-ai":
return passthroughToOpenAI;
case "openai-text":
return openAITextToOpenAIChat;
@@ -140,10 +146,16 @@ function getTransformer(
: anthropicV2ToOpenAI;
case "anthropic-chat":
return requestApi === "anthropic-text"
? anthropicChatToAnthropicV2
? anthropicChatToAnthropicV2 // User's legacy text prompt was converted to chat, and response must be converted back to text
: anthropicChatToOpenAI;
case "google-ai":
return googleAIToOpenAI;
case "mistral-ai":
return mistralAIToOpenAI;
case "mistral-text":
return requestApi === "mistral-ai"
? mistralTextToMistralChat // User's chat request was converted to text, and response must be converted back to chat
: mistralAIToOpenAI;
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
@@ -55,8 +55,10 @@ export class SSEStreamAdapter extends Transform {
if ("completion" in eventObj) {
return ["event: completion", `data: ${event}`].join(`\n`);
} else {
} else if (eventObj.type) {
return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`);
} else {
return `data: ${event}`;
}
}
// noinspection FallThroughInSwitchStatementJS -- non-JSON data is unexpected
@@ -116,7 +118,7 @@ export class SSEStreamAdapter extends Transform {
try {
const hasParts = candidates[0].content?.parts?.length > 0;
if (hasParts) {
return `data: ${JSON.stringify(data.value ?? data)}\n`;
return `data: ${JSON.stringify(data.value ?? data)}`;
} else {
this.log.error({ event: data }, "Received bad Google AI event");
return `data: ${buildSpoofedSSE({
@@ -34,7 +34,7 @@ export const anthropicChatToOpenAI: StreamingCompletionTransformer = (
model: params.fallbackModel,
choices: [
{
index: params.index,
index: 0,
delta: { content: deltaEvent.delta.text },
finish_reason: null,
},
@@ -0,0 +1,76 @@
import { logger } from "../../../../../logger";
import { MistralAIStreamEvent, SSEResponseTransformArgs } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
const log = logger.child({
module: "sse-transformer",
transformer: "mistral-ai-to-openai",
});
export const mistralAIToOpenAI = (params: SSEResponseTransformArgs) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
if ("choices" in completionEvent) {
const newChatEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: completionEvent.choices[0].index,
delta: { content: completionEvent.choices[0].message.content },
finish_reason: completionEvent.choices[0].stop_reason,
},
],
};
return { position: -1, event: newChatEvent };
} else if ("outputs" in completionEvent) {
const newTextEvent = {
id: params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content: completionEvent.outputs[0].text },
finish_reason: completionEvent.outputs[0].stop_reason,
},
],
};
return { position: -1, event: newTextEvent };
}
// should never happen
return { position: -1 };
};
function asCompletion(event: ServerSentEvent): MistralAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data);
if (
(Array.isArray(parsed.choices) &&
parsed.choices[0].message !== undefined) ||
(Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined)
) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}
@@ -0,0 +1,63 @@
import {
MistralChatCompletionEvent,
MistralTextCompletionEvent,
StreamingCompletionTransformer,
} from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "mistral-text-to-mistral-chat",
});
/**
* Transforms an incoming Mistral Text SSE to an equivalent Mistral Chat SSE.
* This is generally used when a client sends a Mistral Chat prompt, but we
* convert it to Mistral Text before sending it to the API to work around
* some bugs in Mistral/AWS prompt templating. In these cases we need to convert
* the response back to Mistral Chat.
*/
export const mistralTextToMistralChat: StreamingCompletionTransformer<
MistralChatCompletionEvent
> = (params) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data) {
return { position: -1 };
}
const textCompletion = asTextCompletion(rawEvent);
if (!textCompletion) {
return { position: -1 };
}
const chatEvent: MistralChatCompletionEvent = {
choices: [
{
index: 0,
message: { role: "assistant", content: textCompletion.outputs[0].text },
stop_reason: textCompletion.outputs[0].stop_reason,
},
],
};
return { position: -1, event: chatEvent };
};
function asTextCompletion(
event: ServerSentEvent
): MistralTextCompletionEvent | null {
try {
const parsed = JSON.parse(event.data);
if (Array.isArray(parsed.outputs) && parsed.outputs[0].text !== undefined) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error: any) {
log.warn({ error: error.stack, event }, "Received invalid data event");
}
return null;
}
+79 -19
View File
@@ -1,4 +1,4 @@
import { RequestHandler, Router } from "express";
import express, { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
@@ -21,28 +21,48 @@ import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { BadRequestError } from "../shared/errors";
// Mistral can't settle on a single naming scheme and deprecates models within
// months of releasing them so this list is hard to keep up to date. 2024-07-28
// https://docs.mistral.ai/platform/endpoints
export const KNOWN_MISTRAL_AI_MODELS = [
// Mistral 7b (open weight, legacy)
/*
Mistral Nemo
"A 12B model built with the partnership with Nvidia. It is easy to use and a
drop-in replacement in any system using Mistral 7B that it supersedes."
*/
"open-mistral-nemo",
"open-mistral-nemo-2407",
/*
Mistral Large
"Our flagship model with state-of-the-art reasoning, knowledge, and coding
capabilities."
*/
"mistral-large-latest",
"mistral-large-2407",
"mistral-large-2402", // deprecated
/*
Codestral
"A cutting-edge generative model that has been specifically designed and
optimized for code generation tasks, including fill-in-the-middle and code
completion."
note: this uses a separate bidi completion endpoint that is not implemented
*/
"codestral-latest",
"codestral-2405",
/* So-called "Research Models" */
"open-mistral-7b",
"mistral-tiny-2312",
// Mixtral 8x7b (open weight, legacy)
"open-mixtral-8x7b",
"mistral-small-2312",
// Mixtral Small (newer 8x7b, closed weight)
"open-mistral-8x22b",
"open-codestral-mamba",
/* Deprecated production models */
"mistral-small-latest",
"mistral-small-2402",
// Mistral Medium
"mistral-medium-latest",
"mistral-medium-2312",
// Mistral Large
"mistral-large-latest",
"mistral-large-2402",
// Deprecated identifiers (2024-05-01)
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-tiny-2312",
];
let modelsCache: any = null;
@@ -89,9 +109,24 @@ const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
throw new Error("Expected body to be an object");
}
res.status(200).json({ ...body, proxy: body.proxy });
let newBody = body;
if (req.inboundApi === "mistral-text" && req.outboundApi === "mistral-ai") {
newBody = transformMistralTextToMistralChat(body);
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
export function transformMistralTextToMistralChat(textBody: any) {
return {
...textBody,
choices: [
{ message: { content: textBody.outputs[0].text, role: "assistant" } },
],
outputs: undefined,
};
}
const mistralAIProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.mistral.ai",
@@ -114,12 +149,37 @@ mistralAIRouter.get("/v1/models", handleModelRequest);
mistralAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
}),
createPreprocessorMiddleware(
{
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
},
{ beforeTransform: [detectMistralInputApi] }
),
mistralAIProxy
);
/**
* We can't determine if a request is Mistral text or chat just from the path
* because they both use the same endpoint. We need to check the request body
* for either `messages` or `prompt`.
* @param req
*/
export function detectMistralInputApi(req: Request) {
const { messages, prompt } = req.body;
if (messages) {
req.inboundApi = "mistral-ai";
req.outboundApi = "mistral-ai";
} else if (prompt && req.service === "mistral-ai") {
// Mistral La Plateforme doesn't expose a text completions endpoint.
throw new BadRequestError(
"Mistral (via La Plateforme API) does not support text completions. This format is only supported on Mistral via the AWS API."
);
} else if (prompt && req.service === "aws") {
req.inboundApi = "mistral-text";
req.outboundApi = "mistral-text";
}
}
export const mistralAI = mistralAIRouter;
+26 -10
View File
@@ -28,28 +28,44 @@ import {
// https://platform.openai.com/docs/models/overview
export const KNOWN_OPENAI_MODELS = [
// GPT4o
"gpt-4o",
"gpt-4o-2024-05-13",
"gpt-4-turbo", // alias for latest gpt4-turbo stable
"gpt-4o-2024-08-06",
// GPT4o Mini
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
// GPT4o (ChatGPT)
"chatgpt-4o-latest",
// GPT4 Turbo (superceded by GPT4o)
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09", // gpt4-turbo stable, with vision
"gpt-4-turbo-preview", // alias for latest turbo preview
"gpt-4-0125-preview", // gpt4-turbo preview 2
"gpt-4-1106-preview", // gpt4-turbo preview 1
"gpt-4-vision-preview", // gpt4-turbo preview 1 with vision
// Launch GPT4
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2024-06-13
"gpt-4-32k",
"gpt-4-32k-0314", // EOL 2024-06-13
"gpt-4-32k-0613",
"gpt-4-0314", // legacy
// GPT3.5 Turbo (superceded by GPT4o Mini)
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301", // EOL 2024-06-13
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0125", // latest turbo
"gpt-3.5-turbo-1106", // older turbo
// Text Completion
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-instruct-0914",
// Embeddings
"text-embedding-ada-002",
// Known deprecated models
"gpt-4-32k", // alias for 0613
"gpt-4-32k-0314", // EOL 2025-06-06
"gpt-4-32k-0613", // EOL 2025-06-06
"gpt-4-vision-preview", // EOL 2024-12-06
"gpt-4-1106-vision-preview", // EOL 2024-12-06
"gpt-3.5-turbo-0613", // EOL 2024-09-13
"gpt-3.5-turbo-0301", // not on the website anymore, maybe unavailable
"gpt-3.5-turbo-16k", // alias for 0613
"gpt-3.5-turbo-16k-0613", // EOL 2024-09-13
];
let modelsCache: any = null;
+14 -50
View File
@@ -22,7 +22,7 @@ import {
} from "../shared/models";
import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { getUniqueIps } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
import { sendErrorToClient } from "./middleware/response/error-generator";
@@ -30,10 +30,12 @@ import { sendErrorToClient } from "./middleware/response/error-generator";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
/** Maximum number of queue slots for Agnai.chat requests. */
const AGNAI_CONCURRENCY_LIMIT = 5;
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1;
const USER_CONCURRENCY_LIMIT = parseInt(
process.env.USER_CONCURRENCY_LIMIT ?? "1"
);
/** Maximum number of queue slots for Agnai.chat requests. */
const AGNAI_CONCURRENCY_LIMIT = USER_CONCURRENCY_LIMIT * 5;
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
const MAX_HEARTBEAT_SIZE =
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
@@ -58,39 +60,20 @@ const QUEUE_JOIN_TIMEOUT = 5000;
function getIdentifier(req: Request) {
if (req.user) return req.user.token;
if (req.risuToken) return req.risuToken;
if (isFromSharedIp(req)) return "shared-ip";
// if (isFromSharedIp(req)) return "shared-ip";
return req.ip;
}
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
getIdentifier(queued) === getIdentifier(incoming);
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
// based rate limiting but can only occupy a certain number of slots in the
// queue. Authenticated users always get a single spot in the queue.
const isSharedIp = isFromSharedIp(req);
const maxConcurrentQueuedRequests =
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
if (isSharedIp) {
// Re-enqueued requests are not counted towards the limit since they
// already made it through the queue once.
if (req.retryCount === 0) {
throw new TooManyRequestsError(
"Too many agnai.chat requests are already queued"
);
}
} else {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
throw new TooManyRequestsError(
"Your IP or user token already has another request in the queue."
);
}
// shitty hack to remove hpm's event listeners on retried requests
@@ -146,19 +129,7 @@ export async function reenqueueRequest(req: Request) {
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getModelFamilyForRequest(req) === partition)
.sort((a, b) => {
// Certain requests are exempted from IP-based rate limiting because they
// come from a shared IP address. To prevent these requests from starving
// out other requests during periods of high traffic, we sort them to the
// end of the queue.
const aIsExempted = isFromSharedIp(a);
const bIsExempted = isFromSharedIp(b);
if (aIsExempted && !bIsExempted) return 1;
if (!aIsExempted && bIsExempted) return -1;
return 0;
});
return queue.filter((req) => getModelFamilyForRequest(req) === partition);
}
export function dequeue(partition: ModelFamily): Request | undefined {
@@ -261,7 +232,6 @@ let waitTimes: {
partition: ModelFamily;
start: number;
end: number;
isDeprioritized: boolean;
}[] = [];
/** Adds a successful request to the list of wait times. */
@@ -270,7 +240,6 @@ export function trackWaitTime(req: Request) {
partition: getModelFamilyForRequest(req),
start: req.startTime!,
end: req.queueOutTime ?? Date.now(),
isDeprioritized: isFromSharedIp(req),
});
}
@@ -296,8 +265,7 @@ function calculateWaitTime(partition: ModelFamily) {
.filter((wait) => {
const isSamePartition = wait.partition === partition;
const isRecent = now - wait.end < 300 * 1000;
const isNormalPriority = !wait.isDeprioritized;
return isSamePartition && isRecent && isNormalPriority;
return isSamePartition && isRecent;
})
.map((wait) => wait.end - wait.start);
const recentAverage = recentWaits.length
@@ -311,11 +279,7 @@ function calculateWaitTime(partition: ModelFamily) {
);
const currentWaits = queue
.filter((req) => {
const isSamePartition = getModelFamilyForRequest(req) === partition;
const isNormalPriority = !isFromSharedIp(req);
return isSamePartition && isNormalPriority;
})
.filter((req) => getModelFamilyForRequest(req) === partition)
.map((req) => now - req.startTime!);
const longestCurrentWait = Math.max(...currentWaits, 0);
+15 -32
View File
@@ -1,14 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { config } from "../config";
export const SHARED_IP_ADDRESSES = new Set([
// Agnai.chat
"157.230.249.32", // old
"157.245.148.56",
"174.138.29.50",
"209.97.162.44",
]);
const ONE_MINUTE_MS = 60 * 1000;
type Timestamp = number;
@@ -20,7 +12,10 @@ const exemptedRequests: Timestamp[] = [];
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
attempt > now - ONE_MINUTE_MS;
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
/**
* Returns duration in seconds to wait before retrying for Retry-After header.
*/
const getRetryAfter = (ip: string, type: "text" | "image") => {
const now = Date.now();
const attempts = lastAttempts.get(ip) || [];
const validAttempts = attempts.filter(isRecentAttempt(now));
@@ -29,7 +24,7 @@ const getTryAgainInMs = (ip: string, type: "text" | "image") => {
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
if (validAttempts.length >= limit) {
return validAttempts[0] - now + ONE_MINUTE_MS;
return (validAttempts[0] - now + ONE_MINUTE_MS) / 1000;
} else {
lastAttempts.set(ip, [...validAttempts, now]);
return 0;
@@ -96,22 +91,11 @@ export const ipLimiter = async (
if (!textLimit && !imageLimit) return next();
if (req.user?.type === "special") return next();
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
// by many users. Instead, the request queue will limit the number of such
// requests that may wait in the queue at a time, and sorts them to the end to
// let individual users go first.
if (SHARED_IP_ADDRESSES.has(req.ip)) {
exemptedRequests.push(Date.now());
req.log.info(
{ ip: req.ip, recentExemptions: exemptedRequests.length },
"Exempting Agnai request from rate limiting."
);
return next();
}
const type = (req.baseUrl + req.path).includes("openai-image")
? "image"
: "text";
const path = req.baseUrl + req.path;
const type =
path.includes("openai-image") || path.includes("images/generations")
? "image"
: "text";
const limit = type === "image" ? imageLimit : textLimit;
// If user is authenticated, key rate limiting by their token. Otherwise, key
@@ -123,15 +107,14 @@ export const ipLimiter = async (
res.set("X-RateLimit-Remaining", remaining.toString());
res.set("X-RateLimit-Reset", reset.toString());
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
if (tryAgainInMs > 0) {
res.set("Retry-After", tryAgainInMs.toString());
const retryAfterTime = getRetryAfter(rateLimitKey, type);
if (retryAfterTime > 0) {
const waitSec = Math.ceil(retryAfterTime).toString();
res.set("Retry-After", waitSec);
res.status(429).json({
error: {
type: "proxy_rate_limited",
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${waitSec} seconds.`,
},
});
} else {
+25 -19
View File
@@ -1,42 +1,55 @@
import express, { Request, Response, NextFunction } from "express";
import { gatekeeper } from "./gatekeeper";
import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import express from "express";
import { addV1 } from "./add-v1";
import { anthropic } from "./anthropic";
import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws";
import { azure } from "./azure";
import { checkRisuToken } from "./check-risu-token";
import { gatekeeper } from "./gatekeeper";
import { gcp } from "./gcp";
import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import { sendErrorToClient } from "./middleware/response/error-generator";
const proxyRouter = express.Router();
// Remove `expect: 100-continue` header from requests due to incompatibility
// with node-http-proxy.
proxyRouter.use((req, _res, next) => {
if (req.headers.expect) {
// node-http-proxy does not like it when clients send `expect: 100-continue`
// and will stall. none of the upstream APIs use this header anyway.
delete req.headers.expect;
}
next();
});
// Apply body parsers.
proxyRouter.use(
express.json({ limit: "100mb" }),
express.urlencoded({ extended: true, limit: "100mb" })
);
// Apply auth/rate limits.
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);
// Initialize request queue metadata.
proxyRouter.use((req, _res, next) => {
req.startTime = Date.now();
req.retryCount = 0;
next();
});
// Proxy endpoints.
proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/aws", aws);
proxyRouter.use("/gcp/claude", addV1, gcp);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
proxyRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
@@ -46,7 +59,8 @@ proxyRouter.get("*", (req, res, next) => {
next();
}
});
// Handle 404s.
// Send a fake client error if user specifies an invalid proxy endpoint.
proxyRouter.use((req, res) => {
sendErrorToClient({
req,
@@ -67,11 +81,3 @@ proxyRouter.use((req, res) => {
});
export { proxyRouter as proxyRouter };
function addV1(req: Request, res: Response, next: NextFunction) {
// Clients don't consistently use the /v1 prefix so we'll add it for them.
if (!req.path.startsWith("/v1/")) {
req.url = `/v1${req.url}`;
}
next();
}
+10
View File
@@ -49,6 +49,7 @@ app.use(
// Don't log the prompt text on transform errors
"body.messages",
"body.prompt",
"body.contents",
],
censor: "********",
},
@@ -87,6 +88,15 @@ app.use(blacklist);
app.use(checkOrigin);
app.use("/admin", adminRouter);
app.use((req, _, next) => {
// For whatever reason SillyTavern just ignores the path a user provides
// when using Google AI with reverse proxy. We'll fix it here.
if (req.path.startsWith("/v1beta/models/")) {
req.url = `${config.proxyEndpointRoute}/google-ai${req.url}`;
return next();
}
next();
});
app.use(config.proxyEndpointRoute, proxyRouter);
app.use("/user", userRouter);
if (config.staticServiceInfo) {
+112 -121
View File
@@ -2,8 +2,7 @@ import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GoogleAIKey,
GcpKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
@@ -11,6 +10,7 @@ import {
AnthropicModelFamily,
assertIsKnownModelFamily,
AwsBedrockModelFamily,
GcpModelFamily,
AzureOpenAIModelFamily,
GoogleAIModelFamily,
LLM_SERVICES,
@@ -24,22 +24,16 @@ import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
import { getUniqueIps } from "./proxy/rate-limit";
import { assertNever } from "./shared/utils";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
const CACHE_TTL = 2000;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai";
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
const keyIsGcpKey = (k: KeyPoolKey): k is GcpKey => k.service === "gcp";
/** Stats aggregated across all keys for a given service. */
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
@@ -51,10 +45,15 @@ type ModelAggregates = {
overQuota?: number;
pozzed?: number;
awsLogged?: number;
awsSonnet?: number;
awsHaiku?: number;
// needed to disambugiate aws-claude family's variants
awsClaude2?: number;
awsSonnet3?: number;
awsSonnet3_5?: number;
awsHaiku: number;
gcpSonnet?: number;
gcpSonnet35?: number;
gcpHaiku?: number;
queued: number;
queueTime: string;
tokens: number;
};
/** All possible combinations of model family and aggregate type. */
@@ -86,8 +85,10 @@ type AnthropicInfo = BaseFamilyInfo & {
};
type AwsInfo = BaseFamilyInfo & {
privacy?: string;
sonnetKeys?: number;
haikuKeys?: number;
enabledVariants?: string;
};
type GcpInfo = BaseFamilyInfo & {
enabledVariants?: string;
};
// prettier-ignore
@@ -95,12 +96,11 @@ export type ServiceInfo = {
uptime: number;
endpoints: {
openai?: string;
openai2?: string;
anthropic?: string;
"anthropic-claude-3"?: string;
"google-ai"?: string;
"mistral-ai"?: string;
aws?: string;
"aws"?: string;
gcp?: string;
azure?: string;
"openai-image"?: string;
"azure-image"?: string;
@@ -114,6 +114,7 @@ export type ServiceInfo = {
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
& { [f in AwsBedrockModelFamily]?: AwsInfo }
& { [f in GcpModelFamily]?: GcpInfo }
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
@@ -136,7 +137,6 @@ export type ServiceInfo = {
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
openai: {
openai: `%BASE%/openai`,
openai2: `%BASE%/openai/turbo-instruct`,
"openai-image": `%BASE%/openai-image`,
},
anthropic: {
@@ -149,7 +149,11 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: {
aws: `%BASE%/aws/claude`,
"aws-claude": `%BASE%/aws/claude`,
"aws-mistral": `%BASE%/aws/mistral`,
},
gcp: {
gcp: `%BASE%/gcp/claude`,
},
azure: {
azure: `%BASE%/azure/openai`,
@@ -157,7 +161,7 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
},
};
const modelStats = new Map<ModelAggregateKey, number>();
const familyStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof AllStats, number>();
let cachedInfo: ServiceInfo | undefined;
@@ -174,7 +178,7 @@ export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
.concat("turbo")
);
modelStats.clear();
familyStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
@@ -293,131 +297,102 @@ function increment<T extends keyof AllStats | ModelAggregateKey>(
) {
map.set(key, (map.get(key) || 0) + delta);
}
const addToService = increment.bind(null, serviceStats);
const addToFamily = increment.bind(null, familyStats);
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
increment(
serviceStats,
"mistral-ai__keys",
k.service === "mistral-ai" ? 1 : 0
);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
addToService("proompts", k.promptCount);
addToService("openai__keys", k.service === "openai" ? 1 : 0);
addToService("anthropic__keys", k.service === "anthropic" ? 1 : 0);
addToService("google-ai__keys", k.service === "google-ai" ? 1 : 0);
addToService("mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
addToService("aws__keys", k.service === "aws" ? 1 : 0);
addToService("gcp__keys", k.service === "gcp" ? 1 : 0);
addToService("azure__keys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
const incrementGenericFamilyStats = (f: ModelFamily) => {
const tokens = (k as any)[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
addToFamily(`${f}__tokens`, tokens);
addToFamily(`${f}__revoked`, k.isRevoked ? 1 : 0);
addToFamily(`${f}__active`, k.isDisabled ? 0 : 1);
};
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openai__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
addToService("openai__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
incrementGenericFamilyStats(f);
addToFamily(`${f}__trial`, k.isTrial ? 1 : 0);
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
case "anthropic":
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
addToService("anthropic__uncheckedKeys", Boolean(k.lastChecked) ? 0 : 1);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__trial`, k.tier === "free" ? 1 : 0);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
increment(modelStats, `${f}__pozzed`, k.isPozzed ? 1 : 0);
});
increment(
serviceStats,
"anthropic__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-ai": {
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
const family = "gemini-pro";
sumTokens += k["gemini-proTokens"];
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break;
}
case "mistral-ai": {
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
incrementGenericFamilyStats(f);
addToFamily(`${f}__trial`, k.tier === "free" ? 1 : 0);
addToFamily(`${f}__overQuota`, k.isOverQuota ? 1 : 0);
addToFamily(`${f}__pozzed`, k.isPozzed ? 1 : 0);
});
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
increment(modelStats, `aws-claude__awsSonnet`, k.sonnetEnabled ? 1 : 0);
increment(modelStats, `aws-claude__awsHaiku`, k.haikuEnabled ? 1 : 0);
k.modelFamilies.forEach(incrementGenericFamilyStats);
if (!k.isDisabled) {
// Don't add revoked keys to available AWS variants
k.modelIds.forEach((id) => {
if (id.includes("claude-3-sonnet")) {
addToFamily(`aws-claude__awsSonnet3`, 1);
} else if (id.includes("claude-3-5-sonnet")) {
addToFamily(`aws-claude__awsSonnet3_5`, 1);
} else if (id.includes("claude-3-haiku")) {
addToFamily(`aws-claude__awsHaiku`, 1);
} else if (id.includes("claude-v2")) {
addToFamily(`aws-claude__awsClaude2`, 1);
}
});
}
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus === "enabled";
increment(modelStats, `aws-claude__awsLogged`, countAsLogged ? 1 : 0);
addToFamily(`aws-claude__awsLogged`, countAsLogged ? 1 : 0);
break;
}
case "gcp":
if (!keyIsGcpKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach(incrementGenericFamilyStats);
// TODO: add modelIds to GcpKey
break;
// These services don't have any additional stats to track.
case "azure":
case "google-ai":
case "mistral-ai":
k.modelFamilies.forEach(incrementGenericFamilyStats);
break;
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
addToService("tokens", sumTokens);
addToService("tokenCost", sumCost);
}
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = modelStats.get(`${family}__tokens`) || 0;
const tokens = familyStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo & GcpInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
activeKeys: modelStats.get(`${family}__active`) || 0,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
activeKeys: familyStats.get(`${family}__active`) || 0,
revokedKeys: familyStats.get(`${family}__revoked`) || 0,
};
// Add service-specific stats to the info object.
@@ -425,8 +400,8 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const service = MODEL_FAMILY_SERVICE[family];
switch (service) {
case "openai":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
// Delete trial/revoked keys for non-turbo families.
// Trials are turbo 99% of the time, and if a key is invalid we don't
@@ -437,15 +412,25 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
}
break;
case "anthropic":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
info.prefilledKeys = modelStats.get(`${family}__pozzed`) || 0;
info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0;
info.trialKeys = familyStats.get(`${family}__trial`) || 0;
info.prefilledKeys = familyStats.get(`${family}__pozzed`) || 0;
break;
case "aws":
if (family === "aws-claude") {
info.sonnetKeys = modelStats.get(`${family}__awsSonnet`) || 0;
info.haikuKeys = modelStats.get(`${family}__awsHaiku`) || 0;
const logged = modelStats.get(`${family}__awsLogged`) || 0;
const logged = familyStats.get(`${family}__awsLogged`) || 0;
const variants = new Set<string>();
if (familyStats.get(`${family}__awsClaude2`) || 0)
variants.add("claude2");
if (familyStats.get(`${family}__awsSonnet3`) || 0)
variants.add("sonnet3");
if (familyStats.get(`${family}__awsSonnet3_5`) || 0)
variants.add("sonnet3.5");
if (familyStats.get(`${family}__awsHaiku`) || 0)
variants.add("haiku");
info.enabledVariants = variants.size
? `${Array.from(variants).join(",")}`
: undefined;
if (logged > 0) {
info.privacy = config.allowAwsLogging
? `AWS logging verification inactive. Prompts could be logged.`
@@ -453,6 +438,12 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
}
}
break;
case "gcp":
if (family === "gcp-claude") {
// TODO: implement
info.enabledVariants = "not implemented";
}
break;
}
}
+23 -1
View File
@@ -19,7 +19,12 @@ const AnthropicV1BaseSchema = z
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.object({ user_id: z.string().optional() }).optional(),
tools: z.array(z.any()).optional(),
tool_choice: z.any().optional(),
})
.omit(
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true }
)
.strip();
// https://docs.anthropic.com/claude/reference/complete_post [deprecated]
@@ -44,6 +49,18 @@ const AnthropicV1MessageMultimodalContentSchema = z.array(
data: z.string(),
}),
}),
z.object({
type: z.literal("tool_use"),
id: z.string(),
name: z.string(),
input: z.object({}).passthrough(),
}),
z.object({
type: z.literal("tool_result"),
tool_use_id: z.string(),
is_error: z.boolean().optional(),
content: z.union([z.string(), z.object({}).passthrough()]).optional(),
}),
])
);
@@ -63,7 +80,12 @@ export const AnthropicV1MessagesSchema = AnthropicV1BaseSchema.merge(
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
system: z.string().optional(),
system: z
.union([
z.string(),
z.array(z.object({ type: z.literal("text"), text: z.string() })),
])
.optional(),
})
);
export type AnthropicChatMessage = z.infer<
-181
View File
@@ -1,181 +0,0 @@
import { z } from "zod";
import {
OPENAI_OUTPUT_MAX,
OpenAIV1ChatCompletionSchema,
flattenOpenAIMessageContent,
} from "./openai";
import { APIFormatTransformer } from ".";
// https://docs.cohere.com/reference/chat
export const CohereV1ChatSchema = z
.object({
message: z.string(),
model: z.string().default("command-r-plus"),
stream: z.boolean().default(false).optional(),
preamble: z.string().optional(),
chat_history: z
.array(
// Either a message from a chat participant, or a past tool call
z.union([
z.object({
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
message: z.string(),
tool_calls: z
.array(z.object({ name: z.string(), parameters: z.any() }))
.optional(),
}),
z.object({
role: z.enum(["TOOL"]),
tool_results: z.array(
z.object({
call: z.object({ name: z.string(), parameters: z.any() }),
outputs: z.array(z.any()),
})
),
}),
])
)
.optional(),
// Don't allow conversation_id as it causes calls to be stateful and we don't
// offer guarantees about which key a user's request will be routed to.
conversation_id: z.literal(undefined).optional(),
prompt_truncation: z
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
.optional(),
/*
Supporting RAG is complex because documents can be arbitrary size and have
to have embeddings generated, which incurs a cost that is not trivial to
estimate. We don't support it for now.
connectors: z
.array(
z.object({
id: z.string(),
user_access_token: z.string().optional(),
continue_on_failure: z.boolean().default(false).optional(),
options: z.any().optional(),
})
)
.optional(),
search_queries_only: z.boolean().default(false).optional(),
documents: z
.array(
z.object({
id: z.string().optional(),
title: z.string().optional(),
text: z.string(),
_excludes: z.array(z.string()).optional(),
})
)
.optional(),
citation_quality: z.enum(["accurate", "fast"]).optional(),
*/
temperature: z.number().default(0.3).optional(),
max_tokens: z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
max_input_tokens: z.number().int().optional(),
k: z.number().int().min(0).max(500).default(0).optional(),
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
seed: z.number().int().optional(),
stop_sequences: z.array(z.string()).max(5).optional(),
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
presence_penalty: z.number().min(0).max(1).default(0).optional(),
tools: z
.array(
z.object({
name: z.string(),
description: z.string(),
parameter_definitions: z.record(
z.object({
description: z.string().optional(),
type: z.string(),
required: z.boolean().optional().default(false),
})
),
})
)
.optional(),
tool_results: z
.array(
z.object({
call: z.object({
name: z.string(),
parameters: z.record(z.any()),
}),
outputs: z.array(z.record(z.any())),
})
)
.optional(),
// We always force single step to avoid stateful calls or expensive multi-step
// generations when tools are involved.
force_single_step: z.literal(true).default(true).optional(),
})
.strip();
export type CohereChatMessage = NonNullable<
z.infer<typeof CohereV1ChatSchema>["chat_history"]
>[number];
export function flattenCohereMessageContent(
message: CohereChatMessage
): string {
return message.role === "TOOL"
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
: message.message;
}
export const transformOpenAIToCohere: APIFormatTransformer<
typeof CohereV1ChatSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Cohere request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
// Final OAI message becomes the `message` field in Cohere
const message = messages[messages.length - 1];
// If the first message has system role, use it as preamble.
const hasSystemPreamble = messages[0]?.role === "system";
const preamble = hasSystemPreamble
? flattenOpenAIMessageContent(messages[0].content)
: undefined;
const chatHistory = messages.slice(0, -1).map((m) => {
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
m.role === "assistant"
? "CHATBOT"
: m.role === "system"
? "SYSTEM"
: "USER";
const content = flattenOpenAIMessageContent(m.content);
const message = m.name ? `${m.name}: ${content}` : content;
return { role, message };
});
return {
model: rest.model,
preamble,
chat_history: chatHistory,
message: flattenOpenAIMessageContent(message.content),
stop_sequences:
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
max_tokens: rest.max_tokens,
temperature: rest.temperature,
p: rest.top_p,
frequency_penalty: rest.frequency_penalty,
presence_penalty: rest.presence_penalty,
seed: rest.seed,
stream: rest.stream,
};
};
+11 -10
View File
@@ -5,19 +5,20 @@ import {
} from "./openai";
import { APIFormatTransformer } from "./index";
const GoogleAIV1ContentSchema = z.object({
parts: z.array(z.object({ text: z.string() })), // TODO: add other media types
role: z.enum(["user", "model"]).optional(),
});
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
})
),
contents: z.array(GoogleAIV1ContentSchema),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).optional(),
systemInstruction: GoogleAIV1ContentSchema.optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
@@ -25,12 +26,12 @@ export const GoogleAIV1GenerateContentSchema = z
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
.transform((v) => Math.min(v, 4096)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
}).default({}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
@@ -103,7 +104,7 @@ export const transformOpenAIToGoogleAI: APIFormatTransformer<
stops = [...new Set(stops)].slice(0, 5);
return {
model: "gemini-pro",
model: req.body.model,
stream: rest.stream,
contents,
tools: [],
+7 -21
View File
@@ -21,8 +21,11 @@ import {
GoogleAIV1GenerateContentSchema,
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
import {
MistralAIV1ChatCompletionsSchema,
MistralAIV1TextCompletionsSchema,
transformMistralChatToText,
} from "./mistral-ai";
export { OpenAIChatMessage } from "./openai";
export {
@@ -34,29 +37,15 @@ export {
export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai";
/** Represents a pair of API formats that can be transformed between. */
type APIPair = `${APIFormat}->${APIFormat}`;
/** Represents a map of API format pairs to transformer functions. */
type TransformerMap = {
[key in APIPair]?: APIFormatTransformer<any>;
};
/**
* Represents a transformer function that takes a Request and returns a Promise
* resolving to a value of the specified Zod schema type.
*
* @template Z The Zod schema type to transform the request into (from api-schemas).
* @param req The incoming Request to transform.
* @returns A Promise resolving to the transformed request body.
*/
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
/**
* Specifies possible translations between API formats and the corresponding
* transformer functions to apply them.
*/
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
@@ -64,12 +53,9 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
"openai->cohere-chat": transformOpenAIToCohere,
"mistral-ai->mistral-text": transformMistralChatToText,
};
/**
* Specifies the schema for each API format to validate incoming requests.
*/
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema,
@@ -78,5 +64,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
"cohere-chat": CohereV1ChatSchema,
"mistral-text": MistralAIV1TextCompletionsSchema,
};
+120 -14
View File
@@ -1,15 +1,34 @@
import { z } from "zod";
import { OPENAI_OUTPUT_MAX } from "./openai";
import { Template } from "@huggingface/jinja";
import { APIFormatTransformer } from "./index";
import { logger } from "../../logger";
const MistralChatMessageSchema = z.object({
role: z.enum(["system", "user", "assistant", "tool"]), // TODO: implement tools
content: z.string(),
prefix: z.boolean().optional(),
});
const MistralMessagesSchema = z.array(MistralChatMessageSchema).refine(
(input) => {
const prefixIdx = input.findIndex((msg) => Boolean(msg.prefix));
if (prefixIdx === -1) return true; // no prefix messages
const lastIdx = input.length - 1;
const lastMsg = input[lastIdx];
return prefixIdx === lastIdx && lastMsg.role === "assistant";
},
{
message:
"`prefix` can only be set to `true` on the last message, and only for an assistant message.",
}
);
// https://docs.mistral.ai/api#operation/createChatCompletion
export const MistralAIV1ChatCompletionsSchema = z.object({
const BaseMistralAIV1CompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
messages: MistralMessagesSchema.optional(),
prompt: z.string().optional(),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
@@ -18,12 +37,50 @@ export const MistralAIV1ChatCompletionsSchema = z.object({
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
// Mistral docs say that `stop` can be a string or array but AWS Mistral
// blows up if a string is passed. We must convert it to an array.
stop: z
.union([z.string(), z.array(z.string())])
.optional()
.default([])
.transform((v) => (Array.isArray(v) ? v : [v])),
random_seed: z.number().int().min(0).optional(),
response_format: z
.object({ type: z.enum(["text", "json_object"]) })
.optional(),
safe_prompt: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
export const MistralAIV1ChatCompletionsSchema =
BaseMistralAIV1CompletionsSchema.and(
z.object({ messages: MistralMessagesSchema })
);
export const MistralAIV1TextCompletionsSchema =
BaseMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
/*
Slightly more strict version that only allows a subset of the parameters. AWS
Mistral helpfully returns no details if unsupported parameters are passed so
this list comes from trial and error as of 2024-08-12.
*/
const BaseAWSMistralAIV1CompletionsSchema =
BaseMistralAIV1CompletionsSchema.pick({
temperature: true,
top_p: true,
max_tokens: true,
stop: true,
random_seed: true,
// response_format: true,
// safe_prompt: true,
}).strip();
export const AWSMistralV1ChatCompletionsSchema =
BaseAWSMistralAIV1CompletionsSchema.and(
z.object({ messages: MistralMessagesSchema })
);
export const AWSMistralV1TextCompletionsSchema =
BaseAWSMistralAIV1CompletionsSchema.and(z.object({ prompt: z.string() }));
export type MistralAIChatMessage = z.infer<typeof MistralChatMessageSchema>;
export function fixMistralPrompt(
messages: MistralAIChatMessage[]
@@ -31,12 +88,11 @@ export function fixMistralPrompt(
// Mistral uses OpenAI format but has some additional requirements:
// - Only one system message per request, and it must be the first message if
// present.
// - Final message must be a user message.
// - Final message must be a user message, unless it has `prefix: true`.
// - Cannot have multiple messages from the same role in a row.
// While frontends should be able to handle this, we can fix it here in the
// meantime.
return messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
const fixed = messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
if (acc.length === 0) {
acc.push(msg);
return acc;
@@ -57,4 +113,54 @@ export function fixMistralPrompt(
}
return acc;
}, []);
// If the last message is an assistant message, mark it as a prefix. An
// assistant message at the end of the conversation without `prefix: true`
// results in an error.
if (fixed[fixed.length - 1].role === "assistant") {
fixed[fixed.length - 1].prefix = true;
}
return fixed;
}
let jinjaTemplate: Template;
let renderTemplate: (messages: MistralAIChatMessage[]) => string;
function renderMistralPrompt(messages: MistralAIChatMessage[]) {
if (!jinjaTemplate) {
logger.warn("Lazy loading mistral chat template...");
const { chatTemplate, bosToken, eosToken } =
require("./templates/mistral-template").MISTRAL_TEMPLATE;
jinjaTemplate = new Template(chatTemplate);
renderTemplate = (messages) =>
jinjaTemplate.render({
messages,
bos_token: bosToken,
eos_token: eosToken,
});
}
return renderTemplate(messages);
}
/**
* Attempts to convert a Mistral chat completions request to a text completions,
* using the official prompt template published by Mistral.
*/
export const transformMistralChatToText: APIFormatTransformer<
typeof MistralAIV1TextCompletionsSchema
> = async (req) => {
const { body } = req;
const result = MistralAIV1ChatCompletionsSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid Mistral chat completions request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = renderMistralPrompt(messages);
return { ...rest, prompt, messages: undefined };
};
+1 -1
View File
@@ -52,7 +52,7 @@ export const OpenAIV1ChatCompletionSchema = z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.default(Math.min(OPENAI_OUTPUT_MAX, 16384))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
@@ -0,0 +1,36 @@
export const MISTRAL_TEMPLATE = {
bosToken: "<s>",
eosToken: "</s>",
chatTemplate: `"{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "assistant" %}
{%- if loop.last and message.prefix is defined and message.prefix %}
{{- " " + message["content"] }}
{%- else %}
{{- " " + message["content"] + eos_token}}
{%- endif %}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}`,
};
+18
View File
@@ -0,0 +1,18 @@
/** Module for generating and verifying HMAC signatures. */
import crypto from "crypto";
import { SECRET_SIGNING_KEY } from "../config";
/**
* Generates a HMAC signature for the given message. Optionally salts the
* key with a provided string.
*/
export function signMessage(msg: any, salt: string = ""): string {
const hmac = crypto.createHmac("sha256", SECRET_SIGNING_KEY + salt);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
+2 -2
View File
@@ -1,9 +1,9 @@
import { doubleCsrf } from "csrf-csrf";
import express from "express";
import { config, COOKIE_SECRET } from "../config";
import { config, SECRET_SIGNING_KEY } from "../config";
const { generateToken, doubleCsrfProtection } = doubleCsrf({
getSecret: () => COOKIE_SECRET,
getSecret: () => SECRET_SIGNING_KEY,
cookieName: "csrf",
cookieOptions: {
sameSite: "strict",
+3 -2
View File
@@ -7,8 +7,9 @@ import * as userStore from "./users/user-store";
export const injectLocals: RequestHandler = (req, res, next) => {
// config-related locals
const quota = config.tokenQuota;
res.locals.quotasEnabled =
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
const sumOfQuotas = Object.values(quota).reduce((a, b) => a + b, 0);
res.locals.quotasEnabled = sumOfQuotas > 0;
res.locals.quota = quota;
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory";
@@ -122,7 +122,7 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 10 * 1000;
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
@@ -1,5 +1,5 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { AnthropicModelFamily, getClaudeModelFamily } from "../../models";
@@ -23,10 +23,6 @@ type AnthropicKeyUsage = {
export interface AnthropicKey extends Key, AnthropicKeyUsage {
readonly service: "anthropic";
readonly modelFamilies: AnthropicModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* Whether this key requires a special preamble. For unclear reasons, some
* Anthropic keys will throw an error if the prompt does not begin with a
@@ -217,22 +213,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
key[`${getClaudeModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+224 -46
View File
@@ -1,13 +1,31 @@
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosRequestConfig, AxiosHeaders } from "axios";
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
import { URL } from "url";
import { config } from "../../../config";
import { getAwsBedrockModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
import { AwsBedrockModelFamily } from "../../models";
import { config } from "../../../config";
type ParentModelId = string;
type AliasModelId = string;
type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]];
const KNOWN_MODEL_IDS: ModuleAliasTuple[] = [
["anthropic.claude-v2", "anthropic.claude-v2:1"],
["anthropic.claude-3-sonnet-20240229-v1:0"],
["anthropic.claude-3-haiku-20240307-v1:0"],
["anthropic.claude-3-opus-20240229-v1:0"],
["anthropic.claude-3-5-sonnet-20240620-v1:0"],
["mistral.mistral-7b-instruct-v0:2"],
["mistral.mixtral-8x7b-instruct-v0:1"],
["mistral.mistral-large-2402-v1:0"],
["mistral.mistral-large-2407-v1:0"],
["mistral.mistral-small-2402-v1:0"], // Seems to return 400
];
const KEY_CHECK_BATCH_SIZE = 2; // AWS checker needs to do lots of concurrent requests so should lower the batch size
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const AMZ_HOST =
@@ -15,6 +33,8 @@ const AMZ_HOST =
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
const GET_LIST_INFERENCE_PROFILES_URL = (region: string) =>
`https://bedrock.${region}.amazonaws.com/inference-profiles?maxResults=1000`;
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
const TEST_MESSAGES = [
@@ -24,6 +44,22 @@ const TEST_MESSAGES = [
type AwsError = { error: {} };
type GetInferenceProfilesResponse = {
inferenceProfileSummaries: {
inferenceProfileId: string;
inferenceProfileName: string;
inferenceProfileArn: string;
description?: string;
createdAt?: string;
updatedAt?: string;
status: "ACTIVE" | unknown;
type: "SYSTEM_DEFINED" | unknown;
models: {
modelArn?: string;
}[];
}[];
};
type GetLoggingConfigResponse = {
loggingConfig: null | {
cloudWatchConfig: null | unknown;
@@ -42,54 +78,67 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
service: "aws",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
keyCheckBatchSize: KEY_CHECK_BATCH_SIZE,
updateKey,
});
}
protected async testKeyOrFail(key: AwsBedrockKey) {
// Only check models on startup. For now all models must be available to
// the proxy because we don't route requests to different keys.
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
checks = [
this.invokeModel("anthropic.claude-v2", key),
this.invokeModel("anthropic.claude-3-sonnet-20240229-v1:0", key),
this.invokeModel("anthropic.claude-3-haiku-20240307-v1:0", key),
this.invokeModel("anthropic.claude-3-opus-20240229-v1:0", key),
];
}
checks.unshift(this.checkLoggingConfiguration(key));
const [_logging, claudeV2, sonnet, haiku, opus] = await Promise.all(checks);
if (isInitialCheck) {
const families: AwsBedrockModelFamily[] = [];
if (claudeV2 || sonnet || haiku) families.push("aws-claude");
if (opus) families.push("aws-claude-opus");
if (families.length === 0) {
try {
await this.checkInferenceProfiles(key);
} catch (e) {
const asError = e as AxiosError<AwsError>;
const data = asError.response?.data;
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
{ key: key.hash, error: e.message, data },
"Cannot list inference profiles.\n\
Principal may be missing `AmazonBedrockFullAccess`, or has no policy allowing action `bedrock:ListInferenceProfiles` against resource `arn:aws:bedrock:*:*:inference-profile/*`.\n\
Requests will be made without inference profiles using on-demand quotas, which may be subject to more restrictive rate limits.\n\
See https://docs.aws.amazon.com/bedrock/latest/userguide/cross-region-inference-prereq.html."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
sonnetEnabled: sonnet,
haikuEnabled: haiku,
modelFamilies: families,
});
}
// Perform checks for all parent model IDs
const results = await Promise.all(
KNOWN_MODEL_IDS.filter(([model]) =>
// Skip checks for models that are disabled anyway
config.allowedModelFamilies.includes(getAwsBedrockModelFamily(model))
).map(async ([model, ...aliases]) => ({
models: [model, ...aliases],
success: await this.invokeModel(model, key),
}))
);
// Filter out models that are disabled
const modelIds = results
.filter(({ success }) => success)
.flatMap(({ models }) => models);
if (modelIds.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
modelIds,
modelFamilies: Array.from(
new Set(modelIds.map(getAwsBedrockModelFamily))
),
});
this.log.info(
{
key: key.hash,
sonnet,
haiku,
families: key.modelFamilies,
logged: key.awsLoggingStatus,
families: key.modelFamilies,
models: key.modelIds,
},
"Checked key."
);
@@ -160,7 +209,52 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: AwsBedrockKey) {
private async invokeModel(
model: string,
key: AwsBedrockKey
): Promise<boolean> {
if (model.includes("claude")) {
// If inference profiles are available, try testing model with them.
// If they are not available or the invocation fails with the inference
// profile, fall back to regular model ID.
const { region } = AwsKeyChecker.getCredentialsFromKey(key);
const continent = region.split("-")[0];
const profile = key.inferenceProfileIds.find(
(id) => `${continent}.${model}` === id
);
if (profile) {
this.log.debug(
{ key: key.hash, model, profile },
"Testing model via inference profile."
);
let result: boolean;
try {
result = await this.testClaudeModel(key, profile);
} catch (e) {
this.log.error(
{ key: key.hash, model, profile, error: e.message },
"Error testing model with inference profile; trying model ID directly."
);
result = false;
}
// If the profile worked, we'll return success. Caller will add the
// model (not the profile) to the list of enabled models, but the
// profile will be used when the key is used for inference.
if (result) return true;
}
return this.testClaudeModel(key, model);
} else if (model.includes("mistral")) {
return this.testMistralModel(key, model);
}
throw new Error("AwsKeyChecker#invokeModel: no implementation for model");
}
private async testClaudeModel(
key: AwsBedrockKey,
model: string
): Promise<boolean> {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
// This is not a valid invocation payload, but a 400 response indicates that
// the principal at least has permission to invoke the model.
@@ -175,7 +269,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload,
validateStatus: (status) => status === 400 || status === 403,
validateStatus: (status) => [400, 403, 404].includes(status),
};
config.headers = new AxiosHeaders({
"content-type": "application/json",
@@ -187,11 +281,26 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
// We only allow one type of 403 error, and we only allow it for one model.
// This message indicates the key is valid but this particular model is not
// accessible. Other 403s may indicate the key is not usable.
if (
status === 403 &&
errorMessage?.match(/access to the model with the specified model ID/)
) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (principal does not have access)."
);
return false;
}
// ResourceNotFound typically indicates that the tested model cannot be used
// on the configured region for this set of credentials.
if (status === 404) {
this.log.debug(
{ region: creds.region, model, key: key.hash },
"Model is not available (not supported in this AWS region)."
);
return false;
}
@@ -200,23 +309,91 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
const correctErrorType = errorType === "ValidationException";
const correctErrorMessage = errorMessage?.match(/max_tokens/);
if (!correctErrorType || !correctErrorMessage) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (request rejected)."
);
return false;
// throw new AxiosError(
// `Unexpected error when invoking model ${model}: ${errorMessage}`,
// "AWS_ERROR",
// response.config,
// response.request,
// response
// );
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"AWS InvokeModel test successful."
"Model is available."
);
return true;
}
private async testMistralModel(
key: AwsBedrockKey,
model: string
): Promise<boolean> {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const payload = {
max_tokens: -1,
prompt: "<s>[INST] What is your favourite condiment? [/INST]</s>",
};
const config: AxiosRequestConfig = {
method: "POST",
url: POST_INVOKE_MODEL_URL(creds.region, model),
data: payload,
validateStatus: (status) => [400, 403, 404].includes(status),
headers: {
"content-type": "application/json",
accept: "*/*",
},
};
await AwsKeyChecker.signRequestForAws(config, key);
const response = await axios.request(config);
const { data, status, headers } = response;
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
const errorMessage = data?.message;
if (status === 403 || status === 404) {
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is not available (no access or unsupported region)."
);
return false;
}
const isBadRequest = status === 400;
const isValidationError = errorMessage?.match(/validation error/i);
if (isBadRequest && !isValidationError) {
this.log.debug(
{ key: key.hash, model, errorType, data, status, headers },
"Model is not available (request rejected)."
);
return false;
}
this.log.debug(
{ key: key.hash, model, errorType, data, status },
"Model is available."
);
return true;
}
private async checkInferenceProfiles(key: AwsBedrockKey) {
const creds = AwsKeyChecker.getCredentialsFromKey(key);
const req: AxiosRequestConfig = {
method: "GET",
url: GET_LIST_INFERENCE_PROFILES_URL(creds.region),
headers: { accept: "application/json" },
};
await AwsKeyChecker.signRequestForAws(req, key);
const { data } = await axios.request<GetInferenceProfilesResponse>(req);
const { inferenceProfileSummaries } = data;
const profileIds = inferenceProfileSummaries.map(
(p) => p.inferenceProfileId
);
this.log.debug(
{ key: key.hash, profileIds, region: creds.region },
"Inference profiles found."
);
this.updateKey(key.hash, { inferenceProfileIds: profileIds });
}
private async checkLoggingConfiguration(key: AwsBedrockKey) {
if (config.allowAwsLogging) {
// Don't check logging status if we're allowing it to reduce API calls.
@@ -285,7 +462,8 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
method,
protocol: "https:",
hostname: url.hostname,
path: url.pathname + url.search,
path: url.pathname,
query: Object.fromEntries(url.searchParams),
headers: { Host: url.hostname, ...plainHeaders },
});
+56 -60
View File
@@ -1,10 +1,11 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { AwsKeyChecker } from "./checker";
import { PaymentRequiredError } from "../../errors";
import { AwsBedrockModelFamily, getAwsBedrockModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AwsKeyChecker } from "./checker";
type AwsBedrockKeyUsage = {
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
@@ -13,10 +14,6 @@ type AwsBedrockKeyUsage = {
export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
readonly service: "aws";
readonly modelFamilies: AwsBedrockModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/**
* The confirmed logging status of this key. This is "unknown" until we
* receive a response from the AWS API. Keys which are logged, or not
@@ -24,8 +21,8 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
* set.
*/
awsLoggingStatus: "unknown" | "disabled" | "enabled";
sonnetEnabled: boolean;
haikuEnabled: boolean;
modelIds: string[];
inferenceProfileIds: string[];
}
/**
@@ -75,10 +72,14 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
sonnetEnabled: true,
haikuEnabled: false,
modelIds: ["anthropic.claude-3-sonnet-20240229-v1:0"],
inferenceProfileIds: [],
["aws-claudeTokens"]: 0,
["aws-claude-opusTokens"]: 0,
["aws-mistral-tinyTokens"]: 0,
["aws-mistral-smallTokens"]: 0,
["aws-mistral-mediumTokens"]: 0,
["aws-mistral-largeTokens"]: 0,
};
this.keys.push(newKey);
}
@@ -97,51 +98,61 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
}
public get(model: string) {
let neededVariantId = model;
// This function accepts both Anthropic/Mistral IDs and AWS IDs.
// Generally all AWS model IDs are supersets of the original vendor IDs.
// Claude 2 is the only model that breaks this convention; Anthropic calls
// it claude-2 but AWS calls it claude-v2.
if (model.includes("claude-2")) neededVariantId = "claude-v2";
const neededFamily = getAwsBedrockModelFamily(model);
const availableKeys = this.keys.filter((k) => {
const isNotLogged = k.awsLoggingStatus !== "enabled";
const neededFamily = getAwsBedrockModelFamily(model);
const needsSonnet =
model.includes("sonnet") && neededFamily === "aws-claude";
const needsHaiku =
model.includes("haiku") && neededFamily === "aws-claude";
// Select keys which
return (
// are enabled
!k.isDisabled &&
(isNotLogged || config.allowAwsLogging) &&
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under aws-claude, while opus is not
(k.haikuEnabled || !needsHaiku) &&
k.modelFamilies.includes(neededFamily)
// are not logged, unless policy allows it
(config.allowAwsLogging || k.awsLoggingStatus !== "enabled") &&
// have access to the model family we need
k.modelFamilies.includes(neededFamily) &&
// have access to the specific variant we need
k.modelIds.some((m) => m.includes(neededVariantId))
);
});
this.log.debug(
{
requestedModel: model,
selectedVariant: neededVariantId,
selectedFamily: neededFamily,
totalKeys: this.keys.length,
availableKeys: availableKeys.length,
},
"Selecting AWS key"
);
if (availableKeys.length === 0) {
throw new PaymentRequiredError(
`No AWS Bedrock keys available for model ${model}`
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
/**
* Comparator for prioritizing keys on inference profile compatibility.
* Requests made via inference profiles have higher rate limits so we want
* to use keys with compatible inference profiles first.
*/
const hasInferenceProfile = (
a: AwsBedrockKey,
b: AwsBedrockKey
) => {
const aMatch = +a.inferenceProfileIds.some((p) => p.includes(model));
const bMatch = +b.inferenceProfileIds.some((p) => p.includes(model));
return aMatch - bMatch;
};
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys, hasInferenceProfile)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -169,22 +180,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
key[`${getAwsBedrockModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
// TODO: same exact behavior for three providers, should be refactored
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+10 -52
View File
@@ -1,10 +1,13 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { PaymentRequiredError } from "../../errors";
import { logger } from "../../../logger";
import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { PaymentRequiredError } from "../../errors";
import {
AzureOpenAIModelFamily,
getAzureOpenAIModelFamily,
} from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { AzureOpenAIKeyChecker } from "./checker";
type AzureOpenAIKeyUsage = {
@@ -14,10 +17,6 @@ type AzureOpenAIKeyUsage = {
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
readonly service: "azure";
readonly modelFamilies: AzureOpenAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
contentFiltering: boolean;
}
@@ -105,30 +104,8 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
);
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -156,26 +133,7 @@ export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
}
// TODO: all of this shit is duplicate code
public getLockoutPeriod(family: AzureOpenAIModelFamily) {
const activeKeys = this.keys.filter(
(key) => !key.isDisabled && key.modelFamilies.includes(family)
);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return time until the first key is ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+294
View File
@@ -0,0 +1,294 @@
import axios, { AxiosError } from "axios";
import crypto from "crypto";
import { KeyCheckerBase } from "../key-checker-base";
import type { GcpKey, GcpKeyProvider } from "./provider";
import { GcpModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
const POST_STREAM_RAW_URL = (project: string, region: string, model: string) =>
`https://${GCP_HOST.replace(
"%REGION%",
region
)}/v1/projects/${project}/locations/${region}/publishers/anthropic/models/${model}:streamRawPredict`;
const TEST_MESSAGES = [
{ role: "user", content: "Hi!" },
{ role: "assistant", content: "Hello!" },
];
type UpdateFn = typeof GcpKeyProvider.prototype.update;
export class GcpKeyChecker extends KeyCheckerBase<GcpKey> {
constructor(keys: GcpKey[], updateKey: UpdateFn) {
super(keys, {
service: "gcp",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: GcpKey) {
let checks: Promise<boolean>[] = [];
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
checks = [
this.invokeModel("claude-3-haiku@20240307", key, true),
this.invokeModel("claude-3-sonnet@20240229", key, true),
this.invokeModel("claude-3-opus@20240229", key, true),
this.invokeModel("claude-3-5-sonnet@20240620", key, true),
];
const [sonnet, haiku, opus, sonnet35] = await Promise.all(checks);
this.log.debug(
{ key: key.hash, sonnet, haiku, opus, sonnet35 },
"GCP model initial tests complete."
);
const families: GcpModelFamily[] = [];
if (sonnet || sonnet35 || haiku) families.push("gcp-claude");
if (opus) families.push("gcp-claude-opus");
if (families.length === 0) {
this.log.warn(
{ key: key.hash },
"Key does not have access to any models; disabling."
);
return this.updateKey(key.hash, { isDisabled: true });
}
this.updateKey(key.hash, {
sonnetEnabled: sonnet,
haikuEnabled: haiku,
sonnet35Enabled: sonnet35,
modelFamilies: families,
});
} else {
if (key.haikuEnabled) {
await this.invokeModel("claude-3-haiku@20240307", key, false);
} else if (key.sonnetEnabled) {
await this.invokeModel("claude-3-sonnet@20240229", key, false);
} else if (key.sonnet35Enabled) {
await this.invokeModel("claude-3-5-sonnet@20240620", key, false);
} else {
await this.invokeModel("claude-3-opus@20240229", key, false);
}
this.updateKey(key.hash, { lastChecked: Date.now() });
this.log.debug({ key: key.hash }, "GCP key check complete.");
}
this.log.info(
{
key: key.hash,
families: key.modelFamilies,
},
"Checked key."
);
}
protected handleAxiosError(key: GcpKey, error: AxiosError) {
if (error.response && GcpKeyChecker.errorIsGcpError(error)) {
const { status, data } = error.response;
if (status === 400 || status === 401 || status === 403) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
} else if (status === 429) {
this.log.warn(
{ key: key.hash, error: data },
"Key is rate limited. Rechecking in a minute."
);
const next = Date.now() - (KEY_CHECK_PERIOD - 60 * 1000);
this.updateKey(key.hash, { lastChecked: next });
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
const { response, cause } = error;
const { headers, status, data } = response ?? {};
this.log.error(
{ key: key.hash, status, headers, data, cause, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
/**
* Attempt to invoke the given model with the given key. Returns true if the
* key has access to the model, false if it does not. Throws an error if the
* key is disabled.
*/
private async invokeModel(model: string, key: GcpKey, initial: boolean) {
const creds = GcpKeyChecker.getCredentialsFromKey(key);
const signedJWT = await GcpKeyChecker.createSignedJWT(
creds.clientEmail,
creds.privateKey
);
const [accessToken, jwtError] =
await GcpKeyChecker.exchangeJwtForAccessToken(signedJWT);
if (accessToken === null) {
this.log.warn(
{ key: key.hash, jwtError },
"Unable to get the access token"
);
return false;
}
const payload = {
max_tokens: 1,
messages: TEST_MESSAGES,
anthropic_version: "vertex-2023-10-16",
};
const { data, status } = await axios.post(
POST_STREAM_RAW_URL(creds.projectId, creds.region, model),
payload,
{
headers: GcpKeyChecker.getRequestHeaders(accessToken),
validateStatus: initial
? () => true
: (status: number) => status >= 200 && status < 300,
}
);
this.log.debug({ key: key.hash, data }, "Response from GCP");
if (initial) {
return (
(status >= 200 && status < 300) || status === 429 || status === 529
);
}
return true;
}
static errorIsGcpError(error: AxiosError): error is AxiosError {
const data = error.response?.data as any;
if (Array.isArray(data)) {
return data.length > 0 && data[0]?.error?.message;
} else {
return data?.error?.message;
}
}
static async createSignedJWT(email: string, pkey: string): Promise<string> {
let cryptoKey = await crypto.subtle.importKey(
"pkcs8",
GcpKeyChecker.str2ab(atob(pkey)),
{ name: "RSASSA-PKCS1-v1_5", hash: { name: "SHA-256" } },
false,
["sign"]
);
const authUrl = "https://www.googleapis.com/oauth2/v4/token";
const issued = Math.floor(Date.now() / 1000);
const expires = issued + 600;
const header = { alg: "RS256", typ: "JWT" };
const payload = {
iss: email,
aud: authUrl,
iat: issued,
exp: expires,
scope: "https://www.googleapis.com/auth/cloud-platform",
};
const encodedHeader = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(header)
);
const encodedPayload = GcpKeyChecker.urlSafeBase64Encode(
JSON.stringify(payload)
);
const unsignedToken = `${encodedHeader}.${encodedPayload}`;
const signature = await crypto.subtle.sign(
"RSASSA-PKCS1-v1_5",
cryptoKey,
GcpKeyChecker.str2ab(unsignedToken)
);
const encodedSignature = GcpKeyChecker.urlSafeBase64Encode(signature);
return `${unsignedToken}.${encodedSignature}`;
}
static async exchangeJwtForAccessToken(
signed_jwt: string
): Promise<[string | null, string]> {
const auth_url = "https://www.googleapis.com/oauth2/v4/token";
const params = {
grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
assertion: signed_jwt,
};
const r = await fetch(auth_url, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params)
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) {
return [r.access_token, ""];
}
return [null, JSON.stringify(r)];
}
static str2ab(str: string): ArrayBuffer {
const buffer = new ArrayBuffer(str.length);
const bufferView = new Uint8Array(buffer);
for (let i = 0; i < str.length; i++) {
bufferView[i] = str.charCodeAt(i);
}
return buffer;
}
static urlSafeBase64Encode(data: string | ArrayBuffer): string {
let base64: string;
if (typeof data === "string") {
base64 = btoa(
encodeURIComponent(data).replace(/%([0-9A-F]{2})/g, (match, p1) =>
String.fromCharCode(parseInt("0x" + p1, 16))
)
);
} else {
base64 = btoa(String.fromCharCode(...new Uint8Array(data)));
}
return base64.replace(/\+/g, "-").replace(/\//g, "_").replace(/=+$/, "");
}
static getRequestHeaders(accessToken: string) {
return {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
};
}
static getCredentialsFromKey(key: GcpKey) {
const [projectId, clientEmail, region, rawPrivateKey] = key.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) {
throw new Error("Invalid GCP key");
}
const privateKey = rawPrivateKey
.replace(
/-----BEGIN PRIVATE KEY-----|-----END PRIVATE KEY-----|\r|\n|\\n/g,
""
)
.trim();
return { projectId, clientEmail, region, privateKey };
}
}
+202
View File
@@ -0,0 +1,202 @@
import crypto from "crypto";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { PaymentRequiredError } from "../../errors";
import { GcpModelFamily, getGcpModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GcpKeyChecker } from "./checker";
type GcpKeyUsage = {
[K in GcpModelFamily as `${K}Tokens`]: number;
};
export interface GcpKey extends Key, GcpKeyUsage {
readonly service: "gcp";
readonly modelFamilies: GcpModelFamily[];
sonnetEnabled: boolean;
haikuEnabled: boolean;
sonnet35Enabled: boolean;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 4000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class GcpKeyProvider implements KeyProvider<GcpKey> {
readonly service = "gcp";
private keys: GcpKey[] = [];
private checker?: GcpKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.gcpCredentials?.trim();
if (!keyConfig) {
this.log.warn(
"GCP_CREDENTIALS is not set. GCP API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GcpKey = {
key,
service: this.service,
modelFamilies: ["gcp-claude"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `gcp-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
sonnetEnabled: true,
haikuEnabled: false,
sonnet35Enabled: false,
["gcp-claudeTokens"]: 0,
["gcp-claude-opusTokens"]: 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded GCP keys.");
}
public init() {
if (config.checkKeys) {
this.checker = new GcpKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(model: string) {
const neededFamily = getGcpModelFamily(model);
// this is a horrible mess
// each of these should be separate model families, but adding model
// families is not low enough friction for the rate at which gcp claude
// model variants are added.
const needsSonnet35 =
model.includes("claude-3-5-sonnet") && neededFamily === "gcp-claude";
const needsSonnet =
!needsSonnet35 &&
model.includes("sonnet") &&
neededFamily === "gcp-claude";
const needsHaiku = model.includes("haiku") && neededFamily === "gcp-claude";
const availableKeys = this.keys.filter((k) => {
return (
!k.isDisabled &&
(k.sonnetEnabled || !needsSonnet) && // sonnet and haiku are both under gcp-claude, while opus is not
(k.haikuEnabled || !needsHaiku) &&
(k.sonnet35Enabled || !needsSonnet35) &&
k.modelFamilies.includes(neededFamily)
);
});
this.log.debug(
{
model,
neededFamily,
needsSonnet,
needsHaiku,
needsSonnet35,
availableKeys: availableKeys.length,
totalKeys: this.keys.length,
},
"Selecting GCP key"
);
if (availableKeys.length === 0) {
throw new PaymentRequiredError(
`No GCP keys available for model ${model}`
);
}
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: GcpKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GcpKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key[`${getGcpModelFamily(model)}Tokens`] += tokens;
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {
this.keys.forEach(({ hash }) =>
this.update(hash, { lastChecked: 0, isDisabled: false, isRevoked: false })
);
this.checker?.scheduleNextCheck();
}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
@@ -0,0 +1,155 @@
import axios, { AxiosError } from "axios";
import type { GoogleAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { GoogleAIKey, GoogleAIKeyProvider } from "./provider";
import { getGoogleAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 60 * 1000; // 3 hours
const LIST_MODELS_URL =
"https://generativelanguage.googleapis.com/v1beta/models";
type ListModelsResponse = {
models: {
name: string;
baseModelId: string;
version: string;
displayName: string;
description: string;
inputTokenLimit: number;
outputTokenLimit: number;
supportedGenerationMethods: string[];
temperature: number;
maxTemperature: number;
topP: number;
topK: number;
}[];
nextPageToken: string;
};
type UpdateFn = typeof GoogleAIKeyProvider.prototype.update;
export class GoogleAIKeyChecker extends KeyCheckerBase<GoogleAIKey> {
constructor(keys: GoogleAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "google-ai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: GoogleAIKey) {
const provisionedModels = await this.getProvisionedModels(key);
const updates = {
modelFamilies: provisionedModels,
};
this.updateKey(key.hash, updates);
this.log.info(
{ key: key.hash, models: key.modelFamilies, ids: key.modelIds.length },
"Checked key."
);
}
private async getProvisionedModels(
key: GoogleAIKey
): Promise<GoogleAIModelFamily[]> {
const { data } = await axios.get<ListModelsResponse>(
`${LIST_MODELS_URL}?pageSize=1000&key=${key.key}`
);
const models = data.models;
const ids = new Set<string>();
const families = new Set<GoogleAIModelFamily>();
models.forEach(({ name }) => {
families.add(getGoogleAIModelFamily(name));
ids.add(name);
});
const familiesArray = Array.from(families);
this.updateKey(key.hash, {
modelFamilies: familiesArray,
modelIds: Array.from(ids),
});
return familiesArray;
}
protected handleAxiosError(key: GoogleAIKey, error: AxiosError): void {
if (error.response && GoogleAIKeyChecker.errorIsGoogleAIError(error)) {
const httpStatus = error.response.status;
const { code, message, status, details } = error.response.data.error;
switch (httpStatus) {
case 400:
const reason = details?.[0]?.reason;
if (status === "INVALID_ARGUMENT" && reason === "API_KEY_INVALID") {
this.log.warn(
{ key: key.hash, reason, details },
"Key check returned API_KEY_INVALID error. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
return;
} else if (
status === "FAILED_PRECONDITION" &&
message.match(/please enable billing/i)
) {
this.log.warn(
{ key: key.hash, message, details },
"Key check returned billing disabled error. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
return;
}
break;
case 401:
case 403:
this.log.warn(
{ key: key.hash, status, code, message, details },
"Key check returned Forbidden/Unauthorized error. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
return;
case 429:
this.log.warn(
{ key: key.hash, status, code, message, details },
"Key is rate limited. Rechecking key in 1 minute."
);
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next });
return;
}
this.log.error(
{ key: key.hash, status, code, message, details },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
return this.updateKey(key.hash, { lastChecked: Date.now() });
}
this.log.error(
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 10 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
return this.updateKey(key.hash, { lastChecked: next });
}
static errorIsGoogleAIError(
error: AxiosError
): error is AxiosError<GoogleAIError> {
const data = error.response?.data as any;
return data?.error?.code || data?.error?.status;
}
}
type GoogleAIError = {
error: {
code: string;
message: string;
status: string;
details: any[];
};
};
+30 -53
View File
@@ -1,13 +1,15 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { GoogleAIModelFamily } from "../../models";
import { HttpError, PaymentRequiredError } from "../../errors";
import { PaymentRequiredError } from "../../errors";
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { GoogleAIKeyChecker } from "./checker";
// Note that Google AI is not the same as Vertex AI, both are provided by Google
// but Vertex is the GCP product for enterprise. while Google AI is the
// consumer-ish product. The API is different, and keys are not compatible.
// Note that Google AI is not the same as Vertex AI, both are provided by
// Google but Vertex is the GCP product for enterprise, while Google API is a
// development/hobbyist product. They use completely different APIs and keys.
// https://ai.google.dev/docs/migrate_to_cloud
export type GoogleAIKeyUpdate = Omit<
@@ -27,10 +29,8 @@ type GoogleAIKeyUsage = {
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
/** All detected model IDs on this key. */
modelIds: string[];
}
/**
@@ -49,6 +49,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
readonly service = "google-ai";
private keys: GoogleAIKey[] = [];
private checker?: GoogleAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
@@ -78,49 +79,40 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"gemini-flashTokens": 0,
"gemini-proTokens": 0,
"gemini-ultraTokens": 0,
modelIds: [],
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
}
public init() {}
public init() {
if (config.checkKeys) {
this.checker = new GoogleAIKeyChecker(this.keys, this.update.bind(this));
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: string) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
public get(model: string) {
const neededFamily = getGoogleAIModelFamily(model);
const availableKeys = this.keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
);
if (availableKeys.length === 0) {
throw new PaymentRequiredError("No Google AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const keysByPriority = prioritizeKeys(availableKeys);
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -141,29 +133,14 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, _model: string, tokens: number) {
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key["gemini-proTokens"] += tokens;
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+30 -3
View File
@@ -10,7 +10,7 @@ export type APIFormat =
| "anthropic-text" // Legacy flat string prompt format
| "google-ai"
| "mistral-ai"
| "cohere-chat";
| "mistral-text"
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
@@ -31,6 +31,10 @@ export interface Key {
lastChecked: number;
/** Hash of the key, for logging and to find the key in the pool. */
hash: string;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/*
@@ -59,9 +63,32 @@ export interface KeyProvider<T extends Key = Key> {
recheck(): void;
}
export function createGenericGetLockoutPeriod<T extends Key>(
getKeys: () => T[]
) {
return function (this: unknown, family?: ModelFamily): number {
const keys = getKeys();
const activeKeys = keys.filter(
(k) => !k.isDisabled && (!family || k.modelFamilies.includes(family))
);
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
};
}
export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { GcpKey } from "./gcp/provider";
export { AzureOpenAIKey } from "./azure/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { MistralAIKey } from "./mistral-ai/provider";
export { OpenAIKey } from "./openai/provider";
@@ -7,6 +7,7 @@ type KeyCheckerOptions<TKey extends Key = Key> = {
service: string;
keyCheckPeriod: number;
minCheckInterval: number;
keyCheckBatchSize?: number;
recurringChecksEnabled?: boolean;
updateKey: (hash: string, props: Partial<TKey>) => void;
};
@@ -22,6 +23,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
* than this.
*/
protected readonly keyCheckPeriod: number;
/** Maximum number of keys to check simultaneously. */
protected readonly keyCheckBatchSize: number;
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
protected readonly keys: TKey[] = [];
protected log: pino.Logger;
@@ -33,6 +36,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
this.keyCheckPeriod = opts.keyCheckPeriod;
this.minCheckInterval = opts.minCheckInterval;
this.recurringChecksEnabled = opts.recurringChecksEnabled ?? true;
this.keyCheckBatchSize = opts.keyCheckBatchSize ?? 12;
this.updateKey = opts.updateKey;
this.service = opts.service;
this.log = logger.child({ module: "key-checker", service: opts.service });
@@ -78,7 +82,7 @@ export abstract class KeyCheckerBase<TKey extends Key> {
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
if (numUnchecked > 0) {
const keycheckBatch = uncheckedKeys.slice(0, 12);
const keycheckBatch = uncheckedKeys.slice(0, this.keyCheckBatchSize);
this.timeout = setTimeout(async () => {
try {
@@ -114,7 +118,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
);
// Don't check any individual key too often.
// Don't check anything at all at a rate faster than once per 3 seconds.
// Don't check anything at all more frequently than some minimum interval
// even if keys still need to be checked.
const nextCheck = Math.max(
oldestKey.lastChecked + this.keyCheckPeriod,
this.lastCheck + this.minCheckInterval
+7 -1
View File
@@ -10,6 +10,7 @@ import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { GcpKeyProvider } from "./gcp/provider";
import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
@@ -27,6 +28,7 @@ export class KeyPool {
this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new MistralAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new GcpKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
@@ -128,7 +130,11 @@ export class KeyPool {
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
if (!model.includes('@')) {
return "anthropic";
} else {
return "gcp";
}
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-ai";
@@ -69,9 +69,9 @@ export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
const { status, data } = error.response;
if (status === 401) {
if ([401, 403].includes(status)) {
this.log.warn(
{ key: key.hash, error: data },
{ key: key.hash, error: data, status },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, {
@@ -1,10 +1,11 @@
import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { MistralAIKeyChecker } from "./checker";
import { HttpError } from "../../errors";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { createGenericGetLockoutPeriod, Key, KeyProvider } from "..";
import { prioritizeKeys } from "../prioritize-keys";
import { MistralAIKeyChecker } from "./checker";
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
@@ -13,10 +14,6 @@ type MistralAIKeyUsage = {
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
@@ -98,30 +95,8 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
throw new HttpError(402, "No Mistral AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
const selectedKey = prioritizeKeys(availableKeys)[0];
selectedKey.lastUsed = Date.now();
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
@@ -150,22 +125,7 @@ export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
key[`${family}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys);
/**
* This is called when we receive a 429, which means there are already five
+1 -2
View File
@@ -26,8 +26,6 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
isTrial: boolean;
/** Set when key check returns a non-transient 429. */
isOverQuota: boolean;
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/**
* Last known X-RateLimit-Requests-Reset header from OpenAI, converted to a
* number.
@@ -111,6 +109,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
.digest("hex")
.slice(0, 8)}`,
rateLimitedAt: 0,
rateLimitedUntil: 0,
rateLimitRequestsReset: 0,
rateLimitTokensReset: 0,
turboTokens: 0,
@@ -0,0 +1,39 @@
import { Key } from "./index";
/**
* Given a list of keys, returns a new list of keys sorted from highest to
* lowest priority. Keys are prioritized in the following order:
*
* 1. Keys which are not rate limited
* a. If all keys were rate limited recently, select the least-recently
* rate limited key.
* b. Otherwise, select the first key.
* 2. Keys which have not been used in the longest time
* 3. Keys according to the custom comparator, if provided
* @param keys The list of keys to sort
* @param customComparator A custom comparator function to use for sorting
*/
export function prioritizeKeys<T extends Key>(
keys: T[],
customComparator?: (a: T, b: T) => number
) {
const now = Date.now();
return keys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < a.rateLimitedUntil;
const bRateLimited = now - b.rateLimitedAt < b.rateLimitedUntil;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
if (customComparator) {
const result = customComparator(a, b);
if (result !== 0) return result;
}
return a.lastUsed - b.lastUsed;
});
}
+97 -64
View File
@@ -1,12 +1,11 @@
// Don't import any other project files here as this is one of the first modules
// loaded and it will cause circular imports.
import pino from "pino";
import type { Request } from "express";
/**
* The service that a model is hosted on. Distinct from `APIFormat` because some
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
* services have interoperable APIs (eg Anthropic/AWS/GCP, OpenAI/Azure).
*/
export type LLMService =
| "openai"
@@ -14,8 +13,8 @@ export type LLMService =
| "google-ai"
| "mistral-ai"
| "aws"
| "azure"
| "cohere";
| "gcp"
| "azure";
export type OpenAIModelFamily =
| "turbo"
@@ -25,23 +24,27 @@ export type OpenAIModelFamily =
| "gpt4o"
| "dall-e";
export type AnthropicModelFamily = "claude" | "claude-opus";
export type GoogleAIModelFamily = "gemini-pro";
export type GoogleAIModelFamily =
| "gemini-flash"
| "gemini-pro"
| "gemini-ultra";
export type MistralAIModelFamily =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium"
| "mistral-large";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
// mistral changes their model classes frequently so these no longer
// correspond to specific models. consider them rough pricing tiers.
"mistral-tiny" | "mistral-small" | "mistral-medium" | "mistral-large";
export type AwsBedrockModelFamily = `aws-${
| AnthropicModelFamily
| MistralAIModelFamily}`;
export type GcpModelFamily = "gcp-claude" | "gcp-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type CohereModelFamily = "command-r" | "command-r-plus";
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily
| AzureOpenAIModelFamily
| CohereModelFamily;
| GcpModelFamily
| AzureOpenAIModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
@@ -54,21 +57,27 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"dall-e",
"claude",
"claude-opus",
"gemini-flash",
"gemini-pro",
"gemini-ultra",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"mistral-large",
"aws-claude",
"aws-claude-opus",
"aws-mistral-tiny",
"aws-mistral-small",
"aws-mistral-medium",
"aws-mistral-large",
"gcp-claude",
"gcp-claude-opus",
"azure-turbo",
"azure-gpt4",
"azure-gpt4-32k",
"azure-gpt4-turbo",
"azure-gpt4o",
"azure-dall-e",
"command-r",
"command-r-plus",
] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
@@ -79,12 +88,50 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
"google-ai",
"mistral-ai",
"aws",
"gcp",
"azure",
"cohere",
] as const);
export const MODEL_FAMILY_SERVICE: {
[f in ModelFamily]: LLMService;
} = {
turbo: "openai",
gpt4: "openai",
"gpt4-turbo": "openai",
"gpt4-32k": "openai",
gpt4o: "openai",
"dall-e": "openai",
claude: "anthropic",
"claude-opus": "anthropic",
"aws-claude": "aws",
"aws-claude-opus": "aws",
"aws-mistral-tiny": "aws",
"aws-mistral-small": "aws",
"aws-mistral-medium": "aws",
"aws-mistral-large": "aws",
"gcp-claude": "gcp",
"gcp-claude-opus": "gcp",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure",
"azure-gpt4o": "azure",
"azure-dall-e": "azure",
"gemini-flash": "google-ai",
"gemini-pro": "google-ai",
"gemini-ultra": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
"mistral-large": "mistral-ai",
};
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4o": "gpt4o",
"^gpt-4o(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4o",
"^chatgpt-4o": "gpt4o",
"^gpt-4o-mini(-\\d{4}-\\d{2}-\\d{2})?$": "turbo", // closest match
"^gpt-4-turbo(-\\d{4}-\\d{2}-\\d{2})?$": "gpt4-turbo",
"^gpt-4-turbo(-preview)?$": "gpt4-turbo",
"^gpt-4-(0125|1106)(-preview)?$": "gpt4-turbo",
@@ -98,38 +145,6 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^dall-e-\\d{1}$": "dall-e",
};
export const MODEL_FAMILY_SERVICE: {
[f in ModelFamily]: LLMService;
} = {
turbo: "openai",
gpt4: "openai",
"gpt4-turbo": "openai",
"gpt4-32k": "openai",
"gpt4o": "openai",
"dall-e": "openai",
claude: "anthropic",
"claude-opus": "anthropic",
"aws-claude": "aws",
"aws-claude-opus": "aws",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure",
"azure-gpt4o": "azure",
"azure-dall-e": "azure",
"gemini-pro": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
"mistral-large": "mistral-ai",
"command-r": "cohere",
"command-r-plus": "cohere",
};
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily(
model: string,
defaultFamily: OpenAIModelFamily = "gpt4"
@@ -145,8 +160,12 @@ export function getClaudeModelFamily(model: string): AnthropicModelFamily {
return "claude";
}
export function getGoogleAIModelFamily(_model: string): ModelFamily {
return "gemini-pro";
export function getGoogleAIModelFamily(model: string): GoogleAIModelFamily {
return model.includes("ultra")
? "gemini-ultra"
: model.includes("flash")
? "gemini-flash"
: "gemini-pro";
}
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
@@ -159,16 +178,34 @@ export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
return prunedModel as MistralAIModelFamily;
case "open-mistral-7b":
return "mistral-tiny";
case "open-mistral-nemo":
case "open-mixtral-8x7b":
case "codestral":
case "open-codestral-mamba":
return "mistral-small";
case "open-mixtral-8x22b":
return "mistral-medium";
default:
return "mistral-tiny";
return "mistral-small";
}
}
export function getAwsBedrockModelFamily(model: string): AwsBedrockModelFamily {
if (model.includes("opus")) return "aws-claude-opus";
return "aws-claude";
// remove vendor and version from AWS model ids
// 'anthropic.claude-3-5-sonnet-20240620-v1:0' -> 'claude-3-5-sonnet-20240620'
const deAwsified = model.replace(/^(\w+)\.(.+?)(-v\d+)?(:\d+)*$/, "$2");
if (["claude", "anthropic"].some((x) => model.includes(x))) {
return `aws-${getClaudeModelFamily(deAwsified)}`;
} else if (model.includes("tral")) {
return `aws-${getMistralAIModelFamily(deAwsified)}`;
}
return `aws-claude`;
}
export function getGcpModelFamily(model: string): GcpModelFamily {
if (model.includes("opus")) return "gcp-claude-opus";
return "gcp-claude";
}
export function getAzureOpenAIModelFamily(
@@ -189,11 +226,6 @@ export function getAzureOpenAIModelFamily(
return defaultFamily;
}
export function getCohereModelFamily(model: string): CohereModelFamily {
if (model.includes("plus")) return "command-r-plus";
return "command-r";
}
export function assertIsKnownModelFamily(
modelFamily: string
): asserts modelFamily is ModelFamily {
@@ -210,10 +242,13 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
const model = req.body.model ?? "gpt-3.5-turbo";
let modelFamily: ModelFamily;
// Weird special case for AWS/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
// Weird special case for AWS/GCP/Azure because they serve models with
// different API formats, so the outbound API alone is not sufficient to
// determine the partition.
if (req.service === "aws") {
modelFamily = getAwsBedrockModelFamily(model);
} else if (req.service === "gcp") {
modelFamily = getGcpModelFamily(model);
} else if (req.service === "azure") {
modelFamily = getAzureOpenAIModelFamily(model);
} else {
@@ -231,11 +266,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
modelFamily = getGoogleAIModelFamily(model);
break;
case "mistral-ai":
case "mistral-text":
modelFamily = getMistralAIModelFamily(model);
break;
case "cohere-chat":
modelFamily = getCohereModelFamily(model);
break;
default:
assertNever(req.outboundApi);
}
+2
View File
@@ -30,10 +30,12 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
cost = 0.00001;
break;
case "aws-claude":
case "gcp-claude":
case "claude":
cost = 0.000008;
break;
case "aws-claude-opus":
case "gcp-claude-opus":
case "claude-opus":
cost = 0.000015;
break;
+3
View File
@@ -67,6 +67,9 @@ async function getTokenCountForMessages({
case "image":
numTokens += await getImageTokenCount(part.source.data);
break;
case "tool_use":
case "tool_result":
break;
default:
throw new Error(`Unsupported Anthropic content type.`);
}
+3 -2
View File
@@ -47,9 +47,9 @@ type GoogleAIChatTokenCountRequest = {
};
type MistralAIChatTokenCountRequest = {
prompt: MistralAIChatMessage[];
prompt: string | MistralAIChatMessage[];
completion?: never;
service: "mistral-ai";
service: "mistral-ai" | "mistral-text";
};
type FlatPromptTokenCountRequest = {
@@ -128,6 +128,7 @@ export async function countTokens({
tokenization_duration_ms: getElapsedMs(time),
};
case "mistral-ai":
case "mistral-text":
return {
...getMistralAITokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time),
+2
View File
@@ -37,6 +37,8 @@ export const UserSchema = z
tokenCounts: tokenCountsSchema,
/** Maximum number of tokens the user can consume, by model family. */
tokenLimits: tokenCountsSchema,
/** User-specific token refresh amount, by model family. */
tokenRefresh: tokenCountsSchema,
/** Time at which the user was created. */
createdAt: z.number(),
/** Time at which the user last connected. */
+37 -12
View File
@@ -13,6 +13,7 @@ import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config";
import {
getAwsBedrockModelFamily,
getGcpModelFamily,
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGoogleAIModelFamily,
@@ -70,6 +71,7 @@ export function createUser(createOptions?: {
type?: User["type"];
expiresAt?: number;
tokenLimits?: User["tokenLimits"];
tokenRefresh?: User["tokenRefresh"];
}) {
const token = uuid();
const newUser: User = {
@@ -79,6 +81,7 @@ export function createUser(createOptions?: {
promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS },
tokenLimits: createOptions?.tokenLimits ?? { ...config.tokenQuota },
tokenRefresh: createOptions?.tokenRefresh ?? { ...INITIAL_TOKENS },
createdAt: Date.now(),
meta: {},
};
@@ -123,6 +126,7 @@ export function upsertUser(user: UserUpdate) {
promptCount: 0,
tokenCounts: { ...INITIAL_TOKENS },
tokenLimits: { ...config.tokenQuota },
tokenRefresh: { ...INITIAL_TOKENS },
createdAt: Date.now(),
meta: {},
};
@@ -139,7 +143,6 @@ export function upsertUser(user: UserUpdate) {
}
}
// TODO: Write firebase migration to backfill new fields
if (updates.tokenCounts) {
for (const family of MODEL_FAMILIES) {
updates.tokenCounts[family] ??= 0;
@@ -150,6 +153,16 @@ export function upsertUser(user: UserUpdate) {
updates.tokenLimits[family] ??= 0;
}
}
// tokenRefresh is a special case where we want to merge the existing and
// updated values for each model family, ignoring falsy values.
if (updates.tokenRefresh) {
const merged = { ...existing.tokenRefresh };
for (const family of MODEL_FAMILIES) {
merged[family] =
updates.tokenRefresh[family] || existing.tokenRefresh[family];
}
updates.tokenRefresh = merged;
}
users.set(user.token, Object.assign(existing, updates));
usersToFlush.add(user.token);
@@ -245,19 +258,29 @@ export function hasAvailableQuota({
return tokensConsumed < tokenLimit;
}
/**
* For the given user, sets token limits for each model family to the sum of the
* current count and the refresh amount, up to the default limit. If a quota is
* not specified for a model family, it is not touched.
*/
export function refreshQuota(token: string) {
const user = users.get(token);
if (!user) return;
const { tokenCounts, tokenLimits } = user;
const quotas = Object.entries(config.tokenQuota) as [ModelFamily, number][];
quotas
// If a quota is not configured, don't touch any existing limits a user may
// already have been assigned manually.
.filter(([, quota]) => quota > 0)
.forEach(
([model, quota]) =>
(tokenLimits[model] = (tokenCounts[model] ?? 0) + quota)
);
const { tokenQuota } = config;
const { tokenCounts, tokenLimits, tokenRefresh } = user;
// Get default quotas for each model family.
const defaultQuotas = Object.entries(tokenQuota) as [ModelFamily, number][];
// If any user-specific refresh quotas are present, override default quotas.
const userQuotas = defaultQuotas.map(
([f, q]) => [f, (tokenRefresh[f] ?? 0) || q] as const /* narrow to tuple */
);
userQuotas
// Ignore families with no global or user-specific refresh quota.
.filter(([, q]) => q > 0)
// Increase family token limit by the family's refresh amount.
.forEach(([f, q]) => (tokenLimits[f] = (tokenCounts[f] ?? 0) + q));
usersToFlush.add(token);
}
@@ -307,7 +330,7 @@ function cleanupExpiredTokens() {
user.meta.refreshable = config.captchaMode !== "none";
disabled++;
}
const purgeTimeout = config.powTokenPurgeHours * 60 * 60 * 1000;
const purgeTimeout = config.powTokenPurgeHours * 60 * 60 * 1000;
if (user.disabledAt && user.disabledAt + purgeTimeout < now) {
users.delete(user.token);
usersToFlush.add(user.token);
@@ -395,6 +418,7 @@ function getModelFamilyForQuotaUsage(
// differentiate between Azure and OpenAI variants of the same model.
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
if (model.includes("anthropic.")) return getAwsBedrockModelFamily(model);
if (model.startsWith("claude-") && model.includes("@")) return getGcpModelFamily(model);
switch (api) {
case "openai":
@@ -407,6 +431,7 @@ function getModelFamilyForQuotaUsage(
case "google-ai":
return getGoogleAIModelFamily(model);
case "mistral-ai":
case "mistral-text":
return getMistralAIModelFamily(model);
default:
assertNever(api);
+20 -9
View File
@@ -33,7 +33,7 @@
.pagination li a {
display: block;
padding: 0.5em 1em;
border-bottom: none;
border-bottom: none;
text-decoration: none;
}
.pagination li.active a {
@@ -71,20 +71,24 @@
td.actions:hover {
background-color: #e0e6f6;
}
tr > td,
tr > th {
border-right: 1px solid #dedede;
}
@media (max-width: 800px) {
body {
padding: 0.5em;
}
table.full-width {
width: 100%;
position: static;
left: auto;
right: auto;
margin-left: 0;
margin-right: 0;
}
table.full-width {
width: 100%;
position: static;
left: auto;
right: auto;
margin-left: 0;
margin-right: 0;
}
}
@media (prefers-color-scheme: dark) {
@@ -95,6 +99,13 @@
th.active {
background-color: #446;
}
td.actions:hover {
background-color: #446;
}
tr > td,
tr > th {
border-right: 1px solid #444;
}
}
</style>
</head>
@@ -1,4 +1,6 @@
<p>Next refresh: <time><%- nextQuotaRefresh %></time></p>
<p>
Next refresh: <time><%- nextQuotaRefresh %></time>
</p>
<table class="striped">
<thead>
<tr>
@@ -9,7 +11,7 @@
<% } %>
<th scope="col">Limit</th>
<th scope="col">Remaining</th>
<th scope="col">Refresh Amount</th>
<th scope="col" colspan="<%= showRefreshEdit ? 2 : 1 %>">Refresh Amount</th>
</tr>
</thead>
<tbody>
@@ -19,7 +21,7 @@
<td><%- prettyTokens(user.tokenCounts[key]) %></td>
<% if (showTokenCosts) { %>
<td>$<%- tokenCost(key, user.tokenCounts[key]).toFixed(2) %></td>
<% } %>
<% } %>
<% if (!user.tokenLimits[key]) { %>
<td colspan="2" style="text-align: center">unlimited</td>
<% } else { %>
@@ -29,7 +31,20 @@
<% if (user.type === "temporary") { %>
<td>N/A</td>
<% } else { %>
<td><%- prettyTokens(quota[key]) %></td>
<td><%- prettyTokens(user.tokenRefresh[key] || quota[key]) %></td>
<% } %>
<% if (showRefreshEdit) { %>
<td class="actions">
<a
title="Edit"
id="edit-refresh"
href="#"
data-field="tokenRefresh_<%= key %>"
data-token="<%= user.token %>"
data-modelFamily="<%= key %>"
>✏️</a
>
</td>
<% } %>
</tr>
<% }) %>
+3 -3
View File
@@ -1,14 +1,14 @@
import cookieParser from "cookie-parser";
import expressSession from "express-session";
import MemoryStore from "memorystore";
import { config, COOKIE_SECRET } from "../config";
import { config, SECRET_SIGNING_KEY } from "../config";
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
const cookieParserMiddleware = cookieParser(COOKIE_SECRET);
const cookieParserMiddleware = cookieParser(SECRET_SIGNING_KEY);
const sessionMiddleware = expressSession({
secret: COOKIE_SECRET,
secret: SECRET_SIGNING_KEY,
resave: false,
saveUninitialized: false,
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
+9 -20
View File
@@ -2,6 +2,7 @@ import crypto from "crypto";
import express from "express";
import argon2 from "@node-rs/argon2";
import { z } from "zod";
import { signMessage } from "../../shared/hmac-signing";
import {
authenticate,
createUser,
@@ -13,15 +14,13 @@ import { config } from "../../config";
/** Lockout time after verification in milliseconds */
const LOCKOUT_TIME = 1000 * 60; // 60 seconds
/** HMAC key for signing challenges; regenerated on startup */
let hmacSecret = crypto.randomBytes(32).toString("hex");
let powKeySalt = crypto.randomBytes(32).toString("hex");
/**
* Regenerate the HMAC key used for signing challenges. Calling this function
* will invalidate all existing challenges.
* Invalidates any outstanding unsolved challenges.
*/
export function invalidatePowHmacKey() {
hmacSecret = crypto.randomBytes(32).toString("hex");
export function invalidatePowChallenges() {
powKeySalt = crypto.randomBytes(32).toString("hex");
}
const argon2Params = {
@@ -141,16 +140,6 @@ function generateChallenge(clientIp?: string, token?: string): Challenge {
};
}
function signMessage(msg: any): string {
const hmac = crypto.createHmac("sha256", hmacSecret);
if (typeof msg === "object") {
hmac.update(JSON.stringify(msg));
} else {
hmac.update(msg);
}
return hmac.digest("hex");
}
async function verifySolution(
challenge: Challenge,
solution: string,
@@ -213,7 +202,7 @@ router.post("/challenge", (req, res) => {
}
const { action, refreshToken, proxyKey } = data.data;
if (config.proxyKey && proxyKey !== config.proxyKey) {
res.status(400).json({ error: "Invalid proxy password" });
res.status(401).json({ error: "Invalid proxy password" });
return;
}
@@ -225,11 +214,11 @@ router.post("/challenge", (req, res) => {
return;
}
const challenge = generateChallenge(req.ip, refreshToken);
const signature = signMessage(challenge);
const signature = signMessage(challenge, powKeySalt);
res.json({ challenge, signature });
} else {
const challenge = generateChallenge(req.ip);
const signature = signMessage(challenge);
const signature = signMessage(challenge, powKeySalt);
res.json({ challenge, signature });
}
});
@@ -253,7 +242,7 @@ router.post("/verify", async (req, res) => {
}
const { challenge, signature, solution } = result.data;
if (signMessage(challenge) !== signature) {
if (signMessage(challenge, powKeySalt) !== signature) {
res.status(400).json({
error:
"Invalid signature; server may have restarted since challenge was issued. Please request a new challenge.",
@@ -303,6 +303,10 @@
_csrf: document.querySelector("meta[name=csrf-token]").getAttribute("content"),
};
if (localStorage.getItem("captcha-proxy-key")) {
body.proxyKey = localStorage.getItem("captcha-proxy-key");
}
fetch("/user/captcha/verify", {
method: "POST",
credentials: "same-origin",
+1 -1
View File
@@ -64,7 +64,7 @@
</table>
<h3>Quota Information</h3>
<%- include("partials/shared_quota-info", { quota, user }) %>
<%- include("partials/shared_quota-info", { quota, user, showRefreshEdit: false }) %>
<form id="edit-nickname-form" style="display: none" action="/user/edit-nickname" method="post">
<input type="hidden" name="_csrf" value="<%= csrfToken %>" />
+5 -1
View File
@@ -61,7 +61,11 @@
const refreshToken = token && action === "refresh" ? JSON.parse(token).token : undefined;
const keyInput = document.getElementById("proxy-key");
const proxyKey = (keyInput && keyInput.value) || undefined;
localStorage.setItem("captcha-proxy-key", proxyKey);
if (!proxyKey?.length) {
localStorage.removeItem("captcha-proxy-key");
} else {
localStorage.setItem("captcha-proxy-key", proxyKey);
}
fetch("/user/captcha/challenge", {
method: "POST",