implements language filter
This commit is contained in:
@@ -21,3 +21,5 @@ OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# MAX_OUTPUT_TOKENS=256
|
||||
# LOG_LEVEL=info
|
||||
# LOG_PROMPTS=false
|
||||
# REJECT_DISALLOWED=false
|
||||
# REJECT_MESSAGE=This content violates /aicg/'s acceptable use policy.
|
||||
|
||||
@@ -12,6 +12,10 @@ type Config = {
|
||||
modelRateLimit: number;
|
||||
/** Max number of tokens to generate. Requests which specify a higher value will be rewritten to use this value. */
|
||||
maxOutputTokens: number;
|
||||
/** Whether requests containing disallowed characters should be rejected. */
|
||||
rejectDisallowed?: boolean;
|
||||
/** Message to return when rejecting requests. */
|
||||
rejectMessage?: string;
|
||||
/** Logging threshold. */
|
||||
logLevel?: "debug" | "info" | "warn" | "error";
|
||||
/** Whether prompts and responses should be logged. */
|
||||
@@ -24,6 +28,11 @@ export const config: Config = {
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
modelRateLimit: getEnvWithDefault("MODEL_RATE_LIMIT", 2),
|
||||
maxOutputTokens: getEnvWithDefault("MAX_OUTPUT_TOKENS", 256),
|
||||
rejectDisallowed: getEnvWithDefault("REJECT_DISALLOWED", false),
|
||||
rejectMessage: getEnvWithDefault(
|
||||
"REJECT_MESSAGE",
|
||||
"This content violates /aicg/'s acceptable use policy."
|
||||
),
|
||||
logLevel: getEnvWithDefault("LOG_LEVEL", "info"),
|
||||
logPrompts: getEnvWithDefault("LOG_PROMPTS", false),
|
||||
} as const;
|
||||
|
||||
@@ -6,6 +6,7 @@ import { handleResponse, onError } from "./common";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import {
|
||||
addKey,
|
||||
languageFilter,
|
||||
disableStream,
|
||||
finalizeBody,
|
||||
limitOutputTokens,
|
||||
@@ -18,6 +19,7 @@ const rewriteRequest = (
|
||||
) => {
|
||||
const rewriterPipeline = [
|
||||
addKey,
|
||||
languageFilter,
|
||||
disableStream,
|
||||
limitOutputTokens,
|
||||
finalizeBody,
|
||||
|
||||
@@ -3,6 +3,7 @@ import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
export { addKey } from "./add-key";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { disableStream } from "./disable-stream";
|
||||
export { limitOutputTokens } from "./limit-output-tokens";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import { config } from "../../config";
|
||||
import type { ExpressHttpProxyReqCallback } from ".";
|
||||
import { logger } from "../../logger";
|
||||
|
||||
const DISALLOWED_REGEX =
|
||||
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
|
||||
|
||||
// Our shitty free-tier will fall over if we test every single character in each
|
||||
// 15k character request ten times a second. So we'll just sample 20% of the
|
||||
// characters and hope that's enough.
|
||||
const containsDisallowedCharacters = (text: string) => {
|
||||
const sampleSize = Math.floor(text.length * 0.2);
|
||||
const sample = text
|
||||
.split("")
|
||||
.sort(() => 0.5 - Math.random())
|
||||
.slice(0, sampleSize)
|
||||
.join("");
|
||||
return DISALLOWED_REGEX.test(sample);
|
||||
};
|
||||
|
||||
/** Block requests containing too many disallowed characters. */
|
||||
export const languageFilter: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
|
||||
if (!config.rejectDisallowed) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.method === "POST" && req.body?.messages) {
|
||||
const combinedText = req.body.messages
|
||||
.map((m: { role: string; content: string }) => m.content)
|
||||
.join(",");
|
||||
if (containsDisallowedCharacters(combinedText)) {
|
||||
logger.warn(`Blocked request containing bad characters`);
|
||||
_proxyReq.destroy(new Error(config.rejectMessage));
|
||||
}
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user