refactors SSEStreamAdapter to fix leaking decoder streams

This commit is contained in:
nai-degen
2024-02-04 18:38:06 -06:00
parent 98cea2da02
commit 40240601f5
3 changed files with 88 additions and 65 deletions
@@ -1,4 +1,6 @@
import { pipeline } from "stream";
import { pipeline, Transform } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import {
makeCompletionSSE,
@@ -10,7 +12,9 @@ import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { EventAggregator } from "./streaming/event-aggregator";
import { keyPool } from "../../../shared/key-management";
import { APIFormat, keyPool } from "../../../shared/key-management";
import { AWSEventStreamDecoder } from "./streaming/aws-eventstream-decoder";
import pino from "pino";
const pipelineAsync = promisify(pipeline);
@@ -61,12 +65,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"];
const options = { contentType, api: req.outboundApi, logger: req.log };
// Adapter turns some arbitrary stream (binary, JSON, etc.) into SSE events.
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
const decoder = selectDecoderStream(options);
// Adapter transforms the decoded events into server-sent events.
const adapter = new SSEStreamAdapter(options);
// Aggregator compiles all events into a single response object.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Transformer converts events to the user's requested format.
// Transformer converts server-sent events from one vendor's API message
// format to another.
const transformer = new SSEMessageTransformer({
inputFormat: req.outboundApi,
inputApiVersion: String(req.headers["anthropic-version"]),
@@ -83,7 +92,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
});
try {
await pipelineAsync(proxyRes, adapter, transformer);
await pipelineAsync(proxyRes, decoder, adapter, transformer);
req.log.debug({ key: hash }, `Finished proxying SSE stream.`);
res.end();
return aggregator.getFinalResponse();
@@ -98,7 +107,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
await enqueue(req);
} else {
const { message, stack, lastEvent } = err;
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined";
const errorEvent = makeCompletionSSE({
format: req.inboundApi,
title: "Proxy stream error",
@@ -114,3 +123,29 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
throw err;
}
};
function selectDecoderStream(options: {
api: APIFormat;
contentType?: string;
logger: pino.Logger;
}): NodeJS.ReadWriteStream {
const { api, contentType, logger } = options;
if (contentType?.includes("application/vnd.amazon.eventstream")) {
return new AWSEventStreamDecoder({ logger });
} else if (api === "google-ai") {
return StreamArray.withParser();
} else {
// Passthrough stream, but ensures split chunks across multi-byte characters
// are handled correctly.
const stringDecoder = new StringDecoder("utf8");
return new Transform({
readableObjectMode: true,
writableObjectMode: false,
transform(chunk, _encoding, callback) {
const text = stringDecoder.write(chunk);
if (text) this.push(text);
callback();
},
});
}
}
@@ -1,3 +1,4 @@
import pino from "pino";
import { Transform, TransformOptions } from "stream";
import {
EventStreamCodec,
@@ -5,9 +6,6 @@ import {
MessageDecoderStream,
} from "@smithy/eventstream-codec";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
import { logger } from "../../../../logger";
const log = logger.child({ module: "aws-eventstream-decoder" });
/**
* Consumes an HTTP response stream and transforms it into a decoded stream of
@@ -21,10 +19,12 @@ export class AWSEventStreamDecoder extends Transform {
private messageStream: MessageDecoderStream | null = null;
private queue: Uint8Array[] = [];
private resolveChunk: ((value: Uint8Array | null) => void) | null = null;
private readonly log: pino.Logger;
constructor(options?: TransformOptions) {
constructor(options: TransformOptions & { logger: pino.Logger }) {
super({ ...options, objectMode: true });
this.decoder = new EventStreamCodec(toUtf8, fromUtf8);
this.log = options.logger.child({ module: "aws-eventstream-decoder" });
this.setupStream();
}
@@ -49,14 +49,14 @@ export class AWSEventStreamDecoder extends Transform {
// This generator wraps the response stream (via the chunk queue) in an
// async iterable that can be consumed by the Amazon EventStream library.
const inputGenerator = (async function* () {
const inputGenerator = async function* () {
while (true) {
const chunk = await that.dequeueChunk();
if (chunk === null) break;
yield chunk;
}
log.debug("Input stream generator finished");
});
that.log.debug("Input stream generator finished");
};
// MessageDecoderStream is an async iterator that consumes chunks from
// inputGenerator and yields fully decoded individual messages.
@@ -69,14 +69,14 @@ export class AWSEventStreamDecoder extends Transform {
let lastMessage: Message | null = null;
(async function () {
try {
log.debug("Starting generator");
that.log.debug("Starting generator");
for await (const message of that.messageStream!) {
lastMessage = message;
that.push(message);
}
that.push(null);
} catch (err) {
log.error({ err, lastMessage }, "Error decoding eventstream message");
that.log.error({ err, lastMessage }, "Error decoding eventstream message");
that.emit("error", err);
}
})();
@@ -88,10 +88,18 @@ export class AWSEventStreamDecoder extends Transform {
}
_flush(callback: () => void) {
log.debug("Received end of stream; stopping generator");
this.log.debug("Received end of stream; stopping generator");
if (this.resolveChunk) {
this.resolveChunk(null);
}
callback();
}
_destroy(err: Error | null, callback: (error: Error | null) => void) {
this.log.debug("Destroying stream");
if (this.resolveChunk) {
this.resolveChunk(null);
}
callback(err);
}
}
@@ -1,59 +1,39 @@
import pino from "pino";
import { Transform, TransformOptions } from "stream";
import { Message } from "@smithy/eventstream-codec";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { logger } from "../../../../logger";
import { APIFormat } from "../../../../shared/key-management";
import { makeCompletionSSE } from "../../../../shared/streaming";
import { RetryableError } from "../index";
import { AWSEventStreamDecoder } from "./aws-eventstream-decoder";
const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
logger: pino.Logger;
};
/**
* Receives either text chunks or AWS vnd.amazon.eventstream messages and emits
* full SSE-compliant messages.
* Receives a stream of events in a variety of formats and transforms them into
* Server-Sent Events.
*
* This is an object-mode stream, so it expects to receive objects and will emit
* strings.
*/
export class SSEStreamAdapter extends Transform {
private readonly isAwsStream;
private readonly isGoogleStream;
private awsDecoder = new AWSEventStreamDecoder();
private jsonParser = StreamArray.withParser();
private partialMessage = "";
private decoder = new StringDecoder("utf8");
private textDecoder = new TextDecoder("utf8");
private log: pino.Logger;
constructor(options?: SSEStreamAdapterOptions) {
super(options);
constructor(options: SSEStreamAdapterOptions) {
super({ ...options, objectMode: true });
this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.awsDecoder.on("data", (data: Message) => {
try {
const message = this.processAwsEvent(data);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
} catch (error) {
this.emit("error", error);
}
});
this.jsonParser.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
this.log = options.logger.child({ module: "sse-stream-adapter" });
}
protected processAwsEvent(message: Message): string | null {
protected processAwsMessage(message: Message): string | null {
// Per amazon, headers and body are always present. headers is an object,
// body is a Uint8Array, potentially zero-length.
const { headers, body } = message;
@@ -78,13 +58,13 @@ export class SSEStreamAdapter extends Transform {
const type = exceptionType || errorCode || "UnknownError";
switch (type) {
case "ThrottlingException":
log.warn(
this.log.warn(
{ message, type },
"AWS request throttled after streaming has already started; retrying"
);
throw new RetryableError("AWS request throttled mid-stream");
default:
log.error({ message, type }, "Received bad AWS stream event");
this.log.error({ message, type }, "Received bad AWS stream event");
return makeCompletionSSE({
format: "anthropic",
title: "Proxy stream error",
@@ -97,20 +77,20 @@ export class SSEStreamAdapter extends Transform {
}
default:
// Amazon says this can't ever happen...
log.error({ message }, "Received very bad AWS stream event");
this.log.error({ message }, "Received very bad AWS stream event");
return null;
}
}
/** Processes an incoming array element from the Google AI JSON stream. */
protected processGoogleValue(value: any): string | null {
protected processGoogleObject(value: any): string | null {
try {
const candidates = value.candidates ?? [{}];
const hasParts = candidates[0].content?.parts?.length > 0;
if (hasParts) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error({ event: value }, "Received bad Google AI event");
this.log.error({ event: value }, "Received bad Google AI event");
return `data: ${makeCompletionSSE({
format: "google-ai",
title: "Proxy stream error",
@@ -124,23 +104,23 @@ export class SSEStreamAdapter extends Transform {
} catch (error) {
error.lastEvent = value;
this.emit("error", error);
return null;
}
return null;
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
_transform(data: any, _enc: string, callback: (err?: Error | null) => void) {
try {
if (this.isAwsStream) {
this.awsDecoder.write(chunk);
// `data` is a Message object
const message = this.processAwsMessage(data);
if (message) this.push(message + "\n\n");
} else if (this.isGoogleStream) {
this.jsonParser.write(chunk);
// `data` is an element from the Google AI JSON stream
const message = this.processGoogleObject(data);
if (message) this.push(message + "\n\n");
} else {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly.
const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(
// `data` is a string, but possibly only a partial message
const fullMessages = (this.partialMessage + data).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || "";
@@ -154,7 +134,7 @@ export class SSEStreamAdapter extends Transform {
}
callback();
} catch (error) {
error.lastEvent = chunk?.toString();
error.lastEvent = data?.toString();
this.emit("error", error);
callback(error);
}