Per-user token quotas and automatic quota refreshing (khanon/oai-reverse-proxy!37)
This commit is contained in:
@@ -8,6 +8,7 @@ import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
addAnthropicPreamble,
|
||||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
@@ -72,6 +73,7 @@ const rewriteAnthropicRequest = (
|
||||
res: http.ServerResponse
|
||||
) => {
|
||||
const rewriterPipeline = [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
addAnthropicPreamble,
|
||||
languageFilter,
|
||||
@@ -108,7 +110,7 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Anthropic response to OpenAI format");
|
||||
body = transformAnthropicResponse(body);
|
||||
body = transformAnthropicResponse(body, req);
|
||||
}
|
||||
|
||||
// TODO: Remove once tokenization is stable
|
||||
@@ -126,17 +128,19 @@ const anthropicResponseHandler: ProxyResHandlerWithBody = async (
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformAnthropicResponse(
|
||||
anthropicBody: Record<string, any>
|
||||
anthropicBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "ant-" + anthropicBody.log_id,
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: anthropicBody.model,
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Request, RequestHandler } from "express";
|
||||
import { config } from "../../config";
|
||||
import { authenticate, getUser } from "./user-store";
|
||||
import { authenticate, getUser, hasAvailableQuota } from "./user-store";
|
||||
|
||||
const GATEKEEPER = config.gatekeeper;
|
||||
const PROXY_KEY = config.proxyKey;
|
||||
|
||||
+108
-17
@@ -8,10 +8,16 @@
|
||||
*/
|
||||
|
||||
import admin from "firebase-admin";
|
||||
import schedule from "node-schedule";
|
||||
import { v4 as uuid } from "uuid";
|
||||
import { config, getFirebaseApp } from "../../config";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
const log = logger.child({ module: "users" });
|
||||
|
||||
// TODO: Consolidate model families with QueuePartition and KeyProvider.
|
||||
type QuotaModel = "claude" | "turbo" | "gpt4";
|
||||
|
||||
export interface User {
|
||||
/** The user's personal access token. */
|
||||
token: string;
|
||||
@@ -21,8 +27,12 @@ export interface User {
|
||||
type: UserType;
|
||||
/** The number of prompts the user has made. */
|
||||
promptCount: number;
|
||||
/** The number of tokens the user has consumed. Not yet implemented. */
|
||||
tokenCount: number;
|
||||
/** @deprecated Use `tokenCounts` instead. */
|
||||
tokenCount?: never;
|
||||
/** The number of tokens the user has consumed, by model family. */
|
||||
tokenCounts: Record<QuotaModel, number>;
|
||||
/** The maximum number of tokens the user can consume, by model family. */
|
||||
tokenLimits: Record<QuotaModel, number>;
|
||||
/** The time at which the user was created. */
|
||||
createdAt: number;
|
||||
/** The time at which the user last connected. */
|
||||
@@ -37,7 +47,6 @@ export interface User {
|
||||
* Possible privilege levels for a user.
|
||||
* - `normal`: Default role. Subject to usual rate limits and quotas.
|
||||
* - `special`: Special role. Higher quotas and exempt from auto-ban/lockout.
|
||||
* TODO: implement auto-ban/lockout for normal users when they do naughty shit
|
||||
*/
|
||||
export type UserType = "normal" | "special";
|
||||
|
||||
@@ -49,11 +58,32 @@ const users: Map<string, User> = new Map();
|
||||
const usersToFlush = new Set<string>();
|
||||
|
||||
export async function init() {
|
||||
logger.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
||||
log.info({ store: config.gatekeeperStore }, "Initializing user store...");
|
||||
if (config.gatekeeperStore === "firebase_rtdb") {
|
||||
await initFirebase();
|
||||
}
|
||||
logger.info("User store initialized.");
|
||||
if (config.quotaRefreshPeriod) {
|
||||
const quotaRefreshJob = schedule.scheduleJob(getRefreshCrontab(), () => {
|
||||
for (const user of users.values()) {
|
||||
refreshQuota(user.token);
|
||||
}
|
||||
log.info(
|
||||
{ users: users.size, nextRefresh: quotaRefreshJob.nextInvocation() },
|
||||
"Token quotas refreshed."
|
||||
);
|
||||
});
|
||||
|
||||
if (!quotaRefreshJob) {
|
||||
throw new Error(
|
||||
"Unable to schedule quota refresh. Is QUOTA_REFRESH_PERIOD set correctly?"
|
||||
);
|
||||
}
|
||||
log.debug(
|
||||
{ nextRefresh: quotaRefreshJob.nextInvocation() },
|
||||
"Scheduled token quota refresh."
|
||||
);
|
||||
}
|
||||
log.info("User store initialized.");
|
||||
}
|
||||
|
||||
/** Creates a new user and returns their token. */
|
||||
@@ -64,7 +94,8 @@ export function createUser() {
|
||||
ip: [],
|
||||
type: "normal",
|
||||
promptCount: 0,
|
||||
tokenCount: 0,
|
||||
tokenCounts: { turbo: 0, gpt4: 0, claude: 0 },
|
||||
tokenLimits: { ...config.tokenQuota },
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
usersToFlush.add(token);
|
||||
@@ -86,12 +117,14 @@ export function getUsers() {
|
||||
* user information via JSON. Use other functions for more specific operations.
|
||||
*/
|
||||
export function upsertUser(user: UserUpdate) {
|
||||
// TODO: May need better merging for nested objects
|
||||
const existing: User = users.get(user.token) ?? {
|
||||
token: user.token,
|
||||
ip: [],
|
||||
type: "normal",
|
||||
promptCount: 0,
|
||||
tokenCount: 0,
|
||||
tokenCounts: { turbo: 0, gpt4: 0, claude: 0 },
|
||||
tokenLimits: { ...config.tokenQuota },
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
|
||||
@@ -117,11 +150,16 @@ export function incrementPromptCount(token: string) {
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
/** Increments the token count for the given user by the given amount. */
|
||||
export function incrementTokenCount(token: string, amount = 1) {
|
||||
/** Increments token consumption for the given user and model. */
|
||||
export function incrementTokenCount(
|
||||
token: string,
|
||||
model: string,
|
||||
consumption: number
|
||||
) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
user.tokenCount += amount;
|
||||
const modelFamily = getModelFamily(model);
|
||||
user.tokenCounts[modelFamily] += consumption;
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
@@ -148,6 +186,40 @@ export function authenticate(token: string, ip: string) {
|
||||
return user;
|
||||
}
|
||||
|
||||
export function hasAvailableQuota(
|
||||
token: string,
|
||||
model: string,
|
||||
requested: number
|
||||
) {
|
||||
const user = users.get(token);
|
||||
if (!user) return false;
|
||||
if (user.type === "special") return true;
|
||||
|
||||
const modelFamily = getModelFamily(model);
|
||||
const { tokenCounts, tokenLimits } = user;
|
||||
const tokenLimit = tokenLimits[modelFamily];
|
||||
|
||||
if (!tokenLimit) return true;
|
||||
|
||||
const tokensConsumed = tokenCounts[modelFamily] + requested;
|
||||
return tokensConsumed < tokenLimit;
|
||||
}
|
||||
|
||||
export function refreshQuota(token: string) {
|
||||
const user = users.get(token);
|
||||
if (!user) return;
|
||||
const { tokenCounts, tokenLimits } = user;
|
||||
const quotas = Object.entries(config.tokenQuota) as [QuotaModel, number][];
|
||||
quotas
|
||||
// If a quota is not configured, don't touch any existing limits a user may
|
||||
// already have been assigned manually.
|
||||
.filter(([, quota]) => quota > 0)
|
||||
.forEach(
|
||||
([model, quota]) => (tokenLimits[model] = tokenCounts[model] + quota)
|
||||
);
|
||||
usersToFlush.add(token);
|
||||
}
|
||||
|
||||
/** Disables the given user, optionally providing a reason. */
|
||||
export function disableUser(token: string, reason?: string) {
|
||||
const user = users.get(token);
|
||||
@@ -163,7 +235,7 @@ export function disableUser(token: string, reason?: string) {
|
||||
let firebaseTimeout: NodeJS.Timeout | undefined;
|
||||
|
||||
async function initFirebase() {
|
||||
logger.info("Connecting to Firebase...");
|
||||
log.info("Connecting to Firebase...");
|
||||
const app = getFirebaseApp();
|
||||
const db = admin.database(app);
|
||||
const usersRef = db.ref("users");
|
||||
@@ -171,7 +243,7 @@ async function initFirebase() {
|
||||
const users: Record<string, User> | null = snapshot.val();
|
||||
firebaseTimeout = setInterval(flushUsers, 20 * 1000);
|
||||
if (!users) {
|
||||
logger.info("No users found in Firebase.");
|
||||
log.info("No users found in Firebase.");
|
||||
return;
|
||||
}
|
||||
for (const token in users) {
|
||||
@@ -179,7 +251,7 @@ async function initFirebase() {
|
||||
}
|
||||
usersToFlush.clear();
|
||||
const numUsers = Object.keys(users).length;
|
||||
logger.info({ users: numUsers }, "Loaded users from Firebase");
|
||||
log.info({ users: numUsers }, "Loaded users from Firebase");
|
||||
}
|
||||
|
||||
async function flushUsers() {
|
||||
@@ -204,8 +276,27 @@ async function flushUsers() {
|
||||
}
|
||||
|
||||
await usersRef.update(updates);
|
||||
logger.info(
|
||||
{ users: Object.keys(updates).length },
|
||||
"Flushed users to Firebase"
|
||||
);
|
||||
log.info({ users: Object.keys(updates).length }, "Flushed users to Firebase");
|
||||
}
|
||||
|
||||
function getModelFamily(model: string): QuotaModel {
|
||||
if (model.startsWith("gpt-4")) {
|
||||
// TODO: add 32k models
|
||||
return "gpt4";
|
||||
}
|
||||
if (model.startsWith("gpt-3.5")) {
|
||||
return "turbo";
|
||||
}
|
||||
return "claude";
|
||||
}
|
||||
|
||||
function getRefreshCrontab() {
|
||||
switch (config.quotaRefreshPeriod!) {
|
||||
case "hourly":
|
||||
return "0 * * * *";
|
||||
case "daily":
|
||||
return "0 0 * * *";
|
||||
default:
|
||||
return config.quotaRefreshPeriod ?? "0 0 * * *";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Request, Response } from "express";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ZodError } from "zod";
|
||||
import { AIService } from "../../key-management";
|
||||
import { QuotaExceededError } from "./request/apply-quota-limits";
|
||||
|
||||
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
|
||||
const ANTHROPIC_COMPLETION_ENDPOINT = "/v1/complete";
|
||||
@@ -63,9 +65,7 @@ export const handleInternalError = (
|
||||
res: Response
|
||||
) => {
|
||||
try {
|
||||
const isZod = err instanceof ZodError;
|
||||
const isForbidden = err.name === "ForbiddenError";
|
||||
if (isZod) {
|
||||
if (err instanceof ZodError) {
|
||||
writeErrorResponse(req, res, 400, {
|
||||
error: {
|
||||
type: "proxy_validation_error",
|
||||
@@ -75,7 +75,7 @@ export const handleInternalError = (
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else if (isForbidden) {
|
||||
} else if (err.name === "ForbiddenError") {
|
||||
// Spoofs a vaguely threatening OpenAI error message. Only invoked by the
|
||||
// block-zoomers rewriter to scare off tiktokers.
|
||||
writeErrorResponse(req, res, 403, {
|
||||
@@ -86,6 +86,16 @@ export const handleInternalError = (
|
||||
message: err.message,
|
||||
},
|
||||
});
|
||||
} else if (err instanceof QuotaExceededError) {
|
||||
writeErrorResponse(req, res, 429, {
|
||||
error: {
|
||||
type: "proxy_quota_exceeded",
|
||||
code: "quota_exceeded",
|
||||
message: `You've exceeded your token quota for this model type.`,
|
||||
info: err.quotaInfo,
|
||||
stack: err.stack,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
writeErrorResponse(req, res, 500, {
|
||||
error: {
|
||||
@@ -141,3 +151,17 @@ export function buildFakeSseMessage(
|
||||
}
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
|
||||
export function getCompletionForService({
|
||||
service,
|
||||
body,
|
||||
}: {
|
||||
service: AIService;
|
||||
body: Record<string, any>;
|
||||
}): { completion: string; model: string } {
|
||||
if (service === "anthropic") {
|
||||
return { completion: body.completion.trim(), model: body.model };
|
||||
} else {
|
||||
return { completion: body.choices[0].message.content, model: body.model };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import { hasAvailableQuota } from "../../auth/user-store";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { ProxyRequestMiddleware } from ".";
|
||||
|
||||
export class QuotaExceededError extends Error {
|
||||
public quotaInfo: any;
|
||||
constructor(message: string, quotaInfo: any) {
|
||||
super(message);
|
||||
this.name = "QuotaExceededError";
|
||||
this.quotaInfo = quotaInfo;
|
||||
}
|
||||
}
|
||||
|
||||
export const applyQuotaLimits: ProxyRequestMiddleware = (_proxyReq, req) => {
|
||||
if (!isCompletionRequest(req) || !req.user) {
|
||||
return;
|
||||
}
|
||||
|
||||
const requestedTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
if (!hasAvailableQuota(req.user.token, req.body.model, requestedTokens)) {
|
||||
throw new QuotaExceededError(
|
||||
"You have exceeded your proxy token quota for this model.",
|
||||
{
|
||||
quota: req.user.tokenLimits,
|
||||
used: req.user.tokenCounts,
|
||||
requested: requestedTokens,
|
||||
}
|
||||
);
|
||||
}
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { countTokens } from "../../../tokenization";
|
||||
import { OpenAIPromptMessage, countTokens } from "../../../tokenization";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
@@ -15,22 +15,26 @@ const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
* request body.
|
||||
*/
|
||||
export const checkContextSize: RequestPreprocessor = async (req) => {
|
||||
let prompt;
|
||||
const service = req.outboundApi;
|
||||
let result;
|
||||
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
prompt = req.body.messages;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
case "anthropic":
|
||||
}
|
||||
case "anthropic": {
|
||||
req.outputTokens = req.body.max_tokens_to_sample;
|
||||
prompt = req.body.prompt;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown outbound API: ${req.outboundApi}`);
|
||||
}
|
||||
|
||||
const result = await countTokens({ req, prompt, service: req.outboundApi });
|
||||
req.promptTokens = result.token_count;
|
||||
|
||||
// TODO: Remove once token counting is stable
|
||||
@@ -89,6 +93,7 @@ function validateContextSize(req: Request) {
|
||||
);
|
||||
|
||||
req.debug.prompt_tokens = promptTokens;
|
||||
req.debug.completion_tokens = outputTokens;
|
||||
req.debug.max_model_tokens = modelMax;
|
||||
req.debug.max_proxy_tokens = proxyMax;
|
||||
}
|
||||
@@ -101,7 +106,7 @@ function assertRequestHasTokenCounts(
|
||||
outputTokens: z.number().int().min(1),
|
||||
})
|
||||
.nonstrict()
|
||||
.parse(req);
|
||||
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { createPreprocessorMiddleware } from "./preprocess";
|
||||
export { checkContextSize } from "./check-context-size";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
|
||||
@@ -3,14 +3,21 @@ import { Request, Response } from "express";
|
||||
import * as http from "http";
|
||||
import util from "util";
|
||||
import zlib from "zlib";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import { keyPool } from "../../../key-management";
|
||||
import { enqueue, trackWaitTime } from "../../queue";
|
||||
import { incrementPromptCount } from "../../auth/user-store";
|
||||
import { isCompletionRequest, writeErrorResponse } from "../common";
|
||||
import {
|
||||
incrementPromptCount,
|
||||
incrementTokenCount,
|
||||
} from "../../auth/user-store";
|
||||
import {
|
||||
getCompletionForService,
|
||||
isCompletionRequest,
|
||||
writeErrorResponse,
|
||||
} from "../common";
|
||||
import { handleStreamedResponse } from "./handle-streamed-response";
|
||||
import { logPrompt } from "./log-prompt";
|
||||
import { countTokens } from "../../../tokenization";
|
||||
|
||||
const DECODER_MAP = {
|
||||
gzip: util.promisify(zlib.gunzip),
|
||||
@@ -84,12 +91,18 @@ export const createOnProxyResHandler = (apiMiddleware: ProxyResMiddleware) => {
|
||||
if (req.isStreaming) {
|
||||
// `handleStreamedResponse` writes to the response and ends it, so
|
||||
// we can only execute middleware that doesn't write to the response.
|
||||
middlewareStack.push(trackRateLimit, incrementKeyUsage, logPrompt);
|
||||
middlewareStack.push(
|
||||
trackRateLimit,
|
||||
countResponseTokens,
|
||||
incrementUsage,
|
||||
logPrompt
|
||||
);
|
||||
} else {
|
||||
middlewareStack.push(
|
||||
trackRateLimit,
|
||||
handleUpstreamErrors,
|
||||
incrementKeyUsage,
|
||||
countResponseTokens,
|
||||
incrementUsage,
|
||||
copyHttpHeaders,
|
||||
logPrompt,
|
||||
...apiMiddleware
|
||||
@@ -394,15 +407,56 @@ function handleOpenAIRateLimitError(
|
||||
return errorPayload;
|
||||
}
|
||||
|
||||
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isCompletionRequest(req)) {
|
||||
keyPool.incrementPrompt(req.key!);
|
||||
if (req.user) {
|
||||
incrementPromptCount(req.user.token);
|
||||
const model = req.body.model;
|
||||
const tokensUsed = req.promptTokens! + req.outputTokens!;
|
||||
incrementTokenCount(req.user.token, model, tokensUsed);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
_res,
|
||||
body
|
||||
) => {
|
||||
// This function is prone to breaking if the upstream API makes even minor
|
||||
// changes to the response format, especially for SSE responses. If you're
|
||||
// seeing errors in this function, check the reassembled response body from
|
||||
// handleStreamedResponse to see if the upstream API has changed.
|
||||
try {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
const service = req.outboundApi;
|
||||
const { completion } = getCompletionForService({ service, body });
|
||||
const tokens = await countTokens({ req, completion, service });
|
||||
|
||||
req.log.debug(
|
||||
{ service, tokens, prevOutputTokens: req.outputTokens },
|
||||
`Counted tokens for completion`
|
||||
);
|
||||
if (req.debug) {
|
||||
req.debug.completion_tokens = tokens;
|
||||
}
|
||||
|
||||
req.outputTokens = tokens.token_count;
|
||||
} catch (error) {
|
||||
req.log.error(
|
||||
error,
|
||||
"Error while counting completion tokens; assuming `max_output_tokens`"
|
||||
);
|
||||
// req.outputTokens will already be set to `max_output_tokens` from the
|
||||
// prompt counting middleware, so we don't need to do anything here.
|
||||
}
|
||||
};
|
||||
|
||||
const trackRateLimit: ProxyResHandlerWithBody = async (proxyRes, req) => {
|
||||
keyPool.updateRateLimits(req.key!, proxyRes.headers);
|
||||
};
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { AIService } from "../../../key-management";
|
||||
import { logQueue } from "../../../prompt-logging";
|
||||
import { isCompletionRequest } from "../common";
|
||||
import { getCompletionForService, isCompletionRequest } from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { logger } from "../../../logger";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -26,7 +24,7 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
|
||||
const promptPayload = getPromptForRequest(req);
|
||||
const promptFlattened = flattenMessages(promptPayload);
|
||||
const response = getResponseForService({
|
||||
const response = getCompletionForService({
|
||||
service: req.outboundApi,
|
||||
body: responseBody,
|
||||
});
|
||||
@@ -62,17 +60,3 @@ const flattenMessages = (messages: string | OaiMessage[]): string => {
|
||||
}
|
||||
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
|
||||
};
|
||||
|
||||
const getResponseForService = ({
|
||||
service,
|
||||
body,
|
||||
}: {
|
||||
service: AIService;
|
||||
body: Record<string, any>;
|
||||
}): { completion: string; model: string } => {
|
||||
if (service === "anthropic") {
|
||||
return { completion: body.completion.trim(), model: body.model };
|
||||
} else {
|
||||
return { completion: body.choices[0].message.content, model: body.model };
|
||||
}
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
applyQuotaLimits,
|
||||
blockZoomerOrigins,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
@@ -90,6 +91,7 @@ const rewriteRequest = (
|
||||
res: http.ServerResponse
|
||||
) => {
|
||||
const rewriterPipeline = [
|
||||
applyQuotaLimits,
|
||||
addKey,
|
||||
languageFilter,
|
||||
limitCompletions,
|
||||
|
||||
Reference in New Issue
Block a user