Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)

This commit is contained in:
khanon
2023-10-01 01:40:18 +00:00
parent 7e681a7bef
commit fa4bf468d2
38 changed files with 1438 additions and 410 deletions
+16 -5
View File
@@ -93,7 +93,8 @@ export function enqueue(req: Request) {
// If the request opted into streaming, we need to register a heartbeat
// handler to keep the connection alive while it waits in the queue. We
// deregister the handler when the request is dequeued.
if (req.body.stream === "true" || req.body.stream === true) {
const { stream } = req.body;
if (stream === "true" || stream === true || req.isStreaming) {
const res = req.res!;
if (!res.headersSent) {
initStreaming(req);
@@ -138,9 +139,15 @@ function getPartitionForRequest(req: Request): ModelFamily {
// There is a single request queue, but it is partitioned by model family.
// Model families are typically separated on cost/rate limit boundaries so
// they should be treated as separate queues.
const provider = req.outboundApi;
const model = (req.body.model as SupportedModel) ?? "gpt-3.5-turbo";
switch (provider) {
// Weird special case for AWS because they serve multiple models from
// different vendors, even if currently only one is supported.
if (req.service === "aws") {
return "aws-claude";
}
switch (req.outboundApi) {
case "anthropic":
return getClaudeModelFamily(model);
case "openai":
@@ -149,7 +156,7 @@ function getPartitionForRequest(req: Request): ModelFamily {
case "google-palm":
return getGooglePalmModelFamily(model);
default:
assertNever(provider);
assertNever(req.outboundApi);
}
}
@@ -198,12 +205,13 @@ function processQueue() {
// the others, because we only track one rate limit per key.
// TODO: `getLockoutPeriod` uses model names instead of model families
// TODO: genericize this
// TODO: genericize this it's really ugly
const gpt432kLockout = keyPool.getLockoutPeriod("gpt-4-32k");
const gpt4Lockout = keyPool.getLockoutPeriod("gpt-4");
const turboLockout = keyPool.getLockoutPeriod("gpt-3.5-turbo");
const claudeLockout = keyPool.getLockoutPeriod("claude-v1");
const palmLockout = keyPool.getLockoutPeriod("text-bison-001");
const awsClaudeLockout = keyPool.getLockoutPeriod("anthropic.claude-v2");
const reqs: (Request | undefined)[] = [];
if (gpt432kLockout === 0) {
@@ -221,6 +229,9 @@ function processQueue() {
if (palmLockout === 0) {
reqs.push(dequeue("bison"));
}
if (awsClaudeLockout === 0) {
reqs.push(dequeue("aws-claude"));
}
reqs.filter(Boolean).forEach((req) => {
if (req?.proceed) {