Revert "New 2.5 flash thinking budget parameter"

This reverts commit 2f8538519b
This commit is contained in:
reanon
2025-04-17 21:12:39 +00:00
parent 2f8538519b
commit 1e8f55f96d
+225 -170
View File
@@ -1,184 +1,239 @@
import { z } from "zod";
import { Request, RequestHandler, Router } from "express";
import { v4 } from "uuid";
import { GoogleAIKey, keyPool } from "../shared/key-management";
import { config } from "../config";
import { ipLimiter } from "./rate-limit";
import {
flattenOpenAIMessageContent,
OpenAIV1ChatCompletionSchema,
} from "./openai";
import { APIFormatTransformer } from "./index";
createPreprocessorMiddleware,
finalizeSignedRequest,
} from "./middleware/request";
import { ProxyResHandlerWithBody } from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/mutators/add-google-ai-key";
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
import axios from "axios";
const TextPartSchema = z.object({
text: z.string(),
thought: z.boolean().optional()
});
let modelsCache: any = null;
let modelsCacheTime = 0;
const InlineDataPartSchema = z.object({
inlineData: z.object({
mimeType: z.string(),
data: z.string(),
}),
});
// Cache for native Google AI models
let nativeModelsCache: any = null;
let nativeModelsCacheTime = 0;
const PartSchema = z.union([TextPartSchema, InlineDataPartSchema]);
// https://ai.google.dev/models/gemini
// TODO: list models https://ai.google.dev/tutorials/rest_quickstart#list_models
const GoogleAIV1ContentSchema = z.object({
parts: z
.union([PartSchema, z.array(PartSchema)])
.transform((val) => (Array.isArray(val) ? val : [val])),
role: z.enum(["user", "model"]).optional(),
});
const SafetySettingsSchema = z
.array(
z.object({
category: z.enum([
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
]),
threshold: z.enum([
"OFF",
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
]),
})
)
.optional();
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
export const GoogleAIV1GenerateContentSchema = z
.object({
model: z.string().max(100), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(GoogleAIV1ContentSchema),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: SafetySettingsSchema,
systemInstruction: GoogleAIV1ContentSchema.optional(),
// quick fix for SillyTavern, which uses camel case field names for everything
// except for system_instruction where it randomly uses snake case.
// google api evidently accepts either case.
system_instruction: GoogleAIV1ContentSchema.optional(),
generationConfig: z
.object({
temperature: z.number().min(0).max(2).optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 4096)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().min(0).max(1).optional(),
topK: z.number().min(1).max(40).optional(),
stopSequences: z.array(z.string().max(500)).max(5).optional(),
thinkingConfig: z.object({includeThoughts: z.boolean().optional()}).optional(),
// Support for new Gemini 2.5 thinking budget
thinking_budget: z.union([
z.literal("auto"),
z.number().int().min(0).max(24576)
]).optional()
})
.default({}),
})
.strip();
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
export const transformOpenAIToGoogleAI: APIFormatTransformer<
typeof GoogleAIV1GenerateContentSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Google AI request"
);
throw result.error;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
const { messages, ...rest } = result.data;
const foundNames = new Set<string>();
const contents = messages
.map((m) => {
const role = m.role === "assistant" ? "model" : "user";
// Detects character names so we can set stop sequences for them as Gemini
// is prone to continuing as the next character.
// If names are not available, we'll still try to prefix the message
// with generic names so we can set stops for them but they don't work
// as well as real names.
const text = flattenOpenAIMessageContent(m.content);
const propName = m.name?.trim();
const textName =
m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim();
const name =
propName || textName || (role === "model" ? "Character" : "User");
if (!config.googleAIKey) return { object: "list", data: [] };
foundNames.add(name);
const keys = keyPool
.list()
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
if (keys.length === 0) {
modelsCache = { object: "list", data: [] };
modelsCacheTime = new Date().getTime();
return modelsCache;
}
// Prefixing messages with their character name seems to help avoid
// Gemini trying to continue as the next character, or at the very least
// ensures it will hit the stop sequence. Otherwise it will start a new
// paragraph and switch perspectives.
// The response will be very likely to include this prefix so frontends
// will need to strip it out.
const textPrefix = textName ? "" : `${name}: `;
return {
parts: [{ text: textPrefix + text }],
role: m.role === "assistant" ? ("model" as const) : ("user" as const),
};
})
.reduce<GoogleAIChatMessage[]>((acc, msg) => {
const last = acc[acc.length - 1];
if (last?.role === msg.role && 'text' in last.parts[0] && 'text' in msg.parts[0]) {
last.parts[0].text += "\n\n" + msg.parts[0].text;
} else {
acc.push(msg);
}
return acc;
}, []);
// Get all model IDs from keys, excluding any with "bard" in the name
const modelIds = Array.from(
new Set(keys.map((k) => k.modelIds).flat())
).filter((id) => id.startsWith("models/") && !id.includes("bard"));
const models = modelIds.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "google",
parent: null,
}));
let stops = rest.stop
? Array.isArray(rest.stop)
? rest.stop
: [rest.stop]
: [];
stops.push(...Array.from(foundNames).map((name) => `\n${name}:`));
stops = [...new Set(stops)].slice(0, 5);
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return {
model: req.body.model,
stream: rest.stream,
contents,
tools: [],
generationConfig: {
maxOutputTokens: rest.max_tokens,
stopSequences: stops,
topP: rest.top_p,
topK: 40, // openai schema doesn't have this, google ai defaults to 40
temperature: rest.temperature,
},
safetySettings: [
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_CIVIC_INTEGRITY", threshold: "BLOCK_NONE" },
],
};
return modelsCache;
};
export function containsImageContent(contents: GoogleAIChatMessage[]): boolean {
return contents.some(content => {
const parts = Array.isArray(content.parts) ? content.parts : [content.parts];
return parts.some(part => 'inlineData' in part);
});
// Function to fetch native models from Google AI API
const getNativeModelsResponse = async () => {
if (new Date().getTime() - nativeModelsCacheTime < 1000 * 60) {
return nativeModelsCache;
}
if (!config.googleAIKey) return { models: [] };
const keys = keyPool
.list()
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
if (keys.length === 0) {
nativeModelsCache = { models: [] };
nativeModelsCacheTime = new Date().getTime();
return nativeModelsCache;
}
try {
// Use the first available key to fetch models
const key = keys[0];
const apiVersion = "v1beta"; // Use the latest API version
const url = `https://generativelanguage.googleapis.com/${apiVersion}/models`;
const response = await axios.get(url, {
headers: {
"Content-Type": "application/json",
},
params: {
key: key.key,
},
});
// We'll update the model cache but won't attempt to update the keys
// This avoids type issues while still keeping our models list up to date
nativeModelsCache = response.data;
nativeModelsCacheTime = new Date().getTime();
return nativeModelsCache;
} catch (error) {
console.error("Error fetching Google AI models:", error);
// Return empty model list on error
return { models: [] };
}
};
const handleModelRequest: RequestHandler = (_req: Request, res: any) => {
res.status(200).json(getModelsResponse());
};
// Native Gemini API model list request
const handleNativeModelRequest: RequestHandler = async (_req: Request, res: any) => {
try {
const modelsResponse = await getNativeModelsResponse();
res.status(200).json(modelsResponse);
} catch (error) {
console.error("Error in handleNativeModelRequest:", error);
res.status(500).json({ error: "Failed to fetch models" });
}
};
const googleAIBlockingResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
let newBody = body;
if (req.inboundApi === "openai") {
req.log.info("Transforming Google AI response to OpenAI format");
newBody = transformGoogleAIResponse(body, req);
}
res.status(200).json({ ...newBody, proxy: body.proxy });
};
function transformGoogleAIResponse(
resBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }];
const content = parts[0].text.replace(/^(.{0,50}?): /, () => "");
return {
id: "goo-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: { role: "assistant", content },
finish_reason: resBody.candidates[0].finishReason,
index: 0,
},
],
};
}
const googleAIProxy = createQueuedProxyMiddleware({
target: ({ signedRequest }: { signedRequest: any }) => {
if (!signedRequest) throw new Error("Must sign request before proxying");
const { protocol, hostname} = signedRequest;
return `${protocol}//${hostname}`;
},
mutations: [addGoogleAIKey, finalizeSignedRequest],
blockingResponseHandler: googleAIBlockingResponseHandler,
});
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
googleAIRouter.get("/:apiVersion(v1alpha|v1beta)/models", handleNativeModelRequest);
// Native Google AI chat completion endpoint
googleAIRouter.post(
"/:apiVersion(v1alpha|v1beta)/models/:modelId:(generateContent|streamGenerateContent)",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "google-ai", outApi: "google-ai", service: "google-ai" },
{ beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] }
),
googleAIProxy
);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [maybeReassignModel] }
),
googleAIProxy
);
function setStreamFlag(req: Request) {
const isStreaming = req.url.includes("streamGenerateContent");
if (isStreaming) {
req.body.stream = true;
req.isStreaming = true;
} else {
req.body.stream = false;
req.isStreaming = false;
}
}
/**
* Replaces requests for non-Google AI models with gemini-1.5-pro-latest.
* Also strips models/ from the beginning of the model IDs.
**/
function maybeReassignModel(req: Request) {
// Ensure model is on body as a lot of middleware will expect it.
const model = req.body.model || req.url.split("/").pop()?.split(":").shift();
if (!model) {
throw new Error("You must specify a model with your request.");
}
req.body.model = model;
const requested = model;
if (requested.startsWith("models/")) {
req.body.model = requested.slice("models/".length);
}
if (requested.includes("gemini")) {
return;
}
req.log.info({ requested }, "Reassigning model to gemini-1.5-pro-latest");
req.body.model = "gemini-1.5-pro-latest";
}
export const googleAI = googleAIRouter;