Add Google AI API (khanon/oai-reverse-proxy!57)
This commit is contained in:
+3
-3
@@ -34,10 +34,10 @@
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
|
||||
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
|
||||
# generation, uncomment the line below and add 'dall-e' to the list.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
|
||||
|
||||
# URLs from which requests will be blocked.
|
||||
# BLOCKED_ORIGINS=reddit.com,9gag.com
|
||||
@@ -95,7 +95,7 @@
|
||||
# TOKEN_QUOTA_GPT4_TURBO=0
|
||||
# TOKEN_QUOTA_DALL_E=0
|
||||
# TOKEN_QUOTA_CLAUDE=0
|
||||
# TOKEN_QUOTA_BISON=0
|
||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||
|
||||
# How often to refresh token quotas. (hourly | daily)
|
||||
|
||||
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
|
||||
|
||||
# All models as of this writing
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
|
||||
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
|
||||
```
|
||||
|
||||
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
|
||||
|
||||
Generated
+34
@@ -36,6 +36,7 @@
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"stream-json": "^1.8.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
"zlib": "^1.0.5",
|
||||
@@ -51,6 +52,7 @@
|
||||
"@types/node-schedule": "^2.1.0",
|
||||
"@types/sanitize-html": "^2.9.0",
|
||||
"@types/showdown": "^2.0.0",
|
||||
"@types/stream-json": "^1.7.7",
|
||||
"@types/uuid": "^9.0.1",
|
||||
"concurrently": "^8.0.1",
|
||||
"esbuild": "^0.17.16",
|
||||
@@ -1185,6 +1187,25 @@
|
||||
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/stream-chain": {
|
||||
"version": "2.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
|
||||
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/stream-json": {
|
||||
"version": "1.7.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
|
||||
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
|
||||
"dev": true,
|
||||
"dependencies": {
|
||||
"@types/node": "*",
|
||||
"@types/stream-chain": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/uuid": {
|
||||
"version": "9.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
|
||||
@@ -5135,6 +5156,11 @@
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-chain": {
|
||||
"version": "2.2.5",
|
||||
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
|
||||
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
|
||||
},
|
||||
"node_modules/stream-events": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
|
||||
@@ -5144,6 +5170,14 @@
|
||||
"stubs": "^3.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-json": {
|
||||
"version": "1.8.0",
|
||||
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
|
||||
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
|
||||
"dependencies": {
|
||||
"stream-chain": "^2.2.5"
|
||||
}
|
||||
},
|
||||
"node_modules/stream-shift": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
|
||||
|
||||
@@ -44,6 +44,7 @@
|
||||
"sanitize-html": "^2.11.0",
|
||||
"sharp": "^0.32.6",
|
||||
"showdown": "^2.1.0",
|
||||
"stream-json": "^1.8.0",
|
||||
"tiktoken": "^1.0.10",
|
||||
"uuid": "^9.0.0",
|
||||
"zlib": "^1.0.5",
|
||||
@@ -59,6 +60,7 @@
|
||||
"@types/node-schedule": "^2.1.0",
|
||||
"@types/sanitize-html": "^2.9.0",
|
||||
"@types/showdown": "^2.0.0",
|
||||
"@types/stream-json": "^1.7.7",
|
||||
"@types/uuid": "^9.0.1",
|
||||
"concurrently": "^8.0.1",
|
||||
"esbuild": "^0.17.16",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const axios = require("axios");
|
||||
|
||||
const concurrentRequests = 5;
|
||||
const concurrentRequests = 75;
|
||||
const headers = {
|
||||
Authorization: "Bearer test",
|
||||
"Content-Type": "application/json",
|
||||
@@ -16,7 +16,7 @@ const payload = {
|
||||
const makeRequest = async (i) => {
|
||||
try {
|
||||
const response = await axios.post(
|
||||
"http://localhost:7860/proxy/azure/openai/v1/chat/completions",
|
||||
"http://localhost:7860/proxy/google-ai/v1/chat/completions",
|
||||
payload,
|
||||
{ headers }
|
||||
);
|
||||
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
|
||||
response.data
|
||||
);
|
||||
} catch (error) {
|
||||
console.error(`Error in req ${i}:`, error.message);
|
||||
const msg = error.response
|
||||
console.error(`Error in req ${i}:`, error.message, msg || "");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
|
||||
keyPool.recheck("anthropic");
|
||||
const size = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service !== "google-palm").length;
|
||||
.filter((k) => k.service !== "google-ai").length;
|
||||
flash.type = "success";
|
||||
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
|
||||
break;
|
||||
|
||||
+10
-6
@@ -19,8 +19,12 @@ type Config = {
|
||||
openaiKey?: string;
|
||||
/** Comma-delimited list of Anthropic API keys. */
|
||||
anthropicKey?: string;
|
||||
/** Comma-delimited list of Google PaLM API keys. */
|
||||
googlePalmKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of Google AI API keys. Note that these are not the
|
||||
* same as the GCP keys/credentials used for Vertex AI; the models are the
|
||||
* same but the APIs are different. Vertex is the GCP product for enterprise.
|
||||
**/
|
||||
googleAIKey?: string;
|
||||
/**
|
||||
* Comma-delimited list of AWS credentials. Each credential item should be a
|
||||
* colon-delimited list of access key, secret key, and AWS region.
|
||||
@@ -197,7 +201,7 @@ export const config: Config = {
|
||||
port: getEnvWithDefault("PORT", 7860),
|
||||
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
|
||||
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
|
||||
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
|
||||
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
|
||||
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
|
||||
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
|
||||
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
|
||||
@@ -229,7 +233,7 @@ export const config: Config = {
|
||||
"gpt4-32k",
|
||||
"gpt4-turbo",
|
||||
"claude",
|
||||
"bison",
|
||||
"gemini-pro",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -366,7 +370,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
|
||||
"logLevel",
|
||||
"openaiKey",
|
||||
"anthropicKey",
|
||||
"googlePalmKey",
|
||||
"googleAIKey",
|
||||
"awsCredentials",
|
||||
"azureCredentials",
|
||||
"proxyKey",
|
||||
@@ -433,7 +437,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
|
||||
[
|
||||
"OPENAI_KEY",
|
||||
"ANTHROPIC_KEY",
|
||||
"GOOGLE_PALM_KEY",
|
||||
"GOOGLE_AI_KEY",
|
||||
"AWS_CREDENTIALS",
|
||||
"AZURE_CREDENTIALS",
|
||||
].includes(String(env))
|
||||
|
||||
+35
-29
@@ -7,7 +7,7 @@ import {
|
||||
AnthropicKey,
|
||||
AwsBedrockKey,
|
||||
AzureOpenAIKey,
|
||||
GooglePalmKey,
|
||||
GoogleAIKey,
|
||||
keyPool,
|
||||
OpenAIKey,
|
||||
} from "./shared/key-management";
|
||||
@@ -33,8 +33,8 @@ const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
|
||||
k.service === "azure";
|
||||
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
|
||||
k.service === "anthropic";
|
||||
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
|
||||
k.service === "google-palm";
|
||||
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
|
||||
k.service === "google-ai";
|
||||
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
|
||||
|
||||
type ModelAggregates = {
|
||||
@@ -54,7 +54,7 @@ type ServiceAggregates = {
|
||||
openaiKeys?: number;
|
||||
openaiOrgs?: number;
|
||||
anthropicKeys?: number;
|
||||
palmKeys?: number;
|
||||
googleAIKeys?: number;
|
||||
awsKeys?: number;
|
||||
azureKeys?: number;
|
||||
proompts: number;
|
||||
@@ -100,7 +100,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
|
||||
const openaiKeys = serviceStats.get("openaiKeys") || 0;
|
||||
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
|
||||
const palmKeys = serviceStats.get("palmKeys") || 0;
|
||||
const googleAIKeys = serviceStats.get("googleAIKeys") || 0;
|
||||
const awsKeys = serviceStats.get("awsKeys") || 0;
|
||||
const azureKeys = serviceStats.get("azureKeys") || 0;
|
||||
const proompts = serviceStats.get("proompts") || 0;
|
||||
@@ -116,7 +116,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
? { ["openai-image"]: baseUrl + "/openai-image" }
|
||||
: {}),
|
||||
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
|
||||
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
|
||||
...(googleAIKeys ? { "google-ai": baseUrl + "/google-ai" } : {}),
|
||||
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
|
||||
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
|
||||
};
|
||||
@@ -127,7 +127,13 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
|
||||
};
|
||||
|
||||
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
|
||||
const keyInfo = {
|
||||
openaiKeys,
|
||||
anthropicKeys,
|
||||
googleAIKeys,
|
||||
awsKeys,
|
||||
azureKeys,
|
||||
};
|
||||
for (const key of Object.keys(keyInfo)) {
|
||||
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
|
||||
}
|
||||
@@ -135,7 +141,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
|
||||
const providerInfo = {
|
||||
...(openaiKeys ? getOpenAIInfo() : {}),
|
||||
...(anthropicKeys ? getAnthropicInfo() : {}),
|
||||
...(palmKeys ? getPalmInfo() : {}),
|
||||
...(googleAIKeys ? getGoogleAIInfo() : {}),
|
||||
...(awsKeys ? getAwsInfo() : {}),
|
||||
...(azureKeys ? getAzureInfo() : {}),
|
||||
};
|
||||
@@ -197,7 +203,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
increment(serviceStats, "proompts", k.promptCount);
|
||||
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
|
||||
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
|
||||
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
|
||||
increment(serviceStats, "googleAIKeys", k.service === "google-ai" ? 1 : 0);
|
||||
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
|
||||
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
|
||||
|
||||
@@ -251,14 +257,14 @@ function addKeyToAggregates(k: KeyPoolKey) {
|
||||
);
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
|
||||
const family = "bison";
|
||||
sumTokens += k.bisonTokens;
|
||||
sumCost += getTokenCostUsd(family, k.bisonTokens);
|
||||
case "google-ai": {
|
||||
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
|
||||
const family = "gemini-pro";
|
||||
sumTokens += k["gemini-proTokens"];
|
||||
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
|
||||
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
|
||||
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
|
||||
increment(modelStats, `${family}__tokens`, k.bisonTokens);
|
||||
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
|
||||
break;
|
||||
}
|
||||
case "aws": {
|
||||
@@ -388,26 +394,26 @@ function getAnthropicInfo() {
|
||||
};
|
||||
}
|
||||
|
||||
function getPalmInfo() {
|
||||
const bisonInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("bison__active") || 0,
|
||||
revoked: modelStats.get("bison__revoked") || 0,
|
||||
function getGoogleAIInfo() {
|
||||
const googleAIInfo: Partial<ModelAggregates> = {
|
||||
active: modelStats.get("gemini-pro__active") || 0,
|
||||
revoked: modelStats.get("gemini-pro__revoked") || 0,
|
||||
};
|
||||
|
||||
const queue = getQueueInformation("bison");
|
||||
bisonInfo.queued = queue.proomptersInQueue;
|
||||
bisonInfo.queueTime = queue.estimatedQueueTime;
|
||||
const queue = getQueueInformation("gemini-pro");
|
||||
googleAIInfo.queued = queue.proomptersInQueue;
|
||||
googleAIInfo.queueTime = queue.estimatedQueueTime;
|
||||
|
||||
const tokens = modelStats.get("bison__tokens") || 0;
|
||||
const cost = getTokenCostUsd("bison", tokens);
|
||||
const tokens = modelStats.get("gemini-pro__tokens") || 0;
|
||||
const cost = getTokenCostUsd("gemini-pro", tokens);
|
||||
|
||||
return {
|
||||
bison: {
|
||||
gemini: {
|
||||
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
|
||||
activeKeys: bisonInfo.active,
|
||||
revokedKeys: bisonInfo.revoked,
|
||||
proomptersInQueue: bisonInfo.queued,
|
||||
estimatedQueueTime: bisonInfo.queueTime,
|
||||
activeKeys: googleAIInfo.active,
|
||||
revokedKeys: googleAIInfo.revoked,
|
||||
proomptersInQueue: googleAIInfo.queued,
|
||||
estimatedQueueTime: googleAIInfo.queueTime,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeSignedRequest,
|
||||
forceModel,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.googleAIKey) return { object: "list", data: [] };
|
||||
|
||||
const googleAIVariants = ["gemini-pro"];
|
||||
|
||||
const models = googleAIVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "google",
|
||||
permission: [],
|
||||
root: "google",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Google AI response to OpenAI format");
|
||||
body = transformGoogleAIResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
function transformGoogleAIResponse(
|
||||
googleAIResp: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "goo-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: googleAIResp.candidates[0].content.parts[0].text,
|
||||
},
|
||||
finish_reason: googleAIResp.candidates[0].finishReason,
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
const googleAIProxy = createQueueMiddleware({
|
||||
beforeProxy: addGoogleAIKey,
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "bad-target-will-be-rewritten",
|
||||
router: ({ signedRequest }) => {
|
||||
const { protocol, hostname, path } = signedRequest;
|
||||
return `${protocol}//${hostname}${path}`;
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const googleAIRouter = Router();
|
||||
googleAIRouter.get("/v1/models", handleModelRequest);
|
||||
// OpenAI-to-Google AI compatibility endpoint.
|
||||
googleAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
|
||||
{ afterTransform: [forceModel("gemini-pro")] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
export const googleAI = googleAIRouter;
|
||||
@@ -177,8 +177,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
|
||||
return "";
|
||||
}
|
||||
return body.completion.trim();
|
||||
case "google-palm":
|
||||
return body.candidates[0].output;
|
||||
case "google-ai":
|
||||
if ("choices" in body) {
|
||||
return body.choices[0].message.content;
|
||||
}
|
||||
return body.candidates[0].content.parts[0].text;
|
||||
case "openai-image":
|
||||
return body.data?.map((item: any) => item.url).join("\n");
|
||||
default:
|
||||
@@ -197,7 +200,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
|
||||
case "anthropic":
|
||||
// Anthropic confirms the model in the response, but AWS Claude doesn't.
|
||||
return body.model || req.body.model;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
// Google doesn't confirm the model in the response.
|
||||
return req.body.model;
|
||||
default:
|
||||
|
||||
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
|
||||
// The streaming flag must be set before any other onProxyReq handler runs,
|
||||
// as it may influence the behavior of subsequent handlers.
|
||||
// Image generation requests can't be streamed.
|
||||
req.isStreaming = req.body.stream === true || req.body.stream === "true";
|
||||
// TODO: this flag is set in too many places
|
||||
req.isStreaming =
|
||||
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
||||
req.body.stream = req.isStreaming;
|
||||
|
||||
try {
|
||||
|
||||
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
case "anthropic":
|
||||
assignedKey = keyPool.get("claude-v1");
|
||||
break;
|
||||
case "google-palm":
|
||||
assignedKey = keyPool.get("text-bison-001");
|
||||
delete req.body.stream;
|
||||
break;
|
||||
case "openai-text":
|
||||
assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
|
||||
break;
|
||||
@@ -42,6 +38,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
throw new Error(
|
||||
"OpenAI Chat as an API translation target is not supported"
|
||||
);
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this model.");
|
||||
case "openai-image":
|
||||
assignedKey = keyPool.get("dall-e-3");
|
||||
break;
|
||||
@@ -73,21 +71,13 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
|
||||
}
|
||||
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
|
||||
break;
|
||||
case "google-palm":
|
||||
const originalPath = proxyReq.path;
|
||||
proxyReq.path = originalPath.replace(
|
||||
/(\?.*)?$/,
|
||||
`?key=${assignedKey.key}`
|
||||
);
|
||||
break;
|
||||
case "azure":
|
||||
const azureKey = assignedKey.key;
|
||||
proxyReq.setHeader("api-key", azureKey);
|
||||
break;
|
||||
case "aws":
|
||||
throw new Error(
|
||||
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
|
||||
);
|
||||
case "google-ai":
|
||||
throw new Error("add-key should not be used for this service.");
|
||||
default:
|
||||
assertNever(assignedKey.service);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import type { HPMRequestCallback } from "../index";
|
||||
|
||||
/**
|
||||
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
|
||||
* before the proxy middleware. This function just assigns the path and headers
|
||||
* to the proxy request.
|
||||
* For AWS/Azure/Google requests, the body is signed earlier in the request
|
||||
* pipeline, before the proxy middleware. This function just assigns the path
|
||||
* and headers to the proxy request.
|
||||
*/
|
||||
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
|
||||
if (!req.signedRequest) {
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addGoogleAIKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
|
||||
const serviceValid = req.service === "google-ai";
|
||||
if (!apisValid || !serviceValid) {
|
||||
throw new Error("addGoogleAIKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
const model = req.body.model;
|
||||
req.key = keyPool.get(model);
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
"Assigned Google AI API key to request"
|
||||
);
|
||||
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
|
||||
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
delete req.body.stream;
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: "generativelanguage.googleapis.com",
|
||||
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
|
||||
headers: {
|
||||
["host"]: `generativelanguage.googleapis.com`,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
};
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { countTokens } from "../../../../shared/tokenization";
|
||||
import { assertNever } from "../../../../shared/utils";
|
||||
import type { OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload";
|
||||
|
||||
/**
|
||||
* Given a request with an already-transformed body, counts the number of
|
||||
@@ -30,9 +30,9 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
case "google-palm": {
|
||||
req.outputTokens = req.body.maxOutputTokens;
|
||||
const prompt: string = req.body.prompt.text;
|
||||
case "google-ai": {
|
||||
req.outputTokens = req.body.generationConfig.maxOutputTokens;
|
||||
const prompt: GoogleAIChatMessage[] = req.body.contents;
|
||||
result = await countTokens({ req, prompt, service });
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ function getPromptFromRequest(req: Request) {
|
||||
case "openai-text":
|
||||
case "openai-image":
|
||||
return body.prompt;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return body.prompt.text;
|
||||
default:
|
||||
assertNever(service);
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import { Request } from "express";
|
||||
import { z } from "zod";
|
||||
import { config } from "../../../../config";
|
||||
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
|
||||
import {
|
||||
isTextGenerationRequest,
|
||||
isImageGenerationRequest,
|
||||
} from "../../common";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
|
||||
@@ -121,10 +124,19 @@ const OpenAIV1ImagesGenerationSchema = z.object({
|
||||
user: z.string().optional(),
|
||||
});
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
|
||||
const PalmV1GenerateTextSchema = z.object({
|
||||
model: z.string(),
|
||||
prompt: z.object({ text: z.string() }),
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
|
||||
const GoogleAIV1GenerateContentSchema = z.object({
|
||||
model: z.string(), //actually specified in path but we need it for the router
|
||||
stream: z.boolean().optional().default(false), // also used for router
|
||||
contents: z.array(
|
||||
z.object({
|
||||
parts: z.array(z.object({ text: z.string() })),
|
||||
role: z.enum(["user", "model"]),
|
||||
})
|
||||
),
|
||||
tools: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
generationConfig: z.object({
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
.number()
|
||||
@@ -135,16 +147,20 @@ const PalmV1GenerateTextSchema = z.object({
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().optional(),
|
||||
topK: z.number().optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
stopSequences: z.array(z.string()).max(5).optional(),
|
||||
}),
|
||||
});
|
||||
|
||||
export type GoogleAIChatMessage = z.infer<
|
||||
typeof GoogleAIV1GenerateContentSchema
|
||||
>["contents"][0];
|
||||
|
||||
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||
anthropic: AnthropicV1CompleteSchema,
|
||||
openai: OpenAIV1ChatCompletionSchema,
|
||||
"openai-text": OpenAIV1TextCompletionSchema,
|
||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||
"google-palm": PalmV1GenerateTextSchema,
|
||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||
};
|
||||
|
||||
/** Transforms an incoming request body to one that matches the target API. */
|
||||
@@ -174,8 +190,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") {
|
||||
req.body = openaiToPalm(req);
|
||||
if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
|
||||
req.body = openaiToGoogleAI(req);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -310,7 +326,9 @@ function openaiToOpenaiImage(req: Request) {
|
||||
return OpenAIV1ImagesGenerationSchema.parse(transformed);
|
||||
}
|
||||
|
||||
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
function openaiToGoogleAI(
|
||||
req: Request
|
||||
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
|
||||
const { body } = req;
|
||||
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||
...body,
|
||||
@@ -319,13 +337,16 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
if (!result.success) {
|
||||
req.log.warn(
|
||||
{ issues: result.error.issues, body },
|
||||
"Invalid OpenAI-to-Palm request"
|
||||
"Invalid OpenAI-to-Google AI request"
|
||||
);
|
||||
throw result.error;
|
||||
}
|
||||
|
||||
const { messages, ...rest } = result.data;
|
||||
const prompt = flattenOpenAIChatMessages(messages);
|
||||
const contents = messages.map((m) => ({
|
||||
parts: [{ text: flattenOpenAIMessageContent(m.content) }],
|
||||
role: m.role === "user" ? "user" as const : "model" as const,
|
||||
}));
|
||||
|
||||
let stops = rest.stop
|
||||
? Array.isArray(rest.stop)
|
||||
@@ -339,20 +360,22 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
|
||||
z.array(z.string()).max(5).parse(stops);
|
||||
|
||||
return {
|
||||
prompt: { text: prompt },
|
||||
model: "gemini-pro",
|
||||
stream: rest.stream,
|
||||
contents,
|
||||
tools: [],
|
||||
generationConfig: {
|
||||
maxOutputTokens: rest.max_tokens,
|
||||
stopSequences: stops,
|
||||
model: "text-bison-001",
|
||||
topP: rest.top_p,
|
||||
topK: 40, // openai schema doesn't have this, google ai defaults to 40
|
||||
temperature: rest.temperature,
|
||||
},
|
||||
safetySettings: [
|
||||
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
|
||||
],
|
||||
};
|
||||
}
|
||||
@@ -428,3 +451,4 @@ function flattenOpenAIMessageContent(
|
||||
.join("\n")
|
||||
: content;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import { RequestPreprocessor } from "../index";
|
||||
|
||||
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
|
||||
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
|
||||
const BISON_MAX_CONTEXT = 8100;
|
||||
const GOOGLE_AI_MAX_CONTEXT = 32000;
|
||||
|
||||
/**
|
||||
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
|
||||
@@ -31,8 +31,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
case "anthropic":
|
||||
proxyMax = CLAUDE_MAX_CONTEXT;
|
||||
break;
|
||||
case "google-palm":
|
||||
proxyMax = BISON_MAX_CONTEXT;
|
||||
case "google-ai":
|
||||
proxyMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
break;
|
||||
case "openai-image":
|
||||
return;
|
||||
@@ -62,8 +62,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
|
||||
modelMax = 100000;
|
||||
} else if (model.match(/^claude-2/)) {
|
||||
modelMax = 200000;
|
||||
} else if (model.match(/^text-bison-\d{3}$/)) {
|
||||
modelMax = BISON_MAX_CONTEXT;
|
||||
} else if (model.match(/^gemini-\d{3}$/)) {
|
||||
modelMax = GOOGLE_AI_MAX_CONTEXT;
|
||||
} else if (model.match(/^anthropic\.claude/)) {
|
||||
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
|
||||
modelMax = 100000;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import express from "express";
|
||||
import { pipeline } from "stream";
|
||||
import { promisify } from "util";
|
||||
import {
|
||||
@@ -59,7 +58,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
|
||||
const prefersNativeEvents = req.inboundApi === req.outboundApi;
|
||||
const contentType = proxyRes.headers["content-type"];
|
||||
|
||||
const adapter = new SSEStreamAdapter({ contentType });
|
||||
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
|
||||
const aggregator = new EventAggregator({ format: req.outboundApi });
|
||||
const transformer = new SSEMessageTransformer({
|
||||
inputFormat: req.outboundApi,
|
||||
|
||||
@@ -288,7 +288,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
// For Anthropic, this is usually due to missing preamble.
|
||||
switch (service) {
|
||||
case "openai":
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
case "azure":
|
||||
const filteredCodes = ["content_policy_violation", "content_filter"];
|
||||
if (filteredCodes.includes(errorPayload.error?.code)) {
|
||||
@@ -350,8 +350,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "azure":
|
||||
handleAzureRateLimitError(req, errorPayload);
|
||||
break;
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
|
||||
case "google-ai":
|
||||
handleGoogleAIRateLimitError(req, errorPayload);
|
||||
break;
|
||||
default:
|
||||
assertNever(service);
|
||||
@@ -373,8 +373,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
|
||||
case "anthropic":
|
||||
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "google-palm":
|
||||
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
|
||||
case "google-ai":
|
||||
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
|
||||
break;
|
||||
case "aws":
|
||||
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
|
||||
@@ -529,6 +529,23 @@ function handleAzureRateLimitError(
|
||||
}
|
||||
}
|
||||
|
||||
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
|
||||
function handleGoogleAIRateLimitError(
|
||||
req: Request,
|
||||
errorPayload: ProxiedErrorPayload
|
||||
) {
|
||||
const status = errorPayload.error?.status;
|
||||
switch (status) {
|
||||
case "RESOURCE_EXHAUSTED":
|
||||
keyPool.markRateLimited(req.key!);
|
||||
reenqueueRequest(req);
|
||||
throw new RetryableError("Rate-limited request re-enqueued.");
|
||||
default:
|
||||
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
|
||||
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
|
||||
const model = req.body.model;
|
||||
|
||||
@@ -73,7 +73,7 @@ const getPromptForRequest = (
|
||||
};
|
||||
case "anthropic":
|
||||
return req.body.prompt;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return req.body.prompt.text;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
|
||||
@@ -27,12 +27,12 @@ export class EventAggregator {
|
||||
getFinalResponse() {
|
||||
switch (this.format) {
|
||||
case "openai":
|
||||
case "google-ai":
|
||||
return mergeEventsForOpenAIChat(this.events);
|
||||
case "openai-text":
|
||||
return mergeEventsForOpenAIText(this.events);
|
||||
case "anthropic":
|
||||
return mergeEventsForAnthropic(this.events);
|
||||
case "google-palm":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE aggregation not supported for ${this.format}`);
|
||||
default:
|
||||
|
||||
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
|
||||
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
|
||||
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
|
||||
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
|
||||
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
|
||||
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
|
||||
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
|
||||
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
|
||||
|
||||
@@ -7,9 +7,10 @@ import {
|
||||
anthropicV2ToOpenAI,
|
||||
OpenAIChatCompletionStreamEvent,
|
||||
openAITextToOpenAIChat,
|
||||
googleAIToOpenAI,
|
||||
passthroughToOpenAI,
|
||||
StreamingCompletionTransformer,
|
||||
} from "./index";
|
||||
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
|
||||
|
||||
const genlog = logger.child({ module: "sse-transformer" });
|
||||
|
||||
@@ -111,7 +112,8 @@ function getTransformer(
|
||||
return version === "2023-01-01"
|
||||
? anthropicV1ToOpenAI
|
||||
: anthropicV2ToOpenAI;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
return googleAIToOpenAI;
|
||||
case "openai-image":
|
||||
throw new Error(`SSE transformation not supported for ${responseApi}`);
|
||||
default:
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
import { Transform, TransformOptions } from "stream";
|
||||
|
||||
import { StringDecoder } from "string_decoder";
|
||||
// @ts-ignore
|
||||
import { Parser } from "lifion-aws-event-stream";
|
||||
import { logger } from "../../../../logger";
|
||||
import { RetryableError } from "../index";
|
||||
import { APIFormat } from "../../../../shared/key-management";
|
||||
import StreamArray from "stream-json/streamers/StreamArray";
|
||||
|
||||
const log = logger.child({ module: "sse-stream-adapter" });
|
||||
|
||||
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
|
||||
type SSEStreamAdapterOptions = TransformOptions & {
|
||||
contentType?: string;
|
||||
api: APIFormat;
|
||||
};
|
||||
type AwsEventStreamMessage = {
|
||||
headers: {
|
||||
":message-type": "event" | "exception";
|
||||
@@ -22,7 +28,9 @@ type AwsEventStreamMessage = {
|
||||
*/
|
||||
export class SSEStreamAdapter extends Transform {
|
||||
private readonly isAwsStream;
|
||||
private readonly isGoogleStream;
|
||||
private parser = new Parser();
|
||||
private jsonStream = StreamArray.withParser();
|
||||
private partialMessage = "";
|
||||
private decoder = new StringDecoder("utf8");
|
||||
|
||||
@@ -30,6 +38,7 @@ export class SSEStreamAdapter extends Transform {
|
||||
super(options);
|
||||
this.isAwsStream =
|
||||
options?.contentType === "application/vnd.amazon.eventstream";
|
||||
this.isGoogleStream = options?.api === "google-ai";
|
||||
|
||||
this.parser.on("data", (data: AwsEventStreamMessage) => {
|
||||
const message = this.processAwsEvent(data);
|
||||
@@ -37,6 +46,12 @@ export class SSEStreamAdapter extends Transform {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
this.jsonStream.on("data", (data: { value: any }) => {
|
||||
const message = this.processGoogleValue(data.value);
|
||||
if (message) {
|
||||
this.push(Buffer.from(message + "\n\n"), "utf8");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
|
||||
@@ -73,17 +88,38 @@ export class SSEStreamAdapter extends Transform {
|
||||
}
|
||||
}
|
||||
|
||||
// Google doesn't use event streams and just sends elements in an array over
|
||||
// a long-lived HTTP connection. Needs stream-json to parse the array.
|
||||
protected processGoogleValue(value: any): string | null {
|
||||
if ("candidates" in value) {
|
||||
return `data: ${JSON.stringify(value)}`;
|
||||
} else {
|
||||
log.error(
|
||||
{ value },
|
||||
"Received unexpected Google AI event stream message"
|
||||
);
|
||||
return getFakeErrorCompletion(
|
||||
"proxy Google AI error",
|
||||
JSON.stringify(value)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
|
||||
try {
|
||||
if (this.isAwsStream) {
|
||||
this.parser.write(chunk);
|
||||
} else if (this.isGoogleStream) {
|
||||
this.jsonStream.write(chunk);
|
||||
} else {
|
||||
// We may receive multiple (or partial) SSE messages in a single chunk,
|
||||
// so we need to buffer and emit separate stream events for full
|
||||
// messages so we can parse/transform them properly.
|
||||
const str = this.decoder.write(chunk);
|
||||
|
||||
const fullMessages = (this.partialMessage + str).split(/\r\r|\n\n|\r\n\r\n/);
|
||||
const fullMessages = (this.partialMessage + str).split(
|
||||
/\r\r|\n\n|\r\n\r\n/
|
||||
);
|
||||
this.partialMessage = fullMessages.pop() || "";
|
||||
|
||||
for (const message of fullMessages) {
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import { StreamingCompletionTransformer } from "../index";
|
||||
import { parseEvent, ServerSentEvent } from "../parse-sse";
|
||||
import { logger } from "../../../../../logger";
|
||||
|
||||
const log = logger.child({
|
||||
module: "sse-transformer",
|
||||
transformer: "google-ai-to-openai",
|
||||
});
|
||||
|
||||
type GoogleAIStreamEvent = {
|
||||
candidates: {
|
||||
content: { parts: { text: string }[]; role: string };
|
||||
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
|
||||
index: number;
|
||||
tokenCount?: number;
|
||||
safetyRatings: { category: string; probability: string }[];
|
||||
}[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an incoming Google AI SSE to an equivalent OpenAI
|
||||
* chat.completion.chunk SSE.
|
||||
*/
|
||||
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
|
||||
const { data } = params;
|
||||
|
||||
const rawEvent = parseEvent(data);
|
||||
if (!rawEvent.data || rawEvent.data === "[DONE]") {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const completionEvent = asCompletion(rawEvent);
|
||||
if (!completionEvent) {
|
||||
return { position: -1 };
|
||||
}
|
||||
|
||||
const parts = completionEvent.candidates[0].content.parts;
|
||||
const text = parts[0]?.text ?? "";
|
||||
const newEvent = {
|
||||
id: "goo-" + params.fallbackId,
|
||||
object: "chat.completion.chunk" as const,
|
||||
created: Date.now(),
|
||||
model: params.fallbackModel,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
delta: { content: text },
|
||||
finish_reason: completionEvent.candidates[0].finishReason ?? null,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
return { position: -1, event: newEvent };
|
||||
};
|
||||
|
||||
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
|
||||
try {
|
||||
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
|
||||
if (parsed.candidates?.length > 0) {
|
||||
return parsed;
|
||||
} else {
|
||||
// noinspection ExceptionCaughtLocallyJS
|
||||
throw new Error("Missing required fields");
|
||||
}
|
||||
} catch (error) {
|
||||
log.warn({ error: error.stack, event }, "Received invalid event");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
} from "./middleware/response";
|
||||
import { generateModelList } from "./openai";
|
||||
import {
|
||||
mirrorGeneratedImage,
|
||||
OpenAIImageGenerationResult,
|
||||
} from "../shared/file-storage/mirror-generated-image";
|
||||
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
import { Request, RequestHandler, Router } from "express";
|
||||
import * as http from "http";
|
||||
import { createProxyMiddleware } from "http-proxy-middleware";
|
||||
import { v4 } from "uuid";
|
||||
import { config } from "../config";
|
||||
import { logger } from "../logger";
|
||||
import { createQueueMiddleware } from "./queue";
|
||||
import { ipLimiter } from "./rate-limit";
|
||||
import { handleProxyError } from "./middleware/common";
|
||||
import {
|
||||
addKey,
|
||||
createOnProxyReqHandler,
|
||||
createPreprocessorMiddleware,
|
||||
finalizeBody,
|
||||
forceModel,
|
||||
} from "./middleware/request";
|
||||
import {
|
||||
createOnProxyResHandler,
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
|
||||
const getModelsResponse = () => {
|
||||
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
if (!config.googlePalmKey) return { object: "list", data: [] };
|
||||
|
||||
const bisonVariants = ["text-bison-001"];
|
||||
|
||||
const models = bisonVariants.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
owned_by: "google",
|
||||
permission: [],
|
||||
root: "palm",
|
||||
parent: null,
|
||||
}));
|
||||
|
||||
modelsCache = { object: "list", data: models };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
|
||||
return modelsCache;
|
||||
};
|
||||
|
||||
const handleModelRequest: RequestHandler = (_req, res) => {
|
||||
res.status(200).json(getModelsResponse());
|
||||
};
|
||||
|
||||
/** Only used for non-streaming requests. */
|
||||
const palmResponseHandler: ProxyResHandlerWithBody = async (
|
||||
_proxyRes,
|
||||
req,
|
||||
res,
|
||||
body
|
||||
) => {
|
||||
if (typeof body !== "object") {
|
||||
throw new Error("Expected body to be an object");
|
||||
}
|
||||
|
||||
if (config.promptLogging) {
|
||||
const host = req.get("host");
|
||||
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
|
||||
}
|
||||
|
||||
if (req.inboundApi === "openai") {
|
||||
req.log.info("Transforming Google PaLM response to OpenAI format");
|
||||
body = transformPalmResponse(body, req);
|
||||
}
|
||||
|
||||
if (req.tokenizerInfo) {
|
||||
body.proxy_tokenizer = req.tokenizerInfo;
|
||||
}
|
||||
|
||||
// TODO: PaLM has no streaming capability which will pose a problem here if
|
||||
// requests wait in the queue for too long. Probably need to fake streaming
|
||||
// and return the entire completion in one stream event using the other
|
||||
// response handler.
|
||||
res.status(200).json(body);
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms a model response from the Anthropic API to match those from the
|
||||
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
|
||||
* is only used for non-streaming requests as streaming requests are handled
|
||||
* on-the-fly.
|
||||
*/
|
||||
function transformPalmResponse(
|
||||
palmRespBody: Record<string, any>,
|
||||
req: Request
|
||||
): Record<string, any> {
|
||||
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
|
||||
return {
|
||||
id: "plm-" + v4(),
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model: req.body.model,
|
||||
usage: {
|
||||
prompt_tokens: req.promptTokens,
|
||||
completion_tokens: req.outputTokens,
|
||||
total_tokens: totalTokens,
|
||||
},
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: palmRespBody.candidates[0].output,
|
||||
},
|
||||
finish_reason: null, // palm doesn't return this
|
||||
index: 0,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
|
||||
if (req.body.stream) {
|
||||
throw new Error("Google PaLM API doesn't support streaming requests");
|
||||
}
|
||||
|
||||
// PaLM API specifies the model in the URL path, not the request body. This
|
||||
// doesn't work well with our rewriter architecture, so we need to manually
|
||||
// fix it here.
|
||||
|
||||
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
|
||||
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
|
||||
|
||||
// The chat api (generateMessage) is not very useful at this time as it has
|
||||
// few params and no adjustable safety settings.
|
||||
|
||||
proxyReq.path = proxyReq.path.replace(
|
||||
/^\/v1\/chat\/completions/,
|
||||
`/v1beta2/models/${req.body.model}:generateText`
|
||||
);
|
||||
}
|
||||
|
||||
const googlePalmProxy = createQueueMiddleware({
|
||||
proxyMiddleware: createProxyMiddleware({
|
||||
target: "https://generativelanguage.googleapis.com",
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({
|
||||
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
|
||||
}),
|
||||
proxyRes: createOnProxyResHandler([palmResponseHandler]),
|
||||
error: handleProxyError,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
const palmRouter = Router();
|
||||
palmRouter.get("/v1/models", handleModelRequest);
|
||||
// OpenAI-to-Google PaLM compatibility endpoint.
|
||||
palmRouter.post(
|
||||
"/v1/chat/completions",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
|
||||
{ afterTransform: [forceModel("text-bison-001")] }
|
||||
),
|
||||
googlePalmProxy
|
||||
);
|
||||
|
||||
export const googlePalm = palmRouter;
|
||||
+2
-2
@@ -4,7 +4,7 @@ import { checkRisuToken } from "./check-risu-token";
|
||||
import { openai } from "./openai";
|
||||
import { openaiImage } from "./openai-image";
|
||||
import { anthropic } from "./anthropic";
|
||||
import { googlePalm } from "./palm";
|
||||
import { googleAI } from "./google-ai";
|
||||
import { aws } from "./aws";
|
||||
import { azure } from "./azure";
|
||||
|
||||
@@ -31,7 +31,7 @@ proxyRouter.use((req, _res, next) => {
|
||||
proxyRouter.use("/openai", addV1, openai);
|
||||
proxyRouter.use("/openai-image", addV1, openaiImage);
|
||||
proxyRouter.use("/anthropic", addV1, anthropic);
|
||||
proxyRouter.use("/google-palm", addV1, googlePalm);
|
||||
proxyRouter.use("/google-ai", addV1, googleAI);
|
||||
proxyRouter.use("/aws/claude", addV1, aws);
|
||||
proxyRouter.use("/azure/openai", addV1, azure);
|
||||
// Redirect browser requests to the homepage.
|
||||
|
||||
@@ -62,7 +62,6 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
|
||||
{ key: key.hash, error: error.message },
|
||||
"Key is rate limited. Rechecking in 10 seconds."
|
||||
);
|
||||
0;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
break;
|
||||
|
||||
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
|
||||
import { getAzureOpenAIModelFamily } from "../../models";
|
||||
import { OpenAIModel } from "../openai/provider";
|
||||
import { AzureOpenAIKeyChecker } from "./checker";
|
||||
import { AwsKeyChecker } from "../aws/checker";
|
||||
|
||||
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
|
||||
|
||||
|
||||
+28
-24
@@ -2,13 +2,17 @@ import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { GooglePalmModelFamily } from "../../models";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
export type GooglePalmModel = "text-bison-001";
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by Google
|
||||
// but Vertex is the GCP product for enterprise. while Google AI is the
|
||||
// consumer-ish product. The API is different, and keys are not compatible.
|
||||
// https://ai.google.dev/docs/migrate_to_cloud
|
||||
|
||||
export type GooglePalmKeyUpdate = Omit<
|
||||
Partial<GooglePalmKey>,
|
||||
export type GoogleAIModel = "gemini-pro";
|
||||
|
||||
export type GoogleAIKeyUpdate = Omit<
|
||||
Partial<GoogleAIKey>,
|
||||
| "key"
|
||||
| "hash"
|
||||
| "lastUsed"
|
||||
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
|
||||
| "rateLimitedUntil"
|
||||
>;
|
||||
|
||||
type GooglePalmKeyUsage = {
|
||||
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
|
||||
type GoogleAIKeyUsage = {
|
||||
[K in GoogleAIModelFamily as `${K}Tokens`]: number;
|
||||
};
|
||||
|
||||
export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
|
||||
readonly service: "google-palm";
|
||||
readonly modelFamilies: GooglePalmModelFamily[];
|
||||
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
|
||||
readonly service: "google-ai";
|
||||
readonly modelFamilies: GoogleAIModelFamily[];
|
||||
/** The time at which this key was last rate limited. */
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
|
||||
*/
|
||||
const KEY_REUSE_DELAY = 500;
|
||||
|
||||
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
readonly service = "google-palm";
|
||||
export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
readonly service = "google-ai";
|
||||
|
||||
private keys: GooglePalmKey[] = [];
|
||||
private keys: GoogleAIKey[] = [];
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
const keyConfig = config.googlePalmKey?.trim();
|
||||
const keyConfig = config.googleAIKey?.trim();
|
||||
if (!keyConfig) {
|
||||
this.log.warn(
|
||||
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
|
||||
"GOOGLE_AI_KEY is not set. Google AI API will not be available."
|
||||
);
|
||||
return;
|
||||
}
|
||||
let bareKeys: string[];
|
||||
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
|
||||
for (const key of bareKeys) {
|
||||
const newKey: GooglePalmKey = {
|
||||
const newKey: GoogleAIKey = {
|
||||
key,
|
||||
service: this.service,
|
||||
modelFamilies: ["bison"],
|
||||
modelFamilies: ["gemini-pro"],
|
||||
isDisabled: false,
|
||||
isRevoked: false,
|
||||
promptCount: 0,
|
||||
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
bisonTokens: 0,
|
||||
"gemini-proTokens": 0,
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
|
||||
}
|
||||
|
||||
public init() {}
|
||||
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
}
|
||||
|
||||
public get(_model: GooglePalmModel) {
|
||||
public get(_model: GoogleAIModel) {
|
||||
const availableKeys = this.keys.filter((k) => !k.isDisabled);
|
||||
if (availableKeys.length === 0) {
|
||||
throw new Error("No Google PaLM keys available");
|
||||
throw new Error("No Google AI keys available");
|
||||
}
|
||||
|
||||
// (largely copied from the OpenAI provider, without trial key support)
|
||||
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
return { ...selectedKey };
|
||||
}
|
||||
|
||||
public disable(key: GooglePalmKey) {
|
||||
public disable(key: GoogleAIKey) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
|
||||
if (!keyFromPool || keyFromPool.isDisabled) return;
|
||||
keyFromPool.isDisabled = true;
|
||||
this.log.warn({ key: key.hash }, "Key disabled");
|
||||
}
|
||||
|
||||
public update(hash: string, update: Partial<GooglePalmKey>) {
|
||||
public update(hash: string, update: Partial<GoogleAIKey>) {
|
||||
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
|
||||
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
|
||||
}
|
||||
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key.bisonTokens += tokens;
|
||||
key["gemini-proTokens"] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
@@ -1,6 +1,6 @@
|
||||
import { OpenAIModel } from "./openai/provider";
|
||||
import { AnthropicModel } from "./anthropic/provider";
|
||||
import { GooglePalmModel } from "./palm/provider";
|
||||
import { GoogleAIModel } from "./google-ai/provider";
|
||||
import { AwsBedrockModel } from "./aws/provider";
|
||||
import { AzureOpenAIModel } from "./azure/provider";
|
||||
import { KeyPool } from "./key-pool";
|
||||
@@ -10,20 +10,20 @@ import type { ModelFamily } from "../models";
|
||||
export type APIFormat =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "google-ai"
|
||||
| "openai-text"
|
||||
| "openai-image";
|
||||
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
|
||||
export type LLMService =
|
||||
| "openai"
|
||||
| "anthropic"
|
||||
| "google-palm"
|
||||
| "google-ai"
|
||||
| "aws"
|
||||
| "azure";
|
||||
export type Model =
|
||||
| OpenAIModel
|
||||
| AnthropicModel
|
||||
| GooglePalmModel
|
||||
| GoogleAIModel
|
||||
| AwsBedrockModel
|
||||
| AzureOpenAIModel;
|
||||
|
||||
@@ -77,6 +77,6 @@ export interface KeyProvider<T extends Key = Key> {
|
||||
export const keyPool = new KeyPool();
|
||||
export { AnthropicKey } from "./anthropic/provider";
|
||||
export { OpenAIKey } from "./openai/provider";
|
||||
export { GooglePalmKey } from "./palm/provider";
|
||||
export { GoogleAIKey } from "././google-ai/provider";
|
||||
export { AwsBedrockKey } from "./aws/provider";
|
||||
export { AzureOpenAIKey } from "./azure/provider";
|
||||
|
||||
@@ -7,7 +7,7 @@ import { logger } from "../../logger";
|
||||
import { Key, Model, KeyProvider, LLMService } from "./index";
|
||||
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
|
||||
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
|
||||
import { GooglePalmKeyProvider } from "./palm/provider";
|
||||
import { GoogleAIKeyProvider } from "./google-ai/provider";
|
||||
import { AwsBedrockKeyProvider } from "./aws/provider";
|
||||
import { ModelFamily } from "../models";
|
||||
import { assertNever } from "../utils";
|
||||
@@ -24,7 +24,7 @@ export class KeyPool {
|
||||
constructor() {
|
||||
this.keyProviders.push(new OpenAIKeyProvider());
|
||||
this.keyProviders.push(new AnthropicKeyProvider());
|
||||
this.keyProviders.push(new GooglePalmKeyProvider());
|
||||
this.keyProviders.push(new GoogleAIKeyProvider());
|
||||
this.keyProviders.push(new AwsBedrockKeyProvider());
|
||||
this.keyProviders.push(new AzureOpenAIKeyProvider());
|
||||
}
|
||||
@@ -119,9 +119,9 @@ export class KeyPool {
|
||||
} else if (model.startsWith("claude-")) {
|
||||
// https://console.anthropic.com/docs/api/reference#parameters
|
||||
return "anthropic";
|
||||
} else if (model.includes("bison")) {
|
||||
} else if (model.includes("gemini")) {
|
||||
// https://developers.generativeai.google.com/models/language
|
||||
return "google-palm";
|
||||
return "google-ai";
|
||||
} else if (model.startsWith("anthropic.claude")) {
|
||||
// AWS offers models from a few providers
|
||||
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
@@ -142,8 +142,8 @@ export class KeyPool {
|
||||
return "openai";
|
||||
case "claude":
|
||||
return "anthropic";
|
||||
case "bison":
|
||||
return "google-palm";
|
||||
case "gemini-pro":
|
||||
return "google-ai";
|
||||
case "aws-claude":
|
||||
return "aws";
|
||||
case "azure-turbo":
|
||||
|
||||
+8
-10
@@ -11,7 +11,7 @@ export type OpenAIModelFamily =
|
||||
| "gpt4-turbo"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude";
|
||||
export type GooglePalmModelFamily = "bison";
|
||||
export type GoogleAIModelFamily = "gemini-pro";
|
||||
export type AwsBedrockModelFamily = "aws-claude";
|
||||
export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
OpenAIModelFamily,
|
||||
@@ -20,7 +20,7 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
|
||||
export type ModelFamily =
|
||||
| OpenAIModelFamily
|
||||
| AnthropicModelFamily
|
||||
| GooglePalmModelFamily
|
||||
| GoogleAIModelFamily
|
||||
| AwsBedrockModelFamily
|
||||
| AzureOpenAIModelFamily;
|
||||
|
||||
@@ -33,7 +33,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"gpt4-turbo",
|
||||
"dall-e",
|
||||
"claude",
|
||||
"bison",
|
||||
"gemini-pro",
|
||||
"aws-claude",
|
||||
"azure-turbo",
|
||||
"azure-gpt4",
|
||||
@@ -53,7 +53,7 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||
"^dall-e-\\d{1}$": "dall-e",
|
||||
};
|
||||
|
||||
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
|
||||
pino({ level: "debug" }).child({ module: "startup" });
|
||||
|
||||
export function getOpenAIModelFamily(
|
||||
model: string,
|
||||
@@ -70,10 +70,8 @@ export function getClaudeModelFamily(model: string): ModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function getGooglePalmModelFamily(model: string): ModelFamily {
|
||||
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
|
||||
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
|
||||
return "bison";
|
||||
export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
||||
return "gemini-pro";
|
||||
}
|
||||
|
||||
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
|
||||
@@ -130,8 +128,8 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
||||
case "openai-image":
|
||||
modelFamily = getOpenAIModelFamily(model);
|
||||
break;
|
||||
case "google-palm":
|
||||
modelFamily = getGooglePalmModelFamily(model);
|
||||
case "google-ai":
|
||||
modelFamily = getGoogleAIModelFamily(model);
|
||||
break;
|
||||
default:
|
||||
assertNever(req.outboundApi);
|
||||
|
||||
@@ -74,7 +74,7 @@ export function buildFakeSse(type: string, string: string, req: Request) {
|
||||
log_id: "proxy-req-" + req.id,
|
||||
};
|
||||
break;
|
||||
case "google-palm":
|
||||
case "google-ai":
|
||||
case "openai-image":
|
||||
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
|
||||
default:
|
||||
|
||||
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
|
||||
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
|
||||
import { logger } from "../../logger";
|
||||
import { libSharp } from "../file-storage";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { z } from "zod";
|
||||
|
||||
const log = logger.child({ module: "tokenizer", service: "openai" });
|
||||
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
|
||||
@@ -29,11 +33,11 @@ export async function getTokenCount(
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
const gpt4 = model.startsWith("gpt-4");
|
||||
const oldFormatting = model.startsWith("turbo-0301");
|
||||
const vision = model.includes("vision");
|
||||
|
||||
const tokensPerMessage = gpt4 ? 3 : 4;
|
||||
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
|
||||
const tokensPerMessage = oldFormatting ? 4 : 3;
|
||||
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
|
||||
|
||||
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
|
||||
|
||||
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
|
||||
token_count: Math.ceil(tokens),
|
||||
};
|
||||
}
|
||||
|
||||
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
|
||||
if (typeof prompt === "string") {
|
||||
return getTextTokenCount(prompt);
|
||||
}
|
||||
|
||||
const tokensPerMessage = 3;
|
||||
|
||||
let numTokens = 0;
|
||||
for (const message of prompt) {
|
||||
numTokens += tokensPerMessage;
|
||||
numTokens += encoder.encode(message.parts[0].text).length;
|
||||
}
|
||||
|
||||
numTokens += 3;
|
||||
|
||||
return {
|
||||
tokenizer: "tiktoken (google-ai estimate)",
|
||||
token_count: numTokens,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { Request } from "express";
|
||||
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import type {
|
||||
GoogleAIChatMessage,
|
||||
OpenAIChatMessage,
|
||||
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
|
||||
import { assertNever } from "../utils";
|
||||
import {
|
||||
init as initClaude,
|
||||
@@ -9,6 +12,7 @@ import {
|
||||
init as initOpenAi,
|
||||
getTokenCount as getOpenAITokenCount,
|
||||
getOpenAIImageCost,
|
||||
estimateGoogleAITokenCount,
|
||||
} from "./openai";
|
||||
import { APIFormat } from "../key-management";
|
||||
|
||||
@@ -24,8 +28,9 @@ type TokenCountRequest = { req: Request } & (
|
||||
| {
|
||||
prompt: string;
|
||||
completion?: never;
|
||||
service: "openai-text" | "anthropic" | "google-palm";
|
||||
service: "openai-text" | "anthropic" | "google-ai";
|
||||
}
|
||||
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
|
||||
| { prompt?: never; completion: string; service: APIFormat }
|
||||
| { prompt?: never; completion?: never; service: "openai-image" }
|
||||
);
|
||||
@@ -65,11 +70,11 @@ export async function countTokens({
|
||||
}),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
case "google-palm":
|
||||
// TODO: Can't find a tokenization library for PaLM. There is an API
|
||||
case "google-ai":
|
||||
// TODO: Can't find a tokenization library for Gemini. There is an API
|
||||
// endpoint for it but it adds significant latency to the request.
|
||||
return {
|
||||
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
|
||||
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
|
||||
tokenization_duration_ms: getElapsedMs(time),
|
||||
};
|
||||
default:
|
||||
|
||||
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
|
||||
"gpt4-turbo": z.number().optional().default(0),
|
||||
"dall-e": z.number().optional().default(0),
|
||||
claude: z.number().optional().default(0),
|
||||
bison: z.number().optional().default(0),
|
||||
"gemini-pro": z.number().optional().default(0),
|
||||
"aws-claude": z.number().optional().default(0),
|
||||
});
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import { config, getFirebaseApp } from "../../config";
|
||||
import {
|
||||
getAzureOpenAIModelFamily,
|
||||
getClaudeModelFamily,
|
||||
getGooglePalmModelFamily,
|
||||
getGoogleAIModelFamily,
|
||||
getOpenAIModelFamily,
|
||||
MODEL_FAMILIES,
|
||||
ModelFamily,
|
||||
@@ -33,7 +33,7 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
|
||||
"gpt4-turbo": 0,
|
||||
"dall-e": 0,
|
||||
claude: 0,
|
||||
bison: 0,
|
||||
"gemini-pro": 0,
|
||||
"aws-claude": 0,
|
||||
"azure-turbo": 0,
|
||||
"azure-gpt4": 0,
|
||||
@@ -397,8 +397,8 @@ function getModelFamilyForQuotaUsage(
|
||||
return getOpenAIModelFamily(model);
|
||||
case "anthropic":
|
||||
return getClaudeModelFamily(model);
|
||||
case "google-palm":
|
||||
return getGooglePalmModelFamily(model);
|
||||
case "google-ai":
|
||||
return getGoogleAIModelFamily(model);
|
||||
default:
|
||||
assertNever(api);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user