From 1e8f55f96d3a9e25571b450416cd3a258de60dc5 Mon Sep 17 00:00:00 2001 From: reanon <85157-reanon@users.noreply.gitgud.io> Date: Thu, 17 Apr 2025 21:12:39 +0000 Subject: [PATCH] Revert "New 2.5 flash thinking budget parameter" This reverts commit 2f8538519b76f192d05b0b4fe041be9d639518c5 --- src/proxy/google-ai.ts | 395 +++++++++++++++++++++++------------------ 1 file changed, 225 insertions(+), 170 deletions(-) diff --git a/src/proxy/google-ai.ts b/src/proxy/google-ai.ts index e100244..5ac7d9d 100644 --- a/src/proxy/google-ai.ts +++ b/src/proxy/google-ai.ts @@ -1,184 +1,239 @@ -import { z } from "zod"; +import { Request, RequestHandler, Router } from "express"; +import { v4 } from "uuid"; +import { GoogleAIKey, keyPool } from "../shared/key-management"; +import { config } from "../config"; +import { ipLimiter } from "./rate-limit"; import { - flattenOpenAIMessageContent, - OpenAIV1ChatCompletionSchema, -} from "./openai"; -import { APIFormatTransformer } from "./index"; + createPreprocessorMiddleware, + finalizeSignedRequest, +} from "./middleware/request"; +import { ProxyResHandlerWithBody } from "./middleware/response"; +import { addGoogleAIKey } from "./middleware/request/mutators/add-google-ai-key"; +import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; +import axios from "axios"; -const TextPartSchema = z.object({ - text: z.string(), - thought: z.boolean().optional() -}); +let modelsCache: any = null; +let modelsCacheTime = 0; -const InlineDataPartSchema = z.object({ - inlineData: z.object({ - mimeType: z.string(), - data: z.string(), - }), -}); +// Cache for native Google AI models +let nativeModelsCache: any = null; +let nativeModelsCacheTime = 0; -const PartSchema = z.union([TextPartSchema, InlineDataPartSchema]); +// https://ai.google.dev/models/gemini +// TODO: list models https://ai.google.dev/tutorials/rest_quickstart#list_models -const GoogleAIV1ContentSchema = z.object({ - parts: z - .union([PartSchema, z.array(PartSchema)]) - .transform((val) => (Array.isArray(val) ? val : [val])), - role: z.enum(["user", "model"]).optional(), -}); - - -const SafetySettingsSchema = z - .array( - z.object({ - category: z.enum([ - "HARM_CATEGORY_HARASSMENT", - "HARM_CATEGORY_HATE_SPEECH", - "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "HARM_CATEGORY_DANGEROUS_CONTENT", - "HARM_CATEGORY_CIVIC_INTEGRITY", - ]), - threshold: z.enum([ - "OFF", - "BLOCK_NONE", - "BLOCK_ONLY_HIGH", - "BLOCK_MEDIUM_AND_ABOVE", - "BLOCK_LOW_AND_ABOVE", - "HARM_BLOCK_THRESHOLD_UNSPECIFIED", - ]), - }) - ) - .optional(); - -// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent -export const GoogleAIV1GenerateContentSchema = z - .object({ - model: z.string().max(100), //actually specified in path but we need it for the router - stream: z.boolean().optional().default(false), // also used for router - contents: z.array(GoogleAIV1ContentSchema), - tools: z.array(z.object({})).max(0).optional(), - safetySettings: SafetySettingsSchema, - systemInstruction: GoogleAIV1ContentSchema.optional(), - // quick fix for SillyTavern, which uses camel case field names for everything - // except for system_instruction where it randomly uses snake case. - // google api evidently accepts either case. - system_instruction: GoogleAIV1ContentSchema.optional(), - generationConfig: z - .object({ - temperature: z.number().min(0).max(2).optional(), - maxOutputTokens: z.coerce - .number() - .int() - .optional() - .default(16) - .transform((v) => Math.min(v, 4096)), // TODO: Add config - candidateCount: z.literal(1).optional(), - topP: z.number().min(0).max(1).optional(), - topK: z.number().min(1).max(40).optional(), - stopSequences: z.array(z.string().max(500)).max(5).optional(), - thinkingConfig: z.object({includeThoughts: z.boolean().optional()}).optional(), - // Support for new Gemini 2.5 thinking budget - thinking_budget: z.union([ - z.literal("auto"), - z.number().int().min(0).max(24576) - ]).optional() - }) - .default({}), - }) - .strip(); -export type GoogleAIChatMessage = z.infer< - typeof GoogleAIV1GenerateContentSchema ->["contents"][0]; - -export const transformOpenAIToGoogleAI: APIFormatTransformer< - typeof GoogleAIV1GenerateContentSchema -> = async (req) => { - const { body } = req; - const result = OpenAIV1ChatCompletionSchema.safeParse({ - ...body, - model: "gpt-3.5-turbo", - }); - if (!result.success) { - req.log.warn( - { issues: result.error.issues, body }, - "Invalid OpenAI-to-Google AI request" - ); - throw result.error; +const getModelsResponse = () => { + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; } - const { messages, ...rest } = result.data; - const foundNames = new Set(); - const contents = messages - .map((m) => { - const role = m.role === "assistant" ? "model" : "user"; - // Detects character names so we can set stop sequences for them as Gemini - // is prone to continuing as the next character. - // If names are not available, we'll still try to prefix the message - // with generic names so we can set stops for them but they don't work - // as well as real names. - const text = flattenOpenAIMessageContent(m.content); - const propName = m.name?.trim(); - const textName = - m.role === "system" ? "" : text.match(/^(.{0,50}?): /)?.[1]?.trim(); - const name = - propName || textName || (role === "model" ? "Character" : "User"); + if (!config.googleAIKey) return { object: "list", data: [] }; - foundNames.add(name); + const keys = keyPool + .list() + .filter((k) => k.service === "google-ai") as GoogleAIKey[]; + if (keys.length === 0) { + modelsCache = { object: "list", data: [] }; + modelsCacheTime = new Date().getTime(); + return modelsCache; + } - // Prefixing messages with their character name seems to help avoid - // Gemini trying to continue as the next character, or at the very least - // ensures it will hit the stop sequence. Otherwise it will start a new - // paragraph and switch perspectives. - // The response will be very likely to include this prefix so frontends - // will need to strip it out. - const textPrefix = textName ? "" : `${name}: `; - return { - parts: [{ text: textPrefix + text }], - role: m.role === "assistant" ? ("model" as const) : ("user" as const), - }; - }) - .reduce((acc, msg) => { - const last = acc[acc.length - 1]; - if (last?.role === msg.role && 'text' in last.parts[0] && 'text' in msg.parts[0]) { - last.parts[0].text += "\n\n" + msg.parts[0].text; - } else { - acc.push(msg); - } - return acc; - }, []); + // Get all model IDs from keys, excluding any with "bard" in the name + const modelIds = Array.from( + new Set(keys.map((k) => k.modelIds).flat()) + ).filter((id) => id.startsWith("models/") && !id.includes("bard")); + + const models = modelIds.map((id) => ({ + id, + object: "model", + created: new Date().getTime(), + owned_by: "google", + permission: [], + root: "google", + parent: null, + })); - let stops = rest.stop - ? Array.isArray(rest.stop) - ? rest.stop - : [rest.stop] - : []; - stops.push(...Array.from(foundNames).map((name) => `\n${name}:`)); - stops = [...new Set(stops)].slice(0, 5); + modelsCache = { object: "list", data: models }; + modelsCacheTime = new Date().getTime(); - return { - model: req.body.model, - stream: rest.stream, - contents, - tools: [], - generationConfig: { - maxOutputTokens: rest.max_tokens, - stopSequences: stops, - topP: rest.top_p, - topK: 40, // openai schema doesn't have this, google ai defaults to 40 - temperature: rest.temperature, - }, - safetySettings: [ - { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" }, - { category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" }, - { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" }, - { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" }, - { category: "HARM_CATEGORY_CIVIC_INTEGRITY", threshold: "BLOCK_NONE" }, - ], - }; + return modelsCache; }; -export function containsImageContent(contents: GoogleAIChatMessage[]): boolean { - return contents.some(content => { - const parts = Array.isArray(content.parts) ? content.parts : [content.parts]; - return parts.some(part => 'inlineData' in part); - }); +// Function to fetch native models from Google AI API +const getNativeModelsResponse = async () => { + if (new Date().getTime() - nativeModelsCacheTime < 1000 * 60) { + return nativeModelsCache; + } + + if (!config.googleAIKey) return { models: [] }; + + const keys = keyPool + .list() + .filter((k) => k.service === "google-ai") as GoogleAIKey[]; + if (keys.length === 0) { + nativeModelsCache = { models: [] }; + nativeModelsCacheTime = new Date().getTime(); + return nativeModelsCache; + } + + try { + // Use the first available key to fetch models + const key = keys[0]; + const apiVersion = "v1beta"; // Use the latest API version + const url = `https://generativelanguage.googleapis.com/${apiVersion}/models`; + + const response = await axios.get(url, { + headers: { + "Content-Type": "application/json", + }, + params: { + key: key.key, + }, + }); + + // We'll update the model cache but won't attempt to update the keys + // This avoids type issues while still keeping our models list up to date + nativeModelsCache = response.data; + nativeModelsCacheTime = new Date().getTime(); + return nativeModelsCache; + } catch (error) { + console.error("Error fetching Google AI models:", error); + // Return empty model list on error + return { models: [] }; + } +}; + +const handleModelRequest: RequestHandler = (_req: Request, res: any) => { + res.status(200).json(getModelsResponse()); +}; + +// Native Gemini API model list request +const handleNativeModelRequest: RequestHandler = async (_req: Request, res: any) => { + try { + const modelsResponse = await getNativeModelsResponse(); + res.status(200).json(modelsResponse); + } catch (error) { + console.error("Error in handleNativeModelRequest:", error); + res.status(500).json({ error: "Failed to fetch models" }); + } +}; + +const googleAIBlockingResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + if (req.inboundApi === "openai") { + req.log.info("Transforming Google AI response to OpenAI format"); + newBody = transformGoogleAIResponse(body, req); + } + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +function transformGoogleAIResponse( + resBody: Record, + req: Request +): Record { + const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); + const parts = resBody.candidates[0].content?.parts ?? [{ text: "" }]; + const content = parts[0].text.replace(/^(.{0,50}?): /, () => ""); + return { + id: "goo-" + v4(), + object: "chat.completion", + created: Date.now(), + model: req.body.model, + usage: { + prompt_tokens: req.promptTokens, + completion_tokens: req.outputTokens, + total_tokens: totalTokens, + }, + choices: [ + { + message: { role: "assistant", content }, + finish_reason: resBody.candidates[0].finishReason, + index: 0, + }, + ], + }; } + +const googleAIProxy = createQueuedProxyMiddleware({ + target: ({ signedRequest }: { signedRequest: any }) => { + if (!signedRequest) throw new Error("Must sign request before proxying"); + const { protocol, hostname} = signedRequest; + return `${protocol}//${hostname}`; + }, + mutations: [addGoogleAIKey, finalizeSignedRequest], + blockingResponseHandler: googleAIBlockingResponseHandler, +}); + +const googleAIRouter = Router(); +googleAIRouter.get("/v1/models", handleModelRequest); +googleAIRouter.get("/:apiVersion(v1alpha|v1beta)/models", handleNativeModelRequest); + +// Native Google AI chat completion endpoint +googleAIRouter.post( + "/:apiVersion(v1alpha|v1beta)/models/:modelId:(generateContent|streamGenerateContent)", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "google-ai", outApi: "google-ai", service: "google-ai" }, + { beforeTransform: [maybeReassignModel], afterTransform: [setStreamFlag] } + ), + googleAIProxy +); + +// OpenAI-to-Google AI compatibility endpoint. +googleAIRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "openai", outApi: "google-ai", service: "google-ai" }, + { afterTransform: [maybeReassignModel] } + ), + googleAIProxy +); + +function setStreamFlag(req: Request) { + const isStreaming = req.url.includes("streamGenerateContent"); + if (isStreaming) { + req.body.stream = true; + req.isStreaming = true; + } else { + req.body.stream = false; + req.isStreaming = false; + } +} + +/** + * Replaces requests for non-Google AI models with gemini-1.5-pro-latest. + * Also strips models/ from the beginning of the model IDs. + **/ +function maybeReassignModel(req: Request) { + // Ensure model is on body as a lot of middleware will expect it. + const model = req.body.model || req.url.split("/").pop()?.split(":").shift(); + if (!model) { + throw new Error("You must specify a model with your request."); + } + req.body.model = model; + + const requested = model; + if (requested.startsWith("models/")) { + req.body.model = requested.slice("models/".length); + } + + if (requested.includes("gemini")) { + return; + } + + req.log.info({ requested }, "Reassigning model to gemini-1.5-pro-latest"); + req.body.model = "gemini-1.5-pro-latest"; +} + +export const googleAI = googleAIRouter;