properly enforce allowedModelFamilies; refactor HPM proxyReq handlers
This commit is contained in:
+5
-39
@@ -14,17 +14,8 @@
|
||||
import crypto from "crypto";
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool } from "../shared/key-management";
|
||||
import {
|
||||
getAwsBedrockModelFamily,
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
} from "../shared/models";
|
||||
import { getModelFamilyForRequest, MODEL_FAMILIES, ModelFamily } from "../shared/models";
|
||||
import { buildFakeSse, initializeSseStream } from "../shared/streaming";
|
||||
import { assertNever } from "../shared/utils";
|
||||
import { logger } from "../logger";
|
||||
import { getUniqueIps, SHARED_IP_ADDRESSES } from "./rate-limit";
|
||||
import { RequestPreprocessor } from "./middleware/request";
|
||||
@@ -132,34 +123,9 @@ export function enqueue(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
function getPartitionForRequest(req: Request): ModelFamily {
|
||||
// There is a single request queue, but it is partitioned by model family.
|
||||
// Model families are typically separated on cost/rate limit boundaries so
|
||||
// they should be treated as separate queues.
|
||||
const model = req.body.model ?? "gpt-3.5-turbo";
|
||||
|
||||
// Weird special case for AWS/Azure because they serve multiple models from
|
||||
// different vendors, even if currently only one is supported.
|
||||
if (req.service === "aws") return getAwsBedrockModelFamily(model);
|
||||
if (req.service === "azure") return getAzureOpenAIModelFamily(model);
|
||||
|
||||
switch (req.outboundApi) {
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return getOpenAIModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue
|
||||
.filter((req) => getPartitionForRequest(req) === partition)
|
||||
.filter((req) => getModelFamilyForRequest(req) === partition)
|
||||
.sort((a, b) => {
|
||||
// Certain requests are exempted from IP-based rate limiting because they
|
||||
// come from a shared IP address. To prevent these requests from starving
|
||||
@@ -222,7 +188,7 @@ function processQueue() {
|
||||
|
||||
reqs.filter(Boolean).forEach((req) => {
|
||||
if (req?.proceed) {
|
||||
const modelFamily = getPartitionForRequest(req!);
|
||||
const modelFamily = getModelFamilyForRequest(req!);
|
||||
req.log.info({
|
||||
retries: req.retryCount,
|
||||
partition: modelFamily,
|
||||
@@ -279,7 +245,7 @@ let waitTimes: {
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
waitTimes.push({
|
||||
partition: getPartitionForRequest(req),
|
||||
partition: getModelFamilyForRequest(req),
|
||||
start: req.startTime!,
|
||||
end: req.queueOutTime ?? Date.now(),
|
||||
isDeprioritized: isFromSharedIp(req),
|
||||
@@ -324,7 +290,7 @@ function calculateWaitTime(partition: ModelFamily) {
|
||||
|
||||
const currentWaits = queue
|
||||
.filter((req) => {
|
||||
const isSamePartition = getPartitionForRequest(req) === partition;
|
||||
const isSamePartition = getModelFamilyForRequest(req) === partition;
|
||||
const isNormalPriority = !isFromSharedIp(req);
|
||||
return isSamePartition && isNormalPriority;
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user