replaces eventstream lib to (hopefully) fix interrupted AWS streams
This commit is contained in:
@@ -0,0 +1,97 @@
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
import {
|
||||
EventStreamCodec,
|
||||
Message,
|
||||
MessageDecoderStream,
|
||||
} from "@smithy/eventstream-codec";
|
||||
import { fromUtf8, toUtf8 } from "@smithy/util-utf8";
|
||||
import { logger } from "../../../../logger";
|
||||
|
||||
const log = logger.child({ module: "aws-eventstream-decoder" });
|
||||
|
||||
/**
|
||||
* Consumes an HTTP response stream and transforms it into a decoded stream of
|
||||
* AWS vnd.amazon.eventstream messages.
|
||||
*
|
||||
* The AWS library uses async iterators, so this class needs to act as a bridge
|
||||
* between the async generator and the Node stream API for downstream consumers.
|
||||
*/
|
||||
export class AWSEventStreamDecoder extends Transform {
|
||||
private readonly decoder: EventStreamCodec;
|
||||
private messageStream: MessageDecoderStream | null = null;
|
||||
private queue: Uint8Array[] = [];
|
||||
private resolveChunk: ((value: Uint8Array | null) => void) | null = null;
|
||||
|
||||
constructor(options?: TransformOptions) {
|
||||
super({ ...options, objectMode: true });
|
||||
this.decoder = new EventStreamCodec(toUtf8, fromUtf8);
|
||||
this.setupStream();
|
||||
}
|
||||
|
||||
protected enqueueChunk(chunk: Uint8Array) {
|
||||
if (this.resolveChunk) {
|
||||
this.resolveChunk(chunk);
|
||||
this.resolveChunk = null;
|
||||
} else {
|
||||
this.queue.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
protected dequeueChunk(): Promise<Uint8Array | null> {
|
||||
if (this.queue.length > 0) {
|
||||
return Promise.resolve(this.queue.shift()!);
|
||||
}
|
||||
return new Promise((resolve) => (this.resolveChunk = resolve));
|
||||
}
|
||||
|
||||
protected setupStream() {
|
||||
const that = this;
|
||||
|
||||
// This generator wraps the response stream (via the chunk queue) in an
|
||||
// async iterable that can be consumed by the Amazon EventStream library.
|
||||
const inputGenerator = (async function* () {
|
||||
while (true) {
|
||||
const chunk = await that.dequeueChunk();
|
||||
if (chunk === null) break;
|
||||
yield chunk;
|
||||
}
|
||||
log.debug("Input stream generator finished");
|
||||
});
|
||||
|
||||
// MessageDecoderStream is an async iterator that consumes chunks from
|
||||
// inputGenerator and yields fully decoded individual messages.
|
||||
this.messageStream = new MessageDecoderStream({
|
||||
decoder: this.decoder,
|
||||
inputStream: inputGenerator(),
|
||||
});
|
||||
|
||||
// Start the generator and push messages downstream as they are decoded.
|
||||
let lastMessage: Message | null = null;
|
||||
(async function () {
|
||||
try {
|
||||
log.debug("Starting generator");
|
||||
for await (const message of that.messageStream!) {
|
||||
lastMessage = message;
|
||||
that.push(message);
|
||||
}
|
||||
that.push(null);
|
||||
} catch (err) {
|
||||
log.error({ err, lastMessage }, "Error decoding eventstream message");
|
||||
that.emit("error", err);
|
||||
}
|
||||
})();
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: string, callback: () => void) {
|
||||
this.enqueueChunk(chunk);
|
||||
callback();
|
||||
}
|
||||
|
||||
_flush(callback: () => void) {
|
||||
log.debug("Received end of stream; stopping generator");
|
||||
if (this.resolveChunk) {
|
||||
this.resolveChunk(null);
|
||||
}
|
||||
callback();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
|
||||
import { StringDecoder } from "string_decoder";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../../logger";
|
||||
import { RetryableError } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { Message } from "@smithy/eventstream-codec";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
import { StringDecoder } from "string_decoder";
|
||||
import { logger } from "../../../../logger";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import { makeCompletionSSE } from "../../../../shared/streaming";
|
||||
import { RetryableError } from "../index";
|
||||
import { AWSEventStreamDecoder } from "./aws-eventstream-decoder";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
@@ -15,25 +14,19 @@ type SSEStreamAdapterOptions = TransformOptions & {
|
||||
contentType?: string;
|
||||
api: APIFormat;
|
||||
};
|
||||
type AwsEventStreamMessage = {
|
||||
headers: {
|
||||
":message-type": "event" | "exception";
|
||||
":exception-type"?: string;
|
||||
};
|
||||
payload: { message?: string /** base64 encoded */; bytes?: string };
|
||||
};
|
||||
|
||||
/**
|
||||
* Receives either text chunks or AWS binary event stream chunks and emits
|
||||
* full SSE events.
|
||||
* Receives either text chunks or AWS vnd.amazon.eventstream messages and emits
|
||||
* full SSE-compliant messages.
|
||||
*/
|
||||
export class SSEStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private readonly isGoogleStream;
|
||||
private awsParser = new Parser();
|
||||
private awsDecoder = new AWSEventStreamDecoder();
|
||||
private jsonParser = StreamArray.withParser();
|
||||
private partialMessage = "";
|
||||
private decoder = new StringDecoder("utf8");
|
||||
private textDecoder = new TextDecoder("utf8");
|
||||
|
||||
constructor(options?: SSEStreamAdapterOptions) {
|
||||
super(options);
|
||||
@@ -41,10 +34,14 @@ export class SSEStreamAdapter extends Transform {
|
||||
options?.contentType === "application/vnd.amazon.eventstream";
|
||||
this.isGoogleStream = options?.api === "google-ai";
|
||||
|
||||
this.awsParser.on("data", (data: AwsEventStreamMessage) => {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
this.awsDecoder.on("data", (data: Message) => {
|
||||
try {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
} catch (error) {
|
||||
this.emit("error", error);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -56,39 +53,52 @@ export class SSEStreamAdapter extends Transform {
|
||||
});
|
||||
}
|
||||
|
||||
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
const { payload, headers } = event;
|
||||
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
||||
const eventStr = JSON.stringify(event);
|
||||
// Under high load, AWS can rugpull us by returning a 200 and starting the
|
||||
// stream but then immediately sending a rate limit error as the first
|
||||
// event. My guess is some race condition in their rate limiting check
|
||||
// that occurs if two requests arrive at the same time when only one
|
||||
// concurrency slot is available.
|
||||
if (headers[":exception-type"] === "throttlingException") {
|
||||
log.warn(
|
||||
{ event: eventStr },
|
||||
"AWS request throttled after streaming has already started; retrying"
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
} else {
|
||||
log.error({ event: eventStr }, "Received bad AWS stream event");
|
||||
return makeCompletionSSE({
|
||||
format: "anthropic",
|
||||
title: "Proxy stream error",
|
||||
message:
|
||||
"The proxy received malformed or unexpected data from AWS while streaming.",
|
||||
obj: event,
|
||||
reqId: "proxy-sse-adapter-message",
|
||||
model: "",
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const { bytes } = payload;
|
||||
return [
|
||||
"event: completion",
|
||||
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
|
||||
].join("\n");
|
||||
protected processAwsEvent(message: Message): string | null {
|
||||
// Per amazon, headers and body are always present. headers is an object,
|
||||
// body is a Uint8Array, potentially zero-length.
|
||||
const { headers, body } = message;
|
||||
const eventType = headers[":event-type"]?.value;
|
||||
const messageType = headers[":message-type"]?.value;
|
||||
const contentType = headers[":content-type"]?.value;
|
||||
const exceptionType = headers[":exception-type"]?.value;
|
||||
const errorCode = headers[":error-code"]?.value;
|
||||
const bodyStr = this.textDecoder.decode(body);
|
||||
|
||||
switch (messageType) {
|
||||
case "event":
|
||||
if (contentType === "application/json" && eventType === "chunk") {
|
||||
const { bytes } = JSON.parse(bodyStr);
|
||||
const event = Buffer.from(bytes, "base64").toString("utf8");
|
||||
return ["event: completion", `data: ${event}`].join(`\n`);
|
||||
}
|
||||
// Intentional fallthrough, non-JSON events will be something very weird
|
||||
// noinspection FallThroughInSwitchStatementJS
|
||||
case "exception":
|
||||
case "error":
|
||||
const type = exceptionType || errorCode || "UnknownError";
|
||||
switch (type) {
|
||||
case "ThrottlingException":
|
||||
log.warn(
|
||||
{ message, type },
|
||||
"AWS request throttled after streaming has already started; retrying"
|
||||
);
|
||||
throw new RetryableError("AWS request throttled mid-stream");
|
||||
default:
|
||||
log.error({ message, type }, "Received bad AWS stream event");
|
||||
return makeCompletionSSE({
|
||||
format: "anthropic",
|
||||
title: "Proxy stream error",
|
||||
message:
|
||||
"The proxy received an unrecognized error from AWS while streaming.",
|
||||
obj: message,
|
||||
reqId: "proxy-sse-adapter-message",
|
||||
model: "",
|
||||
});
|
||||
}
|
||||
default:
|
||||
// Amazon says this can't ever happen...
|
||||
log.error({ message }, "Received very bad AWS stream event");
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +131,7 @@ export class SSEStreamAdapter extends Transform {
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
this.awsParser.write(chunk);
|
||||
this.awsDecoder.write(chunk);
|
||||
} else if (this.isGoogleStream) {
|
||||
this.jsonParser.write(chunk);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user