uses EventStreamMarshaller from AWS SDK to hopefully handle split messages

This commit is contained in:
nai-degen
2024-02-05 19:56:41 -06:00
parent a8fd3c7240
commit ecc804887b
6 changed files with 109 additions and 118 deletions
@@ -1,7 +1,8 @@
import { pipeline, Transform } from "stream";
import { pipeline, Transform, Readable } from "stream";
import StreamArray from "stream-json/streamers/StreamArray";
import { StringDecoder } from "string_decoder";
import { promisify } from "util";
import { APIFormat, keyPool } from "../../../shared/key-management";
import {
makeCompletionSSE,
copySseResponseHeaders,
@@ -9,12 +10,10 @@ import {
} from "../../../shared/streaming";
import { enqueue } from "../../queue";
import { decodeResponseBody, RawResponseBodyHandler, RetryableError } from ".";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { EventAggregator } from "./streaming/event-aggregator";
import { APIFormat, keyPool } from "../../../shared/key-management";
import { AWSEventStreamDecoder } from "./streaming/aws-eventstream-decoder";
import pino from "pino";
import { SSEMessageTransformer } from "./streaming/sse-message-transformer";
import { SSEStreamAdapter } from "./streaming/sse-stream-adapter";
import { viaEventStreamMarshaller } from "./streaming/via-event-stream-marshaller";
const pipelineAsync = promisify(pipeline);
@@ -65,13 +64,13 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"];
const options = { contentType, api: req.outboundApi, logger: req.log };
const streamOptions = { contentType, api: req.outboundApi, logger: req.log };
// Decoder turns the raw response stream into a stream of events in some
// format (text/event-stream, vnd.amazon.event-stream, streaming JSON, etc).
const decoder = selectDecoderStream(options);
const decoder = selectDecoderStream({ ...streamOptions, input: proxyRes });
// Adapter transforms the decoded events into server-sent events.
const adapter = new SSEStreamAdapter(options);
const adapter = new SSEStreamAdapter(streamOptions);
// Aggregator compiles all events into a single response object.
const aggregator = new EventAggregator({ format: req.outboundApi });
// Transformer converts server-sent events from one vendor's API message
@@ -89,6 +88,8 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
.on("data", (msg) => {
if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`);
aggregator.addEvent(msg);
}).on("end", () => {
req.log.debug({ key: hash }, `Finished streaming response.`);
});
try {
@@ -125,13 +126,13 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
};
function selectDecoderStream(options: {
input: Readable;
api: APIFormat;
contentType?: string;
logger: pino.Logger;
}): NodeJS.ReadWriteStream {
const { api, contentType, logger } = options;
}) {
const { api, contentType, input } = options;
if (contentType?.includes("application/vnd.amazon.eventstream")) {
return new AWSEventStreamDecoder({ logger });
return viaEventStreamMarshaller(input);
} else if (api === "google-ai") {
return StreamArray.withParser();
} else {
+2
View File
@@ -188,6 +188,8 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
if (contentEncoding) {
if (isSupportedContentEncoding(contentEncoding)) {
const decoder = DECODER_MAP[contentEncoding];
// @ts-ignore - started failing after upgrading TypeScript, don't care
// as it was never a problem.
body = await decoder(body);
} else {
const errorMessage = `Proxy received response with unsupported content-encoding: ${contentEncoding}`;
@@ -1,105 +0,0 @@
import pino from "pino";
import { Transform, TransformOptions } from "stream";
import {
EventStreamCodec,
Message,
MessageDecoderStream,
} from "@smithy/eventstream-codec";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
/**
* Consumes an HTTP response stream and transforms it into a decoded stream of
* AWS vnd.amazon.eventstream messages.
*
* The AWS library uses async iterators, so this class needs to act as a bridge
* between the async generator and the Node stream API for downstream consumers.
*/
export class AWSEventStreamDecoder extends Transform {
private readonly decoder: EventStreamCodec;
private messageStream: MessageDecoderStream | null = null;
private queue: Uint8Array[] = [];
private resolveChunk: ((value: Uint8Array | null) => void) | null = null;
private readonly log: pino.Logger;
constructor(options: TransformOptions & { logger: pino.Logger }) {
super({ ...options, objectMode: true });
this.decoder = new EventStreamCodec(toUtf8, fromUtf8);
this.log = options.logger.child({ module: "aws-eventstream-decoder" });
this.setupStream();
}
protected enqueueChunk(chunk: Uint8Array) {
if (this.resolveChunk) {
this.resolveChunk(chunk);
this.resolveChunk = null;
} else {
this.queue.push(chunk);
}
}
protected dequeueChunk(): Promise<Uint8Array | null> {
if (this.queue.length > 0) {
return Promise.resolve(this.queue.shift()!);
}
return new Promise((resolve) => (this.resolveChunk = resolve));
}
protected setupStream() {
const that = this;
// This generator wraps the response stream (via the chunk queue) in an
// async iterable that can be consumed by the Amazon EventStream library.
const inputGenerator = async function* () {
while (true) {
const chunk = await that.dequeueChunk();
if (chunk === null) break;
yield chunk;
}
that.log.debug("Input stream generator finished");
};
// MessageDecoderStream is an async iterator that consumes chunks from
// inputGenerator and yields fully decoded individual messages.
this.messageStream = new MessageDecoderStream({
decoder: this.decoder,
inputStream: inputGenerator(),
});
// Start the generator and push messages downstream as they are decoded.
let lastMessage: Message | null = null;
(async function () {
try {
that.log.debug("Starting generator");
for await (const message of that.messageStream!) {
lastMessage = message;
that.push(message);
}
that.push(null);
} catch (err) {
that.log.error({ err, lastMessage }, "Error decoding eventstream message");
that.emit("error", err);
}
})();
}
_transform(chunk: Buffer, _encoding: string, callback: () => void) {
this.enqueueChunk(chunk);
callback();
}
_flush(callback: () => void) {
this.log.debug("Received end of stream; stopping generator");
if (this.resolveChunk) {
this.resolveChunk(null);
}
callback();
}
_destroy(err: Error | null, callback: (error: Error | null) => void) {
this.log.debug("Destroying stream");
if (this.resolveChunk) {
this.resolveChunk(null);
}
callback(err);
}
}
@@ -0,0 +1,65 @@
import { Duplex, Readable } from "stream";
import { EventStreamMarshaller } from "@smithy/eventstream-serde-node";
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
import { Message } from "@smithy/eventstream-codec";
/**
* Decodes a Readable stream, such as a proxied HTTP response, into a stream of
* Message objects using the AWS SDK's EventStreamMarshaller.
* @param input
*/
export function viaEventStreamMarshaller(input: Readable): Duplex {
const config = { utf8Encoder: toUtf8, utf8Decoder: fromUtf8 };
const eventStream = new EventStreamMarshaller(config).deserialize(
input,
// deserializer is always an object with one key. we just extract the value
// and pipe it to SSEStreamAdapter for it to turn it into an SSE stream
async (input: Record<string, Message>) => Object.values(input)[0]
);
return new StreamFromIterable(eventStream);
}
// In theory, Duplex.from(eventStream) would have rendered this wrapper
// unnecessary, but I was not able to get it to work for a number of reasons and
// needed more control over the stream's lifecycle.
class StreamFromIterable extends Duplex {
private readonly asyncIterable: AsyncIterable<Message>;
private iterator: AsyncIterator<Message>;
private reading: boolean;
constructor(asyncIterable: AsyncIterable<Message>, options = {}) {
super({ ...options, objectMode: true });
this.asyncIterable = asyncIterable;
this.iterator = this.asyncIterable[Symbol.asyncIterator]();
this.reading = false;
}
async _read(_size: number) {
if (this.reading) return;
this.reading = true;
try {
while (true) {
const { value, done } = await this.iterator.next();
if (done) {
this.push(null);
break;
}
if (!this.push(value)) break;
}
} catch (err) {
this.destroy(err);
} finally {
this.reading = false;
}
}
_write(_chunk: any, _encoding: string, callback: () => void) {
callback();
}
_final(callback: () => void) {
callback();
}
}