Per-user token quotas and automatic quota refreshing (khanon/oai-reverse-proxy!37)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import { Request, Response } from "express";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ZodError } from "zod";
|
||||
import { AIService } from "../../key-management";
|
||||
import { QuotaExceededError } from "./request/apply-quota-limits";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
@@ -63,9 +65,7 @@ export const handleInternalError = (
|
||||
res: Response
|
||||
) => {
|
||||
try {
|
||||
const isZod = err instanceof ZodError;
|
||||
const isForbidden = err.name === "ForbiddenError";
|
||||
if (isZod) {
|
||||
if (err instanceof ZodError) {
|
||||
writeErrorResponse(req, res, 400, {
|
||||
error: {
|
||||
type: "proxy_validation_error",
|
||||
@@ -75,7 +75,7 @@ export const handleInternalError = (
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else if (isForbidden) {
|
||||
} else if (err.name === "ForbiddenError") {
|
||||
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the
|
||||
// block-zoomers rewriter to scare off tiktokers.
|
||||
writeErrorResponse(req, res, 403, {
|
||||
@@ -86,6 +86,16 @@ export const handleInternalError = (
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else if (err instanceof QuotaExceededError) {
|
||||
writeErrorResponse(req, res, 429, {
|
||||
error: {
|
||||
type: "proxy_quota_exceeded",
|
||||
code: "quota_exceeded",
|
||||
message: `You've exceeded your token quota for this model type.`,
|
||||
info: err.quotaInfo,
|
||||
stack: err.stack,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
writeErrorResponse(req, res, 500, {
|
||||
error: {
|
||||
@@ -141,3 +151,17 @@ export function buildFakeSseMessage(
|
||||
}
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
|
||||
export function getCompletionForService({
|
||||
service,
|
||||
body,
|
||||
}: {
|
||||
service: AIService;
|
||||
body: Record<string, any>;
|
||||
}): { completion: string; model: string } {
|
||||
if (service === "anthropic") {
|
||||
return { completion: body.completion.trim(), model: body.model };
|
||||
} else {
|
||||
return { completion: body.choices[0].message.content, model: body.model };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import { hasAvailableQuota } from "../../auth/user-store";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
export class QuotaExceededError extends Error {
|
||||
public quotaInfo: any;
|
||||
constructor(message: string, quotaInfo: any) {
|
||||
super(message);
|
||||
this.name = "QuotaExceededError";
|
||||
this.quotaInfo = quotaInfo;
|
||||
}
|
||||
}
|
||||
|
||||
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!isCompletionRequest(req) || !req.user) {
|
||||
return;
|
||||
}
|
||||
|
||||
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
|
||||
throw new QuotaExceededError(
|
||||
"You have exceeded your proxy token quota for this model.",
|
||||
{
|
||||
quota: req.user.tokenLimits,
|
||||
used: req.user.tokenCounts,
|
||||
requested: requestedTokens,
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { countTokens } from "../../../tokenization";
|
||||
import { OpenAIPromptMessage, countTokens } from "../../../tokenization";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
@@ -15,22 +15,26 @@ const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
* request body.
|
||||
*/
|
||||
export const checkContextSize: RequestPreprocessor = async (req) => {
|
||||
let prompt;
|
||||
const service = req.outboundApi;
|
||||
let result;
|
||||
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
prompt = req.body.messages;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
case "anthropic":
|
||||
}
|
||||
case "anthropic": {
|
||||
req.outputTokens = req.body.max_tokens_to_sample;
|
||||
prompt = req.body.prompt;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown outbound API: ${req.outboundApi}`);
|
||||
}
|
||||
|
||||
const result = await countTokens({ req, prompt, service: req.outboundApi });
|
||||
req.promptTokens = result.token_count;
|
||||
|
||||
// TODO: Remove once token counting is stable
|
||||
@@ -89,6 +93,7 @@ function validateContextSize(req: Request) {
|
||||
);
|
||||
|
||||
req.debug.prompt_tokens = promptTokens;
|
||||
req.debug.completion_tokens = outputTokens;
|
||||
req.debug.max_model_tokens = modelMax;
|
||||
req.debug.max_proxy_tokens = proxyMax;
|
||||
}
|
||||
@@ -101,7 +106,7 @@ function assertRequestHasTokenCounts(
|
||||
outputTokens: z.number().int().min(1),
|
||||
})
|
||||
.nonstrict()
|
||||
.parse(req);
|
||||
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { createPreprocessorMiddleware } from "./preprocess";
|
||||
export { checkContextSize } from "./check-context-size";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
|
||||
@@ -3,14 +3,21 @@ import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { incrementPromptCount } from "../../auth/user-store";
|
||||
import { isCompletionRequest, writeErrorResponse } from "../common";
|
||||
import {
|
||||
incrementPromptCount,
|
||||
incrementTokenCount,
|
||||
} from "../../auth/user-store";
|
||||
import {
|
||||
getCompletionForService,
|
||||
isCompletionRequest,
|
||||
writeErrorResponse,
|
||||
} from "../common";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { countTokens } from "../../../tokenization";
|
||||
|
||||
const DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
@@ -84,12 +91,18 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||
if (req.isStreaming) {
|
||||
// `handleStreamedResponse` writes to the response and ends it, so
|
||||
// we can only execute middleware that doesn't write to the response.
|
||||
middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt);
|
||||
middlewareStack.push(
|
||||
trackRateLimit,
|
||||
countResponseTokens,
|
||||
incrementUsage,
|
||||
logPrompt
|
||||
);
|
||||
} else {
|
||||
middlewareStack.push(
|
||||
trackRateLimit,
|
||||
handleUpstreamErrors,
|
||||
incrementKeyUsage,
|
||||
countResponseTokens,
|
||||
incrementUsage,
|
||||
copyHttpHeaders,
|
||||
logPrompt,
|
||||
...apiMiddleware
|
||||
@@ -394,15 +407,56 @@ function handleOpenAIRateLimitError(
|
||||
return errorPayload;
|
||||
}
|
||||
|
||||
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isCompletionRequest(req)) {
|
||||
keyPool.incrementPrompt(req.key!);
|
||||
if (req.user) {
|
||||
incrementPromptCount(req.user.token);
|
||||
const model = req.body.model;
|
||||
const tokensUsed = req.promptTokens! + req.outputTokens!;
|
||||
incrementTokenCount(req.user.token, model, tokensUsed);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
_res,
|
||||
body
|
||||
) => {
|
||||
// This function is prone to breaking if the upstream API makes even minor
|
||||
// changes to the response format, especially for SSE responses. If you're
|
||||
// seeing errors in this function, check the reassembled response body from
|
||||
// handleStreamedResponse to see if the upstream API has changed.
|
||||
try {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
const service = req.outboundApi;
|
||||
const { completion } = getCompletionForService({ service, body });
|
||||
const tokens = await countTokens({ req, completion, service });
|
||||
|
||||
req.log.debug(
|
||||
{ service, tokens, prevOutputTokens: req.outputTokens },
|
||||
`Counted tokens for completion`
|
||||
);
|
||||
if (req.debug) {
|
||||
req.debug.completion_tokens = tokens;
|
||||
}
|
||||
|
||||
req.outputTokens = tokens.token_count;
|
||||
} catch (error) {
|
||||
req.log.error(
|
||||
error,
|
||||
"Error while counting completion tokens; assuming `max_output_tokens`"
|
||||
);
|
||||
// req.outputTokens will already be set to `max_output_tokens` from the
|
||||
// prompt counting middleware, so we don't need to do anything here.
|
||||
}
|
||||
};
|
||||
|
||||
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
|
||||
keyPool.updateRateLimits(req.key!, proxyRes.headers);
|
||||
};
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { AIService } from "../../../key-management";
|
||||
import { logQueue } from "../../../prompt-logging";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { getCompletionForService, isCompletionRequest } from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { logger } from "../../../logger";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -26,7 +24,7 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
|
||||
const promptPayload = getPromptForRequest(req);
|
||||
const promptFlattened = flattenMessages(promptPayload);
|
||||
const response = getResponseForService({
|
||||
const response = getCompletionForService({
|
||||
service: req.outboundApi,
|
||||
body: responseBody,
|
||||
});
|
||||
@@ -62,17 +60,3 @@ const flattenMessages = (messages: string | OaiMessage[]): string => {
|
||||
}
|
||||
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
|
||||
};
|
||||
|
||||
const getResponseForService = ({
|
||||
service,
|
||||
body,
|
||||
}: {
|
||||
service: AIService;
|
||||
body: Record<string, any>;
|
||||
}): { completion: string; model: string } => {
|
||||
if (service === "anthropic") {
|
||||
return { completion: body.completion.trim(), model: body.model };
|
||||
} else {
|
||||
return { completion: body.choices[0].message.content, model: body.model };
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user