From 372ad85283f3a2fdb57fc249937c8471e9bf0529 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 31 Dec 2024 10:16:04 +0000 Subject: [PATCH] Support camelCase Gemini params and validate vision --- .../transform-outbound-payload.ts | 41 +++++++++++++++++++ .../request/preprocessors/validate-vision.ts | 5 ++- src/shared/api-schemas/google-ai.ts | 7 ++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts index 51035c6..70d1bdf 100644 --- a/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts +++ b/src/proxy/middleware/request/preprocessors/transform-outbound-payload.ts @@ -30,6 +30,7 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => { } applyMistralPromptFixes(req); + applyGoogleAIKeyTransforms(req); // Native prompts are those which were already provided by the client in the // target API format. We don't need to transform them. @@ -87,3 +88,43 @@ function applyMistralPromptFixes(req: Request): void { } } } + +function toSnakeCase(str: string): string { + return str.replace(/[A-Z]/g, letter => `_${letter.toLowerCase()}`); +} + +function transformKeysToSnakeCase(obj: any, hasTransformed = { value: false }): any { + if (Array.isArray(obj)) { + return obj.map(item => transformKeysToSnakeCase(item, hasTransformed)); + } + + if (obj !== null && typeof obj === 'object') { + return Object.fromEntries( + Object.entries(obj).map(([key, value]) => { + const snakeKey = toSnakeCase(key); + if (snakeKey !== key) { + hasTransformed.value = true; + } + return [ + snakeKey, + transformKeysToSnakeCase(value, hasTransformed) + ]; + }) + ); + } + + return obj; +} + +function applyGoogleAIKeyTransforms(req: Request): void { + // Google (Gemini) API in their infinite wisdom accepts both snake_case and camelCase + // even though in the docs they use snake_case. Some frontends (e.g. ST) use camelCase + // so we normalize all keys to snake_case here + if (req.outboundApi === "google-ai") { + const hasTransformed = { value: false }; + req.body = transformKeysToSnakeCase(req.body, hasTransformed); + if (hasTransformed.value) { + req.log.info("Applied Gemini camelCase -> snake_case transform"); + } + } +} \ No newline at end of file diff --git a/src/proxy/middleware/request/preprocessors/validate-vision.ts b/src/proxy/middleware/request/preprocessors/validate-vision.ts index 5940222..705c2f9 100644 --- a/src/proxy/middleware/request/preprocessors/validate-vision.ts +++ b/src/proxy/middleware/request/preprocessors/validate-vision.ts @@ -3,6 +3,7 @@ import { assertNever } from "../../../../shared/utils"; import { RequestPreprocessor } from "../index"; import { containsImageContent as containsImageContentOpenAI } from "../../../../shared/api-schemas/openai"; import { containsImageContent as containsImageContentAnthropic } from "../../../../shared/api-schemas/anthropic"; +import { containsImageContent as containsImageContentGoogleAI } from "../../../../shared/api-schemas/google-ai"; import { ForbiddenError } from "../../../../shared/errors"; /** @@ -25,8 +26,10 @@ export const validateVision: RequestPreprocessor = async (req) => { case "anthropic-chat": hasImage = containsImageContentAnthropic(req.body.messages); break; - case "anthropic-text": case "google-ai": + hasImage = containsImageContentGoogleAI(req.body.contents); + break; + case "anthropic-text": case "mistral-ai": case "mistral-text": case "openai-image": diff --git a/src/shared/api-schemas/google-ai.ts b/src/shared/api-schemas/google-ai.ts index 3283e87..ef80c04 100644 --- a/src/shared/api-schemas/google-ai.ts +++ b/src/shared/api-schemas/google-ai.ts @@ -165,3 +165,10 @@ export const transformOpenAIToGoogleAI: APIFormatTransformer< ], }; }; + +export function containsImageContent(contents: GoogleAIChatMessage[]): boolean { + return contents.some(content => { + const parts = Array.isArray(content.parts) ? content.parts : [content.parts]; + return parts.some(part => 'inline_data' in part); + }); +}