properly enforce allowedModelFamilies; refactor HPM proxyReq handlers

This commit is contained in:
nai-degen
2023-12-05 21:41:04 -06:00
parent 12276a1f59
commit 94d4efe9bb
34 changed files with 204 additions and 262 deletions
+5 -39
View File
@@ -14,17 +14,8 @@
import crypto from "crypto";
import type { Handler, Request } from "express";
import { keyPool } from "../shared/key-management";
import {
getAwsBedrockModelFamily,
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getOpenAIModelFamily,
MODEL_FAMILIES,
ModelFamily,
} from "../shared/models";
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
import { assertNever } from "../shared/utils";
import { logger } from "../logger";
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
import { RequestPreprocessor } from "./middleware/request";
@@ -132,34 +123,9 @@ export function enqueue(req: Request) {
}
}
function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const model = req.body.model ?? "gpt-3.5-turbo";
// Weird special case for AWS/Azure because they serve multiple models from
// different vendors, even if currently only one is supported.
if (req.service === "aws") return getAwsBedrockModelFamily(model);
if (req.service === "azure") return getAzureOpenAIModelFamily(model);
switch (req.outboundApi) {
case "anthropic":
return getClaudeModelFamily(model);
case "openai":
case "openai-text":
case "openai-image":
return getOpenAIModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
default:
assertNever(req.outboundApi);
}
}
function getQueueForPartition(partition: ModelFamily): Request[] {
return queue
.filter((req) => getPartitionForRequest(req) === partition)
.filter((req) => getModelFamilyForRequest(req) === partition)
.sort((a, b) => {
// Certain requests are exempted from IP-based rate limiting because they
// come from a shared IP address. To prevent these requests from starving
@@ -222,7 +188,7 @@ function processQueue() {
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {
const modelFamily = getPartitionForRequest(req!);
const modelFamily = getModelFamilyForRequest(req!);
req.log.info({
retries: req.retryCount,
partition: modelFamily,
@@ -279,7 +245,7 @@ let waitTimes: {
/** Adds a successful request to the list of wait times. */
export function trackWaitTime(req: Request) {
waitTimes.push({
partition: getPartitionForRequest(req),
partition: getModelFamilyForRequest(req),
start: req.startTime!,
end: req.queueOutTime ?? Date.now(),
isDeprioritized: isFromSharedIp(req),
@@ -324,7 +290,7 @@ function calculateWaitTime(partition: ModelFamily) {
const currentWaits = queue
.filter((req) => {
const isSamePartition = getPartitionForRequest(req) === partition;
const isSamePartition = getModelFamilyForRequest(req) === partition;
const isNormalPriority = !isFromSharedIp(req);
return isSamePartition && isNormalPriority;
})