Azure OpenAI suport (khanon/oai-reverse-proxy!48)
This commit is contained in:
+8
-4
@@ -11,7 +11,7 @@ import {
|
||||
createPreprocessorMiddleware,
|
||||
stripHeaders,
|
||||
signAwsRequest,
|
||||
finalizeAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
createOnProxyReqHandler,
|
||||
blockZoomerOrigins,
|
||||
} from "./middleware/request";
|
||||
@@ -30,7 +30,11 @@ const getModelsResponse = () => {
|
||||
|
||||
if (!config.awsCredentials) return { object: "list", data: [] };
|
||||
|
||||
const variants = ["anthropic.claude-v1", "anthropic.claude-v2"];
|
||||
const variants = [
|
||||
"anthropic.claude-v1",
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-v2:1",
|
||||
];
|
||||
|
||||
const models = variants.map((id) => ({
|
||||
id,
|
||||
@@ -134,7 +138,7 @@ const awsProxy = createQueueMiddleware({
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeAwsRequest,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
@@ -183,7 +187,7 @@ function maybeReassignModel(req: Request) {
|
||||
req.body.model = "anthropic.claude-v1";
|
||||
} else {
|
||||
// User's client requested v2 or possibly some OpenAI model, default to v2
|
||||
req.body.model = "anthropic.claude-v2";
|
||||
req.body.model = "anthropic.claude-v2:1";
|
||||
}
|
||||
// TODO: Handle claude-instant
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import { RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { config } from "../config";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
ModelFamily,
|
||||
AzureOpenAIModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
} from "../shared/models";
|
||||
import { logger } from "../logger";
|
||||
import { KNOWN_OPENAI_MODELS } from "./openai";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addAzureKey } from "./middleware/request/add-azure-key";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
function getModelsResponse() {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
let available = new Set<AzureOpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "azure") continue;
|
||||
key.modelFamilies.forEach((family) =>
|
||||
available.add(family as AzureOpenAIModelFamily)
|
||||
);
|
||||
}
|
||||
const allowed = new Set<ModelFamily>(config.allowedModelFamilies);
|
||||
available = new Set([...available].filter((x) => allowed.has(x)));
|
||||
|
||||
const models = KNOWN_OPENAI_MODELS.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "azure",
|
||||
permission: [
|
||||
{
|
||||
id: "modelperm-" + id,
|
||||
object: "model_permission",
|
||||
created: new Date().getTime(),
|
||||
organization: "*",
|
||||
group: null,
|
||||
is_blocking: false,
|
||||
},
|
||||
],
|
||||
root: id,
|
||||
parent: null,
|
||||
})).filter((model) => available.has(getAzureOpenAIModelFamily(model.id)));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
const azureOpenaiResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
const azureOpenAIProxy = createQueueMiddleware({
|
||||
beforeProxy: addAzureKey,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "will be set by router",
|
||||
router: (req) => {
|
||||
if (!req.signedRequest) throw new Error("signedRequest not set");
|
||||
const { hostname, path } = req.signedRequest;
|
||||
return `https://${hostname}${path}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
applyQuotaLimits,
|
||||
limitCompletions,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
finalizeSignedRequest,
|
||||
],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([azureOpenaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const azureOpenAIRouter = Router();
|
||||
azureOpenAIRouter.get("/v1/models", handleModelRequest);
|
||||
azureOpenAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware({
|
||||
inApi: "openai",
|
||||
outApi: "openai",
|
||||
service: "azure",
|
||||
}),
|
||||
azureOpenAIProxy
|
||||
);
|
||||
|
||||
export const azure = azureOpenAIRouter;
|
||||
@@ -59,7 +59,7 @@ export function writeErrorResponse(
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
if (req.tokenizerInfo && errorPayload.error) {
|
||||
if (req.tokenizerInfo && typeof errorPayload.error === "object") {
|
||||
errorPayload.error.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
res.status(statusCode).json(errorPayload);
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import { AzureOpenAIKey, keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
export const addAzureKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "openai";
|
||||
const serviceValid = req.service === "azure";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addAzureKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model.startsWith("azure-")
|
||||
? req.body.model
|
||||
: `azure-${req.body.model}`;
|
||||
|
||||
req.key = keyPool.get(model);
|
||||
req.body.model = model;
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Azure OpenAI key to request"
|
||||
);
|
||||
|
||||
const cred = req.key as AzureOpenAIKey;
|
||||
const { resourceName, deploymentId, apiKey } = getCredentialsFromKey(cred);
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: `${resourceName}.openai.azure.com`,
|
||||
path: `/openai/deployments/${deploymentId}/chat/completions?api-version=2023-09-01-preview`,
|
||||
headers: {
|
||||
["host"]: `${resourceName}.openai.azure.com`,
|
||||
["content-type"]: "application/json",
|
||||
["api-key"]: apiKey,
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
|
||||
function getCredentialsFromKey(key: AzureOpenAIKey) {
|
||||
const [resourceName, deploymentId, apiKey] = key.key.split(":");
|
||||
if (!resourceName || !deploymentId || !apiKey) {
|
||||
throw new Error("Assigned Azure OpenAI key is not in the correct format.");
|
||||
}
|
||||
return { resourceName, deploymentId, apiKey };
|
||||
}
|
||||
@@ -80,6 +80,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "azure":
|
||||
const azureKey = assignedKey.key;
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
|
||||
+4
-4
@@ -1,11 +1,11 @@
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* For AWS requests, the body is signed earlier in the request pipeline, before
|
||||
* the proxy middleware. This function just assigns the path and headers to the
|
||||
* proxy request.
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
*/
|
||||
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
export const finalizeSignedRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
throw new Error("Expected req.signedRequest to be set");
|
||||
}
|
||||
@@ -22,7 +22,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
||||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
||||
export { finalizeSignedRequest } from "./finalize-signed-request";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
|
||||
|
||||
@@ -289,15 +289,17 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-palm":
|
||||
if (errorPayload.error?.code === "content_policy_violation") {
|
||||
errorPayload.proxy_note = `Request was filtered by OpenAI's content moderation system. Try another prompt.`;
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
errorPayload.proxy_note = `Request was filtered by the upstream API's content moderation system. Modify your prompt and try again.`;
|
||||
refundLastAttempt(req);
|
||||
} else if (errorPayload.error?.code === "billing_hard_limit_reached") {
|
||||
// For some reason, some models return this 400 error instead of the
|
||||
// same 429 billing error that other models return.
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
} else {
|
||||
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
|
||||
errorPayload.proxy_note = `The upstream API rejected the request. Your prompt may be too long for ${req.body?.model}.`;
|
||||
}
|
||||
break;
|
||||
case "anthropic":
|
||||
@@ -342,7 +344,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
throw new Error("Rate limit handling not implemented for PaLM");
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
@@ -369,6 +373,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
case "azure":
|
||||
errorPayload.proxy_note = `The assigned Azure deployment does not support the requested model.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type SSEMessageTransformerOptions = TransformOptions & {
|
||||
export class SSEMessageTransformer extends Transform {
|
||||
private lastPosition: number;
|
||||
private msgCount: number;
|
||||
private readonly inputFormat: APIFormat;
|
||||
private readonly transformFn: StreamingCompletionTransformer;
|
||||
private readonly log;
|
||||
private readonly fallbackId: string;
|
||||
@@ -42,6 +43,7 @@ export class SSEMessageTransformer extends Transform {
|
||||
options.inputFormat,
|
||||
options.inputApiVersion
|
||||
);
|
||||
this.inputFormat = options.inputFormat;
|
||||
this.fallbackId = options.requestId;
|
||||
this.fallbackModel = options.requestedModel;
|
||||
this.log.debug(
|
||||
@@ -67,6 +69,17 @@ export class SSEMessageTransformer extends Transform {
|
||||
});
|
||||
this.lastPosition = newPosition;
|
||||
|
||||
// Special case for Azure OpenAI, which is 99% the same as OpenAI but
|
||||
// sometimes emits an extra event at the beginning of the stream with the
|
||||
// content moderation system's response to the prompt. A lot of frontends
|
||||
// don't expect this and neither does our event aggregator so we drop it.
|
||||
if (this.inputFormat === "openai" && this.msgCount <= 1) {
|
||||
if (originalMessage.includes("prompt_filter_results")) {
|
||||
this.log.debug("Dropping Azure OpenAI content moderation SSE event");
|
||||
return callback();
|
||||
}
|
||||
}
|
||||
|
||||
this.emit("originalMessage", originalMessage);
|
||||
|
||||
// Some events may not be transformed, e.g. ping events
|
||||
|
||||
+2
-2
@@ -24,7 +24,7 @@ import {
|
||||
import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/response";
|
||||
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
const KNOWN_MODELS = [
|
||||
export const KNOWN_OPENAI_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
@@ -46,7 +46,7 @@ const KNOWN_MODELS = [
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
export function generateModelList(models = KNOWN_MODELS) {
|
||||
export function generateModelList(models = KNOWN_OPENAI_MODELS) {
|
||||
let available = new Set<OpenAIModelFamily>();
|
||||
for (const key of keyPool.list()) {
|
||||
if (key.isDisabled || key.service !== "openai") continue;
|
||||
|
||||
+12
-7
@@ -26,6 +26,7 @@ import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
@@ -34,7 +35,7 @@ const log = logger.child({ module: "request-queue" });
|
||||
const AGNAI_CONCURRENCY_LIMIT = 5;
|
||||
/** Maximum number of queue slots for individual users. */
|
||||
const USER_CONCURRENCY_LIMIT = 1;
|
||||
const MIN_HEARTBEAT_SIZE = 512;
|
||||
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
|
||||
const MAX_HEARTBEAT_SIZE =
|
||||
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
|
||||
const HEARTBEAT_INTERVAL =
|
||||
@@ -358,12 +359,16 @@ export function createQueueMiddleware({
|
||||
return (req, res, next) => {
|
||||
req.proceed = async () => {
|
||||
if (beforeProxy) {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
try {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
} catch (err) {
|
||||
return handleProxyError(err, req, res);
|
||||
}
|
||||
}
|
||||
proxyMiddleware(req, res, next);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@ import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googlePalm } from "./palm";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
const proxyRouter = express.Router();
|
||||
proxyRouter.use((req, _res, next) => {
|
||||
@@ -32,6 +33,7 @@ proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
proxyRouter.get("*", (req, res, next) => {
|
||||
const isBrowser = req.headers["user-agent"]?.includes("Mozilla");
|
||||
|
||||
Reference in New Issue
Block a user