Revert "New 2.5 flash thinking budget parameter"
This reverts commit 2f8538519b
This commit is contained in:
+225
-170
@@ -1,184 +1,239 @@
|
||||
import { z } from "zod";
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { v4 } from "uuid";
|
||||
import { GoogleAIKey, keyPool } from "../shared/key-management";
|
||||
import { config } from "../config";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import {
|
||||
flattenOpenAIMessageContent,
|
||||
OpenAIV1ChatCompletionSchema,
|
||||
} from "./openai";
|
||||
import { APIFormatTransformer } from "./index";
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
} from "./middleware/request";
|
||||
import { ProxyResHandlerWithBody } from "./middleware/response";
|
||||
import { addGoogleAIKey } from "./middleware/request/mutators/add-google-ai-key";
|
||||
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
|
||||
import axios from "axios";
|
||||
|
||||
const TextPartSchema = z.object({
|
||||
text: z.string(),
|
||||
thought: z.boolean().optional()
|
||||
});
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const InlineDataPartSchema = z.object({
|
||||
inlineData: z.object({
|
||||
mimeType: z.string(),
|
||||
data: z.string(),
|
||||
}),
|
||||
});
|
||||
// Cache for native Google AI models
|
||||
let nativeModelsCache: any = null;
|
||||
let nativeModelsCacheTime = 0;
|
||||
|
||||
const PartSchema = z.union([TextPartSchema, InlineDataPartSchema]);
|
||||
// https://ai.google.dev/models/gemini
|
||||
// TODO: list models https://ai.google.dev/tutorials/rest_quickstart#list_models
|
||||
|
||||
const GoogleAIV1ContentSchema = z.object({
|
||||
parts: z
|
||||
.union([PartSchema, z.array(PartSchema)])
|
||||
.transform((val) => (Array.isArray(val) ? val : [val])),
|
||||
role: z.enum(["user", "model"]).optional(),
|
||||
});
|
||||
|
||||
|
||||
const SafetySettingsSchema = z
|
||||
.array(
|
||||
z.object({
|
||||
category: z.enum([
|
||||
"HARM_CATEGORY_HARASSMENT",
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
]),
|
||||
threshold: z.enum([
|
||||
"OFF",
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
|
||||
]),
|
||||
})
|
||||
)
|
||||
.optional();
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
|
||||
export const GoogleAIV1GenerateContentSchema = z
|
||||
.object({
|
||||
model: z.string().max(100), //actually specified in path but we need it for the router
|
||||
stream: z.boolean().optional().default(false), // also used for router
|
||||
contents: z.array(GoogleAIV1ContentSchema),
|
||||
tools: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: SafetySettingsSchema,
|
||||
systemInstruction: GoogleAIV1ContentSchema.optional(),
|
||||
// quick fix for SillyTavern, which uses camel case field names for everything
|
||||
// except for system_instruction where it randomly uses snake case.
|
||||
// google api evidently accepts either case.
|
||||
system_instruction: GoogleAIV1ContentSchema.optional(),
|
||||
generationConfig: z
|
||||
.object({
|
||||
temperature: z.number().min(0).max(2).optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
.number()
|
||||
.int()
|
||||
.optional()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v, 4096)), // TODO: Add config
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().min(0).max(1).optional(),
|
||||
topK: z.number().min(1).max(40).optional(),
|
||||
stopSequences: z.array(z.string().max(500)).max(5).optional(),
|
||||
thinkingConfig: z.object({includeThoughts: z.boolean().optional()}).optional(),
|
||||
// Support for new Gemini 2.5 thinking budget
|
||||
thinking_budget: z.union([
|
||||
z.literal("auto"),
|
||||
z.number().int().min(0).max(24576)
|
||||
]).optional()
|
||||
})
|
||||
.default({}),
|
||||
})
|
||||
.strip();
|
||||
export type GoogleAIChatMessage = z.infer<
|
||||
typeof GoogleAIV1GenerateContentSchema
|
||||
>["contents"][0];
|
||||
|
||||
export const transformOpenAIToGoogleAI: APIFormatTransformer<
|
||||
typeof GoogleAIV1GenerateContentSchema
|
||||
> = async (req) => {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
...body,
|
||||
model: "gpt-3.5-turbo",
|
||||
});
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Google AI request"
|
||||
);
|
||||
throw result.error;
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const foundNames = new Set<string>();
|
||||
const contents = messages
|
||||
.map((m) => {
|
||||
const role = m.role === "assistant" ? "model" : "user";
|
||||
// Detects character names so we can set stop sequences for them as Gemini
|
||||
// is prone to continuing as the next character.
|
||||
// If names are not available, we'll still try to prefix the message
|
||||
// with generic names so we can set stops for them but they don't work
|
||||
// as well as real names.
|
||||
const text = flattenOpenAIMessageContent(m.content);
|
||||
const propName = m.name?.trim();
|
||||
const textName =
|
||||
m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim();
|
||||
const name =
|
||||
propName || textName || (role === "model" ? "Character" : "User");
|
||||
if (!config.googleAIKey) return { object: "list", data: [] };
|
||||
|
||||
foundNames.add(name);
|
||||
const keys = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
|
||||
if (keys.length === 0) {
|
||||
modelsCache = { object: "list", data: [] };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
// Prefixing messages with their character name seems to help avoid
|
||||
// Gemini trying to continue as the next character, or at the very least
|
||||
// ensures it will hit the stop sequence. Otherwise it will start a new
|
||||
// paragraph and switch perspectives.
|
||||
// The response will be very likely to include this prefix so frontends
|
||||
// will need to strip it out.
|
||||
const textPrefix = textName ? "" : `${name}: `;
|
||||
return {
|
||||
parts: [{ text: textPrefix + text }],
|
||||
role: m.role === "assistant" ? ("model" as const) : ("user" as const),
|
||||
};
|
||||
})
|
||||
.reduce<GoogleAIChatMessage[]>((acc, msg) => {
|
||||
const last = acc[acc.length - 1];
|
||||
if (last?.role === msg.role && 'text' in last.parts[0] && 'text' in msg.parts[0]) {
|
||||
last.parts[0].text += "\n\n" + msg.parts[0].text;
|
||||
} else {
|
||||
acc.push(msg);
|
||||
}
|
||||
return acc;
|
||||
}, []);
|
||||
// Get all model IDs from keys, excluding any with "bard" in the name
|
||||
const modelIds = Array.from(
|
||||
new Set(keys.map((k) => k.modelIds).flat())
|
||||
).filter((id) => id.startsWith("models/") && !id.includes("bard"));
|
||||
|
||||
const models = modelIds.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "google",
|
||||
permission: [],
|
||||
root: "google",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
? rest.stop
|
||||
: [rest.stop]
|
||||
: [];
|
||||
stops.push(...Array.from(foundNames).map((name) => `\n${name}:`));
|
||||
stops = [...new Set(stops)].slice(0, 5);
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return {
|
||||
model: req.body.model,
|
||||
stream: rest.stream,
|
||||
contents,
|
||||
tools: [],
|
||||
generationConfig: {
|
||||
maxOutputTokens: rest.max_tokens,
|
||||
stopSequences: stops,
|
||||
topP: rest.top_p,
|
||||
topK: 40, // openai schema doesn't have this, google ai defaults to 40
|
||||
temperature: rest.temperature,
|
||||
},
|
||||
safetySettings: [
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_CIVIC_INTEGRITY", threshold: "BLOCK_NONE" },
|
||||
],
|
||||
};
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
export function containsImageContent(contents: GoogleAIChatMessage[]): boolean {
|
||||
return contents.some(content => {
|
||||
const parts = Array.isArray(content.parts) ? content.parts : [content.parts];
|
||||
return parts.some(part => 'inlineData' in part);
|
||||
});
|
||||
// Function to fetch native models from Google AI API
|
||||
const getNativeModelsResponse = async () => {
|
||||
if (new Date().getTime() - nativeModelsCacheTime < 1000 * 60) {
|
||||
return nativeModelsCache;
|
||||
}
|
||||
|
||||
if (!config.googleAIKey) return { models: [] };
|
||||
|
||||
const keys = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
|
||||
if (keys.length === 0) {
|
||||
nativeModelsCache = { models: [] };
|
||||
nativeModelsCacheTime = new Date().getTime();
|
||||
return nativeModelsCache;
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the first available key to fetch models
|
||||
const key = keys[0];
|
||||
const apiVersion = "v1beta"; // Use the latest API version
|
||||
const url = `https://generativelanguage.googleapis.com/${apiVersion}/models`;
|
||||
|
||||
const response = await axios.get(url, {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
params: {
|
||||
key: key.key,
|
||||
},
|
||||
});
|
||||
|
||||
// We'll update the model cache but won't attempt to update the keys
|
||||
// This avoids type issues while still keeping our models list up to date
|
||||
nativeModelsCache = response.data;
|
||||
nativeModelsCacheTime = new Date().getTime();
|
||||
return nativeModelsCache;
|
||||
} catch (error) {
|
||||
console.error("Error fetching Google AI models:", error);
|
||||
// Return empty model list on error
|
||||
return { models: [] };
|
||||
}
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req: Request, res: any) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
// Native Gemini API model list request
|
||||
const handleNativeModelRequest: RequestHandler = async (_req: Request, res: any) => {
|
||||
try {
|
||||
const modelsResponse = await getNativeModelsResponse();
|
||||
res.status(200).json(modelsResponse);
|
||||
} catch (error) {
|
||||
console.error("Error in handleNativeModelRequest:", error);
|
||||
res.status(500).json({ error: "Failed to fetch models" });
|
||||
}
|
||||
};
|
||||
|
||||
const googleAIBlockingResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
let newBody = body;
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Google AI response to OpenAI format");
|
||||
newBody = transformGoogleAIResponse(body, req);
|
||||
}
|
||||
|
||||
res.status(200).json({ ...newBody, proxy: body.proxy });
|
||||
};
|
||||
|
||||
function transformGoogleAIResponse(
|
||||
resBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }];
|
||||
const content = parts[0].text.replace(/^(.{0,50}?): /, () => "");
|
||||
return {
|
||||
id: "goo-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: { role: "assistant", content },
|
||||
finish_reason: resBody.candidates[0].finishReason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const googleAIProxy = createQueuedProxyMiddleware({
|
||||
target: ({ signedRequest }: { signedRequest: any }) => {
|
||||
if (!signedRequest) throw new Error("Must sign request before proxying");
|
||||
const { protocol, hostname} = signedRequest;
|
||||
return `${protocol}//${hostname}`;
|
||||
},
|
||||
mutations: [addGoogleAIKey, finalizeSignedRequest],
|
||||
blockingResponseHandler: googleAIBlockingResponseHandler,
|
||||
});
|
||||
|
||||
const googleAIRouter = Router();
|
||||
googleAIRouter.get("/v1/models", handleModelRequest);
|
||||
googleAIRouter.get("/:apiVersion(v1alpha|v1beta)/models", handleNativeModelRequest);
|
||||
|
||||
// Native Google AI chat completion endpoint
|
||||
googleAIRouter.post(
|
||||
"/:apiVersion(v1alpha|v1beta)/models/:modelId:(generateContent|streamGenerateContent)",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "google-ai", outApi: "google-ai", service: "google-ai" },
|
||||
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-Google AI compatibility endpoint.
|
||||
googleAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
|
||||
{ afterTransform: [maybeReassignModel] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
function setStreamFlag(req: Request) {
|
||||
const isStreaming = req.url.includes("streamGenerateContent");
|
||||
if (isStreaming) {
|
||||
req.body.stream = true;
|
||||
req.isStreaming = true;
|
||||
} else {
|
||||
req.body.stream = false;
|
||||
req.isStreaming = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces requests for non-Google AI models with gemini-1.5-pro-latest.
|
||||
* Also strips models/ from the beginning of the model IDs.
|
||||
**/
|
||||
function maybeReassignModel(req: Request) {
|
||||
// Ensure model is on body as a lot of middleware will expect it.
|
||||
const model = req.body.model || req.url.split("/").pop()?.split(":").shift();
|
||||
if (!model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
req.body.model = model;
|
||||
|
||||
const requested = model;
|
||||
if (requested.startsWith("models/")) {
|
||||
req.body.model = requested.slice("models/".length);
|
||||
}
|
||||
|
||||
if (requested.includes("gemini")) {
|
||||
return;
|
||||
}
|
||||
|
||||
req.log.info({ requested }, "Reassigning model to gemini-1.5-pro-latest");
|
||||
req.body.model = "gemini-1.5-pro-latest";
|
||||
}
|
||||
|
||||
export const googleAI = googleAIRouter;
|
||||
|
||||
Reference in New Issue
Block a user