Add configurable network interface or SOCKS/HTTP proxy for outgoing requests (khanon/oai-reverse-proxy!80)

This commit is contained in:
khanon
2024-09-16 15:17:57 +00:00
parent 6e97e036b2
commit d21e274358
55 changed files with 1983 additions and 920 deletions
+67 -56
View File
@@ -8,6 +8,9 @@
# Use production mode unless you are developing locally. # Use production mode unless you are developing locally.
NODE_ENV=production NODE_ENV=production
# Detail level of diagnostic logging. (trace | debug | info | warn | error)
# LOG_LEVEL=info
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# General settings: # General settings:
@@ -24,23 +27,22 @@ NODE_ENV=production
# Max number of context tokens a user can request at once. # Max number of context tokens a user can request at once.
# Increase this if your proxy allow GPT 32k or 128k context # Increase this if your proxy allow GPT 32k or 128k context
# MAX_CONTEXT_TOKENS_OPENAI=16384 # MAX_CONTEXT_TOKENS_OPENAI=32768
# MAX_CONTEXT_TOKENS_ANTHROPIC=32768
# Max number of output tokens a user can request at once. # Max number of output tokens a user can request at once.
# MAX_OUTPUT_TOKENS_OPENAI=400 # MAX_OUTPUT_TOKENS_OPENAI=1024
# MAX_OUTPUT_TOKENS_ANTHROPIC=400 # MAX_OUTPUT_TOKENS_ANTHROPIC=1024
# Whether to show the estimated cost of consumed tokens on the info page. # Whether to show the estimated cost of consumed tokens on the info page.
# SHOW_TOKEN_COSTS=false # SHOW_TOKEN_COSTS=false
# Whether to automatically check API keys for validity. # Whether to automatically check API keys for validity.
# Note: CHECK_KEYS is disabled by default in local development mode, but enabled # Disabled by default in local development mode, but enabled in production.
# by default in production mode.
# CHECK_KEYS=true # CHECK_KEYS=true
# Which model types users are allowed to access. # Which model types users are allowed to access.
# The following model families are recognized: # The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | o1 | dall-e | claude # turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | o1 | dall-e | claude
# | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | # | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny |
# | mistral-small | mistral-medium | mistral-large | aws-claude | # | mistral-small | mistral-medium | mistral-large | aws-claude |
@@ -60,6 +62,42 @@ NODE_ENV=production
# By default, no image services are allowed and image prompts are rejected. # By default, no image services are allowed and image prompts are rejected.
# ALLOWED_VISION_SERVICES= # ALLOWED_VISION_SERVICES=
# Whether prompts should be logged to Google Sheets.
# Requires additional setup. See `docs/google-sheets.md` for more information.
# PROMPT_LOGGING=false
# Specifies the number of proxies or load balancers in front of the server.
# For Cloudflare or Hugging Face deployments, the default of 1 is correct.
# For any other deployments, please see config.ts as the correct configuration
# depends on your setup. Misconfiguring this value can result in problems
# accurately tracking IP addresses and enforcing rate limits.
# TRUSTED_PROXIES=1
# Whether cookies should be set without the Secure flag, for hosts that don't
# support SSL. True by default in development, false in production.
# USE_INSECURE_COOKIES=false
# Reorganizes requests in the queue according to their token count, placing
# larger prompts further back. The penalty is determined by (promptTokens *
# TOKENS_PUNISHMENT_FACTOR). A value of 1.0 adds one second per 1000 tokens.
# When there is no queue or it is very short, the effect is negligible (this
# setting only reorders the queue, it does not artificially delay requests).
# TOKENS_PUNISHMENT_FACTOR=0.0
# Captcha verification settings. Refer to docs/pow-captcha.md for guidance.
# CAPTCHA_MODE=none
# POW_TOKEN_HOURS=24
# POW_TOKEN_MAX_IPS=2
# POW_DIFFICULTY_LEVEL=low
# POW_CHALLENGE_TIMEOUT=30
# -------------------------------------------------------------------------------
# Blocking settings:
# Allows blocking requests depending on content, referers, or IP addresses.
# This is a convenience feature; if you need more robust functionality it is
# highly recommended to put this application behind nginx or Cloudflare, as they
# will have better performance.
# IP addresses or CIDR blocks from which requests will be blocked. # IP addresses or CIDR blocks from which requests will be blocked.
# IP_BLACKLIST=10.0.0.1/24 # IP_BLACKLIST=10.0.0.1/24
# URLs from which requests will be blocked. # URLs from which requests will be blocked.
@@ -68,35 +106,13 @@ NODE_ENV=production
# BLOCK_MESSAGE="You must be over the age of majority in your country to use this service." # BLOCK_MESSAGE="You must be over the age of majority in your country to use this service."
# Destination to redirect blocked requests to. # Destination to redirect blocked requests to.
# BLOCK_REDIRECT="https://roblox.com/" # BLOCK_REDIRECT="https://roblox.com/"
# Comma-separated list of phrases that will be rejected. Surround phrases with
# Comma-separated list of phrases that will be rejected. Only whole words are matched. # quotes if they contain commas. You can use regular expression tokens.
# Surround phrases with quotes if they contain commas. # Avoid overly broad phrases as will trigger on any match in the entire prompt.
# Avoid short or common phrases as this tests the entire prompt.
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four" # REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
# Message to show when requests are rejected. # Message to show when requests are rejected.
# REJECT_MESSAGE="You can't say that here." # REJECT_MESSAGE="You can't say that here."
# Whether prompts should be logged to Google Sheets.
# Requires additional setup. See `docs/google-sheets.md` for more information.
# PROMPT_LOGGING=false
# The port and network interface to listen on.
# PORT=7860
# BIND_ADDRESS=0.0.0.0
# Whether cookies should be set without the Secure flag, for hosts that don't support SSL.
# USE_INSECURE_COOKIES=false
# Detail level of logging. (trace | debug | info | warn | error)
# LOG_LEVEL=info
# Captcha verification settings. Refer to docs/pow-captcha.md for guidance.
# CAPTCHA_MODE=none
# POW_TOKEN_HOURS=24
# POW_TOKEN_MAX_IPS=2
# POW_DIFFICULTY_LEVEL=low
# POW_CHALLENGE_TIMEOUT=30
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Optional settings for user management, access control, and quota enforcement: # Optional settings for user management, access control, and quota enforcement:
# See `docs/user-management.md` for more information and setup instructions. # See `docs/user-management.md` for more information and setup instructions.
@@ -116,15 +132,8 @@ NODE_ENV=production
# ALLOW_NICKNAME_CHANGES=true # ALLOW_NICKNAME_CHANGES=true
# Default token quotas for each model family. (0 for unlimited) # Default token quotas for each model family. (0 for unlimited)
# Specify as TOKEN_QUOTA_MODEL_FAMILY=value, replacing dashes with underscores. # Specify as TOKEN_QUOTA_MODEL_FAMILY=value (replacing dashes with underscores).
# TOKEN_QUOTA_TURBO=0 # eg. TOKEN_QUOTA_TURBO=0, TOKEN_QUOTA_GPT4=1000000, TOKEN_QUOTA_GPT4_32K=100000
# TOKEN_QUOTA_GPT4=0
# TOKEN_QUOTA_GPT4_32K=0
# TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# TOKEN_QUOTA_GCP_CLAUDE=0
# "Tokens" for image-generation models are counted at a rate of 100000 tokens # "Tokens" for image-generation models are counted at a rate of 100000 tokens
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo. # per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens). # DALL-E 3 costs around US$0.10 per image (10000 tokens).
@@ -135,12 +144,22 @@ NODE_ENV=production
# Leave unset to never automatically refresh quotas. # Leave unset to never automatically refresh quotas.
# QUOTA_REFRESH_PERIOD=daily # QUOTA_REFRESH_PERIOD=daily
# Specifies the number of proxies or load balancers in front of the server. # -------------------------------------------------------------------------------
# For Cloudflare or Hugging Face deployments, the default of 1 is correct. # HTTP agent settings:
# For any other deployments, please see config.ts as the correct configuration # If you need to change how the proxy makes requests to other servers, such
# depends on your setup. Misconfiguring this value can result in problems # as when checking keys or forwarding users' requests to external services,
# accurately tracking IP addresses and enforcing rate limits. # you can configure an alternative HTTP agent. Otherwise the default OS settings
# TRUSTED_PROXIES=1 # will be used.
# The name of the network interface to use. The first external IPv4 address
# belonging to this interface will be used for outgoing requests.
# HTTP_AGENT_INTERFACE=enp0s3
# The URL of a proxy server to use. Supports SOCKS4, SOCKS5, HTTP, and HTTPS.
# Note that if your proxy server issues a self-signed certificate, you may need
# NODE_EXTRA_CA_CERTS set to the path to your certificate. You will need to set
# that variable in your environment, not in this file.
# HTTP_AGENT_PROXY_URL=http://test:test@127.0.0.1:8000
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Secrets and keys: # Secrets and keys:
@@ -164,11 +183,10 @@ GCP_CREDENTIALS=project-id:client-email:region:private-key
# With user_token gatekeeper, the admin password used to manage users. # With user_token gatekeeper, the admin password used to manage users.
# ADMIN_KEY=your-very-secret-key # ADMIN_KEY=your-very-secret-key
# To restrict access to the admin interface to specific IP addresses, set the # Restrict access to the admin interface to specific IP addresses, specified
# ADMIN_WHITELIST environment variable to a comma-separated list of CIDR blocks. # as a comma-separated list of CIDR ranges.
# ADMIN_WHITELIST=0.0.0.0/0 # ADMIN_WHITELIST=0.0.0.0/0
# With firebase_rtdb gatekeeper storage, the Firebase project credentials. # With firebase_rtdb gatekeeper storage, the Firebase project credentials.
# FIREBASE_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # FIREBASE_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# FIREBASE_RTDB_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.firebaseio.com # FIREBASE_RTDB_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.firebaseio.com
@@ -176,10 +194,3 @@ GCP_CREDENTIALS=project-id:client-email:region:private-key
# With prompt logging, the Google Sheets credentials. # With prompt logging, the Google Sheets credentials.
# GOOGLE_SHEETS_SPREADSHEET_ID=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # GOOGLE_SHEETS_SPREADSHEET_ID=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# GOOGLE_SHEETS_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx # GOOGLE_SHEETS_KEY=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# Prioritize requests in the queue according to their
# token count, placing larger requests further back.
#
# Punishes requests with a second's delay per 1k tokens
# when the value is 1.0, two seconds when it's 2, etc.
# TOKENS_PUNISHMENT_FACTOR=0.0
+918 -125
View File
File diff suppressed because it is too large Load Diff
+12 -7
View File
@@ -5,9 +5,11 @@
"scripts": { "scripts": {
"build": "tsc && copyfiles -u 1 src/**/*.ejs build", "build": "tsc && copyfiles -u 1 src/**/*.ejs build",
"database:migrate": "ts-node scripts/migrate.ts", "database:migrate": "ts-node scripts/migrate.ts",
"postinstall": "patch-package",
"prepare": "husky install", "prepare": "husky install",
"start": "node build/server.js", "start": "node --trace-deprecation --trace-warnings build/server.js",
"start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts", "start:dev": "nodemon --watch src --exec ts-node --transpile-only src/server.ts",
"start:debug": "ts-node --inspect --transpile-only src/server.ts",
"start:replit": "tsc && node build/server.js", "start:replit": "tsc && node build/server.js",
"start:watch": "nodemon --require source-map-support/register build/server.js", "start:watch": "nodemon --require source-map-support/register build/server.js",
"type-check": "tsc --noEmit" "type-check": "tsc --noEmit"
@@ -36,18 +38,21 @@
"csrf-csrf": "^2.3.0", "csrf-csrf": "^2.3.0",
"dotenv": "^16.3.1", "dotenv": "^16.3.1",
"ejs": "^3.1.10", "ejs": "^3.1.10",
"express": "^4.18.2", "express": "^4.19.3",
"express-session": "^1.17.3", "express-session": "^1.17.3",
"firebase-admin": "^12.3.1", "firebase-admin": "^12.5.0",
"glob": "^10.3.12", "glob": "^10.3.12",
"googleapis": "^122.0.0", "googleapis": "^122.0.0",
"http-proxy-middleware": "^3.0.0-beta.1", "http-proxy": "1.18.1",
"http-proxy-middleware": "^3.0.2",
"ipaddr.js": "^2.1.0", "ipaddr.js": "^2.1.0",
"memorystore": "^1.6.7", "memorystore": "^1.6.7",
"multer": "^1.4.5-lts.1", "multer": "^1.4.5-lts.1",
"node-schedule": "^2.1.1", "node-schedule": "^2.1.1",
"patch-package": "^8.0.0",
"pino": "^8.11.0", "pino": "^8.11.0",
"pino-http": "^8.3.3", "pino-http": "^8.3.3",
"proxy-agent": "^6.4.0",
"sanitize-html": "^2.13.0", "sanitize-html": "^2.13.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"showdown": "^2.1.0", "showdown": "^2.1.0",
@@ -84,8 +89,8 @@
"typescript": "^5.4.2" "typescript": "^5.4.2"
}, },
"overrides": { "overrides": {
"braces": "^3.0.3", "node-fetch@2.x": {
"fast-xml-parser": "^4.4.1", "whatwg-url": "14.x"
"follow-redirects": "^1.15.4" }
} }
} }
+23
View File
@@ -0,0 +1,23 @@
# Patches
Contains monkey patches for certain packages, applied using `patch-package`.
## `http-proxy+1.18.1.patch`
Modifies the `http-proxy` package to work around an incompatibility with
body-parser and SOCKS5 proxies due to some esoteric stream handling behavior
when `socks-proxy-agent` is used instead of a generic http.Agent.
Modification involves adjusting the `buffer` property on ProxyServer's `options`
object to be a function that returns a stream instead of a stream itself. This
allows us to give it a function which produces a new Readable from the already-
parsed request body.
With the old implementation we would need to create an entirely new ProxyServer
instance for each request, which is not ideal under heavy load.
`http-proxy` hasn't been updated in six years so it's unlikely that this patch
will be broken by future updates, but it's stil pinned to 1.18.1 for now.
### See also
https://github.com/chimurai/http-proxy-middleware/issues/40
https://github.com/chimurai/http-proxy-middleware/issues/299
https://github.com/http-party/node-http-proxy/pull/1027
+13
View File
@@ -0,0 +1,13 @@
diff --git a/node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js b/node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js
index 7ae7355..c825c27 100644
--- a/node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js
+++ b/node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js
@@ -167,7 +167,7 @@ module.exports = {
}
}
- (options.buffer || req).pipe(proxyReq);
+ (options.buffer(req) || req).pipe(proxyReq);
proxyReq.on('response', function(proxyRes) {
if(server) { server.emit('proxyRes', proxyRes, req, res); }
+48 -29
View File
@@ -385,6 +385,36 @@ type Config = {
* Accepts floats. * Accepts floats.
*/ */
tokensPunishmentFactor: number; tokensPunishmentFactor: number;
/**
* Configuration for HTTP requests made by the proxy to other servers, such
* as when checking keys or forwarding users' requests to external services.
* If not set, all requests will be made using the default agent.
*
* If set, the proxy may make requests to other servers using the specified
* settings. This is useful if you wish to route users' requests through
* another proxy or VPN, or if you have multiple network interfaces and want
* to use a specific one for outgoing requests.
*/
httpAgent?: {
/**
* The name of the network interface to use. The first external IPv4 address
* belonging to this interface will be used for outgoing requests.
*/
interface?: string;
/**
* The URL of a proxy server to use. Supports SOCKS4, SOCKS5, HTTP, and
* HTTPS. If not set, the proxy will be made using the default agent.
* - SOCKS4: `socks4://some-socks-proxy.com:9050`
* - SOCKS5: `socks5://username:password@some-socks-proxy.com:9050`
* - HTTP: `http://proxy-server-over-tcp.com:3128`
* - HTTPS: `https://proxy-server-over-tls.com:3129`
*
* **Note:** If your proxy server issues a certificate, you may need to set
* `NODE_EXTRA_CA_CERTS` to the path to your certificate, otherwise this
* application will reject TLS connections.
*/
proxyUrl?: string;
};
}; };
// To change configs, create a file called .env in the root directory. // To change configs, create a file called .env in the root directory.
@@ -491,6 +521,10 @@ export const config: Config = {
), ),
ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")), ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")),
tokensPunishmentFactor: getEnvWithDefault("TOKENS_PUNISHMENT_FACTOR", 0.0), tokensPunishmentFactor: getEnvWithDefault("TOKENS_PUNISHMENT_FACTOR", 0.0),
httpAgent: {
interface: getEnvWithDefault("HTTP_AGENT_INTERFACE", undefined),
proxyUrl: getEnvWithDefault("HTTP_AGENT_PROXY_URL", undefined),
},
} as const; } as const;
function generateSigningKey() { function generateSigningKey() {
@@ -610,6 +644,16 @@ export async function assertConfigIsValid() {
); );
} }
if (Object.values(config.httpAgent || {}).filter(Boolean).length === 0) {
delete config.httpAgent;
} else if (config.httpAgent) {
if (config.httpAgent.interface && config.httpAgent.proxyUrl) {
throw new Error(
"Cannot set both `HTTP_AGENT_INTERFACE` and `HTTP_AGENT_PROXY_URL`."
);
}
}
// Ensure forks which add new secret-like config keys don't unwittingly expose // Ensure forks which add new secret-like config keys don't unwittingly expose
// them to users. // them to users.
for (const key of getKeys(config)) { for (const key of getKeys(config)) {
@@ -623,15 +667,16 @@ export async function assertConfigIsValid() {
`Config key "${key}" may be sensitive but is exposed. Add it to SENSITIVE_KEYS or OMITTED_KEYS.` `Config key "${key}" may be sensitive but is exposed. Add it to SENSITIVE_KEYS or OMITTED_KEYS.`
); );
} }
await maybeInitializeFirebase();
} }
/** /**
* Config keys that are masked on the info page, but not hidden as their * Config keys that are masked on the info page, but not hidden as their
* presence may be relevant to the user due to privacy implications. * presence may be relevant to the user due to privacy implications.
*/ */
export const SENSITIVE_KEYS: (keyof Config)[] = ["googleSheetsSpreadsheetId"]; export const SENSITIVE_KEYS: (keyof Config)[] = [
"googleSheetsSpreadsheetId",
"httpAgent",
];
/** /**
* Config keys that are not displayed on the info page at all, generally because * Config keys that are not displayed on the info page at all, generally because
@@ -755,32 +800,6 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
} }
} }
let firebaseApp: firebase.app.App | undefined;
async function maybeInitializeFirebase() {
if (!config.gatekeeperStore.startsWith("firebase")) {
return;
}
const firebase = await import("firebase-admin");
const firebaseKey = Buffer.from(config.firebaseKey!, "base64").toString();
const app = firebase.initializeApp({
credential: firebase.credential.cert(JSON.parse(firebaseKey)),
databaseURL: config.firebaseRtdbUrl,
});
await app.database().ref("connection-test").set(Date.now());
firebaseApp = app;
}
export function getFirebaseApp(): firebase.app.App {
if (!firebaseApp) {
throw new Error("Firebase app not initialized.");
}
return firebaseApp;
}
function parseCsv(val: string): string[] { function parseCsv(val: string): string[] {
if (!val) return []; if (!val) return [];
+29 -52
View File
@@ -1,22 +1,14 @@
import { Request, Response, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config"; import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
addKey, addKey,
addAnthropicPreamble,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeBody, finalizeBody,
createOnProxyReqHandler,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
ProxyResHandlerWithBody, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
createOnProxyResHandler, import { ProxyReqManager } from "./middleware/request/proxy-req-manager";
} from "./middleware/response";
import { sendErrorToClient } from "./middleware/response/error-generator";
let modelsCache: any = null; let modelsCache: any = null;
let modelsCacheTime = 0; let modelsCacheTime = 0;
@@ -69,7 +61,6 @@ const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse()); res.status(200).json(getModelsResponse());
}; };
/** Only used for non-streaming requests. */
const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async ( const anthropicBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes, _proxyRes,
req, req,
@@ -123,13 +114,7 @@ export function transformAnthropicChatResponseToAnthropicText(
}; };
} }
/** function transformAnthropicTextResponseToOpenAI(
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
export function transformAnthropicTextResponseToOpenAI(
anthropicBody: Record<string, any>, anthropicBody: Record<string, any>,
req: Request req: Request
): Record<string, any> { ): Record<string, any> {
@@ -201,38 +186,30 @@ function setAnthropicBetaHeader(req: Request) {
} }
} }
const anthropicProxy = createQueueMiddleware({ function selectUpstreamPath(manager: ProxyReqManager) {
proxyMiddleware: createProxyMiddleware({ const req = manager.request;
target: "https://api.anthropic.com", const pathname = req.url.split("?")[0];
changeOrigin: true, req.log.debug({ pathname }, "Anthropic path filter");
selfHandleResponse: true, const isText = req.outboundApi === "anthropic-text";
logger, const isChat = req.outboundApi === "anthropic-chat";
on: { if (isChat && pathname === "/v1/complete") {
proxyReq: createOnProxyReqHandler({ manager.setPath("/v1/messages");
pipeline: [addKey, addAnthropicPreamble, finalizeBody], }
}), if (isText && pathname === "/v1/chat/completions") {
proxyRes: createOnProxyResHandler([anthropicBlockingResponseHandler]), manager.setPath("/v1/complete");
error: handleProxyError, }
}, if (isChat && pathname === "/v1/chat/completions") {
// Abusing pathFilter to rewrite the paths dynamically. manager.setPath("/v1/messages");
pathFilter: (pathname, req) => { }
const isText = req.outboundApi === "anthropic-text"; if (isChat && ["sonnet", "opus"].includes(req.params.type)) {
const isChat = req.outboundApi === "anthropic-chat"; manager.setPath("/v1/messages");
if (isChat && pathname === "/v1/complete") { }
req.url = "/v1/messages"; }
}
if (isText && pathname === "/v1/chat/completions") { const anthropicProxy = createQueuedProxyMiddleware({
req.url = "/v1/complete"; target: "https://api.anthropic.com",
} mutations: [selectUpstreamPath, addKey, finalizeBody],
if (isChat && pathname === "/v1/chat/completions") { blockingResponseHandler: anthropicBlockingResponseHandler,
req.url = "/v1/messages";
}
if (isChat && ["sonnet", "opus"].includes(req.params.type)) {
req.url = "/v1/messages";
}
return true;
},
}),
}); });
const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware( const nativeAnthropicChatPreprocessor = createPreprocessorMiddleware(
+16 -40
View File
@@ -1,27 +1,19 @@
import { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid"; import { v4 } from "uuid";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createPreprocessorMiddleware,
signAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
} from "./middleware/request";
import {
ProxyResHandlerWithBody,
createOnProxyResHandler,
} from "./middleware/response";
import { import {
transformAnthropicChatResponseToAnthropicText, transformAnthropicChatResponseToAnthropicText,
transformAnthropicChatResponseToOpenAI, transformAnthropicChatResponseToOpenAI,
} from "./anthropic"; } from "./anthropic";
import { ipLimiter } from "./rate-limit";
import {
createPreprocessorMiddleware,
finalizeSignedRequest,
signAwsRequest,
} from "./middleware/request";
import { ProxyResHandlerWithBody } from "./middleware/response";
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
/** Only used for non-streaming requests. */ const awsBlockingResponseHandler: ProxyResHandlerWithBody = async (
const awsResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes, _proxyRes,
req, req,
res, res,
@@ -55,12 +47,6 @@ const awsResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json({ ...newBody, proxy: body.proxy }); res.status(200).json({ ...newBody, proxy: body.proxy });
}; };
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformAwsTextResponseToOpenAI( function transformAwsTextResponseToOpenAI(
awsBody: Record<string, any>, awsBody: Record<string, any>,
req: Request req: Request
@@ -89,23 +75,13 @@ function transformAwsTextResponseToOpenAI(
}; };
} }
const awsClaudeProxy = createQueueMiddleware({ const awsClaudeProxy = createQueuedProxyMiddleware({
beforeProxy: signAwsRequest, target: ({ signedRequest }) => {
proxyMiddleware: createProxyMiddleware({ if (!signedRequest) throw new Error("Must sign request before proxying");
target: "bad-target-will-be-rewritten", return `${signedRequest.protocol}//${signedRequest.hostname}`;
router: ({ signedRequest }) => { },
if (!signedRequest) throw new Error("Must sign request before proxying"); mutations: [signAwsRequest,finalizeSignedRequest],
return `${signedRequest.protocol}//${signedRequest.hostname}`; blockingResponseHandler: awsBlockingResponseHandler,
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
error: handleProxyError,
},
}),
}); });
const nativeTextPreprocessor = createPreprocessorMiddleware( const nativeTextPreprocessor = createPreprocessorMiddleware(
+14 -29
View File
@@ -1,21 +1,16 @@
import { Request } from "express"; import { Request, Router } from "express";
import { import {
createOnProxyResHandler, detectMistralInputApi,
ProxyResHandlerWithBody, transformMistralTextToMistralChat,
} from "./middleware/response"; } from "./mistral-ai";
import { createQueueMiddleware } from "./queue"; import { ipLimiter } from "./rate-limit";
import { ProxyResHandlerWithBody } from "./middleware/response";
import { import {
createOnProxyReqHandler,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeSignedRequest, finalizeSignedRequest,
signAwsRequest, signAwsRequest,
} from "./middleware/request"; } from "./middleware/request";
import { createProxyMiddleware } from "http-proxy-middleware"; import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
import { logger } from "../logger";
import { handleProxyError } from "./middleware/common";
import { Router } from "express";
import { ipLimiter } from "./rate-limit";
import { detectMistralInputApi, transformMistralTextToMistralChat } from "./mistral-ai";
const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async ( const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes, _proxyRes,
@@ -39,23 +34,13 @@ const awsMistralBlockingResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json({ ...newBody, proxy: body.proxy }); res.status(200).json({ ...newBody, proxy: body.proxy });
}; };
const awsMistralProxy = createQueueMiddleware({ const awsMistralProxy = createQueuedProxyMiddleware({
beforeProxy: signAwsRequest, target: ({ signedRequest }) => {
proxyMiddleware: createProxyMiddleware({ if (!signedRequest) throw new Error("Must sign request before proxying");
target: "bad-target-will-be-rewritten", return `${signedRequest.protocol}//${signedRequest.hostname}`;
router: ({ signedRequest }) => { },
if (!signedRequest) throw new Error("Must sign request before proxying"); mutations: [signAwsRequest,finalizeSignedRequest],
return `${signedRequest.protocol}//${signedRequest.hostname}`; blockingResponseHandler: awsMistralBlockingResponseHandler,
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([awsMistralBlockingResponseHandler]),
error: handleProxyError,
},
}),
}); });
function maybeReassignModel(req: Request) { function maybeReassignModel(req: Request) {
+11 -27
View File
@@ -1,21 +1,14 @@
import { RequestHandler, Router } from "express"; import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config"; import { config } from "../config";
import { logger } from "../logger";
import { generateModelList } from "./openai"; import { generateModelList } from "./openai";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
addAzureKey, addAzureKey,
createOnProxyReqHandler,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeSignedRequest, finalizeSignedRequest,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
createOnProxyResHandler, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
ProxyResHandlerWithBody,
} from "./middleware/response";
let modelsCache: any = null; let modelsCache: any = null;
let modelsCacheTime = 0; let modelsCacheTime = 0;
@@ -47,26 +40,17 @@ const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json({ ...body, proxy: body.proxy }); res.status(200).json({ ...body, proxy: body.proxy });
}; };
const azureOpenAIProxy = createQueueMiddleware({ const azureOpenAIProxy = createQueuedProxyMiddleware({
beforeProxy: addAzureKey, target: ({ signedRequest }) => {
proxyMiddleware: createProxyMiddleware({ if (!signedRequest) throw new Error("Must sign request before proxying");
target: "will be set by router", const { hostname, path } = signedRequest;
router: (req) => { return `https://${hostname}${path}`;
if (!req.signedRequest) throw new Error("signedRequest not set"); },
const { hostname, path } = req.signedRequest; mutations: [addAzureKey, finalizeSignedRequest],
return `https://${hostname}${path}`; blockingResponseHandler: azureOpenaiResponseHandler,
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
error: handleProxyError,
},
}),
}); });
const azureOpenAIRouter = Router(); const azureOpenAIRouter = Router();
azureOpenAIRouter.get("/v1/models", handleModelRequest); azureOpenAIRouter.get("/v1/models", handleModelRequest);
azureOpenAIRouter.post( azureOpenAIRouter.post(
+13 -30
View File
@@ -1,21 +1,15 @@
import { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config"; import { config } from "../config";
import { logger } from "../logger"; import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
createPreprocessorMiddleware, createPreprocessorMiddleware,
signGcpRequest,
finalizeSignedRequest, finalizeSignedRequest,
createOnProxyReqHandler, signGcpRequest,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
ProxyResHandlerWithBody, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
createOnProxyResHandler,
} from "./middleware/response";
import { transformAnthropicChatResponseToOpenAI } from "./anthropic";
const LATEST_GCP_SONNET_MINOR_VERSION = "20240229"; const LATEST_GCP_SONNET_MINOR_VERSION = "20240229";
let modelsCache: any = null; let modelsCache: any = null;
@@ -56,8 +50,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse()); res.status(200).json(getModelsResponse());
}; };
/** Only used for non-streaming requests. */ const gcpBlockingResponseHandler: ProxyResHandlerWithBody = async (
const gcpResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes, _proxyRes,
req, req,
res, res,
@@ -78,23 +71,13 @@ const gcpResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json({ ...newBody, proxy: body.proxy }); res.status(200).json({ ...newBody, proxy: body.proxy });
}; };
const gcpProxy = createQueueMiddleware({ const gcpProxy = createQueuedProxyMiddleware({
beforeProxy: signGcpRequest, target: ({ signedRequest }) => {
proxyMiddleware: createProxyMiddleware({ if (!signedRequest) throw new Error("Must sign request before proxying");
target: "bad-target-will-be-rewritten", return `${signedRequest.protocol}//${signedRequest.hostname}`;
router: ({ signedRequest }) => { },
if (!signedRequest) throw new Error("Must sign request before proxying"); mutations: [signGcpRequest, finalizeSignedRequest],
return `${signedRequest.protocol}//${signedRequest.hostname}`; blockingResponseHandler: gcpBlockingResponseHandler,
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([gcpResponseHandler]),
error: handleProxyError,
},
}),
}); });
const oaiToChatPreprocessor = createPreprocessorMiddleware( const oaiToChatPreprocessor = createPreprocessorMiddleware(
+13 -40
View File
@@ -1,22 +1,15 @@
import { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid"; import { v4 } from "uuid";
import { GoogleAIKey, keyPool } from "../shared/key-management";
import { config } from "../config"; import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
createOnProxyReqHandler,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeSignedRequest, finalizeSignedRequest,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
createOnProxyResHandler, import { addGoogleAIKey } from "./middleware/request/mutators/add-google-ai-key";
ProxyResHandlerWithBody, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
import { GoogleAIKey, keyPool } from "../shared/key-management";
let modelsCache: any = null; let modelsCache: any = null;
let modelsCacheTime = 0; let modelsCacheTime = 0;
@@ -63,8 +56,7 @@ const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse()); res.status(200).json(getModelsResponse());
}; };
/** Only used for non-streaming requests. */ const googleAIBlockingResponseHandler: ProxyResHandlerWithBody = async (
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes, _proxyRes,
req, req,
res, res,
@@ -110,33 +102,14 @@ function transformGoogleAIResponse(
}; };
} }
const googleAIProxy = createQueueMiddleware({ const googleAIProxy = createQueuedProxyMiddleware({
beforeProxy: addGoogleAIKey, target: ({ signedRequest }) => {
proxyMiddleware: createProxyMiddleware({ if (!signedRequest) throw new Error("Must sign request before proxying");
target: "bad-target-will-be-rewritten", const { protocol, hostname, path } = signedRequest;
router: ({ signedRequest }) => { return `${protocol}//${hostname}${path}`;
const { protocol, hostname, path } = signedRequest; },
return `${protocol}//${hostname}${path}`; mutations: [addGoogleAIKey, finalizeSignedRequest],
}, blockingResponseHandler: googleAIBlockingResponseHandler,
changeOrigin: true,
selfHandleResponse: true,
// Prevent logging of the API key by HPM
logger: logger.child(
{},
{
redact: {
paths: ["*"],
censor: (v) =>
typeof v === "string" ? v.replace(/key=\S+/g, "key=xxxxxxx") : v,
},
}
),
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
error: handleProxyError,
},
}),
}); });
const googleAIRouter = Router(); const googleAIRouter = Router();
+14 -7
View File
@@ -1,6 +1,6 @@
import { Request, Response } from "express"; import { Request, Response } from "express";
import http from "http"; import http from "http";
import httpProxy from "http-proxy"; import { Socket } from "net";
import { ZodError } from "zod"; import { ZodError } from "zod";
import { generateErrorMessage } from "zod-error"; import { generateErrorMessage } from "zod-error";
import { HttpError } from "../../shared/errors"; import { HttpError } from "../../shared/errors";
@@ -72,16 +72,23 @@ export function sendProxyError(
}); });
} }
export const handleProxyError: httpProxy.ErrorCallback = (err, req, res) => { /**
req.log.error(err, `Error during http-proxy-middleware request`); * Handles errors thrown during preparation of a proxy request (before it is
classifyErrorAndSend(err, req as Request, res as Response); * sent to the upstream API), typically due to validation, quota, or other
}; * pre-flight checks. Depending on the error class, this function will send an
* appropriate error response to the client, streaming it if necessary.
*/
export const classifyErrorAndSend = ( export const classifyErrorAndSend = (
err: Error, err: Error,
req: Request, req: Request,
res: Response res: Response | Socket
) => { ) => {
if (res instanceof Socket) {
// We should always have an Express response object here, but http-proxy's
// ErrorCallback type says it could be just a Socket.
req.log.error(err, "Caught error while proxying request to target but cannot send error response to client.");
return res.destroy();
}
try { try {
const { statusCode, statusMessage, userMessage, ...errorDetails } = const { statusCode, statusMessage, userMessage, ...errorDetails } =
classifyError(err); classifyError(err);
+26 -35
View File
@@ -1,44 +1,38 @@
import type { Request } from "express"; import type { Request } from "express";
import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
export { createOnProxyReqHandler } from "./onproxyreq-factory"; import { ProxyReqManager } from "./proxy-req-manager";
export { export {
createPreprocessorMiddleware, createPreprocessorMiddleware,
createEmbeddingsPreprocessorMiddleware, createEmbeddingsPreprocessorMiddleware,
} from "./preprocessor-factory"; } from "./preprocessor-factory";
// Express middleware (runs before http-proxy-middleware, can be async) // Preprocessors (runs before request is queued, usually body transformation/validation)
export { addAzureKey } from "./preprocessors/add-azure-key";
export { applyQuotaLimits } from "./preprocessors/apply-quota-limits"; export { applyQuotaLimits } from "./preprocessors/apply-quota-limits";
export { blockZoomerOrigins } from "./preprocessors/block-zoomer-origins";
export { countPromptTokens } from "./preprocessors/count-prompt-tokens"; export { countPromptTokens } from "./preprocessors/count-prompt-tokens";
export { languageFilter } from "./preprocessors/language-filter"; export { languageFilter } from "./preprocessors/language-filter";
export { setApiFormat } from "./preprocessors/set-api-format"; export { setApiFormat } from "./preprocessors/set-api-format";
export { signAwsRequest } from "./preprocessors/sign-aws-request";
export { signGcpRequest } from "./preprocessors/sign-vertex-ai-request";
export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload"; export { transformOutboundPayload } from "./preprocessors/transform-outbound-payload";
export { validateContextSize } from "./preprocessors/validate-context-size"; export { validateContextSize } from "./preprocessors/validate-context-size";
export { validateModelFamily } from "./preprocessors/validate-model-family";
export { validateVision } from "./preprocessors/validate-vision"; export { validateVision } from "./preprocessors/validate-vision";
// http-proxy-middleware callbacks (runs on onProxyReq, cannot be async) // Proxy request mutators (runs every time request is dequeued, before proxying, usually for auth/signing)
export { addAnthropicPreamble } from "./onproxyreq/add-anthropic-preamble"; export { addKey, addKeyForEmbeddingsRequest } from "./mutators/add-key";
export { addKey, addKeyForEmbeddingsRequest } from "./onproxyreq/add-key"; export { addAzureKey } from "./mutators/add-azure-key";
export { blockZoomerOrigins } from "./onproxyreq/block-zoomer-origins"; export { finalizeBody } from "./mutators/finalize-body";
export { checkModelFamily } from "./onproxyreq/check-model-family"; export { finalizeSignedRequest } from "./mutators/finalize-signed-request";
export { finalizeBody } from "./onproxyreq/finalize-body"; export { signAwsRequest } from "./mutators/sign-aws-request";
export { finalizeSignedRequest } from "./onproxyreq/finalize-signed-request"; export { signGcpRequest } from "./mutators/sign-vertex-ai-request";
export { stripHeaders } from "./onproxyreq/strip-headers"; export { stripHeaders } from "./mutators/strip-headers";
/** /**
* Middleware that runs prior to the request being handled by http-proxy- * Middleware that runs prior to the request being queued or handled by
* middleware. * http-proxy-middleware. You will not have access to the proxied
* request/response objects since they have not yet been sent to the API.
* *
* Async functions can be used here, but you will not have access to the proxied * User will have been authenticated by the proxy's gatekeeper, but the request
* request/response objects, nor the data set by ProxyRequestMiddleware * won't have been assigned an upstream API key yet.
* functions as they have not yet been run.
*
* User will have been authenticated by the time this middleware runs, but your
* request won't have been assigned an API key yet.
* *
* Note that these functions only run once ever per request, even if the request * Note that these functions only run once ever per request, even if the request
* is automatically retried by the request queue middleware. * is automatically retried by the request queue middleware.
@@ -46,17 +40,14 @@ export { stripHeaders } from "./onproxyreq/strip-headers";
export type RequestPreprocessor = (req: Request) => void | Promise<void>; export type RequestPreprocessor = (req: Request) => void | Promise<void>;
/** /**
* Callbacks that run immediately before the request is sent to the API in * Middleware that runs immediately before the request is proxied to the
* response to http-proxy-middleware's `proxyReq` event. * upstream API, after dequeueing the request from the request queue.
* *
* Async functions cannot be used here as HPM's event emitter is not async and * Because these middleware may be run multiple times per request if a retryable
* will not wait for the promise to resolve before sending the request. * error occurs and the request put back in the queue, they must be idempotent.
* * A change manager is provided to allow the middleware to make changes to the
* Note that these functions may be run multiple times per request if the * request which can be automatically reverted.
* first attempt is rate limited and the request is automatically retried by the
* request queue middleware.
*/ */
export type HPMRequestCallback = ProxyReqCallback<ClientRequest, Request>; export type ProxyReqMutator = (
changeManager: ProxyReqManager
export const forceModel = (model: string) => (req: Request) => ) => void | Promise<void>;
void (req.body.model = model);
@@ -3,14 +3,16 @@ import {
AzureOpenAIKey, AzureOpenAIKey,
keyPool, keyPool,
} from "../../../../shared/key-management"; } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index"; import { ProxyReqMutator } from "../index";
export const addAzureKey: RequestPreprocessor = (req) => { export const addAzureKey: ProxyReqMutator = async (manager) => {
const req = manager.request;
const validAPIs: APIFormat[] = ["openai", "openai-image"]; const validAPIs: APIFormat[] = ["openai", "openai-image"];
const apisValid = [req.outboundApi, req.inboundApi].every((api) => const apisValid = [req.outboundApi, req.inboundApi].every((api) =>
validAPIs.includes(api) validAPIs.includes(api)
); );
const serviceValid = req.service === "azure"; const serviceValid = req.service === "azure";
if (!apisValid || !serviceValid) { if (!apisValid || !serviceValid) {
throw new Error("addAzureKey called on invalid request"); throw new Error("addAzureKey called on invalid request");
} }
@@ -22,11 +24,15 @@ export const addAzureKey: RequestPreprocessor = (req) => {
const model = req.body.model.startsWith("azure-") const model = req.body.model.startsWith("azure-")
? req.body.model ? req.body.model
: `azure-${req.body.model}`; : `azure-${req.body.model}`;
// TODO: untracked mutation to body, I think this should just be a
req.key = keyPool.get(model, "azure"); // RequestPreprocessor because we don't need to do it every dequeue.
req.body.model = model; req.body.model = model;
const key = keyPool.get(model, "azure");
manager.setKey(key);
// Handles the sole Azure API deviation from the OpenAI spec (that I know of) // Handles the sole Azure API deviation from the OpenAI spec (that I know of)
// TODO: this should also probably be a RequestPreprocessor
const notNullOrUndefined = (x: any) => x !== null && x !== undefined; const notNullOrUndefined = (x: any) => x !== null && x !== undefined;
if ([req.body.logprobs, req.body.top_logprobs].some(notNullOrUndefined)) { if ([req.body.logprobs, req.body.top_logprobs].some(notNullOrUndefined)) {
// OpenAI wants logprobs: true/false and top_logprobs: number // OpenAI wants logprobs: true/false and top_logprobs: number
@@ -43,7 +49,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
} }
req.log.info( req.log.info(
{ key: req.key.hash, model }, { key: key.hash, model },
"Assigned Azure OpenAI key to request" "Assigned Azure OpenAI key to request"
); );
@@ -55,7 +61,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
const apiVersion = const apiVersion =
req.outboundApi === "openai" ? "2023-09-01-preview" : "2024-02-15-preview"; req.outboundApi === "openai" ? "2023-09-01-preview" : "2024-02-15-preview";
req.signedRequest = { manager.setSignedRequest({
method: "POST", method: "POST",
protocol: "https:", protocol: "https:",
hostname: `${resourceName}.openai.azure.com`, hostname: `${resourceName}.openai.azure.com`,
@@ -66,7 +72,7 @@ export const addAzureKey: RequestPreprocessor = (req) => {
["api-key"]: apiKey, ["api-key"]: apiKey,
}, },
body: JSON.stringify(req.body), body: JSON.stringify(req.body),
}; });
}; };
function getCredentialsFromKey(key: AzureOpenAIKey) { function getCredentialsFromKey(key: AzureOpenAIKey) {
@@ -1,7 +1,8 @@
import { keyPool } from "../../../../shared/key-management"; import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index"; import { ProxyReqMutator} from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => { export const addGoogleAIKey: ProxyReqMutator = (manager) => {
const req = manager.request;
const inboundValid = const inboundValid =
req.inboundApi === "openai" || req.inboundApi === "google-ai"; req.inboundApi === "openai" || req.inboundApi === "google-ai";
const outboundValid = req.outboundApi === "google-ai"; const outboundValid = req.outboundApi === "google-ai";
@@ -12,10 +13,11 @@ export const addGoogleAIKey: RequestPreprocessor = (req) => {
} }
const model = req.body.model; const model = req.body.model;
req.isStreaming = req.isStreaming || req.body.stream; const key = keyPool.get(model, "google-ai");
req.key = keyPool.get(model, "google-ai"); manager.setKey(key);
req.log.info( req.log.info(
{ key: req.key.hash, model, stream: req.isStreaming }, { key: key.hash, model, stream: req.isStreaming },
"Assigned Google AI API key to request" "Assigned Google AI API key to request"
); );
@@ -23,17 +25,20 @@ export const addGoogleAIKey: RequestPreprocessor = (req) => {
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY} // https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
const payload = { ...req.body, stream: undefined, model: undefined }; const payload = { ...req.body, stream: undefined, model: undefined };
req.signedRequest = { // TODO: this isn't actually signed, so the manager api is a little unclear
// with the ProxyReqManager refactor, it's probably no longer necesasry to
// do this because we can modify the path using Manager.setPath.
manager.setSignedRequest({
method: "POST", method: "POST",
protocol: "https:", protocol: "https:",
hostname: "generativelanguage.googleapis.com", hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${ path: `/v1beta/models/${model}:${
req.isStreaming ? "streamGenerateContent" : "generateContent" req.isStreaming ? "streamGenerateContent" : "generateContent"
}?key=${req.key.key}`, }?key=${key.key}`,
headers: { headers: {
["host"]: `generativelanguage.googleapis.com`, ["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json", ["content-type"]: "application/json",
}, },
body: JSON.stringify(payload), body: JSON.stringify(payload),
}; });
}; };
@@ -2,10 +2,12 @@ import { AnthropicChatMessage } from "../../../../shared/api-schemas";
import { containsImageContent } from "../../../../shared/api-schemas/anthropic"; import { containsImageContent } from "../../../../shared/api-schemas/anthropic";
import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management"; import { Key, OpenAIKey, keyPool } from "../../../../shared/key-management";
import { isEmbeddingsRequest } from "../../common"; import { isEmbeddingsRequest } from "../../common";
import { HPMRequestCallback } from "../index";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import { ProxyReqMutator } from "../index";
export const addKey: ProxyReqMutator = (manager) => {
const req = manager.request;
export const addKey: HPMRequestCallback = (proxyReq, req) => {
let assignedKey: Key; let assignedKey: Key;
const { service, inboundApi, outboundApi, body } = req; const { service, inboundApi, outboundApi, body } = req;
@@ -58,7 +60,7 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
} }
} }
req.key = assignedKey; manager.setKey(assignedKey);
req.log.info( req.log.info(
{ key: assignedKey.hash, model: body.model, inboundApi, outboundApi }, { key: assignedKey.hash, model: body.model, inboundApi, outboundApi },
"Assigned key to request" "Assigned key to request"
@@ -67,21 +69,21 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
// TODO: KeyProvider should assemble all necessary headers // TODO: KeyProvider should assemble all necessary headers
switch (assignedKey.service) { switch (assignedKey.service) {
case "anthropic": case "anthropic":
proxyReq.setHeader("X-API-Key", assignedKey.key); manager.setHeader("X-API-Key", assignedKey.key);
break; break;
case "openai": case "openai":
const key: OpenAIKey = assignedKey as OpenAIKey; const key: OpenAIKey = assignedKey as OpenAIKey;
if (key.organizationId && !key.key.includes("svcacct")) { if (key.organizationId && !key.key.includes("svcacct")) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId); manager.setHeader("OpenAI-Organization", key.organizationId);
} }
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); manager.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break; break;
case "mistral-ai": case "mistral-ai":
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); manager.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break; break;
case "azure": case "azure":
const azureKey = assignedKey.key; const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey); manager.setHeader("api-key", azureKey);
break; break;
case "aws": case "aws":
case "gcp": case "gcp":
@@ -96,10 +98,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
* Special case for embeddings requests which don't go through the normal * Special case for embeddings requests which don't go through the normal
* request pipeline. * request pipeline.
*/ */
export const addKeyForEmbeddingsRequest: HPMRequestCallback = ( export const addKeyForEmbeddingsRequest: ProxyReqMutator = (manager) => {
proxyReq, const req = manager.request;
req
) => {
if (!isEmbeddingsRequest(req)) { if (!isEmbeddingsRequest(req)) {
throw new Error( throw new Error(
"addKeyForEmbeddingsRequest called on non-embeddings request" "addKeyForEmbeddingsRequest called on non-embeddings request"
@@ -110,18 +110,18 @@ export const addKeyForEmbeddingsRequest: HPMRequestCallback = (
throw new Error("Embeddings requests must be from OpenAI"); throw new Error("Embeddings requests must be from OpenAI");
} }
req.body = { input: req.body.input, model: "text-embedding-ada-002" }; manager.setBody({ input: req.body.input, model: "text-embedding-ada-002" });
const key = keyPool.get("text-embedding-ada-002", "openai") as OpenAIKey; const key = keyPool.get("text-embedding-ada-002", "openai") as OpenAIKey;
req.key = key; manager.setKey(key);
req.log.info( req.log.info(
{ key: key.hash, toApi: req.outboundApi }, { key: key.hash, toApi: req.outboundApi },
"Assigned Turbo key to embeddings request" "Assigned Turbo key to embeddings request"
); );
proxyReq.setHeader("Authorization", `Bearer ${key.key}`); manager.setHeader("Authorization", `Bearer ${key.key}`);
if (key.organizationId) { if (key.organizationId) {
proxyReq.setHeader("OpenAI-Organization", key.organizationId); manager.setHeader("OpenAI-Organization", key.organizationId);
} }
}; };
@@ -0,0 +1,22 @@
import type { ProxyReqMutator } from "../index";
/** Finalize the rewritten request body. Must be the last mutator. */
export const finalizeBody: ProxyReqMutator = (manager) => {
const req = manager.request;
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
// For image generation requests, remove stream flag.
if (req.outboundApi === "openai-image") {
delete req.body.stream;
}
// For anthropic text to chat requests, remove undefined prompt.
if (req.outboundApi === "anthropic-chat") {
delete req.body.prompt;
}
const serialized =
typeof req.body === "string" ? req.body : JSON.stringify(req.body);
manager.setHeader("Content-Length", String(Buffer.byteLength(serialized)));
manager.setBody(serialized);
}
};
@@ -0,0 +1,32 @@
import { ProxyReqMutator } from "../index";
/**
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/
export const finalizeSignedRequest: ProxyReqMutator = (manager) => {
const req = manager.request;
if (!req.signedRequest) {
throw new Error("Expected req.signedRequest to be set");
}
// The path depends on the selected model and the assigned key's region.
manager.setPath(req.signedRequest.path);
// Amazon doesn't want extra headers, so we need to remove all of them and
// reassign only the ones specified in the signed request.
const headers = req.signedRequest.headers;
Object.keys(headers).forEach((key) => {
manager.removeHeader(key);
});
Object.entries(req.signedRequest.headers).forEach(([key, value]) => {
manager.setHeader(key, value);
});
const serialized =
typeof req.signedRequest.body === "string"
? req.signedRequest.body
: JSON.stringify(req.signedRequest.body);
manager.setHeader("Content-Length", String(Buffer.byteLength(serialized)));
manager.setBody(serialized);
};
@@ -7,11 +7,11 @@ import {
AnthropicV1MessagesSchema, AnthropicV1MessagesSchema,
} from "../../../../shared/api-schemas"; } from "../../../../shared/api-schemas";
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management"; import { AwsBedrockKey, keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { import {
AWSMistralV1ChatCompletionsSchema, AWSMistralV1ChatCompletionsSchema,
AWSMistralV1TextCompletionsSchema, AWSMistralV1TextCompletionsSchema,
} from "../../../../shared/api-schemas/mistral-ai"; } from "../../../../shared/api-schemas/mistral-ai";
import { ProxyReqMutator } from "../index";
const AMZ_HOST = const AMZ_HOST =
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com"; process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com";
@@ -21,32 +21,24 @@ const AMZ_HOST =
* request object in place to fix the path. * request object in place to fix the path.
* This happens AFTER request transformation. * This happens AFTER request transformation.
*/ */
export const signAwsRequest: RequestPreprocessor = async (req) => { export const signAwsRequest: ProxyReqMutator = async (manager) => {
const req = manager.request;
const { model, stream } = req.body; const { model, stream } = req.body;
req.key = keyPool.get(model, "aws"); const key = keyPool.get(model, "aws") as AwsBedrockKey;
manager.setKey(key);
req.isStreaming = stream === true || stream === "true";
// same as addAnthropicPreamble for non-AWS requests, but has to happen here
if (req.outboundApi === "anthropic-text") {
let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:";
req.body.prompt = preamble + req.body.prompt;
}
const credential = getCredentialParts(req); const credential = getCredentialParts(req);
const host = AMZ_HOST.replace("%REGION%", credential.region); const host = AMZ_HOST.replace("%REGION%", credential.region);
// AWS only uses 2023-06-01 and does not actually check this header, but we // AWS only uses 2023-06-01 and does not actually check this header, but we
// set it so that the stream adapter always selects the correct transformer. // set it so that the stream adapter always selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01"; manager.setHeader("anthropic-version", "2023-06-01");
// If our key has an inference profile compatible with the requested model, // If our key has an inference profile compatible with the requested model,
// we want to use the inference profile instead of the model ID when calling // we want to use the inference profile instead of the model ID when calling
// InvokeModel as that will give us higher rate limits. // InvokeModel as that will give us higher rate limits.
const profile = const profile =
(req.key as AwsBedrockKey).inferenceProfileIds.find((p) => key.inferenceProfileIds.find((p) => p.includes(model)) || model;
p.includes(model)
) || model;
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request // Uses the AWS SDK to sign a request, then modifies our HPM proxy request
// with the headers generated by the SDK. // with the headers generated by the SDK.
@@ -59,7 +51,7 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
["Host"]: host, ["Host"]: host,
["content-type"]: "application/json", ["content-type"]: "application/json",
}, },
body: JSON.stringify(applyAwsStrictValidation(req)), body: JSON.stringify(getStrictlyValidatedBodyForAws(req)),
}); });
if (stream) { if (stream) {
@@ -68,19 +60,13 @@ export const signAwsRequest: RequestPreprocessor = async (req) => {
newRequest.headers["accept"] = "*/*"; newRequest.headers["accept"] = "*/*";
} }
const { key, body, inboundApi, outboundApi } = req; const { body, inboundApi, outboundApi } = req;
req.log.info( req.log.info(
{ { key: key.hash, model: body.model, profile, inboundApi, outboundApi },
key: key.hash,
model: body.model,
inferenceProfile: profile,
inboundApi,
outboundApi,
},
"Assigned AWS credentials to request" "Assigned AWS credentials to request"
); );
req.signedRequest = await sign(newRequest, getCredentialParts(req)); manager.setSignedRequest(await sign(newRequest, getCredentialParts(req)));
}; };
type Credential = { type Credential = {
@@ -116,7 +102,7 @@ async function sign(request: HttpRequest, credential: Credential) {
return signer.sign(request); return signer.sign(request);
} }
function applyAwsStrictValidation(req: Request): unknown { function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown {
// AWS uses vendor API formats but imposes additional (more strict) validation // AWS uses vendor API formats but imposes additional (more strict) validation
// rules, namely that extraneous parameters are not allowed. We will validate // rules, namely that extraneous parameters are not allowed. We will validate
// using the vendor's zod schema but apply `.strip` to ensure that any // using the vendor's zod schema but apply `.strip` to ensure that any
@@ -1,12 +1,16 @@
import express from "express"; import { Request } from "express";
import crypto from "crypto"; import crypto from "crypto";
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas"; import { AnthropicV1MessagesSchema } from "../../../../shared/api-schemas";
import { keyPool } from "../../../../shared/key-management";
import { getAxiosInstance } from "../../../../shared/network";
import { ProxyReqMutator } from "../index";
const axios = getAxiosInstance();
const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com"; const GCP_HOST = process.env.GCP_HOST || "%REGION%-aiplatform.googleapis.com";
export const signGcpRequest: RequestPreprocessor = async (req) => { export const signGcpRequest: ProxyReqMutator = async (manager) => {
const req = manager.request;
const serviceValid = req.service === "gcp"; const serviceValid = req.service === "gcp";
if (!serviceValid) { if (!serviceValid) {
throw new Error("addVertexAIKey called on invalid request"); throw new Error("addVertexAIKey called on invalid request");
@@ -16,12 +20,11 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
throw new Error("You must specify a model with your request."); throw new Error("You must specify a model with your request.");
} }
const { model, stream } = req.body; const { model } = req.body;
req.key = keyPool.get(model, "gcp"); const key = keyPool.get(model, "gcp");
manager.setKey(key);
req.log.info({ key: req.key.hash, model }, "Assigned GCP key to request"); req.log.info({ key: key.hash, model }, "Assigned GCP key to request");
req.isStreaming = String(stream) === "true";
// TODO: This should happen in transform-outbound-payload.ts // TODO: This should happen in transform-outbound-payload.ts
// TODO: Support tools // TODO: Support tools
@@ -45,9 +48,9 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
const host = GCP_HOST.replace("%REGION%", credential.region); const host = GCP_HOST.replace("%REGION%", credential.region);
// GCP doesn't use the anthropic-version header, but we set it to ensure the // GCP doesn't use the anthropic-version header, but we set it to ensure the
// stream adapter selects the correct transformer. // stream adapter selects the correct transformer.
req.headers["anthropic-version"] = "2023-06-01"; manager.setHeader("anthropic-version", "2023-06-01");
req.signedRequest = { manager.setSignedRequest({
method: "POST", method: "POST",
protocol: "https:", protocol: "https:",
hostname: host, hostname: host,
@@ -58,11 +61,11 @@ export const signGcpRequest: RequestPreprocessor = async (req) => {
["authorization"]: `Bearer ${accessToken}`, ["authorization"]: `Bearer ${accessToken}`,
}, },
body: JSON.stringify(strippedParams), body: JSON.stringify(strippedParams),
}; });
}; };
async function getAccessToken( async function getAccessToken(
req: express.Request req: Readonly<Request>
): Promise<[string, Credential]> { ): Promise<[string, Credential]> {
// TODO: access token caching to reduce latency // TODO: access token caching to reduce latency
const credential = getCredentialParts(req); const credential = getCredentialParts(req);
@@ -134,19 +137,23 @@ async function exchangeJwtForAccessToken(
assertion: signedJwt, assertion: signedJwt,
}; };
const r = await fetch(authUrl, { try {
method: "POST", const response = await axios.post(authUrl, params, {
headers: { "Content-Type": "application/x-www-form-urlencoded" }, headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: Object.entries(params) });
.map(([k, v]) => `${k}=${v}`)
.join("&"),
}).then((res) => res.json());
if (r.access_token) { if (response.data.access_token) {
return [r.access_token, ""]; return [response.data.access_token, ""];
} else {
return [null, JSON.stringify(response.data)];
}
} catch (error) {
if ("response" in error && "data" in error.response) {
return [null, JSON.stringify(error.response.data)];
} else {
return [null, "An unexpected error occurred"];
}
} }
return [null, JSON.stringify(r)];
} }
function str2ab(str: string): ArrayBuffer { function str2ab(str: string): ArrayBuffer {
@@ -179,7 +186,7 @@ type Credential = {
privateKey: string; privateKey: string;
}; };
function getCredentialParts(req: express.Request): Credential { function getCredentialParts(req: Readonly<Request>): Credential {
const [projectId, clientEmail, region, rawPrivateKey] = const [projectId, clientEmail, region, rawPrivateKey] =
req.key!.key.split(":"); req.key!.key.split(":");
if (!projectId || !clientEmail || !region || !rawPrivateKey) { if (!projectId || !clientEmail || !region || !rawPrivateKey) {
@@ -0,0 +1,21 @@
import { ProxyReqMutator } from "../index";
/**
* Removes origin and referer headers before sending the request to the API for
* privacy reasons.
*/
export const stripHeaders: ProxyReqMutator = (manager) => {
manager.setHeader("origin", "");
manager.setHeader("referer", "");
manager.removeHeader("tailscale-user-login");
manager.removeHeader("tailscale-user-name");
manager.removeHeader("tailscale-headers-info");
manager.removeHeader("tailscale-user-profile-pic");
manager.removeHeader("cf-connecting-ip");
manager.removeHeader("forwarded");
manager.removeHeader("true-client-ip");
manager.removeHeader("x-forwarded-for");
manager.removeHeader("x-forwarded-host");
manager.removeHeader("x-forwarded-proto");
manager.removeHeader("x-real-ip");
};
@@ -1,45 +0,0 @@
import {
applyQuotaLimits,
blockZoomerOrigins,
checkModelFamily,
HPMRequestCallback,
stripHeaders,
} from "./index";
type ProxyReqHandlerFactoryOptions = { pipeline: HPMRequestCallback[] };
/**
* Returns an http-proxy-middleware request handler that runs the given set of
* onProxyReq callback functions in sequence.
*
* These will run each time a request is proxied, including on automatic retries
* by the queue after encountering a rate limit.
*/
export const createOnProxyReqHandler = ({
pipeline,
}: ProxyReqHandlerFactoryOptions): HPMRequestCallback => {
const callbackPipeline = [
checkModelFamily,
applyQuotaLimits,
blockZoomerOrigins,
stripHeaders,
...pipeline,
];
return (proxyReq, req, res, options) => {
// The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed.
// TODO: this flag is set in too many places
req.isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
try {
for (const fn of callbackPipeline) {
fn(proxyReq, req, res, options);
}
} catch (error) {
proxyReq.destroy(error);
}
};
};
@@ -1,33 +0,0 @@
import { AnthropicKey, Key } from "../../../../shared/key-management";
import { isTextGenerationRequest } from "../../common";
import { HPMRequestCallback } from "../index";
/**
* Some keys require the prompt to start with `\n\nHuman:`. There is no way to
* know this without trying to send the request and seeing if it fails. If a
* key is marked as requiring a preamble, it will be added here.
*/
export const addAnthropicPreamble: HPMRequestCallback = (_proxyReq, req) => {
if (
!isTextGenerationRequest(req) ||
req.key?.service !== "anthropic" ||
req.outboundApi !== "anthropic-text"
) {
return;
}
let preamble = "";
let prompt = req.body.prompt;
assertAnthropicKey(req.key);
if (req.key.requiresPreamble && prompt) {
preamble = prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:";
req.log.debug({ key: req.key.hash, preamble }, "Adding preamble to prompt");
}
req.body.prompt = preamble + prompt;
};
function assertAnthropicKey(key: Key): asserts key is AnthropicKey {
if (key.service !== "anthropic") {
throw new Error(`Expected an Anthropic key, got '${key.service}'`);
}
}
@@ -1,23 +0,0 @@
import { fixRequestBody } from "http-proxy-middleware";
import type { HPMRequestCallback } from "../index";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: HPMRequestCallback = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
// For image generation requests, remove stream flag.
if (req.outboundApi === "openai-image") {
delete req.body.stream;
}
// For anthropic text to chat requests, remove undefined prompt.
if (req.outboundApi === "anthropic-chat") {
delete req.body.prompt;
}
const updatedBody = JSON.stringify(req.body);
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
(req as any).rawBody = Buffer.from(updatedBody);
// body-parser and http-proxy-middleware don't play nice together
fixRequestBody(proxyReq, req);
}
};
@@ -1,26 +0,0 @@
import type { HPMRequestCallback } from "../index";
/**
* For AWS/GCP/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) {
throw new Error("Expected req.signedRequest to be set");
}
// The path depends on the selected model and the assigned key's region.
proxyReq.path = req.signedRequest.path;
// Amazon doesn't want extra headers, so we need to remove all of them and
// reassign only the ones specified in the signed request.
proxyReq.getRawHeaderNames().forEach(proxyReq.removeHeader.bind(proxyReq));
Object.entries(req.signedRequest.headers).forEach(([key, value]) => {
proxyReq.setHeader(key, value);
});
// Don't use fixRequestBody here because it adds a content-length header.
// Amazon doesn't want that and it breaks the signature.
proxyReq.write(req.signedRequest.body);
};
@@ -1,21 +0,0 @@
import { HPMRequestCallback } from "../index";
/**
* Removes origin and referer headers before sending the request to the API for
* privacy reasons.
**/
export const stripHeaders: HPMRequestCallback = (proxyReq) => {
proxyReq.setHeader("origin", "");
proxyReq.setHeader("referer", "");
proxyReq.removeHeader("tailscale-user-login");
proxyReq.removeHeader("tailscale-user-name");
proxyReq.removeHeader("tailscale-headers-info");
proxyReq.removeHeader("tailscale-user-profile-pic")
proxyReq.removeHeader("cf-connecting-ip");
proxyReq.removeHeader("forwarded");
proxyReq.removeHeader("true-client-ip");
proxyReq.removeHeader("x-forwarded-for");
proxyReq.removeHeader("x-forwarded-host");
proxyReq.removeHeader("x-forwarded-proto");
proxyReq.removeHeader("x-real-ip");
};
@@ -4,12 +4,15 @@ import { initializeSseStream } from "../../../shared/streaming";
import { classifyErrorAndSend } from "../common"; import { classifyErrorAndSend } from "../common";
import { import {
RequestPreprocessor, RequestPreprocessor,
blockZoomerOrigins,
countPromptTokens, countPromptTokens,
languageFilter, languageFilter,
setApiFormat, setApiFormat,
transformOutboundPayload, transformOutboundPayload,
validateContextSize, validateContextSize,
validateModelFamily,
validateVision, validateVision,
applyQuotaLimits,
} from "."; } from ".";
type RequestPreprocessorOptions = { type RequestPreprocessorOptions = {
@@ -30,14 +33,15 @@ type RequestPreprocessorOptions = {
/** /**
* Returns a middleware function that processes the request body into the given * Returns a middleware function that processes the request body into the given
* API format, and then sequentially runs the given additional preprocessors. * API format, and then sequentially runs the given additional preprocessors.
* These should be used for validation and transformations that only need to
* happen once per request.
* *
* These run first in the request lifecycle, a single time per request before it * These run first in the request lifecycle, a single time per request before it
* is added to the request queue. They aren't run again if the request is * is added to the request queue. They aren't run again if the request is
* re-attempted after a rate limit. * re-attempted after a rate limit.
* *
* To run a preprocessor on every re-attempt, pass it to createQueueMiddleware. * To run functions against requests every time they are re-attempted, write a
* It will run after these preprocessors, but before the request is sent to * ProxyReqMutator and pass it to createQueuedProxyMiddleware instead.
* http-proxy-middleware.
*/ */
export const createPreprocessorMiddleware = ( export const createPreprocessorMiddleware = (
apiFormat: Parameters<typeof setApiFormat>[0], apiFormat: Parameters<typeof setApiFormat>[0],
@@ -45,6 +49,7 @@ export const createPreprocessorMiddleware = (
): RequestHandler => { ): RequestHandler => {
const preprocessors: RequestPreprocessor[] = [ const preprocessors: RequestPreprocessor[] = [
setApiFormat(apiFormat), setApiFormat(apiFormat),
blockZoomerOrigins,
...(beforeTransform ?? []), ...(beforeTransform ?? []),
transformOutboundPayload, transformOutboundPayload,
countPromptTokens, countPromptTokens,
@@ -52,6 +57,8 @@ export const createPreprocessorMiddleware = (
...(afterTransform ?? []), ...(afterTransform ?? []),
validateContextSize, validateContextSize,
validateVision, validateVision,
validateModelFamily,
applyQuotaLimits,
]; ];
return async (...args) => executePreprocessors(preprocessors, args); return async (...args) => executePreprocessors(preprocessors, args);
}; };
@@ -83,10 +90,10 @@ async function executePreprocessors(
next(); next();
} catch (error) { } catch (error) {
if (error.constructor.name === "ZodError") { if (error.constructor.name === "ZodError") {
const msg = error?.issues const issues = error?.issues
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`) ?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
.join("; "); .join("; ");
req.log.warn({ issues: msg }, "Prompt validation failed."); req.log.warn({ issues }, "Prompt failed preprocessor validation.");
} else { } else {
req.log.error(error, "Error while executing request preprocessor"); req.log.error(error, "Error while executing request preprocessor");
} }
@@ -152,10 +159,7 @@ function isTestMessage(body: any) {
messages[0].content === "Hi" messages[0].content === "Hi"
); );
} else if (contents) { } else if (contents) {
return ( return contents.length === 1 && contents[0].parts[0]?.text === "Hi";
contents.length === 1 &&
contents[0].parts[0]?.text === "Hi"
);
} else { } else {
return ( return (
prompt?.trim() === "Human: Hi\n\nAssistant:" || prompt?.trim() === "Human: Hi\n\nAssistant:" ||
@@ -1,6 +1,6 @@
import { hasAvailableQuota } from "../../../../shared/users/user-store"; import { hasAvailableQuota } from "../../../../shared/users/user-store";
import { isImageGenerationRequest, isTextGenerationRequest } from "../../common"; import { isImageGenerationRequest, isTextGenerationRequest } from "../../common";
import { HPMRequestCallback } from "../index"; import { RequestPreprocessor } from "../index";
export class QuotaExceededError extends Error { export class QuotaExceededError extends Error {
public quotaInfo: any; public quotaInfo: any;
@@ -11,7 +11,7 @@ export class QuotaExceededError extends Error {
} }
} }
export const applyQuotaLimits: HPMRequestCallback = (_proxyReq, req) => { export const applyQuotaLimits: RequestPreprocessor = (req) => {
const subjectToQuota = const subjectToQuota =
isTextGenerationRequest(req) || isImageGenerationRequest(req); isTextGenerationRequest(req) || isImageGenerationRequest(req);
if (!subjectToQuota || !req.user) return; if (!subjectToQuota || !req.user) return;
@@ -1,4 +1,4 @@
import { HPMRequestCallback } from "../index"; import { RequestPreprocessor } from "../index";
const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(","); const DISALLOWED_ORIGIN_SUBSTRINGS = "janitorai.com,janitor.ai".split(",");
@@ -13,7 +13,7 @@ class ZoomerForbiddenError extends Error {
* Blocks requests from Janitor AI users with a fake, scary error message so I * Blocks requests from Janitor AI users with a fake, scary error message so I
* stop getting emails asking for tech support. * stop getting emails asking for tech support.
*/ */
export const blockZoomerOrigins: HPMRequestCallback = (_proxyReq, req) => { export const blockZoomerOrigins: RequestPreprocessor = (req) => {
const origin = req.headers.origin || req.headers.referer; const origin = req.headers.origin || req.headers.referer;
if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) { if (origin && DISALLOWED_ORIGIN_SUBSTRINGS.some((s) => origin.includes(s))) {
// Venus-derivatives send a test prompt to check if the proxy is working. // Venus-derivatives send a test prompt to check if the proxy is working.
@@ -17,7 +17,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
switch (service) { switch (service) {
case "openai": { case "openai": {
req.outputTokens = req.body.max_tokens; req.outputTokens = req.body.max_completion_tokens || req.body.max_tokens;
const prompt: OpenAIChatMessage[] = req.body.messages; const prompt: OpenAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
@@ -4,8 +4,22 @@ import { LLMService } from "../../../../shared/models";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
export const setApiFormat = (api: { export const setApiFormat = (api: {
/**
* The API format the user made the request in and expects the response to be
* in.
*/
inApi: Request["inboundApi"]; inApi: Request["inboundApi"];
/**
* The API format the proxy will make the request in and expects the response
* to be in. If different from `inApi`, the proxy will transform the user's
* request body to this format, and will transform the response body or stream
* events from this format.
*/
outApi: APIFormat; outApi: APIFormat;
/**
* The service the request will be sent to, which determines authentication
* and possibly the streaming transport.
*/
service: LLMService; service: LLMService;
}): RequestPreprocessor => { }): RequestPreprocessor => {
return function configureRequestApiFormat(req) { return function configureRequestApiFormat(req) {
@@ -35,15 +35,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
// target API format. We don't need to transform them. // target API format. We don't need to transform them.
const isNativePrompt = req.inboundApi === req.outboundApi; const isNativePrompt = req.inboundApi === req.outboundApi;
if (isNativePrompt) { if (isNativePrompt) {
const result = API_REQUEST_VALIDATORS[req.inboundApi].safeParse(req.body); const result = API_REQUEST_VALIDATORS[req.inboundApi].parse(req.body);
if (!result.success) { req.body = result;
req.log.warn(
{ issues: result.error.issues, body: req.body },
"Native prompt request validation failed."
);
throw result.error;
}
req.body = result.data;
return; return;
} }
@@ -1,12 +1,12 @@
import { config } from "../../../../config"; import { config } from "../../../../config";
import { ForbiddenError } from "../../../../shared/errors"; import { ForbiddenError } from "../../../../shared/errors";
import { getModelFamilyForRequest } from "../../../../shared/models"; import { getModelFamilyForRequest } from "../../../../shared/models";
import { HPMRequestCallback } from "../index"; import { RequestPreprocessor } from "../index";
/** /**
* Ensures the selected model family is enabled by the proxy configuration. * Ensures the selected model family is enabled by the proxy configuration.
*/ */
export const checkModelFamily: HPMRequestCallback = (_proxyReq, req) => { export const validateModelFamily: RequestPreprocessor = (req) => {
const family = getModelFamilyForRequest(req); const family = getModelFamilyForRequest(req);
if (!config.allowedModelFamilies.includes(family)) { if (!config.allowedModelFamilies.includes(family)) {
throw new ForbiddenError( throw new ForbiddenError(
@@ -0,0 +1,129 @@
import { Request, Response } from "express";
import http from "http";
import ProxyServer from "http-proxy";
import { Readable } from "stream";
import {
createProxyMiddleware,
Options,
debugProxyErrorsPlugin,
proxyEventsPlugin,
} from "http-proxy-middleware";
import { ProxyReqMutator, RequestPreprocessor } from "./index";
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "../response";
import { createQueueMiddleware } from "../../queue";
import { getHttpAgents } from "../../../shared/network";
import { classifyErrorAndSend } from "../common";
/**
* Options for the `createQueuedProxyMiddleware` factory function.
*/
type ProxyMiddlewareFactoryOptions = {
/**
* Functions which receive a ProxyReqManager and can modify the request before
* it is proxied. The modifications will be automatically reverted if the
* request needs to be returned to the queue.
*/
mutations?: ProxyReqMutator[];
/**
* The target URL to proxy requests to. This can be a string or a function
* which accepts the request and returns a string.
*/
target: string | Options<Request>["router"];
/**
* A function which receives the proxy response and the JSON-decoded request
* body. Only fired for non-streaming responses; streaming responses are
* handled in `handle-streaming-response.ts`.
*/
blockingResponseHandler?: ProxyResHandlerWithBody;
};
/**
* Returns a middleware function that accepts incoming requests and places them
* into the request queue. When the request is dequeued, it is proxied to the
* target URL using the given options and middleware. Non-streaming responses
* are handled by the given `blockingResponseHandler`.
*/
export function createQueuedProxyMiddleware({
target,
mutations,
blockingResponseHandler,
}: ProxyMiddlewareFactoryOptions) {
const hpmTarget = typeof target === "string" ? target : "https://setbyrouter";
const hpmRouter = typeof target === "function" ? target : undefined;
const [httpAgent, httpsAgent] = getHttpAgents();
const agent = hpmTarget.startsWith("http:") ? httpAgent : httpsAgent;
const proxyMiddleware = createProxyMiddleware<Request, Response>({
target: hpmTarget,
router: hpmRouter,
agent,
changeOrigin: true,
toProxy: true,
selfHandleResponse: typeof blockingResponseHandler === "function",
// Disable HPM logger plugin (requires re-adding the other default plugins).
// Contrary to name, debugProxyErrorsPlugin is not just for debugging and
// fixes several error handling/connection close issues in http-proxy core.
ejectPlugins: true,
// Inferred (via Options<express.Request>) as Plugin<express.Request>, but
// the default plugins only allow http.IncomingMessage for TReq. They are
// compatible with express.Request, so we can use them. `Plugin` type is not
// exported for some reason.
plugins: [
debugProxyErrorsPlugin,
pinoLoggerPlugin,
proxyEventsPlugin,
] as any,
on: {
proxyRes: createOnProxyResHandler(
blockingResponseHandler ? [blockingResponseHandler] : []
),
error: classifyErrorAndSend,
},
buffer: ((req: Request) => {
// This is a hack/monkey patch and is not part of the official
// http-proxy-middleware package. See patches/http-proxy+1.18.1.patch.
let payload = req.body;
if (typeof payload === "string") {
payload = Buffer.from(payload);
}
const stream = new Readable();
stream.push(payload);
stream.push(null);
return stream;
}) as any,
});
return createQueueMiddleware({ mutations, proxyMiddleware });
}
type ProxiedResponse = http.IncomingMessage & Response & any;
function pinoLoggerPlugin(proxyServer: ProxyServer<Request>) {
proxyServer.on("error", (err, req, res, target) => {
const originalUrl = req.originalUrl;
const targetUrl = target?.toString();
req.log.error(
{ originalUrl, targetUrl, err },
"Error occurred while proxying request to target"
);
});
proxyServer.on("proxyReq", (proxyReq, req, res) => {
const originalUrl = req.originalUrl;
const targetHost = `${proxyReq.protocol}//${proxyReq.host}`;
const targetPath = res.req.url;
req.log.info(
{ originalUrl, targetHost, targetPath },
"Sending request to upstream API..."
);
});
proxyServer.on("proxyRes", (proxyRes: ProxiedResponse, req, _res) => {
const originalUrl = req.originalUrl;
const targetHost = `${proxyRes.req.protocol}//${proxyRes.req.hostname}`;
const targetPath = proxyRes.req.path;
const statusCode = proxyRes.statusCode;
req.log.info(
{ originalUrl, targetHost, targetPath, statusCode },
"Got response from upstream API."
);
});
}
@@ -0,0 +1,112 @@
import { Request } from "express";
import { Key } from "../../../shared/key-management";
import { assertNever } from "../../../shared/utils";
/**
* Represents a change to the request that will be reverted if the request
* fails.
*/
interface ProxyReqMutation {
target: "header" | "path" | "body" | "api-key" | "signed-request";
key?: string;
originalValue: any | undefined;
}
/**
* Manages a request's headers, body, and path, allowing them to be modified
* before the request is proxied and automatically reverted if the request
* needs to be retried.
*/
export class ProxyReqManager {
private req: Request;
private mutations: ProxyReqMutation[] = [];
/**
* A read-only proxy of the request object. Avoid changing any properties
* here as they will persist across retries.
*/
public readonly request: Readonly<Request>;
constructor(req: Request) {
this.req = req;
this.request = new Proxy(req, {
get: (target, prop) => {
if (typeof prop === "string") return target[prop as keyof Request];
return undefined;
},
});
}
setHeader(name: string, newValue: string): void {
const originalValue = this.req.get(name);
this.mutations.push({ target: "header", key: name, originalValue });
this.req.headers[name.toLowerCase()] = newValue;
}
removeHeader(name: string): void {
const originalValue = this.req.get(name);
this.mutations.push({ target: "header", key: name, originalValue });
delete this.req.headers[name.toLowerCase()];
}
setBody(newBody: any): void {
const originalValue = this.req.body;
this.mutations.push({ target: "body", key: "body", originalValue });
this.req.body = newBody;
}
setKey(newKey: Key): void {
const originalValue = this.req.key;
this.mutations.push({ target: "api-key", key: "key", originalValue });
this.req.key = newKey;
}
setPath(newPath: string): void {
const originalValue = this.req.path;
this.mutations.push({ target: "path", key: "path", originalValue });
this.req.url = newPath;
}
setSignedRequest(newSignedRequest: typeof this.req.signedRequest): void {
const originalValue = this.req.signedRequest;
this.mutations.push({ target: "signed-request", key: "signedRequest", originalValue });
this.req.signedRequest = newSignedRequest;
}
hasChanged(): boolean {
return this.mutations.length > 0;
}
revert(): void {
for (const mutation of this.mutations.reverse()) {
switch (mutation.target) {
case "header":
if (mutation.originalValue === undefined) {
delete this.req.headers[mutation.key!.toLowerCase()];
continue;
} else {
this.req.headers[mutation.key!.toLowerCase()] =
mutation.originalValue;
}
break;
case "path":
this.req.url = mutation.originalValue;
break;
case "body":
this.req.body = mutation.originalValue;
break;
case "api-key":
// We don't reset the key here because it's not a property of the
// inbound request, so we'd only ever be reverting it to null.
break;
case "signed-request":
this.req.signedRequest = mutation.originalValue;
break;
default:
assertNever(mutation.target);
}
}
this.mutations = [];
}
}
@@ -2,36 +2,33 @@ import express from "express";
import { APIFormat } from "../../../shared/key-management"; import { APIFormat } from "../../../shared/key-management";
import { assertNever } from "../../../shared/utils"; import { assertNever } from "../../../shared/utils";
import { initializeSseStream } from "../../../shared/streaming"; import { initializeSseStream } from "../../../shared/streaming";
import http from "http";
function getMessageContent({ /**
title, * Returns a Markdown-formatted message that renders semi-nicely in most chat
message, * frontends. For example:
obj, *
}: { * **Proxy error (HTTP 404 Not Found)**
* The proxy encountered an error while trying to send your prompt to the upstream service. Further technical details are provided below.
* ***
* *The requested Claude model might not exist, or the key might not be provisioned for it.*
* ```
* {
* "type": "error",
* "error": {
* "type": "not_found_error",
* "message": "model: some-invalid-model-id",
* },
* "proxy_note": "The requested Claude model might not exist, or the key might not be provisioned for it."
* }
* ```
*/
function getMessageContent(params: {
title: string; title: string;
message: string; message: string;
obj?: Record<string, any>; obj?: Record<string, any>;
}) { }) {
/* const { title, message, obj } = params;
Constructs a Markdown-formatted message that renders semi-nicely in most chat
frontends. For example:
**Proxy error (HTTP 404 Not Found)**
The proxy encountered an error while trying to send your prompt to the upstream service. Further technical details are provided below.
***
*The requested Claude model might not exist, or the key might not be provisioned for it.*
```
{
"type": "error",
"error": {
"type": "not_found_error",
"message": "model: some-invalid-model-id",
},
"proxy_note": "The requested Claude model might not exist, or the key might not be provisioned for it."
}
```
*/
const note = obj?.proxy_note || obj?.error?.message || ""; const note = obj?.proxy_note || obj?.error?.message || "";
const header = `### **${title}**`; const header = `### **${title}**`;
const friendlyMessage = note ? `${message}\n\n----\n\n*${note}*` : message; const friendlyMessage = note ? `${message}\n\n----\n\n*${note}*` : message;
@@ -71,7 +68,11 @@ type ErrorGeneratorOptions = {
statusCode?: number; statusCode?: number;
}; };
export function tryInferFormat(body: any): APIFormat | "unknown" { /**
* Very crude inference of the request format based on the request body. Don't
* rely on this to be very accurate.
*/
function tryInferFormat(body: any): APIFormat | "unknown" {
if (typeof body !== "object" || !body.model) { if (typeof body !== "object" || !body.model) {
return "unknown"; return "unknown";
} }
@@ -95,7 +96,11 @@ export function tryInferFormat(body: any): APIFormat | "unknown" {
return "unknown"; return "unknown";
} }
// avoid leaking upstream hostname on dns resolution error /**
* Redacts the hostname from the error message if it contains a DNS resolution
* error. This is to avoid leaking upstream hostnames on DNS resolution errors,
* as those may contain sensitive information about the proxy's configuration.
*/
function redactHostname(options: ErrorGeneratorOptions): ErrorGeneratorOptions { function redactHostname(options: ErrorGeneratorOptions): ErrorGeneratorOptions {
if (!options.message.includes("getaddrinfo")) return options; if (!options.message.includes("getaddrinfo")) return options;
@@ -112,46 +117,61 @@ function redactHostname(options: ErrorGeneratorOptions): ErrorGeneratorOptions {
return redacted; return redacted;
} }
export function sendErrorToClient({ /**
options, * Generates an appropriately-formatted error response and sends it to the
req, * client over their requested transport (blocking or SSE stream).
res, */
}: { export function sendErrorToClient(params: {
options: ErrorGeneratorOptions; options: ErrorGeneratorOptions;
req: express.Request; req: express.Request;
res: express.Response; res: express.Response;
}) { }) {
const redactedOpts = redactHostname(options); const { req, res } = params;
const { format: inputFormat } = redactedOpts; const options = redactHostname(params.options);
const { statusCode, message, title, obj: details } = options;
// Since we want to send the error in a format the client understands, we
// need to know the request format. `setApiFormat` might not have been called
// yet, so we'll try to infer it from the request body.
const format = const format =
inputFormat === "unknown" ? tryInferFormat(req.body) : inputFormat; options.format === "unknown" ? tryInferFormat(req.body) : options.format;
if (format === "unknown") { if (format === "unknown") {
return res.status(redactedOpts.statusCode || 400).json({ // Early middleware error (auth, rate limit) so we can only send something
error: redactedOpts.message, // generic.
details: redactedOpts.obj, const code = statusCode || 400;
const hasDetails = details && Object.keys(details).length > 0;
return res.status(code).json({
error: {
message,
type: http.STATUS_CODES[code]!.replace(/\s+/g, "_").toLowerCase(),
},
...(hasDetails ? { details } : {}),
}); });
} }
const completion = buildSpoofedCompletion({ ...redactedOpts, format }); // Cannot modify headers if client opted into streaming and made it into the
const event = buildSpoofedSSE({ ...redactedOpts, format }); // proxy request queue, because that immediately starts an SSE stream.
const isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
if (!res.headersSent) { if (!res.headersSent) {
res.setHeader("x-oai-proxy-error", redactedOpts.title); res.setHeader("x-oai-proxy-error", title);
res.setHeader("x-oai-proxy-error-status", redactedOpts.statusCode || 500); res.setHeader("x-oai-proxy-error-status", statusCode || 500);
} }
// By this point, we know the request format. To get the error to display in
// chat clients' UIs, we'll send it as a 200 response as a spoofed completion
// from the language model. Depending on whether the client is streaming, we
// will either send an SSE event or a JSON response.
const isStreaming = req.isStreaming || String(req.body.stream) === "true";
if (isStreaming) { if (isStreaming) {
// User can have opted into streaming but not made it into the queue yet,
// in which case the stream must be started first.
if (!res.headersSent) { if (!res.headersSent) {
initializeSseStream(res); initializeSseStream(res);
} }
res.write(event); res.write(buildSpoofedSSE({ ...options, format }));
res.write(`data: [DONE]\n\n`); res.write(`data: [DONE]\n\n`);
res.end(); res.end();
} else { } else {
res.status(200).json(completion); res.status(200).json(buildSpoofedCompletion({ ...options, format }));
} }
} }
@@ -193,7 +213,7 @@ export function buildSpoofedCompletion({
return { return {
outputs: [{ text: content, stop_reason: title }], outputs: [{ text: content, stop_reason: title }],
model, model,
} };
case "openai-text": case "openai-text":
return { return {
id: "error-" + id, id: "error-" + id,
+29 -14
View File
@@ -47,7 +47,7 @@ export type ProxyResHandlerWithBody = (
*/ */
body: string | Record<string, any> body: string | Record<string, any>
) => Promise<void>; ) => Promise<void>;
export type ProxyResMiddleware = ProxyResHandlerWithBody[]; export type ProxyResMiddleware = ProxyResHandlerWithBody[] | undefined;
/** /**
* Returns a on.proxyRes handler that executes the given middleware stack after * Returns a on.proxyRes handler that executes the given middleware stack after
@@ -71,11 +71,22 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
req: Request, req: Request,
res: Response res: Response
) => { ) => {
const initialHandler: RawResponseBodyHandler = req.isStreaming // Proxied request has by now been sent to the upstream API, so we revert
// tracked mutations that were only needed to send the request.
// This generally means path adjustment, headers, and body serialization.
if (req.changeManager) {
req.changeManager.revert();
}
const initialHandler = req.isStreaming
? handleStreamedResponse ? handleStreamedResponse
: handleBlockingResponse; : handleBlockingResponse;
let lastMiddleware = initialHandler.name; let lastMiddleware = initialHandler.name;
if (Buffer.isBuffer(req.body)) {
req.body = JSON.parse(req.body.toString());
}
try { try {
const body = await initialHandler(proxyRes, req, res); const body = await initialHandler(proxyRes, req, res);
const middlewareStack: ProxyResMiddleware = []; const middlewareStack: ProxyResMiddleware = [];
@@ -100,7 +111,7 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
saveImage, saveImage,
logPrompt, logPrompt,
logEvent, logEvent,
...apiMiddleware ...(apiMiddleware ?? [])
); );
} }
@@ -723,22 +734,26 @@ const trackKeyRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers); keyPool.updateRateLimits(req.key!, proxyRes.headers);
}; };
const omittedHeaders = new Set<string>([
// Omit content-encoding because we will always decode the response body
"content-encoding",
// Omit transfer-encoding because we are using response.json which will
// set a content-length header, which is not valid for chunked responses.
"transfer-encoding",
// Don't set cookies from upstream APIs because proxied requests are stateless
"set-cookie",
"openai-organization",
"x-request-id",
"cf-ray",
]);
const copyHttpHeaders: ProxyResHandlerWithBody = async ( const copyHttpHeaders: ProxyResHandlerWithBody = async (
proxyRes, proxyRes,
_req, _req,
res res
) => { ) => {
Object.keys(proxyRes.headers).forEach((key) => { Object.keys(proxyRes.headers).forEach((key) => {
// Omit content-encoding because we will always decode the response body if (omittedHeaders.has(key)) return;
if (key === "content-encoding") {
return;
}
// We're usually using res.json() to send the response, which causes express
// to set content-length. That's not valid for chunked responses and some
// clients will reject it so we need to omit it.
if (key === "transfer-encoding") {
return;
}
res.setHeader(key, proxyRes.headers[key] as string); res.setHeader(key, proxyRes.headers[key] as string);
}); });
}; };
@@ -782,6 +797,6 @@ function getAwsErrorType(header: string | string[] | undefined) {
function assertJsonResponse(body: any): asserts body is Record<string, any> { function assertJsonResponse(body: any): asserts body is Record<string, any> {
if (typeof body !== "object") { if (typeof body !== "object") {
throw new Error("Expected response to be an object"); throw new Error(`Expected response to be an object, got ${typeof body}`);
} }
} }
+9 -26
View File
@@ -1,27 +1,20 @@
import express, { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware"; import { BadRequestError } from "../shared/errors";
import { config } from "../config";
import { keyPool } from "../shared/key-management"; import { keyPool } from "../shared/key-management";
import { import {
getMistralAIModelFamily, getMistralAIModelFamily,
MistralAIModelFamily, MistralAIModelFamily,
ModelFamily, ModelFamily,
} from "../shared/models"; } from "../shared/models";
import { logger } from "../logger"; import { config } from "../config";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
addKey, addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeBody, finalizeBody,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
createOnProxyResHandler, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
ProxyResHandlerWithBody,
} from "./middleware/response";
import { BadRequestError } from "../shared/errors";
// Mistral can't settle on a single naming scheme and deprecates models within // Mistral can't settle on a single naming scheme and deprecates models within
// months of releasing them so this list is hard to keep up to date. 2024-07-28 // months of releasing them so this list is hard to keep up to date. 2024-07-28
@@ -127,20 +120,10 @@ export function transformMistralTextToMistralChat(textBody: any) {
}; };
} }
const mistralAIProxy = createQueueMiddleware({ const mistralAIProxy = createQueuedProxyMiddleware({
proxyMiddleware: createProxyMiddleware({ target: "https://api.mistral.ai",
target: "https://api.mistral.ai", mutations: [addKey, finalizeBody],
changeOrigin: true, blockingResponseHandler: mistralAIResponseHandler,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([mistralAIResponseHandler]),
error: handleProxyError,
},
}),
}); });
const mistralAIRouter = Router(); const mistralAIRouter = Router();
+19 -28
View File
@@ -1,22 +1,15 @@
import { RequestHandler, Router, Request } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware"; import { OpenAIImageGenerationResult } from "../shared/file-storage/mirror-generated-image";
import { config } from "../config"; import { generateModelList } from "./openai";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
addKey, addKey,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeBody, finalizeBody,
createOnProxyReqHandler,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
createOnProxyResHandler, import { ProxyReqManager } from "./middleware/request/proxy-req-manager";
ProxyResHandlerWithBody, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
} from "./middleware/response";
import { generateModelList } from "./openai";
import { OpenAIImageGenerationResult } from "../shared/file-storage/mirror-generated-image";
const KNOWN_MODELS = ["dall-e-2", "dall-e-3"]; const KNOWN_MODELS = ["dall-e-2", "dall-e-3"];
@@ -96,21 +89,19 @@ function transformResponseForChat(
}; };
} }
const openaiImagesProxy = createQueueMiddleware({ function replacePath(manager: ProxyReqManager) {
proxyMiddleware: createProxyMiddleware({ const req = manager.request;
target: "https://api.openai.com", const pathname = req.url.split("?")[0];
changeOrigin: true, req.log.debug({ pathname }, "OpenAI image path filter");
selfHandleResponse: true, if (req.path.startsWith("/v1/chat/completions")) {
logger, manager.setPath("/v1/images/generations");
pathRewrite: { }
"^/v1/chat/completions": "/v1/images/generations", }
},
on: { const openaiImagesProxy = createQueuedProxyMiddleware({
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }), target: "https://api.openai.com",
proxyRes: createOnProxyResHandler([openaiImagesResponseHandler]), mutations: [replacePath, addKey, finalizeBody],
error: handleProxyError, blockingResponseHandler: openaiImagesResponseHandler,
},
}),
}); });
const openaiImagesRouter = Router(); const openaiImagesRouter = Router();
+12 -33
View File
@@ -1,26 +1,18 @@
import { Request, RequestHandler, Router } from "express"; import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config"; import { config } from "../config";
import { AzureOpenAIKey, keyPool, OpenAIKey } from "../shared/key-management"; import { AzureOpenAIKey, keyPool, OpenAIKey } from "../shared/key-management";
import { getOpenAIModelFamily } from "../shared/models"; import { getOpenAIModelFamily } from "../shared/models";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit"; import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import { import {
addKey, addKey,
addKeyForEmbeddingsRequest, addKeyForEmbeddingsRequest,
createEmbeddingsPreprocessorMiddleware, createEmbeddingsPreprocessorMiddleware,
createOnProxyReqHandler,
createPreprocessorMiddleware, createPreprocessorMiddleware,
finalizeBody, finalizeBody,
forceModel,
RequestPreprocessor, RequestPreprocessor,
} from "./middleware/request"; } from "./middleware/request";
import { import { ProxyResHandlerWithBody } from "./middleware/response";
createOnProxyResHandler, import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
ProxyResHandlerWithBody,
} from "./middleware/response";
// https://platform.openai.com/docs/models/overview // https://platform.openai.com/docs/models/overview
let modelsCache: any = null; let modelsCache: any = null;
@@ -126,7 +118,6 @@ const openaiResponseHandler: ProxyResHandlerWithBody = async (
res.status(200).json({ ...newBody, proxy: body.proxy }); res.status(200).json({ ...newBody, proxy: body.proxy });
}; };
/** Only used for non-streaming responses. */
function transformTurboInstructResponse( function transformTurboInstructResponse(
turboInstructBody: Record<string, any> turboInstructBody: Record<string, any>
): Record<string, any> { ): Record<string, any> {
@@ -144,31 +135,15 @@ function transformTurboInstructResponse(
return transformed; return transformed;
} }
const openaiProxy = createQueueMiddleware({ const openaiProxy = createQueuedProxyMiddleware({
proxyMiddleware: createProxyMiddleware({ mutations: [addKey, finalizeBody],
target: "https://api.openai.com", target: "https://api.openai.com",
changeOrigin: true, blockingResponseHandler: openaiResponseHandler,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [addKey, finalizeBody] }),
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
error: handleProxyError,
},
}),
}); });
const openaiEmbeddingsProxy = createProxyMiddleware({ const openaiEmbeddingsProxy = createQueuedProxyMiddleware({
mutations: [addKeyForEmbeddingsRequest, finalizeBody],
target: "https://api.openai.com", target: "https://api.openai.com",
changeOrigin: true,
selfHandleResponse: false,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [addKeyForEmbeddingsRequest, finalizeBody],
}),
error: handleProxyError,
},
}); });
const openaiRouter = Router(); const openaiRouter = Router();
@@ -215,6 +190,10 @@ openaiRouter.post(
openaiEmbeddingsProxy openaiEmbeddingsProxy
); );
function forceModel(model: string): RequestPreprocessor {
return (req: Request) => void (req.body.model = model);
}
function fixupMaxTokens(req: Request) { function fixupMaxTokens(req: Request) {
if (!req.body.max_completion_tokens) { if (!req.body.max_completion_tokens) {
req.body.max_completion_tokens = req.body.max_tokens; req.body.max_completion_tokens = req.body.max_tokens;
+40 -15
View File
@@ -24,9 +24,10 @@ import {
import { initializeSseStream } from "../shared/streaming"; import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger"; import { logger } from "../logger";
import { getUniqueIps } from "./rate-limit"; import { getUniqueIps } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request"; import { ProxyReqMutator, RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
import { sendErrorToClient } from "./middleware/response/error-generator"; import { sendErrorToClient } from "./middleware/response/error-generator";
import { ProxyReqManager } from "./middleware/request/proxy-req-manager";
import { classifyErrorAndSend } from "./middleware/common";
const queue: Request[] = []; const queue: Request[] = [];
const log = logger.child({ module: "request-queue" }); const log = logger.child({ module: "request-queue" });
@@ -67,6 +68,14 @@ const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
getIdentifier(queued) === getIdentifier(incoming); getIdentifier(queued) === getIdentifier(incoming);
async function enqueue(req: Request) { async function enqueue(req: Request) {
if (req.socket.destroyed || req.res?.writableEnded) {
// In rare cases, a request can be disconnected after it is dequeued for a
// retry, but before it is re-enqueued. In this case we may miss the abort
// and the request will loop in the queue forever.
req.log.warn("Attempt to enqueue aborted request.");
throw new Error("Attempt to enqueue aborted request.");
}
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length; const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) { if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT) {
@@ -139,7 +148,14 @@ export function dequeue(partition: ModelFamily): Request | undefined {
} }
const req = modelQueue.reduce((prev, curr) => const req = modelQueue.reduce((prev, curr) =>
prev.startTime + config.tokensPunishmentFactor*((prev.promptTokens ?? 0) + (prev.outputTokens ?? 0)) < curr.startTime + config.tokensPunishmentFactor*((curr.promptTokens ?? 0) + (curr.outputTokens ?? 0)) ? prev : curr prev.startTime +
config.tokensPunishmentFactor *
((prev.promptTokens ?? 0) + (prev.outputTokens ?? 0)) <
curr.startTime +
config.tokensPunishmentFactor *
((curr.promptTokens ?? 0) + (curr.outputTokens ?? 0))
? prev
: curr
); );
queue.splice(queue.indexOf(req), 1); queue.splice(queue.indexOf(req), 1);
@@ -306,26 +322,35 @@ export function getQueueLength(partition: ModelFamily | "all" = "all") {
} }
export function createQueueMiddleware({ export function createQueueMiddleware({
beforeProxy, mutations = [],
proxyMiddleware, proxyMiddleware,
}: { }: {
beforeProxy?: RequestPreprocessor; mutations?: ProxyReqMutator[];
proxyMiddleware: Handler; proxyMiddleware: Handler;
}): Handler { }): Handler {
return async (req, res, next) => { return async (req, res, next) => {
req.proceed = async () => { req.proceed = async () => {
if (beforeProxy) { // canonicalize the stream field which is set in a few places not always
try { // consistently
// Hack to let us run asynchronous middleware before the req.isStreaming = req.isStreaming || String(req.body.stream) === "true";
// http-proxy-middleware handler. This is used to sign AWS requests req.body.stream = req.isStreaming;
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is try {
// dequeued, not just the first time. // Just before executing the proxyMiddleware, we will create a
await beforeProxy(req); // ProxyReqManager to track modifications to the request. This allows
} catch (err) { // us to revert those changes if the proxied request fails with a
return handleProxyError(err, req, res); // retryable error. That happens in proxyMiddleware's onProxyRes
// handler.
const changeManager = new ProxyReqManager(req);
req.changeManager = changeManager;
for (const mutator of mutations) {
await mutator(changeManager);
} }
} catch (err) {
// Failure during request preparation is a fatal error.
return classifyErrorAndSend(err, req, res);
} }
proxyMiddleware(req, res, next); proxyMiddleware(req, res, next);
}; };
+8 -1
View File
@@ -23,6 +23,7 @@ import { init as initTokenizers } from "./shared/tokenization";
import { checkOrigin } from "./proxy/check-origin"; import { checkOrigin } from "./proxy/check-origin";
import { sendErrorToClient } from "./proxy/middleware/response/error-generator"; import { sendErrorToClient } from "./proxy/middleware/response/error-generator";
import { initializeDatabase, getDatabase } from "./shared/database"; import { initializeDatabase, getDatabase } from "./shared/database";
import { initializeFirebase } from "./shared/firebase";
const PORT = config.port; const PORT = config.port;
const BIND_ADDRESS = config.bindAddress; const BIND_ADDRESS = config.bindAddress;
@@ -137,6 +138,12 @@ async function start() {
logger.info("Checking configs and external dependencies..."); logger.info("Checking configs and external dependencies...");
await assertConfigIsValid(); await assertConfigIsValid();
if (config.gatekeeperStore.startsWith("firebase")) {
logger.info("Testing Firebase connection...");
await initializeFirebase();
logger.info("Firebase connection successful.");
}
keyPool.init(); keyPool.init();
await initTokenizers(); await initTokenizers();
@@ -166,7 +173,7 @@ async function start() {
app.listen(PORT, BIND_ADDRESS, () => { app.listen(PORT, BIND_ADDRESS, () => {
logger.info( logger.info(
{ port: PORT, interface: BIND_ADDRESS }, { port: PORT, interface: BIND_ADDRESS },
"Now listening for connections." "Server ready to accept connections."
); );
registerUncaughtExceptionHandler(); registerUncaughtExceptionHandler();
}); });
+2
View File
@@ -5,6 +5,7 @@ import { Express } from "express-serve-static-core";
import { APIFormat, Key } from "./key-management"; import { APIFormat, Key } from "./key-management";
import { User } from "./users/schema"; import { User } from "./users/schema";
import { LLMService, ModelFamily } from "./models"; import { LLMService, ModelFamily } from "./models";
import { ProxyReqManager } from "../proxy/middleware/request/proxy-req-manager";
declare global { declare global {
namespace Express { namespace Express {
@@ -24,6 +25,7 @@ declare global {
queueOutTime?: number; queueOutTime?: number;
onAborted?: () => void; onAborted?: () => void;
proceed: () => void; proceed: () => void;
changeManager?: ProxyReqManager;
heartbeatInterval?: NodeJS.Timeout; heartbeatInterval?: NodeJS.Timeout;
monitorInterval?: NodeJS.Timeout; monitorInterval?: NodeJS.Timeout;
promptTokens?: number; promptTokens?: number;
@@ -1,12 +1,14 @@
import axios from "axios";
import express from "express"; import express from "express";
import { promises as fs } from "fs"; import { promises as fs } from "fs";
import path from "path"; import path from "path";
import { v4 } from "uuid"; import { v4 } from "uuid";
import { USER_ASSETS_DIR } from "../../config"; import { USER_ASSETS_DIR } from "../../config";
import { getAxiosInstance } from "../network";
import { addToImageHistory } from "./image-history"; import { addToImageHistory } from "./image-history";
import { libSharp } from "./index"; import { libSharp } from "./index";
const axios = getAxiosInstance();
export type OpenAIImageGenerationResult = { export type OpenAIImageGenerationResult = {
created: number; created: number;
data: { data: {
+28
View File
@@ -0,0 +1,28 @@
import type firebase from "firebase-admin";
import { config } from "../config";
import { getHttpAgents } from "./network";
let firebaseApp: firebase.app.App | undefined;
export async function initializeFirebase() {
const firebase = await import("firebase-admin");
const firebaseKey = Buffer.from(config.firebaseKey!, "base64").toString();
const app = firebase.initializeApp({
// RTDB doesn't actually seem to use this but respects `WS_PROXY` if set,
// so we do that in the network module.
httpAgent: getHttpAgents()[0],
credential: firebase.credential.cert(JSON.parse(firebaseKey)),
databaseURL: config.firebaseRtdbUrl,
});
await app.database().ref("connection-test").set(Date.now());
firebaseApp = app;
}
export function getFirebaseApp(): firebase.app.App {
if (!firebaseApp) {
throw new Error("Firebase app not initialized.");
}
return firebaseApp;
}
@@ -1,7 +1,10 @@
import axios, { AxiosError, AxiosResponse } from "axios"; import { AxiosError, AxiosResponse } from "axios";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { AnthropicKey, AnthropicKeyProvider } from "./provider"; import type { AnthropicKey, AnthropicKeyProvider } from "./provider";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 1000 * 60 * 60 * 6; // 6 hours const KEY_CHECK_PERIOD = 1000 * 60 * 60 * 6; // 6 hours
const POST_MESSAGES_URL = "https://api.anthropic.com/v1/messages"; const POST_MESSAGES_URL = "https://api.anthropic.com/v1/messages";
+4 -1
View File
@@ -1,13 +1,16 @@
import { Sha256 } from "@aws-crypto/sha256-js"; import { Sha256 } from "@aws-crypto/sha256-js";
import { SignatureV4 } from "@smithy/signature-v4"; import { SignatureV4 } from "@smithy/signature-v4";
import { HttpRequest } from "@smithy/protocol-http"; import { HttpRequest } from "@smithy/protocol-http";
import axios, { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios"; import { AxiosError, AxiosHeaders, AxiosRequestConfig } from "axios";
import { URL } from "url"; import { URL } from "url";
import { config } from "../../../config"; import { config } from "../../../config";
import { getAwsBedrockModelFamily } from "../../models"; import { getAwsBedrockModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider"; import type { AwsBedrockKey, AwsBedrockKeyProvider } from "./provider";
const axios = getAxiosInstance();
type ParentModelId = string; type ParentModelId = string;
type AliasModelId = string; type AliasModelId = string;
type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]]; type ModuleAliasTuple = [ParentModelId, ...AliasModelId[]];
+5 -2
View File
@@ -1,7 +1,10 @@
import axios, { AxiosError } from "axios"; import { AxiosError } from "axios";
import { getAzureOpenAIModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider"; import type { AzureOpenAIKey, AzureOpenAIKeyProvider } from "./provider";
import { getAzureOpenAIModelFamily } from "../../models";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
+5 -2
View File
@@ -1,8 +1,11 @@
import axios, { AxiosError } from "axios"; import { AxiosError } from "axios";
import crypto from "crypto"; import crypto from "crypto";
import { GcpModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { GcpKey, GcpKeyProvider } from "./provider"; import type { GcpKey, GcpKeyProvider } from "./provider";
import { GcpModelFamily } from "../../models";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes const KEY_CHECK_PERIOD = 90 * 60 * 1000; // 90 minutes
@@ -1,8 +1,10 @@
import axios, { AxiosError } from "axios"; import { AxiosError } from "axios";
import type { GoogleAIModelFamily } from "../../models"; import { GoogleAIModelFamily, getGoogleAIModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { GoogleAIKey, GoogleAIKeyProvider } from "./provider"; import type { GoogleAIKey, GoogleAIKeyProvider } from "./provider";
import { getGoogleAIModelFamily } from "../../models";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 3 * 60 * 60 * 1000; // 3 hours const KEY_CHECK_PERIOD = 3 * 60 * 60 * 1000; // 3 hours
@@ -1,8 +1,10 @@
import axios, { AxiosError } from "axios"; import { AxiosError } from "axios";
import type { MistralAIModelFamily } from "../../models"; import { MistralAIModelFamily, getMistralAIModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { MistralAIKey, MistralAIKeyProvider } from "./provider"; import type { MistralAIKey, MistralAIKeyProvider } from "./provider";
import { getMistralAIModelFamily } from "../../models";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
+5 -3
View File
@@ -1,8 +1,10 @@
import axios, { AxiosError } from "axios"; import { AxiosError } from "axios";
import type { OpenAIModelFamily } from "../../models";
import { KeyCheckerBase } from "../key-checker-base"; import { KeyCheckerBase } from "../key-checker-base";
import type { OpenAIKey, OpenAIKeyProvider } from "./provider"; import type { OpenAIKey, OpenAIKeyProvider } from "./provider";
import { getOpenAIModelFamily } from "../../models"; import { OpenAIModelFamily, getOpenAIModelFamily } from "../../models";
import { getAxiosInstance } from "../../network";
const axios = getAxiosInstance();
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour const KEY_CHECK_PERIOD = 60 * 60 * 1000; // 1 hour
+69
View File
@@ -0,0 +1,69 @@
import axios, { AxiosInstance } from "axios";
import http from "http";
import https from "https";
import os from "os";
import { ProxyAgent } from "proxy-agent";
import { config } from "../config";
import { logger } from "../logger";
const log = logger.child({ module: "network" });
export type HttpAgent = http.Agent | https.Agent;
/** HTTP agent used by http-proxy-middleware when forwarding requests. */
let httpAgent: HttpAgent;
/** HTTPS agent used by http-proxy-middleware when forwarding requests. */
let httpsAgent: HttpAgent;
/** Axios instance used for any non-proxied requests. */
let axiosInstance: AxiosInstance;
function getInterfaceAddress(iface: string) {
const ifaces = os.networkInterfaces();
log.debug({ ifaces, iface }, "Found network interfaces.");
if (!ifaces[iface]) {
throw new Error(`Interface ${iface} not found.`);
}
const addresses = ifaces[iface]!.filter(
({ family, internal }) => family === "IPv4" && !internal
);
if (addresses.length === 0) {
throw new Error(`Interface ${iface} has no external IPv4 addresses.`);
}
log.debug({ selected: addresses[0] }, "Selected network interface.");
return addresses[0].address;
}
export function getHttpAgents() {
if (httpAgent) return [httpAgent, httpsAgent];
const { interface: iface, proxyUrl } = config.httpAgent || {};
if (iface) {
const address = getInterfaceAddress(iface);
httpAgent = new http.Agent({ localAddress: address, keepAlive: true });
httpsAgent = new https.Agent({ localAddress: address, keepAlive: true });
log.info({ address }, "Using configured interface for outgoing requests.");
} else if (proxyUrl) {
process.env.HTTP_PROXY = proxyUrl;
process.env.HTTPS_PROXY = proxyUrl;
process.env.WS_PROXY = proxyUrl;
httpAgent = new ProxyAgent();
httpsAgent = httpAgent; // ProxyAgent automatically handles HTTPS
const proxy = proxyUrl.replace(/:.*@/, "@******");
log.info({ proxy }, "Using proxy server for outgoing requests.");
} else {
httpAgent = new http.Agent();
httpsAgent = new https.Agent();
}
return [httpAgent, httpsAgent];
}
export function getAxiosInstance() {
if (axiosInstance) return axiosInstance;
const [httpAgent, httpsAgent] = getHttpAgents();
axiosInstance = axios.create({ httpAgent, httpsAgent, proxy: false });
return axiosInstance;
}
+5 -4
View File
@@ -10,7 +10,10 @@
import admin from "firebase-admin"; import admin from "firebase-admin";
import schedule from "node-schedule"; import schedule from "node-schedule";
import { v4 as uuid } from "uuid"; import { v4 as uuid } from "uuid";
import { config, getFirebaseApp } from "../../config"; import { config } from "../../config";
import { logger } from "../../logger";
import { getFirebaseApp } from "../firebase";
import { APIFormat } from "../key-management";
import { import {
getAwsBedrockModelFamily, getAwsBedrockModelFamily,
getGcpModelFamily, getGcpModelFamily,
@@ -22,10 +25,8 @@ import {
MODEL_FAMILIES, MODEL_FAMILIES,
ModelFamily, ModelFamily,
} from "../models"; } from "../models";
import { logger } from "../../logger";
import { User, UserTokenCounts, UserUpdate } from "./schema";
import { APIFormat } from "../key-management";
import { assertNever } from "../utils"; import { assertNever } from "../utils";
import { User, UserTokenCounts, UserUpdate } from "./schema";
const log = logger.child({ module: "users" }); const log = logger.child({ module: "users" });