Implements prompt logging via Google Sheets (khanon/oai-reverse-proxy!1)

This commit is contained in:
nai-degen
2023-04-15 01:21:04 +00:00
parent a767044850
commit fc3043dad0
30 changed files with 1078 additions and 80 deletions
+45
View File
@@ -0,0 +1,45 @@
import { Key, Model, keyPool, SUPPORTED_MODELS } from "../../../key-management";
import type { ExpressHttpProxyReqCallback } from ".";
/** Add an OpenAI key from the pool to the request. */
export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
let assignedKey: Key;
// Not all clients request a particular model.
// If they request a model, just use that.
// If they don't request a model, use a GPT-4 key if there is an active one,
// otherwise use a GPT-3.5 key.
// TODO: Anthropic mode should prioritize Claude over Claude Instant.
// Each provider needs to define some priority order for their models.
if (bodyHasModel(req.body)) {
assignedKey = keyPool.get(req.body.model);
} else {
try {
assignedKey = keyPool.get("gpt-4");
} catch {
assignedKey = keyPool.get("gpt-3.5-turbo");
}
}
req.key = assignedKey;
req.log.info(
{
key: assignedKey.hash,
model: req.body?.model,
isGpt4: assignedKey.isGpt4,
},
"Assigned key to request"
);
// TODO: Requests to Anthropic models use `X-API-Key`.
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
};
function bodyHasModel(body: any): body is { model: Model } {
// Model names can have suffixes indicating the frozen release version but
// OpenAI and Anthropic will use the latest version if you omit the suffix.
const isSupportedModel = (model: string) =>
SUPPORTED_MODELS.some((supported) => model.startsWith(supported));
return typeof body?.model === "string" && isSupportedModel(body.model);
}
@@ -0,0 +1,8 @@
import type { ExpressHttpProxyReqCallback } from ".";
/** Disable token streaming as the proxy middleware doesn't support it. */
export const disableStream: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
if (req.method === "POST" && req.body && req.body.stream) {
req.body.stream = false;
}
};
@@ -0,0 +1,14 @@
import { fixRequestBody } from "http-proxy-middleware";
import type { ExpressHttpProxyReqCallback } from ".";
/** Finalize the rewritten request body. Must be the last rewriter. */
export const finalizeBody: ExpressHttpProxyReqCallback = (proxyReq, req) => {
if (["POST", "PUT", "PATCH"].includes(req.method ?? "") && req.body) {
const updatedBody = JSON.stringify(req.body);
proxyReq.setHeader("Content-Length", Buffer.byteLength(updatedBody));
(req as any).rawBody = Buffer.from(updatedBody);
// body-parser and http-proxy-middleware don't play nice together
fixRequestBody(proxyReq, req);
}
};
+16
View File
@@ -0,0 +1,16 @@
import type { Request } from "express";
import type { ClientRequest } from "http";
import type { ProxyReqCallback } from "http-proxy";
export { addKey } from "./add-key";
export { disableStream } from "./disable-stream";
export { finalizeBody } from "./finalize-body";
export { languageFilter } from "./language-filter";
export { limitCompletions } from "./limit-completions";
export { limitOutputTokens } from "./limit-output-tokens";
export { transformKoboldPayload } from "./transform-kobold-payload";
export type ExpressHttpProxyReqCallback = ProxyReqCallback<
ClientRequest,
Request
>;
@@ -0,0 +1,36 @@
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
const DISALLOWED_REGEX =
/[\u2E80-\u2E99\u2E9B-\u2EF3\u2F00-\u2FD5\u3005\u3007\u3021-\u3029\u3038-\u303B\u3400-\u4DB5\u4E00-\u9FD5\uF900-\uFA6D\uFA70-\uFAD9]/;
// Our shitty free-tier VMs will fall over if we test every single character in
// each 15k character request ten times a second. So we'll just sample 20% of
// the characters and hope that's enough.
const containsDisallowedCharacters = (text: string) => {
const sampleSize = Math.ceil(text.length * (config.rejectSampleRate || 0.2));
const sample = text
.split("")
.sort(() => 0.5 - Math.random())
.slice(0, sampleSize)
.join("");
return DISALLOWED_REGEX.test(sample);
};
/** Block requests containing too many disallowed characters. */
export const languageFilter: ExpressHttpProxyReqCallback = (_proxyReq, req) => {
if (!config.rejectDisallowed) {
return;
}
if (req.method === "POST" && req.body?.messages) {
const combinedText = req.body.messages
.map((m: { role: string; content: string }) => m.content)
.join(",");
if (containsDisallowedCharacters(combinedText)) {
logger.warn(`Blocked request containing bad characters`);
_proxyReq.destroy(new Error(config.rejectMessage));
}
}
};
@@ -0,0 +1,17 @@
import type { ExpressHttpProxyReqCallback } from ".";
const OPENAI_CHAT_COMPLETION_ENDPOINT = "/v1/chat/completions";
/** Don't allow multiple completions to be requested to prevent abuse. */
export const limitCompletions: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.method === "POST" && req.path === OPENAI_CHAT_COMPLETION_ENDPOINT) {
const originalN = req.body?.n || 1;
req.body.n = 1;
if (originalN !== req.body.n) {
req.log.warn(`Limiting completion choices from ${originalN} to 1`);
}
}
};
@@ -0,0 +1,29 @@
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
const MAX_TOKENS = config.maxOutputTokens;
/** Enforce a maximum number of tokens requested from OpenAI. */
export const limitOutputTokens: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.method === "POST" && req.body?.max_tokens) {
// convert bad or missing input to a MAX_TOKENS
if (typeof req.body.max_tokens !== "number") {
logger.warn(
`Invalid max_tokens value: ${req.body.max_tokens}. Using ${MAX_TOKENS}`
);
req.body.max_tokens = MAX_TOKENS;
}
const originalTokens = req.body.max_tokens;
req.body.max_tokens = Math.min(req.body.max_tokens, MAX_TOKENS);
if (originalTokens !== req.body.max_tokens) {
logger.warn(
`Limiting max_tokens from ${originalTokens} to ${req.body.max_tokens}`
);
}
}
};
@@ -0,0 +1,99 @@
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
// Kobold requests look like this:
// body:
// {
// prompt: "Aqua is character from Konosuba anime. Aqua is a goddess, before life in the Fantasy World, she was a goddess of water who guided humans to the afterlife. Aqua looks like young woman with beauty no human could match. Aqua has light blue hair, blue eyes, slim figure, long legs, wide hips, blue waist-long hair that is partially tied into a loop with a spherical clip. Aqua's measurements are 83-56-83 cm. Aqua's height 157cm. Aqua wears sleeveless dark-blue dress with white trimmings, extremely short dark blue miniskirt, green bow around her chest with a blue gem in the middle, detached white sleeves with blue and golden trimmings, thigh-high blue heeled boots over white stockings with blue trimmings. Aqua is very strong in water magic, but a little stupid, so she does not always use it to the place. Aqua is high-spirited, cheerful, carefree. Aqua rarely thinks about the consequences of her actions and always acts or speaks on her whims. Because very easy to taunt Aqua with jeers or lure her with praises.\n" +
// "Aqua's personality: high-spirited, likes to party, carefree, cheerful.\n" +
// 'Circumstances and context of the dialogue: Aqua is standing in the city square and is looking for new followers\n' +
// 'This is how Aqua should talk\n' +
// 'You: Hi Aqua, I heard you like to spend time in the pub.\n' +
// "Aqua: *excitedly* Oh my goodness, yes! I just love spending time at the pub! It's so much fun to talk to all the adventurers and hear about their exciting adventures! And you are?\n" +
// "You: I'm a new here and I wanted to ask for your advice.\n" +
// 'Aqua: *giggles* Oh, advice! I love giving advice! And in gratitude for that, treat me to a drink! *gives signals to the bartender*\n' +
// 'This is how Aqua should talk\n' +
// 'You: Hello\n' +
// "Aqua: *excitedly* Hello there, dear! Are you new to Axel? Don't worry, I, Aqua the goddess of water, am here to help you! Do you need any assistance? And may I say, I look simply radiant today! *strikes a pose and looks at you with puppy eyes*\n" +
// '\n' +
// 'Then the roleplay chat between You and Aqua begins.\n' +
// "Aqua: *She is in the town square of a city named Axel. It's morning on a Saturday and she suddenly notices a person who looks like they don't know what they're doing. She approaches him and speaks* \n" +
// '\n' +
// `"Are you new here? Do you need help? Don't worry! I, Aqua the Goddess of Water, shall help you! Do I look beautiful?" \n` +
// '\n' +
// '*She strikes a pose and looks at him with puppy eyes.*\n' +
// 'You: test\n' +
// 'You: test\n' +
// 'You: t\n' +
// 'You: test\n',
// use_story: false,
// use_memory: false,
// use_authors_note: false,
// use_world_info: false,
// max_context_length: 2048,
// max_length: 180,
// rep_pen: 1.1,
// rep_pen_range: 1024,
// rep_pen_slope: 0.9,
// temperature: 0.65,
// tfs: 0.9,
// top_a: 0,
// top_k: 0,
// top_p: 0.9,
// typical: 1,
// sampler_order: [
// 6, 0, 1, 2,
// 3, 4, 5
// ],
// singleline: false
// }
// OpenAI expects this body:
// { model: 'gpt-3.5-turbo', temperature: 0.65, top_p: 0.9, max_tokens: 180, messages }
// there's also a frequency_penalty but it's not clear how that maps to kobold's
// rep_pen.
// messages is an array of { role: "system" | "assistant" | "user", content: ""}
// kobold only sends us the entire prompt. we can try to split the last two
// lines into user and assistant messages, but that's not always correct. For
// now it will have to do.
/** Transforms a KoboldAI payload into an OpenAI payload. */
export const transformKoboldPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
const { body } = req;
const { prompt, max_length, rep_pen, top_p, temperature } = body;
if (!max_length) {
logger.error("KoboldAI request missing max_length.");
throw new Error("You must specify a max_length parameter.");
}
const promptLines = prompt.split("\n");
// The very last line is the contentless "Assistant: " hint to the AI.
// Tavern just leaves an empty line, Agnai includes the AI's name.
const assistantHint = promptLines.pop();
// The second-to-last line is the user's prompt, generally.
const userPrompt = promptLines.pop();
const messages = [
{ role: "system", content: promptLines.join("\n") },
{ role: "user", content: userPrompt },
{ role: "assistant", content: assistantHint },
];
// Kobold doesn't select a model. If the addKey rewriter assigned us a GPT-4
// key, use that. Otherwise, use GPT-3.5-turbo.
const model = req.key!.isGpt4 ? "gpt-4" : "gpt-3.5-turbo";
const newBody = {
model,
temperature,
top_p,
frequency_penalty: rep_pen, // remove this if model turns schizo
max_tokens: max_length,
messages,
};
req.body = newBody;
};
+283
View File
@@ -0,0 +1,283 @@
import { Request, Response } from "express";
import * as http from "http";
import util from "util";
import zlib from "zlib";
import * as httpProxy from "http-proxy";
import { logger } from "../../../logger";
import { keyPool } from "../../../key-management";
import { logPrompt } from "./log-prompt";
export const QUOTA_ROUTES = ["/v1/chat/completions"];
const DECODER_MAP = {
gzip: util.promisify(zlib.gunzip),
deflate: util.promisify(zlib.inflate),
br: util.promisify(zlib.brotliDecompress),
};
const isSupportedContentEncoding = (
contentEncoding: string
): contentEncoding is keyof typeof DECODER_MAP => {
return contentEncoding in DECODER_MAP;
};
type DecodeResponseBodyHandler = (
proxyRes: http.IncomingMessage,
req: Request,
res: Response
) => Promise<string | Record<string, any>>;
export type ProxyResHandlerWithBody = (
proxyRes: http.IncomingMessage,
req: Request,
res: Response,
/**
* This will be an object if the response content-type is application/json,
* otherwise it will be a string.
*/
body: string | Record<string, any>
) => Promise<void>;
export type ProxyResMiddleware = ProxyResHandlerWithBody[];
/**
* Returns a on.proxyRes handler that executes the given middleware stack after
* the common proxy response handlers have processed the response and decoded
* the body. Custom middleware won't execute if the response is determined to
* be an error from the downstream service as the response will be taken over
* by the common error handler.
*/
export const createOnProxyResHandler = (middleware: ProxyResMiddleware) => {
return async (
proxyRes: http.IncomingMessage,
req: Request,
res: Response
) => {
let lastMiddlewareName = decodeResponseBody.name;
try {
const body = await decodeResponseBody(proxyRes, req, res);
const middlewareStack: ProxyResMiddleware = [
handleDownstreamErrors,
incrementKeyUsage,
copyHttpHeaders,
logPrompt,
...middleware,
];
for (const middleware of middlewareStack) {
lastMiddlewareName = middleware.name;
await middleware(proxyRes, req, res, body);
}
} catch (error: any) {
// downstream errors will have already been responded to
if (res.headersSent) {
return;
}
const message = `Error while executing proxy response middleware: ${lastMiddlewareName} (${error.message})`;
logger.error(
{
error: error.stack,
thrownBy: lastMiddlewareName,
key: req.key?.hash,
},
message
);
res
.status(500)
.json({ error: "Internal server error", proxy_note: message });
}
};
};
/**
* Handles the response from the downstream service and decodes the body if
* necessary. If the response is JSON, it will be parsed and returned as an
* object. Otherwise, it will be returned as a string.
* @throws {Error} Unsupported content-encoding or invalid application/json body
*/
const decodeResponseBody: DecodeResponseBodyHandler = async (
proxyRes,
req,
res
) => {
const promise = new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
let body = Buffer.concat(chunks);
const contentEncoding = proxyRes.headers["content-encoding"];
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
body = await decoder(body);
} else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
logger.warn({ contentEncoding, key: req.key?.hash }, errorMessage);
res.status(500).json({ error: errorMessage, contentEncoding });
return reject(errorMessage);
}
}
try {
if (proxyRes.headers["content-type"]?.includes("application/json")) {
const json = JSON.parse(body.toString());
return resolve(json);
}
return resolve(body.toString());
} catch (error: any) {
const errorMessage = `Proxy received response with invalid JSON: ${error.message}`;
logger.warn({ error, key: req.key?.hash }, errorMessage);
res.status(500).json({ error: errorMessage });
return reject(errorMessage);
}
});
});
return promise;
};
// TODO: This is too specific to OpenAI's error responses, Anthropic errors
// will need a different handler.
/**
* Handles non-2xx responses from the downstream service. If the proxied
* response is an error, this will respond to the client with an error payload
* and throw an error to stop the middleware stack.
* @throws {Error} HTTP error status code from downstream service
*/
const handleDownstreamErrors: ProxyResHandlerWithBody = async (
proxyRes,
req,
res,
body
) => {
const statusCode = proxyRes.statusCode || 500;
if (statusCode < 400) {
return;
}
let errorPayload: Record<string, any>;
// Subtract 1 from available keys because if this message is being shown,
// it's because the key is about to be disabled.
const availableKeys = keyPool.available() - 1;
const tryAgainMessage = Boolean(availableKeys)
? `There are ${availableKeys} more keys available; try your request again.`
: "There are no more keys available.";
try {
if (typeof body === "object") {
errorPayload = body;
} else {
throw new Error("Received non-JSON error response from downstream.");
}
} catch (parseError: any) {
const statusMessage = proxyRes.statusMessage || "Unknown error";
// Likely Bad Gateway or Gateway Timeout from OpenAI's Cloudflare proxy
logger.warn(
{ statusCode, statusMessage, key: req.key?.hash },
parseError.message
);
const errorObject = {
statusCode,
statusMessage: proxyRes.statusMessage,
error: parseError.message,
proxy_note: `This is likely a temporary error with the downstream service.`,
};
res.status(statusCode).json(errorObject);
throw new Error(parseError.message);
}
logger.warn(
{
statusCode,
type: errorPayload.error?.code,
errorPayload,
key: req.key?.hash,
},
`Received error response from downstream. (${proxyRes.statusMessage})`
);
if (statusCode === 400) {
// Bad request (likely prompt is too long)
errorPayload.proxy_note = `OpenAI rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!);
errorPayload.proxy_note = `The OpenAI key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 429) {
// One of:
// - Quota exceeded (key is dead, disable it)
// - Rate limit exceeded (key is fine, just try again)
// - Model overloaded (their fault, just try again)
if (errorPayload.error?.type === "insufficient_quota") {
keyPool.disable(req.key!);
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
} else {
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
}
} else if (statusCode === 404) {
// Most likely model not found
if (errorPayload.error?.code === "model_not_found") {
if (req.key!.isGpt4) {
keyPool.downgradeKey(req.key?.hash);
errorPayload.proxy_note = `This key was incorrectly assigned to GPT-4. It has been downgraded to Turbo.`;
} else {
errorPayload.proxy_note = `No model was found for this key.`;
}
}
} else {
errorPayload.proxy_note = `Unrecognized error from OpenAI.`;
}
res.status(statusCode).json(errorPayload);
throw new Error(errorPayload.error?.message);
};
/** Handles errors in the request rewriter pipeline. */
export const handleInternalError: httpProxy.ErrorCallback = (
err,
_req,
res
) => {
logger.error({ error: err }, "Error in proxy request pipeline.");
(res as http.ServerResponse).writeHead(500, {
"Content-Type": "application/json",
});
res.end(
JSON.stringify({
error: {
type: "proxy_error",
message: err.message,
stack: err.stack,
proxy_note: `Reverse proxy encountered an error before it could reach the downstream API.`,
},
})
);
};
const incrementKeyUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (QUOTA_ROUTES.includes(req.path)) {
keyPool.incrementPrompt(req.key?.hash);
}
};
const copyHttpHeaders: ProxyResHandlerWithBody = async (
proxyRes,
_req,
res
) => {
Object.keys(proxyRes.headers).forEach((key) => {
// Omit content-encoding because we will always decode the response body
if (key === "content-encoding") {
return;
}
// We're usually using res.json() to send the response, which causes express
// to set content-length. That's not valid for chunked responses and some
// clients will reject it so we need to omit it.
if (key === "transfer-encoding") {
return;
}
res.setHeader(key, proxyRes.headers[key] as string);
});
};
@@ -0,0 +1,54 @@
import { config } from "../../../config";
import { logQueue } from "../../../prompt-logging";
import { ProxyResHandlerWithBody } from ".";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
_proxyRes,
req,
_res,
responseBody
) => {
if (!config.promptLogging) {
return;
}
if (typeof responseBody !== "object") {
throw new Error("Expected body to be an object");
}
const model = req.body.model;
const promptFlattened = flattenMessages(req.body.messages);
const response = getResponseForModel({ model, body: responseBody });
logQueue.enqueue({
model,
endpoint: req.api,
promptRaw: JSON.stringify(req.body.messages),
promptFlattened,
response,
});
};
type OaiMessage = {
role: "user" | "assistant" | "system";
content: string;
};
const flattenMessages = (messages: OaiMessage[]): string => {
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
};
const getResponseForModel = ({
model,
body,
}: {
model: string;
body: Record<string, any>;
}) => {
if (model.startsWith("claude")) {
// TODO: confirm if there is supposed to be a leading space
return body.completion.trim();
} else {
return body.choices[0].message.content;
}
};