Add GPT-4-32k support (khanon/oai-reverse-proxy!39)
This commit is contained in:
+23
-13
@@ -16,13 +16,17 @@
|
||||
*/
|
||||
|
||||
import type { Handler, Request } from "express";
|
||||
import { keyPool, SupportedModel } from "../key-management";
|
||||
import {
|
||||
getClaudeModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
keyPool,
|
||||
ModelFamily,
|
||||
SupportedModel,
|
||||
} from "../key-management";
|
||||
import { logger } from "../logger";
|
||||
import { AGNAI_DOT_CHAT_IP } from "./rate-limit";
|
||||
import { buildFakeSseMessage } from "./middleware/common";
|
||||
|
||||
export type QueuePartition = "claude" | "turbo" | "gpt-4";
|
||||
|
||||
const queue: Request[] = [];
|
||||
const log = logger.child({ module: "request-queue" });
|
||||
|
||||
@@ -129,7 +133,7 @@ export function enqueue(req: Request) {
|
||||
}
|
||||
}
|
||||
|
||||
function getPartitionForRequest(req: Request): QueuePartition {
|
||||
function getPartitionForRequest(req: Request): ModelFamily {
|
||||
// There is a single request queue, but it is partitioned by model and API
|
||||
// provider.
|
||||
// - claude: requests for the Anthropic API, regardless of model
|
||||
@@ -138,19 +142,19 @@ function getPartitionForRequest(req: Request): QueuePartition {
|
||||
const provider = req.outboundApi;
|
||||
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
|
||||
if (provider === "anthropic") {
|
||||
return "claude";
|
||||
return getClaudeModelFamily(model);
|
||||
}
|
||||
if (provider === "openai" && model.startsWith("gpt-4")) {
|
||||
return "gpt-4";
|
||||
if (provider === "openai") {
|
||||
return getOpenAIModelFamily(model);
|
||||
}
|
||||
return "turbo";
|
||||
}
|
||||
|
||||
function getQueueForPartition(partition: QueuePartition): Request[] {
|
||||
function getQueueForPartition(partition: ModelFamily): Request[] {
|
||||
return queue.filter((req) => getPartitionForRequest(req) === partition);
|
||||
}
|
||||
|
||||
export function dequeue(partition: QueuePartition): Request | undefined {
|
||||
export function dequeue(partition: ModelFamily): Request | undefined {
|
||||
const modelQueue = getQueueForPartition(partition);
|
||||
|
||||
if (modelQueue.length === 0) {
|
||||
@@ -189,13 +193,19 @@ function processQueue() {
|
||||
// This isn't completely correct, because a key can service multiple models.
|
||||
// Currently if a key is locked out on one model it will also stop servicing
|
||||
// the others, because we only track one rate limit per key.
|
||||
|
||||
// TODO: `getLockoutPeriod` uses model names instead of model families
|
||||
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
|
||||
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
|
||||
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
|
||||
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
|
||||
|
||||
const reqs: (Request | undefined)[] = [];
|
||||
if (gpt432kLockout === 0) {
|
||||
reqs.push(dequeue("gpt4-32k"));
|
||||
}
|
||||
if (gpt4Lockout === 0) {
|
||||
reqs.push(dequeue("gpt-4"));
|
||||
reqs.push(dequeue("gpt4"));
|
||||
}
|
||||
if (turboLockout === 0) {
|
||||
reqs.push(dequeue("turbo"));
|
||||
@@ -244,7 +254,7 @@ export function start() {
|
||||
log.info(`Started request queue.`);
|
||||
}
|
||||
|
||||
let waitTimes: { partition: QueuePartition; start: number; end: number }[] = [];
|
||||
let waitTimes: { partition: ModelFamily; start: number; end: number }[] = [];
|
||||
|
||||
/** Adds a successful request to the list of wait times. */
|
||||
export function trackWaitTime(req: Request) {
|
||||
@@ -256,7 +266,7 @@ export function trackWaitTime(req: Request) {
|
||||
}
|
||||
|
||||
/** Returns average wait time in milliseconds. */
|
||||
export function getEstimatedWaitTime(partition: QueuePartition) {
|
||||
export function getEstimatedWaitTime(partition: ModelFamily) {
|
||||
const now = Date.now();
|
||||
const recentWaits = waitTimes.filter(
|
||||
(wt) => wt.partition === partition && now - wt.end < 300 * 1000
|
||||
@@ -271,7 +281,7 @@ export function getEstimatedWaitTime(partition: QueuePartition) {
|
||||
);
|
||||
}
|
||||
|
||||
export function getQueueLength(partition: QueuePartition | "all" = "all") {
|
||||
export function getQueueLength(partition: ModelFamily | "all" = "all") {
|
||||
if (partition === "all") {
|
||||
return queue.length;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user