This commit is contained in:
khanon
2023-12-04 04:21:18 +00:00
parent cd1b9d0e0c
commit fbdea30264
31 changed files with 1237 additions and 216 deletions
+8 -4
View File
@@ -11,7 +11,7 @@ import {
createPreprocessorMiddleware,
stripHeaders,
signAwsRequest,
finalizeAwsRequest,
finalizeSignedRequest,
createOnProxyReqHandler,
blockZoomerOrigins,
} from "./middleware/request";
@@ -30,7 +30,11 @@ const getModelsResponse = () => {
if (!config.awsCredentials) return { object: "list", data: [] };
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
const variants = [
"anthropic.claude-v1",
"anthropic.claude-v2",
"anthropic.claude-v2:1",
];
const models = variants.map((id) => ({
id,
@@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({
applyQuotaLimits,
blockZoomerOrigins,
stripHeaders,
finalizeAwsRequest,
finalizeSignedRequest,
],
}),
proxyRes: createOnProxyResHandler([awsResponseHandler]),
@@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) {
req.body.model = "anthropic.claude-v1";
} else {
// User's client requested v2 or possibly some OpenAI model, default to v2
req.body.model = "anthropic.claude-v2";
req.body.model = "anthropic.claude-v2:1";
}
// TODO: Handle claude-instant
}
+140
View File
@@ -0,0 +1,140 @@
import { RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { config } from "../config";
import { keyPool } from "../shared/key-management";
import {
ModelFamily,
AzureOpenAIModelFamily,
getAzureOpenAIModelFamily,
} from "../shared/models";
import { logger } from "../logger";
import { KNOWN_OPENAI_MODELS } from "./openai";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
applyQuotaLimits,
blockZoomerOrigins,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
limitCompletions,
stripHeaders,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addAzureKey } from "./middleware/request/add-azure-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
function getModelsResponse() {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
let available = new Set<AzureOpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "azure") continue;
key.modelFamilies.forEach((family) =>
available.add(family as AzureOpenAIModelFamily)
);
}
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
available = new Set([...available].filter((x) => allowed.has(x)));
const models = KNOWN_OPENAI_MODELS.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "azure",
permission: [
{
id: "modelperm-" + id,
object: "model_permission",
created: new Date().getTime(),
organization: "*",
group: null,
is_blocking: false,
},
],
root: id,
parent: null,
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
const azureOpenAIProxy = createQueueMiddleware({
beforeProxy: addAzureKey,
proxyMiddleware: createProxyMiddleware({
target: "will be set by router",
router: (req) => {
if (!req.signedRequest) throw new Error("signedRequest not set");
const { hostname, path } = req.signedRequest;
return `https://${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [
applyQuotaLimits,
limitCompletions,
blockZoomerOrigins,
stripHeaders,
finalizeSignedRequest,
],
}),
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
error: handleProxyError,
},
}),
});
const azureOpenAIRouter = Router();
azureOpenAIRouter.get("/v1/models", handleModelRequest);
azureOpenAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware({
inApi: "openai",
outApi: "openai",
service: "azure",
}),
azureOpenAIProxy
);
export const azure = azureOpenAIRouter;
+1 -1
View File
@@ -59,7 +59,7 @@ export function writeErrorResponse(
res.write(`data: [DONE]\n\n`);
res.end();
} else {
if (req.tokenizerInfo && errorPayload.error) {
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
}
res.status(statusCode).json(errorPayload);
@@ -0,0 +1,50 @@
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
import { RequestPreprocessor } from ".";
export const addAzureKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
const serviceValid = req.service === "azure";
if (!apisValid || !serviceValid) {
throw new Error("addAzureKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model.startsWith("azure-")
? req.body.model
: `azure-${req.body.model}`;
req.key = keyPool.get(model);
req.body.model = model;
req.log.info(
{ key: req.key.hash, model },
"Assigned Azure OpenAI key to request"
);
const cred = req.key as AzureOpenAIKey;
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: `${resourceName}.openai.azure.com`,
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
headers: {
["host"]: `${resourceName}.openai.azure.com`,
["content-type"]: "application/json",
["api-key"]: apiKey,
},
body: JSON.stringify(req.body),
};
};
function getCredentialsFromKey(key: AzureOpenAIKey) {
const [resourceName, deploymentId, apiKey] = key.key.split(":");
if (!resourceName || !deploymentId || !apiKey) {
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
}
return { resourceName, deploymentId, apiKey };
}
+4
View File
@@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
`?key=${assignedKey.key}`
);
break;
case "azure":
const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
throw new Error(
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
@@ -1,11 +1,11 @@
import type { ProxyRequestMiddleware } from ".";
/**
* For AWS requests, the body is signed earlier in the request pipeline, before
* the proxy middleware. This function just assigns the path and headers to the
* proxy request.
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
* before the proxy middleware. This function just assigns the path and headers
* to the proxy request.
*/
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
if (!req.signedRequest) {
throw new Error("Expected req.signedRequest to be set");
}
+1 -1
View File
@@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
export { addAnthropicPreamble } from "./add-anthropic-preamble";
export { blockZoomerOrigins } from "./block-zoomer-origins";
export { finalizeBody } from "./finalize-body";
export { finalizeAwsRequest } from "./finalize-aws-request";
export { finalizeSignedRequest } from "./finalize-signed-request";
export { limitCompletions } from "./limit-completions";
export { stripHeaders } from "./strip-headers";
+11 -4
View File
@@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
switch (service) {
case "openai":
case "google-palm":
if (errorPayload.error?.code === "content_policy_violation") {
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) {
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
refundLastAttempt(req);
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
// For some reason, some models return this 400 error instead of the
// same 429 billing error that other models return.
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
}
break;
case "anthropic":
@@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
handleAwsRateLimitError(req, errorPayload);
break;
case "google-palm":
throw new Error("Rate limit handling not implemented for PaLM");
case "azure":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
break;
default:
assertNever(service);
}
@@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
case "azure":
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
break;
default:
assertNever(service);
}
@@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & {
export class SSEMessageTransformer extends Transform {
private lastPosition: number;
private msgCount: number;
private readonly inputFormat: APIFormat;
private readonly transformFn: StreamingCompletionTransformer;
private readonly log;
private readonly fallbackId: string;
@@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform {
options.inputFormat,
options.inputApiVersion
);
this.inputFormat = options.inputFormat;
this.fallbackId = options.requestId;
this.fallbackModel = options.requestedModel;
this.log.debug(
@@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform {
});
this.lastPosition = newPosition;
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
// sometimes emits an extra event at the beginning of the stream with the
// content moderation system's response to the prompt. A lot of frontends
// don't expect this and neither does our event aggregator so we drop it.
if (this.inputFormat === "openai" && this.msgCount <= 1) {
if (originalMessage.includes("prompt_filter_results")) {
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
return callback();
}
}
this.emit("originalMessage", originalMessage);
// Some events may not be transformed, e.g. ping events
+2 -2
View File
@@ -24,7 +24,7 @@ import {
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
// https://platform.openai.com/docs/models/overview
const KNOWN_MODELS = [
export const KNOWN_OPENAI_MODELS = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
@@ -46,7 +46,7 @@ const KNOWN_MODELS = [
let modelsCache: any = null;
let modelsCacheTime = 0;
export function generateModelList(models = KNOWN_MODELS) {
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
let available = new Set<OpenAIModelFamily>();
for (const key of keyPool.list()) {
if (key.isDisabled || key.service !== "openai") continue;
+12 -7
View File
@@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
import { handleProxyError } from "./middleware/common";
const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });
@@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" });
const AGNAI_CONCURRENCY_LIMIT = 5;
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1;
const MIN_HEARTBEAT_SIZE = 512;
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
const MAX_HEARTBEAT_SIZE =
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
const HEARTBEAT_INTERVAL =
@@ -358,12 +359,16 @@ export function createQueueMiddleware({
return (req, res, next) => {
req.proceed = async () => {
if (beforeProxy) {
// Hack to let us run asynchronous middleware before the
// http-proxy-middleware handler. This is used to sign AWS requests
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is
// dequeued, not just the first time.
await beforeProxy(req);
try {
// Hack to let us run asynchronous middleware before the
// http-proxy-middleware handler. This is used to sign AWS requests
// before they are proxied, as the signing is asynchronous.
// Unlike RequestPreprocessors, this runs every time the request is
// dequeued, not just the first time.
await beforeProxy(req);
} catch (err) {
return handleProxyError(err, req, res);
}
}
proxyMiddleware(req, res, next);
};
+2
View File
@@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { aws } from "./aws";
import { azure } from "./azure";
const proxyRouter = express.Router();
proxyRouter.use((req, _res, next) => {
@@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
proxyRouter.get("*", (req, res, next) => {
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");