55 Commits

Author SHA1 Message Date
nai-degen 6f7abf0220 wip 2024-02-04 13:31:27 -06:00
nai-degen 235510e588 fixes incorrect AWS Claude 2.1 max context limit 2024-02-01 20:40:15 -06:00
nai-degen 7eb6eb90ad moves api schema validators from transform-outbound-payload into shared 2024-01-29 19:38:22 -06:00
nai-degen 924db33f7e attempts to auto-convert Mistral prompts for its more strict rules 2024-01-28 17:42:23 -06:00
nai-degen 3f2f30e605 updates gpt4-v tokenizer for previous Risu change 2024-01-27 13:35:46 -06:00
nai-degen c9791acd85 makes gpt4-v input validation less strict to accomodate Risu 2024-01-27 13:24:11 -06:00
nai-degen e871b8ecf1 removes logprobs default value since it breaks gpt-4-vision 2024-01-27 12:19:24 -06:00
nai-degen 37ca98ad30 adds dark mode (infopage only currently) 2024-01-25 16:24:11 -06:00
nai-degen e6dc4475e6 fixes max context size for nu-gpt4-turbo 2024-01-25 14:07:42 -06:00
nai-degen 5e646b1c86 adds gpt-4-0125-preview and gpt-4-turbo-preview alias 2024-01-25 13:27:03 -06:00
nai-degen 6f626e623e fixes OAI trial keys bricking the dall-e queue 2024-01-25 01:47:51 -06:00
nai-degen 02a54bf4e3 fixes azure openai logprobs (actually tested this time) 2024-01-25 01:17:18 -06:00
nai-degen 79b2e5b6fd adds very basic support for OpenAI function calling 2024-01-24 16:42:26 -06:00
nai-degen 935a633325 fixes typo in Azure logprob adjustment 2024-01-24 16:03:47 -06:00
nai-degen 4a4b60ebcd handles Azure deviation from OpenAI spec on logprobs param 2024-01-24 16:01:19 -06:00
nai-degen ad465be363 fixes logprobs schema validation for turbo instruct endpoint 2024-01-24 14:31:10 -06:00
nai-degen c7a351baa8 adds support for requesting logprobs from OpenAI Chat Completions API 2024-01-24 11:46:09 -06:00
nai-degen ba8b052b17 adds bindAddress to omitted config keys 2024-01-18 04:14:15 -06:00
nai-degen e813cd9d22 default claude 2.1 instead of 1.3 in openai compat endpoint since 1.3 is not accessible on all keys 2024-01-18 04:14:15 -06:00
nai-degen 4c2a2c1e6c improves handle-streamed-response comments/docs [skip-ci] 2024-01-18 04:14:15 -06:00
nai-degen f1d927fa62 updates README with building/forking info [skip-ci] 2024-01-15 11:46:09 -06:00
nai-degen ad6e5224e3 allows binding to loopback interface via app config instead of only docker 2024-01-15 11:32:26 -06:00
nai-degen 85d89bdb9f fixes CI image tagging on main branch 2024-01-15 01:37:50 -06:00
khanon f5e7195cc9 Add Gitlab CI and self-hosting instructions (khanon/oai-reverse-proxy!61) 2024-01-15 06:51:12 +00:00
nai-degen 81f1e2bc37 fixes broken GET models endpoint for openai/mistral 2024-01-14 05:33:24 -06:00
nai-degen c2a686f229 Revert "reduces max request body size for now"
This reverts commit 4ffa7fb12b.
2024-01-13 18:12:16 -06:00
twinkletoes 96a0f94041 Fix Mistral safe_prompt schema property (khanon/oai-reverse-proxy!60) 2024-01-14 00:11:39 +00:00
nai-degen d56043616e adds keychecker workaround for OpenAI API bug falsely returning gpt4-32k 2024-01-12 10:33:48 -06:00
nai-degen e3e06b065d fixes sourcemap dependency in package.json 2024-01-09 00:32:34 -06:00
nai-degen 1bbb515200 updates static service info 2024-01-08 23:32:25 -06:00
nai-degen a57cc4e8d4 updates dotenv 2024-01-08 23:25:02 -06:00
nai-degen 2239bead2c updates README.md 2024-01-08 19:36:35 -06:00
nai-degen 1a585ddd32 adds TRUSTED_PROXIES to .env.example 2024-01-08 16:41:30 -06:00
nai-degen be731691a1 allows configurable trust proxy setting for Render deployments 2024-01-08 16:39:28 -06:00
nai-degen c2e442e030 long overdue removal of tired in-joke 2024-01-08 11:01:44 -06:00
nai-degen d3ac3b362b trusts only one proxy hop (AWS WAF in huggingface's case) 2024-01-07 19:18:01 -06:00
nai-degen 7b0892ddae fixes unawaited call to async enqueue 2024-01-07 16:23:53 -06:00
nai-degen 7f92565739 SSE queueing adjustments, untested 2024-01-07 16:19:22 -06:00
nai-degen 936d3c0721 corrects nodejs max heap memory config 2024-01-07 16:16:27 -06:00
nai-degen 4ffa7fb12b reduces max request body size for now 2024-01-07 13:03:24 -06:00
nai-degen 8dc7464381 strips extraneous properties on zod schemas 2024-01-07 13:00:48 -06:00
nai-degen d2cd24bfd2 suggest larger nodejs max heap 2024-01-07 12:58:50 -06:00
twinkletoes e33f778192 Change mistral-medium friendly name (khanon/oai-reverse-proxy!59) 2023-12-26 00:27:17 +00:00
twinkletoes 4a823b216f Mistral AI support (khanon/oai-reverse-proxy!58) 2023-12-25 18:33:16 +00:00
nai-degen 01e76cbb1c restores accidentally deleted line breaking infopage stats 2023-12-17 00:25:58 -06:00
nai-degen 655703e680 refactors infopage 2023-12-16 20:30:20 -06:00
nai-degen 3be2687793 tries to detect Azure GPT4-Turbo deployments more reliably 2023-12-15 12:14:23 -06:00
nai-degen 5599a83ae4 improves streaming error handling 2023-12-14 05:01:10 -06:00
nai-degen de34d41918 fixes gemini name prefixing when 'Add character names' is disabled in ST 2023-12-13 23:21:30 -06:00
nai-degen c5cd90dcef adjusts prompt transform to discourage Gemini from speaking for user 2023-12-13 23:03:57 -06:00
nai-degen 8a135a960d fixes gemini prompt reformatting for jbs; adds stop sequences 2023-12-13 21:45:53 -06:00
nai-degen 707cbbce16 fixes gemini throwing an error on JB prompts 2023-12-13 19:14:31 -06:00
khanon fad16cc268 Add Google AI API (khanon/oai-reverse-proxy!57) 2023-12-13 21:56:07 +00:00
nai-degen 0d3682197c treats 403 from anthropic as key dead 2023-12-11 09:13:53 -06:00
valadaptive e0624e30fd Fix some corner cases in SSE parsing (khanon/oai-reverse-proxy!56) 2023-12-09 06:18:01 +00:00
79 changed files with 3582 additions and 1516 deletions
+16 -6
View File
@@ -5,6 +5,9 @@
# All values have reasonable defaults, so you only need to change the ones you
# want to override.
# Use production mode unless you are developing locally.
NODE_ENV=production
# ------------------------------------------------------------------------------
# General settings:
@@ -34,10 +37,10 @@
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | mistral-tiny | mistral-small | mistral-medium | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,mistral-tiny,mistral-small,mistral-medium,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com
@@ -57,8 +60,9 @@
# Requires additional setup. See `docs/google-sheets.md` for more information.
# PROMPT_LOGGING=false
# The port to listen on.
# The port and network interface to listen on.
# PORT=7860
# BIND_ADDRESS=0.0.0.0
# Whether cookies should be set without the Secure flag, for hosts that don't support SSL.
# USE_INSECURE_COOKIES=false
@@ -95,17 +99,23 @@
# TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_BISON=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily)
# Leave unset to never automatically refresh quotas.
# QUOTA_REFRESH_PERIOD=daily
# Specifies the number of proxies or load balancers in front of the server.
# For Cloudflare or Hugging Face deployments, the default of 1 is correct.
# For any other deployments, please see config.ts as the correct configuration
# depends on your setup. Misconfiguring this value can result in problems
# accurately tracking IP addresses and enforcing rate limits.
# TRUSTED_PROXIES=1
# ------------------------------------------------------------------------------
# Secrets and keys:
# Do not put any passwords or API keys directly in this file.
# For Huggingface, set them via the Secrets section in your Space's config UI.
# For Huggingface, set them via the Secrets section in your Space's config UI. Dp not set them in .env.
# For Render, create a "secret file" called .env using the Environment tab.
# You can add multiple API keys by separating them with a comma.
+3 -1
View File
@@ -1,8 +1,10 @@
.env
.env*
!.env.vault
.venv
.vscode
.idea
build
greeting.md
node_modules
http-client.private.env.json
+43 -15
View File
@@ -1,34 +1,53 @@
# OAI Reverse Proxy
Reverse proxy server for the OpenAI and Anthropic APIs. Forwards text generation requests while rejecting administrative/billing requests. Includes optional rate limiting and prompt filtering to prevent abuse.
Reverse proxy server for various LLM APIs.
### Table of Contents
- [What is this?](#what-is-this)
- [Why?](#why)
- [Usage Instructions](#setup-instructions)
- [Deploy to Huggingface (Recommended)](#deploy-to-huggingface-recommended)
- [Deploy to Repl.it (WIP)](#deploy-to-replit-wip)
- [Features](#features)
- [Usage Instructions](#usage-instructions)
- [Self-hosting](#self-hosting)
- [Alternatives](#alternatives)
- [Huggingface (outdated, not advised)](#huggingface-outdated-not-advised)
- [Render (outdated, not advised)](#render-outdated-not-advised)
- [Local Development](#local-development)
## What is this?
If you would like to provide a friend access to an API via keys you own, you can use this to keep your keys safe while still allowing them to generate text with the API. You can also use this if you'd like to build a client-side application which uses the OpenAI or Anthropic APIs, but don't want to build your own backend. You should never embed your real API keys in a client-side application. Instead, you can have your frontend connect to this reverse proxy and forward requests to the downstream service.
This project allows you to run a reverse proxy server for various LLM APIs.
This keeps your keys safe and allows you to use the rate limiting and prompt filtering features of the proxy to prevent abuse.
## Why?
OpenAI keys have full account permissions. They can revoke themselves, generate new keys, modify spend quotas, etc. **You absolutely should not share them, post them publicly, nor embed them in client-side applications as they can be easily stolen.**
This proxy only forwards text generation requests to the downstream service and rejects requests which would otherwise modify your account.
## Features
- [x] Support for multiple APIs
- [x] [OpenAI](https://openai.com/)
- [x] [Anthropic](https://www.anthropic.com/)
- [x] [AWS Bedrock](https://aws.amazon.com/bedrock/)
- [x] [Google MakerSuite/Gemini API](https://ai.google.dev/)
- [x] [Azure OpenAI](https://azure.microsoft.com/en-us/products/ai-services/openai-service)
- [x] Translation from OpenAI-formatted prompts to any other API, including streaming responses
- [x] Multiple API keys with rotation and rate limit handling
- [x] Basic user management
- [x] Simple role-based permissions
- [x] Per-model token quotas
- [x] Temporary user accounts
- [x] Prompt and completion logging
- [x] Abuse detection and prevention
---
## Usage Instructions
If you'd like to run your own instance of this proxy, you'll need to deploy it somewhere and configure it with your API keys. A few easy options are provided below, though you can also deploy it to any other service you'd like.
If you'd like to run your own instance of this server, you'll need to deploy it somewhere and configure it with your API keys. A few easy options are provided below, though you can also deploy it to any other service you'd like if you know what you're doing and the service supports Node.js.
### Deploy to Huggingface (Recommended)
### Self-hosting
[See here for instructions on how to self-host the application on your own VPS or local machine.](./docs/self-hosting.md)
**Ensure you set the `TRUSTED_PROXIES` environment variable according to your deployment.** Refer to [.env.example](./.env.example) and [config.ts](./src/config.ts) for more information.
### Alternatives
Fiz and Sekrit are working on some alternative ways to deploy this conveniently. While I'm not involved in this effort beyond providing technical advice regarding my code, I'll link to their work here for convenience: [Sekrit's rentry](https://rentry.org/sekrit)
### Huggingface (outdated, not advised)
[See here for instructions on how to deploy to a Huggingface Space.](./docs/deploy-huggingface.md)
### Deploy to Render
### Render (outdated, not advised)
[See here for instructions on how to deploy to Render.com.](./docs/deploy-render.md)
## Local Development
@@ -40,3 +59,12 @@ To run the proxy locally for development or testing, install Node.js >= 18.0.0 a
4. Start the server in development mode with `npm run start:dev`.
You can also use `npm run start:dev:tsc` to enable project-wide type checking at the cost of slower startup times. `npm run type-check` can be used to run type checking without starting the server.
## Building
To build the project, run `npm run build`. This will compile the TypeScript code to JavaScript and output it to the `build` directory.
Note that if you are trying to build the server on a very memory-constrained (<= 1GB) VPS, you may need to run the build with `NODE_OPTIONS=--max_old_space_size=2048 npm run build` to avoid running out of memory during the build process, assuming you have swap enabled. The application itself should run fine on a 512MB VPS for most reasonable traffic levels.
## Forking
If you are forking the repository on GitGud, you may wish to disable GitLab CI/CD or you will be spammed with emails about failed builds due not having any CI runners. You can do this by going to *Settings > General > Visibility, project features, permissions* and then disabling the "CI/CD" feature.
+21
View File
@@ -0,0 +1,21 @@
stages:
- build
build_image:
stage: build
image:
name: gcr.io/kaniko-project/executor:debug
entrypoint: [""]
script:
- |
if [ "$CI_COMMIT_REF_NAME" = "main" ]; then
TAG="latest"
else
TAG=$CI_COMMIT_REF_NAME
fi
- echo "Building image with tag $TAG"
- BASE64_AUTH=$(echo -n "$DOCKER_HUB_USERNAME:$DOCKER_HUB_ACCESS_TOKEN" | base64)
- echo "{\"auths\":{\"https://index.docker.io/v1/\":{\"auth\":\"$BASE64_AUTH\"}}}" > /kaniko/.docker/config.json
- /kaniko/executor --context $CI_PROJECT_DIR --dockerfile $CI_PROJECT_DIR/docker/ci/Dockerfile --destination docker.io/khanonci/oai-reverse-proxy:$TAG --build-arg CI_COMMIT_REF_NAME=$CI_COMMIT_REF_NAME --build-arg CI_COMMIT_SHA=$CI_COMMIT_SHA --build-arg CI_PROJECT_PATH=$CI_PROJECT_PATH
only:
- main
+22
View File
@@ -0,0 +1,22 @@
FROM node:18-bullseye-slim
WORKDIR /app
COPY . .
RUN npm ci
RUN npm run build
RUN npm prune --production
EXPOSE 7860
ENV PORT=7860
ENV NODE_ENV=production
ARG CI_COMMIT_REF_NAME
ARG CI_COMMIT_SHA
ARG CI_PROJECT_PATH
ENV GITGUD_BRANCH=$CI_COMMIT_REF_NAME
ENV GITGUD_COMMIT=$CI_COMMIT_SHA
ENV GITGUD_PROJECT=$CI_PROJECT_PATH
CMD [ "npm", "start" ]
+17
View File
@@ -0,0 +1,17 @@
# Before running this, create a .env and greeting.md file.
# Refer to .env.example for the required environment variables.
# User-generated content is stored in the data directory.
# When self-hosting, it's recommended to run this behind a reverse proxy like
# nginx or Caddy to handle SSL/TLS and rate limiting. Refer to
# docs/self-hosting.md for more information and an example nginx config.
version: '3.8'
services:
oai-reverse-proxy:
image: khanonci/oai-reverse-proxy:latest
ports:
- "127.0.0.1:7860:7860"
env_file:
- ./.env
volumes:
- ./greeting.md:/app/greeting.md
- ./data:/app/data
+2
View File
@@ -10,4 +10,6 @@ COPY Dockerfile greeting.md* .env* ./
RUN npm run build
EXPOSE 7860
ENV NODE_ENV=production
# Huggigface free VMs have 16GB of RAM so we can be greedy
ENV NODE_OPTIONS="--max-old-space-size=12882"
CMD [ "npm", "start" ]
+1 -1
View File
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
# All models as of this writing
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
```
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
+3
View File
@@ -1,5 +1,7 @@
# Deploy to Huggingface Space
**⚠️ This method is no longer recommended. Please use the [self-hosting instructions](./self-hosting.md) instead.**
This repository can be deployed to a [Huggingface Space](https://huggingface.co/spaces). This is a free service that allows you to run a simple server in the cloud. You can use it to safely share your OpenAI API key with a friend.
### 1. Get an API key
@@ -32,6 +34,7 @@ COPY Dockerfile greeting.md* .env* ./
RUN npm run build
EXPOSE 7860
ENV NODE_ENV=production
ENV NODE_OPTIONS="--max-old-space-size=12882"
CMD [ "npm", "start" ]
```
- Click "Commit new file to `main`" to save the Dockerfile.
+5
View File
@@ -1,4 +1,7 @@
# Deploy to Render.com
**⚠️ This method is no longer recommended. Please use the [self-hosting instructions](./self-hosting.md) instead.**
Render.com offers a free tier that includes 750 hours of compute time per month. This is enough to run a single proxy instance 24/7. Instances shut down after 15 minutes without traffic but start up again automatically when a request is received. You can use something like https://app.checklyhq.com/ to ping your proxy every 15 minutes to keep it alive.
### 1. Create account
@@ -28,6 +31,8 @@ The service will be created according to the instructions in the `render.yaml` f
- For example, `OPENAI_KEY=sk-abc123`.
- Click **Save Changes**.
**IMPORTANT:** Set `TRUSTED_PROXIES=3`, otherwise users' IP addresses will not be recorded correctly (the server will see the IP address of Render's load balancer instead of the user's real IP address).
The service will automatically rebuild and deploy with the new environment variables. This will take a few minutes. The link to your deployed proxy will appear at the top of the page.
If you want to change the URL, go to the **Settings** tab of your Web Service and click the **Edit** button next to **Name**. You can also set a custom domain, though I haven't tried this yet.
+150
View File
@@ -0,0 +1,150 @@
# Quick self-hosting guide
Temporary guide for self-hosting. This will be improved in the future to provide more robust instructions and options. Provided commands are for Ubuntu.
This uses prebuilt Docker images for convenience. If you want to make adjustments to the code you can instead clone the repo and follow the Local Development guide in the [README](../README.md).
## Table of Contents
- [Requirements](#requirements)
- [Running the application](#running-the-application)
- [Setting up a reverse proxy](#setting-up-a-reverse-proxy)
- [trycloudflare](#trycloudflare)
- [nginx](#nginx)
- [Example basic nginx configuration (no SSL)](#example-basic-nginx-configuration-no-ssl)
- [Example with Cloudflare SSL](#example-with-cloudflare-ssl)
- [Updating/Restarting the application](#updatingrestarting-the-application)
## Requirements
- Docker
- Docker Compose
- A VPS with at least 512MB of RAM (1GB recommended)
- A domain name
If you don't have a VPS and domain name you can use TryCloudflare to set up a temporary URL that you can share with others. See [trycloudflare](#trycloudflare) for more information.
## Running the application
- Install Docker and Docker Compose
- Create a new directory for the application
- This will contain your .env file, greeting file, and any user-generated files
- Execute the following commands:
- ```
touch .env
touch greeting.md
echo "OPENAI_KEY=your-openai-key" >> .env
curl https://gitgud.io/khanon/oai-reverse-proxy/-/raw/main/docker/docker-compose-selfhost.yml -o docker-compose.yml
```
- You can set further environment variables and keys in the `.env` file. See [.env.example](../.env.example) for a list of available options.
- You can set a custom greeting in `greeting.md`. This will be displayed on the homepage.
- Run `docker compose up -d`
You can check logs with `docker compose logs -n 100 -f`.
The provided docker-compose file listens on port 7860 but binds to localhost only. You should use a reverse proxy to expose the application to the internet as described in the next section.
## Setting up a reverse proxy
Rather than exposing the application directly to the internet, it is recommended to set up a reverse proxy. This will allow you to use HTTPS and add additional security measures.
### trycloudflare
This will give you a temporary (72 hours) URL that you can use to let others connect to your instance securely, without having to set up a reverse proxy. If you are running the server on your home network, this is probably the best option.
- Install `cloudflared` following the instructions at [try.cloudflare.com](https://try.cloudflare.com/).
- Run `cloudflared tunnel --url http://localhost:7860`
- You will be given a temporary URL that you can share with others.
If you have a VPS, you should use a proper reverse proxy like nginx instead for a more permanent solution which will allow you to use your own domain name, handle SSL, and add additional security/anti-abuse measures.
### nginx
First, install nginx.
- `sudo apt update && sudo apt install nginx`
#### Example basic nginx configuration (no SSL)
- `sudo nano /etc/nginx/sites-available/oai.conf`
- ```
server {
listen 80;
server_name example.com;
location / {
proxy_pass http://localhost:7860;
}
}
```
- Replace `example.com` with your domain name.
- Ctrl+X to exit, Y to save, Enter to confirm.
- `sudo ln -s /etc/nginx/sites-available/oai.conf /etc/nginx/sites-enabled`
- `sudo nginx -t`
- This will check the configuration file for errors.
- `sudo systemctl restart nginx`
- This will restart nginx and apply the new configuration.
#### Example with Cloudflare SSL
This allows you to use a self-signed certificate on the server, and have Cloudflare handle client SSL. You need to have a Cloudflare account and have your domain set up with Cloudflare already, pointing to your server's IP address.
- Set Cloudflare to use Full SSL mode. Since we are using a self-signed certificate, don't use Full (strict) mode.
- Create a self-signed certificate:
- `openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout /etc/ssl/private/nginx-selfsigned.key -out /etc/ssl/certs/nginx-selfsigned.crt`
- `sudo nano /etc/nginx/sites-available/oai.conf`
- ```
server {
listen 443 ssl;
server_name yourdomain.com www.yourdomain.com;
ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
# Only allow inbound traffic from Cloudflare
allow 173.245.48.0/20;
allow 103.21.244.0/22;
allow 103.22.200.0/22;
allow 103.31.4.0/22;
allow 141.101.64.0/18;
allow 108.162.192.0/18;
allow 190.93.240.0/20;
allow 188.114.96.0/20;
allow 197.234.240.0/22;
allow 198.41.128.0/17;
allow 162.158.0.0/15;
allow 104.16.0.0/13;
allow 104.24.0.0/14;
allow 172.64.0.0/13;
allow 131.0.72.0/22;
deny all;
location / {
proxy_pass http://localhost:7860;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection 'upgrade';
proxy_set_header Host $host;
proxy_cache_bypass $http_upgrade;
}
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers 'ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256';
ssl_prefer_server_ciphers on;
ssl_session_cache shared:SSL:10m;
}
```
- Replace `yourdomain.com` with your domain name.
- Ctrl+X to exit, Y to save, Enter to confirm.
- `sudo ln -s /etc/nginx/sites-available/oai.conf /etc/nginx/sites-enabled`
## Updating/Restarting the application
After making an .env change, you need to restart the application for it to take effect.
- `docker compose down`
- `docker compose up -d`
To update the application to the latest version:
- `docker compose pull`
- `docker compose down`
- `docker compose up -d`
- `docker image prune -f`
+45 -10
View File
@@ -20,7 +20,7 @@
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
"csrf-csrf": "^2.3.0",
"dotenv": "^16.0.3",
"dotenv": "^16.3.1",
"ejs": "^3.1.9",
"express": "^4.18.2",
"express-session": "^1.17.3",
@@ -36,6 +36,8 @@
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"source-map-support": "^0.5.21",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
"zlib": "^1.0.5",
@@ -51,6 +53,7 @@
"@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
@@ -59,7 +62,6 @@
"nodemon": "^3.0.1",
"pino-pretty": "^10.2.3",
"prettier": "^3.0.3",
"source-map-support": "^0.5.21",
"ts-node": "^10.9.1",
"typescript": "^5.1.3"
},
@@ -1185,6 +1187,25 @@
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
"dev": true
},
"node_modules/@types/stream-chain": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
"dev": true,
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/stream-json": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
"dev": true,
"dependencies": {
"@types/node": "*",
"@types/stream-chain": "*"
}
},
"node_modules/@types/uuid": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
@@ -2228,11 +2249,14 @@
}
},
"node_modules/dotenv": {
"version": "16.0.3",
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.0.3.tgz",
"integrity": "sha512-7GO6HghkA5fYG9TYnNxi14/7K9f5occMlp3zXAuSxn7CKCxt9xbNWG7yF8hTCSUchlfWSe3uLmlPfigevRItzQ==",
"version": "16.3.1",
"resolved": "https://registry.npmjs.org/dotenv/-/dotenv-16.3.1.tgz",
"integrity": "sha512-IPzF4w4/Rd94bA9imS68tZBaYyBWSCE47V1RGuMrB94iyTOIEwRmVL2x/4An+6mETpLrKJ5hQkB8W4kFAadeIQ==",
"engines": {
"node": ">=12"
},
"funding": {
"url": "https://github.com/motdotla/dotenv?sponsor=1"
}
},
"node_modules/duplexify": {
@@ -2800,9 +2824,9 @@
}
},
"node_modules/follow-redirects": {
"version": "1.15.2",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz",
"integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==",
"version": "1.15.4",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz",
"integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==",
"funding": [
{
"type": "individual",
@@ -5090,7 +5114,6 @@
"version": "0.6.1",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz",
"integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==",
"devOptional": true,
"engines": {
"node": ">=0.10.0"
}
@@ -5107,7 +5130,6 @@
"version": "0.5.21",
"resolved": "https://registry.npmjs.org/source-map-support/-/source-map-support-0.5.21.tgz",
"integrity": "sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==",
"dev": true,
"dependencies": {
"buffer-from": "^1.0.0",
"source-map": "^0.6.0"
@@ -5135,6 +5157,11 @@
"node": ">= 0.8"
}
},
"node_modules/stream-chain": {
"version": "2.2.5",
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
},
"node_modules/stream-events": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
@@ -5144,6 +5171,14 @@
"stubs": "^3.0.0"
}
},
"node_modules/stream-json": {
"version": "1.8.0",
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
"dependencies": {
"stream-chain": "^2.2.5"
}
},
"node_modules/stream-shift": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
+6 -3
View File
@@ -28,7 +28,7 @@
"copyfiles": "^2.4.1",
"cors": "^2.8.5",
"csrf-csrf": "^2.3.0",
"dotenv": "^16.0.3",
"dotenv": "^16.3.1",
"ejs": "^3.1.9",
"express": "^4.18.2",
"express-session": "^1.17.3",
@@ -44,6 +44,8 @@
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"source-map-support": "^0.5.21",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
"zlib": "^1.0.5",
@@ -59,6 +61,7 @@
"@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
@@ -67,12 +70,12 @@
"nodemon": "^3.0.1",
"pino-pretty": "^10.2.3",
"prettier": "^3.0.3",
"source-map-support": "^0.5.21",
"ts-node": "^10.9.1",
"typescript": "^5.1.3"
},
"overrides": {
"google-gax": "^3.6.1",
"postcss": "^8.4.31"
"postcss": "^8.4.31",
"follow-redirects": "^1.15.4"
}
}
+31 -3
View File
@@ -81,7 +81,7 @@ Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-3.5-turbo",
"model": "gpt-4-1106-preview",
"max_tokens": 20,
"stream": true,
"temperature": 1,
@@ -231,8 +231,36 @@ Content-Type: application/json
}
###
# @name Proxy / Google PaLM -- OpenAI-to-PaLM API Translation
POST {{proxy-host}}/proxy/google-palm/v1/chat/completions
# @name Proxy / Azure OpenAI -- Native Chat Completions
POST {{proxy-host}}/proxy/azure/openai/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
{
"model": "gpt-4",
"max_tokens": 20,
"stream": true,
"temperature": 1,
"seed": 2,
"messages": [
{
"role": "user",
"content": "Hi what is the name of the fourth president of the united states?"
},
{
"role": "assistant",
"content": "That would be George Washington."
},
{
"role": "user",
"content": "That's not right."
}
]
}
###
# @name Proxy / Google AI -- OpenAI-to-Google AI API Translation
POST {{proxy-host}}/proxy/google-ai/v1/chat/completions
Authorization: Bearer {{proxy-key}}
Content-Type: application/json
+4 -3
View File
@@ -1,6 +1,6 @@
const axios = require("axios");
const concurrentRequests = 5;
const concurrentRequests = 75;
const headers = {
Authorization: "Bearer test",
"Content-Type": "application/json",
@@ -16,7 +16,7 @@ const payload = {
const makeRequest = async (i) => {
try {
const response = await axios.post(
"http://localhost:7860/proxy/azure/openai/v1/chat/completions",
"http://localhost:7860/proxy/google-ai/v1/chat/completions",
payload,
{ headers }
);
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
response.data
);
} catch (error) {
console.error(`Error in req ${i}:`, error.message);
const msg = error.response
console.error(`Error in req ${i}:`, error.message, msg || "");
}
};
+3 -2
View File
@@ -4,7 +4,8 @@ import { HttpError } from "../shared/errors";
import { injectLocals } from "../shared/inject-locals";
import { withSession } from "../shared/with-session";
import { injectCsrfToken, checkCsrfToken } from "../shared/inject-csrf";
import { buildInfoPageHtml } from "../info-page";
import { renderPage } from "../info-page";
import { buildInfo } from "../service-info";
import { loginRouter } from "./login";
import { usersApiRouter as apiRouter } from "./api/users";
import { usersWebRouter as webRouter } from "./web/manage";
@@ -26,7 +27,7 @@ adminRouter.use("/", loginRouter);
adminRouter.use("/manage", authorize({ via: "cookie" }), webRouter);
adminRouter.use("/service-info", authorize({ via: "cookie" }), (req, res) => {
return res.send(
buildInfoPageHtml(req.protocol + "://" + req.get("host"), true)
renderPage(buildInfo(req.protocol + "://" + req.get("host"), true))
);
});
+1 -1
View File
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
keyPool.recheck("anthropic");
const size = keyPool
.list()
.filter((k) => k.service !== "google-palm").length;
.filter((k) => k.service !== "google-ai").length;
flash.type = "success";
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
break;
+93 -15
View File
@@ -4,6 +4,7 @@ import path from "path";
import pino from "pino";
import type { ModelFamily } from "./shared/models";
import { MODEL_FAMILIES } from "./shared/models";
dotenv.config();
const startupLogger = pino({ level: "debug" }).child({ module: "startup" });
@@ -15,12 +16,22 @@ export const USER_ASSETS_DIR = path.join(DATA_DIR, "user-files");
type Config = {
/** The port the proxy server will listen on. */
port: number;
/** The network interface the proxy server will listen on. */
bindAddress: string;
/** Comma-delimited list of OpenAI API keys. */
openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string;
/** Comma-delimited list of Google PaLM API keys. */
googlePalmKey?: string;
/**
* Comma-delimited list of Google AI API keys. Note that these are not the
* same as the GCP keys/credentials used for Vertex AI; the models are the
* same but the APIs are different. Vertex is the GCP product for enterprise.
**/
googleAIKey?: string;
/**
* Comma-delimited list of Mistral AI API keys.
*/
mistralAIKey?: string;
/**
* Comma-delimited list of AWS credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and AWS region.
@@ -189,15 +200,61 @@ type Config = {
* configured ADMIN_KEY and go to /admin/service-info.
**/
staticServiceInfo?: boolean;
/**
* Trusted proxy hops. If you are deploying the server behind a reverse proxy
* (Nginx, Cloudflare Tunnel, AWS WAF, etc.) the IP address of incoming
* requests will be the IP address of the proxy, not the actual user.
*
* Depending on your hosting configuration, there may be multiple proxies/load
* balancers between your server and the user. Each one will append the
* incoming IP address to the `X-Forwarded-For` header. The user's real IP
* address will be the first one in the list, assuming the header has not been
* tampered with. Setting this value correctly ensures that the server doesn't
* trust values in `X-Forwarded-For` not added by trusted proxies.
*
* In order for the server to determine the user's real IP address, you need
* to tell it how many proxies are between the user and the server so it can
* select the correct IP address from the `X-Forwarded-For` header.
*
* *WARNING:* If you set it incorrectly, the proxy will either record the
* wrong IP address, or it will be possible for users to spoof their IP
* addresses and bypass rate limiting. Check the request logs to see what
* incoming X-Forwarded-For values look like.
*
* Examples:
* - X-Forwarded-For: "34.1.1.1, 172.1.1.1, 10.1.1.1" => trustedProxies: 3
* - X-Forwarded-For: "34.1.1.1" => trustedProxies: 1
* - no X-Forwarded-For header => trustedProxies: 0 (the actual IP of the incoming request will be used)
*
* As of 2024/01/08:
* For HuggingFace or Cloudflare Tunnel, use 1.
* For Render, use 3.
* For deployments not behind a load balancer, use 0.
*
* You should double check against your actual request logs to be sure.
*
* Defaults to 1, as most deployments are on HuggingFace or Cloudflare Tunnel.
*/
trustedProxies?: number;
/**
* Whether to allow OpenAI tool usage. The proxy doesn't impelment any
* support for tools/function calling but can pass requests and responses as
* is. Note that the proxy also cannot accurately track quota usage for
* requests involving tools, so you must opt in to this feature at your own
* risk.
*/
allowOpenAIToolUsage?: boolean;
};
// To change configs, create a file called .env in the root directory.
// See .env.example for an example.
export const config: Config = {
port: getEnvWithDefault("PORT", 7860),
bindAddress: getEnvWithDefault("BIND_ADDRESS", "0.0.0.0"),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
mistralAIKey: getEnvWithDefault("MISTRAL_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
@@ -229,7 +286,10 @@ export const config: Config = {
"gpt4-32k",
"gpt4-turbo",
"claude",
"bison",
"gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude",
"azure-turbo",
"azure-gpt4",
@@ -273,6 +333,8 @@ export const config: Config = {
showRecentImages: getEnvWithDefault("SHOW_RECENT_IMAGES", true),
useInsecureCookies: getEnvWithDefault("USE_INSECURE_COOKIES", isDev),
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
} as const;
function generateCookieSecret() {
@@ -361,17 +423,20 @@ export const SENSITIVE_KEYS: (keyof Config)[] = ["googleSheetsSpreadsheetId"];
* Config keys that are not displayed on the info page at all, generally because
* they are not relevant to the user or can be inferred from other config.
*/
export const OMITTED_KEYS: (keyof Config)[] = [
export const OMITTED_KEYS = [
"port",
"bindAddress",
"logLevel",
"openaiKey",
"anthropicKey",
"googlePalmKey",
"googleAIKey",
"mistralAIKey",
"awsCredentials",
"azureCredentials",
"proxyKey",
"adminKey",
"rejectPhrases",
"rejectMessage",
"showTokenCosts",
"googleSheetsKey",
"firebaseKey",
@@ -387,34 +452,47 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"staticServiceInfo",
"checkKeys",
"allowedModelFamilies",
];
"trustedProxies"
] satisfies (keyof Config)[];
type OmitKeys = (typeof OMITTED_KEYS)[number];
type Printable<T> = {
[P in keyof T as Exclude<P, OmitKeys>]: T[P] extends object
? Printable<T[P]>
: string;
};
type PublicConfig = Printable<Config>;
const getKeys = Object.keys as <T extends object>(obj: T) => Array<keyof T>;
export function listConfig(obj: Config = config): Record<string, any> {
const result: Record<string, any> = {};
export function listConfig(obj: Config = config) {
const result: Record<string, unknown> = {};
for (const key of getKeys(obj)) {
const value = obj[key]?.toString() || "";
const shouldOmit =
OMITTED_KEYS.includes(key) || value === "" || value === "undefined";
const shouldMask = SENSITIVE_KEYS.includes(key);
const shouldOmit =
OMITTED_KEYS.includes(key as OmitKeys) ||
value === "" ||
value === "undefined";
if (shouldOmit) {
continue;
}
const validKey = key as keyof Printable<Config>;
if (value && shouldMask) {
result[key] = "********";
result[validKey] = "********";
} else {
result[key] = value;
result[validKey] = value;
}
if (typeof obj[key] === "object" && !Array.isArray(obj[key])) {
result[key] = listConfig(obj[key] as unknown as Config);
}
}
return result;
return result as PublicConfig;
}
/**
@@ -433,7 +511,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
[
"OPENAI_KEY",
"ANTHROPIC_KEY",
"GOOGLE_PALM_KEY",
"GOOGLE_AI_KEY",
"AWS_CREDENTIALS",
"AZURE_CREDENTIALS",
].includes(String(env))
+60 -501
View File
@@ -1,74 +1,39 @@
/** This whole module really sucks */
/** This whole module kinda sucks */
import fs from "fs";
import { Request, Response } from "express";
import showdown from "showdown";
import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GooglePalmKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
import {
AzureOpenAIModelFamily,
ModelFamily,
OpenAIModelFamily,
} from "./shared/models";
import { getUniqueIps } from "./proxy/rate-limit";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { getTokenCostUsd, prettyTokens } from "./shared/stats";
import { assertNever } from "./shared/utils";
import { config } from "./config";
import { buildInfo, ServiceInfo } from "./service-info";
import { getLastNImages } from "./shared/file-storage/image-history";
import { keyPool } from "./shared/key-management";
import { MODEL_FAMILY_SERVICE, ModelFamily } from "./shared/models";
const INFO_PAGE_TTL = 2000;
const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
"turbo": "GPT-3.5 Turbo",
"gpt4": "GPT-4",
"gpt4-32k": "GPT-4 32k",
"gpt4-turbo": "GPT-4 Turbo",
"dall-e": "DALL-E",
"claude": "Claude",
"gemini-pro": "Gemini Pro",
"mistral-tiny": "Mistral 7B",
"mistral-small": "Mixtral 8x7B",
"mistral-medium": "Mistral Medium (prototype)",
"aws-claude": "AWS Claude",
"azure-turbo": "Azure GPT-3.5 Turbo",
"azure-gpt4": "Azure GPT-4",
"azure-gpt4-32k": "Azure GPT-4 32k",
"azure-gpt4-turbo": "Azure GPT-4 Turbo",
};
const converter = new showdown.Converter();
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
: "";
let infoPageHtml: string | undefined;
let infoPageLastUpdated = 0;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
k.service === "google-palm";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
type ModelAggregates = {
active: number;
trial?: number;
revoked?: number;
overQuota?: number;
pozzed?: number;
awsLogged?: number;
queued: number;
queueTime: string;
tokens: number;
};
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type ServiceAggregates = {
status?: string;
openaiKeys?: number;
openaiOrgs?: number;
anthropicKeys?: number;
palmKeys?: number;
awsKeys?: number;
azureKeys?: number;
proompts: number;
tokens: number;
tokenCost: number;
openAiUncheckedKeys?: number;
anthropicUncheckedKeys?: number;
} & {
[modelFamily in ModelFamily]?: ModelAggregates;
};
const modelStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof ServiceAggregates, number>();
export const handleInfoPage = (req: Request, res: Response) => {
if (infoPageLastUpdated + INFO_PAGE_TTL > Date.now()) {
return res.send(infoPageHtml);
@@ -79,87 +44,16 @@ export const handleInfoPage = (req: Request, res: Response) => {
? getExternalUrlForHuggingfaceSpaceId(process.env.SPACE_ID)
: req.protocol + "://" + req.get("host");
infoPageHtml = buildInfoPageHtml(baseUrl + "/proxy");
const info = buildInfo(baseUrl + "/proxy");
infoPageHtml = renderPage(info);
infoPageLastUpdated = Date.now();
res.send(infoPageHtml);
};
function getCostString(cost: number) {
if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`;
}
export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const keys = keyPool.list();
const hideFullInfo = config.staticServiceInfo && !asAdmin;
modelStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
const openaiKeys = serviceStats.get("openaiKeys") || 0;
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
const palmKeys = serviceStats.get("palmKeys") || 0;
const awsKeys = serviceStats.get("awsKeys") || 0;
const azureKeys = serviceStats.get("azureKeys") || 0;
const proompts = serviceStats.get("proompts") || 0;
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
const allowDalle = config.allowedModelFamilies.includes("dall-e");
const endpoints = {
...(openaiKeys ? { openai: baseUrl + "/openai" } : {}),
...(openaiKeys ? { openai2: baseUrl + "/openai/turbo-instruct" } : {}),
...(openaiKeys && allowDalle
? { ["openai-image"]: baseUrl + "/openai-image" }
: {}),
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
};
const stats = {
proompts,
tookens: `${prettyTokens(tokens)}${getCostString(tokenCost)}`,
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
for (const key of Object.keys(keyInfo)) {
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
}
const providerInfo = {
...(openaiKeys ? getOpenAIInfo() : {}),
...(anthropicKeys ? getAnthropicInfo() : {}),
...(palmKeys ? getPalmInfo() : {}),
...(awsKeys ? getAwsInfo() : {}),
...(azureKeys ? getAzureInfo() : {}),
};
if (hideFullInfo) {
for (const provider of Object.keys(providerInfo)) {
delete (providerInfo as any)[provider].proomptersInQueue;
delete (providerInfo as any)[provider].estimatedQueueTime;
delete (providerInfo as any)[provider].usage;
}
}
const info = {
uptime: Math.floor(process.uptime()),
endpoints,
...(hideFullInfo ? {} : stats),
...keyInfo,
...providerInfo,
config: listConfig(),
build: process.env.BUILD_INFO || "dev",
};
export function renderPage(info: ServiceInfo) {
const title = getServerTitle();
const headerHtml = buildInfoPageHeader(new showdown.Converter(), title);
const headerHtml = buildInfoPageHeader(info);
return `<!DOCTYPE html>
<html lang="en">
@@ -167,8 +61,25 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
<meta charset="utf-8" />
<meta name="robots" content="noindex" />
<title>${title}</title>
<style>
body {
font-family: sans-serif;
background-color: #f0f0f0;
padding: 1em;
}
@media (prefers-color-scheme: dark) {
body {
background-color: #222;
color: #eee;
}
a:link, a:visited {
color: #bbe;
}
}
</style>
</head>
<body style="font-family: sans-serif; background-color: #f0f0f0; padding: 1em;">
<body>
${headerHtml}
<hr />
<h2>Service Info</h2>
@@ -178,324 +89,14 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
</html>`;
}
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
const orgIds = new Set(
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
);
return orgIds.size;
}
function increment<T extends keyof ServiceAggregates | ModelAggregateKey>(
map: Map<T, number>,
key: T,
delta = 1
) {
map.set(key, (map.get(key) || 0) + delta);
}
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openAiUncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
const family = "claude";
sumTokens += k.claudeTokens;
sumCost += getTokenCostUsd(family, k.claudeTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.claudeTokens);
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
increment(
serviceStats,
"anthropicUncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-palm": {
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
const family = "bison";
sumTokens += k.bisonTokens;
sumCost += getTokenCostUsd(family, k.bisonTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.bisonTokens);
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude";
sumTokens += k["aws-claudeTokens"];
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
break;
}
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
}
function getOpenAIInfo() {
const info: { status?: string; openaiKeys?: number; openaiOrgs?: number } & {
[modelFamily in OpenAIModelFamily]?: {
usage?: string;
activeKeys: number;
trialKeys?: number;
revokedKeys?: number;
overQuotaKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
} = {};
const keys = keyPool.list().filter(keyIsOpenAIKey);
const enabledFamilies = new Set(config.allowedModelFamilies);
const accessibleFamilies = keys
.flatMap((k) => k.modelFamilies)
.filter((f) => enabledFamilies.has(f))
.concat("turbo");
const familySet = new Set(accessibleFamilies);
if (config.checkKeys) {
const unchecked = serviceStats.get("openAiUncheckedKeys") || 0;
if (unchecked > 0) {
info.status = `Checking ${unchecked} keys...`;
}
info.openaiKeys = keys.length;
info.openaiOrgs = getUniqueOpenAIOrgs(keys);
familySet.forEach((f) => {
const tokens = modelStats.get(`${f}__tokens`) || 0;
const cost = getTokenCostUsd(f, tokens);
info[f] = {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: modelStats.get(`${f}__active`) || 0,
trialKeys: modelStats.get(`${f}__trial`) || 0,
revokedKeys: modelStats.get(`${f}__revoked`) || 0,
overQuotaKeys: modelStats.get(`${f}__overQuota`) || 0,
};
// Don't show trial/revoked keys for non-turbo families.
// Generally those stats only make sense for the lowest-tier model.
if (f !== "turbo") {
delete info[f]!.trialKeys;
delete info[f]!.revokedKeys;
}
});
} else {
info.status = "Key checking is disabled.";
info.turbo = { activeKeys: keys.filter((k) => !k.isDisabled).length };
info.gpt4 = {
activeKeys: keys.filter(
(k) => !k.isDisabled && k.modelFamilies.includes("gpt4")
).length,
};
}
familySet.forEach((f) => {
if (enabledFamilies.has(f)) {
if (!info[f]) info[f] = { activeKeys: 0 }; // may occur if checkKeys is disabled
const { estimatedQueueTime, proomptersInQueue } = getQueueInformation(f);
info[f]!.proomptersInQueue = proomptersInQueue;
info[f]!.estimatedQueueTime = estimatedQueueTime;
} else {
(info[f]! as any).status = "GPT-3.5-Turbo is disabled on this proxy.";
}
});
return info;
}
function getAnthropicInfo() {
const claudeInfo: Partial<ModelAggregates> = {
active: modelStats.get("claude__active") || 0,
pozzed: modelStats.get("claude__pozzed") || 0,
revoked: modelStats.get("claude__revoked") || 0,
};
const queue = getQueueInformation("claude");
claudeInfo.queued = queue.proomptersInQueue;
claudeInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("claude__tokens") || 0;
const cost = getTokenCostUsd("claude", tokens);
const unchecked =
(config.checkKeys && serviceStats.get("anthropicUncheckedKeys")) || 0;
return {
claude: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
...(unchecked > 0 ? { status: `Checking ${unchecked} keys...` } : {}),
activeKeys: claudeInfo.active,
revokedKeys: claudeInfo.revoked,
...(config.checkKeys ? { pozzedKeys: claudeInfo.pozzed } : {}),
proomptersInQueue: claudeInfo.queued,
estimatedQueueTime: claudeInfo.queueTime,
},
};
}
function getPalmInfo() {
const bisonInfo: Partial<ModelAggregates> = {
active: modelStats.get("bison__active") || 0,
revoked: modelStats.get("bison__revoked") || 0,
};
const queue = getQueueInformation("bison");
bisonInfo.queued = queue.proomptersInQueue;
bisonInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("bison__tokens") || 0;
const cost = getTokenCostUsd("bison", tokens);
return {
bison: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active,
revokedKeys: bisonInfo.revoked,
proomptersInQueue: bisonInfo.queued,
estimatedQueueTime: bisonInfo.queueTime,
},
};
}
function getAwsInfo() {
const awsInfo: Partial<ModelAggregates> = {
active: modelStats.get("aws-claude__active") || 0,
revoked: modelStats.get("aws-claude__revoked") || 0,
};
const queue = getQueueInformation("aws-claude");
awsInfo.queued = queue.proomptersInQueue;
awsInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("aws-claude__tokens") || 0;
const cost = getTokenCostUsd("aws-claude", tokens);
const logged = modelStats.get("aws-claude__awsLogged") || 0;
const logMsg = config.allowAwsLogging
? `${logged} active keys are potentially logged.`
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
return {
"aws-claude": {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: awsInfo.active,
revokedKeys: awsInfo.revoked,
proomptersInQueue: awsInfo.queued,
estimatedQueueTime: awsInfo.queueTime,
...(logged > 0 ? { privacy: logMsg } : {}),
},
};
}
function getAzureInfo() {
const azureFamilies = [
"azure-turbo",
"azure-gpt4",
"azure-gpt4-turbo",
"azure-gpt4-32k",
] as const;
const azureInfo: {
[modelFamily in AzureOpenAIModelFamily]?: {
usage?: string;
activeKeys: number;
revokedKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
} = {};
for (const family of azureFamilies) {
const familyAllowed = config.allowedModelFamilies.includes(family);
const activeKeys = modelStats.get(`${family}__active`) || 0;
if (!familyAllowed || activeKeys === 0) continue;
azureInfo[family] = {
activeKeys,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
};
const queue = getQueueInformation(family);
azureInfo[family]!.proomptersInQueue = queue.proomptersInQueue;
azureInfo[family]!.estimatedQueueTime = queue.estimatedQueueTime;
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
azureInfo[family]!.usage = `${prettyTokens(tokens)} tokens${getCostString(
cost
)}`;
}
return azureInfo;
}
const customGreeting = fs.existsSync("greeting.md")
? `\n## Server Greeting\n${fs.readFileSync("greeting.md", "utf8")}`
: "";
/**
* If the server operator provides a `greeting.md` file, it will be included in
* the rendered info page.
**/
function buildInfoPageHeader(converter: showdown.Converter, title: string) {
function buildInfoPageHeader(info: ServiceInfo) {
const title = getServerTitle();
// TODO: use some templating engine instead of this mess
let infoBody = `<!-- Header for Showdown's parser, don't remove this line -->
# ${title}`;
let infoBody = `# ${title}`;
if (config.promptLogging) {
infoBody += `\n## Prompt Logging Enabled
This proxy keeps full logs of all prompts and AI responses. Prompt logs are anonymous and do not contain IP addresses or timestamps.
@@ -510,45 +111,18 @@ This proxy keeps full logs of all prompts and AI responses. Prompt logs are anon
}
const waits: string[] = [];
infoBody += `\n## Estimated Wait Times`;
if (config.openaiKey) {
// TODO: un-fuck this
const keys = keyPool.list().filter((k) => k.service === "openai");
for (const modelFamily of config.allowedModelFamilies) {
const service = MODEL_FAMILY_SERVICE[modelFamily];
const turboWait = getQueueInformation("turbo").estimatedQueueTime;
waits.push(`**Turbo:** ${turboWait}`);
const hasKeys = keyPool.list().some((k) => {
return k.service === service && k.modelFamilies.includes(modelFamily);
});
const gpt4Wait = getQueueInformation("gpt4").estimatedQueueTime;
const hasGpt4 = keys.some((k) => k.modelFamilies.includes("gpt4"));
const allowedGpt4 = config.allowedModelFamilies.includes("gpt4");
if (hasGpt4 && allowedGpt4) {
waits.push(`**GPT-4:** ${gpt4Wait}`);
const wait = info[modelFamily]?.estimatedQueueTime;
if (hasKeys && wait) {
waits.push(`**${MODEL_FAMILY_FRIENDLY_NAME[modelFamily] || modelFamily}**: ${wait}`);
}
const gpt432kWait = getQueueInformation("gpt4-32k").estimatedQueueTime;
const hasGpt432k = keys.some((k) => k.modelFamilies.includes("gpt4-32k"));
const allowedGpt432k = config.allowedModelFamilies.includes("gpt4-32k");
if (hasGpt432k && allowedGpt432k) {
waits.push(`**GPT-4-32k:** ${gpt432kWait}`);
}
const dalleWait = getQueueInformation("dall-e").estimatedQueueTime;
const hasDalle = keys.some((k) => k.modelFamilies.includes("dall-e"));
const allowedDalle = config.allowedModelFamilies.includes("dall-e");
if (hasDalle && allowedDalle) {
waits.push(`**DALL-E:** ${dalleWait}`);
}
}
if (config.anthropicKey) {
const claudeWait = getQueueInformation("claude").estimatedQueueTime;
waits.push(`**Claude:** ${claudeWait}`);
}
if (config.awsCredentials) {
const awsClaudeWait = getQueueInformation("aws-claude").estimatedQueueTime;
waits.push(`**Claude (AWS):** ${awsClaudeWait}`);
}
infoBody += "\n\n" + waits.join(" / ");
@@ -565,21 +139,6 @@ function getSelfServiceLinks() {
return `<footer style="font-size: 0.8em;"><hr /><a target="_blank" href="/user/lookup">Check your user token info</a></footer>`;
}
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: ModelFamily) {
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
? `${Math.round(waitMs / 1000)}sec`
: `${Math.round(waitMs / 60000)}min, ${Math.round(
(waitMs % 60000) / 1000
)}sec`;
return {
proomptersInQueue: getQueueLength(partition),
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
};
}
function getServerTitle() {
// Use manually set title if available
if (process.env.SERVER_TITLE) {
+1 -10
View File
@@ -173,16 +173,7 @@ anthropicRouter.post(
function maybeReassignModel(req: Request) {
const model = req.body.model;
if (!model.startsWith("gpt-")) return;
const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k";
const contextSize = req.promptTokens! + req.outputTokens!;
if (contextSize > 8500) {
req.log.debug(
{ model: bigModel, contextSize },
"Using Claude 100k model for OpenAI-to-Anthropic request"
);
req.body.model = bigModel;
}
req.body.model = "claude-2.1";
}
export const anthropic = anthropicRouter;
+58
View File
@@ -0,0 +1,58 @@
/* Provides a single endpoint for all services. */
import { RequestHandler } from "express";
import { generateErrorMessage } from "zod-error";
import { APIFormat } from "../shared/key-management";
import {
getServiceForModel,
LLMService,
MODEL_FAMILIES,
MODEL_FAMILY_SERVICE,
ModelFamily,
} from "../shared/models";
import { API_SCHEMA_VALIDATORS } from "../shared/api-schemas";
const detectApiFormat = (body: any, formats: APIFormat[]): APIFormat => {
const errors = [];
for (const format of formats) {
const result = API_SCHEMA_VALIDATORS[format].safeParse(body);
if (result.success) {
return format;
} else {
errors.push(result.error);
}
}
throw new Error(`Couldn't determine the format of your request. Errors: ${errors}`);
};
/**
* Tries to infer LLMService and APIFormat using the model name and the presence
* of certain fields in the request body.
*/
const inferService: RequestHandler = (req, res, next) => {
const model = req.body.model;
if (!model) {
throw new Error("No model specified");
}
// Service determines the key provider and is typically determined by the
// requested model, though some models are served by multiple services.
// API format determines the expected request/response format.
let service: LLMService;
let inboundApi: APIFormat;
let outboundApi: APIFormat;
if (MODEL_FAMILIES.includes(model)) {
service = MODEL_FAMILY_SERVICE[model as ModelFamily];
} else {
service = getServiceForModel(model);
}
// Each service has typically one API format.
switch (service) {
case "openai": {
const detected = detectApiFormat(req.body, ["openai", "openai-text", "openai-image"]);
}
}
};
+140
View File
@@ -0,0 +1,140 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googleAIKey) return { object: "list", data: [] };
const googleAIVariants = ["gemini-pro"];
const models = googleAIVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "google",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google AI response to OpenAI format");
body = transformGoogleAIResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
function transformGoogleAIResponse(
resBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }];
const content = parts[0].text.replace(/^(.{0,50}?): /, () => "");
return {
id: "goo-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: { role: "assistant", content },
finish_reason: resBody.candidates[0].finishReason,
index: 0,
},
],
};
}
const googleAIProxy = createQueueMiddleware({
beforeProxy: addGoogleAIKey,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
const { protocol, hostname, path } = signedRequest;
return `${protocol}//${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
error: handleProxyError,
},
}),
});
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [forceModel("gemini-pro")] }
),
googleAIProxy
);
export const googleAI = googleAIRouter;
+58 -23
View File
@@ -2,7 +2,7 @@ import { Request, Response } from "express";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error";
import { buildFakeSse } from "../../shared/streaming";
import { makeCompletionSSE } from "../../shared/streaming";
import { assertNever } from "../../shared/utils";
import { QuotaExceededError } from "./request/preprocessors/apply-quota-limits";
@@ -40,11 +40,13 @@ export function writeErrorResponse(
req: Request,
res: Response,
statusCode: number,
statusMessage: string,
errorPayload: Record<string, any>
) {
const errorSource = errorPayload.error?.type?.startsWith("proxy")
? "proxy"
: "upstream";
const msg =
statusCode === 500
? `The proxy encountered an error while trying to process your prompt.`
: `The proxy encountered an error while trying to send your prompt to the upstream service.`;
// If we're mid-SSE stream, send a data event with the error payload and end
// the stream. Otherwise just send a normal error response.
@@ -52,10 +54,15 @@ export function writeErrorResponse(
res.headersSent ||
String(res.getHeader("content-type")).startsWith("text/event-stream")
) {
const errorTitle = `${errorSource} error (${statusCode})`;
const errorContent = JSON.stringify(errorPayload, null, 2);
const msg = buildFakeSse(errorTitle, errorContent, req);
res.write(msg);
const event = makeCompletionSSE({
format: req.inboundApi,
title: `Proxy error (HTTP ${statusCode} ${statusMessage})`,
message: `${msg} Further technical details are provided below.`,
obj: errorPayload,
reqId: req.id,
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
@@ -77,8 +84,9 @@ export const classifyErrorAndSend = (
res: Response
) => {
try {
const { status, userMessage, ...errorDetails } = classifyError(err);
writeErrorResponse(req, res, status, {
const { statusCode, statusMessage, userMessage, ...errorDetails } =
classifyError(err);
writeErrorResponse(req, res, statusCode, statusMessage, {
error: { message: userMessage, ...errorDetails },
});
} catch (error) {
@@ -88,14 +96,17 @@ export const classifyErrorAndSend = (
function classifyError(err: Error): {
/** HTTP status code returned to the client. */
status: number;
statusCode: number;
/** HTTP status message returned to the client. */
statusMessage: string;
/** Message displayed to the user. */
userMessage: string;
/** Short error type, e.g. "proxy_validation_error". */
type: string;
} & Record<string, any> {
const defaultError = {
status: 500,
statusCode: 500,
statusMessage: "Internal Server Error",
userMessage: `Reverse proxy error: ${err.message}`,
type: "proxy_internal_error",
stack: err.stack,
@@ -112,19 +123,33 @@ function classifyError(err: Error): {
return `At '${rest.pathComponent}': ${issue.message}`;
},
});
return { status: 400, userMessage, type: "proxy_validation_error" };
case "ForbiddenError":
return {
statusCode: 400,
statusMessage: "Bad Request",
userMessage,
type: "proxy_validation_error",
};
case "ZoomerForbiddenError":
// Mimics a ban notice from OpenAI, thrown when blockZoomerOrigins blocks
// a request.
return {
status: 403,
statusCode: 403,
statusMessage: "Forbidden",
userMessage: `Your account has been disabled for violating our terms of service.`,
type: "organization_account_disabled",
code: "policy_violation",
};
case "ForbiddenError":
return {
statusCode: 403,
statusMessage: "Forbidden",
userMessage: `Request is not allowed. (${err.message})`,
type: "proxy_forbidden",
};
case "QuotaExceededError":
return {
status: 429,
statusCode: 429,
statusMessage: "Too Many Requests",
userMessage: `You've exceeded your token quota for this model type.`,
type: "proxy_quota_exceeded",
info: (err as QuotaExceededError).quotaInfo,
@@ -134,21 +159,24 @@ function classifyError(err: Error): {
switch (err.code) {
case "ENOTFOUND":
return {
status: 502,
statusCode: 502,
statusMessage: "Bad Gateway",
userMessage: `Reverse proxy encountered a DNS error while trying to connect to the upstream service.`,
type: "proxy_network_error",
code: err.code,
};
case "ECONNREFUSED":
return {
status: 502,
statusCode: 502,
statusMessage: "Bad Gateway",
userMessage: `Reverse proxy couldn't connect to the upstream service.`,
type: "proxy_network_error",
code: err.code,
};
case "ECONNRESET":
return {
status: 504,
statusCode: 504,
statusMessage: "Gateway Timeout",
userMessage: `Reverse proxy timed out while waiting for the upstream service to respond.`,
type: "proxy_network_error",
code: err.code,
@@ -165,7 +193,10 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
const format = req.outboundApi;
switch (format) {
case "openai":
return body.choices[0].message.content;
case "mistral-ai":
// Can be null if the model wants to invoke tools rather than return a
// completion.
return body.choices[0].message.content || "";
case "openai-text":
return body.choices[0].text;
case "anthropic":
@@ -177,8 +208,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
return "";
}
return body.completion.trim();
case "google-palm":
return body.candidates[0].output;
case "google-ai":
if ("choices" in body) {
return body.choices[0].message.content;
}
return body.candidates[0].content.parts[0].text;
case "openai-image":
return body.data?.map((item: any) => item.url).join("\n");
default:
@@ -191,13 +225,14 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
switch (format) {
case "openai":
case "openai-text":
case "mistral-ai":
return body.model;
case "openai-image":
return req.body.model;
case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;
case "google-palm":
case "google-ai":
// Google doesn't confirm the model in the response.
return req.body.model;
default:
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
// TODO: this flag is set in too many places
req.isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
try {
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
case "anthropic":
assignedKey = keyPool.get("claude-v1");
break;
case "google-palm":
assignedKey = keyPool.get("text-bison-001");
delete req.body.stream;
break;
case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
break;
@@ -42,6 +38,10 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
throw new Error(
"OpenAI Chat as an API translation target is not supported"
);
case "google-ai":
throw new Error("add-key should not be used for this model.");
case "mistral-ai":
throw new Error("Mistral AI should never be translated");
case "openai-image":
assignedKey = keyPool.get("dall-e-3");
break;
@@ -71,23 +71,16 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
if (key.organizationId) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
}
case "mistral-ai":
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break;
case "google-palm":
const originalPath = proxyReq.path;
proxyReq.path = originalPath.replace(
/(\?.*)?$/,
`?key=${assignedKey.key}`
);
break;
case "azure":
const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
throw new Error(
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
);
case "google-ai":
throw new Error("add-key should not be used for this service.");
default:
assertNever(assignedKey.service);
}
@@ -2,10 +2,10 @@ import { HPMRequestCallback } from "../index";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
class ForbiddenError extends Error {
class ZoomerForbiddenError extends Error {
constructor(message: string) {
super(message);
this.name = "ForbiddenError";
this.name = "ZoomerForbiddenError";
}
}
@@ -22,7 +22,7 @@ export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => {
return;
}
throw new ForbiddenError(
throw new ZoomerForbiddenError(
`Your access was terminated due to violation of our policies, please check your email for more information. If you believe this is in error and would like to appeal, please contact us through our help center at help.openai.com.`
);
}
@@ -1,13 +1,14 @@
import { HPMRequestCallback } from "../index";
import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models";
/**
* Ensures the selected model family is enabled by the proxy configuration.
**/
export const checkModelFamily: HPMRequestCallback = (proxyReq, req) => {
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req, res) => {
const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) {
throw new Error(`Model family ${family} is not permitted on this proxy`);
throw new ForbiddenError(`Model family '${family}' is not enabled on this proxy`);
}
};
@@ -1,9 +1,9 @@
import type { HPMRequestCallback } from "../index";
/**
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
* before the proxy middleware. This function just assigns the path and headers
* to the proxy request.
* For AWS/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) {
@@ -18,6 +18,22 @@ export const addAzureKey: RequestPreprocessor = (req) => {
req.key = keyPool.get(model);
req.body.model = model;
// Handles the sole Azure API deviation from the OpenAI spec (that I know of)
const notNullOrUndefined = (x: any) => x !== null && x !== undefined;
if ([req.body.logprobs, req.body.top_logprobs].some(notNullOrUndefined)) {
// OpenAI wants logprobs: true/false and top_logprobs: number
// Azure seems to just want to combine them into logprobs: number
// if (typeof req.body.logprobs === "boolean") {
// req.body.logprobs = req.body.top_logprobs || undefined;
// delete req.body.top_logprobs
// }
// Temporarily just disabling logprobs for Azure because their model support
// is random: `This model does not support the 'logprobs' parameter.`
delete req.body.logprobs;
delete req.body.top_logprobs;
}
req.log.info(
{ key: req.key.hash, model },
@@ -0,0 +1,40 @@
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!apisValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model;
req.key = keyPool.get(model);
req.log.info(
{ key: req.key.hash, model },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
req.isStreaming = req.isStreaming || req.body.stream;
delete req.body.stream;
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",
},
body: JSON.stringify(req.body),
};
};
@@ -1,7 +1,11 @@
import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload";
import {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../../shared/api-schemas";
/**
* Given a request with an already-transformed body, counts the number of
@@ -30,9 +34,15 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
case "google-palm": {
req.outputTokens = req.body.maxOutputTokens;
const prompt: string = req.body.prompt.text;
case "google-ai": {
req.outputTokens = req.body.generationConfig.maxOutputTokens;
const prompt: GoogleAIChatMessage[] = req.body.contents;
result = await countTokens({ req, prompt, service });
break;
}
case "mistral-ai": {
req.outputTokens = req.body.max_tokens;
const prompt: MistralAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
}
@@ -3,7 +3,10 @@ import { config } from "../../../../config";
import { assertNever } from "../../../../shared/utils";
import { RequestPreprocessor } from "../index";
import { UserInputError } from "../../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload";
import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../../shared/api-schemas";
const rejectedClients = new Map<string, number>();
@@ -53,8 +56,9 @@ function getPromptFromRequest(req: Request) {
case "anthropic":
return body.prompt;
case "openai":
case "mistral-ai":
return body.messages
.map((msg: OpenAIChatMessage) => {
.map((msg: OpenAIChatMessage | MistralAIChatMessage) => {
const text = Array.isArray(msg.content)
? msg.content
.map((c) => {
@@ -68,7 +72,7 @@ function getPromptFromRequest(req: Request) {
case "openai-text":
case "openai-image":
return body.prompt;
case "google-palm":
case "google-ai":
return body.prompt.text;
default:
assertNever(service);
@@ -1,13 +1,14 @@
import { Request } from "express";
import { APIFormat, LLMService } from "../../../../shared/key-management";
import { APIFormat } from "../../../../shared/key-management";
import { LLMService } from "../../../../shared/models";
import { RequestPreprocessor } from "../index";
export const setApiFormat = (api: {
inApi: Request["inboundApi"];
outApi: APIFormat;
service: LLMService,
service: LLMService;
}): RequestPreprocessor => {
return function configureRequestApiFormat (req) {
return function configureRequestApiFormat(req) {
req.inboundApi = api.inApi;
req.outboundApi = api.outApi;
req.service = api.service;
@@ -2,9 +2,9 @@ import express from "express";
import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http";
import { AnthropicV1CompleteSchema } from "../../../../shared/api-schemas/anthropic";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { AnthropicV1CompleteSchema } from "./transform-outbound-payload";
const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
@@ -32,7 +32,9 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
temperature: true,
top_k: true,
top_p: true,
}).parse(req.body);
})
.strip()
.parse(req.body);
const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region);
@@ -68,6 +70,7 @@ type Credential = {
secretAccessKey: string;
region: string;
};
function getCredentialParts(req: express.Request): Credential {
const [accessKeyId, secretAccessKey, region] = req.key!.key.split(":");
@@ -1,151 +1,14 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
import {
isImageGenerationRequest,
isTextGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index";
import { APIFormat } from "../../../../shared/key-management";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// TODO: move schemas to shared
// https://console.anthropic.com/docs/api/reference#-v1-complete
export const AnthropicV1CompleteSchema = z.object({
model: z.string(),
prompt: z.string({
required_error:
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
}),
max_tokens_to_sample: z.coerce
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
stop_sequences: z.array(z.string()).optional(),
stream: z.boolean().optional().default(false),
temperature: z.coerce.number().optional().default(1),
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
metadata: z.any().optional(),
});
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatContentArraySchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image_url"),
image_url: z.object({
url: z.string().url(),
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
}),
}),
])
);
export const OpenAIV1ChatCompletionSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
name: z.string().optional(),
}),
{
required_error:
"No `messages` found. Ensure you've set the correct completion endpoint.",
invalid_type_error:
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
}
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z
.literal(1, {
errorMap: () => ({
message: "You may only request a single completion at a time.",
}),
})
.optional(),
stream: z.boolean().optional().default(false),
stop: z.union([z.string(), z.array(z.string())]).optional(),
max_tokens: z.coerce
.number()
.int()
.nullish()
.default(16)
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
user: z.string().optional(),
seed: z.number().int().optional(),
});
export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema
>["messages"][0];
const OpenAIV1TextCompletionSchema = z
.object({
model: z
.string()
.regex(
/^gpt-3.5-turbo-instruct/,
"Model must start with 'gpt-3.5-turbo-instruct'"
),
prompt: z.string({
required_error:
"No `prompt` found. Ensure you've set the correct completion endpoint.",
}),
logprobs: z.number().int().nullish().default(null),
echo: z.boolean().optional().default(false),
best_of: z.literal(1).optional(),
stop: z.union([z.string(), z.array(z.string()).max(4)]).optional(),
suffix: z.string().optional(),
})
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true }));
// https://platform.openai.com/docs/api-reference/images/create
const OpenAIV1ImagesGenerationSchema = z.object({
prompt: z.string().max(4000),
model: z.string().optional(),
quality: z.enum(["standard", "hd"]).optional().default("standard"),
n: z.number().int().min(1).max(4).optional().default(1),
response_format: z.enum(["url", "b64_json"]).optional(),
size: z
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.optional()
.default("1024x1024"),
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
user: z.string().optional(),
});
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
const PalmV1GenerateTextSchema = z.object({
model: z.string(),
prompt: z.object({ text: z.string() }),
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
stopSequences: z.array(z.string()).max(5).optional(),
});
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema,
};
import { openAIToAnthropic } from "../../../../shared/api-schemas/anthropic";
import { openAIToOpenAIText } from "../../../../shared/api-schemas/openai-text";
import { openAIToOpenAIImage } from "../../../../shared/api-schemas/openai-image";
import { openAIToGoogleAI } from "../../../../shared/api-schemas/google-ai";
import { fixMistralPrompt } from "../../../../shared/api-schemas/mistral-ai";
import { API_SCHEMA_VALIDATORS } from "../../../../shared/api-schemas";
/** Transforms an incoming request body to one that matches the target API. */
export const transformOutboundPayload: RequestPreprocessor = async (req) => {
@@ -156,8 +19,17 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
if (alreadyTransformed || notTransformable) return;
if (req.inboundApi === "mistral-ai") {
const messages = req.body.messages;
req.body.messages = fixMistralPrompt(messages);
req.log.info(
{ old: messages.length, new: req.body.messages.length },
"Fixed Mistral prompt"
);
}
if (sameService) {
const result = VALIDATORS[req.inboundApi].safeParse(req.body);
const result = API_SCHEMA_VALIDATORS[req.inboundApi].safeParse(req.body);
if (!result.success) {
req.log.error(
{ issues: result.error.issues, body: req.body },
@@ -170,22 +42,22 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
}
if (req.inboundApi === "openai" && req.outboundApi === "anthropic") {
req.body = openaiToAnthropic(req);
req.body = openAIToAnthropic(req);
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") {
req.body = openaiToPalm(req);
if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
req.body = openAIToGoogleAI(req);
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "openai-text") {
req.body = openaiToOpenaiText(req);
req.body = openAIToOpenAIText(req);
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "openai-image") {
req.body = openaiToOpenaiImage(req);
req.body = openAIToOpenAIImage(req);
return;
}
@@ -193,238 +65,3 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
);
};
function openaiToAnthropic(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Anthropic request"
);
throw result.error;
}
req.headers["anthropic-version"] = "2023-06-01";
const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudePrompt(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
// Recommended by Anthropic
stops.push("\n\nHuman:");
// Helps with jailbreak prompts that send fake system messages and multi-bot
// chats that prefix bot messages with "System: Respond as <bot name>".
stops.push("\n\nSystem:");
// Remove duplicates
stops = [...new Set(stops)];
return {
// Model may be overridden in `calculate-context-size.ts` to avoid having
// a circular dependency (`calculate-context-size.ts` needs an already-
// transformed request body to count tokens, but this function would like
// to know the count to select a model).
model: process.env.CLAUDE_SMALL_MODEL || "claude-v1",
prompt: prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: stops,
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
};
}
function openaiToOpenaiText(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-text request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push("\n\nUser:");
stops = [...new Set(stops)];
const transformed = { ...rest, prompt: prompt, stop: stops };
return OpenAIV1TextCompletionSchema.parse(transformed);
}
// Takes the last chat message and uses it verbatim as the image prompt.
function openaiToOpenaiImage(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-image request"
);
throw result.error;
}
const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (Array.isArray(prompt)) {
throw new Error("Image generation prompt must be a text message.");
}
if (body.stream) {
throw new Error(
"Streaming is not supported for image generation requests."
);
}
// Some frontends do weird things with the prompt, like prefixing it with a
// character name or wrapping the entire thing in quotes. We will look for
// the index of "Image:" and use everything after that as the prompt.
const index = prompt?.toLowerCase().indexOf("image:");
if (index === -1 || !prompt) {
throw new Error(
`Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`
);
}
// TODO: Add some way to specify parameters via chat message
const transformed = {
model: body.model.includes("dall-e") ? body.model : "dall-e-3",
quality: "standard",
size: "1024x1024",
response_format: "url",
prompt: prompt.slice(index! + 6).trim(),
};
return OpenAIV1ImagesGenerationSchema.parse(transformed);
}
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Palm request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push("\n\nUser:");
stops = [...new Set(stops)];
z.array(z.string()).max(5).parse(stops);
return {
prompt: { text: prompt },
maxOutputTokens: rest.max_tokens,
stopSequences: stops,
model: "text-bison-001",
topP: rest.top_p,
temperature: rest.temperature,
safetySettings: [
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
],
};
}
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
// Temporary to allow experimenting with prompt strategies
const PROMPT_VERSION: number = 1;
switch (PROMPT_VERSION) {
case 1:
return (
messages
.map((m) => {
// Claude-style human/assistant turns
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
case 2:
return messages
.map((m) => {
// Claude without prefixes (except system) and no Assistant priming
let role: string = "";
if (role === "system") {
role = "System: ";
}
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
})
.join("");
default:
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}
function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"]
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}
@@ -6,7 +6,8 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const BISON_MAX_CONTEXT = 8100;
const GOOGLE_AI_MAX_CONTEXT = 32000;
const MISTRAL_AI_MAX_CONTENT = 32768;
/**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -31,8 +32,11 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "anthropic":
proxyMax = CLAUDE_MAX_CONTEXT;
break;
case "google-palm":
proxyMax = BISON_MAX_CONTEXT;
case "google-ai":
proxyMax = GOOGLE_AI_MAX_CONTEXT;
break;
case "mistral-ai":
proxyMax = MISTRAL_AI_MAX_CONTENT;
break;
case "openai-image":
return;
@@ -44,7 +48,9 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
let modelMax: number;
if (model.match(/gpt-3.5-turbo-16k/)) {
modelMax = 16384;
} else if (model.match(/gpt-4-1106(-preview)?/)) {
} else if (model.match(/gpt-4-turbo(-preview)?$/)) {
modelMax = 131072;
} else if (model.match(/gpt-4-(0125|1106)(-preview)?$/)) {
modelMax = 131072;
} else if (model.match(/^gpt-4(-\d{4})?-vision(-preview)?$/)) {
modelMax = 131072;
@@ -62,8 +68,12 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 100000;
} else if (model.match(/^claude-2/)) {
modelMax = 200000;
} else if (model.match(/^text-bison-\d{3}$/)) {
modelMax = BISON_MAX_CONTEXT;
} else if (model.match(/^gemini-\d{3}$/)) {
modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^mistral-(tiny|small|medium)$/)) {
modelMax = MISTRAL_AI_MAX_CONTENT;
} else if (model.match(/^anthropic\.claude-v2:\d/)) {
modelMax = 200000;
} else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000;
@@ -1,8 +1,7 @@
import express from "express";
import { pipeline } from "stream";
import { promisify } from "util";
import {
buildFakeSse,
makeCompletionSSE,
copySseResponseHeaders,
initializeSseStream,
} from "../../../shared/streaming";
@@ -16,14 +15,18 @@ import { keyPool } from "../../../shared/key-management";
const pipelineAsync = promisify(pipeline);
/**
* Consume the SSE stream and forward events to the client. Once the stream is
* stream is closed, resolve with the full response body so that subsequent
* middleware can work with it.
* `handleStreamedResponse` consumes and transforms a streamed response from the
* upstream service, forwarding events to the client in their requested format.
* After the entire stream has been consumed, it resolves with the full response
* body so that subsequent middleware in the chain can process it as if it were
* a non-streaming response.
*
* Typically we would only need of the raw response handlers to execute, but
* in the event a streamed request results in a non-200 response, we need to
* fall back to the non-streaming response handler so that the error handler
* can inspect the error response.
* In the event of an error, the request's streaming flag is unset and the non-
* streaming response handler is called instead.
*
* If the error is retryable, that handler will re-enqueue the request and also
* reset the streaming flag. Unfortunately the streaming flag is set and unset
* in multiple places, so it's hard to keep track of.
*/
export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes,
@@ -49,8 +52,8 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
`Starting to proxy SSE stream.`
);
// Users waiting in the queue already have a SSE connection open for the
// heartbeat, so we can't always send the stream headers.
// Typically, streaming will have already been initialized by the request
// queue to send heartbeat pings.
if (!res.headersSent) {
copySseResponseHeaders(proxyRes, res);
initializeSseStream(res);
@@ -59,8 +62,11 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"];
const adapter = new SSEStreamAdapter({ contentType });
// Adapter turns some arbitrary stream (binary, JSON, etc.) into SSE events.
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
// Aggregator compiles all events into a single response object.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Transformer converts events to the user's requested format.
const transformer = new SSEMessageTransformer({
inputFormat: req.outboundApi,
inputApiVersion: String(req.headers["anthropic-version"]),
@@ -89,10 +95,20 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
`Re-enqueueing request due to retryable error during streaming response.`
);
req.retryCount++;
enqueue(req);
await enqueue(req);
} else {
const errorEvent = buildFakeSse("stream-error", err.message, req);
res.write(`${errorEvent}data: [DONE]\n\n`);
const { message, stack, lastEvent } = err;
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"
const errorEvent = makeCompletionSSE({
format: req.inboundApi,
title: "Proxy stream error",
message: "An unexpected error occurred while streaming the response.",
obj: { message, stack, lastEvent: eventText },
reqId: req.id,
model: req.body?.model,
});
res.write(errorEvent);
res.write(`data: [DONE]\n\n`);
res.end();
}
throw err;
+112 -33
View File
@@ -152,13 +152,13 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
};
};
function reenqueueRequest(req: Request) {
async function reenqueueRequest(req: Request) {
req.log.info(
{ key: req.key?.hash, retryCount: req.retryCount },
`Re-enqueueing request due to retryable error`
);
req.retryCount++;
enqueue(req);
await enqueue(req);
}
/**
@@ -192,7 +192,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
req.log.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, {
writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage,
contentEncoding,
});
@@ -209,7 +209,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
} catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
req.log.warn({ error: error.stack, key: req.key?.hash }, errorMessage);
writeErrorResponse(req, res, 500, { error: errorMessage });
writeErrorResponse(req, res, 500, "Internal Server Error", {
error: errorMessage,
});
return reject(errorMessage);
}
});
@@ -237,6 +239,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
body
) => {
const statusCode = proxyRes.statusCode || 500;
const statusMessage = proxyRes.statusMessage || "Internal Server Error";
if (statusCode < 400) {
return;
@@ -253,16 +256,16 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} catch (parseError) {
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
const hash = req.key?.hash;
const statusMessage = proxyRes.statusMessage || "Unknown error";
req.log.warn({ statusCode, statusMessage, key: hash }, parseError.message);
const errorObject = {
statusCode,
statusMessage: proxyRes.statusMessage,
error: parseError.message,
proxy_note: `This is likely a temporary error with the upstream service.`,
status: statusCode,
statusMessage,
proxy_note: `Proxy got back an error, but it was not in JSON format. This is likely a temporary problem with the upstream service.`,
};
writeErrorResponse(req, res, statusCode, errorObject);
writeErrorResponse(req, res, statusCode, statusMessage, errorObject);
throw new HttpError(statusCode, parseError.message);
}
@@ -288,7 +291,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// For Anthropic, this is usually due to missing preamble.
switch (service) {
case "openai":
case "google-palm":
case "google-ai":
case "mistral-ai":
case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) {
@@ -297,14 +301,14 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
}
break;
case "anthropic":
case "aws":
maybeHandleMissingPreambleError(req, errorPayload);
await maybeHandleMissingPreambleError(req, errorPayload);
break;
default:
assertNever(service);
@@ -314,7 +318,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 403) {
// Amazon is the only service that returns 403.
if (service === "anthropic") {
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
return;
}
switch (errorType) {
case "UnrecognizedClientException":
// Key is invalid.
@@ -335,19 +343,20 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
} else if (statusCode === 429) {
switch (service) {
case "openai":
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
await handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
break;
case "anthropic":
handleAnthropicRateLimitError(req, errorPayload);
await handleAnthropicRateLimitError(req, errorPayload);
break;
case "aws":
handleAwsRateLimitError(req, errorPayload);
await handleAwsRateLimitError(req, errorPayload);
break;
case "azure":
handleAzureRateLimitError(req, errorPayload);
case "mistral-ai":
await handleAzureRateLimitError(req, errorPayload);
break;
case "google-palm":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
case "google-ai":
await handleGoogleAIRateLimitError(req, errorPayload);
break;
default:
assertNever(service);
@@ -369,8 +378,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "anthropic":
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
break;
case "google-palm":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
case "google-ai":
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break;
case "mistral-ai":
errorPayload.proxy_note = `The requested Mistral AI model might not exist, or the key might not be provisioned for it.`;
break;
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
@@ -393,7 +405,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
);
}
writeErrorResponse(req, res, statusCode, errorPayload);
writeErrorResponse(req, res, statusCode, statusMessage, errorPayload);
throw new HttpError(statusCode, errorPayload.error?.message);
};
@@ -416,7 +428,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
* }
* ```
*/
function maybeHandleMissingPreambleError(
async function maybeHandleMissingPreambleError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
@@ -429,27 +441,27 @@ function maybeHandleMissingPreambleError(
"Request failed due to missing preamble. Key will be marked as such for subsequent requests."
);
keyPool.update(req.key!, { requiresPreamble: true });
reenqueueRequest(req);
await reenqueueRequest(req);
throw new RetryableError("Claude request re-enqueued to add preamble.");
} else {
errorPayload.proxy_note = `Proxy received unrecognized error from Anthropic. Check the specific error for more information.`;
}
}
function handleAnthropicRateLimitError(
async function handleAnthropicRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
await reenqueueRequest(req);
throw new RetryableError("Claude rate-limited request re-enqueued.");
} else {
errorPayload.proxy_note = `Unrecognized rate limit error from Anthropic. Key may be over quota.`;
}
}
function handleAwsRateLimitError(
async function handleAwsRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
@@ -457,7 +469,7 @@ function handleAwsRateLimitError(
switch (errorType) {
case "ThrottlingException":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
await reenqueueRequest(req);
throw new RetryableError("AWS rate-limited request re-enqueued.");
case "ModelNotReadyException":
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
@@ -467,11 +479,11 @@ function handleAwsRateLimitError(
}
}
function handleOpenAIRateLimitError(
async function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: ProxiedErrorPayload
): Record<string, any> {
): Promise<Record<string, any>> {
const type = errorPayload.error?.type;
switch (type) {
case "insufficient_quota":
@@ -500,8 +512,58 @@ function handleOpenAIRateLimitError(
}
// Per-minute request or token rate limit is exceeded, which we can retry
reenqueueRequest(req);
await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
// WIP/nonfunctional
// case "tokens_usage_based":
// // Weird new rate limit type that seems limited to preview models.
// // Distinct from `tokens` type. Can be per-minute or per-day.
//
// // I've seen reports of this error for 500k tokens/day and 10k tokens/min.
// // 10k tokens per minute is problematic, because this is much less than
// // GPT4-Turbo's max context size for a single prompt and is effectively a
// // cap on the max context size for just that key+model, which the app is
// // not able to deal with.
//
// // Similarly if there is a 500k tokens per day limit and 450k tokens have
// // been used today, the max context for that key becomes 50k tokens until
// // the next day and becomes progressively smaller as more tokens are used.
//
// // To work around these keys we will first retry the request a few times.
// // After that we will reject the request, and if it's a per-day limit we
// // will also disable the key.
//
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per day: Limit 500000, Used 460000, Requested 50000"
// // "Rate limit reached for gpt-4-1106-preview in organization org-xxxxxxxxxxxxxxxxxxx on tokens_usage_based per min: Limit 10000, Requested 40000"
//
// const regex =
// /Rate limit reached for .+ in organization .+ on \w+ per (day|min): Limit (\d+)(?:, Used (\d+))?, Requested (\d+)/;
// const [, period, limit, used, requested] =
// errorPayload.error?.message?.match(regex) || [];
//
// req.log.warn(
// { key: req.key?.hash, period, limit, used, requested },
// "Received `tokens_usage_based` rate limit error from OpenAI."
// );
//
// if (!period || !limit || !requested) {
// errorPayload.proxy_note = `Unrecognized rate limit error from OpenAI. (${errorPayload.error?.message})`;
// break;
// }
//
// if (req.retryCount < 2) {
// await reenqueueRequest(req);
// throw new RetryableError("Rate-limited request re-enqueued.");
// }
//
// if (period === "min") {
// errorPayload.proxy_note = `Assigned key can't be used for prompts longer than ${limit} tokens, and no other keys are available right now. Reduce the length of your prompt or try again in a few minutes.`;
// } else {
// errorPayload.proxy_note = `Assigned key has reached its per-day request limit for this model. Try another model.`;
// }
//
// keyPool.markRateLimited(req.key!);
// break;
default:
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
break;
@@ -509,7 +571,7 @@ function handleOpenAIRateLimitError(
return errorPayload;
}
function handleAzureRateLimitError(
async function handleAzureRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
@@ -517,7 +579,7 @@ function handleAzureRateLimitError(
switch (code) {
case "429":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `Unrecognized rate limit error from Azure (${code}). Please report this.`;
@@ -525,6 +587,23 @@ function handleAzureRateLimitError(
}
}
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
async function handleGoogleAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const status = errorPayload.error?.status;
switch (status) {
case "RESOURCE_EXHAUSTED":
keyPool.markRateLimited(req.key!);
await reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
break;
}
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model;
+8 -4
View File
@@ -9,7 +9,10 @@ import {
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/preprocessors/transform-outbound-payload";
import {
MistralAIChatMessage,
OpenAIChatMessage,
} from "../../../shared/api-schemas";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
@@ -54,12 +57,13 @@ type OaiImageResult = {
const getPromptForRequest = (
req: Request,
responseBody: Record<string, any>
): string | OpenAIChatMessage[] | OaiImageResult => {
): string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
switch (req.outboundApi) {
case "openai":
case "mistral-ai":
return req.body.messages;
case "openai-text":
return req.body.prompt;
@@ -73,7 +77,7 @@ const getPromptForRequest = (
};
case "anthropic":
return req.body.prompt;
case "google-palm":
case "google-ai":
return req.body.prompt.text;
default:
assertNever(req.outboundApi);
@@ -81,7 +85,7 @@ const getPromptForRequest = (
};
const flattenMessages = (
val: string | OpenAIChatMessage[] | OaiImageResult
val: string | OpenAIChatMessage[] | MistralAIChatMessage[] | OaiImageResult
): string => {
if (typeof val === "string") {
return val.trim();
@@ -4,7 +4,7 @@ import {
mergeEventsForAnthropic,
mergeEventsForOpenAIChat,
mergeEventsForOpenAIText,
OpenAIChatCompletionStreamEvent
OpenAIChatCompletionStreamEvent,
} from "./index";
/**
@@ -27,12 +27,13 @@ export class EventAggregator {
getFinalResponse() {
switch (this.format) {
case "openai":
case "google-ai":
case "mistral-ai":
return mergeEventsForOpenAIChat(this.events);
case "openai-text":
return mergeEventsForOpenAIText(this.events);
case "anthropic":
return mergeEventsForAnthropic(this.events);
case "google-palm":
case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`);
default:
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
@@ -7,9 +7,10 @@ import {
anthropicV2ToOpenAI,
OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat,
googleAIToOpenAI,
passthroughToOpenAI,
StreamingCompletionTransformer,
} from "./index";
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
const genlog = logger.child({ module: "sse-transformer" });
@@ -92,6 +93,7 @@ export class SSEMessageTransformer extends Transform {
this.push(transformedMessage);
callback();
} catch (err) {
err.lastEvent = chunk?.toString();
this.log.error(err, "Error transforming SSE message");
callback(err);
}
@@ -104,6 +106,7 @@ function getTransformer(
): StreamingCompletionTransformer {
switch (responseApi) {
case "openai":
case "mistral-ai":
return passthroughToOpenAI;
case "openai-text":
return openAITextToOpenAIChat;
@@ -111,7 +114,8 @@ function getTransformer(
return version === "2023-01-01"
? anthropicV1ToOpenAI
: anthropicV2ToOpenAI;
case "google-palm":
case "google-ai":
return googleAIToOpenAI;
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
@@ -1,12 +1,20 @@
import { Transform, TransformOptions } from "stream";
import { StringDecoder } from "string_decoder";
// @ts-ignore
import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../../logger";
import { RetryableError } from "../index";
import { APIFormat } from "../../../../shared/key-management";
import StreamArray from "stream-json/streamers/StreamArray";
import { makeCompletionSSE } from "../../../../shared/streaming";
const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
};
type AwsEventStreamMessage = {
headers: {
":message-type": "event" | "exception";
@@ -21,20 +29,31 @@ type AwsEventStreamMessage = {
*/
export class SSEStreamAdapter extends Transform {
private readonly isAwsStream;
private parser = new Parser();
private readonly isGoogleStream;
private awsParser = new Parser();
private jsonParser = StreamArray.withParser();
private partialMessage = "";
private decoder = new StringDecoder("utf8");
constructor(options?: SSEStreamAdapterOptions) {
super(options);
this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.parser.on("data", (data: AwsEventStreamMessage) => {
this.awsParser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
this.jsonParser.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
}
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
@@ -53,17 +72,19 @@ export class SSEStreamAdapter extends Transform {
);
throw new RetryableError("AWS request throttled mid-stream");
} else {
log.error(
{ event: eventStr },
"Received unexpected AWS event stream message"
);
return getFakeErrorCompletion("proxy AWS error", eventStr);
log.error({ event: eventStr }, "Received bad AWS stream event");
return makeCompletionSSE({
format: "anthropic",
title: "Proxy stream error",
message:
"The proxy received malformed or unexpected data from AWS while streaming.",
obj: event,
reqId: "proxy-sse-adapter-message",
model: "",
});
}
} else {
const { bytes } = payload;
// technically this is a transformation but we don't really distinguish
// between aws claude and anthropic claude at the APIFormat level, so
// these will short circuit the message transformer
return [
"event: completion",
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
@@ -71,44 +92,61 @@ export class SSEStreamAdapter extends Transform {
}
}
/** Processes an incoming array element from the Google AI JSON stream. */
protected processGoogleValue(value: any): string | null {
try {
const candidates = value.candidates ?? [{}];
const hasParts = candidates[0].content?.parts?.length > 0;
if (hasParts) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error({ event: value }, "Received bad Google AI event");
return `data: ${makeCompletionSSE({
format: "google-ai",
title: "Proxy stream error",
message:
"The proxy received malformed or unexpected data from Google AI while streaming.",
obj: value,
reqId: "proxy-sse-adapter-message",
model: "",
})}`;
}
} catch (error) {
error.lastEvent = value;
this.emit("error", error);
return null;
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try {
if (this.isAwsStream) {
this.parser.write(chunk);
this.awsParser.write(chunk);
} else if (this.isGoogleStream) {
this.jsonParser.write(chunk);
} else {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly.
const str = chunk.toString("utf8");
const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
const fullMessages = (this.partialMessage + str).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) {
// Mixing line endings will break some clients and our request queue
// will have already sent \n for heartbeats, so we need to normalize
// to \n.
this.push(message.replace(/\r\n/g, "\n") + "\n\n");
this.push(message.replace(/\r\n?/g, "\n") + "\n\n");
}
}
callback();
} catch (error) {
error.lastEvent = chunk?.toString();
this.emit("error", error);
callback(error);
}
}
}
function getFakeErrorCompletion(type: string, message: string) {
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
const fakeEvent = JSON.stringify({
log_id: "aws-proxy-sse-message",
stop_reason: type,
completion:
"\nProxy encountered an error during streaming response.\n" + content,
truncated: false,
stop: null,
model: "",
});
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
}
@@ -0,0 +1,76 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "google-ai-to-openai",
});
type GoogleAIStreamEvent = {
candidates: {
content: { parts: { text: string }[]; role: string };
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
index: number;
tokenCount?: number;
safetyRatings: { category: string; probability: string }[];
}[];
};
/**
* Transforms an incoming Google AI SSE to an equivalent OpenAI
* chat.completion.chunk SSE.
*/
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
const { data, index } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const parts = completionEvent.candidates[0].content.parts;
let content = parts[0]?.text ?? "";
// If this is the first chunk, try stripping speaker names from the response
// e.g. "John: Hello" -> "Hello"
if (index === 0) {
content = content.replace(/^(.*?): /, "").trim();
}
const newEvent = {
id: "goo-" + params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content },
finish_reason: completionEvent.candidates[0].finishReason ?? null,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
if (parsed.candidates?.length > 0) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}
+118
View File
@@ -0,0 +1,118 @@
import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
getMistralAIModelFamily,
MistralAIModelFamily,
ModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
// https://docs.mistral.ai/platform/endpoints
export const KNOWN_MISTRAL_AI_MODELS = [
"mistral-tiny",
"mistral-small",
"mistral-medium",
];
let modelsCache: any = null;
let modelsCacheTime = 0;
export function generateModelList(models = KNOWN_MISTRAL_AI_MODELS) {
let available = new Set<MistralAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "mistral-ai") continue;
key.modelFamilies.forEach((family) =>
available.add(family as MistralAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
return models
.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "mistral-ai",
}))
.filter((model) => available.has(getMistralAIModelFamily(model.id)));
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60){
return res.status(200).json(modelsCache);
}
const result = generateModelList();
modelsCache = { object: "list", data: result };
modelsCacheTime = new Date().getTime();
res.status(200).json(modelsCache);
};
const mistralAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
const mistralAIProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://api.mistral.ai",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
error: handleProxyError,
},
}),
});
const mistralAIRouter = Router();
mistralAIRouter.get("/v1/models", handleModelRequest);
// General chat completion endpoint.
mistralAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "mistral-ai",
outApi: "mistral-ai",
service: "mistral-ai",
}),
mistralAIProxy
);
export const mistralAI = mistralAIRouter;
+3 -2
View File
@@ -17,7 +17,6 @@ import {
} from "./middleware/response";
import { generateModelList } from "./openai";
import {
mirrorGeneratedImage,
OpenAIImageGenerationResult,
} from "../shared/file-storage/mirror-generated-image";
@@ -26,7 +25,9 @@ const KNOWN_MODELS = ["dall-e-2", "dall-e-3"];
let modelListCache: any = null;
let modelListValid = 0;
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelListValid < 1000 * 60) return modelListCache;
if (new Date().getTime() - modelListValid < 1000 * 60) {
return res.status(200).json(modelListCache);
}
const result = generateModelList(KNOWN_MODELS);
modelListCache = { object: "list", data: result };
modelListValid = new Date().getTime();
+7 -5
View File
@@ -28,6 +28,8 @@ import {
// https://platform.openai.com/docs/models/overview
export const KNOWN_OPENAI_MODELS = [
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
@@ -35,7 +37,7 @@ export const KNOWN_OPENAI_MODELS = [
"gpt-4-0314", // EOL 2024-06-13
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-4-32k-0314", // EOL 2024-06-13
// "gpt-4-32k-0314", // EOL 2024-06-13
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301", // EOL 2024-06-13
"gpt-3.5-turbo-0613",
@@ -83,7 +85,9 @@ export function generateModelList(models = KNOWN_OPENAI_MODELS) {
}
const handleModelRequest: RequestHandler = (_req, res) => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) return modelsCache;
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return res.status(200).json(modelsCache);
}
const result = generateModelList();
modelsCache = { object: "list", data: result };
modelsCacheTime = new Date().getTime();
@@ -161,9 +165,7 @@ const openaiProxy = createQueueMiddleware({
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, finalizeBody],
}),
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody], }),
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleProxyError,
},
-170
View File
@@ -1,170 +0,0 @@
import { Request, RequestHandler, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googlePalmKey) return { object: "list", data: [] };
const bisonVariants = ["text-bison-001"];
const models = bisonVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "palm",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const palmResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google PaLM response to OpenAI format");
body = transformPalmResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if
// requests wait in the queue for too long. Probably need to fake streaming
// and return the entire completion in one stream event using the other
// response handler.
res.status(200).json(body);
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformPalmResponse(
palmRespBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "plm-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: palmRespBody.candidates[0].output,
},
finish_reason: null, // palm doesn't return this
index: 0,
},
],
};
}
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
if (req.body.stream) {
throw new Error("Google PaLM API doesn't support streaming requests");
}
// PaLM API specifies the model in the URL path, not the request body. This
// doesn't work well with our rewriter architecture, so we need to manually
// fix it here.
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
// The chat api (generateMessage) is not very useful at this time as it has
// few params and no adjustable safety settings.
proxyReq.path = proxyReq.path.replace(
/^\/v1\/chat\/completions/,
`/v1beta2/models/${req.body.model}:generateText`
);
}
const googlePalmProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://generativelanguage.googleapis.com",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,
},
}),
});
const palmRouter = Router();
palmRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google PaLM compatibility endpoint.
palmRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
{ afterTransform: [forceModel("text-bison-001")] }
),
googlePalmProxy
);
export const googlePalm = palmRouter;
+57 -32
View File
@@ -14,8 +14,12 @@
import crypto from "crypto";
import type { Handler, Request } from "express";
import { keyPool } from "../shared/key-management";
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import {
getModelFamilyForRequest,
MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { makeCompletionSSE, initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
@@ -37,6 +41,7 @@ const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "50");
const PAYLOAD_SCALE_FACTOR = parseFloat(
process.env.PAYLOAD_SCALE_FACTOR ?? "6"
);
const QUEUE_JOIN_TIMEOUT = 5000;
/**
* Returns an identifier for a request. This is used to determine if a
@@ -60,7 +65,7 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
export function enqueue(req: Request) {
export async function enqueue(req: Request) {
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
let isGuest = req.user?.token === undefined;
@@ -92,7 +97,7 @@ export function enqueue(req: Request) {
if (stream === "true" || stream === true || req.isStreaming) {
const res = req.res!;
if (!res.headersSent) {
initStreaming(req);
await initStreaming(req);
}
registerHeartbeat(req);
} else if (getProxyLoad() > LOAD_THRESHOLD) {
@@ -119,7 +124,9 @@ export function enqueue(req: Request) {
if (req.retryCount ?? 0 > 0) {
req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`);
} else {
req.log.info(`Enqueued new request.`);
const size = req.socket.bytesRead;
const endpoint = req.url?.split("?")[0];
req.log.info({ size, endpoint }, `Enqueued new request.`);
}
}
@@ -189,10 +196,10 @@ function processQueue() {
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {
const modelFamily = getModelFamilyForRequest(req!);
req.log.info({
retries: req.retryCount,
partition: modelFamily,
}, `Dequeuing request.`);
req.log.info(
{ retries: req.retryCount, partition: modelFamily },
`Dequeuing request.`
);
req.proceed();
}
});
@@ -327,7 +334,7 @@ export function createQueueMiddleware({
beforeProxy?: RequestPreprocessor;
proxyMiddleware: Handler;
}): Handler {
return (req, res, next) => {
return async (req, res, next) => {
req.proceed = async () => {
if (beforeProxy) {
try {
@@ -345,7 +352,7 @@ export function createQueueMiddleware({
};
try {
enqueue(req);
await enqueue(req);
} catch (err: any) {
req.res!.status(429).json({
type: "proxy_error",
@@ -367,8 +374,15 @@ function killQueuedRequest(req: Request) {
try {
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
if (res.headersSent) {
const fakeErrorEvent = buildFakeSse("proxy queue error", message, req);
res.write(fakeErrorEvent);
const event = makeCompletionSSE({
format: req.inboundApi,
title: "Proxy queue error",
message,
reqId: String(req.id),
model: req.body?.model,
});
res.write(event);
res.write(`data: [DONE]\n\n`);
res.end();
} else {
res.status(500).json({ error: message });
@@ -378,20 +392,39 @@ function killQueuedRequest(req: Request) {
}
}
function initStreaming(req: Request) {
async function initStreaming(req: Request) {
const res = req.res!;
initializeSseStream(res);
if (req.query.badSseParser) {
// Some clients have a broken SSE parser that doesn't handle comments
// correctly. These clients can pass ?badSseParser=true to
// disable comments in the SSE stream.
res.write(getHeartbeatPayload());
return;
}
const joinMsg = `: joining queue at position ${
queue.length
}\n\n${getHeartbeatPayload()}`;
res.write(`: joining queue at position ${queue.length}\n\n`);
res.write(getHeartbeatPayload());
let drainTimeout: NodeJS.Timeout;
const welcome = new Promise<void>((resolve, reject) => {
const onDrain = () => {
clearTimeout(drainTimeout);
req.log.debug(`Client finished consuming join message.`);
res.off("drain", onDrain);
resolve();
};
drainTimeout = setTimeout(() => {
res.off("drain", onDrain);
res.destroy();
reject(new Error("Unreponsive streaming client; killing connection"));
}, QUEUE_JOIN_TIMEOUT);
if (!res.write(joinMsg)) {
req.log.warn("Kernel buffer is full; holding client request.");
res.once("drain", onDrain);
} else {
clearTimeout(drainTimeout);
resolve();
}
});
await welcome;
}
/**
@@ -451,14 +484,6 @@ function removeProxyMiddlewareEventListeners(req: Request) {
export function registerHeartbeat(req: Request) {
const res = req.res!;
const currentSize = getHeartbeatSize();
req.log.debug({
currentSize,
HEARTBEAT_INTERVAL,
PAYLOAD_SCALE_FACTOR,
MAX_HEARTBEAT_SIZE,
}, "Joining queue with heartbeat.");
let isBufferFull = false;
let bufferFullCount = 0;
req.heartbeatInterval = setInterval(() => {
@@ -502,7 +527,7 @@ function monitorHeartbeat(req: Request) {
if (bytesSinceLast < minBytes) {
req.log.warn(
{ minBytes, bytesSinceLast },
"Queued request is processing heartbeats enough data or server is overloaded; killing connection."
"Queued request is not processing heartbeats enough data or server is overloaded; killing connection."
);
res.destroy();
}
+4 -2
View File
@@ -4,7 +4,8 @@ import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { googleAI } from "./google-ai";
import { mistralAI } from "./mistral-ai";
import { aws } from "./aws";
import { azure } from "./azure";
@@ -31,7 +32,8 @@ proxyRouter.use((req, _res, next) => {
proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/mistral-ai", addV1, mistralAI);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
+34 -16
View File
@@ -13,6 +13,7 @@ import { keyPool } from "./shared/key-management";
import { adminRouter } from "./admin/routes";
import { proxyRouter } from "./proxy/routes";
import { handleInfoPage } from "./info-page";
import { buildInfo } from "./service-info";
import { logQueue } from "./shared/prompt-logging";
import { start as startRequestQueue } from "./proxy/queue";
import { init as initUserStore } from "./shared/users/user-store";
@@ -21,6 +22,7 @@ import { checkOrigin } from "./proxy/check-origin";
import { userRouter } from "./user/routes";
const PORT = config.port;
const BIND_ADDRESS = config.bindAddress;
const app = express();
// middleware
@@ -49,10 +51,7 @@ app.use(
})
);
// TODO: Detect (or support manual configuration of) whether the app is behind
// a load balancer/reverse proxy, which is necessary to determine request IP
// addresses correctly.
app.set("trust proxy", true);
app.set("trust proxy", Number(config.trustedProxies));
app.set("view engine", "ejs");
app.set("views", [
@@ -67,13 +66,18 @@ app.get("/health", (_req, res) => res.sendStatus(200));
app.use(cors());
app.use(checkOrigin);
// routes
app.get("/", handleInfoPage);
if (config.staticServiceInfo) {
app.get("/", (_req, res) => res.sendStatus(200));
} else {
app.get("/", handleInfoPage);
}
app.get("/status", (req, res) => {
res.json(buildInfo(req.protocol + "://" + req.get("host"), false));
});
app.use("/admin", adminRouter);
app.use("/proxy", proxyRouter);
app.use("/user", userRouter);
// 500 and 404
app.use((err: any, _req: unknown, res: express.Response, _next: unknown) => {
if (err.status) {
res.status(err.status).json({ error: err.message });
@@ -120,15 +124,18 @@ async function start() {
logger.info("Starting request queue...");
startRequestQueue();
app.listen(PORT, async () => {
logger.info({ port: PORT }, "Now listening for connections.");
registerUncaughtExceptionHandler();
});
const diskSpace = await checkDiskSpace(
__dirname.startsWith("/app") ? "/app" : os.homedir()
);
app.listen(PORT, BIND_ADDRESS, () => {
logger.info(
{ port: PORT, interface: BIND_ADDRESS },
"Now listening for connections."
);
registerUncaughtExceptionHandler();
});
logger.info(
{ build: process.env.BUILD_INFO, nodeEnv: process.env.NODE_ENV, diskSpace },
"Startup complete."
@@ -158,7 +165,18 @@ function registerUncaughtExceptionHandler() {
* didn't set it to something misleading.
*/
async function setBuildInfo() {
// Render .dockerignore's the .git directory but provides info in the env
// For CI builds, use the env vars set during the build process
if (process.env.GITGUD_BRANCH) {
const sha = process.env.GITGUD_COMMIT?.slice(0, 7) || "unknown SHA";
const branch = process.env.GITGUD_BRANCH;
const repo = process.env.GITGUD_PROJECT;
const buildInfo = `[ci] ${sha} (${branch}@${repo})`;
process.env.BUILD_INFO = buildInfo;
logger.info({ build: buildInfo }, "Using build info from CI image.");
return;
}
// For render, the git directory is dockerignore'd so we use env vars
if (process.env.RENDER) {
const sha = process.env.RENDER_GIT_COMMIT?.slice(0, 7) || "unknown SHA";
const branch = process.env.RENDER_GIT_BRANCH || "unknown branch";
@@ -169,10 +187,10 @@ async function setBuildInfo() {
return;
}
// For huggingface and bare metal deployments, we can get the info from git
try {
// Ignore git's complaints about dubious directory ownership on Huggingface
// (which evidently runs dockerized Spaces on Windows with weird NTFS perms)
if (process.env.SPACE_ID) {
// TODO: may not be necessary anymore with adjusted Huggingface dockerfile
childProcess.execSync("git config --global --add safe.directory /app");
}
@@ -192,7 +210,7 @@ async function setBuildInfo() {
let [sha, branch, remote, status] = await Promise.all(promises);
remote = remote.match(/.*[\/:]([\w-]+)\/([\w\-\.]+?)(?:\.git)?$/) || [];
remote = remote.match(/.*[\/:]([\w-]+)\/([\w\-.]+?)(?:\.git)?$/) || [];
const repo = remote.slice(-2).join("/");
status = status
// ignore Dockerfile changes since that's how the user deploys the app
+441
View File
@@ -0,0 +1,441 @@
/** Calculates and returns stats about the service. */
import { config, listConfig } from "./config";
import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GoogleAIKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
import {
AnthropicModelFamily,
assertIsKnownModelFamily,
AwsBedrockModelFamily,
AzureOpenAIModelFamily,
GoogleAIModelFamily,
LLM_SERVICES,
LLMService,
MistralAIModelFamily,
MODEL_FAMILY_SERVICE,
ModelFamily,
OpenAIModelFamily,
} from "./shared/models";
import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats";
import { getUniqueIps } from "./proxy/rate-limit";
import { assertNever } from "./shared/utils";
import { getEstimatedWaitTime, getQueueLength } from "./proxy/queue";
import { MistralAIKey } from "./shared/key-management/mistral-ai/provider";
const CACHE_TTL = 2000;
type KeyPoolKey = ReturnType<typeof keyPool.list>[0];
const keyIsOpenAIKey = (k: KeyPoolKey): k is OpenAIKey =>
k.service === "openai";
const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai";
const keyIsMistralAIKey = (k: KeyPoolKey): k is MistralAIKey =>
k.service === "mistral-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
/** Stats aggregated across all keys for a given service. */
type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs";
/** Stats aggregated across all keys for a given model family. */
type ModelAggregates = {
active: number;
trial?: number;
revoked?: number;
overQuota?: number;
pozzed?: number;
awsLogged?: number;
queued: number;
queueTime: string;
tokens: number;
};
/** All possible combinations of model family and aggregate type. */
type ModelAggregateKey = `${ModelFamily}__${keyof ModelAggregates}`;
type AllStats = {
proompts: number;
tokens: number;
tokenCost: number;
} & { [modelFamily in ModelFamily]?: ModelAggregates } & {
[service in LLMService as `${service}__${ServiceAggregate}`]?: number;
};
type BaseFamilyInfo = {
usage?: string;
activeKeys: number;
revokedKeys?: number;
proomptersInQueue?: number;
estimatedQueueTime?: string;
};
type OpenAIInfo = BaseFamilyInfo & {
trialKeys?: number;
overQuotaKeys?: number;
};
type AnthropicInfo = BaseFamilyInfo & { pozzedKeys?: number };
type AwsInfo = BaseFamilyInfo & { privacy?: string };
// prettier-ignore
export type ServiceInfo = {
uptime: number;
endpoints: {
openai?: string;
openai2?: string;
"openai-image"?: string;
anthropic?: string;
"google-ai"?: string;
"mistral-ai"?: string;
aws?: string;
azure?: string;
};
proompts?: number;
tookens?: string;
proomptersNow?: number;
status?: string;
config: ReturnType<typeof listConfig>;
build: string;
} & { [f in OpenAIModelFamily]?: OpenAIInfo }
& { [f in AnthropicModelFamily]?: AnthropicInfo; }
& { [f in AwsBedrockModelFamily]?: AwsInfo }
& { [f in AzureOpenAIModelFamily]?: BaseFamilyInfo; }
& { [f in GoogleAIModelFamily]?: BaseFamilyInfo }
& { [f in MistralAIModelFamily]?: BaseFamilyInfo };
// https://stackoverflow.com/a/66661477
// type DeepKeyOf<T> = (
// [T] extends [never]
// ? ""
// : T extends object
// ? {
// [K in Exclude<keyof T, symbol>]: `${K}${DotPrefix<DeepKeyOf<T[K]>>}`;
// }[Exclude<keyof T, symbol>]
// : ""
// ) extends infer D
// ? Extract<D, string>
// : never;
// type DotPrefix<T extends string> = T extends "" ? "" : `.${T}`;
// type ServiceInfoPath = `{${DeepKeyOf<ServiceInfo>}}`;
const SERVICE_ENDPOINTS: { [s in LLMService]: Record<string, string> } = {
openai: {
openai: `%BASE%/openai`,
openai2: `%BASE%/openai/turbo-instruct`,
"openai-image": `%BASE%/openai-image`,
},
anthropic: {
anthropic: `%BASE%/anthropic`,
},
"google-ai": {
"google-ai": `%BASE%/google-ai`,
},
"mistral-ai": {
"mistral-ai": `%BASE%/mistral-ai`,
},
aws: {
aws: `%BASE%/aws/claude`,
},
azure: {
azure: `%BASE%/azure/openai`,
},
};
const modelStats = new Map<ModelAggregateKey, number>();
const serviceStats = new Map<keyof AllStats, number>();
let cachedInfo: ServiceInfo | undefined;
let cacheTime = 0;
export function buildInfo(baseUrl: string, forAdmin = false): ServiceInfo {
if (cacheTime + CACHE_TTL > Date.now()) return cachedInfo!;
const keys = keyPool.list();
const accessibleFamilies = new Set(
keys
.flatMap((k) => k.modelFamilies)
.filter((f) => config.allowedModelFamilies.includes(f))
.concat("turbo")
);
modelStats.clear();
serviceStats.clear();
keys.forEach(addKeyToAggregates);
const endpoints = getEndpoints(baseUrl, accessibleFamilies);
const trafficStats = getTrafficStats();
const { serviceInfo, modelFamilyInfo } =
getServiceModelStats(accessibleFamilies);
const status = getStatus();
if (config.staticServiceInfo && !forAdmin) {
delete trafficStats.proompts;
delete trafficStats.tookens;
delete trafficStats.proomptersNow;
for (const family of Object.keys(modelFamilyInfo)) {
assertIsKnownModelFamily(family);
delete modelFamilyInfo[family]?.proomptersInQueue;
delete modelFamilyInfo[family]?.estimatedQueueTime;
delete modelFamilyInfo[family]?.usage;
}
}
return (cachedInfo = {
uptime: Math.floor(process.uptime()),
endpoints,
...trafficStats,
...serviceInfo,
status,
...modelFamilyInfo,
config: listConfig(),
build: process.env.BUILD_INFO || "dev",
});
}
function getStatus() {
if (!config.checkKeys) return "Key checking is disabled.";
let unchecked = 0;
for (const service of LLM_SERVICES) {
unchecked += serviceStats.get(`${service}__uncheckedKeys`) || 0;
}
return unchecked ? `Checking ${unchecked} keys...` : undefined;
}
function getEndpoints(baseUrl: string, accessibleFamilies: Set<ModelFamily>) {
const endpoints: Record<string, string> = {};
for (const service of LLM_SERVICES) {
for (const [name, url] of Object.entries(SERVICE_ENDPOINTS[service])) {
endpoints[name] = url.replace("%BASE%", baseUrl);
}
if (service === "openai" && !accessibleFamilies.has("dall-e")) {
delete endpoints["openai-image"];
}
}
return endpoints;
}
type TrafficStats = Pick<ServiceInfo, "proompts" | "tookens" | "proomptersNow">;
function getTrafficStats(): TrafficStats {
const tokens = serviceStats.get("tokens") || 0;
const tokenCost = serviceStats.get("tokenCost") || 0;
return {
proompts: serviceStats.get("proompts") || 0,
tookens: `${prettyTokens(tokens)}${getCostSuffix(tokenCost)}`,
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
}
function getServiceModelStats(accessibleFamilies: Set<ModelFamily>) {
const serviceInfo: {
[s in LLMService as `${s}${"Keys" | "Orgs"}`]?: number;
} = {};
const modelFamilyInfo: { [f in ModelFamily]?: BaseFamilyInfo } = {};
for (const service of LLM_SERVICES) {
const hasKeys = serviceStats.get(`${service}__keys`) || 0;
if (!hasKeys) continue;
serviceInfo[`${service}Keys`] = hasKeys;
accessibleFamilies.forEach((f) => {
if (MODEL_FAMILY_SERVICE[f] === service) {
modelFamilyInfo[f] = getInfoForFamily(f);
}
});
if (service === "openai" && config.checkKeys) {
serviceInfo.openaiOrgs = getUniqueOpenAIOrgs(keyPool.list());
}
}
return { serviceInfo, modelFamilyInfo };
}
function getUniqueOpenAIOrgs(keys: KeyPoolKey[]) {
const orgIds = new Set(
keys.filter((k) => k.service === "openai").map((k: any) => k.organizationId)
);
return orgIds.size;
}
function increment<T extends keyof AllStats | ModelAggregateKey>(
map: Map<T, number>,
key: T,
delta = 1
) {
map.set(key, (map.get(key) || 0) + delta);
}
function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openai__keys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropic__keys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "google-ai__keys", k.service === "google-ai" ? 1 : 0);
increment(serviceStats, "mistral-ai__keys", k.service === "mistral-ai" ? 1 : 0);
increment(serviceStats, "aws__keys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azure__keys", k.service === "azure" ? 1 : 0);
let sumTokens = 0;
let sumCost = 0;
switch (k.service) {
case "openai":
if (!keyIsOpenAIKey(k)) throw new Error("Invalid key type");
increment(
serviceStats,
"openai__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__trial`, k.isTrial ? 1 : 0);
increment(modelStats, `${f}__overQuota`, k.isOverQuota ? 1 : 0);
});
break;
case "azure":
if (!keyIsAzureKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
});
break;
case "anthropic": {
if (!keyIsAnthropicKey(k)) throw new Error("Invalid key type");
const family = "claude";
sumTokens += k.claudeTokens;
sumCost += getTokenCostUsd(family, k.claudeTokens);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.claudeTokens);
increment(modelStats, `${family}__pozzed`, k.isPozzed ? 1 : 0);
increment(
serviceStats,
"anthropic__uncheckedKeys",
Boolean(k.lastChecked) ? 0 : 1
);
break;
}
case "google-ai": {
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
const family = "gemini-pro";
sumTokens += k["gemini-proTokens"];
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break;
}
case "mistral-ai": {
if (!keyIsMistralAIKey(k)) throw new Error("Invalid key type");
k.modelFamilies.forEach((f) => {
const tokens = k[`${f}Tokens`];
sumTokens += tokens;
sumCost += getTokenCostUsd(f, tokens);
increment(modelStats, `${f}__tokens`, tokens);
increment(modelStats, `${f}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${f}__active`, k.isDisabled ? 0 : 1);
});
break;
}
case "aws": {
if (!keyIsAwsKey(k)) throw new Error("Invalid key type");
const family = "aws-claude";
sumTokens += k["aws-claudeTokens"];
sumCost += getTokenCostUsd(family, k["aws-claudeTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k["aws-claudeTokens"]);
// Ignore revoked keys for aws logging stats, but include keys where the
// logging status is unknown.
const countAsLogged =
k.lastChecked && !k.isDisabled && k.awsLoggingStatus !== "disabled";
increment(modelStats, `${family}__awsLogged`, countAsLogged ? 1 : 0);
break;
}
default:
assertNever(k.service);
}
increment(serviceStats, "tokens", sumTokens);
increment(serviceStats, "tokenCost", sumCost);
}
function getInfoForFamily(family: ModelFamily): BaseFamilyInfo {
const tokens = modelStats.get(`${family}__tokens`) || 0;
const cost = getTokenCostUsd(family, tokens);
let info: BaseFamilyInfo & OpenAIInfo & AnthropicInfo & AwsInfo = {
usage: `${prettyTokens(tokens)} tokens${getCostSuffix(cost)}`,
activeKeys: modelStats.get(`${family}__active`) || 0,
revokedKeys: modelStats.get(`${family}__revoked`) || 0,
};
// Add service-specific stats to the info object.
if (config.checkKeys) {
const service = MODEL_FAMILY_SERVICE[family];
switch (service) {
case "openai":
info.overQuotaKeys = modelStats.get(`${family}__overQuota`) || 0;
info.trialKeys = modelStats.get(`${family}__trial`) || 0;
// Delete trial/revoked keys for non-turbo families.
// Trials are turbo 99% of the time, and if a key is invalid we don't
// know what models it might have had assigned to it.
if (family !== "turbo") {
delete info.trialKeys;
delete info.revokedKeys;
}
break;
case "anthropic":
info.pozzedKeys = modelStats.get(`${family}__pozzed`) || 0;
break;
case "aws":
const logged = modelStats.get(`${family}__awsLogged`) || 0;
if (logged > 0) {
info.privacy = config.allowAwsLogging
? `${logged} active keys are potentially logged.`
: `${logged} active keys are potentially logged and can't be used. Set ALLOW_AWS_LOGGING=true to override.`;
}
break;
}
}
// Add queue stats to the info object.
const queue = getQueueInformation(family);
info.proomptersInQueue = queue.proomptersInQueue;
info.estimatedQueueTime = queue.estimatedQueueTime;
return info;
}
/** Returns queue time in seconds, or minutes + seconds if over 60 seconds. */
function getQueueInformation(partition: ModelFamily) {
const waitMs = getEstimatedWaitTime(partition);
const waitTime =
waitMs < 60000
? `${Math.round(waitMs / 1000)}sec`
: `${Math.round(waitMs / 60000)}min, ${Math.round(
(waitMs % 60000) / 1000
)}sec`;
return {
proomptersInQueue: getQueueLength(partition),
estimatedQueueTime: waitMs > 2000 ? waitTime : "no wait",
};
}
+92
View File
@@ -0,0 +1,92 @@
import { z } from "zod";
import { Request } from "express";
import { config } from "../../config";
import {
flattenOpenAIMessageContent,
OpenAIChatMessage,
OpenAIV1ChatCompletionSchema,
} from "./openai";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
// https://console.anthropic.com/docs/api/reference#-v1-complete
export const AnthropicV1CompleteSchema = z
.object({
model: z.string().max(100),
prompt: z.string({
required_error:
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
}),
max_tokens_to_sample: z.coerce
.number()
.int()
.transform((v) => Math.min(v, CLAUDE_OUTPUT_MAX)),
stop_sequences: z.array(z.string().max(500)).optional(),
stream: z.boolean().optional().default(false),
temperature: z.coerce.number().optional().default(1),
top_k: z.coerce.number().optional(),
top_p: z.coerce.number().optional(),
})
.strip();
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}
export function openAIToAnthropic(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Anthropic request"
);
throw result.error;
}
req.headers["anthropic-version"] = "2023-06-01";
const { messages, ...rest } = result.data;
const prompt = openAIMessagesToClaudePrompt(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
// Recommended by Anthropic
stops.push("\n\nHuman:");
// Helps with jailbreak prompts that send fake system messages and multi-bot
// chats that prefix bot messages with "System: Respond as <bot name>".
stops.push("\n\nSystem:");
// Remove duplicates
stops = [...new Set(stops)];
return {
model: rest.model,
prompt: prompt,
max_tokens_to_sample: rest.max_tokens,
stop_sequences: stops,
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
};
}
+124
View File
@@ -0,0 +1,124 @@
import { z } from "zod";
import { Request } from "express";
import {
flattenOpenAIMessageContent,
OpenAIV1ChatCompletionSchema,
} from "./openai";
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
}),
),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
export function openAIToGoogleAI(
req: Request,
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Google AI request",
);
throw result.error;
}
const { messages, ...rest } = result.data;
const foundNames = new Set<string>();
const contents = messages
.map((m) => {
const role = m.role === "assistant" ? "model" : "user";
// Detects character names so we can set stop sequences for them as Gemini
// is prone to continuing as the next character.
// If names are not available, we'll still try to prefix the message
// with generic names so we can set stops for them but they don't work
// as well as real names.
const text = flattenOpenAIMessageContent(m.content);
const propName = m.name?.trim();
const textName =
m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim();
const name =
propName || textName || (role === "model" ? "Character" : "User");
foundNames.add(name);
// Prefixing messages with their character name seems to help avoid
// Gemini trying to continue as the next character, or at the very least
// ensures it will hit the stop sequence. Otherwise it will start a new
// paragraph and switch perspectives.
// The response will be very likely to include this prefix so frontends
// will need to strip it out.
const textPrefix = textName ? "" : `${name}: `;
return {
parts: [{ text: textPrefix + text }],
role: m.role === "assistant" ? ("model" as const) : ("user" as const),
};
})
.reduce<GoogleAIChatMessage[]>((acc, msg) => {
const last = acc[acc.length - 1];
if (last?.role === msg.role) {
last.parts[0].text += "\n\n" + msg.parts[0].text;
} else {
acc.push(msg);
}
return acc;
}, []);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push(...Array.from(foundNames).map((name) => `\n${name}:`));
stops = [...new Set(stops)].slice(0, 5);
return {
model: "gemini-pro",
stream: rest.stream,
contents,
tools: [],
generationConfig: {
maxOutputTokens: rest.max_tokens,
stopSequences: stops,
topP: rest.top_p,
topK: 40, // openai schema doesn't have this, google ai defaults to 40
temperature: rest.temperature,
},
safetySettings: [
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
],
};
}
+21
View File
@@ -0,0 +1,21 @@
import { z } from "zod";
import { APIFormat } from "../key-management";
import { AnthropicV1CompleteSchema } from "./anthropic";
import { OpenAIV1ChatCompletionSchema } from "./openai";
import { OpenAIV1TextCompletionSchema } from "./openai-text";
import { OpenAIV1ImagesGenerationSchema } from "./openai-image";
import { GoogleAIV1GenerateContentSchema } from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
export { OpenAIChatMessage } from "./openai";
export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai";
export const API_SCHEMA_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
};
+60
View File
@@ -0,0 +1,60 @@
import { z } from "zod";
import { OPENAI_OUTPUT_MAX } from "./openai";
// https://docs.mistral.ai/api#operation/createChatCompletion
export const MistralAIV1ChatCompletionsSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
})
),
temperature: z.number().optional().default(0.7),
top_p: z.number().optional().default(1),
max_tokens: z.coerce
.number()
.int()
.nullish()
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
stream: z.boolean().optional().default(false),
safe_prompt: z.boolean().optional().default(false),
random_seed: z.number().int().optional(),
});
export type MistralAIChatMessage = z.infer<
typeof MistralAIV1ChatCompletionsSchema
>["messages"][0];
export function fixMistralPrompt(
messages: MistralAIChatMessage[]
): MistralAIChatMessage[] {
// Mistral uses OpenAI format but has some additional requirements:
// - Only one system message per request, and it must be the first message if
// present.
// - Final message must be a user message.
// - Cannot have multiple messages from the same role in a row.
// While frontends should be able to handle this, we can fix it here in the
// meantime.
return messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
if (acc.length === 0) {
acc.push(msg);
return acc;
}
const copy = { ...msg };
// Reattribute subsequent system messages to the user
if (msg.role === "system") {
copy.role = "user";
}
// Consolidate multiple messages from the same role
const last = acc[acc.length - 1];
if (last.role === copy.role) {
last.content += "\n\n" + copy.content;
} else {
acc.push(copy);
}
return acc;
}, []);
}
+66
View File
@@ -0,0 +1,66 @@
import { z } from "zod";
import { Request } from "express";
import { OpenAIV1ChatCompletionSchema } from "./openai";
// https://platform.openai.com/docs/api-reference/images/create
export const OpenAIV1ImagesGenerationSchema = z
.object({
prompt: z.string().max(4000),
model: z.string().max(100).optional(),
quality: z.enum(["standard", "hd"]).optional().default("standard"),
n: z.number().int().min(1).max(4).optional().default(1),
response_format: z.enum(["url", "b64_json"]).optional(),
size: z
.enum(["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"])
.optional()
.default("1024x1024"),
style: z.enum(["vivid", "natural"]).optional().default("vivid"),
user: z.string().max(500).optional(),
})
.strip();
// Takes the last chat message and uses it verbatim as the image prompt.
export function openAIToOpenAIImage(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-image request",
);
throw result.error;
}
const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (Array.isArray(prompt)) {
throw new Error("Image generation prompt must be a text message.");
}
if (body.stream) {
throw new Error(
"Streaming is not supported for image generation requests.",
);
}
// Some frontends do weird things with the prompt, like prefixing it with a
// character name or wrapping the entire thing in quotes. We will look for
// the index of "Image:" and use everything after that as the prompt.
const index = prompt?.toLowerCase().indexOf("image:");
if (index === -1 || !prompt) {
throw new Error(
`Start your prompt with 'Image:' followed by a description of the image you want to generate (received: ${prompt}).`,
);
}
// TODO: Add some way to specify parameters via chat message
const transformed = {
model: body.model.includes("dall-e") ? body.model : "dall-e-3",
quality: "standard",
size: "1024x1024",
response_format: "url",
prompt: prompt.slice(index! + 6).trim(),
};
return OpenAIV1ImagesGenerationSchema.parse(transformed);
}
+56
View File
@@ -0,0 +1,56 @@
import { z } from "zod";
import {
flattenOpenAIChatMessages,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { Request } from "express";
export const OpenAIV1TextCompletionSchema = z
.object({
model: z
.string()
.max(100)
.regex(
/^gpt-3.5-turbo-instruct/,
"Model must start with 'gpt-3.5-turbo-instruct'"
),
prompt: z.string({
required_error:
"No `prompt` found. Ensure you've set the correct completion endpoint.",
}),
logprobs: z.number().int().nullish().default(null),
echo: z.boolean().optional().default(false),
best_of: z.literal(1).optional(),
stop: z
.union([z.string().max(500), z.array(z.string().max(500)).max(4)])
.optional(),
suffix: z.string().max(1000).optional(),
})
.strip()
.merge(OpenAIV1ChatCompletionSchema.omit({ messages: true, logprobs: true }));
export function openAIToOpenAIText(req: Request) {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-OpenAI-text request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push("\n\nUser:");
stops = [...new Set(stops)];
const transformed = { ...rest, prompt: prompt, stop: stops };
return OpenAIV1TextCompletionSchema.parse(transformed);
}
+133
View File
@@ -0,0 +1,133 @@
import { z } from "zod";
import { config } from "../../config";
export const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatContentArraySchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.union([z.literal("image"), z.literal("image_url")]),
image_url: z.object({
url: z.string().url(),
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
}),
}),
])
);
export const OpenAIV1ChatCompletionSchema = z
.object({
model: z.string().max(100),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant", "tool", "function"]),
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
name: z.string().optional(),
tool_calls: z.array(z.any()).optional(),
function_call: z.array(z.any()).optional(),
tool_call_id: z.string().optional(),
}),
{
required_error:
"No `messages` found. Ensure you've set the correct completion endpoint.",
invalid_type_error:
"Messages were not formatted correctly. Refer to the OpenAI Chat API documentation for more information.",
}
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z
.literal(1, {
errorMap: () => ({
message: "You may only request a single completion at a time.",
}),
})
.optional(),
stream: z.boolean().optional().default(false),
stop: z
.union([z.string().max(500), z.array(z.string().max(500))])
.optional(),
max_tokens: z.coerce
.number()
.int()
.nullish()
.default(16)
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
user: z.string().max(500).optional(),
seed: z.number().int().optional(),
// Be warned that Azure OpenAI combines these two into a single field.
// It's the only deviation from the OpenAI API that I'm aware of so I have
// special cased it in `addAzureKey` rather than expecting clients to do it.
logprobs: z.boolean().optional(),
top_logprobs: z.number().int().optional(),
// Quickly adding some newer tool usage params, not tested. They will be
// passed through to the API as-is.
tools: z.array(z.any()).optional(),
functions: z.array(z.any()).optional(),
tool_choice: z.any().optional(),
function_choice: z.any().optional(),
response_format: z.any(),
})
// Tool usage must be enabled via config because we currently have no way to
// track quota usage for them or enforce limits.
.omit(
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, functions: true }
)
.strip();
export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema
>["messages"][0];
export function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"]
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}
export function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
// Temporary to allow experimenting with prompt strategies
const PROMPT_VERSION: number = 1;
switch (PROMPT_VERSION) {
case 1:
return (
messages
.map((m) => {
// Claude-style human/assistant turns
let role: string = m.role;
if (role === "assistant") {
role = "Assistant";
} else if (role === "system") {
role = "System";
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
case 2:
return messages
.map((m) => {
// Claude without prefixes (except system) and no Assistant priming
let role: string = "";
if (role === "system") {
role = "System: ";
}
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
})
.join("");
default:
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}
+5 -3
View File
@@ -1,8 +1,10 @@
// noinspection JSUnusedGlobalSymbols,ES6UnusedImports
import type { HttpRequest } from "@smithy/types";
import { Express } from "express-serve-static-core";
import { APIFormat, Key, LLMService } from "../shared/key-management";
import { User } from "../shared/users/schema";
import { ModelFamily } from "../shared/models";
import { APIFormat, Key } from "./key-management";
import { User } from "./users/schema";
import { LLMService, ModelFamily } from "./models";
declare global {
namespace Express {
@@ -48,20 +48,20 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
protected handleAxiosError(key: AnthropicKey, error: AxiosError) {
if (error.response && AnthropicKeyChecker.errorIsAnthropicAPIError(error)) {
const { status, data } = error.response;
if (status === 401) {
if (status === 401 || status === 403) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
} else if (status === 429) {
}
else if (status === 429) {
switch (data.error.type) {
case "rate_limit_error":
this.log.warn(
{ key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds."
);
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next });
break;
+41 -26
View File
@@ -36,34 +36,10 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
protected async testKeyOrFail(key: AzureOpenAIKey) {
const model = await this.testModel(key);
this.log.info(
{ key: key.hash, deploymentModel: model },
"Checked key."
);
this.log.info({ key: key.hash, deploymentModel: model }, "Checked key.");
this.updateKey(key.hash, { modelFamilies: [model] });
}
// provided api-key header isn't valid (401)
// {
// "error": {
// "code": "401",
// "message": "Access denied due to invalid subscription key or wrong API endpoint. Make sure to provide a valid key for an active subscription and use a correct regional API endpoint for your resource."
// }
// }
// api key correct but deployment id is wrong (404)
// {
// "error": {
// "code": "DeploymentNotFound",
// "message": "The API deployment for this resource does not exist. If you created the deployment within the last 5 minutes, please wait a moment and try again."
// }
// }
// resource name is wrong (node will throw ENOTFOUND)
// rate limited (429)
// TODO: try to reproduce this
protected handleAxiosError(key: AzureOpenAIKey, error: AxiosError) {
if (error.response && AzureOpenAIKeyChecker.errorIsAzureError(error)) {
const data = error.response.data;
@@ -88,6 +64,20 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
isDisabled: true,
isRevoked: true,
});
case "429":
this.log.warn(
{ key: key.hash, errorType, error: error.response.data },
"Key is rate limited. Rechecking key in 1 minute."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
setTimeout(async () => {
this.log.info(
{ key: key.hash },
"Rechecking Azure key after rate limit."
);
await this.checkKey(key);
}, 1000 * 60);
return;
default:
this.log.error(
{ key: key.hash, errorType, error: error.response.data, status },
@@ -129,7 +119,32 @@ export class AzureOpenAIKeyChecker extends KeyCheckerBase<AzureOpenAIKey> {
headers: { "Content-Type": "application/json", "api-key": apiKey },
});
return getAzureOpenAIModelFamily(data.model);
const family = getAzureOpenAIModelFamily(data.model);
// Azure returns "gpt-4" even for GPT-4 Turbo, so we need further checks.
// Otherwise we can use the model family Azure returned.
if (family !== "azure-gpt4") {
return family;
}
// Try to send an oversized prompt. GPT-4 Turbo can handle this but regular
// GPT-4 will return a Bad Request error.
const contextText = {
max_tokens: 9000,
stream: false,
temperature: 0,
seed: 0,
messages: [{ role: "user", content: "" }],
};
const { data: contextTest, status } = await axios.post(url, contextText, {
headers: { "Content-Type": "application/json", "api-key": apiKey },
validateStatus: (status) => status === 400 || status === 200,
});
const code = contextTest.error?.code;
this.log.debug({ code, status }, "Performed Azure GPT4 context size test.");
if (code === "context_length_exceeded") return "azure-gpt4";
return "azure-gpt4-turbo";
}
static errorIsAzureError(error: AxiosError): error is AxiosError<AzureError> {
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
@@ -2,13 +2,17 @@ import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
import type { GoogleAIModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language
export type GooglePalmModel = "text-bison-001";
// Note that Google AI is not the same as Vertex AI, both are provided by Google
// but Vertex is the GCP product for enterprise. while Google AI is the
// consumer-ish product. The API is different, and keys are not compatible.
// https://ai.google.dev/docs/migrate_to_cloud
export type GooglePalmKeyUpdate = Omit<
Partial<GooglePalmKey>,
export type GoogleAIModel = "gemini-pro";
export type GoogleAIKeyUpdate = Omit<
Partial<GoogleAIKey>,
| "key"
| "hash"
| "lastUsed"
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
| "rateLimitedUntil"
>;
type GooglePalmKeyUsage = {
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
type GoogleAIKeyUsage = {
[K in GoogleAIModelFamily as `${K}Tokens`]: number;
};
export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
readonly service: "google-palm";
readonly modelFamilies: GooglePalmModelFamily[];
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
*/
const KEY_REUSE_DELAY = 500;
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
readonly service = "google-palm";
export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
readonly service = "google-ai";
private keys: GooglePalmKey[] = [];
private keys: GoogleAIKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.googlePalmKey?.trim();
const keyConfig = config.googleAIKey?.trim();
if (!keyConfig) {
this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
"GOOGLE_AI_KEY is not set. Google AI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GooglePalmKey = {
const newKey: GoogleAIKey = {
key,
service: this.service,
modelFamilies: ["bison"],
modelFamilies: ["gemini-pro"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
"gemini-proTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
}
public init() {}
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: GooglePalmModel) {
public get(_model: GoogleAIModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Google PaLM keys available");
throw new Error("No Google AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey };
}
public disable(key: GooglePalmKey) {
public disable(key: GoogleAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GooglePalmKey>) {
public update(hash: string, update: Partial<GoogleAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key.bisonTokens += tokens;
key["gemini-proTokens"] += tokens;
}
public getLockoutPeriod() {
+6 -12
View File
@@ -1,29 +1,23 @@
import type { LLMService, ModelFamily } from "../models";
import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider";
import { GoogleAIModel } from "./google-ai/provider";
import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool";
import type { ModelFamily } from "../models";
/** The request and response format used by a model's API. */
export type APIFormat =
| "openai"
| "anthropic"
| "google-palm"
| "google-ai"
| "mistral-ai"
| "openai-text"
| "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService =
| "openai"
| "anthropic"
| "google-palm"
| "aws"
| "azure";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| GoogleAIModel
| AwsBedrockModel
| AzureOpenAIModel;
@@ -77,6 +71,6 @@ export interface KeyProvider<T extends Key = Key> {
export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider";
+14 -56
View File
@@ -4,14 +4,19 @@ import os from "os";
import schedule from "node-schedule";
import { config } from "../../config";
import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index";
import {
getServiceForModel,
LLMService,
MODEL_FAMILY_SERVICE,
ModelFamily,
} from "../models";
import { Key, KeyProvider, Model } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
import { AzureOpenAIKeyProvider } from "./azure/provider";
import { MistralAIKeyProvider } from "./mistral-ai/provider";
type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate;
@@ -24,7 +29,8 @@ export class KeyPool {
constructor() {
this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider());
this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new MistralAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
@@ -41,7 +47,7 @@ export class KeyPool {
}
public get(model: Model): Key {
const service = this.getServiceForModel(model);
const service = getServiceForModel(model);
return this.getKeyProvider(service).get(model);
}
@@ -71,7 +77,7 @@ export class KeyPool {
public available(model: Model | "all" = "all"): number {
return this.keyProviders.reduce((sum, provider) => {
const includeProvider =
model === "all" || this.getServiceForModel(model) === provider.service;
model === "all" || getServiceForModel(model) === provider.service;
return sum + (includeProvider ? provider.available() : 0);
}, 0);
}
@@ -82,7 +88,7 @@ export class KeyPool {
}
public getLockoutPeriod(family: ModelFamily): number {
const service = this.getServiceForModelFamily(family);
const service = MODEL_FAMILY_SERVICE[family];
return this.getKeyProvider(service).getLockoutPeriod(family);
}
@@ -108,54 +114,6 @@ export class KeyPool {
provider.recheck();
}
private getServiceForModel(model: Model): LLMService {
if (
model.startsWith("gpt") ||
model.startsWith("text-embedding-ada") ||
model.startsWith("dall-e")
) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
} else if (model.includes("bison")) {
// https://developers.generativeai.google.com/models/language
return "google-palm";
} else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
private getServiceForModelFamily(modelFamily: ModelFamily): LLMService {
switch (modelFamily) {
case "gpt4":
case "gpt4-32k":
case "gpt4-turbo":
case "turbo":
case "dall-e":
return "openai";
case "claude":
return "anthropic";
case "bison":
return "google-palm";
case "aws-claude":
return "aws";
case "azure-turbo":
case "azure-gpt4":
case "azure-gpt4-32k":
case "azure-gpt4-turbo":
return "azure";
default:
assertNever(modelFamily);
}
}
private getKeyProvider(service: LLMService): KeyProvider {
return this.keyProviders.find((provider) => provider.service === service)!;
}
@@ -0,0 +1,112 @@
import axios, { AxiosError } from "axios";
import type { MistralAIModelFamily, OpenAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base";
import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
import { getMistralAIModelFamily, getOpenAIModelFamily } from "../../models";
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
const GET_MODELS_URL = "https://api.mistral.ai/v1/models";
type GetModelsResponse = {
data: [{ id: string }];
};
type MistralAIError = {
message: string;
request_id: string;
};
type UpdateFn = typeof MistralAIKeyProvider.prototype.update;
export class MistralAIKeyChecker extends KeyCheckerBase<MistralAIKey> {
constructor(keys: MistralAIKey[], updateKey: UpdateFn) {
super(keys, {
service: "mistral-ai",
keyCheckPeriod: KEY_CHECK_PERIOD,
minCheckInterval: MIN_CHECK_INTERVAL,
recurringChecksEnabled: false,
updateKey,
});
}
protected async testKeyOrFail(key: MistralAIKey) {
// We only need to check for provisioned models on the initial check.
const isInitialCheck = !key.lastChecked;
if (isInitialCheck) {
const provisionedModels = await this.getProvisionedModels(key);
const updates = {
modelFamilies: provisionedModels,
};
this.updateKey(key.hash, updates);
}
this.log.info({ key: key.hash, models: key.modelFamilies }, "Checked key.");
}
private async getProvisionedModels(
key: MistralAIKey
): Promise<MistralAIModelFamily[]> {
const opts = { headers: MistralAIKeyChecker.getHeaders(key) };
const { data } = await axios.get<GetModelsResponse>(GET_MODELS_URL, opts);
const models = data.data;
const families = new Set<MistralAIModelFamily>();
models.forEach(({ id }) => families.add(getMistralAIModelFamily(id)));
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
const familiesArray = [...families];
const keyFromPool = this.keys.find((k) => k.hash === key.hash)!;
this.updateKey(key.hash, {
modelFamilies: familiesArray,
lastChecked: keyFromPool.lastChecked,
});
return familiesArray;
}
protected handleAxiosError(key: MistralAIKey, error: AxiosError) {
if (error.response && MistralAIKeyChecker.errorIsMistralAIError(error)) {
const { status, data } = error.response;
if (status === 401) {
this.log.warn(
{ key: key.hash, error: data },
"Key is invalid or revoked. Disabling key."
);
this.updateKey(key.hash, {
isDisabled: true,
isRevoked: true,
modelFamilies: ["mistral-tiny"],
});
} else {
this.log.error(
{ key: key.hash, status, error: data },
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
);
this.updateKey(key.hash, { lastChecked: Date.now() });
}
return;
}
this.log.error(
{ key: key.hash, error: error.message },
"Network error while checking key; trying this key again in a minute."
);
const oneMinute = 60 * 1000;
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
this.updateKey(key.hash, { lastChecked: next });
}
static errorIsMistralAIError(
error: AxiosError
): error is AxiosError<MistralAIError> {
const data = error.response?.data as any;
return data?.message && data?.request_id;
}
static getHeaders(key: MistralAIKey) {
return {
Authorization: `Bearer ${key.key}`,
};
}
}
@@ -0,0 +1,210 @@
import crypto from "crypto";
import { Key, KeyProvider, Model } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { MistralAIKeyChecker } from "./checker";
export type MistralAIModel =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type MistralAIKeyUpdate = Omit<
Partial<MistralAIKey>,
| "key"
| "hash"
| "lastUsed"
| "promptCount"
| "rateLimitedAt"
| "rateLimitedUntil"
>;
type MistralAIKeyUsage = {
[K in MistralAIModelFamily as `${K}Tokens`]: number;
};
export interface MistralAIKey extends Key, MistralAIKeyUsage {
readonly service: "mistral-ai";
readonly modelFamilies: MistralAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
rateLimitedUntil: number;
}
/**
* Upon being rate limited, a key will be locked out for this many milliseconds
* while we wait for other concurrent requests to finish.
*/
const RATE_LIMIT_LOCKOUT = 2000;
/**
* Upon assigning a key, we will wait this many milliseconds before allowing it
* to be used again. This is to prevent the queue from flooding a key with too
* many requests while we wait to learn whether previous ones succeeded.
*/
const KEY_REUSE_DELAY = 500;
export class MistralAIKeyProvider implements KeyProvider<MistralAIKey> {
readonly service = "mistral-ai";
private keys: MistralAIKey[] = [];
private checker?: MistralAIKeyChecker;
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.mistralAIKey?.trim();
if (!keyConfig) {
this.log.warn(
"MISTRAL_AI_KEY is not set. Mistral AI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: MistralAIKey = {
key,
service: this.service,
modelFamilies: ["mistral-tiny", "mistral-small", "mistral-medium"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
lastUsed: 0,
rateLimitedAt: 0,
rateLimitedUntil: 0,
hash: `mst-${crypto
.createHash("sha256")
.update(key)
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
"mistral-tinyTokens": 0,
"mistral-smallTokens": 0,
"mistral-mediumTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded Mistral AI keys.");
}
public init() {
if (config.checkKeys) {
const updateFn = this.update.bind(this);
this.checker = new MistralAIKeyChecker(this.keys, updateFn);
this.checker.start();
}
}
public list() {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: Model) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Mistral AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
// Select a key, from highest priority to lowest priority:
// 1. Keys which are not rate limited
// a. If all keys were rate limited recently, select the least-recently
// rate limited key.
// 3. Keys which have not been used in the longest time
const now = Date.now();
const keysByPriority = availableKeys.sort((a, b) => {
const aRateLimited = now - a.rateLimitedAt < RATE_LIMIT_LOCKOUT;
const bRateLimited = now - b.rateLimitedAt < RATE_LIMIT_LOCKOUT;
if (aRateLimited && !bRateLimited) return 1;
if (!aRateLimited && bRateLimited) return -1;
if (aRateLimited && bRateLimited) {
return a.rateLimitedAt - b.rateLimitedAt;
}
return a.lastUsed - b.lastUsed;
});
const selectedKey = keysByPriority[0];
selectedKey.lastUsed = now;
this.throttle(selectedKey.hash);
return { ...selectedKey };
}
public disable(key: MistralAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<MistralAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
public available() {
return this.keys.filter((k) => !k.isDisabled).length;
}
public incrementUsage(hash: string, model: string, tokens: number) {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
const family = getMistralAIModelFamily(model);
key[`${family}Tokens`] += tokens;
}
public getLockoutPeriod() {
const activeKeys = this.keys.filter((k) => !k.isDisabled);
// Don't lock out if there are no keys available or the queue will stall.
// Just let it through so the add-key middleware can throw an error.
if (activeKeys.length === 0) return 0;
const now = Date.now();
const rateLimitedKeys = activeKeys.filter((k) => now < k.rateLimitedUntil);
const anyNotRateLimited = rateLimitedKeys.length < activeKeys.length;
if (anyNotRateLimited) return 0;
// If all keys are rate-limited, return the time until the first key is
// ready.
return Math.min(...activeKeys.map((k) => k.rateLimitedUntil - now));
}
/**
* This is called when we receive a 429, which means there are already five
* concurrent requests running on this key. We don't have any information on
* when these requests will resolve, so all we can do is wait a bit and try
* again. We will lock the key for 2 seconds after getting a 429 before
* retrying in order to give the other requests a chance to finish.
*/
public markRateLimited(keyHash: string) {
this.log.debug({ key: keyHash }, "Key rate limited");
const key = this.keys.find((k) => k.hash === keyHash)!;
const now = Date.now();
key.rateLimitedAt = now;
key.rateLimitedUntil = now + RATE_LIMIT_LOCKOUT;
}
public recheck() {}
/**
* Applies a short artificial delay to the key upon dequeueing, in order to
* prevent it from being immediately assigned to another request before the
* current one can be dispatched.
**/
private throttle(hash: string) {
const now = Date.now();
const key = this.keys.find((k) => k.hash === hash)!;
const currentRateLimit = key.rateLimitedUntil;
const nextRateLimit = now + KEY_REUSE_DELAY;
key.rateLimitedAt = now;
key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit);
}
}
@@ -73,6 +73,12 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
const families = new Set<OpenAIModelFamily>();
models.forEach(({ id }) => families.add(getOpenAIModelFamily(id, "turbo")));
// disable dall-e for trial keys due to very low per-day quota that tends to
// render the key unusable.
if (key.isTrial) {
families.delete("dall-e");
}
// as of 2023-11-18, many keys no longer return the dalle3 model but still
// have access to it via the api for whatever reason.
@@ -80,6 +86,15 @@ export class OpenAIKeyChecker extends KeyCheckerBase<OpenAIKey> {
// families.delete("dall-e");
// }
// as of 2024-01-10, the models endpoint has a bug and sometimes returns the
// gpt-4-32k-0314 snapshot even though the key doesn't have access to
// base gpt-4-32k. we will ignore this model if the snapshot is returned
// without the base model.
const has32k = models.find(({ id }) => id === "gpt-4-32k");
if (families.has("gpt4-32k") && !has32k) {
families.delete("gpt4-32k");
}
// We want to update the key's model families here, but we don't want to
// update its `lastChecked` timestamp because we need to let the liveness
// check run before we can consider the key checked.
+108 -13
View File
@@ -1,8 +1,20 @@
// Don't import anything here, this is imported by config.ts
// Don't import any other project files here as this is one of the first modules
// loaded and it will cause circular imports.
import pino from "pino";
import type { Request } from "express";
import { assertNever } from "./utils";
/**
* The service that a model is hosted on. Distinct from `APIFormat` because some
* services have interoperable APIs (eg Anthropic/AWS, OpenAI/Azure).
*/
export type LLMService =
| "openai"
| "anthropic"
| "google-ai"
| "mistral-ai"
| "aws"
| "azure";
export type OpenAIModelFamily =
| "turbo"
@@ -11,7 +23,11 @@ export type OpenAIModelFamily =
| "gpt4-turbo"
| "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type GoogleAIModelFamily = "gemini-pro";
export type MistralAIModelFamily =
| "mistral-tiny"
| "mistral-small"
| "mistral-medium";
export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily,
@@ -20,7 +36,8 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GooglePalmModelFamily
| GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
@@ -33,7 +50,10 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"gpt4-turbo",
"dall-e",
"claude",
"bison",
"gemini-pro",
"mistral-tiny",
"mistral-small",
"mistral-medium",
"aws-claude",
"azure-turbo",
"azure-gpt4",
@@ -41,8 +61,20 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"azure-gpt4-turbo",
] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
arr: A & ([LLMService] extends [A[number]] ? unknown : never)
) => arr)([
"openai",
"anthropic",
"google-ai",
"mistral-ai",
"aws",
"azure",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^gpt-4-1106(-preview)?$": "gpt4-turbo",
"^gpt-4-turbo(-preview)?$": "gpt4-turbo",
"^gpt-4-(0125|1106)(-preview)?$": "gpt4-turbo",
"^gpt-4(-\\d{4})?-vision(-preview)?$": "gpt4-turbo",
"^gpt-4-32k-\\d{4}$": "gpt4-32k",
"^gpt-4-32k$": "gpt4-32k",
@@ -53,7 +85,27 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^dall-e-\\d{1}$": "dall-e",
};
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
export const MODEL_FAMILY_SERVICE: {
[f in ModelFamily]: LLMService;
} = {
turbo: "openai",
gpt4: "openai",
"gpt4-turbo": "openai",
"gpt4-32k": "openai",
"dall-e": "openai",
claude: "anthropic",
"aws-claude": "aws",
"azure-turbo": "azure",
"azure-gpt4": "azure",
"azure-gpt4-32k": "azure",
"azure-gpt4-turbo": "azure",
"gemini-pro": "google-ai",
"mistral-tiny": "mistral-ai",
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
};
pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily(
model: string,
@@ -70,10 +122,19 @@ export function getClaudeModelFamily(model: string): ModelFamily {
return "claude";
}
export function getGooglePalmModelFamily(model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
return "bison";
export function getGoogleAIModelFamily(_model: string): ModelFamily {
return "gemini-pro";
}
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
switch (model) {
case "mistral-tiny":
case "mistral-small":
case "mistral-medium":
return model;
default:
return "mistral-tiny";
}
}
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
@@ -130,8 +191,11 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "openai-image":
modelFamily = getOpenAIModelFamily(model);
break;
case "google-palm":
modelFamily = getGooglePalmModelFamily(model);
case "google-ai":
modelFamily = getGoogleAIModelFamily(model);
break;
case "mistral-ai":
modelFamily = getMistralAIModelFamily(model);
break;
default:
assertNever(req.outboundApi);
@@ -140,3 +204,34 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
return (req.modelFamily = modelFamily);
}
export function getServiceForModel(model: string): LLMService {
if (
model.startsWith("gpt") ||
model.startsWith("text-embedding-ada") ||
model.startsWith("dall-e")
) {
// https://platform.openai.com/docs/models/model-endpoint-compatibility
return "openai";
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-ai";
} else if (model.includes("mistral")) {
// https://docs.mistral.ai/platform/endpoints
return "mistral-ai";
} else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
return "aws";
} else if (model.startsWith("azure")) {
return "azure";
}
throw new Error(`Unknown service for model '${model}'`);
}
function assertNever(x: never): never {
throw new Error(`Called assertNever with argument ${x}.`);
}
+15
View File
@@ -1,3 +1,4 @@
import { config } from "../config";
import { ModelFamily } from "./models";
// technically slightly underestimates, because completion tokens cost more
@@ -24,6 +25,15 @@ export function getTokenCostUsd(model: ModelFamily, tokens: number) {
case "claude":
cost = 0.00001102;
break;
case "mistral-tiny":
cost = 0.00000031;
break;
case "mistral-small":
cost = 0.00000132;
break;
case "mistral-medium":
cost = 0.0000055;
break;
}
return cost * Math.max(0, tokens);
}
@@ -40,3 +50,8 @@ export function prettyTokens(tokens: number): string {
return (tokens / 1000000000).toFixed(3) + "b";
}
}
export function getCostSuffix(cost: number) {
if (!config.showTokenCosts) return "";
return ` ($${cost.toFixed(2)})`;
}
+55 -24
View File
@@ -1,6 +1,7 @@
import { Request, Response } from "express";
import { Response } from "express";
import { IncomingMessage } from "http";
import { assertNever } from "./utils";
import { APIFormat } from "./key-management";
export function initializeSseStream(res: Response) {
res.statusCode = 200;
@@ -39,54 +40,84 @@ export function copySseResponseHeaders(
* that the request is being proxied to. Used to send error messages to the
* client in the middle of a streaming request.
*/
export function buildFakeSse(type: string, string: string, req: Request) {
let fakeEvent;
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
export function makeCompletionSSE({
format,
title,
message,
obj,
reqId,
model = "unknown",
}: {
format: APIFormat;
title: string;
message: string;
obj?: object;
reqId: string | number | object;
model?: string;
}) {
const id = String(reqId);
const content = `\n\n**${title}**\n${message}${
obj ? `\n\`\`\`\n${JSON.stringify(obj, null, 2)}\n\`\`\`\n` : ""
}`;
switch (req.inboundApi) {
let event;
switch (format) {
case "openai":
fakeEvent = {
id: "chatcmpl-" + req.id,
case "mistral-ai":
event = {
id: "chatcmpl-" + id,
object: "chat.completion.chunk",
created: Date.now(),
model: req.body?.model,
choices: [{ delta: { content }, index: 0, finish_reason: type }],
model,
choices: [{ delta: { content }, index: 0, finish_reason: title }],
};
break;
case "openai-text":
fakeEvent = {
id: "cmpl-" + req.id,
event = {
id: "cmpl-" + id,
object: "text_completion",
created: Date.now(),
choices: [
{ text: content, index: 0, logprobs: null, finish_reason: type },
{ text: content, index: 0, logprobs: null, finish_reason: title },
],
model: req.body?.model,
model,
};
break;
case "anthropic":
fakeEvent = {
event = {
completion: content,
stop_reason: type,
truncated: false, // I've never seen this be true
stop_reason: title,
truncated: false,
stop: null,
model: req.body?.model,
log_id: "proxy-req-" + req.id,
model,
log_id: "proxy-req-" + id,
};
break;
case "google-palm":
case "google-ai":
return JSON.stringify({
candidates: [
{
content: { parts: [{ text: content }], role: "model" },
finishReason: title,
index: 0,
tokenCount: null,
safetyRatings: [],
},
],
});
case "openai-image":
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
throw new Error(`SSE not supported for ${format} requests`);
default:
assertNever(req.inboundApi);
assertNever(format);
}
if (req.inboundApi === "anthropic") {
if (format === "anthropic") {
return (
["event: completion", `data: ${JSON.stringify(fakeEvent)}`].join("\n") +
["event: completion", `data: ${JSON.stringify(event)}`].join("\n") +
"\n\n"
);
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
return `data: ${JSON.stringify(event)}\n\n`;
}
File diff suppressed because one or more lines are too long
+45
View File
@@ -0,0 +1,45 @@
import * as tokenizer from "./mistral-tokenizer-js";
import { MistralAIChatMessage } from "../api-schemas";
export function init() {
tokenizer.initializemistralTokenizer();
return true;
}
export function getTokenCount(prompt: MistralAIChatMessage[] | string) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
let chunks = [];
for (const message of prompt) {
switch (message.role) {
case "system":
chunks.push(message.content);
break;
case "assistant":
chunks.push(message.content + "</s>");
break;
case "user":
chunks.push("[INST] " + message.content + " [/INST]");
break;
}
}
return getTextTokenCount(chunks.join(" "));
}
function getTextTokenCount(prompt: string) {
// Don't try tokenizing if the prompt is massive to prevent DoS.
// 500k characters should be sufficient for all supported models.
if (prompt.length > 500000) {
return {
tokenizer: "length fallback",
token_count: 100000,
};
}
return {
tokenizer: "mistral-tokenizer-js",
token_count: tokenizer.encode(prompt.normalize("NFKC"))!.length,
};
}
+28 -5
View File
@@ -2,7 +2,7 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { GoogleAIChatMessage, OpenAIChatMessage } from "../api-schemas";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
@@ -29,11 +29,11 @@ export async function getTokenCount(
return getTextTokenCount(prompt);
}
const gpt4 = model.startsWith("gpt-4");
const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
@@ -50,7 +50,7 @@ export async function getTokenCount(
for (const item of value) {
if (item.type === "text") {
textContent += item.text;
} else if (item.type === "image_url") {
} else if (["image", "image_url"].includes(item.type)) {
const { url, detail } = item.image_url;
const cost = await getGpt4VisionTokenCost(url, detail);
numTokens += cost ?? 0;
@@ -228,3 +228,26 @@ export function getOpenAIImageCost(params: {
token_count: Math.ceil(tokens),
};
}
export function estimateGoogleAITokenCount(
prompt: string | GoogleAIChatMessage[]
) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
numTokens += encoder.encode(message.parts[0].text).length;
}
numTokens += 3;
return {
tokenizer: "tiktoken (google-ai estimate)",
token_count: numTokens,
};
}
+29 -8
View File
@@ -1,20 +1,30 @@
import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils";
import {
init as initClaude,
getTokenCount as getClaudeTokenCount,
init as initClaude,
} from "./claude";
import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
estimateGoogleAITokenCount,
getOpenAIImageCost,
getTokenCount as getOpenAITokenCount,
init as initOpenAi,
} from "./openai";
import {
getTokenCount as getMistralAITokenCount,
init as initMistralAI,
} from "./mistral";
import { APIFormat } from "../key-management";
import {
GoogleAIChatMessage,
MistralAIChatMessage,
OpenAIChatMessage,
} from "../api-schemas";
export async function init() {
initClaude();
initOpenAi();
initMistralAI();
}
/** Tagged union via `service` field of the different types of requests that can
@@ -24,7 +34,13 @@ type TokenCountRequest = { req: Request } & (
| {
prompt: string;
completion?: never;
service: "openai-text" | "anthropic" | "google-palm";
service: "openai-text" | "anthropic" | "google-ai";
}
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| {
prompt: MistralAIChatMessage[];
completion?: never;
service: "mistral-ai";
}
| { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" }
@@ -65,11 +81,16 @@ export async function countTokens({
}),
tokenization_duration_ms: getElapsedMs(time),
};
case "google-palm":
// TODO: Can't find a tokenization library for PaLM. There is an API
case "google-ai":
// TODO: Can't find a tokenization library for Gemini. There is an API
// endpoint for it but it adds significant latency to the request.
return {
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time),
};
case "mistral-ai":
return {
...getMistralAITokenCount(prompt ?? completion),
tokenization_duration_ms: getElapsedMs(time),
};
default:
+1 -1
View File
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
"gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0),
bison: z.number().optional().default(0),
"gemini-pro": z.number().optional().default(0),
"aws-claude": z.number().optional().default(0),
});
+10 -4
View File
@@ -14,7 +14,8 @@ import { config, getFirebaseApp } from "../../config";
import {
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getGoogleAIModelFamily,
getMistralAIModelFamily,
getOpenAIModelFamily,
MODEL_FAMILIES,
ModelFamily,
@@ -33,7 +34,10 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
"gpt4-turbo": 0,
"dall-e": 0,
claude: 0,
bison: 0,
"gemini-pro": 0,
"mistral-tiny": 0,
"mistral-small": 0,
"mistral-medium": 0,
"aws-claude": 0,
"azure-turbo": 0,
"azure-gpt4": 0,
@@ -397,8 +401,10 @@ function getModelFamilyForQuotaUsage(
return getOpenAIModelFamily(model);
case "anthropic":
return getClaudeModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
case "google-ai":
return getGoogleAIModelFamily(model);
case "mistral-ai":
return getMistralAIModelFamily(model);
default:
assertNever(api);
}
+1 -1
View File
@@ -15,5 +15,5 @@
},
"include": ["src"],
"exclude": ["node_modules"],
"files": ["src/types/custom.d.ts"]
"files": ["src/shared/custom.d.ts"]
}