Support for GPT-4-Vision (khanon/oai-reverse-proxy!54)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import { RequestPreprocessor } from "./index";
|
||||
import { countTokens, OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { countTokens } from "../../../shared/tokenization";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
@@ -13,7 +14,7 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
switch (service) {
|
||||
case "openai": {
|
||||
req.outputTokens = req.body.max_tokens;
|
||||
const prompt: OpenAIPromptMessage[] = req.body.messages;
|
||||
const prompt: OpenAIChatMessage[] = req.body.messages;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ import { config } from "../../../config";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { UserInputError } from "../../../shared/errors";
|
||||
import { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
const rejectedClients = new Map<string, number>();
|
||||
|
||||
@@ -53,9 +54,16 @@ function getPromptFromRequest(req: Request) {
|
||||
return body.prompt;
|
||||
case "openai":
|
||||
return body.messages
|
||||
.map(
|
||||
(m: { content: string; role: string }) => `${m.role}: ${m.content}`
|
||||
)
|
||||
.map((msg: OpenAIChatMessage) => {
|
||||
const text = Array.isArray(msg.content)
|
||||
? msg.content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
})
|
||||
.join()
|
||||
: msg.content;
|
||||
return `${msg.role}: ${text}`;
|
||||
})
|
||||
.join("\n\n");
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../config";
|
||||
import { OpenAIPromptMessage } from "../../../shared/tokenization";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../common";
|
||||
import { RequestPreprocessor } from ".";
|
||||
import { APIFormat } from "../../../shared/key-management";
|
||||
@@ -9,6 +8,8 @@ import { APIFormat } from "../../../shared/key-management";
|
||||
const CLAUDE_OUTPUT_MAX = config.maxOutputTokensAnthropic;
|
||||
const OPENAI_OUTPUT_MAX = config.maxOutputTokensOpenAI;
|
||||
|
||||
// TODO: move schemas to shared
|
||||
|
||||
// https://console.anthropic.com/docs/api/reference#-v1-complete
|
||||
export const AnthropicV1CompleteSchema = z.object({
|
||||
model: z.string(),
|
||||
@@ -29,12 +30,25 @@ export const AnthropicV1CompleteSchema = z.object({
|
||||
});
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat/create
|
||||
const OpenAIV1ChatCompletionSchema = z.object({
|
||||
const OpenAIV1ChatContentArraySchema = z.array(
|
||||
z.union([
|
||||
z.object({ type: z.literal("text"), text: z.string() }),
|
||||
z.object({
|
||||
type: z.literal("image_url"),
|
||||
image_url: z.object({
|
||||
url: z.string().url(),
|
||||
detail: z.enum(["low", "auto", "high"]).optional().default("auto"),
|
||||
}),
|
||||
}),
|
||||
])
|
||||
);
|
||||
|
||||
export const OpenAIV1ChatCompletionSchema = z.object({
|
||||
model: z.string(),
|
||||
messages: z.array(
|
||||
z.object({
|
||||
role: z.enum(["system", "user", "assistant"]),
|
||||
content: z.string(),
|
||||
content: z.union([z.string(), OpenAIV1ChatContentArraySchema]),
|
||||
name: z.string().optional(),
|
||||
}),
|
||||
{
|
||||
@@ -68,6 +82,10 @@ const OpenAIV1ChatCompletionSchema = z.object({
|
||||
seed: z.number().int().optional(),
|
||||
});
|
||||
|
||||
export type OpenAIChatMessage = z.infer<
|
||||
typeof OpenAIV1ChatCompletionSchema
|
||||
>["messages"][0];
|
||||
|
||||
const OpenAIV1TextCompletionSchema = z
|
||||
.object({
|
||||
model: z
|
||||
@@ -232,7 +250,7 @@ function openaiToOpenaiText(req: Request) {
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
@@ -260,6 +278,9 @@ function openaiToOpenaiImage(req: Request) {
|
||||
|
||||
const { messages } = result.data;
|
||||
const prompt = messages.filter((m) => m.role === "user").pop()?.content;
|
||||
if (Array.isArray(prompt)) {
|
||||
throw new Error("Image generation prompt must be a text message.");
|
||||
}
|
||||
|
||||
if (body.stream) {
|
||||
throw new Error(
|
||||
@@ -304,7 +325,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAiChatMessages(messages);
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
@@ -336,7 +357,7 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
};
|
||||
}
|
||||
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
||||
export function openAIMessagesToClaudePrompt(messages: OpenAIChatMessage[]) {
|
||||
return (
|
||||
messages
|
||||
.map((m) => {
|
||||
@@ -348,17 +369,17 @@ export function openAIMessagesToClaudePrompt(messages: OpenAIPromptMessage[]) {
|
||||
} else if (role === "user") {
|
||||
role = "Human";
|
||||
}
|
||||
const name = m.name?.trim();
|
||||
const content = flattenOpenAIMessageContent(m.content);
|
||||
// https://console.anthropic.com/docs/prompt-design
|
||||
// `name` isn't supported by Anthropic but we can still try to use it.
|
||||
return `\n\n${role}: ${m.name?.trim() ? `(as ${m.name}) ` : ""}${
|
||||
m.content
|
||||
}`;
|
||||
return `\n\n${role}: ${name ? `(as ${name}) ` : ""}${content}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
}
|
||||
|
||||
function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||
function flattenOpenAIChatMessages(messages: OpenAIChatMessage[]) {
|
||||
// Temporary to allow experimenting with prompt strategies
|
||||
const PROMPT_VERSION: number = 1;
|
||||
switch (PROMPT_VERSION) {
|
||||
@@ -375,7 +396,7 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||
} else if (role === "user") {
|
||||
role = "User";
|
||||
}
|
||||
return `\n\n${role}: ${m.content}`;
|
||||
return `\n\n${role}: ${flattenOpenAIMessageContent(m.content)}`;
|
||||
})
|
||||
.join("") + "\n\nAssistant:"
|
||||
);
|
||||
@@ -387,10 +408,23 @@ function flattenOpenAiChatMessages(messages: OpenAIPromptMessage[]) {
|
||||
if (role === "system") {
|
||||
role = "System: ";
|
||||
}
|
||||
return `\n\n${role}${m.content}`;
|
||||
return `\n\n${role}${flattenOpenAIMessageContent(m.content)}`;
|
||||
})
|
||||
.join("");
|
||||
default:
|
||||
throw new Error(`Unknown prompt version: ${PROMPT_VERSION}`);
|
||||
}
|
||||
}
|
||||
|
||||
function flattenOpenAIMessageContent(
|
||||
content: OpenAIChatMessage["content"]
|
||||
): string {
|
||||
return Array.isArray(content)
|
||||
? content
|
||||
.map((contentItem) => {
|
||||
if ("text" in contentItem) return contentItem.text;
|
||||
if ("image_url" in contentItem) return "[ Uploaded Image Omitted ]";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
} from "../common";
|
||||
import { ProxyResHandlerWithBody } from ".";
|
||||
import { assertNever } from "../../../shared/utils";
|
||||
import { OpenAIChatMessage } from "../request/transform-outbound-payload";
|
||||
|
||||
/** If prompt logging is enabled, enqueues the prompt for logging. */
|
||||
export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
@@ -42,11 +43,6 @@ export const logPrompt: ProxyResHandlerWithBody = async (
|
||||
});
|
||||
};
|
||||
|
||||
type OaiMessage = {
|
||||
role: "user" | "assistant" | "system";
|
||||
content: string;
|
||||
};
|
||||
|
||||
type OaiImageResult = {
|
||||
prompt: string;
|
||||
size: string;
|
||||
@@ -58,7 +54,7 @@ type OaiImageResult = {
|
||||
const getPromptForRequest = (
|
||||
req: Request,
|
||||
responseBody: Record<string, any>
|
||||
): string | OaiMessage[] | OaiImageResult => {
|
||||
): string | OpenAIChatMessage[] | OaiImageResult => {
|
||||
// Since the prompt logger only runs after the request has been proxied, we
|
||||
// can assume the body has already been transformed to the target API's
|
||||
// format.
|
||||
@@ -85,13 +81,25 @@ const getPromptForRequest = (
|
||||
};
|
||||
|
||||
const flattenMessages = (
|
||||
val: string | OaiMessage[] | OaiImageResult
|
||||
val: string | OpenAIChatMessage[] | OaiImageResult
|
||||
): string => {
|
||||
if (typeof val === "string") {
|
||||
return val.trim();
|
||||
}
|
||||
if (Array.isArray(val)) {
|
||||
return val.map((m) => `${m.role}: ${m.content}`).join("\n");
|
||||
return val
|
||||
.map(({ content, role }) => {
|
||||
const text = Array.isArray(content)
|
||||
? content
|
||||
.map((c) => {
|
||||
if ("text" in c) return c.text;
|
||||
if ("image_url" in c) return "(( Attached Image ))";
|
||||
})
|
||||
.join("\n")
|
||||
: content;
|
||||
return `${role}: ${text}`;
|
||||
})
|
||||
.join("\n");
|
||||
}
|
||||
return val.prompt.trim();
|
||||
};
|
||||
|
||||
@@ -26,6 +26,7 @@ import { createOnProxyResHandler, ProxyResHandlerWithBody } from "./middleware/r
|
||||
// https://platform.openai.com/docs/models/overview
|
||||
const KNOWN_MODELS = [
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-0314", // EOL 2024-06-13
|
||||
|
||||
+1
-1
@@ -475,7 +475,7 @@ export function registerHeartbeat(req: Request) {
|
||||
const res = req.res!;
|
||||
|
||||
const currentSize = getHeartbeatSize();
|
||||
req.log.info({
|
||||
req.log.debug({
|
||||
currentSize,
|
||||
HEARTBEAT_INTERVAL,
|
||||
PAYLOAD_SCALE_FACTOR,
|
||||
|
||||
+2
-2
@@ -17,8 +17,8 @@ proxyRouter.use((req, _res, next) => {
|
||||
next();
|
||||
});
|
||||
proxyRouter.use(
|
||||
express.json({ limit: "1536kb" }),
|
||||
express.urlencoded({ extended: true, limit: "1536kb" })
|
||||
express.json({ limit: "10mb" }),
|
||||
express.urlencoded({ extended: true, limit: "10mb" })
|
||||
);
|
||||
proxyRouter.use(gatekeeper);
|
||||
proxyRouter.use(checkRisuToken);
|
||||
|
||||
Reference in New Issue
Block a user