Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
const BISON_MAX_CONTEXT = 8100;
|
||||
|
||||
/**
|
||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||
* and outbound API format, which combined determine the size of the context.
|
||||
* If the context is too large, an error is thrown.
|
||||
* This preprocessor should run after any preprocessor that transforms the
|
||||
* request body.
|
||||
*/
|
||||
export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
assertRequestHasTokenCounts(req);
|
||||
const promptTokens = req.promptTokens;
|
||||
const outputTokens = req.outputTokens;
|
||||
const contextTokens = promptTokens + outputTokens;
|
||||
const model = req.body.model;
|
||||
|
||||
let proxyMax: number;
|
||||
switch (req.outboundApi) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
proxyMax = OPENAI_MAX_CONTEXT;
|
||||
break;
|
||||
case "anthropic":
|
||||
proxyMax = CLAUDE_MAX_CONTEXT;
|
||||
break;
|
||||
case "google-palm":
|
||||
proxyMax = BISON_MAX_CONTEXT;
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
proxyMax ||= Number.MAX_SAFE_INTEGER;
|
||||
|
||||
let modelMax = 0;
|
||||
if (model.match(/gpt-3.5-turbo-16k/)) {
|
||||
modelMax = 16384;
|
||||
} else if (model.match(/gpt-3.5-turbo/)) {
|
||||
modelMax = 4096;
|
||||
} else if (model.match(/gpt-4-32k/)) {
|
||||
modelMax = 32768;
|
||||
} else if (model.match(/gpt-4/)) {
|
||||
modelMax = 8192;
|
||||
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?(?:-100k)/)) {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/claude-(?:instant-)?v1(?:\.\d)?$/)) {
|
||||
modelMax = 9000;
|
||||
} else if (model.match(/claude-2/)) {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^text-bison-\d{3}$/)) {
|
||||
modelMax = BISON_MAX_CONTEXT;
|
||||
} else {
|
||||
// Don't really want to throw here because I don't want to have to update
|
||||
// this ASAP every time a new model is released.
|
||||
req.log.warn({ model }, "Unknown model, using 100k token limit.");
|
||||
modelMax = 100000;
|
||||
}
|
||||
|
||||
const finalMax = Math.min(proxyMax, modelMax);
|
||||
z.number()
|
||||
.int()
|
||||
.max(finalMax, {
|
||||
message: `Your request exceeds the context size limit for this model or proxy. (max: ${finalMax} tokens, requested: ${promptTokens} prompt + ${outputTokens} output = ${contextTokens} context tokens)`,
|
||||
})
|
||||
.parse(contextTokens);
|
||||
|
||||
req.log.debug(
|
||||
{ promptTokens, outputTokens, contextTokens, modelMax, proxyMax },
|
||||
"Prompt size validated"
|
||||
);
|
||||
|
||||
req.debug.prompt_tokens = promptTokens;
|
||||
req.debug.completion_tokens = outputTokens;
|
||||
req.debug.max_model_tokens = modelMax;
|
||||
req.debug.max_proxy_tokens = proxyMax;
|
||||
};
|
||||
|
||||
function assertRequestHasTokenCounts(
|
||||
req: Request
|
||||
): asserts req is Request & { promptTokens: number; outputTokens: number } {
|
||||
z.object({
|
||||
promptTokens: z.number().int().min(1),
|
||||
outputTokens: z.number().int().min(1),
|
||||
})
|
||||
.nonstrict()
|
||||
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
|
||||
}
|
||||
Reference in New Issue
Block a user