extracts Risu auth into new middleware so queue can use it too

This commit is contained in:
nai-degen
2023-07-22 13:48:02 -05:00
parent b8534dafae
commit e2bd8a6b86
5 changed files with 93 additions and 54 deletions
+64
View File
@@ -0,0 +1,64 @@
/**
* Authenticates RisuAI.xyz users using a special x-risu-tk header provided by
* RisuAI.xyz. This lets us rate limit and limit queue concurrency properly,
* since otherwise RisuAI.xyz users share the same IP address and can't be
* distinguished.
* Contributors: @kwaroran
*/
import axios from "axios";
import { Request, Response, NextFunction } from "express";
const RISUAI_TOKEN_CHECKER_URL = "https://sv.risuai.xyz/public/api/checktoken";
const validRisuTokens = new Set<string>();
let lastFailedRisuTokenCheck = 0;
export async function checkRisuToken(
req: Request,
_res: Response,
next: NextFunction
) {
let header = req.header("x-risu-tk") || null;
if (!header) {
return next();
}
const timeSinceLastFailedCheck = Date.now() - lastFailedRisuTokenCheck;
if (timeSinceLastFailedCheck < 60 * 1000) {
req.log.warn(
{ timeSinceLastFailedCheck },
"Skipping RisuAI token check due to recent failed check"
);
return next();
}
try {
if (!validRisuTokens.has(header)) {
req.log.info("Authenticating new RisuAI token");
const validCheck = await axios.post<{ vaild: boolean }>(
RISUAI_TOKEN_CHECKER_URL,
{ token: header },
{ headers: { "Content-Type": "application/json" } }
);
if (!validCheck.data.vaild) {
req.log.warn("Invalid RisuAI token; using IP instead");
} else {
req.log.info("RisuAI token authenticated");
validRisuTokens.add(header);
req.risuToken = header;
}
} else {
req.log.debug("RisuAI token already known");
req.risuToken = header;
}
} catch (err) {
lastFailedRisuTokenCheck = Date.now();
req.log.warn(
{ error: err.message },
"Error authenticating RisuAI token; using IP instead"
);
}
next();
}
+22 -12
View File
@@ -34,24 +34,34 @@ const AGNAI_CONCURRENCY_LIMIT = 15;
/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = 1;
const sameIpPredicate = (incoming: Request) => (queued: Request) =>
queued.ip === incoming.ip;
/**
* Returns a unique identifier for a request. This is used to determine if a
* request is already in the queue.
* This can be (in order of preference):
* - user token assigned by the proxy operator
* - x-risu-tk header, if the request is from RisuAI.xyz
* - IP address
*/
function getIdentifier(req: Request) {
if (req.user) {
return req.user.token;
}
if (req.risuToken) {
return req.risuToken;
}
return req.ip;
}
const sameUserPredicate = (incoming: Request) => (queued: Request) => {
const incomingUser = incoming.user ?? { token: incoming.ip };
const queuedUser = queued.user ?? { token: queued.ip };
return queuedUser.token === incomingUser.token;
const queuedId = getIdentifier(queued);
const incomingId = getIdentifier(incoming);
return queuedId === incomingId;
};
export function enqueue(req: Request) {
let enqueuedRequestCount = 0;
const enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
let isGuest = req.user?.token === undefined;
if (isGuest) {
enqueuedRequestCount = queue.filter(sameIpPredicate(req)).length;
} else {
enqueuedRequestCount = queue.filter(sameUserPredicate(req)).length;
}
// All Agnai.chat requests come from the same IP, so we allow them to have
// more spots in the queue. Can't make it unlimited because people will
// intentionally abuse it.
+3 -42
View File
@@ -1,15 +1,13 @@
import axios from "axios";
import { Request, Response, NextFunction } from "express";
import { config } from "../config";
export const AGNAI_DOT_CHAT_IP = "157.230.249.32";
const RISUAI_TOKEN_CHECKER_URL = "https://sv.risuai.xyz/public/api/checktoken";
const RATE_LIMIT_ENABLED = Boolean(config.modelRateLimit);
const RATE_LIMIT = Math.max(1, config.modelRateLimit);
const ONE_MINUTE_MS = 60 * 1000;
const lastAttempts = new Map<string, number[]>();
const validRisuTokens = new Set<string>();
const expireOldAttempts = (now: number) => (attempt: number) =>
attempt > now - ONE_MINUTE_MS;
@@ -73,46 +71,9 @@ export const ipLimiter = async (
return;
}
// makes risuai.xyz rate limiting by x-risu-tk header since it's shared between a lot of users.
let risuToken = req.header("x-risu-tk") || null;
if (risuToken) {
try {
// checks the token only when it is not in freshRisuTokens or bitFreshRisuTokens
if (!validRisuTokens.has(risuToken)) {
req.log.info(
{ token: `${risuToken.slice(0, 4)}...` },
"Authenticating new RisuAI token"
);
// checks the token is vaild (fresh) to prevend abuse
const validCheck = await axios.post<{ vaild: boolean }>(
RISUAI_TOKEN_CHECKER_URL,
{ token: risuToken },
{ headers: { "Content-Type": "application/json" } }
);
if (!validCheck.data.vaild) {
//if its invaild, uses ip instead
req.log.warn("Invalid RisuAI token; rate limiting by IP instead");
risuToken = null;
} else {
req.log.info("RisuAI token authenticated; adding to known tokens");
validRisuTokens.add(risuToken);
}
}
} catch (e: any) {
//if request throws error, uses ip
// TODO: probably need a backoff here to avoid spamming RisuAI
req.log.warn(
{ error: e.message },
"Error authenticating RisuAI token; rate limiting by IP instead"
);
risuToken = null;
}
}
// If user is authenticated, key rate limiting by their token. Otherwise, key
// rate limiting by their IP address. Mitigates key sharing.
const rateLimitKey = req.user?.token || risuToken || req.ip;
const rateLimitKey = req.user?.token || req.risuToken || req.ip;
const { remaining, reset } = getStatus(rateLimitKey);
res.set("X-RateLimit-Limit", config.modelRateLimit.toString());
@@ -127,7 +88,7 @@ export const ipLimiter = async (
type: "proxy_rate_limited",
message: `This proxy is rate limited to ${
config.modelRateLimit
} model requests per minute. Please try again in ${Math.ceil(
} prompts per minute. Please try again in ${Math.ceil(
tryAgainInMs / 1000
)} seconds.`,
},
+2
View File
@@ -6,6 +6,7 @@ equivalent OpenAI requests. */
import * as express from "express";
import { gatekeeper } from "./auth/gatekeeper";
import { checkRisuToken } from "./auth/check-risu-token";
import { kobold } from "./kobold";
import { openai } from "./openai";
import { anthropic } from "./anthropic";
@@ -16,6 +17,7 @@ proxyRouter.use(
express.urlencoded({ extended: true, limit: "1536kb" })
);
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);
proxyRouter.use((req, _res, next) => {
req.startTime = Date.now();
req.retryCount = 0;
+2
View File
@@ -10,6 +10,8 @@ declare global {
inboundApi: AIService | "kobold";
/** Denotes the format of the request being proxied to the API. */
outboundApi: AIService;
/** If the request comes from a RisuAI.xyz user, this is their token. */
risuToken?: string;
user?: User;
isStreaming?: boolean;
startTime: number;