Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b8cc5e563e | |||
| 00402c8310 | |||
| df2e986366 | |||
| f9620991e7 | |||
| dd511fe60d | |||
| ea2bfb9eef | |||
| 39436e7492 | |||
| 3b9013cd1e | |||
| 8884544b05 | |||
| 05ab8c37eb | |||
| f53e328398 | |||
| 21af866fd9 | |||
| 5d3433268f | |||
| 4114dba4f5 | |||
| e44d24a3af | |||
| d611aeee18 |
@@ -0,0 +1,4 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
end_of_line = crlf
|
||||
+6
-33
@@ -11,17 +11,11 @@
|
||||
# The title displayed on the info page.
|
||||
# SERVER_TITLE=Coom Tunnel
|
||||
|
||||
# Text model requests allowed per minute per user.
|
||||
# TEXT_MODEL_RATE_LIMIT=4
|
||||
# Image model requests allowed per minute per user.
|
||||
# IMAGE_MODEL_RATE_LIMIT=2
|
||||
|
||||
# Max number of context tokens a user can request at once.
|
||||
# Increase this if your proxy allow GPT 32k or 128k context
|
||||
# MAX_CONTEXT_TOKENS_OPENAI=16384
|
||||
# Model requests allowed per minute per user.
|
||||
# MODEL_RATE_LIMIT=4
|
||||
|
||||
# Max number of output tokens a user can request at once.
|
||||
# MAX_OUTPUT_TOKENS_OPENAI=400
|
||||
# MAX_OUTPUT_TOKENS_OPENAI=300
|
||||
# MAX_OUTPUT_TOKENS_ANTHROPIC=400
|
||||
|
||||
# Whether to show the estimated cost of consumed tokens on the info page.
|
||||
@@ -33,11 +27,7 @@
|
||||
# CHECK_KEYS=true
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
# ALLOWED_MODEL_FAMILIES=claude,turbo,gpt4,gpt4-32k
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
@@ -46,10 +36,8 @@
|
||||
# Destination to redirect blocked requests to.
|
||||
# BLOCK_REDIRECT="https://roblox.com/"
|
||||
|
||||
# Comma-separated list of phrases that will be rejected. Only whole words are matched.
|
||||
# Surround phrases with quotes if they contain commas.
|
||||
# Avoid short or common phrases as this tests the entire prompt.
|
||||
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
|
||||
# Whether to reject requests containing disallowed content.
|
||||
# REJECT_DISALLOWED=false
|
||||
# Message to show when requests are rejected.
|
||||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
||||
|
||||
@@ -60,9 +48,6 @@
|
||||
# The port to listen on.
|
||||
# PORT=7860
|
||||
|
||||
# Whether cookies should be set without the Secure flag, for hosts that don't support SSL.
|
||||
# USE_INSECURE_COOKIES=false
|
||||
|
||||
# Detail level of logging. (trace | debug | info | warn | error)
|
||||
# LOG_LEVEL=info
|
||||
|
||||
@@ -78,25 +63,15 @@
|
||||
|
||||
# Maximum number of unique IPs a user can connect from. (0 for unlimited)
|
||||
# MAX_IPS_PER_USER=0
|
||||
# Whether user_tokens should be automatically disabled when reaching the IP limit.
|
||||
# MAX_IPS_AUTO_BAN=true
|
||||
|
||||
# With user_token gatekeeper, whether to allow users to change their nickname.
|
||||
# ALLOW_NICKNAME_CHANGES=true
|
||||
|
||||
# Default token quotas for each model family. (0 for unlimited)
|
||||
# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
|
||||
# which is similar to the cost of GPT-4 Turbo.
|
||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||
# See `docs/dall-e-configuration.md` for more information.
|
||||
# TOKEN_QUOTA_TURBO=0
|
||||
# TOKEN_QUOTA_GPT4=0
|
||||
# TOKEN_QUOTA_GPT4_32K=0
|
||||
# TOKEN_QUOTA_GPT4_TURBO=0
|
||||
# TOKEN_QUOTA_DALL_E=0
|
||||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_BISON=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
|
||||
# How often to refresh token quotas. (hourly | daily)
|
||||
# Leave unset to never automatically refresh quotas.
|
||||
@@ -114,8 +89,6 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||
AZURE_CREDENTIALS=azure-resource-name:deployment-id:api-key,another-azure-resource-name:another-deployment-id:another-api-key
|
||||
|
||||
# With proxy_key gatekeeper, the password users must provide to access the API.
|
||||
# PROXY_KEY=your-secret-key
|
||||
|
||||
@@ -5,4 +5,3 @@
|
||||
build
|
||||
greeting.md
|
||||
node_modules
|
||||
http-client.private.env.json
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
*
|
||||
!.gitkeep
|
||||
@@ -3,8 +3,6 @@ RUN apt-get update && \
|
||||
apt-get install -y git
|
||||
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
|
||||
WORKDIR /app
|
||||
RUN chown -R 1000:1000 /app
|
||||
USER 1000
|
||||
RUN npm install
|
||||
COPY Dockerfile greeting.md* .env* ./
|
||||
RUN npm run build
|
||||
|
||||
@@ -45,11 +45,10 @@ You can also request Claude Instant, but support for this isn't fully implemente
|
||||
### Supported model IDs
|
||||
Users can send these model IDs to the proxy to invoke the corresponding models.
|
||||
- **Claude**
|
||||
- `anthropic.claude-v1` (~18k context, claude 1.3)
|
||||
- `anthropic.claude-v2` (~100k context, claude 2.0)
|
||||
- `anthropic.claude-v2:1` (~200k context, claude 2.1)
|
||||
- `anthropic.claude-v1` (~18k context)
|
||||
- `anthropic.claude-v2` (~100k context)
|
||||
- **Claude Instant**
|
||||
- `anthropic.claude-instant-v1` (~100k context, claude instant 1.2)
|
||||
- `anthropic.claude-instant-v1`
|
||||
|
||||
## Note regarding logging
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
# Configuring the proxy for Azure
|
||||
|
||||
The proxy supports Azure OpenAI Service via the `/proxy/azure/openai` endpoint. The process of setting it up is slightly different from regular OpenAI.
|
||||
|
||||
- [Setting keys](#setting-keys)
|
||||
- [Model assignment](#model-assignment)
|
||||
|
||||
## Setting keys
|
||||
|
||||
Use the `AZURE_CREDENTIALS` environment variable to set the Azure API keys.
|
||||
|
||||
Like other APIs, you can provide multiple keys separated by commas. Each Azure key, however, is a set of values including the Resource Name, Deployment ID, and API key. These are separated by a colon (`:`).
|
||||
|
||||
For example:
|
||||
```
|
||||
AZURE_CREDENTIALS=contoso-ml:gpt4-8k:0123456789abcdef0123456789abcdef,northwind-corp:testdeployment:0123456789abcdef0123456789abcdef
|
||||
```
|
||||
|
||||
## Model assignment
|
||||
Note that each Azure deployment is assigned a model when you create it in the Azure OpenAI Service portal. If you want to use a different model, you'll need to create a new deployment, and therefore a new key to be added to the AZURE_CREDENTIALS environment variable. Each credential only grants access to one model.
|
||||
|
||||
### Supported model IDs
|
||||
Users can send normal OpenAI model IDs to the proxy to invoke the corresponding models. For the most part they work the same with Azure. GPT-3.5 Turbo has an ID of "gpt-35-turbo" because Azure doesn't allow periods in model names, but the proxy should automatically convert this to the correct ID.
|
||||
|
||||
As noted above, you can only use model IDs for which a deployment has been created and added to the proxy.
|
||||
|
||||
## On content filtering
|
||||
Be aware that all Azure OpenAI Service deployments have content filtering enabled by default at a Medium level. Prompts or responses which are deemed to be inappropriate will be rejected by the API. This is a feature of the Azure OpenAI Service and not the proxy.
|
||||
|
||||
You can disable this from deployment's settings within Azure, but you would need to request an exemption from Microsoft for your organization first. See [this page](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/content-filters) for more information.
|
||||
@@ -1,71 +0,0 @@
|
||||
# Configuring the proxy for DALL-E
|
||||
|
||||
The proxy supports DALL-E 2 and DALL-E 3 image generation via the `/proxy/openai-images` endpoint. By default it is disabled as it is somewhat expensive and potentially more open to abuse than text generation.
|
||||
|
||||
- [Updating your Dockerfile](#updating-your-dockerfile)
|
||||
- [Enabling DALL-E](#enabling-dall-e)
|
||||
- [Setting quotas](#setting-quotas)
|
||||
- [Rate limiting](#rate-limiting)
|
||||
|
||||
## Updating your Dockerfile
|
||||
If you are using a previous version of the Dockerfile supplied with the proxy, it doesn't have the necessary permissions to let the proxy save temporary files.
|
||||
|
||||
You can replace the entire thing with the new Dockerfile at [./docker/huggingface/Dockerfile](../docker/huggingface/Dockerfile) (or the equivalent for Render deployments).
|
||||
|
||||
You can also modify your existing Dockerfile; just add the following lines after the `WORKDIR` line:
|
||||
|
||||
```Dockerfile
|
||||
# Existing
|
||||
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
|
||||
WORKDIR /app
|
||||
|
||||
# Take ownership of the app directory and switch to the non-root user
|
||||
RUN chown -R 1000:1000 /app
|
||||
USER 1000
|
||||
|
||||
# Existing
|
||||
RUN npm install
|
||||
```
|
||||
|
||||
## Enabling DALL-E
|
||||
Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL-E. For example:
|
||||
|
||||
```
|
||||
# GPT3.5 Turbo, GPT-4, GPT-4 Turbo, and DALL-E
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
|
||||
|
||||
# All models as of this writing
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
|
||||
```
|
||||
|
||||
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
|
||||
|
||||
## Setting quotas
|
||||
DALL-E doesn't bill by token like text generation models. Instead there is a fixed cost per image generated, depending on the model, image size, and selected quality.
|
||||
|
||||
The proxy still uses tokens to set quotas for users. The cost for each generated image will be converted to "tokens" at a rate of 100000 tokens per US$1.00. This works out to a similar cost-per-token as GPT-4 Turbo, so you can use similar token quotas for both.
|
||||
|
||||
Use `TOKEN_QUOTA_DALL_E` to set the default quota for image generation. Otherwise it works the same as token quotas for other models.
|
||||
|
||||
```
|
||||
# ~50 standard DALL-E images per refresh period, or US$2.00
|
||||
TOKEN_QUOTA_DALL_E=200000
|
||||
```
|
||||
|
||||
Refer to [https://openai.com/pricing](https://openai.com/pricing) for the latest pricing information. As of this writing, the cheapest DALL-E 3 image costs $0.04 per generation, which works out to 4000 tokens. Higher resolution and quality settings can cost up to $0.12 per image, or 12000 tokens.
|
||||
|
||||
## Rate limiting
|
||||
The old `MODEL_RATE_LIMIT` setting has been split into `TEXT_MODEL_RATE_LIMIT` and `IMAGE_MODEL_RATE_LIMIT`. Whatever value you previously set for `MODEL_RATE_LIMIT` will be used for text models.
|
||||
|
||||
If you don't specify a `IMAGE_MODEL_RATE_LIMIT`, it defaults to half of the `TEXT_MODEL_RATE_LIMIT`, to a minimum of 1 image per minute.
|
||||
|
||||
```
|
||||
# 4 text generations per minute, 2 images per minute
|
||||
TEXT_MODEL_RATE_LIMIT=4
|
||||
IMAGE_MODEL_RATE_LIMIT=2
|
||||
```
|
||||
|
||||
If a prompt is filtered by OpenAI's content filter, it won't count towards the rate limit.
|
||||
|
||||
## Hiding recent images
|
||||
By default, the proxy shows the last 12 recently generated images by users. You can hide this section by setting `SHOW_RECENT_IMAGES` to `false`.
|
||||
@@ -25,8 +25,6 @@ RUN apt-get update && \
|
||||
apt-get install -y git
|
||||
RUN git clone https://gitgud.io/khanon/oai-reverse-proxy.git /app
|
||||
WORKDIR /app
|
||||
RUN chown -R 1000:1000 /app
|
||||
USER 1000
|
||||
RUN npm install
|
||||
COPY Dockerfile greeting.md* .env* ./
|
||||
RUN npm run build
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
{
|
||||
"dev": {
|
||||
"proxy-host": "http://localhost:7860",
|
||||
"oai-key-1": "override in http-client.private.env.json",
|
||||
"proxy-key": "override in http-client.private.env.json",
|
||||
"azu-resource-name": "override in http-client.private.env.json",
|
||||
"azu-deployment-id": "override in http-client.private.env.json"
|
||||
}
|
||||
}
|
||||
Generated
+9
-397
@@ -15,7 +15,6 @@
|
||||
"@smithy/signature-v4": "^2.0.10",
|
||||
"@smithy/types": "^2.3.4",
|
||||
"axios": "^1.3.5",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
"copyfiles": "^2.4.1",
|
||||
"cors": "^2.8.5",
|
||||
@@ -34,7 +33,6 @@
|
||||
"pino": "^8.11.0",
|
||||
"pino-http": "^8.3.3",
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
@@ -1375,20 +1373,15 @@
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.6.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.6.1.tgz",
|
||||
"integrity": "sha512-vfBmhDpKafglh0EldBEbVuoe7DyAavGSLWhuSm5ZSEKQnHhBf0xAAwybbNH1IkrJNGnS/VG4I5yxig1pCEXE4g==",
|
||||
"version": "1.3.5",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.3.5.tgz",
|
||||
"integrity": "sha512-glL/PvG/E+xCWwV8S6nCHcrfg1exGx7vxyUIivIA1iL7BIh6bePylCfVHwp6k13ao7SATxB6imau2kqY+I67kw==",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.0",
|
||||
"form-data": "^4.0.0",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/b4a": {
|
||||
"version": "1.6.4",
|
||||
"resolved": "https://registry.npmjs.org/b4a/-/b4a-1.6.4.tgz",
|
||||
"integrity": "sha512-fpWrvyVHEKyeEvbKZTVOeZF3VSKKWtJxFIxX/jaVPf+cLbGUSitjb49pHLqPV2BUNNZ0LcoeEGfE/YCpyDYHIw=="
|
||||
},
|
||||
"node_modules/balanced-match": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz",
|
||||
@@ -1430,52 +1423,6 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/bl": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/bl/-/bl-4.1.0.tgz",
|
||||
"integrity": "sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==",
|
||||
"dependencies": {
|
||||
"buffer": "^5.5.0",
|
||||
"inherits": "^2.0.4",
|
||||
"readable-stream": "^3.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/bl/node_modules/buffer": {
|
||||
"version": "5.7.1",
|
||||
"resolved": "https://registry.npmjs.org/buffer/-/buffer-5.7.1.tgz",
|
||||
"integrity": "sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/feross"
|
||||
},
|
||||
{
|
||||
"type": "patreon",
|
||||
"url": "https://www.patreon.com/feross"
|
||||
},
|
||||
{
|
||||
"type": "consulting",
|
||||
"url": "https://feross.org/support"
|
||||
}
|
||||
],
|
||||
"dependencies": {
|
||||
"base64-js": "^1.3.1",
|
||||
"ieee754": "^1.1.13"
|
||||
}
|
||||
},
|
||||
"node_modules/bl/node_modules/readable-stream": {
|
||||
"version": "3.6.2",
|
||||
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
|
||||
"integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
|
||||
"dependencies": {
|
||||
"inherits": "^2.0.3",
|
||||
"string_decoder": "^1.1.1",
|
||||
"util-deprecate": "^1.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/bluebird": {
|
||||
"version": "3.7.2",
|
||||
"resolved": "https://registry.npmjs.org/bluebird/-/bluebird-3.7.2.tgz",
|
||||
@@ -1635,14 +1582,6 @@
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/check-disk-space": {
|
||||
"version": "3.4.0",
|
||||
"resolved": "https://registry.npmjs.org/check-disk-space/-/check-disk-space-3.4.0.tgz",
|
||||
"integrity": "sha512-drVkSqfwA+TvuEhFipiR1OC9boEGZL5RrWvVsOthdcvQNXyCCuKkEiTOTXZ7qxSf/GLwq4GvzfrQD/Wz325hgw==",
|
||||
"engines": {
|
||||
"node": ">=16"
|
||||
}
|
||||
},
|
||||
"node_modules/chokidar": {
|
||||
"version": "3.5.3",
|
||||
"resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz",
|
||||
@@ -1670,11 +1609,6 @@
|
||||
"fsevents": "~2.3.2"
|
||||
}
|
||||
},
|
||||
"node_modules/chownr": {
|
||||
"version": "1.1.4",
|
||||
"resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
|
||||
"integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
|
||||
},
|
||||
"node_modules/cliui": {
|
||||
"version": "8.0.1",
|
||||
"resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz",
|
||||
@@ -1689,18 +1623,6 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/color": {
|
||||
"version": "4.2.3",
|
||||
"resolved": "https://registry.npmjs.org/color/-/color-4.2.3.tgz",
|
||||
"integrity": "sha512-1rXeuUUiGGrykh+CeBdu5Ie7OJwinCgQY0bc7GCRxy5xVHy+moaqkpL/jqQq0MtQOeYcrqEz4abc5f0KtU7W4A==",
|
||||
"dependencies": {
|
||||
"color-convert": "^2.0.1",
|
||||
"color-string": "^1.9.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12.5.0"
|
||||
}
|
||||
},
|
||||
"node_modules/color-convert": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz",
|
||||
@@ -1717,15 +1639,6 @@
|
||||
"resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz",
|
||||
"integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA=="
|
||||
},
|
||||
"node_modules/color-string": {
|
||||
"version": "1.9.1",
|
||||
"resolved": "https://registry.npmjs.org/color-string/-/color-string-1.9.1.tgz",
|
||||
"integrity": "sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==",
|
||||
"dependencies": {
|
||||
"color-name": "^1.0.0",
|
||||
"simple-swizzle": "^0.2.2"
|
||||
}
|
||||
},
|
||||
"node_modules/colorette": {
|
||||
"version": "2.0.20",
|
||||
"resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz",
|
||||
@@ -2087,28 +2000,6 @@
|
||||
"ms": "2.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/decompress-response": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/decompress-response/-/decompress-response-6.0.0.tgz",
|
||||
"integrity": "sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==",
|
||||
"dependencies": {
|
||||
"mimic-response": "^3.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/deep-extend": {
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/deep-extend/-/deep-extend-0.6.0.tgz",
|
||||
"integrity": "sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==",
|
||||
"engines": {
|
||||
"node": ">=4.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/deep-is": {
|
||||
"version": "0.1.4",
|
||||
"resolved": "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz",
|
||||
@@ -2148,14 +2039,6 @@
|
||||
"npm": "1.2.8000 || >= 1.4.16"
|
||||
}
|
||||
},
|
||||
"node_modules/detect-libc": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.0.2.tgz",
|
||||
"integrity": "sha512-UX6sGumvvqSaXgdKGUsgZWqcUyIXZ/vZTrlRT/iobiKhGL0zL4d3osHj3uqllWJK+i+sixDS/3COVEOFbupFyw==",
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/diff": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz",
|
||||
@@ -2305,6 +2188,7 @@
|
||||
"version": "1.4.4",
|
||||
"resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.4.tgz",
|
||||
"integrity": "sha512-+uw1inIHVPQoaVuHzRyXd21icM+cnt4CzD5rW+NC1wjOUSTOs+Te7FOv7AhN7vS9x/oIyhLP5PR1H+phQAHu5Q==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"once": "^1.4.0"
|
||||
}
|
||||
@@ -2589,14 +2473,6 @@
|
||||
"node": ">=0.8.x"
|
||||
}
|
||||
},
|
||||
"node_modules/expand-template": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/expand-template/-/expand-template-2.0.3.tgz",
|
||||
"integrity": "sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/express": {
|
||||
"version": "4.18.2",
|
||||
"resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz",
|
||||
@@ -2681,11 +2557,6 @@
|
||||
"integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
|
||||
"optional": true
|
||||
},
|
||||
"node_modules/fast-fifo": {
|
||||
"version": "1.3.2",
|
||||
"resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz",
|
||||
"integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ=="
|
||||
},
|
||||
"node_modules/fast-levenshtein": {
|
||||
"version": "2.0.6",
|
||||
"resolved": "https://registry.npmjs.org/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz",
|
||||
@@ -2847,11 +2718,6 @@
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/fs-constants": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
|
||||
"integrity": "sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow=="
|
||||
},
|
||||
"node_modules/fs.realpath": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz",
|
||||
@@ -2929,11 +2795,6 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/github-from-package": {
|
||||
"version": "0.0.0",
|
||||
"resolved": "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz",
|
||||
"integrity": "sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw=="
|
||||
},
|
||||
"node_modules/glob": {
|
||||
"version": "8.1.0",
|
||||
"resolved": "https://registry.npmjs.org/glob/-/glob-8.1.0.tgz",
|
||||
@@ -3385,11 +3246,6 @@
|
||||
"resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
|
||||
"integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ=="
|
||||
},
|
||||
"node_modules/ini": {
|
||||
"version": "1.3.8",
|
||||
"resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz",
|
||||
"integrity": "sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew=="
|
||||
},
|
||||
"node_modules/ipaddr.js": {
|
||||
"version": "1.9.1",
|
||||
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
|
||||
@@ -3398,11 +3254,6 @@
|
||||
"node": ">= 0.10"
|
||||
}
|
||||
},
|
||||
"node_modules/is-arrayish": {
|
||||
"version": "0.3.2",
|
||||
"resolved": "https://registry.npmjs.org/is-arrayish/-/is-arrayish-0.3.2.tgz",
|
||||
"integrity": "sha512-eVRqCvVlZbuw3GrM63ovNSNAeA1K16kaR/LRY/92w0zxQ5/1YzwblUX652i4Xs9RwAGjW9d9y6X88t8OaAJfWQ=="
|
||||
},
|
||||
"node_modules/is-binary-path": {
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz",
|
||||
@@ -3929,17 +3780,6 @@
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/mimic-response": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/mimic-response/-/mimic-response-3.1.0.tgz",
|
||||
"integrity": "sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==",
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
@@ -3970,11 +3810,6 @@
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/mkdirp-classic": {
|
||||
"version": "0.5.3",
|
||||
"resolved": "https://registry.npmjs.org/mkdirp-classic/-/mkdirp-classic-0.5.3.tgz",
|
||||
"integrity": "sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A=="
|
||||
},
|
||||
"node_modules/ms": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz",
|
||||
@@ -4025,11 +3860,6 @@
|
||||
"node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/napi-build-utils": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/napi-build-utils/-/napi-build-utils-1.0.2.tgz",
|
||||
"integrity": "sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg=="
|
||||
},
|
||||
"node_modules/negotiator": {
|
||||
"version": "0.6.3",
|
||||
"resolved": "https://registry.npmjs.org/negotiator/-/negotiator-0.6.3.tgz",
|
||||
@@ -4038,22 +3868,6 @@
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/node-abi": {
|
||||
"version": "3.51.0",
|
||||
"resolved": "https://registry.npmjs.org/node-abi/-/node-abi-3.51.0.tgz",
|
||||
"integrity": "sha512-SQkEP4hmNWjlniS5zdnfIXTk1x7Ome85RDzHlTbBtzE97Gfwz/Ipw4v/Ryk20DWIy3yCNVLVlGKApCnmvYoJbA==",
|
||||
"dependencies": {
|
||||
"semver": "^7.3.5"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/node-addon-api": {
|
||||
"version": "6.1.0",
|
||||
"resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz",
|
||||
"integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA=="
|
||||
},
|
||||
"node_modules/node-fetch": {
|
||||
"version": "2.6.9",
|
||||
"resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.6.9.tgz",
|
||||
@@ -4403,70 +4217,6 @@
|
||||
"node": "^10 || ^12 || >=14"
|
||||
}
|
||||
},
|
||||
"node_modules/prebuild-install": {
|
||||
"version": "7.1.1",
|
||||
"resolved": "https://registry.npmjs.org/prebuild-install/-/prebuild-install-7.1.1.tgz",
|
||||
"integrity": "sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==",
|
||||
"dependencies": {
|
||||
"detect-libc": "^2.0.0",
|
||||
"expand-template": "^2.0.3",
|
||||
"github-from-package": "0.0.0",
|
||||
"minimist": "^1.2.3",
|
||||
"mkdirp-classic": "^0.5.3",
|
||||
"napi-build-utils": "^1.0.1",
|
||||
"node-abi": "^3.3.0",
|
||||
"pump": "^3.0.0",
|
||||
"rc": "^1.2.7",
|
||||
"simple-get": "^4.0.0",
|
||||
"tar-fs": "^2.0.0",
|
||||
"tunnel-agent": "^0.6.0"
|
||||
},
|
||||
"bin": {
|
||||
"prebuild-install": "bin.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10"
|
||||
}
|
||||
},
|
||||
"node_modules/prebuild-install/node_modules/readable-stream": {
|
||||
"version": "3.6.2",
|
||||
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz",
|
||||
"integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==",
|
||||
"dependencies": {
|
||||
"inherits": "^2.0.3",
|
||||
"string_decoder": "^1.1.1",
|
||||
"util-deprecate": "^1.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/prebuild-install/node_modules/tar-fs": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-2.1.1.tgz",
|
||||
"integrity": "sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==",
|
||||
"dependencies": {
|
||||
"chownr": "^1.1.1",
|
||||
"mkdirp-classic": "^0.5.2",
|
||||
"pump": "^3.0.0",
|
||||
"tar-stream": "^2.1.4"
|
||||
}
|
||||
},
|
||||
"node_modules/prebuild-install/node_modules/tar-stream": {
|
||||
"version": "2.2.0",
|
||||
"resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-2.2.0.tgz",
|
||||
"integrity": "sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==",
|
||||
"dependencies": {
|
||||
"bl": "^4.0.3",
|
||||
"end-of-stream": "^1.4.1",
|
||||
"fs-constants": "^1.0.0",
|
||||
"inherits": "^2.0.3",
|
||||
"readable-stream": "^3.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/prettier": {
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.0.3.tgz",
|
||||
@@ -4602,6 +4352,7 @@
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/pump/-/pump-3.0.0.tgz",
|
||||
"integrity": "sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"end-of-stream": "^1.1.0",
|
||||
"once": "^1.3.1"
|
||||
@@ -4621,11 +4372,6 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/queue-tick": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/queue-tick/-/queue-tick-1.0.1.tgz",
|
||||
"integrity": "sha512-kJt5qhMxoszgU/62PLP1CJytzd2NKetjSRnyuj31fDd3Rlcz3fzlFdFLD1SItunPwyqEOkca6GbV612BWfaBag=="
|
||||
},
|
||||
"node_modules/quick-format-unescaped": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/quick-format-unescaped/-/quick-format-unescaped-4.0.4.tgz",
|
||||
@@ -4661,28 +4407,6 @@
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/rc": {
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz",
|
||||
"integrity": "sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==",
|
||||
"dependencies": {
|
||||
"deep-extend": "^0.6.0",
|
||||
"ini": "~1.3.0",
|
||||
"minimist": "^1.2.0",
|
||||
"strip-json-comments": "~2.0.1"
|
||||
},
|
||||
"bin": {
|
||||
"rc": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/rc/node_modules/strip-json-comments": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-2.0.1.tgz",
|
||||
"integrity": "sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==",
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/readable-stream": {
|
||||
"version": "4.3.0",
|
||||
"resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-4.3.0.tgz",
|
||||
@@ -4891,9 +4615,9 @@
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/semver": {
|
||||
"version": "7.5.4",
|
||||
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz",
|
||||
"integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==",
|
||||
"version": "7.5.3",
|
||||
"resolved": "https://registry.npmjs.org/semver/-/semver-7.5.3.tgz",
|
||||
"integrity": "sha512-QBlUtyVk/5EeHbi7X0fw6liDZc7BBmEaSYn01fMU1OUYbf6GPsbTtd8WmnqbI20SeycoHSeiybkE/q1Q+qlThQ==",
|
||||
"dependencies": {
|
||||
"lru-cache": "^6.0.0"
|
||||
},
|
||||
@@ -4951,28 +4675,6 @@
|
||||
"resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz",
|
||||
"integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw=="
|
||||
},
|
||||
"node_modules/sharp": {
|
||||
"version": "0.32.6",
|
||||
"resolved": "https://registry.npmjs.org/sharp/-/sharp-0.32.6.tgz",
|
||||
"integrity": "sha512-KyLTWwgcR9Oe4d9HwCwNM2l7+J0dUQwn/yf7S0EnTtb0eVS4RxO0eUSvxPtzT4F3SY+C4K6fqdv/DO27sJ/v/w==",
|
||||
"hasInstallScript": true,
|
||||
"dependencies": {
|
||||
"color": "^4.2.3",
|
||||
"detect-libc": "^2.0.2",
|
||||
"node-addon-api": "^6.1.0",
|
||||
"prebuild-install": "^7.1.1",
|
||||
"semver": "^7.5.4",
|
||||
"simple-get": "^4.0.1",
|
||||
"tar-fs": "^3.0.4",
|
||||
"tunnel-agent": "^0.6.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.15.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://opencollective.com/libvips"
|
||||
}
|
||||
},
|
||||
"node_modules/shell-quote": {
|
||||
"version": "1.8.1",
|
||||
"resolved": "https://registry.npmjs.org/shell-quote/-/shell-quote-1.8.1.tgz",
|
||||
@@ -5010,57 +4712,6 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/simple-concat": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/simple-concat/-/simple-concat-1.0.1.tgz",
|
||||
"integrity": "sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/feross"
|
||||
},
|
||||
{
|
||||
"type": "patreon",
|
||||
"url": "https://www.patreon.com/feross"
|
||||
},
|
||||
{
|
||||
"type": "consulting",
|
||||
"url": "https://feross.org/support"
|
||||
}
|
||||
]
|
||||
},
|
||||
"node_modules/simple-get": {
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/simple-get/-/simple-get-4.0.1.tgz",
|
||||
"integrity": "sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/feross"
|
||||
},
|
||||
{
|
||||
"type": "patreon",
|
||||
"url": "https://www.patreon.com/feross"
|
||||
},
|
||||
{
|
||||
"type": "consulting",
|
||||
"url": "https://feross.org/support"
|
||||
}
|
||||
],
|
||||
"dependencies": {
|
||||
"decompress-response": "^6.0.0",
|
||||
"once": "^1.3.1",
|
||||
"simple-concat": "^1.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/simple-swizzle": {
|
||||
"version": "0.2.2",
|
||||
"resolved": "https://registry.npmjs.org/simple-swizzle/-/simple-swizzle-0.2.2.tgz",
|
||||
"integrity": "sha512-JA//kQgZtbuY83m+xT+tXJkmJncGMTFT+C+g2h2R9uxkYIrE2yy9sgmcLhCnw57/WSD+Eh3J97FPEDFnbXnDUg==",
|
||||
"dependencies": {
|
||||
"is-arrayish": "^0.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/simple-update-notifier": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/simple-update-notifier/-/simple-update-notifier-2.0.0.tgz",
|
||||
@@ -5158,19 +4809,11 @@
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/streamx": {
|
||||
"version": "2.15.4",
|
||||
"resolved": "https://registry.npmjs.org/streamx/-/streamx-2.15.4.tgz",
|
||||
"integrity": "sha512-uSXKl88bibiUCQ1eMpItRljCzDENcDx18rsfDmV79r0e/ThfrAwxG4Y2FarQZ2G4/21xcOKmFFd1Hue+ZIDwHw==",
|
||||
"dependencies": {
|
||||
"fast-fifo": "^1.1.0",
|
||||
"queue-tick": "^1.0.1"
|
||||
}
|
||||
},
|
||||
"node_modules/string_decoder": {
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz",
|
||||
"integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==",
|
||||
"devOptional": true,
|
||||
"dependencies": {
|
||||
"safe-buffer": "~5.2.0"
|
||||
}
|
||||
@@ -5229,26 +4872,6 @@
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/tar-fs": {
|
||||
"version": "3.0.4",
|
||||
"resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.0.4.tgz",
|
||||
"integrity": "sha512-5AFQU8b9qLfZCX9zp2duONhPmZv0hGYiBPJsyUdqMjzq/mqVpy/rEUSeHk1+YitmxugaptgBh5oDGU3VsAJq4w==",
|
||||
"dependencies": {
|
||||
"mkdirp-classic": "^0.5.2",
|
||||
"pump": "^3.0.0",
|
||||
"tar-stream": "^3.1.5"
|
||||
}
|
||||
},
|
||||
"node_modules/tar-stream": {
|
||||
"version": "3.1.6",
|
||||
"resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.1.6.tgz",
|
||||
"integrity": "sha512-B/UyjYwPpMBv+PaFSWAmtYjwdrlEaZQEhMIBFNC5oEG8lpiW8XjcSdmEaClj28ArfKScKHs2nshz3k2le6crsg==",
|
||||
"dependencies": {
|
||||
"b4a": "^1.6.4",
|
||||
"fast-fifo": "^1.2.0",
|
||||
"streamx": "^2.15.0"
|
||||
}
|
||||
},
|
||||
"node_modules/teeny-request": {
|
||||
"version": "8.0.3",
|
||||
"resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-8.0.3.tgz",
|
||||
@@ -5424,17 +5047,6 @@
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
|
||||
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
|
||||
},
|
||||
"node_modules/tunnel-agent": {
|
||||
"version": "0.6.0",
|
||||
"resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz",
|
||||
"integrity": "sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==",
|
||||
"dependencies": {
|
||||
"safe-buffer": "^5.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/type-is": {
|
||||
"version": "1.6.18",
|
||||
"resolved": "https://registry.npmjs.org/type-is/-/type-is-1.6.18.tgz",
|
||||
|
||||
@@ -23,7 +23,6 @@
|
||||
"@smithy/signature-v4": "^2.0.10",
|
||||
"@smithy/types": "^2.3.4",
|
||||
"axios": "^1.3.5",
|
||||
"check-disk-space": "^3.4.0",
|
||||
"cookie-parser": "^1.4.6",
|
||||
"copyfiles": "^2.4.1",
|
||||
"cors": "^2.8.5",
|
||||
@@ -42,7 +41,6 @@
|
||||
"pino": "^8.11.0",
|
||||
"pino-http": "^8.3.3",
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
# OAI Reverse Proxy
|
||||
|
||||
###
|
||||
# @name OpenAI -- Chat Completions
|
||||
POST https://api.openai.com/v1/chat/completions
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 30,
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test prompt."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Text Completions
|
||||
POST https://api.openai.com/v1/completions
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"max_tokens": 30,
|
||||
"stream": false,
|
||||
"prompt": "This is a test prompt where"
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Create Embedding
|
||||
POST https://api.openai.com/v1/embeddings
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "This is a test embedding input."
|
||||
}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Get Organizations
|
||||
GET https://api.openai.com/v1/organizations
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
|
||||
###
|
||||
# @name OpenAI -- Get Models
|
||||
GET https://api.openai.com/v1/models
|
||||
Authorization: Bearer {{oai-key-1}}
|
||||
|
||||
###
|
||||
# @name Azure OpenAI -- Chat Completions
|
||||
POST https://{{azu-resource-name}}.openai.azure.com/openai/deployments/{{azu-deployment-id}}/chat/completions?api-version=2023-09-01-preview
|
||||
api-key: {{azu-key-1}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"max_tokens": 1,
|
||||
"stream": false,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test prompt."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Get Models
|
||||
GET {{proxy-host}}/proxy/openai/v1/models
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Native Chat Completions
|
||||
POST {{proxy-host}}/proxy/openai/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"temperature": 1,
|
||||
"seed": 123,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "phrase one"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Native Text Completions
|
||||
POST {{proxy-host}}/proxy/openai/v1/turbo-instruct/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"max_tokens": 20,
|
||||
"temperature": 0,
|
||||
"prompt": "Genshin Impact is a game about",
|
||||
"stream": false
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Chat-to-Text API Translation
|
||||
# Accepts a chat completion request and reformats it to work with the text completion API. `model` is ignored.
|
||||
POST {{proxy-host}}/proxy/openai/turbo-instruct/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 20,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the name of the fourth president of the united states?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "That would be George Washington."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I don't think that's right..."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / OpenAI -- Create Embedding
|
||||
POST {{proxy-host}}/proxy/openai/embeddings
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "This is a test embedding input."
|
||||
}
|
||||
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- Native Completion (old API)
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"max_tokens_to_sample": 20,
|
||||
"temperature": 0.2,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- Native Completion (2023-06-01 API)
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-06-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v1.3",
|
||||
"max_tokens_to_sample": 20,
|
||||
"temperature": 0.2,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Anthropic -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/anthropic/v1/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
#anthropic-version: 2023-06-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 20,
|
||||
"stream": false,
|
||||
"temperature": 0,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / AWS Claude -- Native Completion
|
||||
POST {{proxy-host}}/proxy/aws/claude/v1/complete
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
anthropic-version: 2023-01-01
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-v2",
|
||||
"max_tokens_to_sample": 10,
|
||||
"temperature": 0,
|
||||
"stream": true,
|
||||
"prompt": "What is genshin impact\n\n:Assistant:"
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / AWS Claude -- OpenAI-to-Anthropic API Translation
|
||||
POST {{proxy-host}}/proxy/aws/claude/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"max_tokens": 50,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is genshin impact?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
|
||||
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
|
||||
Authorization: Bearer {{proxy-key}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 42,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi what is the name of the fourth president of the united states?"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
const axios = require("axios");
|
||||
|
||||
const concurrentRequests = 5;
|
||||
const headers = {
|
||||
Authorization: "Bearer test",
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const payload = {
|
||||
model: "gpt-4",
|
||||
max_tokens: 1,
|
||||
stream: false,
|
||||
messages: [{ role: "user", content: "Hi" }],
|
||||
};
|
||||
|
||||
const makeRequest = async (i) => {
|
||||
try {
|
||||
const response = await axios.post(
|
||||
"http://localhost:7860/proxy/azure/openai/v1/chat/completions",
|
||||
payload,
|
||||
{ headers }
|
||||
);
|
||||
console.log(
|
||||
`Req ${i} finished with status code ${response.status} and response:`,
|
||||
response.data
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(`Error in req ${i}:`, error.message);
|
||||
}
|
||||
};
|
||||
|
||||
const executeRequestsConcurrently = () => {
|
||||
const promises = [];
|
||||
for (let i = 1; i <= concurrentRequests; i++) {
|
||||
console.log(`Starting request ${i}`);
|
||||
promises.push(makeRequest(i));
|
||||
}
|
||||
|
||||
Promise.all(promises).then(() => {
|
||||
console.log("All requests finished");
|
||||
});
|
||||
};
|
||||
|
||||
executeRequestsConcurrently();
|
||||
@@ -4,7 +4,6 @@ import { HttpError } from "../shared/errors";
|
||||
import { injectLocals } from "../shared/inject-locals";
|
||||
import { withSession } from "../shared/with-session";
|
||||
import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf";
|
||||
import { buildInfoPageHtml } from "../info-page";
|
||||
import { loginRouter } from "./login";
|
||||
import { usersApiRouter as apiRouter } from "./api/users";
|
||||
import { usersWebRouter as webRouter } from "./web/manage";
|
||||
@@ -24,11 +23,6 @@ adminRouter.use(checkCsrfToken);
|
||||
adminRouter.use(injectLocals);
|
||||
adminRouter.use("/", loginRouter);
|
||||
adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter);
|
||||
adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => {
|
||||
return res.send(
|
||||
buildInfoPageHtml(req.protocol + "://" + req.get("host"), true)
|
||||
);
|
||||
});
|
||||
|
||||
adminRouter.use(
|
||||
(
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
<%- include("partials/shared_header", { title: "OAI Reverse Proxy Admin" }) %>
|
||||
<h1>OAI Reverse Proxy Admin</h1>
|
||||
<% if (!usersEnabled) { %>
|
||||
<p style="color: red; background-color: #eedddd; padding: 1em">
|
||||
<strong>🚨 <code>user_token</code> gatekeeper is not enabled.</strong><br />
|
||||
<br />None of the user management features will do anything.
|
||||
</p>
|
||||
<% } %>
|
||||
<% if (!persistenceEnabled) { %>
|
||||
<p style="color: red; background-color: #eedddd; padding: 1em">
|
||||
<strong>⚠️ Users will be lost when the server restarts because persistence is not configured.</strong><br />
|
||||
@@ -25,7 +19,6 @@
|
||||
<li><a href="/admin/manage/import-users">Import Users</a></li>
|
||||
<li><a href="/admin/manage/export-users">Export Users</a></li>
|
||||
<li><a href="/admin/manage/download-stats">Download Rentry Stats</a>
|
||||
<li><a href="/admin/service-info">Service Info</a></li>
|
||||
</ul>
|
||||
<h3>Maintenance</h3>
|
||||
<form id="maintenanceForm" action="/admin/manage/maintenance" method="post">
|
||||
|
||||
+57
-105
@@ -1,17 +1,13 @@
|
||||
import dotenv from "dotenv";
|
||||
import type firebase from "firebase-admin";
|
||||
import path from "path";
|
||||
import { hostname } from "os";
|
||||
import pino from "pino";
|
||||
import type { ModelFamily } from "./shared/models";
|
||||
import { MODEL_FAMILIES } from "./shared/models";
|
||||
dotenv.config();
|
||||
|
||||
const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
|
||||
const isDev = process.env.NODE_ENV !== "production";
|
||||
|
||||
export const DATA_DIR = path.join(__dirname, "..", "data");
|
||||
export const USER_ASSETS_DIR = path.join(DATA_DIR, "user-files");
|
||||
|
||||
type Config = {
|
||||
/** The port the proxy server will listen on. */
|
||||
port: number;
|
||||
@@ -33,17 +29,6 @@ type Config = {
|
||||
* @example `AWS_CREDENTIALS=access_key_1:secret_key_1:us-east-1,access_key_2:secret_key_2:us-west-2`
|
||||
*/
|
||||
awsCredentials?: string;
|
||||
/**
|
||||
* Comma-delimited list of Azure OpenAI credentials. Each credential item
|
||||
* should be a colon-delimited list of Azure resource name, deployment ID, and
|
||||
* API key.
|
||||
*
|
||||
* The resource name is the subdomain in your Azure OpenAI deployment's URL,
|
||||
* e.g. `https://resource-name.openai.azure.com
|
||||
*
|
||||
* @example `AZURE_CREDENTIALS=resource_name_1:deployment_id_1:api_key_1,resource_name_2:deployment_id_2:api_key_2`
|
||||
*/
|
||||
azureCredentials?: string;
|
||||
/**
|
||||
* The proxy key to require for requests. Only applicable if the user
|
||||
* management mode is set to 'proxy_key', and required if so.
|
||||
@@ -66,12 +51,12 @@ type Config = {
|
||||
*/
|
||||
gatekeeper: "none" | "proxy_key" | "user_token";
|
||||
/**
|
||||
* Persistence layer to use for user management.
|
||||
* - `memory`: Users are stored in memory and are lost on restart (default)
|
||||
* - `firebase_rtdb`: Users are stored in a Firebase Realtime Database;
|
||||
* requires `firebaseKey` and `firebaseRtdbUrl` to be set.
|
||||
* Persistence layer to use for user and key management.
|
||||
* - `memory`: Data is stored in memory and lost on restart (default)
|
||||
* - `firebase_rtdb`: Data is stored in Firebase Realtime Database; requires
|
||||
* `firebaseKey` and `firebaseRtdbUrl` to be set.
|
||||
*/
|
||||
gatekeeperStore: "memory" | "firebase_rtdb";
|
||||
persistenceProvider: "memory" | "firebase_rtdb";
|
||||
/** URL of the Firebase Realtime Database if using the Firebase RTDB store. */
|
||||
firebaseRtdbUrl?: string;
|
||||
/**
|
||||
@@ -81,20 +66,26 @@ type Config = {
|
||||
*/
|
||||
firebaseKey?: string;
|
||||
/**
|
||||
* Maximum number of IPs allowed per user token.
|
||||
* The root key under which data will be stored in the Firebase RTDB. This
|
||||
* allows multiple instances of the proxy to share the same database while
|
||||
* keeping their data separate.
|
||||
*
|
||||
* If you want multiple proxies to share the same data, set all of their
|
||||
* `firebaseRtdbRoot` to the same value. Beware that there will likely
|
||||
* be conflicts because concurrent writes are not yet supported and proxies
|
||||
* currently assume they have exclusive access to the database.
|
||||
*
|
||||
* Defaults to the system hostname so that data is kept separate.
|
||||
*/
|
||||
firebaseRtdbRoot: string;
|
||||
/**
|
||||
* Maximum number of IPs per user, after which their token is disabled.
|
||||
* Users with the manually-assigned `special` role are exempt from this limit.
|
||||
* - Defaults to 0, which means that users are not IP-limited.
|
||||
*/
|
||||
maxIpsPerUser: number;
|
||||
/**
|
||||
* Whether a user token should be automatically disabled if it exceeds the
|
||||
* `maxIpsPerUser` limit, or if only connections from new IPs are be rejected.
|
||||
*/
|
||||
maxIpsAutoBan: boolean;
|
||||
/** Per-IP limit for requests per minute to text and chat models. */
|
||||
textModelRateLimit: number;
|
||||
/** Per-IP limit for requests per minute to image generation models. */
|
||||
imageModelRateLimit: number;
|
||||
/** Per-IP limit for requests per minute to OpenAI's completions endpoint. */
|
||||
modelRateLimit: number;
|
||||
/**
|
||||
* For OpenAI, the maximum number of context tokens (prompt + max output) a
|
||||
* user can request before their request is rejected.
|
||||
@@ -113,10 +104,10 @@ type Config = {
|
||||
maxOutputTokensOpenAI: number;
|
||||
/** For Anthropic, the maximum number of sampled tokens a user can request. */
|
||||
maxOutputTokensAnthropic: number;
|
||||
/** Whether requests containing the following phrases should be rejected. */
|
||||
rejectPhrases: string[];
|
||||
/** Whether requests containing disallowed characters should be rejected. */
|
||||
rejectDisallowed?: boolean;
|
||||
/** Message to return when rejecting requests. */
|
||||
rejectMessage: string;
|
||||
rejectMessage?: string;
|
||||
/** Verbosity level of diagnostic logging. */
|
||||
logLevel: "trace" | "debug" | "info" | "warn" | "error";
|
||||
/**
|
||||
@@ -175,20 +166,6 @@ type Config = {
|
||||
quotaRefreshPeriod?: "hourly" | "daily" | string;
|
||||
/** Whether to allow users to change their own nicknames via the UI. */
|
||||
allowNicknameChanges: boolean;
|
||||
/** Whether to show recent DALL-E image generations on the homepage. */
|
||||
showRecentImages: boolean;
|
||||
/**
|
||||
* If true, cookies will be set without the `Secure` attribute, allowing
|
||||
* the admin UI to used over HTTP.
|
||||
*/
|
||||
useInsecureCookies: boolean;
|
||||
/**
|
||||
* Whether to use a more minimal public Service Info page with static content.
|
||||
* Disables all stats pertaining to traffic, prompt/token usage, and queues.
|
||||
* The full info page will appear if you have signed in as an admin using the
|
||||
* configured ADMIN_KEY and go to /admin/service-info.
|
||||
**/
|
||||
staticServiceInfo?: boolean;
|
||||
};
|
||||
|
||||
// To change configs, create a file called .env in the root directory.
|
||||
@@ -199,25 +176,23 @@ export const config: Config = {
|
||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
adminKey: getEnvWithDefault("ADMIN_KEY", ""),
|
||||
gatekeeper: getEnvWithDefault("GATEKEEPER", "none"),
|
||||
gatekeeperStore: getEnvWithDefault("GATEKEEPER_STORE", "memory"),
|
||||
persistenceProvider: getEnvWithDefault("PERSISTENCE_PROVIDER", "memory"),
|
||||
maxIpsPerUser: getEnvWithDefault("MAX_IPS_PER_USER", 0),
|
||||
maxIpsAutoBan: getEnvWithDefault("MAX_IPS_AUTO_BAN", true),
|
||||
firebaseRtdbUrl: getEnvWithDefault("FIREBASE_RTDB_URL", undefined),
|
||||
firebaseKey: getEnvWithDefault("FIREBASE_KEY", undefined),
|
||||
textModelRateLimit: getEnvWithDefault("TEXT_MODEL_RATE_LIMIT", 4),
|
||||
imageModelRateLimit: getEnvWithDefault("IMAGE_MODEL_RATE_LIMIT", 4),
|
||||
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 16384),
|
||||
firebaseRtdbRoot: getEnvWithDefault("FIREBASE_RTDB_ROOT", hostname()),
|
||||
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 4),
|
||||
maxContextTokensOpenAI: getEnvWithDefault("MAX_CONTEXT_TOKENS_OPENAI", 0),
|
||||
maxContextTokensAnthropic: getEnvWithDefault(
|
||||
"MAX_CONTEXT_TOKENS_ANTHROPIC",
|
||||
0
|
||||
),
|
||||
maxOutputTokensOpenAI: getEnvWithDefault(
|
||||
["MAX_OUTPUT_TOKENS_OPENAI", "MAX_OUTPUT_TOKENS"],
|
||||
400
|
||||
300
|
||||
),
|
||||
maxOutputTokensAnthropic: getEnvWithDefault(
|
||||
["MAX_OUTPUT_TOKENS_ANTHROPIC", "MAX_OUTPUT_TOKENS"],
|
||||
@@ -227,16 +202,11 @@ export const config: Config = {
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4-32k",
|
||||
]),
|
||||
rejectPhrases: parseCsv(getEnvWithDefault("REJECT_PHRASES", "")),
|
||||
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
"This content violates /aicg/'s acceptable use policy."
|
||||
@@ -258,21 +228,16 @@ export const config: Config = {
|
||||
"You must be over the age of majority in your country to use this service."
|
||||
),
|
||||
blockRedirect: getEnvWithDefault("BLOCK_REDIRECT", "https://www.9gag.com"),
|
||||
tokenQuota: MODEL_FAMILIES.reduce(
|
||||
(acc, family: ModelFamily) => {
|
||||
acc[family] = getEnvWithDefault(
|
||||
`TOKEN_QUOTA_${family.toUpperCase().replace(/-/g, "_")}`,
|
||||
0
|
||||
) as number;
|
||||
return acc;
|
||||
},
|
||||
{} as { [key in ModelFamily]: number }
|
||||
),
|
||||
tokenQuota: {
|
||||
turbo: getEnvWithDefault("TOKEN_QUOTA_TURBO", 0),
|
||||
gpt4: getEnvWithDefault("TOKEN_QUOTA_GPT4", 0),
|
||||
"gpt4-32k": getEnvWithDefault("TOKEN_QUOTA_GPT4_32K", 0),
|
||||
claude: getEnvWithDefault("TOKEN_QUOTA_CLAUDE", 0),
|
||||
bison: getEnvWithDefault("TOKEN_QUOTA_BISON", 0),
|
||||
"aws-claude": getEnvWithDefault("TOKEN_QUOTA_AWS_CLAUDE", 0),
|
||||
},
|
||||
quotaRefreshPeriod: getEnvWithDefault("QUOTA_REFRESH_PERIOD", undefined),
|
||||
allowNicknameChanges: getEnvWithDefault("ALLOW_NICKNAME_CHANGES", true),
|
||||
showRecentImages: getEnvWithDefault("SHOW_RECENT_IMAGES", true),
|
||||
useInsecureCookies: getEnvWithDefault("USE_INSECURE_COOKIES", isDev),
|
||||
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
|
||||
} as const;
|
||||
|
||||
function generateCookieSecret() {
|
||||
@@ -288,17 +253,20 @@ function generateCookieSecret() {
|
||||
export const COOKIE_SECRET = generateCookieSecret();
|
||||
|
||||
export async function assertConfigIsValid() {
|
||||
if (process.env.MODEL_RATE_LIMIT !== undefined) {
|
||||
const limit =
|
||||
parseInt(process.env.MODEL_RATE_LIMIT, 10) || config.textModelRateLimit;
|
||||
|
||||
config.textModelRateLimit = limit;
|
||||
config.imageModelRateLimit = Math.max(Math.floor(limit / 2), 1);
|
||||
|
||||
if (process.env.TURBO_ONLY === "true") {
|
||||
startupLogger.warn(
|
||||
{ textLimit: limit, imageLimit: config.imageModelRateLimit },
|
||||
"MODEL_RATE_LIMIT is deprecated. Use TEXT_MODEL_RATE_LIMIT and IMAGE_MODEL_RATE_LIMIT instead."
|
||||
"TURBO_ONLY is deprecated. Use ALLOWED_MODEL_FAMILIES=turbo instead."
|
||||
);
|
||||
config.allowedModelFamilies = config.allowedModelFamilies.filter(
|
||||
(f) => !f.includes("gpt4")
|
||||
);
|
||||
}
|
||||
|
||||
if (!!process.env.GATEKEEPER_STORE) {
|
||||
startupLogger.warn(
|
||||
"GATEKEEPER_STORE is deprecated. Use PERSISTENCE_PROVIDER instead. Configuration will be migrated."
|
||||
);
|
||||
config.persistenceProvider = process.env.GATEKEEPER_STORE as any;
|
||||
}
|
||||
|
||||
if (!["none", "proxy_key", "user_token"].includes(config.gatekeeper)) {
|
||||
@@ -326,11 +294,11 @@ export async function assertConfigIsValid() {
|
||||
}
|
||||
|
||||
if (
|
||||
config.gatekeeperStore === "firebase_rtdb" &&
|
||||
config.persistenceProvider === "firebase_rtdb" &&
|
||||
(!config.firebaseKey || !config.firebaseRtdbUrl)
|
||||
) {
|
||||
throw new Error(
|
||||
"Firebase RTDB store requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
|
||||
"Firebase RTDB persistence requires `FIREBASE_KEY` and `FIREBASE_RTDB_URL` to be set."
|
||||
);
|
||||
}
|
||||
|
||||
@@ -338,8 +306,7 @@ export async function assertConfigIsValid() {
|
||||
// them to users.
|
||||
for (const key of getKeys(config)) {
|
||||
const maybeSensitive = ["key", "credentials", "secret", "password"].some(
|
||||
(sensitive) =>
|
||||
key.toLowerCase().includes(sensitive) && !["checkKeys"].includes(key)
|
||||
(sensitive) => key.toLowerCase().includes(sensitive)
|
||||
);
|
||||
const secured = new Set([...SENSITIVE_KEYS, ...OMITTED_KEYS]);
|
||||
if (maybeSensitive && !secured.has(key))
|
||||
@@ -368,25 +335,19 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
||||
"anthropicKey",
|
||||
"googlePalmKey",
|
||||
"awsCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
"adminKey",
|
||||
"rejectPhrases",
|
||||
"checkKeys",
|
||||
"showTokenCosts",
|
||||
"googleSheetsKey",
|
||||
"persistenceProvider",
|
||||
"firebaseKey",
|
||||
"firebaseRtdbUrl",
|
||||
"gatekeeperStore",
|
||||
"maxIpsPerUser",
|
||||
"blockedOrigins",
|
||||
"blockMessage",
|
||||
"blockRedirect",
|
||||
"allowNicknameChanges",
|
||||
"showRecentImages",
|
||||
"useInsecureCookies",
|
||||
"staticServiceInfo",
|
||||
"checkKeys",
|
||||
"allowedModelFamilies",
|
||||
];
|
||||
|
||||
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
|
||||
@@ -435,7 +396,6 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_PALM_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
) {
|
||||
return value as unknown as T;
|
||||
@@ -455,7 +415,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
let firebaseApp: firebase.app.App | undefined;
|
||||
|
||||
async function maybeInitializeFirebase() {
|
||||
if (!config.gatekeeperStore.startsWith("firebase")) {
|
||||
if (!config.persistenceProvider.startsWith("firebase")) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -477,11 +437,3 @@ export function getFirebaseApp(): firebase.app.App {
|
||||
}
|
||||
return firebaseApp;
|
||||
}
|
||||
|
||||
function parseCsv(val: string): string[] {
|
||||
if (!val) return [];
|
||||
|
||||
const regex = /(".*?"|[^",]+)(?=\s*,|\s*$)/g;
|
||||
const matches = val.match(regex) || [];
|
||||
return matches.map((item) => item.replace(/^"|"$/g, "").trim());
|
||||
}
|
||||
|
||||
+94
-237
@@ -1,4 +1,3 @@
|
||||
/** This whole module really sucks */
|
||||
import fs from "fs";
|
||||
import { Request, Response } from "express";
|
||||
import showdown from "showdown";
|
||||
@@ -6,21 +5,15 @@ import { config, listConfig } from "./config";
|
||||
import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
AzureOpenAIKey,
|
||||
GooglePalmKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
keyPool,
|
||||
} from "./shared/key-management";
|
||||
import {
|
||||
AzureOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
} from "./shared/models";
|
||||
import { ModelFamily, OpenAIModelFamily } from "./shared/models";
|
||||
import { getUniqueIps } from "./proxy/rate-limit";
|
||||
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
|
||||
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
|
||||
import { assertNever } from "./shared/utils";
|
||||
import { getLastNImages } from "./shared/file-storage/image-history";
|
||||
|
||||
const INFO_PAGE_TTL = 2000;
|
||||
let infoPageHtml: string | undefined;
|
||||
@@ -29,8 +22,6 @@ let infoPageLastUpdated = 0;
|
||||
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
|
||||
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
|
||||
k.service === "openai";
|
||||
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
||||
@@ -56,7 +47,6 @@ type ServiceAggregates = {
|
||||
anthropicKeys?: number;
|
||||
palmKeys?: number;
|
||||
awsKeys?: number;
|
||||
azureKeys?: number;
|
||||
proompts: number;
|
||||
tokens: number;
|
||||
tokenCost: number;
|
||||
@@ -71,18 +61,17 @@ const serviceStats = new Map<keyof ServiceAggregates, number>();
|
||||
|
||||
export const handleInfoPage = (req: Request, res: Response) => {
|
||||
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
|
||||
return res.send(infoPageHtml);
|
||||
res.send(infoPageHtml);
|
||||
return;
|
||||
}
|
||||
|
||||
// Sometimes huggingface doesn't send the host header and makes us guess.
|
||||
const baseUrl =
|
||||
process.env.SPACE_ID && !req.get("host")?.includes("hf.space")
|
||||
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
|
||||
: req.protocol + "://" + req.get("host");
|
||||
|
||||
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
|
||||
infoPageLastUpdated = Date.now();
|
||||
|
||||
res.send(infoPageHtml);
|
||||
res.send(cacheInfoPageHtml(baseUrl));
|
||||
};
|
||||
|
||||
function getCostString(cost: number) {
|
||||
@@ -90,9 +79,8 @@ function getCostString(cost: number) {
|
||||
return ` ($${cost.toFixed(2)})`;
|
||||
}
|
||||
|
||||
export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
function cacheInfoPageHtml(baseUrl: string) {
|
||||
const keys = keyPool.list();
|
||||
const hideFullInfo = config.staticServiceInfo && !asAdmin;
|
||||
|
||||
modelStats.clear();
|
||||
serviceStats.clear();
|
||||
@@ -102,58 +90,32 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
||||
const palmKeys = serviceStats.get("palmKeys") || 0;
|
||||
const awsKeys = serviceStats.get("awsKeys") || 0;
|
||||
const azureKeys = serviceStats.get("azureKeys") || 0;
|
||||
const proompts = serviceStats.get("proompts") || 0;
|
||||
const tokens = serviceStats.get("tokens") || 0;
|
||||
const tokenCost = serviceStats.get("tokenCost") || 0;
|
||||
|
||||
const allowDalle = config.allowedModelFamilies.includes("dall-e");
|
||||
|
||||
const endpoints = {
|
||||
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
|
||||
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
|
||||
...(openaiKeys && allowDalle
|
||||
? { ["openai-image"]: baseUrl + "/openai-image" }
|
||||
: {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
|
||||
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
|
||||
};
|
||||
|
||||
const stats = {
|
||||
proompts,
|
||||
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
|
||||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
};
|
||||
|
||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
|
||||
for (const key of Object.keys(keyInfo)) {
|
||||
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
|
||||
}
|
||||
|
||||
const providerInfo = {
|
||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||
...(palmKeys ? getPalmInfo() : {}),
|
||||
...(awsKeys ? getAwsInfo() : {}),
|
||||
...(azureKeys ? getAzureInfo() : {}),
|
||||
};
|
||||
|
||||
if (hideFullInfo) {
|
||||
for (const provider of Object.keys(providerInfo)) {
|
||||
delete (providerInfo as any)[provider].proomptersInQueue;
|
||||
delete (providerInfo as any)[provider].estimatedQueueTime;
|
||||
delete (providerInfo as any)[provider].usage;
|
||||
}
|
||||
}
|
||||
|
||||
const info = {
|
||||
uptime: Math.floor(process.uptime()),
|
||||
endpoints,
|
||||
...(hideFullInfo ? {} : stats),
|
||||
...keyInfo,
|
||||
...providerInfo,
|
||||
endpoints: {
|
||||
...(openaiKeys ? { openai: baseUrl + "/proxy/openai" } : {}),
|
||||
...(openaiKeys
|
||||
? { ["openai2"]: baseUrl + "/proxy/openai/turbo-instruct" }
|
||||
: {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/proxy/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/proxy/google-palm" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/proxy/aws/claude" } : {}),
|
||||
},
|
||||
proompts,
|
||||
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
|
||||
...(config.modelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
openaiKeys,
|
||||
anthropicKeys,
|
||||
palmKeys,
|
||||
awsKeys,
|
||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||
...(palmKeys ? { "palm-bison": getPalmInfo() } : {}),
|
||||
...(awsKeys ? { "aws-claude": getAwsInfo() } : {}),
|
||||
config: listConfig(),
|
||||
build: process.env.BUILD_INFO || "dev",
|
||||
};
|
||||
@@ -161,7 +123,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
const title = getServerTitle();
|
||||
const headerHtml = buildInfoPageHeader(new showdown.Converter(), title);
|
||||
|
||||
return `<!DOCTYPE html>
|
||||
const pageBody = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
@@ -176,6 +138,11 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
${getSelfServiceLinks()}
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
infoPageHtml = pageBody;
|
||||
infoPageLastUpdated = Date.now();
|
||||
|
||||
return pageBody;
|
||||
}
|
||||
|
||||
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
|
||||
@@ -199,10 +166,13 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
||||
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
let sumTokens = 0;
|
||||
let sumCost = 0;
|
||||
let family: ModelFamily;
|
||||
const families = k.modelFamilies.filter((f) =>
|
||||
config.allowedModelFamilies.includes(f)
|
||||
);
|
||||
|
||||
switch (k.service) {
|
||||
case "openai":
|
||||
@@ -213,35 +183,30 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
|
||||
// Technically this would not account for keys that have tokens recorded
|
||||
// on models they aren't provisioned for, but that would be strange
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
|
||||
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
});
|
||||
|
||||
if (families.includes("gpt4-32k")) {
|
||||
family = "gpt4-32k";
|
||||
} else if (families.includes("gpt4")) {
|
||||
family = "gpt4";
|
||||
} else {
|
||||
family = "turbo";
|
||||
}
|
||||
|
||||
increment(modelStats, `${family}__trial`, k.isTrial ? 1 : 0);
|
||||
break;
|
||||
case "azure":
|
||||
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
|
||||
k.modelFamilies.forEach((f) => {
|
||||
const tokens = k[`${f}Tokens`];
|
||||
sumTokens += tokens;
|
||||
sumCost += getTokenCostUsd(f, tokens);
|
||||
increment(modelStats, `${f}__tokens`, tokens);
|
||||
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
|
||||
});
|
||||
break;
|
||||
case "anthropic": {
|
||||
case "anthropic":
|
||||
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
|
||||
const family = "claude";
|
||||
family = "claude";
|
||||
sumTokens += k.claudeTokens;
|
||||
sumCost += getTokenCostUsd(family, k.claudeTokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.claudeTokens);
|
||||
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
|
||||
increment(
|
||||
@@ -250,24 +215,18 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
Boolean(k.lastChecked) ? 0 : 1
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
case "google-palm":
|
||||
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
|
||||
const family = "bison";
|
||||
family = "bison";
|
||||
sumTokens += k.bisonTokens;
|
||||
sumCost += getTokenCostUsd(family, k.bisonTokens);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.bisonTokens);
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
case "aws":
|
||||
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
|
||||
const family = "aws-claude";
|
||||
family = "aws-claude";
|
||||
sumTokens += k["aws-claudeTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
|
||||
|
||||
// Ignore revoked keys for aws logging stats, but include keys where the
|
||||
@@ -277,13 +236,19 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
|
||||
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(k.service);
|
||||
}
|
||||
|
||||
increment(serviceStats, "tokens", sumTokens);
|
||||
increment(serviceStats, "tokenCost", sumCost);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
if ("isRevoked" in k) {
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
}
|
||||
if ("isOverQuota" in k) {
|
||||
increment(modelStats, `${family}__overQuota`, k.isOverQuota ? 1 : 0);
|
||||
}
|
||||
}
|
||||
|
||||
function getOpenAIInfo() {
|
||||
@@ -299,13 +264,14 @@ function getOpenAIInfo() {
|
||||
};
|
||||
} = {};
|
||||
|
||||
const keys = keyPool.list().filter(keyIsOpenAIKey);
|
||||
const enabledFamilies = new Set(config.allowedModelFamilies);
|
||||
const accessibleFamilies = keys
|
||||
.flatMap((k) => k.modelFamilies)
|
||||
.filter((f) => enabledFamilies.has(f))
|
||||
.concat("turbo");
|
||||
const familySet = new Set(accessibleFamilies);
|
||||
const allowedFamilies = new Set(config.allowedModelFamilies);
|
||||
let families = new Set<OpenAIModelFamily>();
|
||||
const keys = keyPool.list().filter((k) => {
|
||||
const isOpenAI = keyIsOpenAIKey(k);
|
||||
if (isOpenAI) k.modelFamilies.forEach((f) => families.add(f));
|
||||
return isOpenAI;
|
||||
}) as Omit<OpenAIKey, "key">[];
|
||||
families = new Set([...families].filter((f) => allowedFamilies.has(f)));
|
||||
|
||||
if (config.checkKeys) {
|
||||
const unchecked = serviceStats.get("openAiUncheckedKeys") || 0;
|
||||
@@ -315,7 +281,7 @@ function getOpenAIInfo() {
|
||||
info.openaiKeys = keys.length;
|
||||
info.openaiOrgs = getUniqueOpenAIOrgs(keys);
|
||||
|
||||
familySet.forEach((f) => {
|
||||
families.forEach((f) => {
|
||||
const tokens = modelStats.get(`${f}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(f, tokens);
|
||||
|
||||
@@ -326,13 +292,6 @@ function getOpenAIInfo() {
|
||||
revokedKeys: modelStats.get(`${f}__revoked`) || 0,
|
||||
overQuotaKeys: modelStats.get(`${f}__overQuota`) || 0,
|
||||
};
|
||||
|
||||
// Don't show trial/revoked keys for non-turbo families.
|
||||
// Generally those stats only make sense for the lowest-tier model.
|
||||
if (f !== "turbo") {
|
||||
delete info[f]!.trialKeys;
|
||||
delete info[f]!.revokedKeys;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
info.status = "Key checking is disabled.";
|
||||
@@ -344,14 +303,11 @@ function getOpenAIInfo() {
|
||||
};
|
||||
}
|
||||
|
||||
familySet.forEach((f) => {
|
||||
if (enabledFamilies.has(f)) {
|
||||
if (!info[f]) info[f] = { activeKeys: 0 }; // may occur if checkKeys is disabled
|
||||
families.forEach((f) => {
|
||||
if (info[f]) {
|
||||
const { estimatedQueueTime, proomptersInQueue } = getQueueInformation(f);
|
||||
info[f]!.proomptersInQueue = proomptersInQueue;
|
||||
info[f]!.estimatedQueueTime = estimatedQueueTime;
|
||||
} else {
|
||||
(info[f]! as any).status = "GPT-3.5-Turbo is disabled on this proxy.";
|
||||
}
|
||||
});
|
||||
|
||||
@@ -362,7 +318,6 @@ function getAnthropicInfo() {
|
||||
const claudeInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("claude__active") || 0,
|
||||
pozzed: modelStats.get("claude__pozzed") || 0,
|
||||
revoked: modelStats.get("claude__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("claude");
|
||||
@@ -380,7 +335,6 @@ function getAnthropicInfo() {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
...(unchecked > 0 ? { status: `Checking ${unchecked} keys...` } : {}),
|
||||
activeKeys: claudeInfo.active,
|
||||
revokedKeys: claudeInfo.revoked,
|
||||
...(config.checkKeys ? { pozzedKeys: claudeInfo.pozzed } : {}),
|
||||
proomptersInQueue: claudeInfo.queued,
|
||||
estimatedQueueTime: claudeInfo.queueTime,
|
||||
@@ -391,7 +345,6 @@ function getAnthropicInfo() {
|
||||
function getPalmInfo() {
|
||||
const bisonInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("bison__active") || 0,
|
||||
revoked: modelStats.get("bison__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("bison");
|
||||
@@ -402,20 +355,16 @@ function getPalmInfo() {
|
||||
const cost = getTokenCostUsd("bison", tokens);
|
||||
|
||||
return {
|
||||
bison: {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
revokedKeys: bisonInfo.revoked,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
},
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
};
|
||||
}
|
||||
|
||||
function getAwsInfo() {
|
||||
const awsInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("aws-claude__active") || 0,
|
||||
revoked: modelStats.get("aws-claude__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("aws-claude");
|
||||
@@ -428,65 +377,20 @@ function getAwsInfo() {
|
||||
const logged = modelStats.get("aws-claude__awsLogged") || 0;
|
||||
const logMsg = config.allowAwsLogging
|
||||
? `${logged} active keys are potentially logged.`
|
||||
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
|
||||
: `${logged} active keys are potentially logged and can't be used.`;
|
||||
|
||||
return {
|
||||
"aws-claude": {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: awsInfo.active,
|
||||
revokedKeys: awsInfo.revoked,
|
||||
proomptersInQueue: awsInfo.queued,
|
||||
estimatedQueueTime: awsInfo.queueTime,
|
||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||
},
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: awsInfo.active,
|
||||
proomptersInQueue: awsInfo.queued,
|
||||
estimatedQueueTime: awsInfo.queueTime,
|
||||
...(logged > 0 ? { privacy: logMsg } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function getAzureInfo() {
|
||||
const azureFamilies = [
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-turbo",
|
||||
"azure-gpt4-32k",
|
||||
] as const;
|
||||
|
||||
const azureInfo: {
|
||||
[modelFamily in AzureOpenAIModelFamily]?: {
|
||||
usage?: string;
|
||||
activeKeys: number;
|
||||
revokedKeys?: number;
|
||||
proomptersInQueue?: number;
|
||||
estimatedQueueTime?: string;
|
||||
};
|
||||
} = {};
|
||||
for (const family of azureFamilies) {
|
||||
const familyAllowed = config.allowedModelFamilies.includes(family);
|
||||
const activeKeys = modelStats.get(`${family}__active`) || 0;
|
||||
|
||||
if (!familyAllowed || activeKeys === 0) continue;
|
||||
|
||||
azureInfo[family] = {
|
||||
activeKeys,
|
||||
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation(family);
|
||||
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
|
||||
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get(`${family}__tokens`) || 0;
|
||||
const cost = getTokenCostUsd(family, tokens);
|
||||
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
|
||||
cost
|
||||
)}`;
|
||||
}
|
||||
|
||||
return azureInfo;
|
||||
}
|
||||
|
||||
const customGreeting = fs.existsSync("greeting.md")
|
||||
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
|
||||
: "";
|
||||
? fs.readFileSync("greeting.md", "utf8")
|
||||
: null;
|
||||
|
||||
/**
|
||||
* If the server operator provides a `greeting.md` file, it will be included in
|
||||
@@ -497,20 +401,16 @@ function buildInfoPageHeader(converter: showdown.Converter, title: string) {
|
||||
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
|
||||
# ${title}`;
|
||||
if (config.promptLogging) {
|
||||
infoBody += `\n## Prompt Logging Enabled
|
||||
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
|
||||
infoBody += `\n## Prompt logging is enabled!
|
||||
The server operator has enabled prompt logging. The prompts you send to this proxy and the AI responses you receive may be saved.
|
||||
|
||||
[You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/shared/prompt-logging/index.ts).
|
||||
Logs are anonymous and do not contain IP addresses or timestamps. [You can see the type of data logged here, along with the rest of the code.](https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/prompt-logging/index.ts).
|
||||
|
||||
**If you are uncomfortable with this, don't send prompts to this proxy!**`;
|
||||
}
|
||||
|
||||
if (config.staticServiceInfo) {
|
||||
return converter.makeHtml(infoBody + customGreeting);
|
||||
}
|
||||
|
||||
const waits: string[] = [];
|
||||
infoBody += `\n## Estimated Wait Times`;
|
||||
infoBody += `\n## Estimated Wait Times\nIf the AI is busy, your prompt will processed when a slot frees up.`;
|
||||
|
||||
if (config.openaiKey) {
|
||||
// TODO: un-fuck this
|
||||
@@ -532,13 +432,6 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
|
||||
if (hasGpt432k && allowedGpt432k) {
|
||||
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
|
||||
}
|
||||
|
||||
const dalleWait = getQueueInformation("dall-e").estimatedQueueTime;
|
||||
const hasDalle = keys.some((k) => k.modelFamilies.includes("dall-e"));
|
||||
const allowedDalle = config.allowedModelFamilies.includes("dall-e");
|
||||
if (hasDalle && allowedDalle) {
|
||||
waits.push(`**DALL-E:** ${dalleWait}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (config.anthropicKey) {
|
||||
@@ -553,10 +446,9 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
|
||||
|
||||
infoBody += "\n\n" + waits.join(" / ");
|
||||
|
||||
infoBody += customGreeting;
|
||||
|
||||
infoBody += buildRecentImageSection();
|
||||
|
||||
if (customGreeting) {
|
||||
infoBody += `\n## Server Greeting\n${customGreeting}`;
|
||||
}
|
||||
return converter.makeHtml(infoBody);
|
||||
}
|
||||
|
||||
@@ -599,44 +491,9 @@ function getServerTitle() {
|
||||
return "OAI Reverse Proxy";
|
||||
}
|
||||
|
||||
function buildRecentImageSection() {
|
||||
if (
|
||||
!config.allowedModelFamilies.includes("dall-e") ||
|
||||
!config.showRecentImages
|
||||
) {
|
||||
return "";
|
||||
}
|
||||
|
||||
let html = `<h2>Recent DALL-E Generations</h2>`;
|
||||
const recentImages = getLastNImages(12).reverse();
|
||||
if (recentImages.length === 0) {
|
||||
html += `<p>No images yet.</p>`;
|
||||
return html;
|
||||
}
|
||||
|
||||
html += `<div style="display: flex; flex-wrap: wrap;" id="recent-images">`;
|
||||
for (const { url, prompt } of recentImages) {
|
||||
const thumbUrl = url.replace(/\.png$/, "_t.jpg");
|
||||
const escapedPrompt = escapeHtml(prompt);
|
||||
html += `<div style="margin: 0.5em;" class="recent-image">
|
||||
<a href="${url}" target="_blank"><img src="${thumbUrl}" title="${escapedPrompt}" alt="${escapedPrompt}" style="max-width: 150px; max-height: 150px;" /></a>
|
||||
</div>`;
|
||||
}
|
||||
html += `</div>`;
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
function escapeHtml(unsafe: string) {
|
||||
return unsafe
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
function getExternalUrlForHuggingfaceSpaceId(spaceId: string) {
|
||||
// Huggingface broke their amazon elb config and no longer sends the
|
||||
// x-forwarded-host header. This is a workaround.
|
||||
try {
|
||||
const [username, spacename] = spaceId.split("/");
|
||||
return `https://${username}-${spacename.replace(/_/g, "-")}.hf.space`;
|
||||
|
||||
+21
-10
@@ -7,10 +7,13 @@ import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
addAnthropicPreamble,
|
||||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
createOnProxyReqHandler,
|
||||
languageFilter,
|
||||
stripHeaders, createOnProxyReqHandler
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
@@ -39,9 +42,8 @@ const getModelsResponse = () => {
|
||||
"claude-instant-v1.1",
|
||||
"claude-instant-v1.1-100k",
|
||||
"claude-instant-v1.0",
|
||||
"claude-2",
|
||||
"claude-2", // claude-2 is 100k by default it seems
|
||||
"claude-2.0",
|
||||
"claude-2.1",
|
||||
];
|
||||
|
||||
const models = claudeVariants.map((id) => ({
|
||||
@@ -85,8 +87,9 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
|
||||
body = transformAnthropicResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
// TODO: Remove once tokenization is stable
|
||||
if (req.debug) {
|
||||
body.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
@@ -126,15 +129,23 @@ function transformAnthropicResponse(
|
||||
};
|
||||
}
|
||||
|
||||
const anthropicProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
const anthropicProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "https://api.anthropic.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKey, addAnthropicPreamble, finalizeBody],
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
addAnthropicPreamble,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([anthropicResponseHandler]),
|
||||
error: handleProxyError,
|
||||
@@ -143,8 +154,8 @@ const anthropicProxy = createQueueMiddleware({
|
||||
// Send OpenAI-compat requests to the real Anthropic endpoint.
|
||||
"^/v1/chat/completions": "/v1/complete",
|
||||
},
|
||||
}),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const anthropicRouter = Router();
|
||||
anthropicRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
+36
-58
@@ -7,18 +7,20 @@ import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
createPreprocessorMiddleware,
|
||||
stripHeaders,
|
||||
signAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
finalizeAwsRequest,
|
||||
createOnProxyReqHandler,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
createOnProxyResHandler,
|
||||
} from "./middleware/response";
|
||||
|
||||
const LATEST_AWS_V2_MINOR_VERSION = "1";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
@@ -29,11 +31,7 @@ const getModelsResponse = () => {
|
||||
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
|
||||
const variants = [
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
];
|
||||
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
@@ -76,8 +74,9 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
|
||||
body = transformAwsResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
// TODO: Remove once tokenization is stable
|
||||
if (req.debug) {
|
||||
body.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
|
||||
// AWS does not confirm the model in the response, so we have to add it
|
||||
@@ -120,24 +119,34 @@ function transformAwsResponse(
|
||||
};
|
||||
}
|
||||
|
||||
const awsProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
const awsProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
if (!signedRequest) {
|
||||
throw new Error("AWS requests must go through signAwsRequest first");
|
||||
}
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
// Credentials are added by signAwsRequest preprocessor
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeAwsRequest,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get("/v1/models", handleModelRequest);
|
||||
@@ -147,7 +156,7 @@ awsRouter.post(
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic", outApi: "anthropic", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
{ afterTransform: [maybeReassignModel, signAwsRequest] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
@@ -157,7 +166,7 @@ awsRouter.post(
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
{ afterTransform: [maybeReassignModel, signAwsRequest] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
@@ -172,47 +181,16 @@ awsRouter.post(
|
||||
*/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const model = req.body.model;
|
||||
|
||||
// If client already specified an AWS Claude model ID, use it
|
||||
if (model.includes("anthropic.claude")) {
|
||||
return;
|
||||
}
|
||||
|
||||
const pattern = /^(claude-)?(instant-)?(v)?(\d+)(\.(\d+))?(-\d+k)?$/i;
|
||||
const match = model.match(pattern);
|
||||
|
||||
// If there's no match, return the latest v2 model
|
||||
if (!match) {
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
const [, , instant, , major, , minor] = match;
|
||||
|
||||
if (instant) {
|
||||
req.body.model = "anthropic.claude-instant-v1";
|
||||
return;
|
||||
}
|
||||
|
||||
// There's only one v1 model
|
||||
if (major === "1") {
|
||||
// User's client sent an AWS model already
|
||||
if (model.includes("anthropic.claude")) return;
|
||||
// User's client is sending Anthropic-style model names, check for v1
|
||||
if (model.match(/^claude-v?1/)) {
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
return;
|
||||
} else {
|
||||
// User's client requested v2 or possibly some OpenAI model, default to v2
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
}
|
||||
|
||||
// Try to map Anthropic API v2 models to AWS v2 models
|
||||
if (major === "2") {
|
||||
if (minor === "0") {
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
return;
|
||||
}
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
}
|
||||
|
||||
// Fallback to latest v2 model
|
||||
req.body.model = `anthropic.claude-v2:${LATEST_AWS_V2_MINOR_VERSION}`;
|
||||
return;
|
||||
// TODO: Handle claude-instant
|
||||
}
|
||||
|
||||
export const aws = awsRouter;
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
ModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { KNOWN_OPENAI_MODELS } from "./openai";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addAzureKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
function getModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
let available = new Set<AzureOpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "azure") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as AzureOpenAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
const models = KNOWN_OPENAI_MODELS.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "azure",
|
||||
permission: [
|
||||
{
|
||||
id: "modelperm-" + id,
|
||||
object: "model_permission",
|
||||
created: new Date().getTime(),
|
||||
organization: "*",
|
||||
group: null,
|
||||
is_blocking: false,
|
||||
},
|
||||
],
|
||||
root: id,
|
||||
parent: null,
|
||||
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const azureOpenAIProxy = createQueueMiddleware({
|
||||
beforeProxy: addAzureKey,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "will be set by router",
|
||||
router: (req) => {
|
||||
if (!req.signedRequest) throw new Error("signedRequest not set");
|
||||
const { hostname, path } = req.signedRequest;
|
||||
return `https://${hostname}${path}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const azureOpenAIRouter = Router();
|
||||
azureOpenAIRouter.get("/v1/models", handleModelRequest);
|
||||
azureOpenAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai",
|
||||
outApi: "openai",
|
||||
service: "azure",
|
||||
}),
|
||||
azureOpenAIProxy
|
||||
);
|
||||
|
||||
export const azure = azureOpenAIRouter;
|
||||
+11
-14
@@ -46,22 +46,19 @@ export const gatekeeper: RequestHandler = (req, res, next) => {
|
||||
}
|
||||
|
||||
if (GATEKEEPER === "user_token" && token) {
|
||||
const { user, result } = authenticate(token, req.ip);
|
||||
|
||||
switch (result) {
|
||||
case "success":
|
||||
req.user = user;
|
||||
return next();
|
||||
case "limited":
|
||||
const user = authenticate(token, req.ip);
|
||||
if (user) {
|
||||
req.user = user;
|
||||
return next();
|
||||
} else {
|
||||
const maybeBannedUser = getUser(token);
|
||||
if (maybeBannedUser?.disabledAt) {
|
||||
return res.status(403).json({
|
||||
error: `Forbidden: no more IPs can authenticate with this token`,
|
||||
error: `Forbidden: ${
|
||||
maybeBannedUser.disabledReason || "Token disabled"
|
||||
}`,
|
||||
});
|
||||
case "disabled":
|
||||
const bannedUser = getUser(token);
|
||||
if (bannedUser?.disabledAt) {
|
||||
const reason = bannedUser.disabledReason || "Token disabled";
|
||||
return res.status(403).json({ error: `Forbidden: ${reason}` });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,15 +4,16 @@ import { ZodError } from "zod";
|
||||
import { generateErrorMessage } from "zod-error";
|
||||
import { buildFakeSse } from "../../shared/streaming";
|
||||
import { assertNever } from "../../shared/utils";
|
||||
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
|
||||
import { QuotaExceededError } from "./request/apply-quota-limits";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const OPENAI_TEXT_COMPLETION_ENDPOINT = "/v1/completions";
|
||||
const OPENAI_EMBEDDINGS_ENDPOINT = "/v1/embeddings";
|
||||
const OPENAI_IMAGE_COMPLETION_ENDPOINT = "/v1/images/generations";
|
||||
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
|
||||
export function isTextGenerationRequest(req: Request) {
|
||||
/** Returns true if we're making a request to a completion endpoint. */
|
||||
export function isCompletionRequest(req: Request) {
|
||||
// 99% sure this function is not needed anymore
|
||||
return (
|
||||
req.method === "POST" &&
|
||||
[
|
||||
@@ -23,13 +24,6 @@ export function isTextGenerationRequest(req: Request) {
|
||||
);
|
||||
}
|
||||
|
||||
export function isImageGenerationRequest(req: Request) {
|
||||
return (
|
||||
req.method === "POST" &&
|
||||
req.path.startsWith(OPENAI_IMAGE_COMPLETION_ENDPOINT)
|
||||
);
|
||||
}
|
||||
|
||||
export function isEmbeddingsRequest(req: Request) {
|
||||
return (
|
||||
req.method === "POST" && req.path.startsWith(OPENAI_EMBEDDINGS_ENDPOINT)
|
||||
@@ -59,8 +53,8 @@ export function writeErrorResponse(
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
|
||||
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
|
||||
if (req.debug && errorPayload.error) {
|
||||
errorPayload.error.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
res.status(statusCode).json(errorPayload);
|
||||
}
|
||||
@@ -96,7 +90,7 @@ function classifyError(err: Error): {
|
||||
} & Record<string, any> {
|
||||
const defaultError = {
|
||||
status: 500,
|
||||
userMessage: `Reverse proxy error: ${err.message}`,
|
||||
userMessage: `Reverse proxy encountered an unexpected error. (${err.message})`,
|
||||
type: "proxy_internal_error",
|
||||
stack: err.stack,
|
||||
};
|
||||
@@ -109,7 +103,7 @@ function classifyError(err: Error): {
|
||||
code: { enabled: false },
|
||||
maxErrors: 3,
|
||||
transform: ({ issue, ...rest }) => {
|
||||
return `At '${rest.pathComponent}': ${issue.message}`;
|
||||
return `At '${rest.pathComponent}', ${issue.message}`;
|
||||
},
|
||||
});
|
||||
return { status: 400, userMessage, type: "proxy_validation_error" };
|
||||
@@ -179,8 +173,6 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
return body.completion.trim();
|
||||
case "google-palm":
|
||||
return body.candidates[0].output;
|
||||
case "openai-image":
|
||||
return body.data?.map((item: any) => item.url).join("\n");
|
||||
default:
|
||||
assertNever(format);
|
||||
}
|
||||
@@ -192,8 +184,6 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return body.model;
|
||||
case "openai-image":
|
||||
return req.body.model;
|
||||
case "anthropic":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return body.model || req.body.model;
|
||||
|
||||
+5
-5
@@ -1,17 +1,17 @@
|
||||
import { AnthropicKey, Key } from "../../../../shared/key-management";
|
||||
import { isTextGenerationRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { AnthropicKey, Key } from "../../../shared/key-management";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* Some keys require the prompt to start with `\n\nHuman:`. There is no way to
|
||||
* know this without trying to send the request and seeing if it fails. If a
|
||||
* key is marked as requiring a preamble, it will be added here.
|
||||
*/
|
||||
export const addAnthropicPreamble: HPMRequestCallback = (
|
||||
export const addAnthropicPreamble: ProxyRequestMiddleware = (
|
||||
_proxyReq,
|
||||
req
|
||||
) => {
|
||||
if (!isTextGenerationRequest(req) || req.key?.service !== "anthropic") {
|
||||
if (!isCompletionRequest(req) || req.key?.service !== "anthropic") {
|
||||
return;
|
||||
}
|
||||
|
||||
+23
-14
@@ -1,12 +1,24 @@
|
||||
import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management";
|
||||
import { isEmbeddingsRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { Key, OpenAIKey, keyPool } from "../../../shared/key-management";
|
||||
import { isCompletionRequest, isEmbeddingsRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
|
||||
/** Add a key that can service this request to the request object. */
|
||||
export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
let assignedKey: Key;
|
||||
|
||||
if (!isCompletionRequest(req)) {
|
||||
// Horrible, horrible hack to stop the proxy from complaining about clients
|
||||
// not sending a model when they are requesting the list of models (which
|
||||
// requires a key, but obviously not a model).
|
||||
|
||||
// I don't think this is needed anymore since models requests are no longer
|
||||
// proxied to the upstream API. Everything going through this is either a
|
||||
// completion request or a special case like OpenAI embeddings.
|
||||
req.log.warn({ path: req.path }, "addKey called on non-completion request");
|
||||
req.body.model = "gpt-3.5-turbo";
|
||||
}
|
||||
|
||||
if (!req.inboundApi || !req.outboundApi) {
|
||||
const err = new Error(
|
||||
"Request API format missing. Did you forget to add the request preprocessor to your router?"
|
||||
@@ -22,6 +34,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
// TODO: use separate middleware to deal with stream flags
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
if (req.inboundApi === req.outboundApi) {
|
||||
assignedKey = keyPool.get(req.body.model);
|
||||
} else {
|
||||
@@ -42,9 +58,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
throw new Error(
|
||||
"OpenAI Chat as an API translation target is not supported"
|
||||
);
|
||||
case "openai-image":
|
||||
assignedKey = keyPool.get("dall-e-3");
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
@@ -80,10 +93,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "azure":
|
||||
const azureKey = assignedKey.key;
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
@@ -97,7 +106,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
* Special case for embeddings requests which don't go through the normal
|
||||
* request pipeline.
|
||||
*/
|
||||
export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
|
||||
export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = (
|
||||
proxyReq,
|
||||
req
|
||||
) => {
|
||||
@@ -111,7 +120,7 @@ export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
|
||||
throw new Error("Embeddings requests must be from OpenAI");
|
||||
}
|
||||
|
||||
req.body = { input: req.body.input, model: "text-embedding-ada-002" };
|
||||
req.body = { input: req.body.input, model: "text-embedding-ada-002" }
|
||||
|
||||
const key = keyPool.get("text-embedding-ada-002") as OpenAIKey;
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import { hasAvailableQuota } from "../../../shared/users/user-store";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
export class QuotaExceededError extends Error {
|
||||
public quotaInfo: any;
|
||||
constructor(message: string, quotaInfo: any) {
|
||||
super(message);
|
||||
this.name = "QuotaExceededError";
|
||||
this.quotaInfo = quotaInfo;
|
||||
}
|
||||
}
|
||||
|
||||
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!isCompletionRequest(req) || !req.user) {
|
||||
return;
|
||||
}
|
||||
|
||||
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
|
||||
throw new QuotaExceededError(
|
||||
"You have exceeded your proxy token quota for this model.",
|
||||
{
|
||||
quota: req.user.tokenLimits,
|
||||
used: req.user.tokenCounts,
|
||||
requested: requestedTokens,
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
+7
-2
@@ -1,4 +1,5 @@
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
|
||||
|
||||
@@ -13,7 +14,11 @@ class ForbiddenError extends Error {
|
||||
* Blocks requests from Janitor AI users with a fake, scary error message so I
|
||||
* stop getting emails asking for tech support.
|
||||
*/
|
||||
export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
|
||||
export const blockZoomerOrigins: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!isCompletionRequest(req)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const origin = req.headers.origin || req.headers.referer;
|
||||
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
|
||||
// Venus-derivatives send a test prompt to check if the proxy is working.
|
||||
+8
-13
@@ -1,7 +1,6 @@
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import { RequestPreprocessor } from "./index";
|
||||
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
@@ -14,7 +13,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: OpenAIChatMessage[] = req.body.messages;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
@@ -36,18 +35,14 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "openai-image": {
|
||||
req.outputTokens = 1;
|
||||
result = await countTokens({ req, service });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
||||
req.promptTokens = result.token_count;
|
||||
|
||||
// TODO: Remove once token counting is stable
|
||||
req.log.debug({ result: result }, "Counted prompt tokens.");
|
||||
req.tokenizerInfo = req.tokenizerInfo ?? {};
|
||||
req.tokenizerInfo = { ...req.tokenizerInfo, ...result };
|
||||
};
|
||||
req.debug = req.debug ?? {};
|
||||
req.debug = { ...req.debug, ...result };
|
||||
};
|
||||
+5
-5
@@ -1,11 +1,11 @@
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
* For AWS requests, the body is signed earlier in the request pipeline, before
|
||||
* the proxy middleware. This function just assigns the path and headers to the
|
||||
* proxy request.
|
||||
*/
|
||||
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
|
||||
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
throw new Error("Expected req.signedRequest to be set");
|
||||
}
|
||||
+2
-7
@@ -1,14 +1,9 @@
|
||||
import { fixRequestBody } from "http-proxy-middleware";
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/** Finalize the rewritten request body. Must be the last rewriter. */
|
||||
export const finalizeBody: HPMRequestCallback = (proxyReq, req) => {
|
||||
export const finalizeBody: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
|
||||
// For image generation requests, remove stream flag.
|
||||
if (req.outboundApi === "openai-image") {
|
||||
delete req.body.stream;
|
||||
}
|
||||
|
||||
const updatedBody = JSON.stringify(req.body);
|
||||
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
|
||||
(req as any).rawBody = Buffer.from(updatedBody);
|
||||
@@ -2,30 +2,29 @@ import type { Request } from "express";
|
||||
import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
export { createOnProxyReqHandler } from "./onproxyreq-factory";
|
||||
export { createOnProxyReqHandler } from "./rewrite";
|
||||
export {
|
||||
createPreprocessorMiddleware,
|
||||
createEmbeddingsPreprocessorMiddleware,
|
||||
} from "./preprocessor-factory";
|
||||
} from "./preprocess";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { addAzureKey } from "./preprocessors/add-azure-key";
|
||||
export { applyQuotaLimits } from "./preprocessors/apply-quota-limits";
|
||||
export { validateContextSize } from "./preprocessors/validate-context-size";
|
||||
export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
|
||||
export { languageFilter } from "./preprocessors/language-filter";
|
||||
export { setApiFormat } from "./preprocessors/set-api-format";
|
||||
export { signAwsRequest } from "./preprocessors/sign-aws-request";
|
||||
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { validateContextSize } from "./validate-context-size";
|
||||
export { countPromptTokens } from "./count-prompt-tokens";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { signAwsRequest } from "./sign-aws-request";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
|
||||
// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async)
|
||||
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key";
|
||||
export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins";
|
||||
export { checkModelFamily } from "./onproxyreq/check-model-family";
|
||||
export { finalizeBody } from "./onproxyreq/finalize-body";
|
||||
export { finalizeSignedRequest } from "./onproxyreq/finalize-signed-request";
|
||||
export { stripHeaders } from "./onproxyreq/strip-headers";
|
||||
// HPM middleware (runs on onProxyReq, cannot be async)
|
||||
export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
||||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
|
||||
/**
|
||||
* Middleware that runs prior to the request being handled by http-proxy-
|
||||
@@ -44,7 +43,7 @@ export { stripHeaders } from "./onproxyreq/strip-headers";
|
||||
export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
||||
|
||||
/**
|
||||
* Callbacks that run immediately before the request is sent to the API in
|
||||
* Middleware that runs immediately before the request is sent to the API in
|
||||
* response to http-proxy-middleware's `proxyReq` event.
|
||||
*
|
||||
* Async functions cannot be used here as HPM's event emitter is not async and
|
||||
@@ -54,7 +53,7 @@ export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
||||
* first attempt is rate limited and the request is automatically retried by the
|
||||
* request queue middleware.
|
||||
*/
|
||||
export type HPMRequestCallback = ProxyReqCallback<ClientRequest, Request>;
|
||||
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
|
||||
|
||||
export const forceModel = (model: string) => (req: Request) =>
|
||||
void (req.body.model = model);
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
const DISALLOWED_REGEX =
|
||||
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
|
||||
|
||||
// Our shitty free-tier VMs will fall over if we test every single character in
|
||||
// each 15k character request ten times a second. So we'll just sample 20% of
|
||||
// the characters and hope that's enough.
|
||||
const containsDisallowedCharacters = (text: string) => {
|
||||
const sampleSize = Math.ceil(text.length * 0.2);
|
||||
const sample = text
|
||||
.split("")
|
||||
.sort(() => 0.5 - Math.random())
|
||||
.slice(0, sampleSize)
|
||||
.join("");
|
||||
return DISALLOWED_REGEX.test(sample);
|
||||
};
|
||||
|
||||
/** Block requests containing too many disallowed characters. */
|
||||
export const languageFilter: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!config.rejectDisallowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isCompletionRequest(req)) {
|
||||
const combinedText = getPromptFromRequest(req);
|
||||
if (containsDisallowedCharacters(combinedText)) {
|
||||
logger.warn(`Blocked request containing bad characters`);
|
||||
_proxyReq.destroy(new Error(config.rejectMessage));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
function getPromptFromRequest(req: Request) {
|
||||
const service = req.outboundApi;
|
||||
const body = req.body;
|
||||
switch (service) {
|
||||
case "anthropic":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map((m: { content: string }) => m.content)
|
||||
.join("\n");
|
||||
case "openai-text":
|
||||
return body.prompt;
|
||||
case "google-palm":
|
||||
return body.prompt.text;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* Don't allow multiple completions to be requested to prevent abuse.
|
||||
* OpenAI-only, Anthropic provides no such parameter.
|
||||
**/
|
||||
export const limitCompletions: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (isCompletionRequest(req) && req.outboundApi === "openai") {
|
||||
const originalN = req.body?.n || 1;
|
||||
req.body.n = 1;
|
||||
if (originalN !== req.body.n) {
|
||||
req.log.warn(`Limiting completion choices from ${originalN} to 1`);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,43 +0,0 @@
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
checkModelFamily,
|
||||
HPMRequestCallback,
|
||||
stripHeaders,
|
||||
} from "./index";
|
||||
|
||||
type ProxyReqHandlerFactoryOptions = { pipeline: HPMRequestCallback[] };
|
||||
|
||||
/**
|
||||
* Returns an http-proxy-middleware request handler that runs the given set of
|
||||
* onProxyReq callback functions in sequence.
|
||||
*
|
||||
* These will run each time a request is proxied, including on automatic retries
|
||||
* by the queue after encountering a rate limit.
|
||||
*/
|
||||
export const createOnProxyReqHandler = ({
|
||||
pipeline,
|
||||
}: ProxyReqHandlerFactoryOptions): HPMRequestCallback => {
|
||||
const callbackPipeline = [
|
||||
checkModelFamily,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
...pipeline,
|
||||
];
|
||||
return (proxyReq, req, res, options) => {
|
||||
// The streaming flag must be set before any other onProxyReq handler runs,
|
||||
// as it may influence the behavior of subsequent handlers.
|
||||
// Image generation requests can't be streamed.
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
try {
|
||||
for (const fn of callbackPipeline) {
|
||||
fn(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -1,13 +0,0 @@
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { config } from "../../../../config";
|
||||
import { getModelFamilyForRequest } from "../../../../shared/models";
|
||||
|
||||
/**
|
||||
* Ensures the selected model family is enabled by the proxy configuration.
|
||||
**/
|
||||
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => {
|
||||
const family = getModelFamilyForRequest(req);
|
||||
if (!config.allowedModelFamilies.includes(family)) {
|
||||
throw new Error(`Model family ${family} is not permitted on this proxy`);
|
||||
}
|
||||
};
|
||||
+2
-24
@@ -7,9 +7,7 @@ import {
|
||||
countPromptTokens,
|
||||
setApiFormat,
|
||||
transformOutboundPayload,
|
||||
languageFilter,
|
||||
} from ".";
|
||||
import { ZodIssue } from "zod";
|
||||
|
||||
type RequestPreprocessorOptions = {
|
||||
/**
|
||||
@@ -29,14 +27,6 @@ type RequestPreprocessorOptions = {
|
||||
/**
|
||||
* Returns a middleware function that processes the request body into the given
|
||||
* API format, and then sequentially runs the given additional preprocessors.
|
||||
*
|
||||
* These run first in the request lifecycle, a single time per request before it
|
||||
* is added to the request queue. They aren't run again if the request is
|
||||
* re-attempted after a rate limit.
|
||||
*
|
||||
* To run a preprocessor on every re-attempt, pass it to createQueueMiddleware.
|
||||
* It will run after these preprocessors, but before the request is sent to
|
||||
* http-proxy-middleware.
|
||||
*/
|
||||
export const createPreprocessorMiddleware = (
|
||||
apiFormat: Parameters<typeof setApiFormat>[0],
|
||||
@@ -47,7 +37,6 @@ export const createPreprocessorMiddleware = (
|
||||
...(beforeTransform ?? []),
|
||||
transformOutboundPayload,
|
||||
countPromptTokens,
|
||||
languageFilter,
|
||||
...(afterTransform ?? []),
|
||||
validateContextSize,
|
||||
];
|
||||
@@ -77,25 +66,14 @@ async function executePreprocessors(
|
||||
}
|
||||
next();
|
||||
} catch (error) {
|
||||
if (error.constructor.name === "ZodError") {
|
||||
const msg = error?.issues
|
||||
?.map((issue: ZodIssue) => issue.message)
|
||||
.join("; ");
|
||||
req.log.info(msg, "Prompt validation failed.");
|
||||
} else {
|
||||
req.log.error(error, "Error while executing request preprocessor");
|
||||
}
|
||||
req.log.error(error, "Error while executing request preprocessor");
|
||||
|
||||
// If the requested has opted into streaming, the client probably won't
|
||||
// handle a non-eventstream response, but we haven't initialized the SSE
|
||||
// stream yet as that is typically done later by the request queue. We'll
|
||||
// do that here and then call classifyErrorAndSend to use the streaming
|
||||
// error handler.
|
||||
const { stream } = req.body;
|
||||
const isStreaming = stream === "true" || stream === true;
|
||||
if (isStreaming && !res.headersSent) {
|
||||
initializeSseStream(res);
|
||||
}
|
||||
initializeSseStream(res)
|
||||
classifyErrorAndSend(error as Error, req, res);
|
||||
}
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
import { AzureOpenAIKey, keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
||||
const serviceValid = req.service === "azure";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addAzureKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model.startsWith("azure-")
|
||||
? req.body.model
|
||||
: `azure-${req.body.model}`;
|
||||
|
||||
req.key = keyPool.get(model);
|
||||
req.body.model = model;
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Azure OpenAI key to request"
|
||||
);
|
||||
|
||||
const cred = req.key as AzureOpenAIKey;
|
||||
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: `${resourceName}.openai.azure.com`,
|
||||
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
|
||||
headers: {
|
||||
["host"]: `${resourceName}.openai.azure.com`,
|
||||
["content-type"]: "application/json",
|
||||
["api-key"]: apiKey,
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
|
||||
function getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
import { hasAvailableQuota } from "../../../../shared/users/user-store";
|
||||
import { isImageGenerationRequest, isTextGenerationRequest } from "../../common";
|
||||
import { HPMRequestCallback } from "../index";
|
||||
|
||||
export class QuotaExceededError extends Error {
|
||||
public quotaInfo: any;
|
||||
constructor(message: string, quotaInfo: any) {
|
||||
super(message);
|
||||
this.name = "QuotaExceededError";
|
||||
this.quotaInfo = quotaInfo;
|
||||
}
|
||||
}
|
||||
|
||||
export const applyQuotaLimits: HPMRequestCallback = (_proxyReq, req) => {
|
||||
const subjectToQuota =
|
||||
isTextGenerationRequest(req) || isImageGenerationRequest(req);
|
||||
if (!subjectToQuota || !req.user) return;
|
||||
|
||||
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
if (
|
||||
!hasAvailableQuota({
|
||||
userToken: req.user.token,
|
||||
model: req.body.model,
|
||||
api: req.outboundApi,
|
||||
requested: requestedTokens,
|
||||
})
|
||||
) {
|
||||
throw new QuotaExceededError(
|
||||
"You have exceeded your proxy token quota for this model.",
|
||||
{
|
||||
quota: req.user.tokenLimits,
|
||||
used: req.user.tokenCounts,
|
||||
requested: requestedTokens,
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,76 +0,0 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { UserInputError } from "../../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
setInterval(() => {
|
||||
rejectedClients.forEach((count, ip) => {
|
||||
if (count > 0) {
|
||||
rejectedClients.set(ip, Math.floor(count / 2));
|
||||
} else {
|
||||
rejectedClients.delete(ip);
|
||||
}
|
||||
});
|
||||
}, 30000);
|
||||
|
||||
/**
|
||||
* Block requests containing blacklisted phrases. Repeated rejections from the
|
||||
* same IP address will be throttled.
|
||||
*/
|
||||
export const languageFilter: RequestPreprocessor = async (req) => {
|
||||
if (!config.rejectPhrases.length) return;
|
||||
|
||||
const prompt = getPromptFromRequest(req);
|
||||
const match = config.rejectPhrases.find((phrase) =>
|
||||
prompt.match(new RegExp(phrase, "i"))
|
||||
);
|
||||
|
||||
if (match) {
|
||||
const ip = req.ip;
|
||||
const rejections = (rejectedClients.get(req.ip) || 0) + 1;
|
||||
const delay = Math.min(60000, Math.pow(2, rejections - 1) * 1000);
|
||||
rejectedClients.set(ip, rejections);
|
||||
req.log.warn(
|
||||
{ match, ip, rejections, delay },
|
||||
"Prompt contains rejected phrase"
|
||||
);
|
||||
await new Promise((resolve) => {
|
||||
req.res!.once("close", resolve);
|
||||
setTimeout(resolve, delay);
|
||||
});
|
||||
throw new UserInputError(config.rejectMessage);
|
||||
}
|
||||
};
|
||||
|
||||
function getPromptFromRequest(req: Request) {
|
||||
const service = req.outboundApi;
|
||||
const body = req.body;
|
||||
switch (service) {
|
||||
case "anthropic":
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map((msg: OpenAIChatMessage) => {
|
||||
const text = Array.isArray(msg.content)
|
||||
? msg.content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
})
|
||||
.join()
|
||||
: msg.content;
|
||||
return `${msg.role}: ${text}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return body.prompt;
|
||||
case "google-palm":
|
||||
return body.prompt.text;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
import { Request } from "express";
|
||||
import { ClientRequest } from "http";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ProxyRequestMiddleware } from "./index";
|
||||
|
||||
type ProxyReqCallback = httpProxy.ProxyReqCallback<ClientRequest, Request>;
|
||||
type RewriterOptions = {
|
||||
beforeRewrite?: ProxyReqCallback[];
|
||||
pipeline: ProxyRequestMiddleware[];
|
||||
};
|
||||
|
||||
export const createOnProxyReqHandler = ({
|
||||
beforeRewrite = [],
|
||||
pipeline,
|
||||
}: RewriterOptions): ProxyReqCallback => {
|
||||
return (proxyReq, req, res, options) => {
|
||||
try {
|
||||
for (const validator of beforeRewrite) {
|
||||
validator(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy request validator");
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
|
||||
try {
|
||||
for (const rewriter of pipeline) {
|
||||
rewriter(proxyReq, req, res, options);
|
||||
}
|
||||
} catch (error) {
|
||||
req.log.error(error, "Error while executing proxy request rewriter");
|
||||
proxyReq.destroy(error);
|
||||
}
|
||||
};
|
||||
};
|
||||
+4
-4
@@ -1,13 +1,13 @@
|
||||
import { Request } from "express";
|
||||
import { APIFormat, LLMService } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { APIFormat, LLMService } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
export const setApiFormat = (api: {
|
||||
inApi: Request["inboundApi"];
|
||||
outApi: APIFormat;
|
||||
service: LLMService,
|
||||
service: LLMService;
|
||||
}): RequestPreprocessor => {
|
||||
return function configureRequestApiFormat (req) {
|
||||
return function configureRequestApiFormat(req) {
|
||||
req.inboundApi = api.inApi;
|
||||
req.outboundApi = api.outApi;
|
||||
req.service = api.service;
|
||||
+3
-3
@@ -2,12 +2,12 @@ import express from "express";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { AnthropicV1CompleteSchema } from "./transform-outbound-payload";
|
||||
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
|
||||
process.env.AMZ_HOST || "invoke-bedrock.%REGION%.amazonaws.com";
|
||||
|
||||
/**
|
||||
* Signs an outgoing AWS request with the appropriate headers modifies the
|
||||
+2
-2
@@ -1,10 +1,10 @@
|
||||
import { HPMRequestCallback } from "../index";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* Removes origin and referer headers before sending the request to the API for
|
||||
* privacy reasons.
|
||||
**/
|
||||
export const stripHeaders: HPMRequestCallback = (proxyReq) => {
|
||||
export const stripHeaders: ProxyRequestMiddleware = (proxyReq) => {
|
||||
proxyReq.setHeader("origin", "");
|
||||
proxyReq.setHeader("referer", "");
|
||||
|
||||
+20
-118
@@ -1,15 +1,14 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../../config";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { config } from "../../../config";
|
||||
import { OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { APIFormat } from "../../../shared/key-management";
|
||||
|
||||
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||
|
||||
// TODO: move schemas to shared
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
export const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string(),
|
||||
@@ -30,25 +29,12 @@ export const AnthropicV1CompleteSchema = z.object({
|
||||
});
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
const OpenAIV1ChatContentArraySchema = z.array(
|
||||
z.union([
|
||||
z.object({ type: z.literal("text"), text: z.string() }),
|
||||
z.object({
|
||||
type: z.literal("image_url"),
|
||||
image_url: z.object({
|
||||
url: z.string().url(),
|
||||
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
|
||||
}),
|
||||
}),
|
||||
])
|
||||
);
|
||||
|
||||
export const OpenAIV1ChatCompletionSchema = z.object({
|
||||
const OpenAIV1ChatCompletionSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||
content: z.string(),
|
||||
name: z.string().optional(),
|
||||
}),
|
||||
{
|
||||
@@ -79,13 +65,8 @@ export const OpenAIV1ChatCompletionSchema = z.object({
|
||||
presence_penalty: z.number().optional().default(0),
|
||||
logit_bias: z.any().optional(),
|
||||
user: z.string().optional(),
|
||||
seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export type OpenAIChatMessage = z.infer<
|
||||
typeof OpenAIV1ChatCompletionSchema
|
||||
>["messages"][0];
|
||||
|
||||
const OpenAIV1TextCompletionSchema = z
|
||||
.object({
|
||||
model: z
|
||||
@@ -106,21 +87,6 @@ const OpenAIV1TextCompletionSchema = z
|
||||
})
|
||||
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
const OpenAIV1ImagesGenerationSchema = z.object({
|
||||
prompt: z.string().max(4000),
|
||||
model: z.string().optional(),
|
||||
quality: z.enum(["standard", "hd"]).optional().default("standard"),
|
||||
n: z.number().int().min(1).max(4).optional().default(1),
|
||||
response_format: z.enum(["url", "b64_json"]).optional(),
|
||||
size: z
|
||||
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
|
||||
.optional()
|
||||
.default("1024x1024"),
|
||||
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
|
||||
const PalmV1GenerateTextSchema = z.object({
|
||||
model: z.string(),
|
||||
@@ -143,7 +109,6 @@ const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
anthropic: AnthropicV1CompleteSchema,
|
||||
openai: OpenAIV1ChatCompletionSchema,
|
||||
"openai-text": OpenAIV1TextCompletionSchema,
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-palm": PalmV1GenerateTextSchema,
|
||||
};
|
||||
|
||||
@@ -151,10 +116,11 @@ const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
const sameService = req.inboundApi === req.outboundApi;
|
||||
const alreadyTransformed = req.retryCount > 0;
|
||||
const notTransformable =
|
||||
!isTextGenerationRequest(req) && !isImageGenerationRequest(req);
|
||||
const notTransformable = !isCompletionRequest(req);
|
||||
|
||||
if (alreadyTransformed || notTransformable) return;
|
||||
if (alreadyTransformed || notTransformable) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sameService) {
|
||||
const result = VALIDATORS[req.inboundApi].safeParse(req.body);
|
||||
@@ -184,11 +150,6 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai" && req.outboundApi === "openai-image") {
|
||||
req.body = openaiToOpenaiImage(req);
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
|
||||
);
|
||||
@@ -250,7 +211,7 @@ function openaiToOpenaiText(req: Request) {
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
@@ -264,52 +225,6 @@ function openaiToOpenaiText(req: Request) {
|
||||
return OpenAIV1TextCompletionSchema.parse(transformed);
|
||||
}
|
||||
|
||||
// Takes the last chat message and uses it verbatim as the image prompt.
|
||||
function openaiToOpenaiImage(req: Request) {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-OpenAI-image request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages } = result.data;
|
||||
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
|
||||
if (Array.isArray(prompt)) {
|
||||
throw new Error("Image generation prompt must be a text message.");
|
||||
}
|
||||
|
||||
if (body.stream) {
|
||||
throw new Error(
|
||||
"Streaming is not supported for image generation requests."
|
||||
);
|
||||
}
|
||||
|
||||
// Some frontends do weird things with the prompt, like prefixing it with a
|
||||
// character name or wrapping the entire thing in quotes. We will look for
|
||||
// the index of "Image:" and use everything after that as the prompt.
|
||||
|
||||
const index = prompt?.toLowerCase().indexOf("image:");
|
||||
if (index === -1 || !prompt) {
|
||||
throw new Error(
|
||||
`Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: Add some way to specify parameters via chat message
|
||||
const transformed = {
|
||||
model: body.model.includes("dall-e") ? body.model : "dall-e-3",
|
||||
quality: "standard",
|
||||
size: "1024x1024",
|
||||
response_format: "url",
|
||||
prompt: prompt.slice(index! + 6).trim(),
|
||||
};
|
||||
return OpenAIV1ImagesGenerationSchema.parse(transformed);
|
||||
}
|
||||
|
||||
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
@@ -325,7 +240,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
@@ -357,7 +272,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
};
|
||||
}
|
||||
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
||||
return (
|
||||
messages
|
||||
.map((m) => {
|
||||
@@ -369,17 +284,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
|
||||
} else if (role === "user") {
|
||||
role = "Human";
|
||||
}
|
||||
const name = m.name?.trim();
|
||||
const content = flattenOpenAIMessageContent(m.content);
|
||||
// https://console.anthropic.com/docs/prompt-design
|
||||
// `name` isn't supported by Anthropic but we can still try to use it.
|
||||
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
|
||||
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
|
||||
m.content
|
||||
}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
}
|
||||
|
||||
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||
// Temporary to allow experimenting with prompt strategies
|
||||
const PROMPT_VERSION: number = 1;
|
||||
switch (PROMPT_VERSION) {
|
||||
@@ -396,7 +311,7 @@ function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||
} else if (role === "user") {
|
||||
role = "User";
|
||||
}
|
||||
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
|
||||
return `\n\n${role}: ${m.content}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
@@ -408,23 +323,10 @@ function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||
if (role === "system") {
|
||||
role = "System: ";
|
||||
}
|
||||
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
|
||||
return `\n\n${role}${m.content}`;
|
||||
})
|
||||
.join("");
|
||||
default:
|
||||
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
|
||||
}
|
||||
}
|
||||
|
||||
function flattenOpenAIMessageContent(
|
||||
content: OpenAIChatMessage["content"]
|
||||
): string {
|
||||
return Array.isArray(content)
|
||||
? content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) return contentItem.text;
|
||||
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
}
|
||||
+12
-18
@@ -1,8 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../../config";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { config } from "../../../config";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
@@ -34,8 +34,6 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
case "google-palm":
|
||||
proxyMax = BISON_MAX_CONTEXT;
|
||||
break;
|
||||
case "openai-image":
|
||||
return;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
@@ -44,10 +42,6 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
let modelMax: number;
|
||||
if (model.match(/gpt-3.5-turbo-16k/)) {
|
||||
modelMax = 16384;
|
||||
} else if (model.match(/gpt-4-1106(-preview)?/)) {
|
||||
modelMax = 131072;
|
||||
} else if (model.match(/^gpt-4(-\d{4})?-vision(-preview)?$/)) {
|
||||
modelMax = 131072;
|
||||
} else if (model.match(/gpt-3.5-turbo/)) {
|
||||
modelMax = 4096;
|
||||
} else if (model.match(/gpt-4-32k/)) {
|
||||
@@ -58,18 +52,18 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^claude-(?:instant-)?v1(?:\.\d)?$/)) {
|
||||
modelMax = 9000;
|
||||
} else if (model.match(/^claude-2\.0/)) {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^claude-2/)) {
|
||||
modelMax = 200000;
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^text-bison-\d{3}$/)) {
|
||||
modelMax = BISON_MAX_CONTEXT;
|
||||
} else if (model.match(/^anthropic\.claude/)) {
|
||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||
modelMax = 100000;
|
||||
} else {
|
||||
req.log.warn({ model }, "Unknown model, using 200k token limit.");
|
||||
modelMax = 200000;
|
||||
// Don't really want to throw here because I don't want to have to update
|
||||
// this ASAP every time a new model is released.
|
||||
req.log.warn({ model }, "Unknown model, using 100k token limit.");
|
||||
modelMax = 100000;
|
||||
}
|
||||
|
||||
const finalMax = Math.min(proxyMax, modelMax);
|
||||
@@ -87,10 +81,10 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
"Prompt size validated"
|
||||
);
|
||||
|
||||
req.tokenizerInfo.prompt_tokens = promptTokens;
|
||||
req.tokenizerInfo.completion_tokens = outputTokens;
|
||||
req.tokenizerInfo.max_model_tokens = modelMax;
|
||||
req.tokenizerInfo.max_proxy_tokens = proxyMax;
|
||||
req.debug.prompt_tokens = promptTokens;
|
||||
req.debug.completion_tokens = outputTokens;
|
||||
req.debug.max_model_tokens = modelMax;
|
||||
req.debug.max_proxy_tokens = proxyMax;
|
||||
};
|
||||
|
||||
function assertRequestHasTokenCounts(
|
||||
@@ -1,17 +1,14 @@
|
||||
import express from "express";
|
||||
import { pipeline } from "stream";
|
||||
import { promisify } from "util";
|
||||
import {
|
||||
buildFakeSse,
|
||||
copySseResponseHeaders,
|
||||
initializeSseStream,
|
||||
initializeSseStream
|
||||
} from "../../../shared/streaming";
|
||||
import { enqueue } from "../../queue";
|
||||
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
|
||||
import { decodeResponseBody, RawResponseBodyHandler } from ".";
|
||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||
import { EventAggregator } from "./streaming/event-aggregator";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
@@ -36,7 +33,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
}
|
||||
|
||||
if (proxyRes.statusCode! > 201) {
|
||||
req.isStreaming = false;
|
||||
req.isStreaming = false; // Forces non-streaming response handler to execute
|
||||
req.log.warn(
|
||||
{ statusCode: proxyRes.statusCode, key: hash },
|
||||
`Streaming request returned error status code. Falling back to non-streaming response handler.`
|
||||
@@ -62,7 +59,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
const adapter = new SSEStreamAdapter({ contentType });
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
const transformer = new SSEMessageTransformer({
|
||||
inputFormat: req.outboundApi,
|
||||
inputFormat: req.outboundApi, // outbound from the request's perspective
|
||||
inputApiVersion: String(req.headers["anthropic-version"]),
|
||||
logger: req.log,
|
||||
requestId: String(req.id),
|
||||
@@ -82,19 +79,9 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
res.end();
|
||||
return aggregator.getFinalResponse();
|
||||
} catch (err) {
|
||||
if (err instanceof RetryableError) {
|
||||
keyPool.markRateLimited(req.key!);
|
||||
req.log.warn(
|
||||
{ key: req.key!.hash, retryCount: req.retryCount },
|
||||
`Re-enqueueing request due to retryable error during streaming response.`
|
||||
);
|
||||
req.retryCount++;
|
||||
enqueue(req);
|
||||
} else {
|
||||
const errorEvent = buildFakeSse("stream-error", err.message, req);
|
||||
res.write(`${errorEvent}data: [DONE]\n\n`);
|
||||
res.end();
|
||||
}
|
||||
const errorEvent = buildFakeSse("stream-error", err.message, req);
|
||||
res.write(`${errorEvent}data: [DONE]\n\n`);
|
||||
res.end();
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -3,9 +3,10 @@ import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { logger } from "../../../logger";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { HttpError } from "../../../shared/errors";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { AnthropicKey, keyPool } from "../../../shared/key-management";
|
||||
import { getOpenAIModelFamily } from "../../../shared/models";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import {
|
||||
@@ -13,16 +14,13 @@ import {
|
||||
incrementTokenCount,
|
||||
} from "../../../shared/users/user-store";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { refundLastAttempt } from "../../rate-limit";
|
||||
import {
|
||||
getCompletionFromBody,
|
||||
isImageGenerationRequest,
|
||||
isTextGenerationRequest,
|
||||
isCompletionRequest,
|
||||
writeErrorResponse,
|
||||
} from "../common";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { saveImage } from "./save-image";
|
||||
|
||||
const DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
@@ -36,7 +34,7 @@ const isSupportedContentEncoding = (
|
||||
return contentEncoding in DECODER_MAP;
|
||||
};
|
||||
|
||||
export class RetryableError extends Error {
|
||||
class RetryableError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
this.name = "RetryableError";
|
||||
@@ -109,7 +107,6 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||
countResponseTokens,
|
||||
incrementUsage,
|
||||
copyHttpHeaders,
|
||||
saveImage,
|
||||
logPrompt,
|
||||
...apiMiddleware
|
||||
);
|
||||
@@ -191,7 +188,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
body = await decoder(body);
|
||||
} else {
|
||||
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
|
||||
req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
|
||||
logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
|
||||
writeErrorResponse(req, res, 500, {
|
||||
error: errorMessage,
|
||||
contentEncoding,
|
||||
@@ -208,7 +205,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
return resolve(body.toString());
|
||||
} catch (error: any) {
|
||||
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
|
||||
req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
|
||||
logger.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
|
||||
writeErrorResponse(req, res, 500, { error: errorMessage });
|
||||
return reject(errorMessage);
|
||||
}
|
||||
@@ -254,7 +251,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
|
||||
const hash = req.key?.hash;
|
||||
const statusMessage = proxyRes.statusMessage || "Unknown error";
|
||||
req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message);
|
||||
logger.warn({ statusCode, statusMessage, key: hash }, parseError.message);
|
||||
|
||||
const errorObject = {
|
||||
statusCode,
|
||||
@@ -271,7 +268,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
errorPayload.error?.type ||
|
||||
getAwsErrorType(proxyRes.headers["x-amzn-errortype"]);
|
||||
|
||||
req.log.warn(
|
||||
logger.warn(
|
||||
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
|
||||
`Received error response from upstream. (${proxyRes.statusMessage})`
|
||||
);
|
||||
@@ -289,18 +286,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-palm":
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
|
||||
refundLastAttempt(req);
|
||||
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
||||
// For some reason, some models return this 400 error instead of the
|
||||
// same 429 billing error that other models return.
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
} else {
|
||||
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
|
||||
}
|
||||
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
|
||||
break;
|
||||
case "anthropic":
|
||||
case "aws":
|
||||
@@ -343,12 +329,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "aws":
|
||||
handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "azure":
|
||||
handleAzureRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||
break;
|
||||
throw new Error("Rate limit handling not implemented for PaLM");
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
@@ -375,9 +357,6 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
@@ -428,7 +407,7 @@ function maybeHandleMissingPreambleError(
|
||||
{ key: req.key?.hash },
|
||||
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
|
||||
);
|
||||
keyPool.update(req.key!, { requiresPreamble: true });
|
||||
keyPool.update(req.key as AnthropicKey, { requiresPreamble: true });
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Claude request re-enqueued to add preamble.");
|
||||
} else {
|
||||
@@ -475,7 +454,6 @@ function handleOpenAIRateLimitError(
|
||||
const type = errorPayload.error?.type;
|
||||
switch (type) {
|
||||
case "insufficient_quota":
|
||||
case "invalid_request_error": // this is the billing_hard_limit_reached error seen in some cases
|
||||
// Billing quota exceeded (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "quota");
|
||||
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
|
||||
@@ -492,14 +470,8 @@ function handleOpenAIRateLimitError(
|
||||
break;
|
||||
case "requests":
|
||||
case "tokens":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
if (errorPayload.error?.message?.match(/on requests per day/)) {
|
||||
// This key has a very low rate limit, so we can't re-enqueue it.
|
||||
errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`;
|
||||
break;
|
||||
}
|
||||
|
||||
// Per-minute request or token rate limit is exceeded, which we can retry
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
@@ -509,39 +481,14 @@ function handleOpenAIRateLimitError(
|
||||
return errorPayload;
|
||||
}
|
||||
|
||||
function handleAzureRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
const code = errorPayload.error?.code;
|
||||
switch (code) {
|
||||
case "429":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from Azure (${code}). Please report this.`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
|
||||
if (isCompletionRequest(req)) {
|
||||
const model = req.body.model;
|
||||
const tokensUsed = req.promptTokens! + req.outputTokens!;
|
||||
req.log.debug(
|
||||
{
|
||||
model,
|
||||
tokensUsed,
|
||||
promptTokens: req.promptTokens,
|
||||
outputTokens: req.outputTokens,
|
||||
},
|
||||
`Incrementing usage for model`
|
||||
);
|
||||
keyPool.incrementUsage(req.key!, model, tokensUsed);
|
||||
if (req.user) {
|
||||
incrementPromptCount(req.user.token);
|
||||
incrementTokenCount(req.user.token, model, req.outboundApi, tokensUsed);
|
||||
incrementTokenCount(req.user.token, model, tokensUsed);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -552,12 +499,6 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
_res,
|
||||
body
|
||||
) => {
|
||||
if (req.outboundApi === "openai-image") {
|
||||
req.outputTokens = req.promptTokens;
|
||||
req.promptTokens = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
// This function is prone to breaking if the upstream API makes even minor
|
||||
// changes to the response format, especially for SSE responses. If you're
|
||||
// seeing errors in this function, check the reassembled response body from
|
||||
@@ -572,8 +513,8 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
{ service, tokens, prevOutputTokens: req.outputTokens },
|
||||
`Counted tokens for completion`
|
||||
);
|
||||
if (req.tokenizerInfo) {
|
||||
req.tokenizerInfo.completion_tokens = tokens;
|
||||
if (req.debug) {
|
||||
req.debug.completion_tokens = tokens;
|
||||
}
|
||||
|
||||
req.outputTokens = tokens.token_count;
|
||||
|
||||
@@ -4,12 +4,10 @@ import { logQueue } from "../../../shared/prompt-logging";
|
||||
import {
|
||||
getCompletionFromBody,
|
||||
getModelFromBody,
|
||||
isImageGenerationRequest,
|
||||
isTextGenerationRequest,
|
||||
isCompletionRequest,
|
||||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -25,11 +23,11 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
const loggable =
|
||||
isTextGenerationRequest(req) || isImageGenerationRequest(req);
|
||||
if (!loggable) return;
|
||||
if (!isCompletionRequest(req)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const promptPayload = getPromptForRequest(req, responseBody);
|
||||
const promptPayload = getPromptForRequest(req);
|
||||
const promptFlattened = flattenMessages(promptPayload);
|
||||
const response = getCompletionFromBody(req, responseBody);
|
||||
const model = getModelFromBody(req, responseBody);
|
||||
@@ -43,18 +41,12 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
});
|
||||
};
|
||||
|
||||
type OaiImageResult = {
|
||||
prompt: string;
|
||||
size: string;
|
||||
style: string;
|
||||
quality: string;
|
||||
revisedPrompt?: string;
|
||||
type OaiMessage = {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string;
|
||||
};
|
||||
|
||||
const getPromptForRequest = (
|
||||
req: Request,
|
||||
responseBody: Record<string, any>
|
||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||
const getPromptForRequest = (req: Request): string | OaiMessage[] => {
|
||||
// Since the prompt logger only runs after the request has been proxied, we
|
||||
// can assume the body has already been transformed to the target API's
|
||||
// format.
|
||||
@@ -63,14 +55,6 @@ const getPromptForRequest = (
|
||||
return req.body.messages;
|
||||
case "openai-text":
|
||||
return req.body.prompt;
|
||||
case "openai-image":
|
||||
return {
|
||||
prompt: req.body.prompt,
|
||||
size: req.body.size,
|
||||
style: req.body.style,
|
||||
quality: req.body.quality,
|
||||
revisedPrompt: responseBody.data[0].revised_prompt,
|
||||
};
|
||||
case "anthropic":
|
||||
return req.body.prompt;
|
||||
case "google-palm":
|
||||
@@ -80,26 +64,9 @@ const getPromptForRequest = (
|
||||
}
|
||||
};
|
||||
|
||||
const flattenMessages = (
|
||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||
): string => {
|
||||
if (typeof val === "string") {
|
||||
return val.trim();
|
||||
const flattenMessages = (messages: string | OaiMessage[]): string => {
|
||||
if (typeof messages === "string") {
|
||||
return messages.trim();
|
||||
}
|
||||
if (Array.isArray(val)) {
|
||||
return val
|
||||
.map(({ content, role }) => {
|
||||
const text = Array.isArray(content)
|
||||
? content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
if ("image_url" in c) return "(( Attached Image ))";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
return `${role}: ${text}`;
|
||||
})
|
||||
.join("\n");
|
||||
}
|
||||
return val.prompt.trim();
|
||||
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
|
||||
};
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import { ProxyResHandlerWithBody } from "./index";
|
||||
import { mirrorGeneratedImage, OpenAIImageGenerationResult } from "../../../shared/file-storage/mirror-generated-image";
|
||||
|
||||
export const saveImage: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
_res,
|
||||
body,
|
||||
) => {
|
||||
if (req.outboundApi !== "openai-image") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (body.data) {
|
||||
const baseUrl = req.protocol + "://" + req.get("host");
|
||||
const prompt = body.data[0].revised_prompt ?? req.body.prompt;
|
||||
await mirrorGeneratedImage(
|
||||
baseUrl,
|
||||
prompt,
|
||||
body as OpenAIImageGenerationResult
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
mergeEventsForAnthropic,
|
||||
mergeEventsForOpenAIChat,
|
||||
mergeEventsForOpenAIText,
|
||||
OpenAIChatCompletionStreamEvent
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
} from "./index";
|
||||
|
||||
/**
|
||||
@@ -33,10 +33,9 @@ export class EventAggregator {
|
||||
case "anthropic":
|
||||
return mergeEventsForAnthropic(this.events);
|
||||
case "google-palm":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE aggregation not supported for ${this.format}`);
|
||||
throw new Error("Google PaLM API does not support streaming responses");
|
||||
default:
|
||||
assertNever(this.format);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,6 @@ type SSEMessageTransformerOptions = TransformOptions & {
|
||||
export class SSEMessageTransformer extends Transform {
|
||||
private lastPosition: number;
|
||||
private msgCount: number;
|
||||
private readonly inputFormat: APIFormat;
|
||||
private readonly transformFn: StreamingCompletionTransformer;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
@@ -43,7 +42,6 @@ export class SSEMessageTransformer extends Transform {
|
||||
options.inputFormat,
|
||||
options.inputApiVersion
|
||||
);
|
||||
this.inputFormat = options.inputFormat;
|
||||
this.fallbackId = options.requestId;
|
||||
this.fallbackModel = options.requestedModel;
|
||||
this.log.debug(
|
||||
@@ -69,24 +67,12 @@ export class SSEMessageTransformer extends Transform {
|
||||
});
|
||||
this.lastPosition = newPosition;
|
||||
|
||||
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
|
||||
// sometimes emits an extra event at the beginning of the stream with the
|
||||
// content moderation system's response to the prompt. A lot of frontends
|
||||
// don't expect this and neither does our event aggregator so we drop it.
|
||||
if (this.inputFormat === "openai" && this.msgCount <= 1) {
|
||||
if (originalMessage.includes("prompt_filter_results")) {
|
||||
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
|
||||
return callback();
|
||||
}
|
||||
}
|
||||
|
||||
this.emit("originalMessage", originalMessage);
|
||||
|
||||
// Some events may not be transformed, e.g. ping events
|
||||
if (!transformedMessage) return callback();
|
||||
|
||||
if (this.msgCount === 1) {
|
||||
// TODO: does this need to be skipped for passthroughToOpenAI?
|
||||
this.push(createInitialMessage(transformedMessage));
|
||||
}
|
||||
this.push(transformedMessage);
|
||||
@@ -112,8 +98,7 @@ function getTransformer(
|
||||
? anthropicV1ToOpenAI
|
||||
: anthropicV2ToOpenAI;
|
||||
case "google-palm":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE transformation not supported for ${responseApi}`);
|
||||
throw new Error("Google PaLM does not support streaming responses");
|
||||
default:
|
||||
assertNever(responseApi);
|
||||
}
|
||||
|
||||
@@ -2,16 +2,12 @@ import { Transform, TransformOptions } from "stream";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../../logger";
|
||||
import { RetryableError } from "../index";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
|
||||
type AwsEventStreamMessage = {
|
||||
headers: {
|
||||
":message-type": "event" | "exception";
|
||||
":exception-type"?: string;
|
||||
};
|
||||
headers: { ":message-type": "event" | "exception" };
|
||||
payload: { message?: string /** base64 encoded */; bytes?: string };
|
||||
};
|
||||
|
||||
@@ -40,25 +36,12 @@ export class SSEStreamAdapter extends Transform {
|
||||
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
const { payload, headers } = event;
|
||||
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
||||
const eventStr = JSON.stringify(event);
|
||||
// Under high load, AWS can rugpull us by returning a 200 and starting the
|
||||
// stream but then immediately sending a rate limit error as the first
|
||||
// event. My guess is some race condition in their rate limiting check
|
||||
// that occurs if two requests arrive at the same time when only one
|
||||
// concurrency slot is available.
|
||||
if (headers[":exception-type"] === "throttlingException") {
|
||||
log.warn(
|
||||
{ event: eventStr },
|
||||
"AWS request throttled after streaming has already started; retrying"
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
} else {
|
||||
log.error(
|
||||
{ event: eventStr },
|
||||
"Received unexpected AWS event stream message"
|
||||
);
|
||||
return getFakeErrorCompletion("proxy AWS error", eventStr);
|
||||
}
|
||||
log.error(
|
||||
{ event: JSON.stringify(event) },
|
||||
"Received bad streaming event from AWS"
|
||||
);
|
||||
const message = JSON.stringify(event);
|
||||
return getFakeErrorCompletion("proxy AWS error", message);
|
||||
} else {
|
||||
const { bytes } = payload;
|
||||
// technically this is a transformation but we don't really distinguish
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
import { RequestHandler, Router, Request } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { generateModelList } from "./openai";
|
||||
import {
|
||||
mirrorGeneratedImage,
|
||||
OpenAIImageGenerationResult,
|
||||
} from "../shared/file-storage/mirror-generated-image";
|
||||
|
||||
const KNOWN_MODELS = ["dall-e-2", "dall-e-3"];
|
||||
|
||||
let modelListCache: any = null;
|
||||
let modelListValid = 0;
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
if (new Date().getTime() - modelListValid < 1000 * 60) return modelListCache;
|
||||
const result = generateModelList(KNOWN_MODELS);
|
||||
modelListCache = { object: "list", data: result };
|
||||
modelListValid = new Date().getTime();
|
||||
res.status(200).json(modelListCache);
|
||||
};
|
||||
|
||||
const openaiImagesResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming OpenAI image response to OpenAI chat format");
|
||||
body = transformResponseForChat(body as OpenAIImageGenerationResult, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a DALL-E image generation response into a chat response, simply
|
||||
* embedding the image URL into the chat message as a Markdown image.
|
||||
*/
|
||||
function transformResponseForChat(
|
||||
imageBody: OpenAIImageGenerationResult,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const prompt = imageBody.data[0].revised_prompt ?? req.body.prompt;
|
||||
const content = imageBody.data
|
||||
.map((item) => {
|
||||
const { url, b64_json } = item;
|
||||
if (b64_json) {
|
||||
return ``;
|
||||
} else {
|
||||
return ``;
|
||||
}
|
||||
})
|
||||
.join("\n\n");
|
||||
|
||||
return {
|
||||
id: "dalle-" + req.id,
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: req.outputTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content },
|
||||
finish_reason: "stop",
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const openaiImagesProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
pathRewrite: {
|
||||
"^/v1/chat/completions": "/v1/images/generations",
|
||||
},
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }),
|
||||
proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const openaiImagesRouter = Router();
|
||||
openaiImagesRouter.get("/v1/models", handleModelRequest);
|
||||
openaiImagesRouter.post(
|
||||
"/v1/images/generations",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai-image",
|
||||
outApi: "openai-image",
|
||||
service: "openai",
|
||||
}),
|
||||
openaiImagesProxy
|
||||
);
|
||||
openaiImagesRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai",
|
||||
outApi: "openai-image",
|
||||
service: "openai",
|
||||
}),
|
||||
openaiImagesProxy
|
||||
);
|
||||
export const openaiImage = openaiImagesRouter;
|
||||
+55
-38
@@ -3,53 +3,60 @@ import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
getOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
OpenAIModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
RequestPreprocessor,
|
||||
addKey,
|
||||
addKeyForEmbeddingsRequest,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createEmbeddingsPreprocessorMiddleware,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
RequestPreprocessor,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
export const KNOWN_OPENAI_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-32k-0314", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-instruct-0914",
|
||||
"text-embedding-ada-002",
|
||||
];
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||
function getModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
const knownModels = [
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-32k-0314", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0301", // EOL 2024-06-13
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-instruct-0914",
|
||||
"text-embedding-ada-002",
|
||||
];
|
||||
|
||||
let available = new Set<OpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "openai") continue;
|
||||
@@ -60,7 +67,7 @@ export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
return models
|
||||
const models = knownModels
|
||||
.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
@@ -80,14 +87,15 @@ export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||
parent: null,
|
||||
}))
|
||||
.filter((model) => available.has(getOpenAIModelFamily(model.id)));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
|
||||
const result = generateModelList();
|
||||
modelsCache = { object: "list", data: result };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
res.status(200).json(modelsCache);
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Handles some turbo-instruct special cases. */
|
||||
@@ -129,8 +137,9 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
body = transformTurboInstructResponse(body);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
// TODO: Remove once tokenization is stable
|
||||
if (req.debug) {
|
||||
body.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
@@ -154,21 +163,29 @@ function transformTurboInstructResponse(
|
||||
return transformed;
|
||||
}
|
||||
|
||||
const openaiProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
const openaiProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKey, finalizeBody],
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const openaiEmbeddingsProxy = createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
@@ -177,7 +194,7 @@ const openaiEmbeddingsProxy = createProxyMiddleware({
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [addKeyForEmbeddingsRequest, finalizeBody],
|
||||
pipeline: [addKeyForEmbeddingsRequest, stripHeaders, finalizeBody],
|
||||
}),
|
||||
error: handleProxyError,
|
||||
},
|
||||
|
||||
+20
-7
@@ -9,10 +9,14 @@ import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
languageFilter,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
@@ -72,8 +76,9 @@ const palmResponseHandler: ProxyResHandlerWithBody = async (
|
||||
body = transformPalmResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
// TODO: Remove once tokenization is stable
|
||||
if (req.debug) {
|
||||
body.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
|
||||
// TODO: PaLM has no streaming capability which will pose a problem here if
|
||||
@@ -138,21 +143,29 @@ function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
|
||||
);
|
||||
}
|
||||
|
||||
const googlePalmProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
const googlePalmProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
target: "https://generativelanguage.googleapis.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
|
||||
beforeRewrite: [reassignPathForPalmModel],
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeBody,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([palmResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
const palmRouter = Router();
|
||||
palmRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
+124
-248
@@ -4,6 +4,10 @@
|
||||
* a given key has generated, so our queue will simply retry requests that fail
|
||||
* with a non-billing related 429 over and over again until they succeed.
|
||||
*
|
||||
* Dequeueing can operate in one of two modes:
|
||||
* - 'fair': requests are dequeued in the order they were enqueued.
|
||||
* - 'random': requests are dequeued randomly, not really a queue at all.
|
||||
*
|
||||
* When a request to a proxied endpoint is received, we create a closure around
|
||||
* the call to http-proxy-middleware and attach it to the request. This allows
|
||||
* us to pause the request until we have a key available. Further, if the
|
||||
@@ -11,15 +15,18 @@
|
||||
* back in the queue and it will be retried later using the same closure.
|
||||
*/
|
||||
|
||||
import crypto from "crypto";
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
|
||||
import {
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||
import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
@@ -28,50 +35,44 @@ const log = logger.child({ module: "request-queue" });
|
||||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = 1;
|
||||
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||
const MAX_HEARTBEAT_SIZE =
|
||||
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||
const HEARTBEAT_INTERVAL =
|
||||
1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5");
|
||||
const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
|
||||
const PAYLOAD_SCALE_FACTOR = parseFloat(
|
||||
process.env.PAYLOAD_SCALE_FACTOR ?? "6"
|
||||
);
|
||||
|
||||
/**
|
||||
* Returns an identifier for a request. This is used to determine if a
|
||||
* Returns a unique identifier for a request. This is used to determine if a
|
||||
* request is already in the queue.
|
||||
*
|
||||
* This can be (in order of preference):
|
||||
* - user token assigned by the proxy operator
|
||||
* - x-risu-tk header, if the request is from RisuAI.xyz
|
||||
* - 'shared-ip' if the request is from a shared IP address like Agnai.chat
|
||||
* - IP address
|
||||
*/
|
||||
function getIdentifier(req: Request) {
|
||||
if (req.user) return req.user.token;
|
||||
if (req.risuToken) return req.risuToken;
|
||||
if (isFromSharedIp(req)) return "shared-ip";
|
||||
if (req.user) {
|
||||
return req.user.token;
|
||||
}
|
||||
if (req.risuToken) {
|
||||
return req.risuToken;
|
||||
}
|
||||
return req.ip;
|
||||
}
|
||||
|
||||
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
||||
getIdentifier(queued) === getIdentifier(incoming);
|
||||
|
||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
|
||||
const sameUserPredicate = (incoming: Request) => (queued: Request) => {
|
||||
const queuedId = getIdentifier(queued);
|
||||
const incomingId = getIdentifier(incoming);
|
||||
return queuedId === incomingId;
|
||||
};
|
||||
|
||||
export function enqueue(req: Request) {
|
||||
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
||||
const enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
|
||||
let isGuest = req.user?.token === undefined;
|
||||
|
||||
// Requests from shared IP addresses such as Agnai.chat are exempt from IP-
|
||||
// based rate limiting but can only occupy a certain number of slots in the
|
||||
// queue. Authenticated users always get a single spot in the queue.
|
||||
const isSharedIp = isFromSharedIp(req);
|
||||
// All Agnai.chat requests come from the same IP, so we allow them to have
|
||||
// more spots in the queue. Can't make it unlimited because people will
|
||||
// intentionally abuse it.
|
||||
// Authenticated users always get a single spot in the queue.
|
||||
const isAgnai = AGNAI_DOT_CHAT_IP.includes(req.ip);
|
||||
const maxConcurrentQueuedRequests =
|
||||
isGuest && isSharedIp ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
|
||||
isGuest && isAgnai ? AGNAI_CONCURRENCY_LIMIT : USER_CONCURRENCY_LIMIT;
|
||||
if (enqueuedRequestCount >= maxConcurrentQueuedRequests) {
|
||||
if (isSharedIp) {
|
||||
if (isAgnai) {
|
||||
// Re-enqueued requests are not counted towards the limit since they
|
||||
// already made it through the queue once.
|
||||
if (req.retryCount === 0) {
|
||||
@@ -82,6 +83,9 @@ export function enqueue(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
queue.push(req);
|
||||
req.queueOutTime = 0;
|
||||
|
||||
// shitty hack to remove hpm's event listeners on retried requests
|
||||
removeProxyMiddlewareEventListeners(req);
|
||||
|
||||
@@ -94,24 +98,31 @@ export function enqueue(req: Request) {
|
||||
if (!res.headersSent) {
|
||||
initStreaming(req);
|
||||
}
|
||||
registerHeartbeat(req);
|
||||
} else if (getProxyLoad() > LOAD_THRESHOLD) {
|
||||
throw new Error(
|
||||
"Due to heavy traffic on this proxy, you must enable streaming for your request."
|
||||
);
|
||||
req.heartbeatInterval = setInterval(() => {
|
||||
if (process.env.NODE_ENV === "production") {
|
||||
if (!req.query.badSseParser) req.res!.write(": queue heartbeat\n\n");
|
||||
} else {
|
||||
req.log.info(`Sending heartbeat to request in queue.`);
|
||||
const partition = getPartitionForRequest(req);
|
||||
const avgWait = Math.round(getEstimatedWaitTime(partition) / 1000);
|
||||
const currentDuration = Math.round((Date.now() - req.startTime) / 1000);
|
||||
const debugMsg = `queue length: ${queue.length}; elapsed time: ${currentDuration}s; avg wait: ${avgWait}s`;
|
||||
req.res!.write(buildFakeSse("heartbeat", debugMsg, req));
|
||||
}
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
queue.push(req);
|
||||
req.queueOutTime = 0;
|
||||
|
||||
// Register a handler to remove the request from the queue if the connection
|
||||
// is aborted or closed before it is dequeued.
|
||||
const removeFromQueue = () => {
|
||||
req.log.info(`Removing aborted request from queue.`);
|
||||
const index = queue.indexOf(req);
|
||||
if (index !== -1) {
|
||||
queue.splice(index, 1);
|
||||
}
|
||||
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
|
||||
if (req.monitorInterval) clearInterval(req.monitorInterval);
|
||||
if (req.heartbeatInterval) {
|
||||
clearInterval(req.heartbeatInterval);
|
||||
}
|
||||
};
|
||||
req.onAborted = removeFromQueue;
|
||||
req.res!.once("close", removeFromQueue);
|
||||
@@ -123,20 +134,33 @@ export function enqueue(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
function getPartitionForRequest(req: Request): ModelFamily {
|
||||
// There is a single request queue, but it is partitioned by model family.
|
||||
// Model families are typically separated on cost/rate limit boundaries so
|
||||
// they should be treated as separate queues.
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
|
||||
// Weird special case for AWS because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") {
|
||||
return "aws-claude";
|
||||
}
|
||||
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return getOpenAIModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue
|
||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||
.sort((a, b) => {
|
||||
// Certain requests are exempted from IP-based rate limiting because they
|
||||
// come from a shared IP address. To prevent these requests from starving
|
||||
// out other requests during periods of high traffic, we sort them to the
|
||||
// end of the queue.
|
||||
const aIsExempted = isFromSharedIp(a);
|
||||
const bIsExempted = isFromSharedIp(b);
|
||||
if (aIsExempted && !bIsExempted) return 1;
|
||||
if (!aIsExempted && bIsExempted) return -1;
|
||||
return 0;
|
||||
});
|
||||
return queue.filter((req) => getPartitionForRequest(req) === partition);
|
||||
}
|
||||
|
||||
export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
@@ -156,8 +180,9 @@ export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
req.onAborted = undefined;
|
||||
}
|
||||
|
||||
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
|
||||
if (req.monitorInterval) clearInterval(req.monitorInterval);
|
||||
if (req.heartbeatInterval) {
|
||||
clearInterval(req.heartbeatInterval);
|
||||
}
|
||||
|
||||
// Track the time leaving the queue now, but don't add it to the wait times
|
||||
// yet because we don't know if the request will succeed or fail. We track
|
||||
@@ -176,23 +201,40 @@ export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
function processQueue() {
|
||||
// This isn't completely correct, because a key can service multiple models.
|
||||
// Currently if a key is locked out on one model it will also stop servicing
|
||||
// the others, because we only track rate limits for the key as a whole.
|
||||
// the others, because we only track one rate limit per key.
|
||||
|
||||
// TODO: `getLockoutPeriod` uses model names instead of model families
|
||||
// TODO: genericize this it's really ugly
|
||||
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
|
||||
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
|
||||
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
|
||||
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
|
||||
const palmLockout = keyPool.getLockoutPeriod("text-bison-001");
|
||||
const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2");
|
||||
|
||||
const reqs: (Request | undefined)[] = [];
|
||||
MODEL_FAMILIES.forEach((modelFamily) => {
|
||||
const lockout = keyPool.getLockoutPeriod(modelFamily);
|
||||
if (lockout === 0) {
|
||||
reqs.push(dequeue(modelFamily));
|
||||
}
|
||||
});
|
||||
if (gpt432kLockout === 0) {
|
||||
reqs.push(dequeue("gpt4-32k"));
|
||||
}
|
||||
if (gpt4Lockout === 0) {
|
||||
reqs.push(dequeue("gpt4"));
|
||||
}
|
||||
if (turboLockout === 0) {
|
||||
reqs.push(dequeue("turbo"));
|
||||
}
|
||||
if (claudeLockout === 0) {
|
||||
reqs.push(dequeue("claude"));
|
||||
}
|
||||
if (palmLockout === 0) {
|
||||
reqs.push(dequeue("bison"));
|
||||
}
|
||||
if (awsClaudeLockout === 0) {
|
||||
reqs.push(dequeue("aws-claude"));
|
||||
}
|
||||
|
||||
reqs.filter(Boolean).forEach((req) => {
|
||||
if (req?.proceed) {
|
||||
const modelFamily = getModelFamilyForRequest(req!);
|
||||
req.log.info({
|
||||
retries: req.retryCount,
|
||||
partition: modelFamily,
|
||||
}, `Dequeuing request.`);
|
||||
req.log.info({ retries: req.retryCount }, `Dequeuing request.`);
|
||||
req.proceed();
|
||||
}
|
||||
});
|
||||
@@ -225,93 +267,38 @@ function cleanQueue() {
|
||||
}
|
||||
|
||||
export function start() {
|
||||
MODEL_FAMILIES.forEach((modelFamily) => {
|
||||
historicalEmas.set(modelFamily, 0);
|
||||
currentEmas.set(modelFamily, 0);
|
||||
estimates.set(modelFamily, 0);
|
||||
});
|
||||
processQueue();
|
||||
cleanQueue();
|
||||
log.info(`Started request queue.`);
|
||||
}
|
||||
|
||||
let waitTimes: {
|
||||
partition: ModelFamily;
|
||||
start: number;
|
||||
end: number;
|
||||
isDeprioritized: boolean;
|
||||
}[] = [];
|
||||
let waitTimes: { partition: ModelFamily; start: number; end: number }[] = [];
|
||||
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
waitTimes.push({
|
||||
partition: getModelFamilyForRequest(req),
|
||||
partition: getPartitionForRequest(req),
|
||||
start: req.startTime!,
|
||||
end: req.queueOutTime ?? Date.now(),
|
||||
isDeprioritized: isFromSharedIp(req),
|
||||
});
|
||||
}
|
||||
|
||||
const WAIT_TIME_INTERVAL = 3000;
|
||||
const ALPHA_HISTORICAL = 0.2;
|
||||
const ALPHA_CURRENT = 0.3;
|
||||
const historicalEmas: Map<ModelFamily, number> = new Map();
|
||||
const currentEmas: Map<ModelFamily, number> = new Map();
|
||||
const estimates: Map<ModelFamily, number> = new Map();
|
||||
|
||||
/** Returns average wait time in milliseconds. */
|
||||
export function getEstimatedWaitTime(partition: ModelFamily) {
|
||||
return estimates.get(partition) ?? 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns estimated wait time for the given queue partition in milliseconds.
|
||||
* Requests which are deprioritized are not included in the calculation as they
|
||||
* would skew the results due to their longer wait times.
|
||||
*/
|
||||
function calculateWaitTime(partition: ModelFamily) {
|
||||
const now = Date.now();
|
||||
const recentWaits = waitTimes
|
||||
.filter((wait) => {
|
||||
const isSamePartition = wait.partition === partition;
|
||||
const isRecent = now - wait.end < 300 * 1000;
|
||||
const isNormalPriority = !wait.isDeprioritized;
|
||||
return isSamePartition && isRecent && isNormalPriority;
|
||||
})
|
||||
.map((wait) => wait.end - wait.start);
|
||||
const recentAverage = recentWaits.length
|
||||
? recentWaits.reduce((sum, wait) => sum + wait, 0) / recentWaits.length
|
||||
: 0;
|
||||
|
||||
const historicalEma = historicalEmas.get(partition) ?? 0;
|
||||
historicalEmas.set(
|
||||
partition,
|
||||
ALPHA_HISTORICAL * recentAverage + (1 - ALPHA_HISTORICAL) * historicalEma
|
||||
const recentWaits = waitTimes.filter(
|
||||
(wt) => wt.partition === partition && now - wt.end < 300 * 1000
|
||||
);
|
||||
if (recentWaits.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const currentWaits = queue
|
||||
.filter((req) => {
|
||||
const isSamePartition = getModelFamilyForRequest(req) === partition;
|
||||
const isNormalPriority = !isFromSharedIp(req);
|
||||
return isSamePartition && isNormalPriority;
|
||||
})
|
||||
.map((req) => now - req.startTime!);
|
||||
const longestCurrentWait = Math.max(...currentWaits, 0);
|
||||
|
||||
const currentEma = currentEmas.get(partition) ?? 0;
|
||||
currentEmas.set(
|
||||
partition,
|
||||
ALPHA_CURRENT * longestCurrentWait + (1 - ALPHA_CURRENT) * currentEma
|
||||
return (
|
||||
recentWaits.reduce((sum, wt) => sum + wt.end - wt.start, 0) /
|
||||
recentWaits.length
|
||||
);
|
||||
|
||||
return (historicalEma + currentEma) / 2;
|
||||
}
|
||||
|
||||
setInterval(() => {
|
||||
MODEL_FAMILIES.forEach((modelFamily) => {
|
||||
estimates.set(modelFamily, calculateWaitTime(modelFamily));
|
||||
});
|
||||
}, WAIT_TIME_INTERVAL);
|
||||
|
||||
export function getQueueLength(partition: ModelFamily | "all" = "all") {
|
||||
if (partition === "all") {
|
||||
return queue.length;
|
||||
@@ -320,27 +307,9 @@ export function getQueueLength(partition: ModelFamily | "all" = "all") {
|
||||
return modelQueue.length;
|
||||
}
|
||||
|
||||
export function createQueueMiddleware({
|
||||
beforeProxy,
|
||||
proxyMiddleware,
|
||||
}: {
|
||||
beforeProxy?: RequestPreprocessor;
|
||||
proxyMiddleware: Handler;
|
||||
}): Handler {
|
||||
export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
|
||||
return (req, res, next) => {
|
||||
req.proceed = async () => {
|
||||
if (beforeProxy) {
|
||||
try {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
} catch (err) {
|
||||
return handleProxyError(err, req, res);
|
||||
}
|
||||
}
|
||||
req.proceed = () => {
|
||||
proxyMiddleware(req, res, next);
|
||||
};
|
||||
|
||||
@@ -360,12 +329,11 @@ export function createQueueMiddleware({
|
||||
function killQueuedRequest(req: Request) {
|
||||
if (!req.res || req.res.writableEnded) {
|
||||
req.log.warn(`Attempted to terminate request that has already ended.`);
|
||||
queue.splice(queue.indexOf(req), 1);
|
||||
return;
|
||||
}
|
||||
const res = req.res;
|
||||
try {
|
||||
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
|
||||
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes. The queue is currently ${queue.length} requests long.`;
|
||||
if (res.headersSent) {
|
||||
const fakeErrorEvent = buildFakeSse("proxy queue error", message, req);
|
||||
res.write(fakeErrorEvent);
|
||||
@@ -386,12 +354,10 @@ function initStreaming(req: Request) {
|
||||
// Some clients have a broken SSE parser that doesn't handle comments
|
||||
// correctly. These clients can pass ?badSseParser=true to
|
||||
// disable comments in the SSE stream.
|
||||
res.write(getHeartbeatPayload());
|
||||
return;
|
||||
}
|
||||
|
||||
res.write(`: joining queue at position ${queue.length}\n\n`);
|
||||
res.write(getHeartbeatPayload());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -447,93 +413,3 @@ function removeProxyMiddlewareEventListeners(req: Request) {
|
||||
req.removeListener("error", reqOnError as any);
|
||||
}
|
||||
}
|
||||
|
||||
export function registerHeartbeat(req: Request) {
|
||||
const res = req.res!;
|
||||
|
||||
const currentSize = getHeartbeatSize();
|
||||
req.log.debug({
|
||||
currentSize,
|
||||
HEARTBEAT_INTERVAL,
|
||||
PAYLOAD_SCALE_FACTOR,
|
||||
MAX_HEARTBEAT_SIZE,
|
||||
}, "Joining queue with heartbeat.");
|
||||
|
||||
let isBufferFull = false;
|
||||
let bufferFullCount = 0;
|
||||
req.heartbeatInterval = setInterval(() => {
|
||||
if (isBufferFull) {
|
||||
bufferFullCount++;
|
||||
if (bufferFullCount >= 3) {
|
||||
req.log.error("Heartbeat skipped too many times; killing connection.");
|
||||
res.destroy();
|
||||
} else {
|
||||
req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const data = getHeartbeatPayload();
|
||||
if (!res.write(data)) {
|
||||
isBufferFull = true;
|
||||
res.once("drain", () => (isBufferFull = false));
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
monitorHeartbeat(req);
|
||||
}
|
||||
|
||||
function monitorHeartbeat(req: Request) {
|
||||
const res = req.res!;
|
||||
|
||||
let lastBytesSent = 0;
|
||||
req.monitorInterval = setInterval(() => {
|
||||
const bytesSent = res.socket?.bytesWritten ?? 0;
|
||||
const bytesSinceLast = bytesSent - lastBytesSent;
|
||||
req.log.debug(
|
||||
{
|
||||
previousBytesSent: lastBytesSent,
|
||||
currentBytesSent: bytesSent,
|
||||
},
|
||||
"Heartbeat monitor check."
|
||||
);
|
||||
lastBytesSent = bytesSent;
|
||||
|
||||
const minBytes = Math.floor(getHeartbeatSize() / 2);
|
||||
if (bytesSinceLast < minBytes) {
|
||||
req.log.warn(
|
||||
{ minBytes, bytesSinceLast },
|
||||
"Queued request is processing heartbeats enough data or server is overloaded; killing connection."
|
||||
);
|
||||
res.destroy();
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL * 2);
|
||||
}
|
||||
|
||||
/** Sends larger heartbeats when the queue is overloaded */
|
||||
function getHeartbeatSize() {
|
||||
const load = getProxyLoad();
|
||||
|
||||
if (load <= LOAD_THRESHOLD) {
|
||||
return MIN_HEARTBEAT_SIZE;
|
||||
} else {
|
||||
const excessLoad = load - LOAD_THRESHOLD;
|
||||
const size =
|
||||
MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2);
|
||||
if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
function getHeartbeatPayload() {
|
||||
const size = getHeartbeatSize();
|
||||
const data =
|
||||
process.env.NODE_ENV === "production"
|
||||
? crypto.randomBytes(size).toString("base64")
|
||||
: `payload size: ${size}`;
|
||||
|
||||
return `: queue heartbeat ${data}\n\n`;
|
||||
}
|
||||
|
||||
function getProxyLoad() {
|
||||
return Math.max(getUniqueIps(), queue.length);
|
||||
}
|
||||
|
||||
+30
-67
@@ -1,34 +1,28 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { config } from "../config";
|
||||
|
||||
export const SHARED_IP_ADDRESSES = new Set([
|
||||
// Agnai.chat
|
||||
export const AGNAI_DOT_CHAT_IP = [
|
||||
"157.230.249.32", // old
|
||||
"157.245.148.56",
|
||||
"174.138.29.50",
|
||||
"209.97.162.44",
|
||||
]);
|
||||
];
|
||||
|
||||
const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
|
||||
const RATE_LIMIT = Math.max(1, config.modelRateLimit);
|
||||
const ONE_MINUTE_MS = 60 * 1000;
|
||||
|
||||
type Timestamp = number;
|
||||
/** Tracks time of last attempts from each IP address or token. */
|
||||
const lastAttempts = new Map<string, Timestamp[]>();
|
||||
/** Tracks time of exempted attempts from shared IPs like Agnai.chat. */
|
||||
const exemptedRequests: Timestamp[] = [];
|
||||
const lastAttempts = new Map<string, number[]>();
|
||||
|
||||
const isRecentAttempt = (now: Timestamp) => (attempt: Timestamp) =>
|
||||
const expireOldAttempts = (now: number) => (attempt: number) =>
|
||||
attempt > now - ONE_MINUTE_MS;
|
||||
|
||||
const getTryAgainInMs = (ip: string, type: "text" | "image") => {
|
||||
const getTryAgainInMs = (ip: string) => {
|
||||
const now = Date.now();
|
||||
const attempts = lastAttempts.get(ip) || [];
|
||||
const validAttempts = attempts.filter(isRecentAttempt(now));
|
||||
const validAttempts = attempts.filter(expireOldAttempts(now));
|
||||
|
||||
const limit =
|
||||
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
|
||||
|
||||
if (validAttempts.length >= limit) {
|
||||
if (validAttempts.length >= RATE_LIMIT) {
|
||||
return validAttempts[0] - now + ONE_MINUTE_MS;
|
||||
} else {
|
||||
lastAttempts.set(ip, [...validAttempts, now]);
|
||||
@@ -36,25 +30,21 @@ const getTryAgainInMs = (ip: string, type: "text" | "image") => {
|
||||
}
|
||||
};
|
||||
|
||||
const getStatus = (ip: string, type: "text" | "image") => {
|
||||
const getStatus = (ip: string) => {
|
||||
const now = Date.now();
|
||||
const attempts = lastAttempts.get(ip) || [];
|
||||
const validAttempts = attempts.filter(isRecentAttempt(now));
|
||||
|
||||
const limit =
|
||||
type === "text" ? config.textModelRateLimit : config.imageModelRateLimit;
|
||||
|
||||
const validAttempts = attempts.filter(expireOldAttempts(now));
|
||||
return {
|
||||
remaining: Math.max(0, limit - validAttempts.length),
|
||||
remaining: Math.max(0, RATE_LIMIT - validAttempts.length),
|
||||
reset: validAttempts.length > 0 ? validAttempts[0] + ONE_MINUTE_MS : now,
|
||||
};
|
||||
};
|
||||
|
||||
/** Prunes attempts and IPs that are no longer relevant after one minute. */
|
||||
/** Prunes attempts and IPs that are no longer relevant after one minutes. */
|
||||
const clearOldAttempts = () => {
|
||||
const now = Date.now();
|
||||
for (const [ip, attempts] of lastAttempts.entries()) {
|
||||
const validAttempts = attempts.filter(isRecentAttempt(now));
|
||||
const validAttempts = attempts.filter(expireOldAttempts(now));
|
||||
if (validAttempts.length === 0) {
|
||||
lastAttempts.delete(ip);
|
||||
} else {
|
||||
@@ -64,25 +54,8 @@ const clearOldAttempts = () => {
|
||||
};
|
||||
setInterval(clearOldAttempts, 10 * 1000);
|
||||
|
||||
/** Prunes exempted requests which are older than one minute. */
|
||||
const clearOldExemptions = () => {
|
||||
const now = Date.now();
|
||||
const validExemptions = exemptedRequests.filter(isRecentAttempt(now));
|
||||
exemptedRequests.splice(0, exemptedRequests.length, ...validExemptions);
|
||||
};
|
||||
setInterval(clearOldExemptions, 10 * 1000);
|
||||
|
||||
export const getUniqueIps = () => lastAttempts.size;
|
||||
|
||||
/**
|
||||
* Can be used to manually remove the most recent attempt from an IP address,
|
||||
* ie. in case a prompt triggered OpenAI's content filter and therefore did not
|
||||
* result in a generation.
|
||||
*/
|
||||
export const refundLastAttempt = (req: Request) => {
|
||||
const key = req.user?.token || req.risuToken || req.ip;
|
||||
const attempts = lastAttempts.get(key) || [];
|
||||
attempts.pop();
|
||||
export const getUniqueIps = () => {
|
||||
return lastAttempts.size;
|
||||
};
|
||||
|
||||
export const ipLimiter = async (
|
||||
@@ -90,46 +63,36 @@ export const ipLimiter = async (
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) => {
|
||||
const imageLimit = config.imageModelRateLimit;
|
||||
const textLimit = config.textModelRateLimit;
|
||||
|
||||
if (!textLimit && !imageLimit) return next();
|
||||
if (!RATE_LIMIT_ENABLED) return next();
|
||||
if (req.user?.type === "special") return next();
|
||||
|
||||
// Exempts Agnai.chat from IP-based rate limiting because its IPs are shared
|
||||
// by many users. Instead, the request queue will limit the number of such
|
||||
// requests that may wait in the queue at a time, and sorts them to the end to
|
||||
// let individual users go first.
|
||||
if (SHARED_IP_ADDRESSES.has(req.ip)) {
|
||||
exemptedRequests.push(Date.now());
|
||||
req.log.info(
|
||||
{ ip: req.ip, recentExemptions: exemptedRequests.length },
|
||||
"Exempting Agnai request from rate limiting."
|
||||
);
|
||||
return next();
|
||||
// Exempt Agnai.chat from rate limiting since it's shared between a lot of
|
||||
// users. Dunno how to prevent this from being abused without some sort of
|
||||
// identifier sent from Agnaistic to identify specific users.
|
||||
if (AGNAI_DOT_CHAT_IP.includes(req.ip)) {
|
||||
req.log.info("Exempting Agnai request from rate limiting.");
|
||||
next();
|
||||
return;
|
||||
}
|
||||
|
||||
const type = (req.baseUrl + req.path).includes("openai-image")
|
||||
? "image"
|
||||
: "text";
|
||||
const limit = type === "image" ? imageLimit : textLimit;
|
||||
|
||||
// If user is authenticated, key rate limiting by their token. Otherwise, key
|
||||
// rate limiting by their IP address. Mitigates key sharing.
|
||||
const rateLimitKey = req.user?.token || req.risuToken || req.ip;
|
||||
|
||||
const { remaining, reset } = getStatus(rateLimitKey, type);
|
||||
res.set("X-RateLimit-Limit", limit.toString());
|
||||
const { remaining, reset } = getStatus(rateLimitKey);
|
||||
res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
|
||||
res.set("X-RateLimit-Remaining", remaining.toString());
|
||||
res.set("X-RateLimit-Reset", reset.toString());
|
||||
|
||||
const tryAgainInMs = getTryAgainInMs(rateLimitKey, type);
|
||||
const tryAgainInMs = getTryAgainInMs(rateLimitKey);
|
||||
if (tryAgainInMs > 0) {
|
||||
res.set("Retry-After", tryAgainInMs.toString());
|
||||
res.status(429).json({
|
||||
error: {
|
||||
type: "proxy_rate_limited",
|
||||
message: `This model type is rate limited to ${limit} prompts per minute. Please try again in ${Math.ceil(
|
||||
message: `This proxy is rate limited to ${
|
||||
config.modelRateLimit
|
||||
} prompts per minute. Please try again in ${Math.ceil(
|
||||
tryAgainInMs / 1000
|
||||
)} seconds.`,
|
||||
},
|
||||
|
||||
+2
-6
@@ -2,11 +2,9 @@ import express, { Request, Response, NextFunction } from "express";
|
||||
import { gatekeeper } from "./gatekeeper";
|
||||
import { checkRisuToken } from "./check-risu-token";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googlePalm } from "./palm";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
@@ -18,8 +16,8 @@ proxyRouter.use((req, _res, next) => {
|
||||
next();
|
||||
});
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "10mb" }),
|
||||
express.urlencoded({ extended: true, limit: "10mb" })
|
||||
express.json({ limit: "1536kb" }),
|
||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
||||
);
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
@@ -29,11 +27,9 @@ proxyRouter.use((req, _res, next) => {
|
||||
next();
|
||||
});
|
||||
proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
|
||||
+16
-31
@@ -1,23 +1,20 @@
|
||||
import { assertConfigIsValid, config, USER_ASSETS_DIR } from "./config";
|
||||
import { assertConfigIsValid, config } from "./config";
|
||||
import "source-map-support/register";
|
||||
import checkDiskSpace from "check-disk-space";
|
||||
import express from "express";
|
||||
import cors from "cors";
|
||||
import path from "path";
|
||||
import pinoHttp from "pino-http";
|
||||
import os from "os";
|
||||
import childProcess from "child_process";
|
||||
import { logger } from "./logger";
|
||||
import { setupAssetsDir } from "./shared/file-storage/setup-assets-dir";
|
||||
import { keyPool } from "./shared/key-management";
|
||||
import { adminRouter } from "./admin/routes";
|
||||
import { proxyRouter } from "./proxy/routes";
|
||||
import { handleInfoPage } from "./info-page";
|
||||
import { logQueue } from "./shared/prompt-logging";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
import { init as initUserStore } from "./shared/users/user-store";
|
||||
import { init as initTokenizers } from "./shared/tokenization";
|
||||
import { logger } from "./logger";
|
||||
import { adminRouter } from "./admin/routes";
|
||||
import { checkOrigin } from "./proxy/check-origin";
|
||||
import { start as startRequestQueue } from "./proxy/queue";
|
||||
import { proxyRouter } from "./proxy/routes";
|
||||
import { init as initKeyPool } from "./shared/key-management/key-pool";
|
||||
import { logQueue } from "./shared/prompt-logging";
|
||||
import { init as initTokenizers } from "./shared/tokenization";
|
||||
import { init as initUserStore } from "./shared/users/user-store";
|
||||
import { userRouter } from "./user/routes";
|
||||
|
||||
const PORT = config.port;
|
||||
@@ -28,7 +25,9 @@ app.use(
|
||||
pinoHttp({
|
||||
quietReqLogger: true,
|
||||
logger,
|
||||
autoLogging: { ignore: ({ url }) => ["/health"].includes(url as string) },
|
||||
autoLogging: {
|
||||
ignore: ({ url }) => ["/health"].includes(url as string),
|
||||
},
|
||||
redact: {
|
||||
paths: [
|
||||
"req.headers.cookie",
|
||||
@@ -41,11 +40,6 @@ app.use(
|
||||
],
|
||||
censor: "********",
|
||||
},
|
||||
customProps: (req) => {
|
||||
const user = (req as express.Request).user;
|
||||
if (user) return { userToken: `...${user.token.slice(-5)}` };
|
||||
return {};
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
@@ -61,8 +55,6 @@ app.set("views", [
|
||||
path.join(__dirname, "shared/views"),
|
||||
]);
|
||||
|
||||
app.use("/user_content", express.static(USER_ASSETS_DIR));
|
||||
|
||||
app.get("/health", (_req, res) => res.sendStatus(200));
|
||||
app.use(cors());
|
||||
app.use(checkOrigin);
|
||||
@@ -100,21 +92,18 @@ async function start() {
|
||||
logger.info("Checking configs and external dependencies...");
|
||||
await assertConfigIsValid();
|
||||
|
||||
keyPool.init();
|
||||
logger.info("Starting key pool...");
|
||||
await initKeyPool();
|
||||
|
||||
await initTokenizers();
|
||||
|
||||
if (config.allowedModelFamilies.includes("dall-e")) {
|
||||
await setupAssetsDir();
|
||||
}
|
||||
|
||||
if (config.gatekeeper === "user_token") {
|
||||
await initUserStore();
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
logger.info("Starting prompt logging...");
|
||||
await logQueue.start();
|
||||
logQueue.start();
|
||||
}
|
||||
|
||||
logger.info("Starting request queue...");
|
||||
@@ -125,12 +114,8 @@ async function start() {
|
||||
registerUncaughtExceptionHandler();
|
||||
});
|
||||
|
||||
const diskSpace = await checkDiskSpace(
|
||||
__dirname.startsWith("/app") ? "/app" : os.homedir()
|
||||
);
|
||||
|
||||
logger.info(
|
||||
{ build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV, diskSpace },
|
||||
{ build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV },
|
||||
"Startup complete."
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
const IMAGE_HISTORY_SIZE = 30;
|
||||
const imageHistory = new Array<ImageHistory>(IMAGE_HISTORY_SIZE);
|
||||
let index = 0;
|
||||
|
||||
type ImageHistory = { url: string; prompt: string };
|
||||
|
||||
export function addToImageHistory(image: ImageHistory) {
|
||||
imageHistory[index] = image;
|
||||
index = (index + 1) % IMAGE_HISTORY_SIZE;
|
||||
}
|
||||
|
||||
export function getLastNImages(n: number) {
|
||||
const result: ImageHistory[] = [];
|
||||
let currentIndex = (index - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
|
||||
|
||||
for (let i = 0; i < n; i++) {
|
||||
if (imageHistory[currentIndex]) result.unshift(imageHistory[currentIndex]);
|
||||
currentIndex = (currentIndex - 1 + IMAGE_HISTORY_SIZE) % IMAGE_HISTORY_SIZE;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// We need to control the timing of when sharp is imported because it has a
|
||||
// native dependency that causes conflicts with node-canvas if they are not
|
||||
// imported in a specific order.
|
||||
import sharp from "sharp";
|
||||
|
||||
export { sharp as libSharp };
|
||||
@@ -1,73 +0,0 @@
|
||||
import axios from "axios";
|
||||
import { promises as fs } from "fs";
|
||||
import path from "path";
|
||||
import { v4 } from "uuid";
|
||||
import { USER_ASSETS_DIR } from "../../config";
|
||||
import { addToImageHistory } from "./image-history";
|
||||
import { libSharp } from "./index";
|
||||
|
||||
|
||||
export type OpenAIImageGenerationResult = {
|
||||
created: number;
|
||||
data: {
|
||||
revised_prompt?: string;
|
||||
url: string;
|
||||
b64_json: string;
|
||||
}[];
|
||||
};
|
||||
|
||||
async function downloadImage(url: string) {
|
||||
const { data } = await axios.get(url, { responseType: "arraybuffer" });
|
||||
const buffer = Buffer.from(data, "binary");
|
||||
const newFilename = `${v4()}.png`;
|
||||
|
||||
const filepath = path.join(USER_ASSETS_DIR, newFilename);
|
||||
await fs.writeFile(filepath, buffer);
|
||||
return filepath;
|
||||
}
|
||||
|
||||
async function saveB64Image(b64: string) {
|
||||
const buffer = Buffer.from(b64, "base64");
|
||||
const newFilename = `${v4()}.png`;
|
||||
|
||||
const filepath = path.join(USER_ASSETS_DIR, newFilename);
|
||||
await fs.writeFile(filepath, buffer);
|
||||
return filepath;
|
||||
}
|
||||
|
||||
async function createThumbnail(filepath: string) {
|
||||
const thumbnailPath = filepath.replace(/(\.[\wd_-]+)$/i, "_t.jpg");
|
||||
|
||||
await libSharp(filepath)
|
||||
.resize(150, 150, {
|
||||
fit: "inside",
|
||||
withoutEnlargement: true,
|
||||
})
|
||||
.toFormat("jpeg")
|
||||
.toFile(thumbnailPath);
|
||||
|
||||
return thumbnailPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Downloads generated images and mirrors them to the user_content directory.
|
||||
* Mutates the result object.
|
||||
*/
|
||||
export async function mirrorGeneratedImage(
|
||||
host: string,
|
||||
prompt: string,
|
||||
result: OpenAIImageGenerationResult
|
||||
): Promise<OpenAIImageGenerationResult> {
|
||||
for (const item of result.data) {
|
||||
let mirror: string;
|
||||
if (item.b64_json) {
|
||||
mirror = await saveB64Image(item.b64_json);
|
||||
} else {
|
||||
mirror = await downloadImage(item.url);
|
||||
}
|
||||
item.url = `${host}/user_content/${path.basename(mirror)}`;
|
||||
await createThumbnail(mirror);
|
||||
addToImageHistory({ url: item.url, prompt });
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
import { promises as fs } from "fs";
|
||||
import { logger } from "../../logger";
|
||||
import { USER_ASSETS_DIR } from "../../config";
|
||||
|
||||
const log = logger.child({ module: "file-storage" });
|
||||
|
||||
export async function setupAssetsDir() {
|
||||
try {
|
||||
log.info({ dir: USER_ASSETS_DIR }, "Setting up user assets directory");
|
||||
await fs.mkdir(USER_ASSETS_DIR, { recursive: true });
|
||||
const stats = await fs.stat(USER_ASSETS_DIR);
|
||||
const mode = stats.mode | 0o666;
|
||||
if (stats.mode !== mode) {
|
||||
await fs.chmod(USER_ASSETS_DIR, mode);
|
||||
}
|
||||
} catch (e) {
|
||||
log.error(e);
|
||||
throw new Error("Could not create user assets directory for DALL-E image generation. You may need to update your Dockerfile to `chown` the working directory to user 1000. See the proxy docs for more information.");
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,11 @@
|
||||
import { doubleCsrf } from "csrf-csrf";
|
||||
import express from "express";
|
||||
import { config, COOKIE_SECRET } from "../config";
|
||||
import { COOKIE_SECRET } from "../config";
|
||||
|
||||
const { generateToken, doubleCsrfProtection } = doubleCsrf({
|
||||
getSecret: () => COOKIE_SECRET,
|
||||
cookieName: "csrf",
|
||||
cookieOptions: {
|
||||
sameSite: "strict",
|
||||
path: "/",
|
||||
secure: !config.useInsecureCookies,
|
||||
},
|
||||
cookieOptions: { sameSite: "strict", path: "/" },
|
||||
getTokenFromRequest: (req) => {
|
||||
const val = req.body["_csrf"] || req.query["_csrf"];
|
||||
delete req.body["_csrf"];
|
||||
|
||||
@@ -11,8 +11,7 @@ export const injectLocals: RequestHandler = (req, res, next) => {
|
||||
quota.turbo > 0 || quota.gpt4 > 0 || quota.claude > 0;
|
||||
res.locals.quota = quota;
|
||||
res.locals.nextQuotaRefresh = userStore.getNextQuotaRefresh();
|
||||
res.locals.persistenceEnabled = config.gatekeeperStore !== "memory";
|
||||
res.locals.usersEnabled = config.gatekeeper === "user_token";
|
||||
res.locals.persistenceEnabled = config.persistenceProvider !== "memory";
|
||||
res.locals.showTokenCosts = config.showTokenCosts;
|
||||
res.locals.maxIps = config.maxIpsPerUser;
|
||||
|
||||
|
||||
@@ -26,23 +26,46 @@ type AnthropicAPIError = {
|
||||
type UpdateFn = typeof AnthropicKeyProvider.prototype.update;
|
||||
|
||||
export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AnthropicKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "anthropic",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AnthropicKey) {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Checked key."
|
||||
);
|
||||
protected async checkKey(key: AnthropicKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
const [{ pozzed }] = await Promise.all([this.testLiveness(key)]);
|
||||
const updates = { isPozzed: pozzed };
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
|
||||
@@ -61,7 +84,6 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
{ key: key.hash, error: error.message },
|
||||
"Key is rate limited. Rechecking in 10 seconds."
|
||||
);
|
||||
0;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
break;
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AnthropicModelFamily } from "../../models";
|
||||
import { KeyProviderBase } from "../key-provider-base";
|
||||
import { Key } from "../types";
|
||||
import { AnthropicKeyChecker } from "./checker";
|
||||
|
||||
// https://docs.anthropic.com/claude/reference/selecting-a-model
|
||||
export type AnthropicModel =
|
||||
| "claude-instant-v1"
|
||||
| "claude-instant-v1-100k"
|
||||
| "claude-v1"
|
||||
| "claude-v1-100k"
|
||||
| "claude-2"
|
||||
| "claude-2.1";
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export type AnthropicKeyUpdate = Omit<
|
||||
Partial<AnthropicKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
| "promptCount"
|
||||
| "rateLimitedAt"
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
// https://docs.anthropic.com/claude/reference/selecting-a-model
|
||||
export const ANTHROPIC_SUPPORTED_MODELS = [
|
||||
"claude-instant-v1",
|
||||
"claude-instant-v1-100k",
|
||||
"claude-v1",
|
||||
"claude-v1-100k",
|
||||
"claude-2",
|
||||
] as const;
|
||||
export type AnthropicModel = (typeof ANTHROPIC_SUPPORTED_MODELS)[number];
|
||||
|
||||
type AnthropicKeyUsage = {
|
||||
[K in AnthropicModelFamily as `${K}Tokens`]: number;
|
||||
@@ -50,72 +44,33 @@ export interface AnthropicKey extends Key, AnthropicKeyUsage {
|
||||
isPozzed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
export class AnthropicKeyProvider extends KeyProviderBase<AnthropicKey> {
|
||||
readonly service = "anthropic" as const;
|
||||
|
||||
export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
readonly service = "anthropic";
|
||||
|
||||
private keys: AnthropicKey[] = [];
|
||||
protected readonly keys: AnthropicKey[] = [];
|
||||
private checker?: AnthropicKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.anthropicKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"ANTHROPIC_KEY is not set. Anthropic API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AnthropicKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["claude"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
isPozzed: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
requiresPreamble: false,
|
||||
hash: `ant-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
claudeTokens: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Anthropic keys.");
|
||||
}
|
||||
public async init() {
|
||||
const storeName = this.store.constructor.name;
|
||||
const loadedKeys = await this.store.load();
|
||||
|
||||
if (loadedKeys.length === 0) {
|
||||
return this.log.warn({ via: storeName }, "No Anthropic keys found.");
|
||||
}
|
||||
|
||||
this.keys.push(...loadedKeys);
|
||||
this.log.info(
|
||||
{ count: this.keys.length, via: storeName },
|
||||
"Loaded Anthropic keys."
|
||||
);
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new AnthropicKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: AnthropicModel) {
|
||||
// Currently, all Anthropic keys have access to all models. This will almost
|
||||
// certainly change when they move out of beta later this year.
|
||||
@@ -152,26 +107,14 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
selectedKey.rateLimitedAt = now;
|
||||
// Intended to throttle the queue processor as otherwise it will just
|
||||
// flood the API with requests and we want to wait a sec to see if we're
|
||||
// going to get a rate limit error on this key.
|
||||
selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AnthropicKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AnthropicKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
@@ -179,7 +122,7 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
key.claudeTokens += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
public getLockoutPeriod(_model: AnthropicModel) {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
@@ -221,20 +164,4 @@ export class AnthropicKeyProvider implements KeyProvider<AnthropicKey> {
|
||||
});
|
||||
this.checker?.scheduleNextCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import crypto from "crypto";
|
||||
import type { AnthropicKey, SerializedKey } from "../index";
|
||||
import { KeySerializerBase } from "../key-serializer-base";
|
||||
|
||||
const SERIALIZABLE_FIELDS: (keyof AnthropicKey)[] = [
|
||||
"key",
|
||||
"service",
|
||||
"hash",
|
||||
"promptCount",
|
||||
"claudeTokens",
|
||||
];
|
||||
export type SerializedAnthropicKey = SerializedKey &
|
||||
Partial<Pick<AnthropicKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||
|
||||
export class AnthropicKeySerializer extends KeySerializerBase<AnthropicKey> {
|
||||
constructor() {
|
||||
super(SERIALIZABLE_FIELDS);
|
||||
}
|
||||
|
||||
deserialize({ key, ...rest }: SerializedAnthropicKey): AnthropicKey {
|
||||
return {
|
||||
key,
|
||||
service: "anthropic" as const,
|
||||
modelFamilies: ["claude" as const],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
isPozzed: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
requiresPreamble: false,
|
||||
hash: `ant-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
claudeTokens: 0,
|
||||
...rest,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -8,13 +8,11 @@ import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
|
||||
const GET_CALLER_IDENTITY_URL = `https://sts.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15`;
|
||||
const GET_INVOCATION_LOGGING_CONFIG_URL = (region: string) =>
|
||||
`https://bedrock.${region}.amazonaws.com/logging/modelinvocations`;
|
||||
const POST_INVOKE_MODEL_URL = (region: string, model: string) =>
|
||||
`https://${AMZ_HOST.replace("%REGION%", region)}/model/${model}/invoke`;
|
||||
`https://invoke-bedrock.${region}.amazonaws.com/model/${model}/invoke`;
|
||||
const TEST_PROMPT = "\n\nHuman:\n\nAssistant:";
|
||||
|
||||
type AwsError = { error: {} };
|
||||
@@ -32,36 +30,58 @@ type GetLoggingConfigResponse = {
|
||||
type UpdateFn = typeof AwsBedrockKeyProvider.prototype.update;
|
||||
|
||||
export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: AwsBedrockKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "aws",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
updateKey,
|
||||
});
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AwsBedrockKey) {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
protected async checkKey(key: AwsBedrockKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// Only check models on startup. For now all models must be available to
|
||||
// the proxy because we don't route requests to different keys.
|
||||
const modelChecks: Promise<unknown>[] = [];
|
||||
if (isInitialCheck) {
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v1", key));
|
||||
modelChecks.push(this.invokeModel("anthropic.claude-v2", key));
|
||||
}
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Checked key."
|
||||
);
|
||||
await Promise.all(modelChecks);
|
||||
await this.checkLoggingConfiguration(key);
|
||||
|
||||
this.log.info(
|
||||
{
|
||||
key: key.hash,
|
||||
models: key.modelFamilies,
|
||||
logged: key.awsLoggingStatus,
|
||||
},
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.updateKey(key.hash, {});
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: AwsBedrockKey, error: AxiosError) {
|
||||
@@ -145,10 +165,12 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
const errorType = (headers["x-amzn-errortype"] as string).split(":")[0];
|
||||
const errorMessage = data?.message;
|
||||
|
||||
// We're looking for a specific error type and message here
|
||||
// We're looking for a specific error type and message here:
|
||||
// "ValidationException"
|
||||
// "Malformed input request: -1 is not greater or equal to 0, please reformat your input and try again."
|
||||
// "Malformed input request: 2 schema violations found, please reformat your input and try again." (if there are multiple issues)
|
||||
const correctErrorType = errorType === "ValidationException";
|
||||
const correctErrorMessage = errorMessage?.match(/max_tokens_to_sample/);
|
||||
const correctErrorMessage = errorMessage?.match(/malformed input request/i);
|
||||
if (!correctErrorType || !correctErrorMessage) {
|
||||
throw new AxiosError(
|
||||
`Unexpected error when invoking model ${model}: ${errorMessage}`,
|
||||
@@ -160,7 +182,7 @@ export class AwsKeyChecker extends KeyCheckerBase<AwsBedrockKey> {
|
||||
}
|
||||
|
||||
this.log.debug(
|
||||
{ key: key.hash, errorType, data, status, model },
|
||||
{ key: key.hash, errorType, data, status },
|
||||
"Liveness test complete."
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AwsBedrockModelFamily } from "../../models";
|
||||
import { KeyProviderBase } from "../key-provider-base";
|
||||
import { Key } from "../types";
|
||||
import { AwsKeyChecker } from "./checker";
|
||||
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
export type AwsBedrockModel =
|
||||
| "anthropic.claude-v1"
|
||||
| "anthropic.claude-v2"
|
||||
| "anthropic.claude-instant-v1";
|
||||
export const AWS_BEDROCK_SUPPORTED_MODELS = [
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-instant-v1",
|
||||
] as const;
|
||||
export type AwsBedrockModel = (typeof AWS_BEDROCK_SUPPORTED_MODELS)[number];
|
||||
|
||||
type AwsBedrockKeyUsage = {
|
||||
[K in AwsBedrockModelFamily as `${K}Tokens`]: number;
|
||||
@@ -31,71 +36,33 @@ export interface AwsBedrockKey extends Key, AwsBedrockKeyUsage {
|
||||
awsLoggingStatus: "unknown" | "disabled" | "enabled";
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 250;
|
||||
export class AwsBedrockKeyProvider extends KeyProviderBase<AwsBedrockKey> {
|
||||
readonly service = "aws" as const;
|
||||
|
||||
export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
readonly service = "aws";
|
||||
|
||||
private keys: AwsBedrockKey[] = [];
|
||||
protected readonly keys: AwsBedrockKey[] = [];
|
||||
private checker?: AwsKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.awsCredentials?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"AWS_CREDENTIALS is not set. AWS Bedrock API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AwsBedrockKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["aws-claude"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
awsLoggingStatus: "unknown",
|
||||
hash: `aws-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
["aws-claudeTokens"]: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded AWS Bedrock keys.");
|
||||
}
|
||||
public async init() {
|
||||
const storeName = this.store.constructor.name;
|
||||
const loadedKeys = await this.store.load();
|
||||
|
||||
if (loadedKeys.length === 0) {
|
||||
return this.log.warn({ via: storeName }, "No AWS credentials found.");
|
||||
}
|
||||
|
||||
this.keys.push(...loadedKeys);
|
||||
this.log.info(
|
||||
{ count: this.keys.length, via: storeName },
|
||||
"Loaded AWS Bedrock keys."
|
||||
);
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new AwsKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: AwsBedrockModel) {
|
||||
const availableKeys = this.keys.filter((k) => {
|
||||
const isNotLogged = k.awsLoggingStatus === "disabled";
|
||||
@@ -129,26 +96,14 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
selectedKey.rateLimitedAt = now;
|
||||
// Intended to throttle the queue processor as otherwise it will just
|
||||
// flood the API with requests and we want to wait a sec to see if we're
|
||||
// going to get a rate limit error on this key.
|
||||
selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AwsBedrockKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AwsBedrockKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
@@ -156,7 +111,7 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
key["aws-claudeTokens"] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
public getLockoutPeriod(_model: AwsBedrockModel) {
|
||||
// TODO: same exact behavior for three providers, should be refactored
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
@@ -193,20 +148,4 @@ export class AwsBedrockKeyProvider implements KeyProvider<AwsBedrockKey> {
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false })
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import crypto from "crypto";
|
||||
import type { AwsBedrockKey, SerializedKey } from "../index";
|
||||
import { KeySerializerBase } from "../key-serializer-base";
|
||||
|
||||
const SERIALIZABLE_FIELDS: (keyof AwsBedrockKey)[] = [
|
||||
"key",
|
||||
"service",
|
||||
"hash",
|
||||
"promptCount",
|
||||
"aws-claudeTokens",
|
||||
];
|
||||
export type SerializedAwsBedrockKey = SerializedKey &
|
||||
Partial<Pick<AwsBedrockKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||
|
||||
export class AwsBedrockKeySerializer extends KeySerializerBase<AwsBedrockKey> {
|
||||
constructor() {
|
||||
super(SERIALIZABLE_FIELDS);
|
||||
}
|
||||
|
||||
deserialize(serializedKey: SerializedAwsBedrockKey): AwsBedrockKey {
|
||||
const { key, ...rest } = serializedKey;
|
||||
return {
|
||||
key,
|
||||
service: "aws",
|
||||
modelFamilies: ["aws-claude"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
awsLoggingStatus: "unknown",
|
||||
hash: `aws-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
["aws-claudeTokens"]: 0,
|
||||
...rest,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 1000; // 3 minutes
|
||||
const AZURE_HOST = process.env.AZURE_HOST || "%RESOURCE_NAME%.openai.azure.com";
|
||||
const POST_CHAT_COMPLETIONS = (resourceName: string, deploymentId: string) =>
|
||||
`https://${AZURE_HOST.replace(
|
||||
"%RESOURCE_NAME%",
|
||||
resourceName
|
||||
)}/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`;
|
||||
|
||||
type AzureError = {
|
||||
error: {
|
||||
message: string;
|
||||
type: string | null;
|
||||
param: string;
|
||||
code: string;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
type UpdateFn = typeof AzureOpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
|
||||
constructor(keys: AzureOpenAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "azure",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: AzureOpenAIKey) {
|
||||
const model = await this.testModel(key);
|
||||
this.log.info(
|
||||
{ key: key.hash, deploymentModel: model },
|
||||
"Checked key."
|
||||
);
|
||||
this.updateKey(key.hash, { modelFamilies: [model] });
|
||||
}
|
||||
|
||||
// provided api-key header isn't valid (401)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "401",
|
||||
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
|
||||
// }
|
||||
// }
|
||||
|
||||
// api key correct but deployment id is wrong (404)
|
||||
// {
|
||||
// "error": {
|
||||
// "code": "DeploymentNotFound",
|
||||
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
|
||||
// }
|
||||
// }
|
||||
|
||||
// resource name is wrong (node will throw ENOTFOUND)
|
||||
|
||||
// rate limited (429)
|
||||
// TODO: try to reproduce this
|
||||
|
||||
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
|
||||
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
|
||||
const data = error.response.data;
|
||||
const status = data.error.status;
|
||||
const errorType = data.error.code || data.error.type;
|
||||
switch (errorType) {
|
||||
case "DeploymentNotFound":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is revoked or deployment ID is incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
case "401":
|
||||
this.log.warn(
|
||||
{ key: key.hash, errorType, error: error.response.data },
|
||||
"Key is disabled or incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, {
|
||||
isDisabled: true,
|
||||
isRevoked: true,
|
||||
});
|
||||
default:
|
||||
this.log.error(
|
||||
{ key: key.hash, errorType, error: error.response.data, status },
|
||||
"Unknown Azure API error while checking key. Please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
}
|
||||
|
||||
const { response, code } = error;
|
||||
if (code === "ENOTFOUND") {
|
||||
this.log.warn(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Resource name is probably incorrect. Disabling key."
|
||||
);
|
||||
return this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
}
|
||||
|
||||
const { headers, status, data } = response ?? {};
|
||||
this.log.error(
|
||||
{ key: key.hash, status, headers, data, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 60 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
private async testModel(key: AzureOpenAIKey) {
|
||||
const { apiKey, deploymentId, resourceName } =
|
||||
AzureOpenAIKeyChecker.getCredentialsFromKey(key);
|
||||
const url = POST_CHAT_COMPLETIONS(resourceName, deploymentId);
|
||||
const testRequest = {
|
||||
max_tokens: 1,
|
||||
stream: false,
|
||||
messages: [{ role: "user", content: "" }],
|
||||
};
|
||||
const { data } = await axios.post(url, testRequest, {
|
||||
headers: { "Content-Type": "application/json", "api-key": apiKey },
|
||||
});
|
||||
|
||||
return getAzureOpenAIModelFamily(data.model);
|
||||
}
|
||||
|
||||
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.code || data?.error?.type;
|
||||
}
|
||||
|
||||
static getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error(
|
||||
"Invalid Azure credential format. Refer to .env.example and ensure your credentials are in the format RESOURCE_NAME:DEPLOYMENT_ID:API_KEY with commas between each credential set."
|
||||
);
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
||||
}
|
||||
@@ -1,215 +0,0 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { AwsKeyChecker } from "../aws/checker";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
|
||||
type AzureOpenAIKeyUsage = {
|
||||
[K in AzureOpenAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface AzureOpenAIKey extends Key, AzureOpenAIKeyUsage {
|
||||
readonly service: "azure";
|
||||
readonly modelFamilies: AzureOpenAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
contentFiltering: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 4000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 250;
|
||||
|
||||
export class AzureOpenAIKeyProvider implements KeyProvider<AzureOpenAIKey> {
|
||||
readonly service = "azure";
|
||||
|
||||
private keys: AzureOpenAIKey[] = [];
|
||||
private checker?: AzureOpenAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.azureCredentials;
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"AZURE_CREDENTIALS is not set. Azure OpenAI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: AzureOpenAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["azure-gpt4"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
contentFiltering: false,
|
||||
hash: `azu-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"azure-turboTokens": 0,
|
||||
"azure-gpt4Tokens": 0,
|
||||
"azure-gpt4-32kTokens": 0,
|
||||
"azure-gpt4-turboTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Azure OpenAI keys.");
|
||||
}
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new AzureOpenAIKeyChecker(
|
||||
this.keys,
|
||||
this.update.bind(this)
|
||||
);
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(model: AzureOpenAIModel) {
|
||||
const neededFamily = getAzureOpenAIModelFamily(model);
|
||||
const availableKeys = this.keys.filter(
|
||||
(k) => !k.isDisabled && k.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
// 1. Keys which are not rate limited
|
||||
// a. If all keys were rate limited recently, select the least-recently
|
||||
// rate limited key.
|
||||
// 3. Keys which have not been used in the longest time
|
||||
|
||||
const now = Date.now();
|
||||
|
||||
const keysByPriority = availableKeys.sort((a, b) => {
|
||||
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
|
||||
|
||||
if (aRateLimited && !bRateLimited) return 1;
|
||||
if (!aRateLimited && bRateLimited) return -1;
|
||||
if (aRateLimited && bRateLimited) {
|
||||
return a.rateLimitedAt - b.rateLimitedAt;
|
||||
}
|
||||
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: AzureOpenAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<AzureOpenAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key[`${getAzureOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
// TODO: all of this shit is duplicate code
|
||||
|
||||
public getLockoutPeriod(family: AzureOpenAIModelFamily) {
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(family)
|
||||
);
|
||||
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
|
||||
const now = Date.now();
|
||||
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
|
||||
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
|
||||
|
||||
if (anyNotRateLimited) return 0;
|
||||
|
||||
// If all keys are rate-limited, return time until the first key is ready.
|
||||
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
|
||||
}
|
||||
|
||||
/**
|
||||
* This is called when we receive a 429, which means there are already five
|
||||
* concurrent requests running on this key. We don't have any information on
|
||||
* when these requests will resolve, so all we can do is wait a bit and try
|
||||
* again. We will lock the key for 2 seconds after getting a 429 before
|
||||
* retrying in order to give the other requests a chance to finish.
|
||||
*/
|
||||
public markRateLimited(keyHash: string) {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const now = Date.now();
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
|
||||
}
|
||||
|
||||
public recheck() {
|
||||
this.keys.forEach(({ hash }) =>
|
||||
this.update(hash, { lastChecked: 0, isDisabled: false })
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
@@ -1,82 +1,10 @@
|
||||
import { OpenAIModel } from "./openai/provider";
|
||||
import { AnthropicModel } from "./anthropic/provider";
|
||||
import { GooglePalmModel } from "./palm/provider";
|
||||
import { AwsBedrockModel } from "./aws/provider";
|
||||
import { AzureOpenAIModel } from "./azure/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
/** The request and response format used by a model's API. */
|
||||
export type APIFormat =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "openai-text"
|
||||
| "openai-image";
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "aws"
|
||||
| "azure";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| AwsBedrockModel
|
||||
| AzureOpenAIModel;
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
readonly key: string;
|
||||
/** The service that this key is for. */
|
||||
service: LLMService;
|
||||
/** The model families that this key has access to. */
|
||||
modelFamilies: ModelFamily[];
|
||||
/** Whether this key is currently disabled, meaning its quota has been exceeded or it has been revoked. */
|
||||
isDisabled: boolean;
|
||||
/** Whether this key specifically has been revoked. */
|
||||
isRevoked: boolean;
|
||||
/** The number of prompts that have been sent with this key. */
|
||||
promptCount: number;
|
||||
/** The time at which this key was last used. */
|
||||
lastUsed: number;
|
||||
/** The time at which this key was last checked. */
|
||||
lastChecked: number;
|
||||
/** Hash of the key, for logging and to find the key in the pool. */
|
||||
hash: string;
|
||||
}
|
||||
|
||||
/*
|
||||
KeyPool and KeyProvider's similarities are a relic of the old design where
|
||||
there was only a single KeyPool for OpenAI keys. Now that there are multiple
|
||||
supported services, the service-specific functionality has been moved to
|
||||
KeyProvider and KeyPool is just a wrapper around multiple KeyProviders,
|
||||
delegating to the appropriate one based on the model requested.
|
||||
|
||||
Existing code will continue to call methods on KeyPool, which routes them to
|
||||
the appropriate KeyProvider or returns data aggregated across all KeyProviders
|
||||
for service-agnostic functionality.
|
||||
*/
|
||||
|
||||
export interface KeyProvider<T extends Key = Key> {
|
||||
readonly service: LLMService;
|
||||
init(): void;
|
||||
get(model: Model): T;
|
||||
list(): Omit<T, "key">[];
|
||||
disable(key: T): void;
|
||||
update(hash: string, update: Partial<T>): void;
|
||||
available(): number;
|
||||
incrementUsage(hash: string, model: string, tokens: number): void;
|
||||
getLockoutPeriod(model: ModelFamily): number;
|
||||
markRateLimited(hash: string): void;
|
||||
recheck(): void;
|
||||
}
|
||||
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
export { keyPool } from "./key-pool";
|
||||
export { OPENAI_SUPPORTED_MODELS } from "./openai/provider";
|
||||
export { ANTHROPIC_SUPPORTED_MODELS } from "./anthropic/provider";
|
||||
export { GOOGLE_PALM_SUPPORTED_MODELS } from "./palm/provider";
|
||||
export { AWS_BEDROCK_SUPPORTED_MODELS } from "./aws/provider";
|
||||
export type { AnthropicKey } from "./anthropic/provider";
|
||||
export type { OpenAIKey } from "./openai/provider";
|
||||
export type { GooglePalmKey } from "./palm/provider";
|
||||
export type { AwsBedrockKey } from "./aws/provider";
|
||||
export * from "./types";
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import { AxiosError } from "axios";
|
||||
import pino from "pino";
|
||||
import { logger } from "../../logger";
|
||||
import { Key } from "./index";
|
||||
import { AxiosError } from "axios";
|
||||
import { Key } from "./types";
|
||||
|
||||
type KeyCheckerOptions<TKey extends Key = Key> = {
|
||||
type KeyCheckerOptions = {
|
||||
service: string;
|
||||
keyCheckPeriod: number;
|
||||
minCheckInterval: number;
|
||||
recurringChecksEnabled?: boolean;
|
||||
updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
};
|
||||
|
||||
export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
protected readonly service: string;
|
||||
protected readonly RECURRING_CHECKS_ENABLED: boolean;
|
||||
/** Minimum time in between any two key checks. */
|
||||
protected readonly MIN_CHECK_INTERVAL: number;
|
||||
/**
|
||||
@@ -22,19 +19,16 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
* than this.
|
||||
*/
|
||||
protected readonly KEY_CHECK_PERIOD: number;
|
||||
protected readonly updateKey: (hash: string, props: Partial<TKey>) => void;
|
||||
protected readonly keys: TKey[] = [];
|
||||
protected log: pino.Logger;
|
||||
protected timeout?: NodeJS.Timeout;
|
||||
protected lastCheck = 0;
|
||||
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions<TKey>) {
|
||||
protected constructor(keys: TKey[], opts: KeyCheckerOptions) {
|
||||
const { service, keyCheckPeriod, minCheckInterval } = opts;
|
||||
this.keys = keys;
|
||||
this.KEY_CHECK_PERIOD = keyCheckPeriod;
|
||||
this.MIN_CHECK_INTERVAL = minCheckInterval;
|
||||
this.RECURRING_CHECKS_ENABLED = opts.recurringChecksEnabled ?? true;
|
||||
this.updateKey = opts.updateKey;
|
||||
this.service = service;
|
||||
this.log = logger.child({ module: "key-checker", service });
|
||||
}
|
||||
@@ -58,34 +52,31 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
* the minimum check interval.
|
||||
*/
|
||||
public scheduleNextCheck() {
|
||||
// Gives each concurrent check a correlation ID to make logs less confusing.
|
||||
const callId = Math.random().toString(36).slice(2, 8);
|
||||
const timeoutId = this.timeout?.[Symbol.toPrimitive]?.();
|
||||
const checkLog = this.log.child({ callId, timeoutId });
|
||||
|
||||
const enabledKeys = this.keys.filter((key) => !key.isDisabled);
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
const numEnabled = enabledKeys.length;
|
||||
const numUnchecked = uncheckedKeys.length;
|
||||
checkLog.debug({ enabled: enabledKeys.length }, "Scheduling next check...");
|
||||
|
||||
clearTimeout(this.timeout);
|
||||
this.timeout = undefined;
|
||||
|
||||
if (!numEnabled) {
|
||||
checkLog.warn("All keys are disabled. Stopping.");
|
||||
if (enabledKeys.length === 0) {
|
||||
checkLog.warn("All keys are disabled. Key checker stopping.");
|
||||
return;
|
||||
}
|
||||
|
||||
checkLog.debug({ numEnabled, numUnchecked }, "Scheduling next check...");
|
||||
|
||||
if (numUnchecked > 0) {
|
||||
const keycheckBatch = uncheckedKeys.slice(0, 12);
|
||||
// Perform startup checks for any keys that haven't been checked yet.
|
||||
const uncheckedKeys = enabledKeys.filter((key) => !key.lastChecked);
|
||||
checkLog.debug({ unchecked: uncheckedKeys.length }, "# of unchecked keys");
|
||||
if (uncheckedKeys.length > 0) {
|
||||
const keysToCheck = uncheckedKeys.slice(0, 12);
|
||||
|
||||
this.timeout = setTimeout(async () => {
|
||||
try {
|
||||
await Promise.all(keycheckBatch.map((key) => this.checkKey(key)));
|
||||
await Promise.all(keysToCheck.map((key) => this.checkKey(key)));
|
||||
} catch (error) {
|
||||
checkLog.error({ error }, "Error checking one or more keys.");
|
||||
this.log.error({ error }, "Error checking one or more keys.");
|
||||
}
|
||||
checkLog.info("Batch complete.");
|
||||
this.scheduleNextCheck();
|
||||
@@ -93,18 +84,11 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
|
||||
checkLog.info(
|
||||
{
|
||||
batch: keycheckBatch.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keycheckBatch.length,
|
||||
batch: keysToCheck.map((k) => k.hash),
|
||||
remaining: uncheckedKeys.length - keysToCheck.length,
|
||||
newTimeoutId: this.timeout?.[Symbol.toPrimitive]?.(),
|
||||
},
|
||||
"Scheduled batch of initial checks."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!this.RECURRING_CHECKS_ENABLED) {
|
||||
checkLog.info(
|
||||
"Initial checks complete and recurring checks are disabled for this service. Stopping."
|
||||
"Scheduled batch check."
|
||||
);
|
||||
return;
|
||||
}
|
||||
@@ -122,35 +106,14 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
);
|
||||
|
||||
const delay = nextCheck - Date.now();
|
||||
this.timeout = setTimeout(
|
||||
() => this.checkKey(oldestKey).then(() => this.scheduleNextCheck()),
|
||||
delay
|
||||
);
|
||||
this.timeout = setTimeout(() => this.checkKey(oldestKey), delay);
|
||||
checkLog.debug(
|
||||
{ key: oldestKey.hash, nextCheck: new Date(nextCheck), delay },
|
||||
"Scheduled next recurring check."
|
||||
"Scheduled single key check."
|
||||
);
|
||||
}
|
||||
|
||||
public async checkKey(key: TKey): Promise<void> {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
|
||||
try {
|
||||
await this.testKeyOrFail(key);
|
||||
} catch (error) {
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
}
|
||||
|
||||
protected abstract testKeyOrFail(key: TKey): Promise<void>;
|
||||
protected abstract checkKey(key: TKey): Promise<void>;
|
||||
|
||||
protected abstract handleAxiosError(key: TKey, error: AxiosError): void;
|
||||
}
|
||||
|
||||
@@ -4,44 +4,42 @@ import os from "os";
|
||||
import schedule from "node-schedule";
|
||||
import { config } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
import { Key, Model, KeyProvider, LLMService } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { KeyProviderBase } from "./key-provider-base";
|
||||
import { getSerializer } from "./serializers";
|
||||
import { FirebaseKeyStore, MemoryKeyStore } from "./stores";
|
||||
import { AnthropicKeyProvider } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider } from "./openai/provider";
|
||||
import { GooglePalmKeyProvider } from "./palm/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { ModelFamily } from "../models";
|
||||
import { assertNever } from "../utils";
|
||||
import { AzureOpenAIKeyProvider } from "./azure/provider";
|
||||
|
||||
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
|
||||
import { Key, KeyStore, LLMService, Model, ServiceToKey } from "./types";
|
||||
|
||||
export class KeyPool {
|
||||
private keyProviders: KeyProvider[] = [];
|
||||
private keyProviders: KeyProviderBase[] = [];
|
||||
private recheckJobs: Partial<Record<LLMService, schedule.Job | null>> = {
|
||||
openai: null,
|
||||
};
|
||||
|
||||
constructor() {
|
||||
this.keyProviders.push(new OpenAIKeyProvider());
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
this.keyProviders.push(
|
||||
new OpenAIKeyProvider(createKeyStore("openai")),
|
||||
new AnthropicKeyProvider(createKeyStore("anthropic")),
|
||||
new GooglePalmKeyProvider(createKeyStore("google-palm")),
|
||||
new AwsBedrockKeyProvider(createKeyStore("aws"))
|
||||
);
|
||||
}
|
||||
|
||||
public init() {
|
||||
this.keyProviders.forEach((provider) => provider.init());
|
||||
public async init() {
|
||||
await Promise.all(this.keyProviders.map((p) => p.init()));
|
||||
|
||||
const availableKeys = this.available("all");
|
||||
if (availableKeys === 0) {
|
||||
throw new Error(
|
||||
"No keys loaded. Ensure that at least one key is configured."
|
||||
);
|
||||
throw new Error("No keys loaded, the application cannot start.");
|
||||
}
|
||||
this.scheduleRecheck();
|
||||
}
|
||||
|
||||
public get(model: Model): Key {
|
||||
const service = this.getServiceForModel(model);
|
||||
const service = this.getService(model);
|
||||
return this.getKeyProvider(service).get(model);
|
||||
}
|
||||
|
||||
@@ -63,7 +61,7 @@ export class KeyPool {
|
||||
}
|
||||
}
|
||||
|
||||
public update(key: Key, props: AllowedPartial): void {
|
||||
public update<T extends Key>(key: T, props: Partial<T>): void {
|
||||
const service = this.getKeyProvider(key.service);
|
||||
service.update(key.hash, props);
|
||||
}
|
||||
@@ -71,7 +69,7 @@ export class KeyPool {
|
||||
public available(model: Model | "all" = "all"): number {
|
||||
return this.keyProviders.reduce((sum, provider) => {
|
||||
const includeProvider =
|
||||
model === "all" || this.getServiceForModel(model) === provider.service;
|
||||
model === "all" || this.getService(model) === provider.service;
|
||||
return sum + (includeProvider ? provider.available() : 0);
|
||||
}, 0);
|
||||
}
|
||||
@@ -81,9 +79,9 @@ export class KeyPool {
|
||||
provider.incrementUsage(key.hash, model, tokens);
|
||||
}
|
||||
|
||||
public getLockoutPeriod(family: ModelFamily): number {
|
||||
const service = this.getServiceForModelFamily(family);
|
||||
return this.getKeyProvider(service).getLockoutPeriod(family);
|
||||
public getLockoutPeriod(model: Model): number {
|
||||
const service = this.getService(model);
|
||||
return this.getKeyProvider(service).getLockoutPeriod(model);
|
||||
}
|
||||
|
||||
public markRateLimited(key: Key): void {
|
||||
@@ -108,12 +106,8 @@ export class KeyPool {
|
||||
provider.recheck();
|
||||
}
|
||||
|
||||
private getServiceForModel(model: Model): LLMService {
|
||||
if (
|
||||
model.startsWith("gpt") ||
|
||||
model.startsWith("text-embedding-ada") ||
|
||||
model.startsWith("dall-e")
|
||||
) {
|
||||
private getService(model: Model): LLMService {
|
||||
if (model.startsWith("gpt") || model.startsWith("text-embedding-ada")) {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
return "openai";
|
||||
} else if (model.startsWith("claude-")) {
|
||||
@@ -126,37 +120,11 @@ export class KeyPool {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
return "aws";
|
||||
} else if (model.startsWith("azure")) {
|
||||
return "azure";
|
||||
}
|
||||
throw new Error(`Unknown service for model '${model}'`);
|
||||
}
|
||||
|
||||
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
|
||||
switch (modelFamily) {
|
||||
case "gpt4":
|
||||
case "gpt4-32k":
|
||||
case "gpt4-turbo":
|
||||
case "turbo":
|
||||
case "dall-e":
|
||||
return "openai";
|
||||
case "claude":
|
||||
return "anthropic";
|
||||
case "bison":
|
||||
return "google-palm";
|
||||
case "aws-claude":
|
||||
return "aws";
|
||||
case "azure-turbo":
|
||||
case "azure-gpt4":
|
||||
case "azure-gpt4-32k":
|
||||
case "azure-gpt4-turbo":
|
||||
return "azure";
|
||||
default:
|
||||
assertNever(modelFamily);
|
||||
}
|
||||
}
|
||||
|
||||
private getKeyProvider(service: LLMService): KeyProvider {
|
||||
private getKeyProvider(service: LLMService): KeyProviderBase {
|
||||
return this.keyProviders.find((provider) => provider.service === service)!;
|
||||
}
|
||||
|
||||
@@ -185,3 +153,25 @@ export class KeyPool {
|
||||
this.recheckJobs.openai = job;
|
||||
}
|
||||
}
|
||||
|
||||
function createKeyStore<S extends LLMService>(
|
||||
service: S
|
||||
): KeyStore<ServiceToKey[S]> {
|
||||
const serializer = getSerializer(service);
|
||||
|
||||
switch (config.persistenceProvider) {
|
||||
case "memory":
|
||||
return new MemoryKeyStore(service, serializer);
|
||||
case "firebase_rtdb":
|
||||
return new FirebaseKeyStore(service, serializer);
|
||||
default:
|
||||
throw new Error(`Unknown store type: ${config.persistenceProvider}`);
|
||||
}
|
||||
}
|
||||
|
||||
export let keyPool: KeyPool;
|
||||
|
||||
export async function init() {
|
||||
keyPool = new KeyPool();
|
||||
await keyPool.init();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { logger } from "../../logger";
|
||||
import { Key, KeyStore, LLMService, Model } from "./types";
|
||||
|
||||
export abstract class KeyProviderBase<K extends Key = Key> {
|
||||
public abstract readonly service: LLMService;
|
||||
|
||||
protected abstract readonly keys: K[];
|
||||
protected abstract log: typeof logger;
|
||||
protected readonly store: KeyStore<K>;
|
||||
|
||||
public constructor(keyStore: KeyStore<K>) {
|
||||
this.store = keyStore;
|
||||
}
|
||||
|
||||
public abstract init(): Promise<void>;
|
||||
|
||||
public addKey(key: K): void {
|
||||
this.keys.push(key);
|
||||
this.store.add(key);
|
||||
}
|
||||
|
||||
public abstract get(model: Model): K;
|
||||
|
||||
/**
|
||||
* Returns a list of all keys, with the actual key value removed. Don't
|
||||
* mutate the returned objects; use `update` instead to ensure the changes
|
||||
* are synced to the key store.
|
||||
*/
|
||||
public list(): Omit<K, "key">[] {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public disable(key: K): void {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
this.update(key.hash, { isDisabled: true } as Partial<K>, true);
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<K>, force = false): void {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) {
|
||||
throw new Error(`No key with hash ${hash}`);
|
||||
}
|
||||
|
||||
Object.assign(key, { lastChecked: Date.now(), ...update });
|
||||
this.store.update(hash, update, force);
|
||||
}
|
||||
|
||||
public available(): number {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public abstract incrementUsage(
|
||||
hash: string,
|
||||
model: string,
|
||||
tokens: number
|
||||
): void;
|
||||
|
||||
public abstract getLockoutPeriod(model: Model): number;
|
||||
|
||||
public abstract markRateLimited(hash: string): void;
|
||||
|
||||
public abstract recheck(): void;
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { Key, KeySerializer, SerializedKey } from "./types";
|
||||
|
||||
export abstract class KeySerializerBase<K extends Key>
|
||||
implements KeySerializer<K>
|
||||
{
|
||||
protected constructor(protected serializableFields: (keyof K)[]) {}
|
||||
|
||||
serialize(keyObj: K): SerializedKey {
|
||||
return {
|
||||
...Object.fromEntries(
|
||||
this.serializableFields
|
||||
.map((f) => [f, keyObj[f]])
|
||||
.filter(([, v]) => v !== undefined)
|
||||
),
|
||||
key: keyObj.key,
|
||||
};
|
||||
}
|
||||
|
||||
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey> {
|
||||
return {
|
||||
...Object.fromEntries(
|
||||
this.serializableFields
|
||||
.map((f) => [f, update[f]])
|
||||
.filter(([, v]) => v !== undefined)
|
||||
),
|
||||
key,
|
||||
};
|
||||
}
|
||||
|
||||
abstract deserialize(serializedKey: SerializedKey): K;
|
||||
}
|
||||
@@ -2,7 +2,6 @@ import axios, { AxiosError } from "axios";
|
||||
import type { OpenAIModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
|
||||
import { getOpenAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
|
||||
@@ -27,41 +26,65 @@ type UpdateFn = typeof OpenAIKeyProvider.prototype.update;
|
||||
|
||||
export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
private readonly cloneKey: CloneFn;
|
||||
private readonly updateKey: UpdateFn;
|
||||
|
||||
constructor(keys: OpenAIKey[], cloneFn: CloneFn, updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "openai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
this.cloneKey = cloneFn;
|
||||
this.updateKey = updateKey;
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: OpenAIKey) {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
const isInitialCheck = !key.lastChecked;
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
protected async checkKey(key: OpenAIKey) {
|
||||
if (key.isDisabled) {
|
||||
this.log.warn({ key: key.hash }, "Skipping check for disabled key.");
|
||||
this.scheduleNextCheck();
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.debug({ key: key.hash }, "Checking key...");
|
||||
let isInitialCheck = !key.lastChecked;
|
||||
try {
|
||||
// We only need to check for provisioned models on the initial check.
|
||||
if (isInitialCheck) {
|
||||
const [provisionedModels, livenessTest] = await Promise.all([
|
||||
this.getProvisionedModels(key),
|
||||
this.testLiveness(key),
|
||||
this.maybeCreateOrganizationClones(key),
|
||||
]);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
isTrial: livenessTest.rateLimit <= 250,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
} else {
|
||||
// No updates needed as models and trial status generally don't change.
|
||||
const [_livenessTest] = await Promise.all([this.testLiveness(key)]);
|
||||
this.updateKey(key.hash, {});
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Key check complete."
|
||||
);
|
||||
} catch (error) {
|
||||
// touch the key so we don't check it again for a while
|
||||
this.updateKey(key.hash, {});
|
||||
this.handleAxiosError(key, error as AxiosError);
|
||||
}
|
||||
|
||||
this.lastCheck = Date.now();
|
||||
// Only enqueue the next check if this wasn't a startup check, since those
|
||||
// are batched together elsewhere.
|
||||
if (!isInitialCheck) {
|
||||
this.log.info(
|
||||
{ key: key.hash },
|
||||
"Recurring keychecks are disabled, no-op."
|
||||
);
|
||||
// this.scheduleNextCheck();
|
||||
}
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, trial: key.isTrial },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
@@ -71,26 +94,29 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
|
||||
const models = data.data;
|
||||
|
||||
const families = new Set<OpenAIModelFamily>();
|
||||
models.forEach(({ id }) => families.add(getOpenAIModelFamily(id, "turbo")));
|
||||
const families: OpenAIModelFamily[] = [];
|
||||
if (models.some(({ id }) => id.startsWith("gpt-3.5-turbo"))) {
|
||||
families.push("turbo");
|
||||
}
|
||||
|
||||
// as of 2023-11-18, many keys no longer return the dalle3 model but still
|
||||
// have access to it via the api for whatever reason.
|
||||
// if (families.has("dall-e") && !models.find(({ id }) => id === "dall-e-3")) {
|
||||
// families.delete("dall-e");
|
||||
// }
|
||||
if (models.some(({ id }) => id.startsWith("gpt-4"))) {
|
||||
families.push("gpt4");
|
||||
}
|
||||
|
||||
if (models.some(({ id }) => id.startsWith("gpt-4-32k"))) {
|
||||
families.push("gpt4-32k");
|
||||
}
|
||||
|
||||
// We want to update the key's model families here, but we don't want to
|
||||
// update its `lastChecked` timestamp because we need to let the liveness
|
||||
// check run before we can consider the key checked.
|
||||
|
||||
const familiesArray = [...families];
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
|
||||
this.updateKey(key.hash, {
|
||||
modelFamilies: familiesArray,
|
||||
modelFamilies: families,
|
||||
lastChecked: keyFromPool.lastChecked,
|
||||
});
|
||||
return familiesArray;
|
||||
return families;
|
||||
}
|
||||
|
||||
private async maybeCreateOrganizationClones(key: OpenAIKey) {
|
||||
@@ -114,17 +140,6 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
|
||||
.filter(({ is_default }) => !is_default)
|
||||
.map(({ id }) => id);
|
||||
this.cloneKey(key.hash, ids);
|
||||
|
||||
// It's possible that the keychecker may be stopped if all non-cloned keys
|
||||
// happened to be unusable, in which case this clnoe will never be checked
|
||||
// unless we restart the keychecker.
|
||||
if (!this.timeout) {
|
||||
this.log.warn(
|
||||
{ parent: key.hash },
|
||||
"Restarting key checker to check cloned keys."
|
||||
);
|
||||
this.scheduleNextCheck();
|
||||
}
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: OpenAIKey, error: AxiosError) {
|
||||
|
||||
@@ -1,26 +1,23 @@
|
||||
/* Manages OpenAI API keys. Tracks usage, disables expired keys, and provides
|
||||
round-robin access to keys. Keys are stored in the OPENAI_KEY environment
|
||||
variable as a comma-separated list of keys. */
|
||||
import crypto from "crypto";
|
||||
import http from "http";
|
||||
import { Key, KeyProvider, Model } from "../index";
|
||||
import { IncomingHttpHeaders } from "http";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { OpenAIKeyChecker } from "./checker";
|
||||
import { getOpenAIModelFamily, OpenAIModelFamily } from "../../models";
|
||||
import { Key, Model } from "../types";
|
||||
import { OpenAIKeyChecker } from "./checker";
|
||||
import { KeyProviderBase } from "../key-provider-base";
|
||||
|
||||
export type OpenAIModel =
|
||||
| "gpt-3.5-turbo"
|
||||
| "gpt-3.5-turbo-instruct"
|
||||
| "gpt-4"
|
||||
| "gpt-4-32k"
|
||||
| "gpt-4-1106"
|
||||
| "text-embedding-ada-002"
|
||||
| "dall-e-2"
|
||||
| "dall-e-3"
|
||||
const KEY_REUSE_DELAY = 1000;
|
||||
|
||||
export const OPENAI_SUPPORTED_MODELS = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"gpt-4",
|
||||
"gpt-4-32k",
|
||||
"text-embedding-ada-002",
|
||||
] as const;
|
||||
export type OpenAIModel = (typeof OPENAI_SUPPORTED_MODELS)[number];
|
||||
|
||||
// Flattening model families instead of using a nested object for easier
|
||||
// cloning.
|
||||
type OpenAIKeyUsage = {
|
||||
[K in OpenAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
@@ -62,77 +59,32 @@ export interface OpenAIKey extends Key, OpenAIKeyUsage {
|
||||
* tokens.
|
||||
*/
|
||||
rateLimitTokensReset: number;
|
||||
/**
|
||||
* This key's maximum request rate for GPT-4, per minute.
|
||||
*/
|
||||
gpt4Rpm: number;
|
||||
}
|
||||
|
||||
export type OpenAIKeyUpdate = Omit<
|
||||
Partial<OpenAIKey>,
|
||||
"key" | "hash" | "promptCount"
|
||||
>;
|
||||
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 1000;
|
||||
|
||||
export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
export class OpenAIKeyProvider extends KeyProviderBase<OpenAIKey> {
|
||||
readonly service = "openai" as const;
|
||||
|
||||
private keys: OpenAIKey[] = [];
|
||||
protected readonly keys: OpenAIKey[] = [];
|
||||
private checker?: OpenAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyString = config.openaiKey?.trim();
|
||||
if (!keyString) {
|
||||
this.log.warn("OPENAI_KEY is not set. OpenAI API will not be available.");
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = keyString.split(",").map((k) => k.trim());
|
||||
bareKeys = [...new Set(bareKeys)];
|
||||
for (const k of bareKeys) {
|
||||
const newKey: OpenAIKey = {
|
||||
key: k,
|
||||
service: "openai" as const,
|
||||
modelFamilies: [
|
||||
"turbo" as const,
|
||||
"gpt4" as const,
|
||||
"gpt4-turbo" as const,
|
||||
],
|
||||
isTrial: false,
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
isOverQuota: false,
|
||||
lastUsed: 0,
|
||||
lastChecked: 0,
|
||||
promptCount: 0,
|
||||
hash: `oai-${crypto
|
||||
.createHash("sha256")
|
||||
.update(k)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitRequestsReset: 0,
|
||||
rateLimitTokensReset: 0,
|
||||
turboTokens: 0,
|
||||
gpt4Tokens: 0,
|
||||
"gpt4-32kTokens": 0,
|
||||
"gpt4-turboTokens": 0,
|
||||
"dall-eTokens": 0,
|
||||
gpt4Rpm: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded OpenAI keys.");
|
||||
}
|
||||
public async init() {
|
||||
const storeName = this.store.constructor.name;
|
||||
const loadedKeys = await this.store.load();
|
||||
|
||||
// TODO: after key management UI, keychecker should always be enabled
|
||||
// because keys may be added after initialization.
|
||||
|
||||
if (loadedKeys.length === 0) {
|
||||
return this.log.warn({ via: storeName }, "No OpenAI keys found.");
|
||||
}
|
||||
|
||||
this.keys.push(...loadedKeys);
|
||||
this.log.info(
|
||||
{ count: this.keys.length, via: storeName },
|
||||
"Loaded OpenAI keys."
|
||||
);
|
||||
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
const cloneFn = this.clone.bind(this);
|
||||
const updateFn = this.update.bind(this);
|
||||
@@ -141,33 +93,26 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a list of all keys, with the key field removed.
|
||||
* Don't mutate returned keys, use a KeyPool method instead.
|
||||
**/
|
||||
public list() {
|
||||
return this.keys.map((key) => {
|
||||
return Object.freeze({
|
||||
...key,
|
||||
key: undefined,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public get(model: Model) {
|
||||
const neededFamily = getOpenAIModelFamily(model);
|
||||
const excludeTrials = model === "text-embedding-ada-002";
|
||||
|
||||
const availableKeys = this.keys.filter(
|
||||
// Allow keys which
|
||||
// Allow keys which...
|
||||
(key) =>
|
||||
!key.isDisabled && // are not disabled
|
||||
key.modelFamilies.includes(neededFamily) && // have access to the model
|
||||
(!excludeTrials || !key.isTrial) // and are not trials (if applicable)
|
||||
!key.isDisabled && // ...are not disabled
|
||||
key.modelFamilies.includes(neededFamily) && // ...have access to the model
|
||||
(!excludeTrials || !key.isTrial) // ...and are not trials (if applicable)
|
||||
);
|
||||
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error(`No keys available for model family '${neededFamily}'.`);
|
||||
throw new Error(`No active keys available for ${neededFamily} models.`);
|
||||
}
|
||||
|
||||
if (!config.allowedModelFamilies.includes(neededFamily)) {
|
||||
throw new Error(
|
||||
`Proxy operator has disabled access to ${neededFamily} models.`
|
||||
);
|
||||
}
|
||||
|
||||
// Select a key, from highest priority to lowest priority:
|
||||
@@ -211,16 +156,29 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
return a.lastUsed - b.lastUsed;
|
||||
});
|
||||
|
||||
// logger.debug(
|
||||
// {
|
||||
// byPriority: keysByPriority.map((k) => ({
|
||||
// hash: k.hash,
|
||||
// isRateLimited: now - k.rateLimitedAt < rateLimitThreshold,
|
||||
// modelFamilies: k.modelFamilies,
|
||||
// })),
|
||||
// },
|
||||
// "Keys sorted by priority"
|
||||
// );
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
/** Called by the key checker to update key information. */
|
||||
public update(keyHash: string, update: OpenAIKeyUpdate) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === keyHash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
// When a key is selected, we rate-limit it for a brief period of time to
|
||||
// prevent the queue processor from immediately flooding it with requests
|
||||
// while the initial request is still being processed (which is when we will
|
||||
// get new rate limit headers).
|
||||
// Instead, we will let a request through every second until the key
|
||||
// becomes fully saturated and locked out again.
|
||||
selectedKey.rateLimitedAt = now;
|
||||
selectedKey.rateLimitRequestsReset = KEY_REUSE_DELAY;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
/** Called by the key checker to create clones of keys for the given orgs. */
|
||||
@@ -231,8 +189,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
...keyFromPool,
|
||||
organizationId: orgId,
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
isOverQuota: false,
|
||||
hash: `oai-${crypto
|
||||
.createHash("sha256")
|
||||
.update(keyFromPool.key + orgId)
|
||||
@@ -246,33 +202,25 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
);
|
||||
return clone;
|
||||
});
|
||||
this.keys.push(...clones);
|
||||
}
|
||||
|
||||
/** Disables a key, or does nothing if the key isn't in this pool. */
|
||||
public disable(key: Key) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
this.update(key.hash, { isDisabled: true });
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
clones.forEach((clone) => this.addKey(clone));
|
||||
}
|
||||
|
||||
/**
|
||||
* Given a model, returns the period until a key will be available to service
|
||||
* the request, or returns 0 if a key is ready immediately.
|
||||
*/
|
||||
public getLockoutPeriod(family: OpenAIModelFamily): number {
|
||||
public getLockoutPeriod(model: Model = "gpt-4"): number {
|
||||
const neededFamily = getOpenAIModelFamily(model);
|
||||
const activeKeys = this.keys.filter(
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(family)
|
||||
(key) => !key.isDisabled && key.modelFamilies.includes(neededFamily)
|
||||
);
|
||||
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
if (activeKeys.length === 0) return 0;
|
||||
if (activeKeys.length === 0) {
|
||||
// If there are no active keys for this model we can't fulfill requests.
|
||||
// We'll return 0 to let the request through and return an error,
|
||||
// otherwise the request will be stuck in the queue forever.
|
||||
return 0;
|
||||
}
|
||||
|
||||
// A key is rate-limited if its `rateLimitedAt` plus the greater of its
|
||||
// `rateLimitRequestsReset` and `rateLimitTokensReset` is after the
|
||||
@@ -285,7 +233,7 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
key.rateLimitRequestsReset,
|
||||
key.rateLimitTokensReset
|
||||
);
|
||||
return now < key.rateLimitedAt + Math.min(20000, resetTime);
|
||||
return now < key.rateLimitedAt + resetTime;
|
||||
}).length;
|
||||
const anyNotRateLimited = rateLimitedKeys < activeKeys.length;
|
||||
|
||||
@@ -294,16 +242,14 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
}
|
||||
|
||||
// If all keys are rate-limited, return the time until the first key is
|
||||
// ready. We don't want to wait longer than 10 seconds because rate limits
|
||||
// are a rolling window and keys may become available sooner than the stated
|
||||
// reset time.
|
||||
// ready.
|
||||
return Math.min(
|
||||
...activeKeys.map((key) => {
|
||||
const resetTime = Math.max(
|
||||
key.rateLimitRequestsReset,
|
||||
key.rateLimitTokensReset
|
||||
);
|
||||
return key.rateLimitedAt + Math.min(20000, resetTime) - now;
|
||||
return key.rateLimitedAt + resetTime - now;
|
||||
})
|
||||
);
|
||||
}
|
||||
@@ -312,10 +258,6 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
this.log.debug({ key: keyHash }, "Key rate limited");
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
key.rateLimitedAt = Date.now();
|
||||
// DALL-E requests do not send headers telling us when the rate limit will
|
||||
// be reset so we need to set a fallback value here. Other models will have
|
||||
// this overwritten by the `updateRateLimits` method.
|
||||
key.rateLimitRequestsReset = 20000;
|
||||
}
|
||||
|
||||
public incrementUsage(keyHash: string, model: string, tokens: number) {
|
||||
@@ -325,21 +267,35 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
key[`${getOpenAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public updateRateLimits(keyHash: string, headers: http.IncomingHttpHeaders) {
|
||||
public updateRateLimits(keyHash: string, headers: IncomingHttpHeaders) {
|
||||
const key = this.keys.find((k) => k.hash === keyHash)!;
|
||||
const requestsReset = headers["x-ratelimit-reset-requests"];
|
||||
const tokensReset = headers["x-ratelimit-reset-tokens"];
|
||||
|
||||
if (typeof requestsReset === "string") {
|
||||
// Sometimes OpenAI only sends one of the two rate limit headers, it's
|
||||
// unclear why.
|
||||
|
||||
if (requestsReset && typeof requestsReset === "string") {
|
||||
this.log.debug(
|
||||
{ key: key.hash, requestsReset },
|
||||
`Updating rate limit requests reset time`
|
||||
);
|
||||
key.rateLimitRequestsReset = getResetDurationMillis(requestsReset);
|
||||
}
|
||||
|
||||
if (typeof tokensReset === "string") {
|
||||
if (tokensReset && typeof tokensReset === "string") {
|
||||
this.log.debug(
|
||||
{ key: key.hash, tokensReset },
|
||||
`Updating rate limit tokens reset time`
|
||||
);
|
||||
key.rateLimitTokensReset = getResetDurationMillis(tokensReset);
|
||||
}
|
||||
|
||||
if (!requestsReset && !tokensReset) {
|
||||
this.log.warn({ key: key.hash }, `No ratelimit headers; skipping update`);
|
||||
this.log.warn(
|
||||
{ key: key.hash },
|
||||
`No rate limit headers in OpenAI response; skipping update`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -355,67 +311,21 @@ export class OpenAIKeyProvider implements KeyProvider<OpenAIKey> {
|
||||
});
|
||||
this.checker?.scheduleNextCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Called when a key is selected for a request, briefly disabling it to
|
||||
* avoid spamming the API with requests while we wait to learn whether this
|
||||
* key is already rate limited.
|
||||
*/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit =
|
||||
Math.max(key.rateLimitRequestsReset, key.rateLimitTokensReset) +
|
||||
key.rateLimitedAt;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
// Don't throttle if the key is already naturally rate limited.
|
||||
if (currentRateLimit > nextRateLimit) return;
|
||||
|
||||
key.rateLimitedAt = Date.now();
|
||||
key.rateLimitRequestsReset = KEY_REUSE_DELAY;
|
||||
}
|
||||
}
|
||||
|
||||
// wip
|
||||
function calculateRequestsPerMinute(headers: http.IncomingHttpHeaders) {
|
||||
const requestsLimit = headers["x-ratelimit-limit-requests"];
|
||||
const requestsReset = headers["x-ratelimit-reset-requests"];
|
||||
|
||||
if (typeof requestsLimit !== "string" || typeof requestsReset !== "string") {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const limit = parseInt(requestsLimit, 10);
|
||||
const reset = getResetDurationMillis(requestsReset);
|
||||
|
||||
// If `reset` is less than one minute, OpenAI specifies the `limit` as an
|
||||
// integer representing requests per minute. Otherwise it actually means the
|
||||
// requests per day.
|
||||
const isPerMinute = reset < 60000;
|
||||
if (isPerMinute) return limit;
|
||||
return limit / 1440;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts reset string ("14m25s", "21.0032s", "14ms" or "21ms") to a number of
|
||||
* milliseconds.
|
||||
* Converts reset string ("21.0032s" or "21ms") to a number of milliseconds.
|
||||
* Result is clamped to 10s even though the API returns up to 60s, because the
|
||||
* API returns the time until the entire quota is reset, even if a key may be
|
||||
* able to fulfill requests before then due to partial resets.
|
||||
**/
|
||||
function getResetDurationMillis(resetDuration?: string): number {
|
||||
const match = resetDuration?.match(
|
||||
/(?:(\d+)m(?!s))?(?:(\d+(?:\.\d+)?)s)?(?:(\d+)ms)?/
|
||||
);
|
||||
|
||||
const match = resetDuration?.match(/(\d+(\.\d+)?)(s|ms)/);
|
||||
if (match) {
|
||||
const [, minutes, seconds, milliseconds] = match.map(Number);
|
||||
|
||||
const minutesToMillis = (minutes || 0) * 60 * 1000;
|
||||
const secondsToMillis = (seconds || 0) * 1000;
|
||||
const millisecondsValue = milliseconds || 0;
|
||||
|
||||
return minutesToMillis + secondsToMillis + millisecondsValue;
|
||||
const [, time, , unit] = match;
|
||||
const value = parseFloat(time);
|
||||
const result = unit === "s" ? value * 1000 : value;
|
||||
return Math.min(result, 10000);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import crypto from "crypto";
|
||||
import type { OpenAIKey, SerializedKey } from "../index";
|
||||
import { KeySerializerBase } from "../key-serializer-base";
|
||||
|
||||
const SERIALIZABLE_FIELDS: (keyof OpenAIKey)[] = [
|
||||
"key",
|
||||
"service",
|
||||
"hash",
|
||||
"organizationId",
|
||||
"promptCount",
|
||||
"gpt4Tokens",
|
||||
"gpt4-32kTokens",
|
||||
"turboTokens",
|
||||
];
|
||||
export type SerializedOpenAIKey = SerializedKey &
|
||||
Partial<Pick<OpenAIKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||
|
||||
export class OpenAIKeySerializer extends KeySerializerBase<OpenAIKey> {
|
||||
constructor() {
|
||||
super(SERIALIZABLE_FIELDS);
|
||||
}
|
||||
|
||||
deserialize({ key, ...rest }: SerializedOpenAIKey): OpenAIKey {
|
||||
return {
|
||||
key,
|
||||
service: "openai",
|
||||
modelFamilies: ["turbo" as const, "gpt4" as const],
|
||||
isTrial: false,
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
isOverQuota: false,
|
||||
lastUsed: 0,
|
||||
lastChecked: 0,
|
||||
promptCount: 0,
|
||||
hash: `oai-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitRequestsReset: 0,
|
||||
rateLimitTokensReset: 0,
|
||||
turboTokens: 0,
|
||||
gpt4Tokens: 0,
|
||||
"gpt4-32kTokens": 0,
|
||||
...rest,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,14 @@
|
||||
import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { GooglePalmModelFamily } from "../../models";
|
||||
import { KeyProviderBase } from "../key-provider-base";
|
||||
import { Key } from "../types";
|
||||
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
export type GooglePalmModel = "text-bison-001";
|
||||
|
||||
export type GooglePalmKeyUpdate = Omit<
|
||||
Partial<GooglePalmKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
| "promptCount"
|
||||
| "rateLimitedAt"
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
export const GOOGLE_PALM_SUPPORTED_MODELS = ["text-bison-001"] as const;
|
||||
export type GooglePalmModel = (typeof GOOGLE_PALM_SUPPORTED_MODELS)[number];
|
||||
|
||||
type GooglePalmKeyUsage = {
|
||||
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
|
||||
@@ -30,62 +23,25 @@ export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
|
||||
rateLimitedUntil: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Upon being rate limited, a key will be locked out for this many milliseconds
|
||||
* while we wait for other concurrent requests to finish.
|
||||
*/
|
||||
const RATE_LIMIT_LOCKOUT = 2000;
|
||||
/**
|
||||
* Upon assigning a key, we will wait this many milliseconds before allowing it
|
||||
* to be used again. This is to prevent the queue from flooding a key with too
|
||||
* many requests while we wait to learn whether previous ones succeeded.
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
export class GooglePalmKeyProvider extends KeyProviderBase<GooglePalmKey> {
|
||||
readonly service = "google-palm";
|
||||
|
||||
private keys: GooglePalmKey[] = [];
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
protected keys: GooglePalmKey[] = [];
|
||||
protected log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.googlePalmKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
|
||||
);
|
||||
return;
|
||||
public async init() {
|
||||
const storeName = this.store.constructor.name;
|
||||
const loadedKeys = await this.store.load();
|
||||
|
||||
if (loadedKeys.length === 0) {
|
||||
return this.log.warn({ via: storeName }, "No Google PaLM keys found.");
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: GooglePalmKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["bison"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `plm-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
bisonTokens: 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
|
||||
}
|
||||
|
||||
public init() {}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
this.keys.push(...loadedKeys);
|
||||
this.log.info(
|
||||
{ count: this.keys.length, via: storeName },
|
||||
"Loaded PaLM keys."
|
||||
);
|
||||
}
|
||||
|
||||
public get(_model: GooglePalmModel) {
|
||||
@@ -118,26 +74,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
|
||||
const selectedKey = keysByPriority[0];
|
||||
selectedKey.lastUsed = now;
|
||||
this.throttle(selectedKey.hash);
|
||||
selectedKey.rateLimitedAt = now;
|
||||
// Intended to throttle the queue processor as otherwise it will just
|
||||
// flood the API with requests and we want to wait a sec to see if we're
|
||||
// going to get a rate limit error on this key.
|
||||
selectedKey.rateLimitedUntil = now + KEY_REUSE_DELAY;
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: GooglePalmKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<GooglePalmKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
|
||||
public available() {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
@@ -145,7 +89,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
key.bisonTokens += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
public getLockoutPeriod(_model: GooglePalmModel) {
|
||||
const activeKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
// Don't lock out if there are no keys available or the queue will stall.
|
||||
// Just let it through so the add-key middleware can throw an error.
|
||||
@@ -178,20 +122,4 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
}
|
||||
|
||||
public recheck() {}
|
||||
|
||||
/**
|
||||
* Applies a short artificial delay to the key upon dequeueing, in order to
|
||||
* prevent it from being immediately assigned to another request before the
|
||||
* current one can be dispatched.
|
||||
**/
|
||||
private throttle(hash: string) {
|
||||
const now = Date.now();
|
||||
const key = this.keys.find((k) => k.hash === hash)!;
|
||||
|
||||
const currentRateLimit = key.rateLimitedUntil;
|
||||
const nextRateLimit = now + KEY_REUSE_DELAY;
|
||||
|
||||
key.rateLimitedAt = now;
|
||||
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import crypto from "crypto";
|
||||
import type { GooglePalmKey, SerializedKey } from "../index";
|
||||
import { KeySerializerBase } from "../key-serializer-base";
|
||||
|
||||
const SERIALIZABLE_FIELDS: (keyof GooglePalmKey)[] = [
|
||||
"key",
|
||||
"service",
|
||||
"hash",
|
||||
"promptCount",
|
||||
"bisonTokens",
|
||||
];
|
||||
export type SerializedGooglePalmKey = SerializedKey &
|
||||
Partial<Pick<GooglePalmKey, (typeof SERIALIZABLE_FIELDS)[number]>>;
|
||||
|
||||
export class GooglePalmKeySerializer extends KeySerializerBase<GooglePalmKey> {
|
||||
constructor() {
|
||||
super(SERIALIZABLE_FIELDS);
|
||||
}
|
||||
|
||||
deserialize(serializedKey: SerializedGooglePalmKey): GooglePalmKey {
|
||||
const { key, ...rest } = serializedKey;
|
||||
return {
|
||||
key,
|
||||
service: "google-palm" as const,
|
||||
modelFamilies: ["bison"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
lastUsed: 0,
|
||||
rateLimitedAt: 0,
|
||||
rateLimitedUntil: 0,
|
||||
hash: `plm-${crypto
|
||||
.createHash("sha256")
|
||||
.update(key)
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
bisonTokens: 0,
|
||||
...rest,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
Key,
|
||||
KeySerializer,
|
||||
LLMService,
|
||||
SerializedKey,
|
||||
ServiceToKey,
|
||||
} from "./types";
|
||||
import { OpenAIKeySerializer } from "./openai/serializer";
|
||||
import { AnthropicKeySerializer } from "./anthropic/serializer";
|
||||
import { GooglePalmKeySerializer } from "./palm/serializer";
|
||||
import { AwsBedrockKeySerializer } from "./aws/serializer";
|
||||
|
||||
export function assertSerializedKey(k: any): asserts k is SerializedKey {
|
||||
if (typeof k !== "object" || !k || typeof (k as any).key !== "string") {
|
||||
throw new Error("Invalid serialized key data");
|
||||
}
|
||||
}
|
||||
|
||||
export function getSerializer<S extends LLMService>(
|
||||
service: S
|
||||
): KeySerializer<ServiceToKey[S]>;
|
||||
export function getSerializer(service: LLMService): KeySerializer<Key> {
|
||||
switch (service) {
|
||||
case "openai":
|
||||
return new OpenAIKeySerializer();
|
||||
case "anthropic":
|
||||
return new AnthropicKeySerializer();
|
||||
case "google-palm":
|
||||
return new GooglePalmKeySerializer();
|
||||
case "aws":
|
||||
return new AwsBedrockKeySerializer();
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
import firebase from "firebase-admin";
|
||||
import { config, getFirebaseApp } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { assertSerializedKey } from "../serializers";
|
||||
import type {
|
||||
Key,
|
||||
KeySerializer,
|
||||
KeyStore,
|
||||
LLMService,
|
||||
SerializedKey,
|
||||
} from "../types";
|
||||
import { MemoryKeyStore } from "./index";
|
||||
|
||||
export class FirebaseKeyStore<K extends Key> implements KeyStore<K> {
|
||||
private readonly db: firebase.database.Database;
|
||||
private readonly log: typeof logger;
|
||||
private readonly pendingUpdates: Map<string, Partial<SerializedKey>>;
|
||||
private readonly root: string;
|
||||
private readonly serializer: KeySerializer<K>;
|
||||
private readonly service: LLMService;
|
||||
private flushInterval: NodeJS.Timeout | null = null;
|
||||
private keysRef: firebase.database.Reference | null = null;
|
||||
|
||||
constructor(
|
||||
service: LLMService,
|
||||
serializer: KeySerializer<K>,
|
||||
app = getFirebaseApp()
|
||||
) {
|
||||
this.db = firebase.database(app);
|
||||
this.log = logger.child({ module: "firebase-key-store", service });
|
||||
this.root = `keys/${config.firebaseRtdbRoot.toLowerCase()}/${service}`;
|
||||
this.serializer = serializer;
|
||||
this.service = service;
|
||||
this.pendingUpdates = new Map();
|
||||
this.scheduleFlush();
|
||||
}
|
||||
|
||||
public async load(isMigrating = false): Promise<K[]> {
|
||||
const keysRef = this.db.ref(this.root);
|
||||
const snapshot = await keysRef.once("value");
|
||||
const keys = snapshot.val();
|
||||
this.keysRef = keysRef;
|
||||
|
||||
if (!keys) {
|
||||
if (isMigrating) return [];
|
||||
this.log.warn("No keys found in Firebase. Migrating from environment.");
|
||||
await this.migrate();
|
||||
return this.load(true);
|
||||
}
|
||||
|
||||
return Object.values(keys).map((k) => {
|
||||
assertSerializedKey(k);
|
||||
return this.serializer.deserialize(k);
|
||||
});
|
||||
}
|
||||
|
||||
public add(key: K) {
|
||||
const serialized = this.serializer.serialize(key);
|
||||
this.pendingUpdates.set(key.hash, serialized);
|
||||
this.forceFlush();
|
||||
}
|
||||
|
||||
public update(id: string, update: Partial<K>, force = false) {
|
||||
const existing = this.pendingUpdates.get(id) ?? {};
|
||||
Object.assign(existing, this.serializer.partialSerialize(id, update));
|
||||
this.pendingUpdates.set(id, existing);
|
||||
if (force) this.forceFlush();
|
||||
}
|
||||
|
||||
private forceFlush() {
|
||||
if (this.flushInterval) clearInterval(this.flushInterval);
|
||||
this.flushInterval = setTimeout(() => this.flush(), 0);
|
||||
}
|
||||
|
||||
private scheduleFlush() {
|
||||
if (this.flushInterval) clearInterval(this.flushInterval);
|
||||
this.flushInterval = setInterval(() => this.flush(), 1000 * 60 * 5);
|
||||
}
|
||||
|
||||
private async flush() {
|
||||
if (!this.keysRef) {
|
||||
this.log.warn(
|
||||
{ pendingUpdates: this.pendingUpdates.size },
|
||||
"Database not loaded yet. Skipping flush."
|
||||
);
|
||||
return this.scheduleFlush();
|
||||
}
|
||||
|
||||
if (this.pendingUpdates.size === 0) {
|
||||
this.log.debug("No pending key updates to flush.");
|
||||
return this.scheduleFlush();
|
||||
}
|
||||
|
||||
const updates: Record<string, Partial<SerializedKey>> = {};
|
||||
this.pendingUpdates.forEach((v, k) => (updates[k] = v));
|
||||
this.pendingUpdates.clear();
|
||||
console.log(updates);
|
||||
|
||||
await this.keysRef.update(updates);
|
||||
|
||||
this.log.debug(
|
||||
{ count: Object.keys(updates).length },
|
||||
"Flushed pending key updates."
|
||||
);
|
||||
this.scheduleFlush();
|
||||
}
|
||||
|
||||
private async migrate(): Promise<SerializedKey[]> {
|
||||
const keysRef = this.db.ref(this.root);
|
||||
const envStore = new MemoryKeyStore<K>(this.service, this.serializer);
|
||||
const keys = await envStore.load();
|
||||
|
||||
if (keys.length === 0) {
|
||||
this.log.warn("No keys found in environment or Firebase.");
|
||||
return [];
|
||||
}
|
||||
|
||||
const updates: Record<string, SerializedKey> = {};
|
||||
keys.forEach((k) => (updates[k.hash] = this.serializer.serialize(k)));
|
||||
await keysRef.update(updates);
|
||||
|
||||
this.log.info({ count: keys.length }, "Migrated keys from environment.");
|
||||
return Object.values(updates);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
export { FirebaseKeyStore } from "./firebase";
|
||||
export { MemoryKeyStore } from "./memory";
|
||||
@@ -0,0 +1,41 @@
|
||||
import { assertNever } from "../../utils";
|
||||
import { Key, KeySerializer, KeyStore, LLMService } from "../types";
|
||||
|
||||
export class MemoryKeyStore<K extends Key> implements KeyStore<K> {
|
||||
private readonly env: string;
|
||||
private readonly serializer: KeySerializer<K>;
|
||||
|
||||
constructor(service: LLMService, serializer: KeySerializer<K>) {
|
||||
switch (service) {
|
||||
case "anthropic":
|
||||
this.env = "ANTHROPIC_KEY";
|
||||
break;
|
||||
case "openai":
|
||||
this.env = "OPENAI_KEY";
|
||||
break;
|
||||
case "google-palm":
|
||||
this.env = "GOOGLE_PALM_KEY";
|
||||
break;
|
||||
case "aws":
|
||||
this.env = "AWS_CREDENTIALS";
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
this.serializer = serializer;
|
||||
}
|
||||
|
||||
public async load() {
|
||||
let envKeys: string[];
|
||||
envKeys = [
|
||||
...new Set(process.env[this.env]?.split(",").map((k) => k.trim())),
|
||||
];
|
||||
return envKeys
|
||||
.filter((k) => k)
|
||||
.map((k) => this.serializer.deserialize({ key: k }));
|
||||
}
|
||||
|
||||
public add() {}
|
||||
|
||||
public update() {}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
import type { OpenAIKey, OpenAIModel } from "./openai/provider";
|
||||
import type { AnthropicKey, AnthropicModel } from "./anthropic/provider";
|
||||
import type { GooglePalmKey, GooglePalmModel } from "./palm/provider";
|
||||
import type { AwsBedrockKey, AwsBedrockModel } from "./aws/provider";
|
||||
import type { ModelFamily } from "../models";
|
||||
|
||||
/** The request and response format used by a model's API. */
|
||||
export type APIFormat = "openai" | "anthropic" | "google-palm" | "openai-text";
|
||||
/**
|
||||
* The service that a model is hosted on; distinct because services like AWS
|
||||
* provide APIs from other service providers, but have their own authentication
|
||||
* and key management.
|
||||
*/
|
||||
export type LLMService = "openai" | "anthropic" | "google-palm" | "aws";
|
||||
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| AwsBedrockModel;
|
||||
|
||||
type AllKeys = OpenAIKey | AnthropicKey | GooglePalmKey | AwsBedrockKey;
|
||||
export type ServiceToKey = {
|
||||
[K in AllKeys["service"]]: Extract<AllKeys, { service: K }>;
|
||||
};
|
||||
export type SerializedKey = { key: string };
|
||||
|
||||
export interface Key {
|
||||
/** The API key itself. Never log this, use `hash` instead. */
|
||||
readonly key: string;
|
||||
/** The service that this key is for. */
|
||||
service: LLMService;
|
||||
/** The model families that this key has access to. */
|
||||
modelFamilies: ModelFamily[];
|
||||
/** Whether this key is currently disabled for some reason. */
|
||||
isDisabled: boolean;
|
||||
/**
|
||||
* Whether this key specifically has been revoked. This is different from
|
||||
* `isDisabled` because a key can be disabled for other reasons, such as
|
||||
* exceeding its quota. A revoked key is assumed to be permanently disabled,
|
||||
* and KeyStore implementations should not return it when loading keys.
|
||||
*/
|
||||
isRevoked: boolean;
|
||||
/** The number of prompts that have been sent with this key. */
|
||||
promptCount: number;
|
||||
/** The time at which this key was last used. */
|
||||
lastUsed: number;
|
||||
/** The time at which this key was last checked. */
|
||||
lastChecked: number;
|
||||
/** Hash of the key, for logging and to find the key in the pool. */
|
||||
hash: string;
|
||||
}
|
||||
|
||||
export interface KeySerializer<K> {
|
||||
serialize(keyObj: K): SerializedKey;
|
||||
deserialize(serializedKey: SerializedKey): K;
|
||||
partialSerialize(key: string, update: Partial<K>): Partial<SerializedKey>;
|
||||
}
|
||||
|
||||
export interface KeyStore<K extends Key> {
|
||||
load(): Promise<K[]>;
|
||||
add(key: K): void;
|
||||
update(id: string, update: Partial<K>, force?: boolean): void;
|
||||
}
|
||||
+10
-89
@@ -1,28 +1,14 @@
|
||||
// Don't import anything here, this is imported by config.ts
|
||||
import { logger } from "../logger";
|
||||
|
||||
import pino from "pino";
|
||||
import type { Request } from "express";
|
||||
import { assertNever } from "./utils";
|
||||
|
||||
export type OpenAIModelFamily =
|
||||
| "turbo"
|
||||
| "gpt4"
|
||||
| "gpt4-32k"
|
||||
| "gpt4-turbo"
|
||||
| "dall-e";
|
||||
export type OpenAIModelFamily = "turbo" | "gpt4" | "gpt4-32k";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
"dall-e"
|
||||
>}`;
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
| AwsBedrockModelFamily;
|
||||
|
||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||
@@ -30,49 +16,37 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"turbo",
|
||||
"gpt4",
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"dall-e",
|
||||
"claude",
|
||||
"bison",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
"azure-gpt4-32k",
|
||||
"azure-gpt4-turbo",
|
||||
] as const);
|
||||
|
||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
|
||||
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
|
||||
"^gpt-4-32k$": "gpt4-32k",
|
||||
"^gpt-4-\\d{4}$": "gpt4",
|
||||
"^gpt-4$": "gpt4",
|
||||
"^gpt-3.5-turbo": "turbo",
|
||||
"^text-embedding-ada-002$": "turbo",
|
||||
"^dall-e-\\d{1}$": "dall-e",
|
||||
};
|
||||
|
||||
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
|
||||
|
||||
export function getOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: OpenAIModelFamily = "gpt4"
|
||||
): OpenAIModelFamily {
|
||||
export function getOpenAIModelFamily(model: string): OpenAIModelFamily {
|
||||
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||
if (model.match(regex)) return family;
|
||||
}
|
||||
return defaultFamily;
|
||||
const stack = new Error().stack;
|
||||
logger.warn({ model, stack }, "Unmapped model family");
|
||||
return "gpt4";
|
||||
}
|
||||
|
||||
export function getClaudeModelFamily(model: string): ModelFamily {
|
||||
if (model.startsWith("anthropic.")) return getAwsBedrockModelFamily(model);
|
||||
export function getClaudeModelFamily(_model: string): ModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function getGooglePalmModelFamily(model: string): ModelFamily {
|
||||
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
|
||||
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
|
||||
const stack = new Error().stack;
|
||||
logger.warn({ model, stack }, "Unmapped PaLM model family");
|
||||
return "bison";
|
||||
}
|
||||
|
||||
@@ -80,24 +54,6 @@ export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
return "aws-claude";
|
||||
}
|
||||
|
||||
export function getAzureOpenAIModelFamily(
|
||||
model: string,
|
||||
defaultFamily: AzureOpenAIModelFamily = "azure-gpt4"
|
||||
): AzureOpenAIModelFamily {
|
||||
// Azure model names omit periods. addAzureKey also prepends "azure-" to the
|
||||
// model name to route the request the correct keyprovider, so we need to
|
||||
// remove that as well.
|
||||
const modified = model
|
||||
.replace("gpt-35-turbo", "gpt-3.5-turbo")
|
||||
.replace("azure-", "");
|
||||
for (const [regex, family] of Object.entries(OPENAI_MODEL_FAMILY_MAP)) {
|
||||
if (modified.match(regex)) {
|
||||
return `azure-${family}` as AzureOpenAIModelFamily;
|
||||
}
|
||||
}
|
||||
return defaultFamily;
|
||||
}
|
||||
|
||||
export function assertIsKnownModelFamily(
|
||||
modelFamily: string
|
||||
): asserts modelFamily is ModelFamily {
|
||||
@@ -105,38 +61,3 @@ export function assertIsKnownModelFamily(
|
||||
throw new Error(`Unknown model family: ${modelFamily}`);
|
||||
}
|
||||
}
|
||||
|
||||
export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
if (req.modelFamily) return req.modelFamily;
|
||||
// There is a single request queue, but it is partitioned by model family.
|
||||
// Model families are typically separated on cost/rate limit boundaries so
|
||||
// they should be treated as separate queues.
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
let modelFamily: ModelFamily;
|
||||
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") {
|
||||
modelFamily = getAwsBedrockModelFamily(model);
|
||||
} else if (req.service === "azure") {
|
||||
modelFamily = getAzureOpenAIModelFamily(model);
|
||||
} else {
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
modelFamily = getClaudeModelFamily(model);
|
||||
break;
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
modelFamily = getOpenAIModelFamily(model);
|
||||
break;
|
||||
case "google-palm":
|
||||
modelFamily = getGooglePalmModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
return (req.modelFamily = modelFamily);
|
||||
}
|
||||
|
||||
@@ -256,8 +256,8 @@ export const appendBatch = async (batch: PromptLogEntry[]) => {
|
||||
return [
|
||||
entry.model,
|
||||
entry.endpoint,
|
||||
entry.promptRaw.slice(-50000),
|
||||
entry.promptFlattened.slice(-50000),
|
||||
entry.promptRaw.slice(0, 50000),
|
||||
entry.promptFlattened.slice(0, 50000),
|
||||
entry.response.slice(0, 50000),
|
||||
];
|
||||
});
|
||||
@@ -396,7 +396,7 @@ export const init = async (onStop: () => void) => {
|
||||
await loadIndexSheet(false);
|
||||
await writeIndexSheet();
|
||||
} catch (e) {
|
||||
log.warn({ error: e.message }, "Could not load index sheet. Creating a new one.");
|
||||
log.info("Creating new index sheet.");
|
||||
await createIndexSheet();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,7 +69,7 @@ export const start = async () => {
|
||||
log.info("Logging backend initialized.");
|
||||
started = true;
|
||||
} catch (e) {
|
||||
log.error({ error: e.message }, "Could not initialize logging backend.");
|
||||
log.error(e, "Could not initialize logging backend.");
|
||||
return;
|
||||
}
|
||||
scheduleFlush();
|
||||
|
||||
+2
-8
@@ -5,9 +5,6 @@ import { ModelFamily } from "./models";
|
||||
export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
let cost = 0;
|
||||
switch (model) {
|
||||
case "gpt4-turbo":
|
||||
cost = 0.00001;
|
||||
break;
|
||||
case "gpt4-32k":
|
||||
cost = 0.00006;
|
||||
break;
|
||||
@@ -15,10 +12,7 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
|
||||
cost = 0.00003;
|
||||
break;
|
||||
case "turbo":
|
||||
cost = 0.000001;
|
||||
break;
|
||||
case "dall-e":
|
||||
cost = 0.00001;
|
||||
cost = 0.0000015;
|
||||
break;
|
||||
case "aws-claude":
|
||||
case "claude":
|
||||
@@ -37,6 +31,6 @@ export function prettyTokens(tokens: number): string {
|
||||
} else if (absTokens < 1000000000) {
|
||||
return (tokens / 1000000).toFixed(2) + "m";
|
||||
} else {
|
||||
return (tokens / 1000000000).toFixed(3) + "b";
|
||||
return (tokens / 1000000000).toFixed(2) + "b";
|
||||
}
|
||||
}
|
||||
|
||||
+15
-12
@@ -39,7 +39,11 @@ export function copySseResponseHeaders(
|
||||
* that the request is being proxied to. Used to send error messages to the
|
||||
* client in the middle of a streaming request.
|
||||
*/
|
||||
export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
export function buildFakeSse(
|
||||
type: string,
|
||||
string: string,
|
||||
req: Request
|
||||
) {
|
||||
let fakeEvent;
|
||||
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
|
||||
|
||||
@@ -50,7 +54,7 @@ export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: req.body?.model,
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: type }],
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: type }]
|
||||
};
|
||||
break;
|
||||
case "openai-text":
|
||||
@@ -59,9 +63,9 @@ export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
object: "text_completion",
|
||||
created: Date.now(),
|
||||
choices: [
|
||||
{ text: content, index: 0, logprobs: null, finish_reason: type },
|
||||
{ text: content, index: 0, logprobs: null, finish_reason: type }
|
||||
],
|
||||
model: req.body?.model,
|
||||
model: req.body?.model
|
||||
};
|
||||
break;
|
||||
case "anthropic":
|
||||
@@ -71,22 +75,21 @@ export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
truncated: false, // I've never seen this be true
|
||||
stop: null,
|
||||
model: req.body?.model,
|
||||
log_id: "proxy-req-" + req.id,
|
||||
log_id: "proxy-req-" + req.id
|
||||
};
|
||||
break;
|
||||
case "google-palm":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
|
||||
throw new Error("PaLM not supported as an inbound API format");
|
||||
default:
|
||||
assertNever(req.inboundApi);
|
||||
}
|
||||
|
||||
if (req.inboundApi === "anthropic") {
|
||||
return (
|
||||
["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") +
|
||||
"\n\n"
|
||||
);
|
||||
return [
|
||||
"event: completion",
|
||||
`data: ${JSON.stringify(fakeEvent)}`,
|
||||
].join("\n") + "\n\n";
|
||||
}
|
||||
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
export { OpenAIPromptMessage } from "./openai";
|
||||
export { init, countTokens } from "./tokenizer";
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import { Tiktoken } from "tiktoken/lite";
|
||||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||
import { logger } from "../../logger";
|
||||
import { libSharp } from "../file-storage";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||
|
||||
let encoder: Tiktoken;
|
||||
|
||||
@@ -21,8 +15,8 @@ export function init() {
|
||||
// Tested against:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
export async function getTokenCount(
|
||||
prompt: string | OpenAIChatMessage[],
|
||||
export function getTokenCount(
|
||||
prompt: string | OpenAIPromptMessage[],
|
||||
model: string
|
||||
) {
|
||||
if (typeof prompt === "string") {
|
||||
@@ -30,49 +24,31 @@ export async function getTokenCount(
|
||||
}
|
||||
|
||||
const gpt4 = model.startsWith("gpt-4");
|
||||
const vision = model.includes("vision");
|
||||
|
||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||
|
||||
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
|
||||
let numTokens = 0;
|
||||
|
||||
for (const message of prompt) {
|
||||
numTokens += tokensPerMessage;
|
||||
for (const key of Object.keys(message)) {
|
||||
{
|
||||
let textContent: string = "";
|
||||
const value = message[key as keyof OpenAIChatMessage];
|
||||
|
||||
if (!value) continue;
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
for (const item of value) {
|
||||
if (item.type === "text") {
|
||||
textContent += item.text;
|
||||
} else if (item.type === "image_url") {
|
||||
const { url, detail } = item.image_url;
|
||||
const cost = await getGpt4VisionTokenCost(url, detail);
|
||||
numTokens += cost ?? 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
textContent = value;
|
||||
}
|
||||
|
||||
const value = message[key as keyof OpenAIPromptMessage];
|
||||
if (!value || typeof value !== "string") continue;
|
||||
// Break if we get a huge message or exceed the token limit to prevent
|
||||
// DoS.
|
||||
// 200k tokens allows for future 200k GPT-4 models and 500k characters
|
||||
// 100k tokens allows for future 100k GPT-4 models and 500k characters
|
||||
// is just a sanity check
|
||||
if (textContent.length > 500000 || numTokens > 200000) {
|
||||
numTokens = 200000;
|
||||
if (value.length > 500000 || numTokens > 100000) {
|
||||
numTokens = 100000;
|
||||
return {
|
||||
tokenizer: "tiktoken (prompt length limit exceeded)",
|
||||
token_count: numTokens,
|
||||
};
|
||||
}
|
||||
|
||||
numTokens += encoder.encode(textContent).length;
|
||||
numTokens += encoder.encode(value).length;
|
||||
if (key === "name") {
|
||||
numTokens += tokensPerName;
|
||||
}
|
||||
@@ -83,78 +59,6 @@ export async function getTokenCount(
|
||||
return { tokenizer: "tiktoken", token_count: numTokens };
|
||||
}
|
||||
|
||||
async function getGpt4VisionTokenCost(
|
||||
url: string,
|
||||
detail: "auto" | "low" | "high" = "auto"
|
||||
) {
|
||||
// For now we do not allow remote images as the proxy would have to download
|
||||
// them, which is a potential DoS vector.
|
||||
if (!url.startsWith("data:image/")) {
|
||||
throw new Error(
|
||||
"Remote images are not supported. Add the image to your prompt as a base64 data URL."
|
||||
);
|
||||
}
|
||||
|
||||
const base64Data = url.split(",")[1];
|
||||
const buffer = Buffer.from(base64Data, "base64");
|
||||
const image = libSharp(buffer);
|
||||
const metadata = await image.metadata();
|
||||
|
||||
if (!metadata || !metadata.width || !metadata.height) {
|
||||
throw new Error("Prompt includes an image that could not be parsed");
|
||||
}
|
||||
|
||||
const { width, height } = metadata;
|
||||
|
||||
let selectedDetail: "low" | "high";
|
||||
if (detail === "auto") {
|
||||
const threshold = 512 * 512;
|
||||
const imageSize = width * height;
|
||||
selectedDetail = imageSize > threshold ? "high" : "low";
|
||||
} else {
|
||||
selectedDetail = detail;
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
if (selectedDetail === "low") {
|
||||
log.info(
|
||||
{ width, height, tokens: 85 },
|
||||
"Using fixed GPT-4-Vision token cost for low detail image"
|
||||
);
|
||||
return 85;
|
||||
}
|
||||
|
||||
let newWidth = width;
|
||||
let newHeight = height;
|
||||
if (width > 2048 || height > 2048) {
|
||||
const aspectRatio = width / height;
|
||||
if (width > height) {
|
||||
newWidth = 2048;
|
||||
newHeight = Math.round(2048 / aspectRatio);
|
||||
} else {
|
||||
newHeight = 2048;
|
||||
newWidth = Math.round(2048 * aspectRatio);
|
||||
}
|
||||
}
|
||||
|
||||
if (newWidth < newHeight) {
|
||||
newHeight = Math.round((newHeight / newWidth) * 768);
|
||||
newWidth = 768;
|
||||
} else {
|
||||
newWidth = Math.round((newWidth / newHeight) * 768);
|
||||
newHeight = 768;
|
||||
}
|
||||
|
||||
const tiles = Math.ceil(newWidth / 512) * Math.ceil(newHeight / 512);
|
||||
const tokens = 170 * tiles + 85;
|
||||
|
||||
log.info(
|
||||
{ width, height, newWidth, newHeight, tiles, tokens },
|
||||
"Calculated GPT-4-Vision token cost for high detail image"
|
||||
);
|
||||
return tokens;
|
||||
}
|
||||
|
||||
function getTextTokenCount(prompt: string) {
|
||||
if (prompt.length > 500000) {
|
||||
return {
|
||||
@@ -169,62 +73,8 @@ function getTextTokenCount(prompt: string) {
|
||||
};
|
||||
}
|
||||
|
||||
// Model Resolution Price
|
||||
// DALL·E 3 1024×1024 $0.040 / image
|
||||
// 1024×1792, 1792×1024 $0.080 / image
|
||||
// DALL·E 3 HD 1024×1024 $0.080 / image
|
||||
// 1024×1792, 1792×1024 $0.120 / image
|
||||
// DALL·E 2 1024×1024 $0.020 / image
|
||||
// 512×512 $0.018 / image
|
||||
// 256×256 $0.016 / image
|
||||
|
||||
export const DALLE_TOKENS_PER_DOLLAR = 100000;
|
||||
|
||||
/**
|
||||
* OpenAI image generation with DALL-E doesn't use tokens but everything else
|
||||
* in the application does. There is a fixed cost for each image generation
|
||||
* request depending on the model and selected quality/resolution parameters,
|
||||
* which we convert to tokens at a rate of 100000 tokens per dollar.
|
||||
*/
|
||||
export function getOpenAIImageCost(params: {
|
||||
model: "dall-e-2" | "dall-e-3";
|
||||
quality: "standard" | "hd";
|
||||
resolution: "512x512" | "256x256" | "1024x1024" | "1024x1792" | "1792x1024";
|
||||
n: number | null;
|
||||
}) {
|
||||
const { model, quality, resolution, n } = params;
|
||||
const usd = (() => {
|
||||
switch (model) {
|
||||
case "dall-e-2":
|
||||
switch (resolution) {
|
||||
case "512x512":
|
||||
return 0.018;
|
||||
case "256x256":
|
||||
return 0.016;
|
||||
case "1024x1024":
|
||||
return 0.02;
|
||||
default:
|
||||
throw new Error("Invalid resolution");
|
||||
}
|
||||
case "dall-e-3":
|
||||
switch (resolution) {
|
||||
case "1024x1024":
|
||||
return quality === "standard" ? 0.04 : 0.08;
|
||||
case "1024x1792":
|
||||
case "1792x1024":
|
||||
return quality === "standard" ? 0.08 : 0.12;
|
||||
default:
|
||||
throw new Error("Invalid resolution");
|
||||
}
|
||||
default:
|
||||
throw new Error("Invalid image generation model");
|
||||
}
|
||||
})();
|
||||
|
||||
const tokens = (n ?? 1) * (usd * DALLE_TOKENS_PER_DOLLAR);
|
||||
|
||||
return {
|
||||
tokenizer: `openai-image cost`,
|
||||
token_count: Math.ceil(tokens),
|
||||
};
|
||||
}
|
||||
export type OpenAIPromptMessage = {
|
||||
name?: string;
|
||||
content: string;
|
||||
role: string;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Request } from "express";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
@@ -8,7 +7,7 @@ import {
|
||||
import {
|
||||
init as initOpenAi,
|
||||
getTokenCount as getOpenAITokenCount,
|
||||
getOpenAIImageCost,
|
||||
OpenAIPromptMessage,
|
||||
} from "./openai";
|
||||
import { APIFormat } from "../key-management";
|
||||
|
||||
@@ -20,14 +19,13 @@ export async function init() {
|
||||
/** Tagged union via `service` field of the different types of requests that can
|
||||
* be made to the tokenization service, for both prompts and completions */
|
||||
type TokenCountRequest = { req: Request } & (
|
||||
| { prompt: OpenAIChatMessage[]; completion?: never; service: "openai" }
|
||||
| { prompt: OpenAIPromptMessage[]; completion?: never; service: "openai" }
|
||||
| {
|
||||
prompt: string;
|
||||
completion?: never;
|
||||
service: "openai-text" | "anthropic" | "google-palm";
|
||||
}
|
||||
| { prompt?: never; completion: string; service: APIFormat }
|
||||
| { prompt?: never; completion?: never; service: "openai-image" }
|
||||
);
|
||||
|
||||
type TokenCountResult = {
|
||||
@@ -52,24 +50,14 @@ export async function countTokens({
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return {
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "openai-image":
|
||||
return {
|
||||
...getOpenAIImageCost({
|
||||
model: req.body.model,
|
||||
quality: req.body.quality,
|
||||
resolution: req.body.size,
|
||||
n: parseInt(req.body.n, 10) || null,
|
||||
}),
|
||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "google-palm":
|
||||
// TODO: Can't find a tokenization library for PaLM. There is an API
|
||||
// endpoint for it but it adds significant latency to the request.
|
||||
return {
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
...getOpenAITokenCount(prompt ?? completion, req.body.model),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
|
||||
@@ -6,8 +6,6 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
|
||||
turbo: z.number().optional().default(0),
|
||||
gpt4: z.number().optional().default(0),
|
||||
"gpt4-32k": z.number().optional().default(0),
|
||||
"gpt4-turbo": z.number().optional().default(0),
|
||||
"dall-e": z.number().optional().default(0),
|
||||
claude: z.number().optional().default(0),
|
||||
bison: z.number().optional().default(0),
|
||||
"aws-claude": z.number().optional().default(0),
|
||||
|
||||
@@ -11,18 +11,9 @@ import admin from "firebase-admin";
|
||||
import schedule from "node-schedule";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
} from "../models";
|
||||
import { MODEL_FAMILIES, ModelFamily } from "../models";
|
||||
import { logger } from "../../logger";
|
||||
import { User, UserTokenCounts, UserUpdate } from "./schema";
|
||||
import { APIFormat } from "../key-management";
|
||||
import { assertNever } from "../utils";
|
||||
|
||||
const log = logger.child({ module: "users" });
|
||||
|
||||
@@ -30,15 +21,9 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
turbo: 0,
|
||||
gpt4: 0,
|
||||
"gpt4-32k": 0,
|
||||
"gpt4-turbo": 0,
|
||||
"dall-e": 0,
|
||||
claude: 0,
|
||||
bison: 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
"azure-gpt4-turbo": 0,
|
||||
"azure-gpt4-32k": 0,
|
||||
};
|
||||
|
||||
const users: Map<string, User> = new Map();
|
||||
@@ -47,8 +32,8 @@ let quotaRefreshJob: schedule.Job | null = null;
|
||||
let userCleanupJob: schedule.Job | null = null;
|
||||
|
||||
export async function init() {
|
||||
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
||||
log.info({ store: config.persistenceProvider }, "Initializing user store...");
|
||||
if (config.persistenceProvider === "firebase_rtdb") {
|
||||
await initFirebase();
|
||||
}
|
||||
if (config.quotaRefreshPeriod) {
|
||||
@@ -161,7 +146,7 @@ export function upsertUser(user: UserUpdate) {
|
||||
usersToFlush.add(user.token);
|
||||
|
||||
// Immediately schedule a flush to the database if we're using Firebase.
|
||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
||||
if (config.persistenceProvider === "firebase_rtdb") {
|
||||
setImmediate(flushUsers);
|
||||
}
|
||||
|
||||
@@ -180,12 +165,11 @@ export function incrementPromptCount(token: string) {
|
||||
export function incrementTokenCount(
|
||||
token: string,
|
||||
model: string,
|
||||
api: APIFormat,
|
||||
consumption: number
|
||||
) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model, api);
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model);
|
||||
const existing = user.tokenCounts[modelFamily] ?? 0;
|
||||
user.tokenCounts[modelFamily] = existing + consumption;
|
||||
usersToFlush.add(token);
|
||||
@@ -196,52 +180,34 @@ export function incrementTokenCount(
|
||||
* to the user's list of IPs. Returns the user if they exist and are not
|
||||
* disabled, otherwise returns undefined.
|
||||
*/
|
||||
export function authenticate(
|
||||
token: string,
|
||||
ip: string
|
||||
): { user?: User; result: "success" | "disabled" | "not_found" | "limited" } {
|
||||
export function authenticate(token: string, ip: string) {
|
||||
const user = users.get(token);
|
||||
if (!user) return { result: "not_found" };
|
||||
if (user.disabledAt) return { result: "disabled" };
|
||||
|
||||
const newIp = !user.ip.includes(ip);
|
||||
|
||||
const userLimit = user.maxIps ?? config.maxIpsPerUser;
|
||||
const enforcedLimit =
|
||||
user.type === "special" || !userLimit ? Infinity : userLimit;
|
||||
|
||||
if (newIp && user.ip.length >= enforcedLimit) {
|
||||
if (config.maxIpsAutoBan) {
|
||||
user.ip.push(ip);
|
||||
disableUser(token, "IP address limit exceeded.");
|
||||
return { result: "disabled" };
|
||||
}
|
||||
return { result: "limited" };
|
||||
} else if (newIp) {
|
||||
user.ip.push(ip);
|
||||
if (!user || user.disabledAt) return;
|
||||
if (!user.ip.includes(ip)) user.ip.push(ip);
|
||||
|
||||
const configIpLimit = user.maxIps ?? config.maxIpsPerUser;
|
||||
const ipLimit =
|
||||
user.type === "special" || !configIpLimit ? Infinity : configIpLimit;
|
||||
if (user.ip.length > ipLimit) {
|
||||
disableUser(token, "IP address limit exceeded.");
|
||||
return;
|
||||
}
|
||||
|
||||
user.lastUsedAt = Date.now();
|
||||
usersToFlush.add(token);
|
||||
return { user, result: "success" };
|
||||
return user;
|
||||
}
|
||||
|
||||
export function hasAvailableQuota({
|
||||
userToken,
|
||||
model,
|
||||
api,
|
||||
requested,
|
||||
}: {
|
||||
userToken: string;
|
||||
model: string;
|
||||
api: APIFormat;
|
||||
requested: number;
|
||||
}) {
|
||||
const user = users.get(userToken);
|
||||
export function hasAvailableQuota(
|
||||
token: string,
|
||||
model: string,
|
||||
requested: number
|
||||
) {
|
||||
const user = users.get(token);
|
||||
if (!user) return false;
|
||||
if (user.type === "special") return true;
|
||||
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model, api);
|
||||
const modelFamily = getModelFamilyForQuotaUsage(model);
|
||||
const { tokenCounts, tokenLimits } = user;
|
||||
const tokenLimit = tokenLimits[modelFamily];
|
||||
|
||||
@@ -383,25 +349,27 @@ async function flushUsers() {
|
||||
);
|
||||
}
|
||||
|
||||
function getModelFamilyForQuotaUsage(
|
||||
model: string,
|
||||
api: APIFormat
|
||||
): ModelFamily {
|
||||
// TODO: this seems incorrect
|
||||
if (model.includes("azure")) return getAzureOpenAIModelFamily(model);
|
||||
|
||||
switch (api) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return getOpenAIModelFamily(model);
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
// TODO: use key-management/models.ts for family mapping
|
||||
function getModelFamilyForQuotaUsage(model: string): ModelFamily {
|
||||
if (model.includes("32k")) {
|
||||
return "gpt4-32k";
|
||||
}
|
||||
if (model.startsWith("gpt-4")) {
|
||||
return "gpt4";
|
||||
}
|
||||
if (model.startsWith("gpt-3.5")) {
|
||||
return "turbo";
|
||||
}
|
||||
if (model.includes("bison")) {
|
||||
return "bison";
|
||||
}
|
||||
if (model.startsWith("claude")) {
|
||||
return "claude";
|
||||
}
|
||||
if(model.startsWith("anthropic.claude")) {
|
||||
return "aws-claude";
|
||||
}
|
||||
throw new Error(`Unknown quota model family for model ${model}`);
|
||||
}
|
||||
|
||||
function getRefreshCrontab() {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import cookieParser from "cookie-parser";
|
||||
import expressSession from "express-session";
|
||||
import MemoryStore from "memorystore";
|
||||
import { config, COOKIE_SECRET } from "../config";
|
||||
import { COOKIE_SECRET } from "../config";
|
||||
|
||||
const ONE_WEEK = 1000 * 60 * 60 * 24 * 7;
|
||||
|
||||
@@ -12,12 +12,7 @@ const sessionMiddleware = expressSession({
|
||||
resave: false,
|
||||
saveUninitialized: false,
|
||||
store: new (MemoryStore(expressSession))({ checkPeriod: ONE_WEEK }),
|
||||
cookie: {
|
||||
sameSite: "strict",
|
||||
maxAge: ONE_WEEK,
|
||||
signed: true,
|
||||
secure: !config.useInsecureCookies,
|
||||
},
|
||||
cookie: { sameSite: "strict", maxAge: ONE_WEEK, signed: true },
|
||||
});
|
||||
|
||||
const withSession = [cookieParserMiddleware, sessionMiddleware];
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user