Implement AWS Bedrock support (khanon/oai-reverse-proxy!45)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import { Request, Response } from "express";
|
||||
import httpProxy from "http-proxy";
|
||||
import { ZodError } from "zod";
|
||||
import { APIFormat } from "../../shared/key-management";
|
||||
import { assertNever } from "../../shared/utils";
|
||||
import { QuotaExceededError } from "./request/apply-quota-limits";
|
||||
|
||||
@@ -59,7 +58,7 @@ export function writeErrorResponse(
|
||||
res.write(`data: [DONE]\n\n`);
|
||||
res.end();
|
||||
} else {
|
||||
if (req.debug) {
|
||||
if (req.debug && errorPayload.error) {
|
||||
errorPayload.error.proxy_tokenizer_debug_info = req.debug;
|
||||
}
|
||||
res.status(statusCode).json(errorPayload);
|
||||
@@ -132,10 +131,7 @@ export function buildFakeSseMessage(
|
||||
req: Request
|
||||
) {
|
||||
let fakeEvent;
|
||||
const useBackticks = !type.includes("403");
|
||||
const msgContent = useBackticks
|
||||
? `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`
|
||||
: `[${type}: ${string}]`;
|
||||
const content = `\`\`\`\n[${type}: ${string}]\n\`\`\`\n`;
|
||||
|
||||
switch (req.inboundApi) {
|
||||
case "openai":
|
||||
@@ -144,13 +140,7 @@ export function buildFakeSseMessage(
|
||||
object: "chat.completion.chunk",
|
||||
created: Date.now(),
|
||||
model: req.body?.model,
|
||||
choices: [
|
||||
{
|
||||
delta: { content: msgContent },
|
||||
index: 0,
|
||||
finish_reason: type,
|
||||
},
|
||||
],
|
||||
choices: [{ delta: { content }, index: 0, finish_reason: type }],
|
||||
};
|
||||
break;
|
||||
case "openai-text":
|
||||
@@ -159,14 +149,14 @@ export function buildFakeSseMessage(
|
||||
object: "text_completion",
|
||||
created: Date.now(),
|
||||
choices: [
|
||||
{ text: msgContent, index: 0, logprobs: null, finish_reason: type },
|
||||
{ text: content, index: 0, logprobs: null, finish_reason: type },
|
||||
],
|
||||
model: req.body?.model,
|
||||
};
|
||||
break;
|
||||
case "anthropic":
|
||||
fakeEvent = {
|
||||
completion: msgContent,
|
||||
completion: content,
|
||||
stop_reason: type,
|
||||
truncated: false, // I've never seen this be true
|
||||
stop: null,
|
||||
@@ -182,25 +172,42 @@ export function buildFakeSseMessage(
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
|
||||
export function getCompletionForService({
|
||||
service,
|
||||
body,
|
||||
req,
|
||||
}: {
|
||||
service: APIFormat;
|
||||
body: Record<string, any>;
|
||||
req?: Request;
|
||||
}): { completion: string; model: string } {
|
||||
switch (service) {
|
||||
export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
return { completion: body.choices[0].message.content, model: body.model };
|
||||
return body.choices[0].message.content;
|
||||
case "openai-text":
|
||||
return { completion: body.choices[0].text, model: body.model };
|
||||
return body.choices[0].text;
|
||||
case "anthropic":
|
||||
return { completion: body.completion.trim(), model: body.model };
|
||||
if (!body.completion) {
|
||||
req.log.error(
|
||||
{ body: JSON.stringify(body) },
|
||||
"Received empty Anthropic completion"
|
||||
);
|
||||
return "";
|
||||
}
|
||||
return body.completion.trim();
|
||||
case "google-palm":
|
||||
return { completion: body.candidates[0].output, model: req?.body.model };
|
||||
return body.candidates[0].output;
|
||||
default:
|
||||
assertNever(service);
|
||||
assertNever(format);
|
||||
}
|
||||
}
|
||||
|
||||
export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
const format = req.outboundApi;
|
||||
switch (format) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
return body.model;
|
||||
case "anthropic":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return body.model || req.body.model;
|
||||
case "google-palm":
|
||||
// Google doesn't confirm the model in the response.
|
||||
return req.body.model;
|
||||
default:
|
||||
assertNever(format);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,7 +80,6 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
proxyReq.setHeader("X-API-Key", assignedKey.key);
|
||||
break;
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
const key: OpenAIKey = assignedKey as OpenAIKey;
|
||||
if (key.organizationId) {
|
||||
proxyReq.setHeader("OpenAI-Organization", key.organizationId);
|
||||
@@ -94,6 +93,10 @@ export const addKey: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
);
|
||||
default:
|
||||
assertNever(assignedKey.service);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
import { RequestPreprocessor } from "./index";
|
||||
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
* tokens and assigns the count to the request.
|
||||
*/
|
||||
export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
const service = req.outboundApi;
|
||||
let result;
|
||||
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "openai-text": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "anthropic": {
|
||||
req.outputTokens = req.body.max_tokens_to_sample;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
req.outputTokens = req.body.maxOutputTokens;
|
||||
const prompt: string = req.body.prompt.text;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
||||
req.promptTokens = result.token_count;
|
||||
|
||||
// TODO: Remove once token counting is stable
|
||||
req.log.debug({ result: result }, "Counted prompt tokens.");
|
||||
req.debug = req.debug ?? {};
|
||||
req.debug = { ...req.debug, ...result };
|
||||
};
|
||||
@@ -0,0 +1,26 @@
|
||||
import type { ProxyRequestMiddleware } from ".";
|
||||
|
||||
/**
|
||||
* For AWS requests, the body is signed earlier in the request pipeline, before
|
||||
* the proxy middleware. This function just assigns the path and headers to the
|
||||
* proxy request.
|
||||
*/
|
||||
export const finalizeAwsRequest: ProxyRequestMiddleware = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
throw new Error("Expected req.signedRequest to be set");
|
||||
}
|
||||
|
||||
// The path depends on the selected model and the assigned key's region.
|
||||
proxyReq.path = req.signedRequest.path;
|
||||
|
||||
// Amazon doesn't want extra headers, so we need to remove all of them and
|
||||
// reassign only the ones specified in the signed request.
|
||||
proxyReq.getRawHeaderNames().forEach(proxyReq.removeHeader.bind(proxyReq));
|
||||
Object.entries(req.signedRequest.headers).forEach(([key, value]) => {
|
||||
proxyReq.setHeader(key, value);
|
||||
});
|
||||
|
||||
// Don't use fixRequestBody here because it adds a content-length header.
|
||||
// Amazon doesn't want that and it breaks the signature.
|
||||
proxyReq.write(req.signedRequest.body);
|
||||
};
|
||||
@@ -2,14 +2,17 @@ import type { Request } from "express";
|
||||
import type { ClientRequest } from "http";
|
||||
import type { ProxyReqCallback } from "http-proxy";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export {
|
||||
createPreprocessorMiddleware,
|
||||
createEmbeddingsPreprocessorMiddleware,
|
||||
} from "./preprocess";
|
||||
export { checkContextSize } from "./check-context-size";
|
||||
|
||||
// Express middleware (runs before http-proxy-middleware, can be async)
|
||||
export { applyQuotaLimits } from "./apply-quota-limits";
|
||||
export { validateContextSize } from "./validate-context-size";
|
||||
export { countPromptTokens } from "./count-prompt-tokens";
|
||||
export { setApiFormat } from "./set-api-format";
|
||||
export { signAwsRequest } from "./sign-aws-request";
|
||||
export { transformOutboundPayload } from "./transform-outbound-payload";
|
||||
|
||||
// HPM middleware (runs on onProxyReq, cannot be async)
|
||||
@@ -17,6 +20,7 @@ export { addKey, addKeyForEmbeddingsRequest } from "./add-key";
|
||||
export { addAnthropicPreamble } from "./add-anthropic-preamble";
|
||||
export { blockZoomerOrigins } from "./block-zoomer-origins";
|
||||
export { finalizeBody } from "./finalize-body";
|
||||
export { finalizeAwsRequest } from "./finalize-aws-request";
|
||||
export { languageFilter } from "./language-filter";
|
||||
export { limitCompletions } from "./limit-completions";
|
||||
export { stripHeaders } from "./strip-headers";
|
||||
@@ -50,3 +54,6 @@ export type RequestPreprocessor = (req: Request) => void | Promise<void>;
|
||||
* request queue middleware.
|
||||
*/
|
||||
export type ProxyRequestMiddleware = ProxyReqCallback<ClientRequest, Request>;
|
||||
|
||||
export const forceModel = (model: string) => (req: Request) =>
|
||||
void (req.body.model = model);
|
||||
|
||||
@@ -2,24 +2,42 @@ import { RequestHandler } from "express";
|
||||
import { handleInternalError } from "../common";
|
||||
import {
|
||||
RequestPreprocessor,
|
||||
checkContextSize,
|
||||
validateContextSize,
|
||||
countPromptTokens,
|
||||
setApiFormat,
|
||||
transformOutboundPayload,
|
||||
} from ".";
|
||||
|
||||
type RequestPreprocessorOptions = {
|
||||
/**
|
||||
* Functions to run before the request body is transformed between API
|
||||
* formats. Use this to change the behavior of the transformation, such as for
|
||||
* endpoints which can accept multiple API formats.
|
||||
*/
|
||||
beforeTransform?: RequestPreprocessor[];
|
||||
/**
|
||||
* Functions to run after the request body is transformed and token counts are
|
||||
* assigned. Use this to perform validation or other actions that depend on
|
||||
* the request body being in the final API format.
|
||||
*/
|
||||
afterTransform?: RequestPreprocessor[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a middleware function that processes the request body into the given
|
||||
* API format, and then sequentially runs the given additional preprocessors.
|
||||
*/
|
||||
export const createPreprocessorMiddleware = (
|
||||
apiFormat: Parameters<typeof setApiFormat>[0],
|
||||
additionalPreprocessors?: RequestPreprocessor[]
|
||||
{ beforeTransform, afterTransform }: RequestPreprocessorOptions = {}
|
||||
): RequestHandler => {
|
||||
const preprocessors: RequestPreprocessor[] = [
|
||||
setApiFormat(apiFormat),
|
||||
...(additionalPreprocessors ?? []),
|
||||
...(beforeTransform ?? []),
|
||||
transformOutboundPayload,
|
||||
checkContextSize,
|
||||
countPromptTokens,
|
||||
...(afterTransform ?? []),
|
||||
validateContextSize,
|
||||
];
|
||||
return async (...args) => executePreprocessors(preprocessors, args);
|
||||
};
|
||||
@@ -29,13 +47,10 @@ export const createPreprocessorMiddleware = (
|
||||
* OpenAI's embeddings API. Tokens are not counted because embeddings requests
|
||||
* are basically free.
|
||||
*/
|
||||
export const createEmbeddingsPreprocessorMiddleware = (
|
||||
additionalPreprocessors?: RequestPreprocessor[]
|
||||
): RequestHandler => {
|
||||
export const createEmbeddingsPreprocessorMiddleware = (): RequestHandler => {
|
||||
const preprocessors: RequestPreprocessor[] = [
|
||||
setApiFormat({ inApi: "openai", outApi: "openai" }),
|
||||
setApiFormat({ inApi: "openai", outApi: "openai", service: "openai" }),
|
||||
(req) => void (req.promptTokens = req.outputTokens = 0),
|
||||
...(additionalPreprocessors ?? []),
|
||||
];
|
||||
return async (...args) => executePreprocessors(preprocessors, args);
|
||||
};
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import { Request } from "express";
|
||||
import { APIFormat } from "../../../shared/key-management";
|
||||
import { APIFormat, LLMService } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
export const setApiFormat = (api: {
|
||||
inApi: Request["inboundApi"];
|
||||
outApi: APIFormat;
|
||||
service: LLMService,
|
||||
}): RequestPreprocessor => {
|
||||
return (req) => {
|
||||
req.inboundApi = api.inApi;
|
||||
req.outboundApi = api.outApi;
|
||||
req.service = api.service;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import express from "express";
|
||||
import { Sha256 } from "@aws-crypto/sha256-js";
|
||||
import { SignatureV4 } from "@smithy/signature-v4";
|
||||
import { HttpRequest } from "@smithy/protocol-http";
|
||||
import { keyPool } from "../../../shared/key-management";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { AnthropicV1CompleteSchema } from "./transform-outbound-payload";
|
||||
|
||||
const AMZ_HOST =
|
||||
process.env.AMZ_HOST || "invoke-bedrock.%REGION%.amazonaws.com";
|
||||
|
||||
/**
|
||||
* Signs an outgoing AWS request with the appropriate headers modifies the
|
||||
* request object in place to fix the path.
|
||||
*/
|
||||
export const signAwsRequest: RequestPreprocessor = async (req) => {
|
||||
req.key = keyPool.get("anthropic.claude-v2");
|
||||
|
||||
const { model, stream } = req.body;
|
||||
req.isStreaming = stream === true || stream === "true";
|
||||
|
||||
let preamble = req.body.prompt.startsWith("\n\nHuman:") ? "" : "\n\nHuman:";
|
||||
req.body.prompt = preamble + req.body.prompt;
|
||||
|
||||
// AWS supports only a subset of Anthropic's parameters and is more strict
|
||||
// about unknown parameters.
|
||||
// TODO: This should happen in transform-outbound-payload.ts
|
||||
const strippedParams = AnthropicV1CompleteSchema.pick({
|
||||
prompt: true,
|
||||
max_tokens_to_sample: true,
|
||||
stop_sequences: true,
|
||||
temperature: true,
|
||||
top_k: true,
|
||||
top_p: true,
|
||||
}).parse(req.body);
|
||||
|
||||
const credential = getCredentialParts(req);
|
||||
const host = AMZ_HOST.replace("%REGION%", credential.region);
|
||||
|
||||
// Uses the AWS SDK to sign a request, then modifies our HPM proxy request
|
||||
// with the headers generated by the SDK.
|
||||
const newRequest = new HttpRequest({
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: host,
|
||||
path: `/model/${model}/invoke${stream ? "-with-response-stream" : ""}`,
|
||||
headers: {
|
||||
["Host"]: host,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(strippedParams),
|
||||
});
|
||||
|
||||
if (stream) {
|
||||
newRequest.headers["x-amzn-bedrock-accept"] = "application/json";
|
||||
} else {
|
||||
newRequest.headers["accept"] = "*/*";
|
||||
}
|
||||
|
||||
req.signedRequest = await sign(newRequest, getCredentialParts(req));
|
||||
};
|
||||
|
||||
type Credential = {
|
||||
accessKeyId: string;
|
||||
secretAccessKey: string;
|
||||
region: string;
|
||||
};
|
||||
function getCredentialParts(req: express.Request): Credential {
|
||||
const [accessKeyId, secretAccessKey, region] = req.key!.key.split(":");
|
||||
|
||||
if (!accessKeyId || !secretAccessKey || !region) {
|
||||
req.log.error(
|
||||
{ key: req.key!.hash },
|
||||
"AWS_CREDENTIALS isn't correctly formatted; refer to the docs"
|
||||
);
|
||||
throw new Error("The key assigned to this request is invalid.");
|
||||
}
|
||||
|
||||
return { accessKeyId, secretAccessKey, region };
|
||||
}
|
||||
|
||||
async function sign(request: HttpRequest, credential: Credential) {
|
||||
const { accessKeyId, secretAccessKey, region } = credential;
|
||||
|
||||
const signer = new SignatureV4({
|
||||
sha256: Sha256,
|
||||
credentials: { accessKeyId, secretAccessKey },
|
||||
region,
|
||||
service: "bedrock",
|
||||
});
|
||||
|
||||
return signer.sign(request);
|
||||
}
|
||||
@@ -10,8 +10,8 @@ const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string().regex(/^claude-/, "Model must start with 'claude-'"),
|
||||
export const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string(),
|
||||
prompt: z.string({
|
||||
required_error:
|
||||
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
|
||||
@@ -23,14 +23,14 @@ const AnthropicV1CompleteSchema = z.object({
|
||||
stop_sequences: z.array(z.string()).optional(),
|
||||
stream: z.boolean().optional().default(false),
|
||||
temperature: z.coerce.number().optional().default(1),
|
||||
top_k: z.coerce.number().optional().default(-1),
|
||||
top_p: z.coerce.number().optional().default(-1),
|
||||
top_k: z.coerce.number().optional(),
|
||||
top_p: z.coerce.number().optional(),
|
||||
metadata: z.any().optional(),
|
||||
});
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
const OpenAIV1ChatCompletionSchema = z.object({
|
||||
model: z.string().regex(/^gpt/, "Model must start with 'gpt-'"),
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
@@ -89,7 +89,7 @@ const OpenAIV1TextCompletionSchema = z
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
|
||||
const PalmV1GenerateTextSchema = z.object({
|
||||
model: z.string().regex(/^\w+-bison-\d{3}$/),
|
||||
model: z.string(),
|
||||
prompt: z.object({ text: z.string() }),
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
@@ -159,7 +159,7 @@ function openaiToAnthropic(req: Request) {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
req.log.error(
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Anthropic request"
|
||||
);
|
||||
@@ -208,7 +208,7 @@ function openaiToOpenaiText(req: Request) {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
|
||||
if (!result.success) {
|
||||
req.log.error(
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-OpenAI-text request"
|
||||
);
|
||||
@@ -227,8 +227,7 @@ function openaiToOpenaiText(req: Request) {
|
||||
stops = [...new Set(stops)];
|
||||
|
||||
const transformed = { ...rest, prompt: prompt, stop: stops };
|
||||
const validated = OpenAIV1TextCompletionSchema.parse(transformed);
|
||||
return validated;
|
||||
return OpenAIV1TextCompletionSchema.parse(transformed);
|
||||
}
|
||||
|
||||
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
@@ -238,7 +237,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
model: "gpt-3.5-turbo",
|
||||
});
|
||||
if (!result.success) {
|
||||
req.log.error(
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Palm request"
|
||||
);
|
||||
|
||||
+3
-72
@@ -1,9 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { OpenAIPromptMessage, countTokens } from "../../../shared/tokenization";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
@@ -16,51 +15,7 @@ const BISON_MAX_CONTEXT = 8100;
|
||||
* This preprocessor should run after any preprocessor that transforms the
|
||||
* request body.
|
||||
*/
|
||||
export const checkContextSize: RequestPreprocessor = async (req) => {
|
||||
const service = req.outboundApi;
|
||||
let result;
|
||||
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "openai-text": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "anthropic": {
|
||||
req.outputTokens = req.body.max_tokens_to_sample;
|
||||
const prompt: string = req.body.prompt;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
req.outputTokens = req.body.maxOutputTokens;
|
||||
const prompt: string = req.body.prompt.text;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
default:
|
||||
assertNever(service);
|
||||
}
|
||||
|
||||
req.promptTokens = result.token_count;
|
||||
|
||||
// TODO: Remove once token counting is stable
|
||||
req.log.debug({ result: result }, "Counted prompt tokens.");
|
||||
req.debug = req.debug ?? {};
|
||||
req.debug = { ...req.debug, ...result };
|
||||
|
||||
maybeTranslateOpenAIModel(req);
|
||||
validateContextSize(req);
|
||||
};
|
||||
|
||||
function validateContextSize(req: Request) {
|
||||
export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
assertRequestHasTokenCounts(req);
|
||||
const promptTokens = req.promptTokens;
|
||||
const outputTokens = req.outputTokens;
|
||||
@@ -125,7 +80,7 @@ function validateContextSize(req: Request) {
|
||||
req.debug.completion_tokens = outputTokens;
|
||||
req.debug.max_model_tokens = modelMax;
|
||||
req.debug.max_proxy_tokens = proxyMax;
|
||||
}
|
||||
};
|
||||
|
||||
function assertRequestHasTokenCounts(
|
||||
req: Request
|
||||
@@ -137,27 +92,3 @@ function assertRequestHasTokenCounts(
|
||||
.nonstrict()
|
||||
.parse({ promptTokens: req.promptTokens, outputTokens: req.outputTokens });
|
||||
}
|
||||
|
||||
/**
|
||||
* For OpenAI-to-Anthropic requests, users can't specify the model, so we need
|
||||
* to pick one based on the final context size. Ideally this would happen in
|
||||
* the `transformOutboundPayload` preprocessor, but we don't have the context
|
||||
* size at that point (and need a transformed body to calculate it).
|
||||
*/
|
||||
function maybeTranslateOpenAIModel(req: Request) {
|
||||
if (req.inboundApi !== "openai" || req.outboundApi !== "anthropic") {
|
||||
return;
|
||||
}
|
||||
|
||||
const bigModel = process.env.CLAUDE_BIG_MODEL || "claude-v1-100k";
|
||||
const contextSize = req.promptTokens! + req.outputTokens!;
|
||||
|
||||
if (contextSize > 8500) {
|
||||
req.log.debug(
|
||||
{ model: bigModel, contextSize },
|
||||
"Using Claude 100k model for OpenAI-to-Anthropic request"
|
||||
);
|
||||
req.body.model = bigModel;
|
||||
}
|
||||
// Small model is the default already set in `transformOutboundPayload`
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import * as http from "http";
|
||||
import { buildFakeSseMessage } from "../common";
|
||||
import { RawResponseBodyHandler, decodeResponseBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { ServerSentEventStreamAdapter } from "./sse-stream-adapter";
|
||||
|
||||
type OpenAiChatCompletionResponse = {
|
||||
id: string;
|
||||
@@ -82,6 +83,11 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
return decodeResponseBody(proxyRes, req, res);
|
||||
}
|
||||
|
||||
req.log.debug(
|
||||
{ headers: proxyRes.headers, key: key.hash },
|
||||
`Received SSE headers.`
|
||||
);
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`);
|
||||
|
||||
@@ -97,75 +103,50 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
res.flushHeaders();
|
||||
}
|
||||
|
||||
const originalEvents: string[] = [];
|
||||
let partialMessage = "";
|
||||
const adapter = new ServerSentEventStreamAdapter({
|
||||
isAwsStream:
|
||||
proxyRes.headers["content-type"] ===
|
||||
"application/vnd.amazon.eventstream",
|
||||
});
|
||||
|
||||
const events: string[] = [];
|
||||
let lastPosition = 0;
|
||||
let eventCount = 0;
|
||||
|
||||
type ProxyResHandler<T extends unknown> = (...args: T[]) => void;
|
||||
function withErrorHandling<T extends unknown>(fn: ProxyResHandler<T>) {
|
||||
return (...args: T[]) => {
|
||||
try {
|
||||
fn(...args);
|
||||
} catch (error) {
|
||||
proxyRes.emit("error", error);
|
||||
}
|
||||
};
|
||||
}
|
||||
proxyRes.pipe(adapter);
|
||||
|
||||
proxyRes.on(
|
||||
"data",
|
||||
withErrorHandling((chunk: Buffer) => {
|
||||
// We may receive multiple (or partial) SSE messages in a single chunk,
|
||||
// so we need to buffer and emit seperate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = chunk.toString();
|
||||
|
||||
// Anthropic uses CRLF line endings (out-of-spec btw)
|
||||
const fullMessages = (partialMessage + str).split(/\r?\n\r?\n/);
|
||||
partialMessage = fullMessages.pop() || "";
|
||||
|
||||
for (const message of fullMessages) {
|
||||
proxyRes.emit("full-sse-event", message);
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
proxyRes.on(
|
||||
"full-sse-event",
|
||||
withErrorHandling((data) => {
|
||||
originalEvents.push(data);
|
||||
adapter.on("data", (chunk: any) => {
|
||||
try {
|
||||
const { event, position } = transformEvent({
|
||||
data,
|
||||
data: chunk.toString(),
|
||||
requestApi: req.inboundApi,
|
||||
responseApi: req.outboundApi,
|
||||
lastPosition,
|
||||
index: eventCount++,
|
||||
});
|
||||
events.push(event);
|
||||
lastPosition = position;
|
||||
res.write(event + "\n\n");
|
||||
})
|
||||
);
|
||||
} catch (err) {
|
||||
adapter.emit("error", err);
|
||||
}
|
||||
});
|
||||
|
||||
proxyRes.on(
|
||||
"end",
|
||||
withErrorHandling(() => {
|
||||
let finalBody = convertEventsToFinalResponse(originalEvents, req);
|
||||
adapter.on("end", () => {
|
||||
try {
|
||||
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
|
||||
const finalBody = convertEventsToFinalResponse(events, req);
|
||||
res.end();
|
||||
resolve(finalBody);
|
||||
})
|
||||
);
|
||||
} catch (err) {
|
||||
adapter.emit("error", err);
|
||||
}
|
||||
});
|
||||
|
||||
proxyRes.on("error", (err) => {
|
||||
adapter.on("error", (err) => {
|
||||
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
|
||||
const fakeErrorEvent = buildFakeSseMessage(
|
||||
"mid-stream-error",
|
||||
err.message,
|
||||
req
|
||||
);
|
||||
res.write(`data: ${JSON.stringify(fakeErrorEvent)}\n\n`);
|
||||
res.write("data: [DONE]\n\n");
|
||||
const errorEvent = buildFakeSseMessage("stream-error", err.message, req);
|
||||
res.write(`data: ${JSON.stringify(errorEvent)}\n\ndata: [DONE]\n\n`);
|
||||
res.end();
|
||||
reject(err);
|
||||
});
|
||||
@@ -197,8 +178,6 @@ function transformEvent(params: SSETransformationArgs) {
|
||||
case "openai->anthropic":
|
||||
// TODO: handle new anthropic streaming format
|
||||
return transformV1AnthropicEventToOpenAI(params);
|
||||
case "openai->google-palm":
|
||||
return transformPalmEventToOpenAI(params);
|
||||
default:
|
||||
throw new Error(`Unsupported streaming API transformation. ${trans}`);
|
||||
}
|
||||
@@ -288,11 +267,6 @@ function transformV1AnthropicEventToOpenAI(params: SSETransformationArgs) {
|
||||
};
|
||||
}
|
||||
|
||||
function transformPalmEventToOpenAI({ data }: SSETransformationArgs) {
|
||||
throw new Error("PaLM streaming not yet supported.");
|
||||
return { position: -1, event: data };
|
||||
}
|
||||
|
||||
/** Copy headers, excluding ones we're already setting for the SSE response. */
|
||||
function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
|
||||
const toOmit = [
|
||||
@@ -366,7 +340,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
|
||||
choices: [],
|
||||
// TODO: merge logprobs
|
||||
};
|
||||
merged = events.reduce((acc, event, i) => {
|
||||
merged = events.reduce((acc, event) => {
|
||||
if (!event.startsWith("data: ")) return acc;
|
||||
if (event === "data: [DONE]") return acc;
|
||||
|
||||
@@ -390,16 +364,37 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
|
||||
return merged;
|
||||
}
|
||||
case "anthropic": {
|
||||
/*
|
||||
* Full complete responses from Anthropic are conveniently just the same as
|
||||
* the final SSE event before the "DONE" event, so we can reuse that
|
||||
*/
|
||||
const lastEvent = events[events.length - 2].toString();
|
||||
const data = JSON.parse(
|
||||
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
|
||||
);
|
||||
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
|
||||
return final;
|
||||
if (req.headers["anthropic-version"] === "2023-01-01") {
|
||||
return convertAnthropicV1(events, req);
|
||||
}
|
||||
|
||||
let merged: AnthropicCompletionResponse = {
|
||||
completion: "",
|
||||
stop_reason: "",
|
||||
truncated: false,
|
||||
stop: null,
|
||||
model: req.body.model,
|
||||
log_id: "",
|
||||
exception: null,
|
||||
}
|
||||
|
||||
merged = events.reduce((acc, event) => {
|
||||
if (!event.startsWith("data: ")) return acc;
|
||||
if (event === "data: [DONE]") return acc;
|
||||
|
||||
const data = JSON.parse(event.slice("data: ".length));
|
||||
|
||||
return {
|
||||
completion: acc.completion + data.completion,
|
||||
stop_reason: data.stop_reason,
|
||||
truncated: data.truncated,
|
||||
stop: data.stop,
|
||||
log_id: data.log_id,
|
||||
exception: data.exception,
|
||||
model: acc.model,
|
||||
};
|
||||
}, merged);
|
||||
return merged;
|
||||
}
|
||||
case "google-palm": {
|
||||
throw new Error("PaLM streaming not yet supported.");
|
||||
@@ -408,3 +403,16 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
|
||||
assertNever(req.outboundApi);
|
||||
}
|
||||
}
|
||||
|
||||
/** Older Anthropic streaming format which sent full completion each time. */
|
||||
function convertAnthropicV1(
|
||||
events: string[],
|
||||
req: Request
|
||||
) {
|
||||
const lastEvent = events[events.length - 2].toString();
|
||||
const data = JSON.parse(
|
||||
lastEvent.slice(lastEvent.indexOf("data: ") + "data: ".length)
|
||||
);
|
||||
const final: AnthropicCompletionResponse = { ...data, log_id: req.id };
|
||||
return final;
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
incrementTokenCount,
|
||||
} from "../../../shared/users/user-store";
|
||||
import {
|
||||
getCompletionForService,
|
||||
getCompletionFromBody,
|
||||
isCompletionRequest,
|
||||
writeErrorResponse,
|
||||
} from "../common";
|
||||
@@ -173,7 +173,7 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
throw err;
|
||||
}
|
||||
|
||||
const promise = new Promise<string>((resolve, reject) => {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
let chunks: Buffer[] = [];
|
||||
proxyRes.on("data", (chunk) => chunks.push(chunk));
|
||||
proxyRes.on("end", async () => {
|
||||
@@ -209,10 +209,14 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
|
||||
}
|
||||
});
|
||||
});
|
||||
return promise;
|
||||
};
|
||||
|
||||
// TODO: This is too specific to OpenAI's error responses.
|
||||
type ProxiedErrorPayload = {
|
||||
error?: Record<string, any>;
|
||||
message?: string;
|
||||
proxy_note?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles non-2xx responses from the upstream service. If the proxied response
|
||||
* is an error, this will respond to the client with an error payload and throw
|
||||
@@ -233,27 +237,19 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
return;
|
||||
}
|
||||
|
||||
let errorPayload: Record<string, any>;
|
||||
// Subtract 1 from available keys because if this message is being shown,
|
||||
// it's because the key is about to be disabled.
|
||||
const availableKeys = keyPool.available(req.outboundApi) - 1;
|
||||
const tryAgainMessage = Boolean(availableKeys)
|
||||
? `There are ${availableKeys} more keys available; try your request again.`
|
||||
: "There are no more keys available.";
|
||||
let errorPayload: ProxiedErrorPayload;
|
||||
const tryAgainMessage = keyPool.available(req.body?.model)
|
||||
? `There may be more keys available for this model; try again in a few seconds.`
|
||||
: "There are no more keys available for this model.";
|
||||
|
||||
try {
|
||||
if (typeof body === "object") {
|
||||
errorPayload = body;
|
||||
} else {
|
||||
throw new Error("Received unparsable error response from upstream.");
|
||||
}
|
||||
} catch (parseError: any) {
|
||||
assertJsonResponse(body);
|
||||
errorPayload = body;
|
||||
} catch (parseError) {
|
||||
// Likely Bad Gateway or Gateway Timeout from upstream's reverse proxy
|
||||
const hash = req.key?.hash;
|
||||
const statusMessage = proxyRes.statusMessage || "Unknown error";
|
||||
// Likely Bad Gateway or Gateway Timeout from reverse proxy/load balancer
|
||||
logger.warn(
|
||||
{ statusCode, statusMessage, key: req.key?.hash },
|
||||
parseError.message
|
||||
);
|
||||
logger.warn({ statusCode, statusMessage, key: hash }, parseError.message);
|
||||
|
||||
const errorObject = {
|
||||
statusCode,
|
||||
@@ -265,53 +261,76 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
throw new Error(parseError.message);
|
||||
}
|
||||
|
||||
const errorType =
|
||||
errorPayload.error?.code ||
|
||||
errorPayload.error?.type ||
|
||||
getAwsErrorType(proxyRes.headers["x-amzn-errortype"]);
|
||||
|
||||
logger.warn(
|
||||
{
|
||||
statusCode,
|
||||
type: errorPayload.error?.code,
|
||||
errorPayload,
|
||||
key: req.key?.hash,
|
||||
},
|
||||
{ statusCode, type: errorType, errorPayload, key: req.key?.hash },
|
||||
`Received error response from upstream. (${proxyRes.statusMessage})`
|
||||
);
|
||||
|
||||
const service = req.key!.service;
|
||||
if (service === "aws") {
|
||||
// Try to standardize the error format for AWS
|
||||
errorPayload.error = { message: errorPayload.message, type: errorType };
|
||||
delete errorPayload.message;
|
||||
}
|
||||
|
||||
if (statusCode === 400) {
|
||||
// Bad request (likely prompt is too long)
|
||||
switch (req.outboundApi) {
|
||||
// Bad request. For OpenAI, this is usually due to prompt length.
|
||||
// For Anthropic, this is usually due to missing preamble.
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `Upstream service rejected the request as invalid. Your prompt may be too long for ${req.body?.model}.`;
|
||||
break;
|
||||
case "anthropic":
|
||||
case "aws":
|
||||
maybeHandleMissingPreambleError(req, errorPayload);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
assertNever(service);
|
||||
}
|
||||
} else if (statusCode === 401) {
|
||||
// Key is invalid or was revoked
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
|
||||
} else if (statusCode === 403) {
|
||||
// Amazon is the only service that returns 403.
|
||||
switch (errorType) {
|
||||
case "UnrecognizedClientException":
|
||||
// Key is invalid.
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
|
||||
break;
|
||||
case "AccessDeniedException":
|
||||
errorPayload.proxy_note = `API key doesn't have access to the requested resource.`;
|
||||
break;
|
||||
default:
|
||||
errorPayload.proxy_note = `Received 403 error. Key may be invalid.`;
|
||||
}
|
||||
} else if (statusCode === 429) {
|
||||
switch (req.outboundApi) {
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
|
||||
break;
|
||||
case "anthropic":
|
||||
handleAnthropicRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "aws":
|
||||
handleAwsRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
throw new Error("Rate limit handling not implemented for PaLM");
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
assertNever(service);
|
||||
}
|
||||
} else if (statusCode === 404) {
|
||||
// Most likely model not found
|
||||
switch (req.outboundApi) {
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "openai-text":
|
||||
if (errorPayload.error?.code === "model_not_found") {
|
||||
const requestedModel = req.body.model;
|
||||
const modelFamily = getOpenAIModelFamily(requestedModel);
|
||||
@@ -328,8 +347,11 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
assertNever(service);
|
||||
}
|
||||
} else {
|
||||
errorPayload.proxy_note = `Unrecognized error from upstream service.`;
|
||||
@@ -368,7 +390,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
*/
|
||||
function maybeHandleMissingPreambleError(
|
||||
req: Request,
|
||||
errorPayload: Record<string, any>
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
if (
|
||||
errorPayload.error?.type === "invalid_request_error" &&
|
||||
@@ -388,7 +410,7 @@ function maybeHandleMissingPreambleError(
|
||||
|
||||
function handleAnthropicRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: Record<string, any>
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
if (errorPayload.error?.type === "rate_limit_error") {
|
||||
keyPool.markRateLimited(req.key!);
|
||||
@@ -399,35 +421,55 @@ function handleAnthropicRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
function handleAwsRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
const errorType = errorPayload.error?.type;
|
||||
switch (errorType) {
|
||||
case "ThrottlingException":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("AWS rate-limited request re-enqueued.");
|
||||
case "ModelNotReadyException":
|
||||
errorPayload.proxy_note = `The requested model is overloaded. Try again in a few seconds.`;
|
||||
break;
|
||||
default:
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from AWS. (${errorType})`;
|
||||
}
|
||||
}
|
||||
|
||||
function handleOpenAIRateLimitError(
|
||||
req: Request,
|
||||
tryAgainMessage: string,
|
||||
errorPayload: Record<string, any>
|
||||
errorPayload: ProxiedErrorPayload
|
||||
): Record<string, any> {
|
||||
const type = errorPayload.error?.type;
|
||||
if (type === "insufficient_quota") {
|
||||
// Billing quota exceeded (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "quota");
|
||||
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
|
||||
} else if (type === "access_terminated") {
|
||||
// Account banned (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
|
||||
} else if (type === "billing_not_active") {
|
||||
// Billing is not active (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned key was deactivated by OpenAI. ${tryAgainMessage}`;
|
||||
} else if (type === "requests" || type === "tokens") {
|
||||
// Per-minute request or token rate limit is exceeded, which we can retry
|
||||
keyPool.markRateLimited(req.key!);
|
||||
// I'm aware this is confusing -- throwing this class of error will cause
|
||||
// the proxy response handler to return without terminating the request,
|
||||
// so that it can be placed back in the queue.
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
} else {
|
||||
// OpenAI probably overloaded
|
||||
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
|
||||
switch (type) {
|
||||
case "insufficient_quota":
|
||||
// Billing quota exceeded (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "quota");
|
||||
errorPayload.proxy_note = `Assigned key's quota has been exceeded. ${tryAgainMessage}`;
|
||||
break;
|
||||
case "access_terminated":
|
||||
// Account banned (key is dead, disable it)
|
||||
keyPool.disable(req.key!, "revoked");
|
||||
errorPayload.proxy_note = `Assigned key has been banned by OpenAI for policy violations. ${tryAgainMessage}`;
|
||||
break;
|
||||
case "billing_not_active":
|
||||
// Key valid but account billing is delinquent
|
||||
keyPool.disable(req.key!, "quota");
|
||||
errorPayload.proxy_note = `Assigned key has been disabled due to delinquent billing. ${tryAgainMessage}`;
|
||||
break;
|
||||
case "requests":
|
||||
case "tokens":
|
||||
// Per-minute request or token rate limit is exceeded, which we can retry
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
errorPayload.proxy_note = `This is likely a temporary error with OpenAI. Try again in a few seconds.`;
|
||||
break;
|
||||
}
|
||||
return errorPayload;
|
||||
}
|
||||
@@ -455,12 +497,9 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
// seeing errors in this function, check the reassembled response body from
|
||||
// handleStreamedResponse to see if the upstream API has changed.
|
||||
try {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
assertJsonResponse(body);
|
||||
const service = req.outboundApi;
|
||||
const { completion } = getCompletionForService({ req, service, body });
|
||||
const completion = getCompletionFromBody(req, body);
|
||||
const tokens = await countTokens({ req, completion, service });
|
||||
|
||||
req.log.debug(
|
||||
@@ -473,7 +512,7 @@ const countResponseTokens: ProxyResHandlerWithBody = async (
|
||||
|
||||
req.outputTokens = tokens.token_count;
|
||||
} catch (error) {
|
||||
req.log.error(
|
||||
req.log.warn(
|
||||
error,
|
||||
"Error while counting completion tokens; assuming `max_output_tokens`"
|
||||
);
|
||||
@@ -505,3 +544,14 @@ const copyHttpHeaders: ProxyResHandlerWithBody = async (
|
||||
res.setHeader(key, proxyRes.headers[key] as string);
|
||||
});
|
||||
};
|
||||
|
||||
function getAwsErrorType(header: string | string[] | undefined) {
|
||||
const val = String(header).match(/^(\w+):?/)?.[1];
|
||||
return val || String(header);
|
||||
}
|
||||
|
||||
function assertJsonResponse(body: any): asserts body is Record<string, any> {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected response to be an object");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { Request } from "express";
|
||||
import { config } from "../../../config";
|
||||
import { logQueue } from "../../../shared/prompt-logging";
|
||||
import { getCompletionForService, isCompletionRequest } from "../common";
|
||||
import {
|
||||
getCompletionFromBody,
|
||||
getModelFromBody,
|
||||
isCompletionRequest,
|
||||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
|
||||
@@ -25,17 +29,15 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
|
||||
const promptPayload = getPromptForRequest(req);
|
||||
const promptFlattened = flattenMessages(promptPayload);
|
||||
const response = getCompletionForService({
|
||||
service: req.outboundApi,
|
||||
body: responseBody,
|
||||
});
|
||||
const response = getCompletionFromBody(req, responseBody);
|
||||
const model = getModelFromBody(req, responseBody);
|
||||
|
||||
logQueue.enqueue({
|
||||
endpoint: req.inboundApi,
|
||||
promptRaw: JSON.stringify(promptPayload),
|
||||
promptFlattened,
|
||||
model: response.model, // may differ from the requested model
|
||||
response: response.completion,
|
||||
model,
|
||||
response,
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../logger";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & { isAwsStream?: boolean };
|
||||
type AwsEventStreamMessage = {
|
||||
headers: { ":message-type": "event" | "exception" };
|
||||
payload: { message?: string /** base64 encoded */; bytes?: string };
|
||||
};
|
||||
|
||||
/**
|
||||
* Receives either text chunks or AWS binary event stream chunks and emits
|
||||
* full SSE events.
|
||||
*/
|
||||
export class ServerSentEventStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private parser = new Parser();
|
||||
private partialMessage = "";
|
||||
|
||||
constructor(options?: SSEStreamAdapterOptions) {
|
||||
super(options);
|
||||
this.isAwsStream = options?.isAwsStream || false;
|
||||
|
||||
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
||||
const message = this.processAwsEvent(data);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message, "utf8"));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
const { payload, headers } = event;
|
||||
if (headers[":message-type"] === "exception" || !payload.bytes) {
|
||||
log.error(
|
||||
{ event: JSON.stringify(event) },
|
||||
"Received bad streaming event from AWS"
|
||||
);
|
||||
const message = JSON.stringify(event);
|
||||
return getFakeErrorCompletion("proxy AWS error", message);
|
||||
} else {
|
||||
return `data: ${Buffer.from(payload.bytes, "base64").toString("utf8")}`;
|
||||
}
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
this.parser.write(chunk);
|
||||
} else {
|
||||
// We may receive multiple (or partial) SSE messages in a single chunk,
|
||||
// so we need to buffer and emit separate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = chunk.toString("utf8");
|
||||
const fullMessages = (this.partialMessage + str).split(/\r?\n\r?\n/);
|
||||
this.partialMessage = fullMessages.pop() || "";
|
||||
|
||||
for (const message of fullMessages) {
|
||||
this.push(message);
|
||||
}
|
||||
}
|
||||
callback();
|
||||
} catch (error) {
|
||||
this.emit("error", error);
|
||||
callback(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getFakeErrorCompletion(type: string, message: string) {
|
||||
const content = `\`\`\`\n[${type}: ${message}]\n\`\`\`\n`;
|
||||
const fakeEvent = {
|
||||
log_id: "aws-proxy-sse-message",
|
||||
stop_reason: type,
|
||||
completion:
|
||||
"\nProxy encountered an error during streaming response.\n" + content,
|
||||
truncated: false,
|
||||
stop: null,
|
||||
model: "",
|
||||
};
|
||||
return `data: ${JSON.stringify(fakeEvent)}\n\n`;
|
||||
}
|
||||
Reference in New Issue
Block a user