Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)

This commit is contained in:
khanon
2023-10-01 01:40:18 +00:00
parent 7e681a7bef
commit fa4bf468d2
38 changed files with 1438 additions and 410 deletions
@@ -3,6 +3,7 @@ import * as http from "http";
import { buildFakeSseMessage } from "../common";
import { RawResponseBodyHandler, decodeResponseBody } from ".";
import { assertNever } from "../../../shared/utils";
import { ServerSentEventStreamAdapter } from "./sse-stream-adapter";
type OpenAiChatCompletionResponse = {
id: string;
@@ -82,6 +83,11 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
return decodeResponseBody(proxyRes, req, res);
}
req.log.debug(
{ headers: proxyRes.headers, key: key.hash },
`Received SSE headers.`
);
return new Promise((resolve, reject) => {
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`);
@@ -97,75 +103,50 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
res.flushHeaders();
}
const originalEvents: string[] = [];
let partialMessage = "";
const adapter = new ServerSentEventStreamAdapter({
isAwsStream:
proxyRes.headers["content-type"] ===
"application/vnd.amazon.eventstream",
});
const events: string[] = [];
let lastPosition = 0;
let eventCount = 0;
type ProxyResHandler<T extends unknown> = (...args: T[]) => void;
function withErrorHandling<T extends unknown>(fn: ProxyResHandler<T>) {
return (...args: T[]) => {
try {
fn(...args);
} catch (error) {
proxyRes.emit("error", error);
}
};
}
proxyRes.pipe(adapter);
proxyRes.on(
"data",
withErrorHandling((chunk: Buffer) => {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit seperate stream events for full
// messages so we can parse/transform them properly.
const str = chunk.toString();
// Anthropic uses CRLF line endings (out-of-spec btw)
const fullMessages = (partialMessage + str).split(/\r?\n\r?\n/);
partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) {
proxyRes.emit("full-sse-event", message);
}
})
);
proxyRes.on(
"full-sse-event",
withErrorHandling((data) => {
originalEvents.push(data);
adapter.on("data", (chunk: any) => {
try {
const { event, position } = transformEvent({
data,
data: chunk.toString(),
requestApi: req.inboundApi,
responseApi: req.outboundApi,
lastPosition,
index: eventCount++,
});
events.push(event);
lastPosition = position;
res.write(event + "\n\n");
})
);
} catch (err) {
adapter.emit("error", err);
}
});
proxyRes.on(
"end",
withErrorHandling(() => {
let finalBody = convertEventsToFinalResponse(originalEvents, req);
adapter.on("end", () => {
try {
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
const finalBody = convertEventsToFinalResponse(events, req);
res.end();
resolve(finalBody);
})
);
} catch (err) {
adapter.emit("error", err);
}
});
proxyRes.on("error", (err) => {
adapter.on("error", (err) => {
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
const fakeErrorEvent = buildFakeSseMessage(
"mid-stream-error",
err.message,
req
);
res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`);
res.write("data: [DONE]\n\n");
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`);
res.end();
reject(err);
});
@@ -197,8 +178,6 @@ function transformEvent(params: SSETransformationArgs) {
case "openai->anthropic":
// TODO: handle new anthropic streaming format
return transformV1AnthropicEventToOpenAI(params);
case "openai->google-palm":
return transformPalmEventToOpenAI(params);
default:
throw new Error(`Unsupported streaming API transformation. ${trans}`);
}
@@ -288,11 +267,6 @@ function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) {
};
}
function transformPalmEventToOpenAI({ data }: SSETransformationArgs) {
throw new Error("PaLM streaming not yet supported.");
return { position: -1, event: data };
}
/** Copy headers, excluding ones we're already setting for the SSE response. */
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
const toOmit = [
@@ -366,7 +340,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
choices: [],
// TODO: merge logprobs
};
merged = events.reduce((acc, event, i) => {
merged = events.reduce((acc, event) => {
if (!event.startsWith("data: ")) return acc;
if (event === "data: [DONE]") return acc;
@@ -390,16 +364,37 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
return merged;
}
case "anthropic": {
/*
* Full complete responses from Anthropic are conveniently just the same as
* the final SSE event before the "DONE" event, so we can reuse that
*/
const lastEvent = events[events.length - 2].toString();
const data = JSON.parse(
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
);
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
return final;
if (req.headers["anthropic-version"] === "2023-01-01") {
return convertAnthropicV1(events, req);
}
let merged: AnthropicCompletionResponse = {
completion: "",
stop_reason: "",
truncated: false,
stop: null,
model: req.body.model,
log_id: "",
exception: null,
}
merged = events.reduce((acc, event) => {
if (!event.startsWith("data: ")) return acc;
if (event === "data: [DONE]") return acc;
const data = JSON.parse(event.slice("data: ".length));
return {
completion: acc.completion + data.completion,
stop_reason: data.stop_reason,
truncated: data.truncated,
stop: data.stop,
log_id: data.log_id,
exception: data.exception,
model: acc.model,
};
}, merged);
return merged;
}
case "google-palm": {
throw new Error("PaLM streaming not yet supported.");
@@ -408,3 +403,16 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
assertNever(req.outboundApi);
}
}
/** Older Anthropic streaming format which sent full completion each time. */
function convertAnthropicV1(
events: string[],
req: Request
) {
const lastEvent = events[events.length - 2].toString();
const data = JSON.parse(
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
);
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
return final;
}
+120 -70
View File
@@ -12,7 +12,7 @@ import {
incrementTokenCount,
} from "../../../shared/users/user-store";
import {
getCompletionForService,
getCompletionFromBody,
isCompletionRequest,
writeErrorResponse,
} from "../common";
@@ -173,7 +173,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
throw err;
}
const promise = new Promise<string>((resolve, reject) => {
return new Promise<string>((resolve, reject) => {
let chunks: Buffer[] = [];
proxyRes.on("data", (chunk) => chunks.push(chunk));
proxyRes.on("end", async () => {
@@ -209,10 +209,14 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
}
});
});
return promise;
};
// TODO: This is too specific to OpenAI's error responses.
type ProxiedErrorPayload = {
error?: Record<string, any>;
message?: string;
proxy_note?: string;
};
/**
* Handles non-2xx responses from the upstream service. If the proxied response
* is an error, this will respond to the client with an error payload and throw
@@ -233,27 +237,19 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
return;
}
let errorPayload: Record<string, any>;
// Subtract 1 from available keys because if this message is being shown,
// it's because the key is about to be disabled.
const availableKeys = keyPool.available(req.outboundApi) - 1;
const tryAgainMessage = Boolean(availableKeys)
? `There are ${availableKeys} more keys available; try your request again.`
: "There are no more keys available.";
let errorPayload: ProxiedErrorPayload;
const tryAgainMessage = keyPool.available(req.body?.model)
? `There may be more keys available for this model; try again in a few seconds.`
: "There are no more keys available for this model.";
try {
if (typeof body === "object") {
errorPayload = body;
} else {
throw new Error("Received unparsable error response from upstream.");
}
} catch (parseError: any) {
assertJsonResponse(body);
errorPayload = body;
} catch (parseError) {
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
const hash = req.key?.hash;
const statusMessage = proxyRes.statusMessage || "Unknown error";
// Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer
logger.warn(
{ statusCode, statusMessage, key: req.key?.hash },
parseError.message
);
logger.warn({ statusCode, statusMessage, key: hash }, parseError.message);
const errorObject = {
statusCode,
@@ -265,53 +261,76 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
throw new Error(parseError.message);
}
const errorType =
errorPayload.error?.code ||
errorPayload.error?.type ||
getAwsErrorType(proxyRes.headers["x-amzn-errortype"]);
logger.warn(
{
statusCode,
type: errorPayload.error?.code,
errorPayload,
key: req.key?.hash,
},
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
`Received error response from upstream. (${proxyRes.statusMessage})`
);
const service = req.key!.service;
if (service === "aws") {
// Try to standardize the error format for AWS
errorPayload.error = { message: errorPayload.message, type: errorType };
delete errorPayload.message;
}
if (statusCode === 400) {
// Bad request (likely prompt is too long)
switch (req.outboundApi) {
// Bad request. For OpenAI, this is usually due to prompt length.
// For Anthropic, this is usually due to missing preamble.
switch (service) {
case "openai":
case "openai-text":
case "google-palm":
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
break;
case "anthropic":
case "aws":
maybeHandleMissingPreambleError(req, errorPayload);
break;
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else if (statusCode === 401) {
// Key is invalid or was revoked
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 403) {
// Amazon is the only service that returns 403.
switch (errorType) {
case "UnrecognizedClientException":
// Key is invalid.
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
break;
case "AccessDeniedException":
errorPayload.proxy_note = `API key doesn't have access to the requested resource.`;
break;
default:
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
}
} else if (statusCode === 429) {
switch (req.outboundApi) {
switch (service) {
case "openai":
case "openai-text":
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
break;
case "anthropic":
handleAnthropicRateLimitError(req, errorPayload);
break;
case "aws":
handleAwsRateLimitError(req, errorPayload);
break;
case "google-palm":
throw new Error("Rate limit handling not implemented for PaLM");
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else if (statusCode === 404) {
// Most likely model not found
switch (req.outboundApi) {
switch (service) {
case "openai":
case "openai-text":
if (errorPayload.error?.code === "model_not_found") {
const requestedModel = req.body.model;
const modelFamily = getOpenAIModelFamily(requestedModel);
@@ -328,8 +347,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "google-palm":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
break;
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
break;
default:
assertNever(req.outboundApi);
assertNever(service);
}
} else {
errorPayload.proxy_note = `Unrecognized error from upstream service.`;
@@ -368,7 +390,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
*/
function maybeHandleMissingPreambleError(
req: Request,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
) {
if (
errorPayload.error?.type === "invalid_request_error" &&
@@ -388,7 +410,7 @@ function maybeHandleMissingPreambleError(
function handleAnthropicRateLimitError(
req: Request,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
) {
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
@@ -399,35 +421,55 @@ function handleAnthropicRateLimitError(
}
}
function handleAwsRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const errorType = errorPayload.error?.type;
switch (errorType) {
case "ThrottlingException":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("AWS rate-limited request re-enqueued.");
case "ModelNotReadyException":
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
break;
default:
errorPayload.proxy_note = `Unrecognized rate limit error from AWS. (${errorType})`;
}
}
function handleOpenAIRateLimitError(
req: Request,
tryAgainMessage: string,
errorPayload: Record<string, any>
errorPayload: ProxiedErrorPayload
): Record<string, any> {
const type = errorPayload.error?.type;
if (type === "insufficient_quota") {
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
} else if (type === "access_terminated") {
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
} else if (type === "billing_not_active") {
// Billing is not active (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
} else if (type === "requests" || type === "tokens") {
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!);
// I'm aware this is confusing -- throwing this class of error will cause
// the proxy response handler to return without terminating the request,
// so that it can be placed back in the queue.
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
} else {
// OpenAI probably overloaded
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
switch (type) {
case "insufficient_quota":
// Billing quota exceeded (key is dead, disable it)
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
break;
case "access_terminated":
// Account banned (key is dead, disable it)
keyPool.disable(req.key!, "revoked");
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
break;
case "billing_not_active":
// Key valid but account billing is delinquent
keyPool.disable(req.key!, "quota");
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`;
break;
case "requests":
case "tokens":
// Per-minute request or token rate limit is exceeded, which we can retry
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
break;
}
return errorPayload;
}
@@ -455,12 +497,9 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
// seeing errors in this function, check the reassembled response body from
// handleStreamedResponse to see if the upstream API has changed.
try {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
assertJsonResponse(body);
const service = req.outboundApi;
const { completion } = getCompletionForService({ req, service, body });
const completion = getCompletionFromBody(req, body);
const tokens = await countTokens({ req, completion, service });
req.log.debug(
@@ -473,7 +512,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
req.outputTokens = tokens.token_count;
} catch (error) {
req.log.error(
req.log.warn(
error,
"Error while counting completion tokens; assuming `max_output_tokens`"
);
@@ -505,3 +544,14 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async (
res.setHeader(key, proxyRes.headers[key] as string);
});
};
function getAwsErrorType(header: string | string[] | undefined) {
const val = String(header).match(/^(\w+):?/)?.[1];
return val || String(header);
}
function assertJsonResponse(body: any): asserts body is Record<string, any> {
if (typeof body !== "object") {
throw new Error("Expected response to be an object");
}
}
+9 -7
View File
@@ -1,7 +1,11 @@
import { Request } from "express";
import { config } from "../../../config";
import { logQueue } from "../../../shared/prompt-logging";
import { getCompletionForService, isCompletionRequest } from "../common";
import {
getCompletionFromBody,
getModelFromBody,
isCompletionRequest,
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
@@ -25,17 +29,15 @@ export const logPrompt: ProxyResHandlerWithBody = async (
const promptPayload = getPromptForRequest(req);
const promptFlattened = flattenMessages(promptPayload);
const response = getCompletionForService({
service: req.outboundApi,
body: responseBody,
});
const response = getCompletionFromBody(req, responseBody);
const model = getModelFromBody(req, responseBody);
logQueue.enqueue({
endpoint: req.inboundApi,
promptRaw: JSON.stringify(promptPayload),
promptFlattened,
model: response.model, // may differ from the requested model
response: response.completion,
model,
response,
});
};
@@ -0,0 +1,85 @@
import { Transform, TransformOptions } from "stream";
// @ts-ignore
import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../logger";
const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean };
type AwsEventStreamMessage = {
headers: { ":message-type": "event" | "exception" };
payload: { message?: string /** base64 encoded */; bytes?: string };
};
/**
* Receives either text chunks or AWS binary event stream chunks and emits
* full SSE events.
*/
export class ServerSentEventStreamAdapter extends Transform {
private readonly isAwsStream;
private parser = new Parser();
private partialMessage = "";
constructor(options?: SSEStreamAdapterOptions) {
super(options);
this.isAwsStream = options?.isAwsStream || false;
this.parser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data);
if (message) {
this.push(Buffer.from(message, "utf8"));
}
});
}
processAwsEvent(event: AwsEventStreamMessage): string | null {
const { payload, headers } = event;
if (headers[":message-type"] === "exception" || !payload.bytes) {
log.error(
{ event: JSON.stringify(event) },
"Received bad streaming event from AWS"
);
const message = JSON.stringify(event);
return getFakeErrorCompletion("proxy AWS error", message);
} else {
return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`;
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try {
if (this.isAwsStream) {
this.parser.write(chunk);
} else {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly.
const str = chunk.toString("utf8");
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) {
this.push(message);
}
}
callback();
} catch (error) {
this.emit("error", error);
callback(error);
}
}
}
function getFakeErrorCompletion(type: string, message: string) {
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
const fakeEvent = {
log_id: "aws-proxy-sse-message",
stop_reason: type,
completion:
"\nProxy encountered an error during streaming response.\n" + content,
truncated: false,
stop: null,
model: "",
};
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
}