Files
OAI-Proxy/src/proxy/middleware/response/streaming/sse-stream-adapter.ts
T

153 lines
5.2 KiB
TypeScript

import { Transform, TransformOptions } from "stream";
import { StringDecoder } from "string_decoder";
// @ts-ignore
import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../../logger";
import { RetryableError } from "../index";
import { APIFormat } from "../../../../shared/key-management";
import StreamArray from "stream-json/streamers/StreamArray";
const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
};
type AwsEventStreamMessage = {
headers: {
":message-type": "event" | "exception";
":exception-type"?: string;
};
payload: { message?: string /** base64 encoded */; bytes?: string };
};
/**
* Receives either text chunks or AWS binary event stream chunks and emits
* full SSE events.
*/
export class SSEStreamAdapter extends Transform {
private readonly isAwsStream;
private readonly isGoogleStream;
private parser = new Parser();
private jsonStream = StreamArray.withParser();
private partialMessage = "";
private decoder = new StringDecoder("utf8");
constructor(options?: SSEStreamAdapterOptions) {
super(options);
this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.parser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
this.jsonStream.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
}
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
const { payload, headers } = event;
if (headers[":message-type"] === "exception" || !payload.bytes) {
const eventStr = JSON.stringify(event);
// Under high load, AWS can rugpull us by returning a 200 and starting the
// stream but then immediately sending a rate limit error as the first
// event. My guess is some race condition in their rate limiting check
// that occurs if two requests arrive at the same time when only one
// concurrency slot is available.
if (headers[":exception-type"] === "throttlingException") {
log.warn(
{ event: eventStr },
"AWS request throttled after streaming has already started; retrying"
);
throw new RetryableError("AWS request throttled mid-stream");
} else {
log.error(
{ event: eventStr },
"Received unexpected AWS event stream message"
);
return getFakeErrorCompletion("proxy AWS error", eventStr);
}
} else {
const { bytes } = payload;
// technically this is a transformation but we don't really distinguish
// between aws claude and anthropic claude at the APIFormat level, so
// these will short circuit the message transformer
return [
"event: completion",
`data: ${Buffer.from(bytes, "base64").toString("utf8")}`,
].join("\n");
}
}
// Google doesn't use event streams and just sends elements in an array over
// a long-lived HTTP connection. Needs stream-json to parse the array.
protected processGoogleValue(value: any): string | null {
if ("candidates" in value) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error(
{ value },
"Received unexpected Google AI event stream message"
);
return getFakeErrorCompletion(
"proxy Google AI error",
JSON.stringify(value)
);
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try {
if (this.isAwsStream) {
this.parser.write(chunk);
} else if (this.isGoogleStream) {
this.jsonStream.write(chunk);
} else {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly.
const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) {
// Mixing line endings will break some clients and our request queue
// will have already sent \n for heartbeats, so we need to normalize
// to \n.
this.push(message.replace(/\r\n?/g, "\n") + "\n\n");
}
}
callback();
} catch (error) {
this.emit("error", error);
callback(error);
}
}
}
function getFakeErrorCompletion(type: string, message: string) {
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
const fakeEvent = JSON.stringify({
log_id: "aws-proxy-sse-message",
stop_reason: type,
completion:
"\nProxy encountered an error during streaming response.\n" + content,
truncated: false,
stop: null,
model: "",
});
return ["event: completion", `data: ${fakeEvent}\n\n`].join("\n");
}