Anthropic endpoint improvements (khanon/oai-reverse-proxy!16)

This commit is contained in:
khanon
2023-05-30 03:13:17 +00:00
parent 2c8c81e6dd
commit 6723cbf662
15 changed files with 192 additions and 153 deletions
+16 -9
View File
@@ -9,9 +9,22 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
// Horrible, horrible hack to stop the proxy from complaining about clients
// not sending a model when they are requesting the list of models (which
// requires a key, but obviously not a model).
// TODO: shouldn't even proxy /models to the upstream API, just fake it
// using the models our key pool has available.
req.body.model = "gpt-3.5-turbo";
}
if (!req.inboundApi || !req.outboundApi) {
const err = new Error(
"Request API format missing. Did you forget to add the `setApiFormat` middleware to your route?"
);
req.log.error(
{ in: req.inboundApi, out: req.outboundApi, path: req.path },
err.message
);
throw err;
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
@@ -25,14 +38,8 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
// the requested model is an OpenAI one even though we're actually sending
// an Anthropic request.
// For such cases, ignore the requested model entirely.
// Real Anthropic requests come in via /proxy/anthropic/v1/complete
// The OpenAI-compatible endpoint is /proxy/anthropic/v1/chat/completions
const openaiCompatible =
req.originalUrl === "/proxy/anthropic/v1/chat/completions";
if (openaiCompatible) {
if (req.inboundApi === "openai" && req.outboundApi === "anthropic") {
req.log.debug("Using an Anthropic key for an OpenAI-compatible request");
req.api = "openai";
// We don't assign the model here, that will happen when transforming the
// request body.
assignedKey = keyPool.get("claude-v1");
@@ -45,8 +52,8 @@ export const addKey: ExpressHttpProxyReqCallback = (proxyReq, req) => {
{
key: assignedKey.hash,
model: req.body?.model,
fromApi: req.api,
toApi: assignedKey.service,
fromApi: req.inboundApi,
toApi: req.outboundApi,
},
"Assigned key to request"
);
@@ -10,7 +10,7 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
req
) => {
if (isCompletionRequest(req) && req.body?.max_tokens) {
const requestedMaxTokens = getMaxTokensFromRequest(req);
const requestedMaxTokens = Number.parseInt(getMaxTokensFromRequest(req));
let maxTokens = requestedMaxTokens;
if (typeof requestedMaxTokens !== "number") {
@@ -24,9 +24,9 @@ export const limitOutputTokens: ExpressHttpProxyReqCallback = (
// TODO: this is not going to scale well, need to implement a better way
// of translating request parameters from one API to another.
maxTokens = Math.min(maxTokens, MAX_TOKENS);
if (req.key!.service === "openai") {
if (req.outboundApi === "openai") {
req.body.max_tokens = maxTokens;
} else if (req.key!.service === "anthropic") {
} else if (req.outboundApi === "anthropic") {
req.body.max_tokens_to_sample = maxTokens;
}
@@ -1,7 +1,8 @@
/**
* Transforms a KoboldAI payload into an OpenAI payload.
* @deprecated Kobold input format isn't supported anymore as all popular
* frontends support reverse proxies or changing their base URL.
* frontends support reverse proxies or changing their base URL. It adds too
* many edge cases to be worth maintaining and doesn't work with newer features.
*/
import { logger } from "../../../logger";
import type { ExpressHttpProxyReqCallback } from ".";
@@ -68,7 +69,7 @@ export const transformKoboldPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.api !== "kobold") {
if (req.inboundApi !== "kobold") {
throw new Error("transformKoboldPayload called for non-kobold request.");
}
@@ -4,33 +4,40 @@ import { ExpressHttpProxyReqCallback, isCompletionRequest } from ".";
// https://console.anthropic.com/docs/api/reference#-v1-complete
const AnthropicV1CompleteSchema = z.object({
model: z.string().regex(/^claude-/),
prompt: z.string(),
max_tokens_to_sample: z.number(),
model: z.string().regex(/^claude-/, "Model must start with 'claude-'"),
prompt: z.string({
required_error:
"No prompt found. Are you sending an OpenAI-formatted request to the Claude endpoint?",
}),
max_tokens_to_sample: z.coerce.number(),
stop_sequences: z.array(z.string()).optional(),
stream: z.boolean().optional().default(false),
temperature: z.number().optional().default(1),
top_k: z.number().optional().default(-1),
top_p: z.number().optional().default(-1),
temperature: z.coerce.number().optional().default(1),
top_k: z.coerce.number().optional().default(-1),
top_p: z.coerce.number().optional().default(-1),
metadata: z.any().optional(),
});
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatCompletionSchema = z.object({
model: z.string().regex(/^gpt/),
model: z.string().regex(/^gpt/, "Model must start with 'gpt-'"),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
name: z.string().optional(),
})
}),
{
required_error:
"No prompt found. Are you sending an Anthropic-formatted request to the OpenAI endpoint?",
}
),
temperature: z.number().optional().default(1),
top_p: z.number().optional().default(1),
n: z.literal(1).optional(),
stream: z.boolean().optional().default(false),
stop: z.union([z.string(), z.array(z.string())]).optional(),
max_tokens: z.number().optional(),
max_tokens: z.coerce.number().optional(),
frequency_penalty: z.number().optional().default(0),
presence_penalty: z.number().optional().default(0),
logit_bias: z.any().optional(),
@@ -42,39 +49,47 @@ export const transformOutboundPayload: ExpressHttpProxyReqCallback = (
_proxyReq,
req
) => {
if (req.retryCount > 0 || !isCompletionRequest(req)) {
const sameService = req.inboundApi === req.outboundApi;
const alreadyTransformed = req.retryCount > 0;
const notTransformable = !isCompletionRequest(req);
if (alreadyTransformed || notTransformable) {
return;
}
const inboundService = req.api;
const outboundService = req.key!.service;
if (inboundService === outboundService) {
if (sameService) {
// Just validate, don't transform.
const validator =
req.outboundApi === "openai"
? OpenAIV1ChatCompletionSchema
: AnthropicV1CompleteSchema;
const result = validator.safeParse(req.body);
if (!result.success) {
req.log.error(
{ issues: result.error.issues, params: req.body },
"Request validation failed"
);
throw result.error;
}
return;
}
// Not supported yet and unnecessary as everything supports OpenAI.
if (inboundService === "anthropic" && outboundService === "openai") {
throw new Error(
"Anthropic -> OpenAI request transformation not supported. Provide an OpenAI-compatible payload, or use the /claude endpoint."
);
}
if (inboundService === "openai" && outboundService === "anthropic") {
if (req.inboundApi === "openai" && req.outboundApi === "anthropic") {
req.body = openaiToAnthropic(req.body, req);
return;
}
throw new Error(
`Unsupported transformation: ${inboundService} -> ${outboundService}`
`'${req.inboundApi}' -> '${req.outboundApi}' request proxying is not supported. Make sure your client is configured to use the correct API.`
);
};
function openaiToAnthropic(body: any, req: Request) {
const result = OpenAIV1ChatCompletionSchema.safeParse(body);
if (!result.success) {
// don't log the prompt
const { messages, ...params } = body;
// don't log the prompt (usually `messages` but maybe `prompt` if the user
// misconfigured their client)
const { messages, prompt, ...params } = body;
req.log.error(
{ issues: result.error.issues, params },
"Invalid OpenAI-to-Anthropic request"
@@ -48,31 +48,28 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
// If these differ, the user is using the OpenAI-compatibile endpoint, so
// we need to translate the SSE events into OpenAI completion events for their
// frontend.
const fromApi = req.api;
const toApi = req.key!.service;
if (!req.isStreaming) {
req.log.error(
{ api: req.api, key: req.key?.hash },
`handleStreamedResponse called for non-streaming request, which isn't valid.`
const err = new Error(
"handleStreamedResponse called for non-streaming request."
);
throw new Error("handleStreamedResponse called for non-streaming request.");
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
const key = req.key!;
if (proxyRes.statusCode !== 200) {
// Ensure we use the non-streaming middleware stack since we won't be
// getting any events.
req.isStreaming = false;
req.log.warn(
`Streaming request to ${req.api} returned ${proxyRes.statusCode} status code. Falling back to non-streaming response handler.`
{ statusCode: proxyRes.statusCode, key: key.hash },
`Streaming request returned error status code. Falling back to non-streaming response handler.`
);
return decodeResponseBody(proxyRes, req, res);
}
return new Promise((resolve, reject) => {
req.log.info(
{ api: req.api, key: req.key?.hash },
`Starting to proxy SSE stream.`
);
req.log.info({ key: key.hash }, `Starting to proxy SSE stream.`);
// Queued streaming requests will already have a connection open and headers
// sent due to the heartbeat handler. In that case we can just start
@@ -105,9 +102,9 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes.on(
"data",
withErrorHandling((chunk) => {
// We may receive multiple (or partial) SSE messages in a single chunk, so
// we need to buffer and emit seperate stream events for full messages so
// we can parse/transform them properly.
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit seperate stream events for full
// messages so we can parse/transform them properly.
const str = chunk.toString();
chunkBuffer.push(str);
@@ -126,12 +123,12 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
proxyRes.on(
"full-sse-event",
withErrorHandling((data) => {
const { event, position } = transformEvent(
const { event, position } = transformEvent({
data,
fromApi,
toApi,
lastPosition
);
requestApi: req.inboundApi,
responseApi: req.outboundApi,
lastPosition,
});
fullChunks.push(event);
lastPosition = position;
res.write(event + "\n\n");
@@ -142,20 +139,14 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
"end",
withErrorHandling(() => {
let finalBody = convertEventsToFinalResponse(fullChunks, req);
req.log.info(
{ api: req.api, key: req.key?.hash },
`Finished proxying SSE stream.`
);
req.log.info({ key: key.hash }, `Finished proxying SSE stream.`);
res.end();
resolve(finalBody);
})
);
proxyRes.on("error", (err) => {
req.log.error(
{ error: err, api: req.api, key: req.key?.hash },
`Error while streaming response.`
);
req.log.error({ error: err, key: key.hash }, `Mid-stream error.`);
const fakeErrorEvent = buildFakeSseMessage(
"mid-stream-error",
err.message,
@@ -173,12 +164,17 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
* Transforms SSE events from the given response API into events compatible with
* the API requested by the client.
*/
function transformEvent(
data: string,
requestApi: string,
responseApi: string,
lastPosition: number
) {
function transformEvent({
data,
requestApi,
responseApi,
lastPosition,
}: {
data: string;
requestApi: string;
responseApi: string;
lastPosition: number;
}) {
if (requestApi === responseApi) {
return { position: -1, event: data };
}
@@ -236,7 +232,7 @@ function copyHeaders(proxyRes: http.IncomingMessage, res: Response) {
}
function convertEventsToFinalResponse(events: string[], req: Request) {
if (req.key!.service === "openai") {
if (req.outboundApi === "openai") {
let response: OpenAiChatCompletionResponse = {
id: "",
object: "",
@@ -278,7 +274,7 @@ function convertEventsToFinalResponse(events: string[], req: Request) {
}, response);
return response;
}
if (req.key!.service === "anthropic") {
if (req.outboundApi === "anthropic") {
/*
* Full complete responses from Anthropic are conveniently just the same as
* the final SSE event before the "DONE" event, so we can reuse that
+6 -9
View File
@@ -155,11 +155,9 @@ export const decodeResponseBody: RawResponseBodyHandler = async (
res
) => {
if (req.isStreaming) {
req.log.error(
{ api: req.api, key: req.key?.hash },
`decodeResponseBody called for a streaming request, which isn't valid.`
);
throw new Error("decodeResponseBody called for a streaming request.");
const err = new Error("decodeResponseBody called for a streaming request.");
req.log.error({ stack: err.stack, api: req.inboundApi }, err.message);
throw err;
}
const promise = new Promise<string>((resolve, reject) => {
@@ -273,14 +271,14 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorPayload.proxy_note = `API key is invalid or revoked. ${tryAgainMessage}`;
} else if (statusCode === 429) {
// OpenAI uses this for a bunch of different rate-limiting scenarios.
if (req.key!.service === "openai") {
if (req.outboundApi === "openai") {
handleOpenAIRateLimitError(req, tryAgainMessage, errorPayload);
} else {
handleAnthropicRateLimitError(req, errorPayload);
}
} else if (statusCode === 404) {
// Most likely model not found
if (req.key!.service === "openai") {
if (req.outboundApi === "openai") {
// TODO: this probably doesn't handle GPT-4-32k variants properly if the
// proxy has keys for both the 8k and 32k context models at the same time.
if (errorPayload.error?.code === "model_not_found") {
@@ -290,7 +288,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
errorPayload.proxy_note = `No model was found for this key.`;
}
}
} else if (req.key!.service === "anthropic") {
} else if (req.outboundApi === "anthropic") {
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
}
} else {
@@ -313,7 +311,6 @@ function handleAnthropicRateLimitError(
req: Request,
errorPayload: Record<string, any>
) {
//{"error":{"type":"rate_limit_error","message":"Number of concurrent connections to Claude exceeds your rate limit. Please try again, or contact sales@anthropic.com to discuss your options for a rate limit increase."}}
if (errorPayload.error?.type === "rate_limit_error") {
keyPool.markRateLimited(req.key!);
if (config.queueMode !== "none") {
+26 -11
View File
@@ -1,3 +1,4 @@
import { Request } from "express";
import { config } from "../../../config";
import { AIService } from "../../../key-management";
import { logQueue } from "../../../prompt-logging";
@@ -22,19 +23,19 @@ export const logPrompt: ProxyResHandlerWithBody = async (
return;
}
const model = req.body.model;
const promptFlattened = flattenMessages(req.body.messages);
const promptPayload = getPromptForRequest(req);
const promptFlattened = flattenMessages(promptPayload);
const response = getResponseForService({
service: req.key!.service,
service: req.outboundApi,
body: responseBody,
});
logQueue.enqueue({
model,
endpoint: req.api,
promptRaw: JSON.stringify(req.body.messages),
endpoint: req.inboundApi,
promptRaw: JSON.stringify(promptPayload),
promptFlattened,
response,
model: response.model, // may differ from the requested model
response: response.completion,
});
};
@@ -43,7 +44,21 @@ type OaiMessage = {
content: string;
};
const flattenMessages = (messages: OaiMessage[]): string => {
const getPromptForRequest = (req: Request): string | OaiMessage[] => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
if (req.outboundApi === "anthropic") {
return req.body.prompt;
} else {
return req.body.messages;
}
};
const flattenMessages = (messages: string | OaiMessage[]): string => {
if (typeof messages === "string") {
return messages;
}
return messages.map((m) => `${m.role}: ${m.content}`).join("\n");
};
@@ -53,10 +68,10 @@ const getResponseForService = ({
}: {
service: AIService;
body: Record<string, any>;
}) => {
}): { completion: string; model: string } => {
if (service === "anthropic") {
return body.completion.trim();
return { completion: body.completion.trim(), model: body.model };
} else {
return body.choices[0].message.content;
return { completion: body.choices[0].message.content, model: body.model };
}
};