re-signs AWS requests on every attempt to fix fucked up queueing
This commit is contained in:
@@ -13,7 +13,8 @@ import {
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
languageFilter,
|
||||
stripHeaders, createOnProxyReqHandler
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
ProxyResHandlerWithBody,
|
||||
@@ -129,8 +130,8 @@ function transformAnthropicResponse(
|
||||
};
|
||||
}
|
||||
|
||||
const anthropicProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
const anthropicProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.anthropic.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
@@ -154,8 +155,8 @@ const anthropicProxy = createQueueMiddleware(
|
||||
// Send OpenAI-compat requests to the real Anthropic endpoint.
|
||||
"^/v1/chat/completions": "/v1/complete",
|
||||
},
|
||||
})
|
||||
);
|
||||
}),
|
||||
});
|
||||
|
||||
const anthropicRouter = Router();
|
||||
anthropicRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
+8
-12
@@ -3,7 +3,6 @@ import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
@@ -120,13 +119,12 @@ function transformAwsResponse(
|
||||
};
|
||||
}
|
||||
|
||||
const awsProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
const awsProxy = createQueueMiddleware({
|
||||
beforeProxy: signAwsRequest,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
if (!signedRequest) {
|
||||
throw new Error("AWS requests must go through signAwsRequest first");
|
||||
}
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
return `${signedRequest.protocol}//${signedRequest.hostname}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
@@ -135,9 +133,7 @@ const awsProxy = createQueueMiddleware(
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [
|
||||
(_, req) => keyPool.throttle(req.key!),
|
||||
applyQuotaLimits,
|
||||
// Credentials are added by signAwsRequest preprocessor
|
||||
languageFilter,
|
||||
blockZoomerOrigins,
|
||||
stripHeaders,
|
||||
@@ -147,8 +143,8 @@ const awsProxy = createQueueMiddleware(
|
||||
proxyRes: createOnProxyResHandler([awsResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
})
|
||||
);
|
||||
}),
|
||||
});
|
||||
|
||||
const awsRouter = Router();
|
||||
awsRouter.get("/v1/models", handleModelRequest);
|
||||
@@ -158,7 +154,7 @@ awsRouter.post(
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "anthropic", outApi: "anthropic", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel, signAwsRequest] }
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
@@ -168,7 +164,7 @@ awsRouter.post(
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "anthropic", service: "aws" },
|
||||
{ afterTransform: [maybeReassignModel, signAwsRequest] }
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
awsProxy
|
||||
);
|
||||
|
||||
@@ -59,7 +59,6 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
}
|
||||
}
|
||||
|
||||
keyPool.throttle(assignedKey);
|
||||
req.key = assignedKey;
|
||||
req.log.info(
|
||||
{
|
||||
@@ -117,7 +116,7 @@ export const addKeyForEmbeddingsRequest: ProxyRequestMiddleware = (
|
||||
throw new Error("Embeddings requests must be from OpenAI");
|
||||
}
|
||||
|
||||
req.body = { input: req.body.input, model: "text-embedding-ada-002" }
|
||||
req.body = { input: req.body.input, model: "text-embedding-ada-002" };
|
||||
|
||||
const key = keyPool.get("text-embedding-ada-002") as OpenAIKey;
|
||||
|
||||
|
||||
+5
-4
@@ -25,6 +25,7 @@ import {
|
||||
limitCompletions,
|
||||
stripHeaders,
|
||||
createOnProxyReqHandler,
|
||||
signAwsRequest,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
@@ -163,8 +164,8 @@ function transformTurboInstructResponse(
|
||||
return transformed;
|
||||
}
|
||||
|
||||
const openaiProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
const openaiProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
@@ -184,8 +185,8 @@ const openaiProxy = createQueueMiddleware(
|
||||
proxyRes: createOnProxyResHandler([openaiResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
})
|
||||
);
|
||||
}),
|
||||
});
|
||||
|
||||
const openaiEmbeddingsProxy = createProxyMiddleware({
|
||||
target: "https://api.openai.com",
|
||||
|
||||
+4
-4
@@ -143,8 +143,8 @@ function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
|
||||
);
|
||||
}
|
||||
|
||||
const googlePalmProxy = createQueueMiddleware(
|
||||
createProxyMiddleware({
|
||||
const googlePalmProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://generativelanguage.googleapis.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
@@ -164,8 +164,8 @@ const googlePalmProxy = createQueueMiddleware(
|
||||
proxyRes: createOnProxyResHandler([palmResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
})
|
||||
);
|
||||
}),
|
||||
});
|
||||
|
||||
const palmRouter = Router();
|
||||
palmRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
+18
-3
@@ -23,6 +23,7 @@ import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||
import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
@@ -52,7 +53,7 @@ function getIdentifier(req: Request) {
|
||||
const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
|
||||
getIdentifier(queued) === getIdentifier(incoming);
|
||||
|
||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip)
|
||||
const isFromSharedIp = (req: Request) => SHARED_IP_ADDRESSES.has(req.ip);
|
||||
|
||||
export function enqueue(req: Request) {
|
||||
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
|
||||
@@ -325,9 +326,23 @@ export function getQueueLength(partition: ModelFamily | "all" = "all") {
|
||||
return modelQueue.length;
|
||||
}
|
||||
|
||||
export function createQueueMiddleware(proxyMiddleware: Handler): Handler {
|
||||
export function createQueueMiddleware({
|
||||
beforeProxy,
|
||||
proxyMiddleware,
|
||||
}: {
|
||||
beforeProxy?: RequestPreprocessor;
|
||||
proxyMiddleware: Handler;
|
||||
}): Handler {
|
||||
return (req, res, next) => {
|
||||
req.proceed = () => {
|
||||
req.proceed = async () => {
|
||||
if (beforeProxy) {
|
||||
// Hack to let us run asynchronous middleware before the
|
||||
// http-proxy-middleware handler. This is used to sign AWS requests
|
||||
// before they are proxied, as the signing is asynchronous.
|
||||
// Unlike RequestPreprocessors, this runs every time the request is
|
||||
// dequeued, not just the first time.
|
||||
await beforeProxy(req);
|
||||
}
|
||||
proxyMiddleware(req, res, next);
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user