Support for GPT-4-Vision (khanon/oai-reverse-proxy!54)

This commit is contained in:
khanon
2023-11-19 05:06:21 +00:00
parent 7f2f324e26
commit f29049f993
13 changed files with 198 additions and 52 deletions
@@ -1,6 +1,7 @@
import { RequestPreprocessor } from "./index";
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
import { countTokens } from "../../../shared/tokenization";
import { assertNever } from "../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload";
/**
* Given a request with an already-transformed body, counts the number of
@@ -13,7 +14,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
switch (service) {
case "openai": {
req.outputTokens = req.body.max_tokens;
const prompt: OpenAIPromptMessage[] = req.body.messages;
const prompt: OpenAIChatMessage[] = req.body.messages;
result = await countTokens({ req, prompt, service });
break;
}
@@ -3,6 +3,7 @@ import { config } from "../../../config";
import { assertNever } from "../../../shared/utils";
import { RequestPreprocessor } from ".";
import { UserInputError } from "../../../shared/errors";
import { OpenAIChatMessage } from "./transform-outbound-payload";
const rejectedClients = new Map<string, number>();
@@ -53,9 +54,16 @@ function getPromptFromRequest(req: Request) {
return body.prompt;
case "openai":
return body.messages
.map(
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
)
.map((msg: OpenAIChatMessage) => {
const text = Array.isArray(msg.content)
? msg.content
.map((c) => {
if ("text" in c) return c.text;
})
.join()
: msg.content;
return `${msg.role}: ${text}`;
})
.join("\n\n");
case "openai-text":
case "openai-image":
@@ -1,7 +1,6 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../config";
import { OpenAIPromptMessage } from "../../../shared/tokenization";
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
import { RequestPreprocessor } from ".";
import { APIFormat } from "../../../shared/key-management";
@@ -9,6 +8,8 @@ import { APIFormat } from "../../../shared/key-management";
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
// TODO: move schemas to shared
// https://console.anthropic.com/docs/api/reference#-v1-complete
export const AnthropicV1CompleteSchema = z.object({
model: z.string(),
@@ -29,12 +30,25 @@ export const AnthropicV1CompleteSchema = z.object({
});
// https://platform.openai.com/docs/api-reference/chat/create
const OpenAIV1ChatCompletionSchema = z.object({
const OpenAIV1ChatContentArraySchema = z.array(
z.union([
z.object({ type: z.literal("text"), text: z.string() }),
z.object({
type: z.literal("image_url"),
image_url: z.object({
url: z.string().url(),
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
}),
}),
])
);
export const OpenAIV1ChatCompletionSchema = z.object({
model: z.string(),
messages: z.array(
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
name: z.string().optional(),
}),
{
@@ -68,6 +82,10 @@ const OpenAIV1ChatCompletionSchema = z.object({
seed: z.number().int().optional(),
});
export type OpenAIChatMessage = z.infer<
typeof OpenAIV1ChatCompletionSchema
>["messages"][0];
const OpenAIV1TextCompletionSchema = z
.object({
model: z
@@ -232,7 +250,7 @@ function openaiToOpenaiText(req: Request) {
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAiChatMessages(messages);
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
@@ -260,6 +278,9 @@ function openaiToOpenaiImage(req: Request) {
const { messages } = result.data;
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
if (Array.isArray(prompt)) {
throw new Error("Image generation prompt must be a text message.");
}
if (body.stream) {
throw new Error(
@@ -304,7 +325,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAiChatMessages(messages);
const prompt = flattenOpenAIChatMessages(messages);
let stops = rest.stop
? Array.isArray(rest.stop)
@@ -336,7 +357,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
};
}
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
return (
messages
.map((m) => {
@@ -348,17 +369,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
} else if (role === "user") {
role = "Human";
}
const name = m.name?.trim();
const content = flattenOpenAIMessageContent(m.content);
// https://console.anthropic.com/docs/prompt-design
// `name` isn't supported by Anthropic but we can still try to use it.
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
m.content
}`;
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
})
.join("") + "\n\nAssistant:"
);
}
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
// Temporary to allow experimenting with prompt strategies
const PROMPT_VERSION: number = 1;
switch (PROMPT_VERSION) {
@@ -375,7 +396,7 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
} else if (role === "user") {
role = "User";
}
return `\n\n${role}: ${m.content}`;
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
})
.join("") + "\n\nAssistant:"
);
@@ -387,10 +408,23 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
if (role === "system") {
role = "System: ";
}
return `\n\n${role}${m.content}`;
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
})
.join("");
default:
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
}
}
function flattenOpenAIMessageContent(
content: OpenAIChatMessage["content"]
): string {
return Array.isArray(content)
? content
.map((contentItem) => {
if ("text" in contentItem) return contentItem.text;
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
})
.join("\n")
: content;
}
+16 -8
View File
@@ -9,6 +9,7 @@ import {
} from "../common";
import { ProxyResHandlerWithBody } from ".";
import { assertNever } from "../../../shared/utils";
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
/** If prompt logging is enabled, enqueues the prompt for logging. */
export const logPrompt: ProxyResHandlerWithBody = async (
@@ -42,11 +43,6 @@ export const logPrompt: ProxyResHandlerWithBody = async (
});
};
type OaiMessage = {
role: "user" | "assistant" | "system";
content: string;
};
type OaiImageResult = {
prompt: string;
size: string;
@@ -58,7 +54,7 @@ type OaiImageResult = {
const getPromptForRequest = (
req: Request,
responseBody: Record<string, any>
): string | OaiMessage[] | OaiImageResult => {
): string | OpenAIChatMessage[] | OaiImageResult => {
// Since the prompt logger only runs after the request has been proxied, we
// can assume the body has already been transformed to the target API's
// format.
@@ -85,13 +81,25 @@ const getPromptForRequest = (
};
const flattenMessages = (
val: string | OaiMessage[] | OaiImageResult
val: string | OpenAIChatMessage[] | OaiImageResult
): string => {
if (typeof val === "string") {
return val.trim();
}
if (Array.isArray(val)) {
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
return val
.map(({ content, role }) => {
const text = Array.isArray(content)
? content
.map((c) => {
if ("text" in c) return c.text;
if ("image_url" in c) return "(( Attached Image ))";
})
.join("\n")
: content;
return `${role}: ${text}`;
})
.join("\n");
}
return val.prompt.trim();
};
+1
View File
@@ -26,6 +26,7 @@ import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/r
// https://platform.openai.com/docs/models/overview
const KNOWN_MODELS = [
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0613",
"gpt-4-0314", // EOL 2024-06-13
+1 -1
View File
@@ -475,7 +475,7 @@ export function registerHeartbeat(req: Request) {
const res = req.res!;
const currentSize = getHeartbeatSize();
req.log.info({
req.log.debug({
currentSize,
HEARTBEAT_INTERVAL,
PAYLOAD_SCALE_FACTOR,
+2 -2
View File
@@ -17,8 +17,8 @@ proxyRouter.use((req, _res, next) => {
next();
});
proxyRouter.use(
express.json({ limit: "1536kb" }),
express.urlencoded({ extended: true, limit: "1536kb" })
express.json({ limit: "10mb" }),
express.urlencoded({ extended: true, limit: "10mb" })
);
proxyRouter.use(gatekeeper);
proxyRouter.use(checkRisuToken);