Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)

This commit is contained in:
khanon
2023-10-01 01:40:18 +00:00
parent 7e681a7bef
commit fa4bf468d2
38 changed files with 1438 additions and 410 deletions
+120 -70
View File
@@ -12,7 +12,7 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import {
getCompletionForService,
getCompletionFromBody,
isCompletionRequest,
writeErrorResponse,
} from "../common";
@@ -173,7 +173,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
throw err;
}
const promise = new Promise<string>((resolve, reject) => {
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
@@ -209,10 +209,14 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
}
});
});
return promise;
};
// TODO: This is too specific to OpenAI's error responses.
type ProxiedErrorPayload = {
error?: Record<string, any>;
message?: string;
proxy_note?: string;
};
/**
* Handles non-2xx responses from the upstream service. If the proxied response
* is an error, this will respond to the client with an error payload and throw
@@ -233,27 +237,19 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
return;
}
let errorPayload: Record<string, any>;
// Subtract 1 from available keys because if this message is being shown,
// it's because the key is about to be disabled.
const availableKeys = keyPool.available(req.outboundApi) - 1;
const tryAgainMessage = Boolean(availableKeys)
? `There are ${availableKeys} more keys available; try your request again.`
: "There are no more keys available.";
let errorPayload: ProxiedErrorPayload;
const tryAgainMessage = keyPool.available(req.body?.model)
? `There may be more keys available for this model; try again in a few seconds.`
: "There are no more keys available for this model.";
try {
if (typeof body === "object") {
errorPayload = body;
} else {
throw new Error("Received unparsable error response from upstream.");
}
} catch (parseError: any) {
assertJsonResponse(body);
errorPayload = body;
} catch (parseError) {
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
const hash = req.key?.hash;
const statusMessage = proxyRes.statusMessage || "Unknown error";
// Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer
logger.warn(
{ statusCode, statusMessage, key: req.key?.hash },
parseError.message
);
logger.warn({ statusCode, statusMessage, key: hash }, parseError.message);
const errorObject = {
statusCode,
@@ -265,53 +261,76 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
throw new Error(parseError.message);
}
const errorType =
errorPayload.error?.code ||
errorPayload.error?.type ||
getAwsErrorType(proxyRes.headers["x-amzn-errortype"]);
logger.warn(
{
statusCode,
type: errorPayload.error?.code,
errorPayload,
key: req.key?.hash,
},
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
`Received error response from upstream. (${proxyRes.statusMessage})`
);
const service = req.key!.service;
if (service === "aws") {
// Try to standardize the error format for AWS
errorPayload.error = { message: errorPayload.message, type: errorType };
delete errorPayload.message;
}
if (statusCode === 400) {
// Bad request (likely prompt is too long)
switch (req.outboundApi) {
// Bad request. For OpenAI, this is usually due to prompt length.
// For Anthropic, this is usually due to missing preamble.
switch (service) {
case "openai":
case "openai-text":
case "google-palm":
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
break;
case "anthropic":
case "aws":
maybeHandleMissingPreambleError(req, errorPayload);
break;
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 403) {
// Amazon is the only service that returns 403.
switch (errorType) {
case "UnrecognizedClientException":
// Key is invalid.
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
break;
case "AccessDeniedException":
errorPayload.proxy_note = `API key doesn't have access to the requested resource.`;
break;
default:
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
}
} else if (statusCode === 429) {
switch (req.outboundApi) {
switch (service) {
case "openai":
case "openai-text":
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
break;
case "anthropic":
handleAnthropicRateLimitError(req, errorPayload);
break;
case "aws":
handleAwsRateLimitError(req, errorPayload);
break;
case "google-palm":
throw new Error("Rate limit handling not implemented for PaLM");
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else if (statusCode === 404) {
// Most likely model not found
switch (req.outboundApi) {
switch (service) {
case "openai":
case "openai-text":
if (errorPayload.error?.code === "model_not_found") {
const requestedModel = req.body.model;
const modelFamily = getOpenAIModelFamily(requestedModel);
@@ -328,8 +347,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "google-palm":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
break;
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else {
errorPayload.proxy_note = `Unrecognized error from upstream service.`;
@@ -368,7 +390,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
*/
function maybeHandleMissingPreambleError(
req: Request,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
) {
if (
errorPayload.error?.type === "invalid_request_error" &&
@@ -388,7 +410,7 @@ function maybeHandleMissingPreambleError(
function handleAnthropicRateLimitError(
req: Request,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
) {
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
@@ -399,35 +421,55 @@ function handleAnthropicRateLimitError(
}
}
function handleAwsRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const errorType = errorPayload.error?.type;
switch (errorType) {
case "ThrottlingException":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("AWS rate-limited request re-enqueued.");
case "ModelNotReadyException":
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
break;
default:
errorPayload.proxy_note = `Unrecognized rate limit error from AWS. (${errorType})`;
}
}
function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
): Record<string, any> {
const type = errorPayload.error?.type;
if (type === "insufficient_quota") {
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
} else if (type === "access_terminated") {
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
} else if (type === "billing_not_active") {
// Billing is not active (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
} else if (type === "requests" || type === "tokens") {
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!);
// I'm aware this is confusing -- throwing this class of error will cause
// the proxy response handler to return without terminating the request,
// so that it can be placed back in the queue.
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
} else {
// OpenAI probably overloaded
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
switch (type) {
case "insufficient_quota":
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
break;
case "access_terminated":
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
break;
case "billing_not_active":
// Key valid but account billing is delinquent
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`;
break;
case "requests":
case "tokens":
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
break;
}
return errorPayload;
}
@@ -455,12 +497,9 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
// seeing errors in this function, check the reassembled response body from
// handleStreamedResponse to see if the upstream API has changed.
try {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
assertJsonResponse(body);
const service = req.outboundApi;
const { completion } = getCompletionForService({ req, service, body });
const completion = getCompletionFromBody(req, body);
const tokens = await countTokens({ req, completion, service });
req.log.debug(
@@ -473,7 +512,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
req.outputTokens = tokens.token_count;
} catch (error) {
req.log.error(
req.log.warn(
error,
"Error while counting completion tokens; assuming `max_output_tokens`"
);
@@ -505,3 +544,14 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async (
res.setHeader(key, proxyRes.headers[key] as string);
});
};
function getAwsErrorType(header: string | string[] | undefined) {
const val = String(header).match(/^(\w+):?/)?.[1];
return val || String(header);
}
function assertJsonResponse(body: any): asserts body is Record<string, any> {
if (typeof body !== "object") {
throw new Error("Expected response to be an object");
}
}