refactors SSEStreamAdapter to fix leaking decoder streams
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import { pipeline } from "stream";
|
||||
import { pipeline, Transform } from "stream";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { StringDecoder } from "string_decoder";
|
||||
import { promisify } from "util";
|
||||
import {
|
||||
makeCompletionSSE,
|
||||
@@ -10,7 +12,9 @@ import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
|
||||
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
|
||||
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
|
||||
import { EventAggregator } from "./streaming/event-aggregator";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { APIFormat, keyPool } from "../../../shared/key-management";
|
||||
import { AWSEventStreamDecoder } from "./streaming/aws-eventstream-decoder";
|
||||
import pino from "pino";
|
||||
|
||||
const pipelineAsync = promisify(pipeline);
|
||||
|
||||
@@ -61,12 +65,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
const options = { contentType, api: req.outboundApi, logger: req.log };
|
||||
|
||||
// Adapter turns some arbitrary stream (binary, JSON, etc.) into SSE events.
|
||||
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
|
||||
// Decoder turns the raw response stream into a stream of events in some
|
||||
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
|
||||
const decoder = selectDecoderStream(options);
|
||||
// Adapter transforms the decoded events into server-sent events.
|
||||
const adapter = new SSEStreamAdapter(options);
|
||||
// Aggregator compiles all events into a single response object.
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
// Transformer converts events to the user's requested format.
|
||||
// Transformer converts server-sent events from one vendor's API message
|
||||
// format to another.
|
||||
const transformer = new SSEMessageTransformer({
|
||||
inputFormat: req.outboundApi,
|
||||
inputApiVersion: String(req.headers["anthropic-version"]),
|
||||
@@ -83,7 +92,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
});
|
||||
|
||||
try {
|
||||
await pipelineAsync(proxyRes, adapter, transformer);
|
||||
await pipelineAsync(proxyRes, decoder, adapter, transformer);
|
||||
req.log.debug({ key: hash }, `Finished proxying SSE stream.`);
|
||||
res.end();
|
||||
return aggregator.getFinalResponse();
|
||||
@@ -98,7 +107,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
await enqueue(req);
|
||||
} else {
|
||||
const { message, stack, lastEvent } = err;
|
||||
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"
|
||||
const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined";
|
||||
const errorEvent = makeCompletionSSE({
|
||||
format: req.inboundApi,
|
||||
title: "Proxy stream error",
|
||||
@@ -114,3 +123,29 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
throw err;
|
||||
}
|
||||
};
|
||||
|
||||
function selectDecoderStream(options: {
|
||||
api: APIFormat;
|
||||
contentType?: string;
|
||||
logger: pino.Logger;
|
||||
}): NodeJS.ReadWriteStream {
|
||||
const { api, contentType, logger } = options;
|
||||
if (contentType?.includes("application/vnd.amazon.eventstream")) {
|
||||
return new AWSEventStreamDecoder({ logger });
|
||||
} else if (api === "google-ai") {
|
||||
return StreamArray.withParser();
|
||||
} else {
|
||||
// Passthrough stream, but ensures split chunks across multi-byte characters
|
||||
// are handled correctly.
|
||||
const stringDecoder = new StringDecoder("utf8");
|
||||
return new Transform({
|
||||
readableObjectMode: true,
|
||||
writableObjectMode: false,
|
||||
transform(chunk, _encoding, callback) {
|
||||
const text = stringDecoder.write(chunk);
|
||||
if (text) this.push(text);
|
||||
callback();
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import pino from "pino";
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
import {
|
||||
EventStreamCodec,
|
||||
@@ -5,9 +6,6 @@ import {
|
||||
MessageDecoderStream,
|
||||
} from "@smithy/eventstream-codec";
|
||||
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
||||
import { logger } from "../../../../logger";
|
||||
|
||||
const log = logger.child({ module: "aws-eventstream-decoder" });
|
||||
|
||||
/**
|
||||
* Consumes an HTTP response stream and transforms it into a decoded stream of
|
||||
@@ -21,10 +19,12 @@ export class AWSEventStreamDecoder extends Transform {
|
||||
private messageStream: MessageDecoderStream | null = null;
|
||||
private queue: Uint8Array[] = [];
|
||||
private resolveChunk: ((value: Uint8Array | null) => void) | null = null;
|
||||
private readonly log: pino.Logger;
|
||||
|
||||
constructor(options?: TransformOptions) {
|
||||
constructor(options: TransformOptions & { logger: pino.Logger }) {
|
||||
super({ ...options, objectMode: true });
|
||||
this.decoder = new EventStreamCodec(toUtf8, fromUtf8);
|
||||
this.log = options.logger.child({ module: "aws-eventstream-decoder" });
|
||||
this.setupStream();
|
||||
}
|
||||
|
||||
@@ -49,14 +49,14 @@ export class AWSEventStreamDecoder extends Transform {
|
||||
|
||||
// This generator wraps the response stream (via the chunk queue) in an
|
||||
// async iterable that can be consumed by the Amazon EventStream library.
|
||||
const inputGenerator = (async function* () {
|
||||
const inputGenerator = async function* () {
|
||||
while (true) {
|
||||
const chunk = await that.dequeueChunk();
|
||||
if (chunk === null) break;
|
||||
yield chunk;
|
||||
}
|
||||
log.debug("Input stream generator finished");
|
||||
});
|
||||
that.log.debug("Input stream generator finished");
|
||||
};
|
||||
|
||||
// MessageDecoderStream is an async iterator that consumes chunks from
|
||||
// inputGenerator and yields fully decoded individual messages.
|
||||
@@ -69,14 +69,14 @@ export class AWSEventStreamDecoder extends Transform {
|
||||
let lastMessage: Message | null = null;
|
||||
(async function () {
|
||||
try {
|
||||
log.debug("Starting generator");
|
||||
that.log.debug("Starting generator");
|
||||
for await (const message of that.messageStream!) {
|
||||
lastMessage = message;
|
||||
that.push(message);
|
||||
}
|
||||
that.push(null);
|
||||
} catch (err) {
|
||||
log.error({ err, lastMessage }, "Error decoding eventstream message");
|
||||
that.log.error({ err, lastMessage }, "Error decoding eventstream message");
|
||||
that.emit("error", err);
|
||||
}
|
||||
})();
|
||||
@@ -88,10 +88,18 @@ export class AWSEventStreamDecoder extends Transform {
|
||||
}
|
||||
|
||||
_flush(callback: () => void) {
|
||||
log.debug("Received end of stream; stopping generator");
|
||||
this.log.debug("Received end of stream; stopping generator");
|
||||
if (this.resolveChunk) {
|
||||
this.resolveChunk(null);
|
||||
}
|
||||
callback();
|
||||
}
|
||||
|
||||
_destroy(err: Error | null, callback: (error: Error | null) => void) {
|
||||
this.log.debug("Destroying stream");
|
||||
if (this.resolveChunk) {
|
||||
this.resolveChunk(null);
|
||||
}
|
||||
callback(err);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +1,39 @@
|
||||
import pino from "pino";
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
import { Message } from "@smithy/eventstream-codec";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { StringDecoder } from "string_decoder";
|
||||
import { logger } from "../../../../logger";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { makeCompletionSSE } from "../../../../shared/streaming";
|
||||
import { RetryableError } from "../index";
|
||||
import { AWSEventStreamDecoder } from "./aws-eventstream-decoder";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & {
|
||||
contentType?: string;
|
||||
api: APIFormat;
|
||||
logger: pino.Logger;
|
||||
};
|
||||
|
||||
/**
|
||||
* Receives either text chunks or AWS vnd.amazon.eventstream messages and emits
|
||||
* full SSE-compliant messages.
|
||||
* Receives a stream of events in a variety of formats and transforms them into
|
||||
* Server-Sent Events.
|
||||
*
|
||||
* This is an object-mode stream, so it expects to receive objects and will emit
|
||||
* strings.
|
||||
*/
|
||||
export class SSEStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private readonly isGoogleStream;
|
||||
private awsDecoder = new AWSEventStreamDecoder();
|
||||
private jsonParser = StreamArray.withParser();
|
||||
private partialMessage = "";
|
||||
private decoder = new StringDecoder("utf8");
|
||||
private textDecoder = new TextDecoder("utf8");
|
||||
private log: pino.Logger;
|
||||
|
||||
constructor(options?: SSEStreamAdapterOptions) {
|
||||
super(options);
|
||||
constructor(options: SSEStreamAdapterOptions) {
|
||||
super({ ...options, objectMode: true });
|
||||
this.isAwsStream =
|
||||
options?.contentType === "application/vnd.amazon.eventstream";
|
||||
this.isGoogleStream = options?.api === "google-ai";
|
||||
|
||||
this.awsDecoder.on("data", (data: Message) => {
|
||||
try {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
} catch (error) {
|
||||
this.emit("error", error);
|
||||
}
|
||||
});
|
||||
|
||||
this.jsonParser.on("data", (data: { value: any }) => {
|
||||
const message = this.processGoogleValue(data.value);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
this.log = options.logger.child({ module: "sse-stream-adapter" });
|
||||
}
|
||||
|
||||
protected processAwsEvent(message: Message): string | null {
|
||||
protected processAwsMessage(message: Message): string | null {
|
||||
// Per amazon, headers and body are always present. headers is an object,
|
||||
// body is a Uint8Array, potentially zero-length.
|
||||
const { headers, body } = message;
|
||||
@@ -78,13 +58,13 @@ export class SSEStreamAdapter extends Transform {
|
||||
const type = exceptionType || errorCode || "UnknownError";
|
||||
switch (type) {
|
||||
case "ThrottlingException":
|
||||
log.warn(
|
||||
this.log.warn(
|
||||
{ message, type },
|
||||
"AWS request throttled after streaming has already started; retrying"
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
default:
|
||||
log.error({ message, type }, "Received bad AWS stream event");
|
||||
this.log.error({ message, type }, "Received bad AWS stream event");
|
||||
return makeCompletionSSE({
|
||||
format: "anthropic",
|
||||
title: "Proxy stream error",
|
||||
@@ -97,20 +77,20 @@ export class SSEStreamAdapter extends Transform {
|
||||
}
|
||||
default:
|
||||
// Amazon says this can't ever happen...
|
||||
log.error({ message }, "Received very bad AWS stream event");
|
||||
this.log.error({ message }, "Received very bad AWS stream event");
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Processes an incoming array element from the Google AI JSON stream. */
|
||||
protected processGoogleValue(value: any): string | null {
|
||||
protected processGoogleObject(value: any): string | null {
|
||||
try {
|
||||
const candidates = value.candidates ?? [{}];
|
||||
const hasParts = candidates[0].content?.parts?.length > 0;
|
||||
if (hasParts) {
|
||||
return `data: ${JSON.stringify(value)}`;
|
||||
} else {
|
||||
log.error({ event: value }, "Received bad Google AI event");
|
||||
this.log.error({ event: value }, "Received bad Google AI event");
|
||||
return `data: ${makeCompletionSSE({
|
||||
format: "google-ai",
|
||||
title: "Proxy stream error",
|
||||
@@ -124,23 +104,23 @@ export class SSEStreamAdapter extends Transform {
|
||||
} catch (error) {
|
||||
error.lastEvent = value;
|
||||
this.emit("error", error);
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
_transform(data: any, _enc: string, callback: (err?: Error | null) => void) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
this.awsDecoder.write(chunk);
|
||||
// `data` is a Message object
|
||||
const message = this.processAwsMessage(data);
|
||||
if (message) this.push(message + "\n\n");
|
||||
} else if (this.isGoogleStream) {
|
||||
this.jsonParser.write(chunk);
|
||||
// `data` is an element from the Google AI JSON stream
|
||||
const message = this.processGoogleObject(data);
|
||||
if (message) this.push(message + "\n\n");
|
||||
} else {
|
||||
// We may receive multiple (or partial) SSE messages in a single chunk,
|
||||
// so we need to buffer and emit separate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = this.decoder.write(chunk);
|
||||
|
||||
const fullMessages = (this.partialMessage + str).split(
|
||||
// `data` is a string, but possibly only a partial message
|
||||
const fullMessages = (this.partialMessage + data).split(
|
||||
/\r\r|\n\n|\r\n\r\n/
|
||||
);
|
||||
this.partialMessage = fullMessages.pop() || "";
|
||||
@@ -154,7 +134,7 @@ export class SSEStreamAdapter extends Transform {
|
||||
}
|
||||
callback();
|
||||
} catch (error) {
|
||||
error.lastEvent = chunk?.toString();
|
||||
error.lastEvent = data?.toString();
|
||||
this.emit("error", error);
|
||||
callback(error);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user