Per-user token quotas and automatic quota refreshing (khanon/oai-reverse-proxy!37)

This commit is contained in:
khanon
2023-08-28 19:33:14 +00:00
parent 785b1f69f3
commit cb780e85da
31 changed files with 544 additions and 145 deletions
+28 -4
View File
@@ -1,6 +1,8 @@
import { Request, Response } from "express";
import httpProxy from "http-proxy";
import { ZodError } from "zod";
import { AIService } from "../../key-management";
import { QuotaExceededError } from "./request/apply-quota-limits";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
@@ -63,9 +65,7 @@ export const handleInternalError = (
res: Response
) => {
try {
const isZod = err instanceof ZodError;
const isForbidden = err.name === "ForbiddenError";
if (isZod) {
if (err instanceof ZodError) {
writeErrorResponse(req, res, 400, {
error: {
type: "proxy_validation_error",
@@ -75,7 +75,7 @@ export const handleInternalError = (
message: err.message,
},
});
} else if (isForbidden) {
} else if (err.name === "ForbiddenError") {
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the
// block-zoomers rewriter to scare off tiktokers.
writeErrorResponse(req, res, 403, {
@@ -86,6 +86,16 @@ export const handleInternalError = (
message: err.message,
},
});
} else if (err instanceof QuotaExceededError) {
writeErrorResponse(req, res, 429, {
error: {
type: "proxy_quota_exceeded",
code: "quota_exceeded",
message: `You've exceeded your token quota for this model type.`,
info: err.quotaInfo,
stack: err.stack,
},
});
} else {
writeErrorResponse(req, res, 500, {
error: {
@@ -141,3 +151,17 @@ export function buildFakeSseMessage(
}
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}
export function getCompletionForService({
service,
body,
}: {
service: AIService;
body: Record<string, any>;
}): { completion: string; model: string } {
if (service === "anthropic") {
return { completion: body.completion.trim(), model: body.model };
} else {
return { completion: body.choices[0].message.content, model: body.model };
}
}
@@ -0,0 +1,30 @@
import { hasAvailableQuota } from "../../auth/user-store";
import { isCompletionRequest } from "../common";
import { ProxyRequestMiddleware } from ".";
export class QuotaExceededError extends Error {
public quotaInfo: any;
constructor(message: string, quotaInfo: any) {
super(message);
this.name = "QuotaExceededError";
this.quotaInfo = quotaInfo;
}
}
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
if (!isCompletionRequest(req) || !req.user) {
return;
}
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
throw new QuotaExceededError(
"You have exceeded your proxy token quota for this model.",
{
quota: req.user.tokenLimits,
used: req.user.tokenCounts,
requested: requestedTokens,
}
);
}
};
@@ -1,7 +1,7 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { countTokens } from "../../../tokenization";
import { OpenAIPromptMessage, countTokens } from "../../../tokenization";
import { RequestPreprocessor } from ".";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
@@ -15,22 +15,26 @@ const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
* request body.
*/
export const checkContextSize: RequestPreprocessor = async (req) => {
let prompt;
const service = req.outboundApi;
let result;
switch (req.outboundApi) {
case "openai":
switch (service) {
case "openai": {
req.outputTokens = req.body.max_tokens;
prompt = req.body.messages;
const prompt: OpenAIPromptMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
case "anthropic":
}
case "anthropic": {
req.outputTokens = req.body.max_tokens_to_sample;
prompt = req.body.prompt;
const prompt: string = req.body.prompt;
result = await countTokens({ req, prompt, service });
break;
}
default:
throw new Error(`Unknown outbound API: ${req.outboundApi}`);
}
const result = await countTokens({ req, prompt, service: req.outboundApi });
req.promptTokens = result.token_count;
// TODO: Remove once token counting is stable
@@ -89,6 +93,7 @@ function validateContextSize(req: Request) {
);
req.debug.prompt_tokens = promptTokens;
req.debug.completion_tokens = outputTokens;
req.debug.max_model_tokens = modelMax;
req.debug.max_proxy_tokens = proxyMax;
}
@@ -101,7 +106,7 @@ function assertRequestHasTokenCounts(
outputTokens: z.number().int().min(1),
})
.nonstrict()
.parse(req);
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
}
/**
+1
View File
@@ -3,6 +3,7 @@ import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
// Express middleware (runs before http-proxy-middleware, can be async)
export { applyQuotaLimits } from "./apply-quota-limits";
export { createPreprocessorMiddleware } from "./preprocess";
export { checkContextSize } from "./check-context-size";
export { setApiFormat } from "./set-api-format";
+60 -6
View File
@@ -3,14 +3,21 @@ import { Request, Response } from "express";
import * as http from "http";
import util from "util";
import zlib from "zlib";
import { config } from "../../../config";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { enqueue, trackWaitTime } from "../../queue";
import { incrementPromptCount } from "../../auth/user-store";
import { isCompletionRequest, writeErrorResponse } from "../common";
import {
incrementPromptCount,
incrementTokenCount,
} from "../../auth/user-store";
import {
getCompletionForService,
isCompletionRequest,
writeErrorResponse,
} from "../common";
import { handleStreamedResponse } from "./handle-streamed-response";
import { logPrompt } from "./log-prompt";
import { countTokens } from "../../../tokenization";
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
@@ -84,12 +91,18 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
if (req.isStreaming) {
// `handleStreamedResponse` writes to the response and ends it, so
// we can only execute middleware that doesn't write to the response.
middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt);
middlewareStack.push(
trackRateLimit,
countResponseTokens,
incrementUsage,
logPrompt
);
} else {
middlewareStack.push(
trackRateLimit,
handleUpstreamErrors,
incrementKeyUsage,
countResponseTokens,
incrementUsage,
copyHttpHeaders,
logPrompt,
...apiMiddleware
@@ -394,15 +407,56 @@ function handleOpenAIRateLimitError(
return errorPayload;
}
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isCompletionRequest(req)) {
keyPool.incrementPrompt(req.key!);
if (req.user) {
incrementPromptCount(req.user.token);
const model = req.body.model;
const tokensUsed = req.promptTokens! + req.outputTokens!;
incrementTokenCount(req.user.token, model, tokensUsed);
}
}
};
const countResponseTokens: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
body
) => {
// This function is prone to breaking if the upstream API makes even minor
// changes to the response format, especially for SSE responses. If you're
// seeing errors in this function, check the reassembled response body from
// handleStreamedResponse to see if the upstream API has changed.
try {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
const service = req.outboundApi;
const { completion } = getCompletionForService({ service, body });
const tokens = await countTokens({ req, completion, service });
req.log.debug(
{ service, tokens, prevOutputTokens: req.outputTokens },
`Counted tokens for completion`
);
if (req.debug) {
req.debug.completion_tokens = tokens;
}
req.outputTokens = tokens.token_count;
} catch (error) {
req.log.error(
error,
"Error while counting completion tokens; assuming `max_output_tokens`"
);
// req.outputTokens will already be set to `max_output_tokens` from the
// prompt counting middleware, so we don't need to do anything here.
}
};
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
keyPool.updateRateLimits(req.key!, proxyRes.headers);
};
+2 -18
View File
@@ -1,10 +1,8 @@
import { Request } from "express";
import { config } from "../../../config";
import { AIService } from "../../../key-management";
import { logQueue } from "../../../prompt-logging";
import { isCompletionRequest } from "../common";
import { getCompletionForService, isCompletionRequest } from "../common";
import { ProxyResHandlerWithBody } from ".";
import { logger } from "../../../logger";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
@@ -26,7 +24,7 @@ export const logPrompt: ProxyResHandlerWithBody = async (
const promptPayload = getPromptForRequest(req);
const promptFlattened = flattenMessages(promptPayload);
const response = getResponseForService({
const response = getCompletionForService({
service: req.outboundApi,
body: responseBody,
});
@@ -62,17 +60,3 @@ const flattenMessages = (messages: string | OaiMessage[]): string => {
}
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
};
const getResponseForService = ({
service,
body,
}: {
service: AIService;
body: Record<string, any>;
}): { completion: string; model: string } => {
if (service === "anthropic") {
return { completion: body.completion.trim(), model: body.model };
} else {
return { completion: body.choices[0].message.content, model: body.model };
}
};