This commit is contained in:
khanon
2023-12-13 21:56:07 +00:00
parent 0d3682197c
commit fad16cc268
40 changed files with 588 additions and 357 deletions
+3 -3
View File
@@ -34,10 +34,10 @@
# Which model types users are allowed to access. # Which model types users are allowed to access.
# The following model families are recognized: # The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo # turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image # By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list. # generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo # ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked. # URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com # BLOCKED_ORIGINS=reddit.com,9gag.com
@@ -95,7 +95,7 @@
# TOKEN_QUOTA_GPT4_TURBO=0 # TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0 # TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0 # TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_BISON=0 # TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0 # TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily) # How often to refresh token quotas. (hourly | daily)
+1 -1
View File
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
# All models as of this writing # All models as of this writing
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
``` ```
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models. Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
+34
View File
@@ -36,6 +36,7 @@
"sanitize-html": "^2.11.0", "sanitize-html": "^2.11.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"showdown": "^2.1.0", "showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10", "tiktoken": "^1.0.10",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zlib": "^1.0.5", "zlib": "^1.0.5",
@@ -51,6 +52,7 @@
"@types/node-schedule": "^2.1.0", "@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0", "@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0", "@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1", "@types/uuid": "^9.0.1",
"concurrently": "^8.0.1", "concurrently": "^8.0.1",
"esbuild": "^0.17.16", "esbuild": "^0.17.16",
@@ -1185,6 +1187,25 @@
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==", "integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
"dev": true "dev": true
}, },
"node_modules/@types/stream-chain": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
"dev": true,
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/stream-json": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
"dev": true,
"dependencies": {
"@types/node": "*",
"@types/stream-chain": "*"
}
},
"node_modules/@types/uuid": { "node_modules/@types/uuid": {
"version": "9.0.1", "version": "9.0.1",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz", "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
@@ -5135,6 +5156,11 @@
"node": ">= 0.8" "node": ">= 0.8"
} }
}, },
"node_modules/stream-chain": {
"version": "2.2.5",
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
},
"node_modules/stream-events": { "node_modules/stream-events": {
"version": "1.0.5", "version": "1.0.5",
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz", "resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
@@ -5144,6 +5170,14 @@
"stubs": "^3.0.0" "stubs": "^3.0.0"
} }
}, },
"node_modules/stream-json": {
"version": "1.8.0",
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
"dependencies": {
"stream-chain": "^2.2.5"
}
},
"node_modules/stream-shift": { "node_modules/stream-shift": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz", "resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
+2
View File
@@ -44,6 +44,7 @@
"sanitize-html": "^2.11.0", "sanitize-html": "^2.11.0",
"sharp": "^0.32.6", "sharp": "^0.32.6",
"showdown": "^2.1.0", "showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10", "tiktoken": "^1.0.10",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zlib": "^1.0.5", "zlib": "^1.0.5",
@@ -59,6 +60,7 @@
"@types/node-schedule": "^2.1.0", "@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0", "@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0", "@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1", "@types/uuid": "^9.0.1",
"concurrently": "^8.0.1", "concurrently": "^8.0.1",
"esbuild": "^0.17.16", "esbuild": "^0.17.16",
+4 -3
View File
@@ -1,6 +1,6 @@
const axios = require("axios"); const axios = require("axios");
const concurrentRequests = 5; const concurrentRequests = 75;
const headers = { const headers = {
Authorization: "Bearer test", Authorization: "Bearer test",
"Content-Type": "application/json", "Content-Type": "application/json",
@@ -16,7 +16,7 @@ const payload = {
const makeRequest = async (i) => { const makeRequest = async (i) => {
try { try {
const response = await axios.post( const response = await axios.post(
"http://localhost:7860/proxy/azure/openai/v1/chat/completions", "http://localhost:7860/proxy/google-ai/v1/chat/completions",
payload, payload,
{ headers } { headers }
); );
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
response.data response.data
); );
} catch (error) { } catch (error) {
console.error(`Error in req ${i}:`, error.message); const msg = error.response
console.error(`Error in req ${i}:`, error.message, msg || "");
} }
}; };
+1 -1
View File
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
keyPool.recheck("anthropic"); keyPool.recheck("anthropic");
const size = keyPool const size = keyPool
.list() .list()
.filter((k) => k.service !== "google-palm").length; .filter((k) => k.service !== "google-ai").length;
flash.type = "success"; flash.type = "success";
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`; flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
break; break;
+10 -6
View File
@@ -19,8 +19,12 @@ type Config = {
openaiKey?: string; openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */ /** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string; anthropicKey?: string;
/** Comma-delimited list of Google PaLM API keys. */ /**
googlePalmKey?: string; * Comma-delimited list of Google AI API keys. Note that these are not the
* same as the GCP keys/credentials used for Vertex AI; the models are the
* same but the APIs are different. Vertex is the GCP product for enterprise.
**/
googleAIKey?: string;
/** /**
* Comma-delimited list of AWS credentials. Each credential item should be a * Comma-delimited list of AWS credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and AWS region. * colon-delimited list of access key, secret key, and AWS region.
@@ -197,7 +201,7 @@ export const config: Config = {
port: getEnvWithDefault("PORT", 7860), port: getEnvWithDefault("PORT", 7860),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""), openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""), anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""), googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""), proxyKey: getEnvWithDefault("PROXY_KEY", ""),
@@ -229,7 +233,7 @@ export const config: Config = {
"gpt4-32k", "gpt4-32k",
"gpt4-turbo", "gpt4-turbo",
"claude", "claude",
"bison", "gemini-pro",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -366,7 +370,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"logLevel", "logLevel",
"openaiKey", "openaiKey",
"anthropicKey", "anthropicKey",
"googlePalmKey", "googleAIKey",
"awsCredentials", "awsCredentials",
"azureCredentials", "azureCredentials",
"proxyKey", "proxyKey",
@@ -433,7 +437,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
[ [
"OPENAI_KEY", "OPENAI_KEY",
"ANTHROPIC_KEY", "ANTHROPIC_KEY",
"GOOGLE_PALM_KEY", "GOOGLE_AI_KEY",
"AWS_CREDENTIALS", "AWS_CREDENTIALS",
"AZURE_CREDENTIALS", "AZURE_CREDENTIALS",
].includes(String(env)) ].includes(String(env))
+35 -29
View File
@@ -7,7 +7,7 @@ import {
AnthropicKey, AnthropicKey,
AwsBedrockKey, AwsBedrockKey,
AzureOpenAIKey, AzureOpenAIKey,
GooglePalmKey, GoogleAIKey,
keyPool, keyPool,
OpenAIKey, OpenAIKey,
} from "./shared/key-management"; } from "./shared/key-management";
@@ -33,8 +33,8 @@ const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure"; k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey => const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic"; k.service === "anthropic";
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey => const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-palm"; k.service === "google-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws"; const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
type ModelAggregates = { type ModelAggregates = {
@@ -54,7 +54,7 @@ type ServiceAggregates = {
openaiKeys?: number; openaiKeys?: number;
openaiOrgs?: number; openaiOrgs?: number;
anthropicKeys?: number; anthropicKeys?: number;
palmKeys?: number; googleAIKeys?: number;
awsKeys?: number; awsKeys?: number;
azureKeys?: number; azureKeys?: number;
proompts: number; proompts: number;
@@ -100,7 +100,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const openaiKeys = serviceStats.get("openaiKeys") || 0; const openaiKeys = serviceStats.get("openaiKeys") || 0;
const anthropicKeys = serviceStats.get("anthropicKeys") || 0; const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
const palmKeys = serviceStats.get("palmKeys") || 0; const googleAIKeys = serviceStats.get("googleAIKeys") || 0;
const awsKeys = serviceStats.get("awsKeys") || 0; const awsKeys = serviceStats.get("awsKeys") || 0;
const azureKeys = serviceStats.get("azureKeys") || 0; const azureKeys = serviceStats.get("azureKeys") || 0;
const proompts = serviceStats.get("proompts") || 0; const proompts = serviceStats.get("proompts") || 0;
@@ -116,7 +116,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
? { ["openai-image"]: baseUrl + "/openai-image" } ? { ["openai-image"]: baseUrl + "/openai-image" }
: {}), : {}),
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}), ...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}), ...(googleAIKeys ? { "google-ai": baseUrl + "/google-ai" } : {}),
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}), ...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}), ...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
}; };
@@ -127,7 +127,13 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}), ...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
}; };
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys }; const keyInfo = {
openaiKeys,
anthropicKeys,
googleAIKeys,
awsKeys,
azureKeys,
};
for (const key of Object.keys(keyInfo)) { for (const key of Object.keys(keyInfo)) {
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key]; if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
} }
@@ -135,7 +141,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const providerInfo = { const providerInfo = {
...(openaiKeys ? getOpenAIInfo() : {}), ...(openaiKeys ? getOpenAIInfo() : {}),
...(anthropicKeys ? getAnthropicInfo() : {}), ...(anthropicKeys ? getAnthropicInfo() : {}),
...(palmKeys ? getPalmInfo() : {}), ...(googleAIKeys ? getGoogleAIInfo() : {}),
...(awsKeys ? getAwsInfo() : {}), ...(awsKeys ? getAwsInfo() : {}),
...(azureKeys ? getAzureInfo() : {}), ...(azureKeys ? getAzureInfo() : {}),
}; };
@@ -197,7 +203,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount); increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0); increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0); increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0); increment(serviceStats, "googleAIKeys", k.service === "google-ai" ? 1 : 0);
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0); increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0); increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
@@ -251,14 +257,14 @@ function addKeyToAggregates(k: KeyPoolKey) {
); );
break; break;
} }
case "google-palm": { case "google-ai": {
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type"); if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
const family = "bison"; const family = "gemini-pro";
sumTokens += k.bisonTokens; sumTokens += k["gemini-proTokens"];
sumCost += getTokenCostUsd(family, k.bisonTokens); sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1); increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0); increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.bisonTokens); increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break; break;
} }
case "aws": { case "aws": {
@@ -388,26 +394,26 @@ function getAnthropicInfo() {
}; };
} }
function getPalmInfo() { function getGoogleAIInfo() {
const bisonInfo: Partial<ModelAggregates> = { const googleAIInfo: Partial<ModelAggregates> = {
active: modelStats.get("bison__active") || 0, active: modelStats.get("gemini-pro__active") || 0,
revoked: modelStats.get("bison__revoked") || 0, revoked: modelStats.get("gemini-pro__revoked") || 0,
}; };
const queue = getQueueInformation("bison"); const queue = getQueueInformation("gemini-pro");
bisonInfo.queued = queue.proomptersInQueue; googleAIInfo.queued = queue.proomptersInQueue;
bisonInfo.queueTime = queue.estimatedQueueTime; googleAIInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("bison__tokens") || 0; const tokens = modelStats.get("gemini-pro__tokens") || 0;
const cost = getTokenCostUsd("bison", tokens); const cost = getTokenCostUsd("gemini-pro", tokens);
return { return {
bison: { gemini: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`, usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active, activeKeys: googleAIInfo.active,
revokedKeys: bisonInfo.revoked, revokedKeys: googleAIInfo.revoked,
proomptersInQueue: bisonInfo.queued, proomptersInQueue: googleAIInfo.queued,
estimatedQueueTime: bisonInfo.queueTime, estimatedQueueTime: googleAIInfo.queueTime,
}, },
}; };
} }
+141
View File
@@ -0,0 +1,141 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googleAIKey) return { object: "list", data: [] };
const googleAIVariants = ["gemini-pro"];
const models = googleAIVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "google",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google AI response to OpenAI format");
body = transformGoogleAIResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
function transformGoogleAIResponse(
googleAIResp: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "goo-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: googleAIResp.candidates[0].content.parts[0].text,
},
finish_reason: googleAIResp.candidates[0].finishReason,
index: 0,
},
],
};
}
const googleAIProxy = createQueueMiddleware({
beforeProxy: addGoogleAIKey,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
const { protocol, hostname, path } = signedRequest;
return `${protocol}//${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
error: handleProxyError,
},
}),
});
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [forceModel("gemini-pro")] }
),
googleAIProxy
);
export const googleAI = googleAIRouter;
+6 -3
View File
@@ -177,8 +177,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
return ""; return "";
} }
return body.completion.trim(); return body.completion.trim();
case "google-palm": case "google-ai":
return body.candidates[0].output; if ("choices" in body) {
return body.choices[0].message.content;
}
return body.candidates[0].content.parts[0].text;
case "openai-image": case "openai-image":
return body.data?.map((item: any) => item.url).join("\n"); return body.data?.map((item: any) => item.url).join("\n");
default: default:
@@ -197,7 +200,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
case "anthropic": case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't. // Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model; return body.model || req.body.model;
case "google-palm": case "google-ai":
// Google doesn't confirm the model in the response. // Google doesn't confirm the model in the response.
return req.body.model; return req.body.model;
default: default:
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other onProxyReq handler runs, // The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers. // as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed. // Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true"; // TODO: this flag is set in too many places
req.isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming; req.body.stream = req.isStreaming;
try { try {
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
case "anthropic": case "anthropic":
assignedKey = keyPool.get("claude-v1"); assignedKey = keyPool.get("claude-v1");
break; break;
case "google-palm":
assignedKey = keyPool.get("text-bison-001");
delete req.body.stream;
break;
case "openai-text": case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct"); assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
break; break;
@@ -42,6 +38,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
throw new Error( throw new Error(
"OpenAI Chat as an API translation target is not supported" "OpenAI Chat as an API translation target is not supported"
); );
case "google-ai":
throw new Error("add-key should not be used for this model.");
case "openai-image": case "openai-image":
assignedKey = keyPool.get("dall-e-3"); assignedKey = keyPool.get("dall-e-3");
break; break;
@@ -73,21 +71,13 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
} }
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`); proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break; break;
case "google-palm":
const originalPath = proxyReq.path;
proxyReq.path = originalPath.replace(
/(\?.*)?$/,
`?key=${assignedKey.key}`
);
break;
case "azure": case "azure":
const azureKey = assignedKey.key; const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey); proxyReq.setHeader("api-key", azureKey);
break; break;
case "aws": case "aws":
throw new Error( case "google-ai":
"add-key should not be used for AWS security credentials. Use sign-aws-request instead." throw new Error("add-key should not be used for this service.");
);
default: default:
assertNever(assignedKey.service); assertNever(assignedKey.service);
} }
@@ -1,9 +1,9 @@
import type { HPMRequestCallback } from "../index"; import type { HPMRequestCallback } from "../index";
/** /**
* For AWS/Azure requests, the body is signed earlier in the request pipeline, * For AWS/Azure/Google requests, the body is signed earlier in the request
* before the proxy middleware. This function just assigns the path and headers * pipeline, before the proxy middleware. This function just assigns the path
* to the proxy request. * and headers to the proxy request.
*/ */
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => { export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) { if (!req.signedRequest) {
@@ -0,0 +1,40 @@
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!apisValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model;
req.key = keyPool.get(model);
req.log.info(
{ key: req.key.hash, model },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
req.isStreaming = req.isStreaming || req.body.stream;
delete req.body.stream;
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",
},
body: JSON.stringify(req.body),
};
};
@@ -1,7 +1,7 @@
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization"; import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils"; import { assertNever } from "../../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload"; import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload";
/** /**
* Given a request with an already-transformed body, counts the number of * Given a request with an already-transformed body, counts the number of
@@ -30,9 +30,9 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
} }
case "google-palm": { case "google-ai": {
req.outputTokens = req.body.maxOutputTokens; req.outputTokens = req.body.generationConfig.maxOutputTokens;
const prompt: string = req.body.prompt.text; const prompt: GoogleAIChatMessage[] = req.body.contents;
result = await countTokens({ req, prompt, service }); result = await countTokens({ req, prompt, service });
break; break;
} }
@@ -68,7 +68,7 @@ function getPromptFromRequest(req: Request) {
case "openai-text": case "openai-text":
case "openai-image": case "openai-image":
return body.prompt; return body.prompt;
case "google-palm": case "google-ai":
return body.prompt.text; return body.prompt.text;
default: default:
assertNever(service); assertNever(service);
@@ -1,7 +1,10 @@
import { Request } from "express"; import { Request } from "express";
import { z } from "zod"; import { z } from "zod";
import { config } from "../../../../config"; import { config } from "../../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common"; import {
isTextGenerationRequest,
isImageGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index"; import { RequestPreprocessor } from "../index";
import { APIFormat } from "../../../../shared/key-management"; import { APIFormat } from "../../../../shared/key-management";
@@ -121,30 +124,43 @@ const OpenAIV1ImagesGenerationSchema = z.object({
user: z.string().optional(), user: z.string().optional(),
}); });
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText // https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
const PalmV1GenerateTextSchema = z.object({ const GoogleAIV1GenerateContentSchema = z.object({
model: z.string(), model: z.string(), //actually specified in path but we need it for the router
prompt: z.object({ text: z.string() }), stream: z.boolean().optional().default(false), // also used for router
temperature: z.number().optional(), contents: z.array(
maxOutputTokens: z.coerce z.object({
.number() parts: z.array(z.object({ text: z.string() })),
.int() role: z.enum(["user", "model"]),
.optional() })
.default(16) ),
.transform((v) => Math.min(v, 1024)), // TODO: Add config tools: z.array(z.object({})).max(0).optional(),
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
safetySettings: z.array(z.object({})).max(0).optional(), safetySettings: z.array(z.object({})).max(0).optional(),
stopSequences: z.array(z.string()).max(5).optional(), generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
.int()
.optional()
.default(16)
.transform((v) => Math.min(v, 1024)), // TODO: Add config
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
stopSequences: z.array(z.string()).max(5).optional(),
}),
}); });
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = { const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema, anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema, openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema, "openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema, "openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema, "google-ai": GoogleAIV1GenerateContentSchema,
}; };
/** Transforms an incoming request body to one that matches the target API. */ /** Transforms an incoming request body to one that matches the target API. */
@@ -174,8 +190,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return; return;
} }
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") { if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
req.body = openaiToPalm(req); req.body = openaiToGoogleAI(req);
return; return;
} }
@@ -310,7 +326,9 @@ function openaiToOpenaiImage(req: Request) {
return OpenAIV1ImagesGenerationSchema.parse(transformed); return OpenAIV1ImagesGenerationSchema.parse(transformed);
} }
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> { function openaiToGoogleAI(
req: Request
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
const { body } = req; const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({ const result = OpenAIV1ChatCompletionSchema.safeParse({
...body, ...body,
@@ -319,13 +337,16 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
if (!result.success) { if (!result.success) {
req.log.warn( req.log.warn(
{ issues: result.error.issues, body }, { issues: result.error.issues, body },
"Invalid OpenAI-to-Palm request" "Invalid OpenAI-to-Google AI request"
); );
throw result.error; throw result.error;
} }
const { messages, ...rest } = result.data; const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages); const contents = messages.map((m) => ({
parts: [{ text: flattenOpenAIMessageContent(m.content) }],
role: m.role === "user" ? "user" as const : "model" as const,
}));
let stops = rest.stop let stops = rest.stop
? Array.isArray(rest.stop) ? Array.isArray(rest.stop)
@@ -339,20 +360,22 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
z.array(z.string()).max(5).parse(stops); z.array(z.string()).max(5).parse(stops);
return { return {
prompt: { text: prompt }, model: "gemini-pro",
maxOutputTokens: rest.max_tokens, stream: rest.stream,
stopSequences: stops, contents,
model: "text-bison-001", tools: [],
topP: rest.top_p, generationConfig: {
temperature: rest.temperature, maxOutputTokens: rest.max_tokens,
stopSequences: stops,
topP: rest.top_p,
topK: 40, // openai schema doesn't have this, google ai defaults to 40
temperature: rest.temperature,
},
safetySettings: [ safetySettings: [
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" }, { category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
], ],
}; };
} }
@@ -428,3 +451,4 @@ function flattenOpenAIMessageContent(
.join("\n") .join("\n")
: content; : content;
} }
@@ -6,7 +6,7 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic; const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI; const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const BISON_MAX_CONTEXT = 8100; const GOOGLE_AI_MAX_CONTEXT = 32000;
/** /**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body * Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -31,8 +31,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "anthropic": case "anthropic":
proxyMax = CLAUDE_MAX_CONTEXT; proxyMax = CLAUDE_MAX_CONTEXT;
break; break;
case "google-palm": case "google-ai":
proxyMax = BISON_MAX_CONTEXT; proxyMax = GOOGLE_AI_MAX_CONTEXT;
break; break;
case "openai-image": case "openai-image":
return; return;
@@ -62,8 +62,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 100000; modelMax = 100000;
} else if (model.match(/^claude-2/)) { } else if (model.match(/^claude-2/)) {
modelMax = 200000; modelMax = 200000;
} else if (model.match(/^text-bison-\d{3}$/)) { } else if (model.match(/^gemini-\d{3}$/)) {
modelMax = BISON_MAX_CONTEXT; modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^anthropic\.claude/)) { } else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude. // Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000; modelMax = 100000;
@@ -1,4 +1,3 @@
import express from "express";
import { pipeline } from "stream"; import { pipeline } from "stream";
import { promisify } from "util"; import { promisify } from "util";
import { import {
@@ -59,7 +58,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi; const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"]; const contentType = proxyRes.headers["content-type"];
const adapter = new SSEStreamAdapter({ contentType }); const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
const aggregator = new EventAggregator({ format: req.outboundApi }); const aggregator = new EventAggregator({ format: req.outboundApi });
const transformer = new SSEMessageTransformer({ const transformer = new SSEMessageTransformer({
inputFormat: req.outboundApi, inputFormat: req.outboundApi,
+22 -5
View File
@@ -288,7 +288,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// For Anthropic, this is usually due to missing preamble. // For Anthropic, this is usually due to missing preamble.
switch (service) { switch (service) {
case "openai": case "openai":
case "google-palm": case "google-ai":
case "azure": case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"]; const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) { if (filteredCodes.includes(errorPayload.error?.code)) {
@@ -350,8 +350,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "azure": case "azure":
handleAzureRateLimitError(req, errorPayload); handleAzureRateLimitError(req, errorPayload);
break; break;
case "google-palm": case "google-ai":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`; handleGoogleAIRateLimitError(req, errorPayload);
break; break;
default: default:
assertNever(service); assertNever(service);
@@ -373,8 +373,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "anthropic": case "anthropic":
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`; errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
break; break;
case "google-palm": case "google-ai":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`; errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break; break;
case "aws": case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`; errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
@@ -529,6 +529,23 @@ function handleAzureRateLimitError(
} }
} }
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
function handleGoogleAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const status = errorPayload.error?.status;
switch (status) {
case "RESOURCE_EXHAUSTED":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
break;
}
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => { const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) { if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model; const model = req.body.model;
+1 -1
View File
@@ -73,7 +73,7 @@ const getPromptForRequest = (
}; };
case "anthropic": case "anthropic":
return req.body.prompt; return req.body.prompt;
case "google-palm": case "google-ai":
return req.body.prompt.text; return req.body.prompt.text;
default: default:
assertNever(req.outboundApi); assertNever(req.outboundApi);
@@ -27,12 +27,12 @@ export class EventAggregator {
getFinalResponse() { getFinalResponse() {
switch (this.format) { switch (this.format) {
case "openai": case "openai":
case "google-ai":
return mergeEventsForOpenAIChat(this.events); return mergeEventsForOpenAIChat(this.events);
case "openai-text": case "openai-text":
return mergeEventsForOpenAIText(this.events); return mergeEventsForOpenAIText(this.events);
case "anthropic": case "anthropic":
return mergeEventsForAnthropic(this.events); return mergeEventsForAnthropic(this.events);
case "google-palm":
case "openai-image": case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`); throw new Error(`SSE aggregation not supported for ${this.format}`);
default: default:
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai"; export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai"; export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai"; export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat"; export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text"; export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic"; export { mergeEventsForAnthropic } from "./aggregators/anthropic";
@@ -7,9 +7,10 @@ import {
anthropicV2ToOpenAI, anthropicV2ToOpenAI,
OpenAIChatCompletionStreamEvent, OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat, openAITextToOpenAIChat,
googleAIToOpenAI,
passthroughToOpenAI,
StreamingCompletionTransformer, StreamingCompletionTransformer,
} from "./index"; } from "./index";
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
const genlog = logger.child({ module: "sse-transformer" }); const genlog = logger.child({ module: "sse-transformer" });
@@ -111,7 +112,8 @@ function getTransformer(
return version === "2023-01-01" return version === "2023-01-01"
? anthropicV1ToOpenAI ? anthropicV1ToOpenAI
: anthropicV2ToOpenAI; : anthropicV2ToOpenAI;
case "google-palm": case "google-ai":
return googleAIToOpenAI;
case "openai-image": case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`); throw new Error(`SSE transformation not supported for ${responseApi}`);
default: default:
@@ -1,13 +1,19 @@
import { Transform, TransformOptions } from "stream"; import { Transform, TransformOptions } from "stream";
import { StringDecoder } from "string_decoder"; import { StringDecoder } from "string_decoder";
// @ts-ignore // @ts-ignore
import { Parser } from "lifion-aws-event-stream"; import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../../logger"; import { logger } from "../../../../logger";
import { RetryableError } from "../index"; import { RetryableError } from "../index";
import { APIFormat } from "../../../../shared/key-management";
import StreamArray from "stream-json/streamers/StreamArray";
const log = logger.child({ module: "sse-stream-adapter" }); const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string }; type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
};
type AwsEventStreamMessage = { type AwsEventStreamMessage = {
headers: { headers: {
":message-type": "event" | "exception"; ":message-type": "event" | "exception";
@@ -22,7 +28,9 @@ type AwsEventStreamMessage = {
*/ */
export class SSEStreamAdapter extends Transform { export class SSEStreamAdapter extends Transform {
private readonly isAwsStream; private readonly isAwsStream;
private readonly isGoogleStream;
private parser = new Parser(); private parser = new Parser();
private jsonStream = StreamArray.withParser();
private partialMessage = ""; private partialMessage = "";
private decoder = new StringDecoder("utf8"); private decoder = new StringDecoder("utf8");
@@ -30,6 +38,7 @@ export class SSEStreamAdapter extends Transform {
super(options); super(options);
this.isAwsStream = this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream"; options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.parser.on("data", (data: AwsEventStreamMessage) => { this.parser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data); const message = this.processAwsEvent(data);
@@ -37,6 +46,12 @@ export class SSEStreamAdapter extends Transform {
this.push(Buffer.from(message + "\n\n"), "utf8"); this.push(Buffer.from(message + "\n\n"), "utf8");
} }
}); });
this.jsonStream.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
} }
protected processAwsEvent(event: AwsEventStreamMessage): string | null { protected processAwsEvent(event: AwsEventStreamMessage): string | null {
@@ -73,17 +88,38 @@ export class SSEStreamAdapter extends Transform {
} }
} }
// Google doesn't use event streams and just sends elements in an array over
// a long-lived HTTP connection. Needs stream-json to parse the array.
protected processGoogleValue(value: any): string | null {
if ("candidates" in value) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error(
{ value },
"Received unexpected Google AI event stream message"
);
return getFakeErrorCompletion(
"proxy Google AI error",
JSON.stringify(value)
);
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) { _transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try { try {
if (this.isAwsStream) { if (this.isAwsStream) {
this.parser.write(chunk); this.parser.write(chunk);
} else if (this.isGoogleStream) {
this.jsonStream.write(chunk);
} else { } else {
// We may receive multiple (or partial) SSE messages in a single chunk, // We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full // so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly. // messages so we can parse/transform them properly.
const str = this.decoder.write(chunk); const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(/\r\r|\n\n|\r\n\r\n/); const fullMessages = (this.partialMessage + str).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || ""; this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) { for (const message of fullMessages) {
@@ -0,0 +1,69 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "google-ai-to-openai",
});
type GoogleAIStreamEvent = {
candidates: {
content: { parts: { text: string }[]; role: string };
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
index: number;
tokenCount?: number;
safetyRatings: { category: string; probability: string }[];
}[];
};
/**
* Transforms an incoming Google AI SSE to an equivalent OpenAI
* chat.completion.chunk SSE.
*/
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const parts = completionEvent.candidates[0].content.parts;
const text = parts[0]?.text ?? "";
const newEvent = {
id: "goo-" + params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content: text },
finish_reason: completionEvent.candidates[0].finishReason ?? null,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
if (parsed.candidates?.length > 0) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}
-1
View File
@@ -17,7 +17,6 @@ import {
} from "./middleware/response"; } from "./middleware/response";
import { generateModelList } from "./openai"; import { generateModelList } from "./openai";
import { import {
mirrorGeneratedImage,
OpenAIImageGenerationResult, OpenAIImageGenerationResult,
} from "../shared/file-storage/mirror-generated-image"; } from "../shared/file-storage/mirror-generated-image";
-170
View File
@@ -1,170 +0,0 @@
import { Request, RequestHandler, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googlePalmKey) return { object: "list", data: [] };
const bisonVariants = ["text-bison-001"];
const models = bisonVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "palm",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const palmResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google PaLM response to OpenAI format");
body = transformPalmResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if
// requests wait in the queue for too long. Probably need to fake streaming
// and return the entire completion in one stream event using the other
// response handler.
res.status(200).json(body);
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformPalmResponse(
palmRespBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "plm-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: palmRespBody.candidates[0].output,
},
finish_reason: null, // palm doesn't return this
index: 0,
},
],
};
}
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
if (req.body.stream) {
throw new Error("Google PaLM API doesn't support streaming requests");
}
// PaLM API specifies the model in the URL path, not the request body. This
// doesn't work well with our rewriter architecture, so we need to manually
// fix it here.
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
// The chat api (generateMessage) is not very useful at this time as it has
// few params and no adjustable safety settings.
proxyReq.path = proxyReq.path.replace(
/^\/v1\/chat\/completions/,
`/v1beta2/models/${req.body.model}:generateText`
);
}
const googlePalmProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://generativelanguage.googleapis.com",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,
},
}),
});
const palmRouter = Router();
palmRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google PaLM compatibility endpoint.
palmRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
{ afterTransform: [forceModel("text-bison-001")] }
),
googlePalmProxy
);
export const googlePalm = palmRouter;
+2 -2
View File
@@ -4,7 +4,7 @@ import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai"; import { openai } from "./openai";
import { openaiImage } from "./openai-image"; import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic"; import { anthropic } from "./anthropic";
import { googlePalm } from "./palm"; import { googleAI } from "./google-ai";
import { aws } from "./aws"; import { aws } from "./aws";
import { azure } from "./azure"; import { azure } from "./azure";
@@ -31,7 +31,7 @@ proxyRouter.use((req, _res, next) => {
proxyRouter.use("/openai", addV1, openai); proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage); proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic); proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm); proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/aws/claude", addV1, aws); proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure); proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage. // Redirect browser requests to the homepage.
@@ -62,7 +62,6 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message }, { key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds." "Key is rate limited. Rechecking in 10 seconds."
); );
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000); const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next }); this.updateKey(key.hash, { lastChecked: next });
break; break;
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models"; import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider"; import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker"; import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">; export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
@@ -2,13 +2,17 @@ import crypto from "crypto";
import { Key, KeyProvider } from ".."; import { Key, KeyProvider } from "..";
import { config } from "../../../config"; import { config } from "../../../config";
import { logger } from "../../../logger"; import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models"; import type { GoogleAIModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language // Note that Google AI is not the same as Vertex AI, both are provided by Google
export type GooglePalmModel = "text-bison-001"; // but Vertex is the GCP product for enterprise. while Google AI is the
// consumer-ish product. The API is different, and keys are not compatible.
// https://ai.google.dev/docs/migrate_to_cloud
export type GooglePalmKeyUpdate = Omit< export type GoogleAIModel = "gemini-pro";
Partial<GooglePalmKey>,
export type GoogleAIKeyUpdate = Omit<
Partial<GoogleAIKey>,
| "key" | "key"
| "hash" | "hash"
| "lastUsed" | "lastUsed"
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
| "rateLimitedUntil" | "rateLimitedUntil"
>; >;
type GooglePalmKeyUsage = { type GoogleAIKeyUsage = {
[K in GooglePalmModelFamily as `${K}Tokens`]: number; [K in GoogleAIModelFamily as `${K}Tokens`]: number;
}; };
export interface GooglePalmKey extends Key, GooglePalmKeyUsage { export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-palm"; readonly service: "google-ai";
readonly modelFamilies: GooglePalmModelFamily[]; readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */ /** The time at which this key was last rate limited. */
rateLimitedAt: number; rateLimitedAt: number;
/** The time until which this key is rate limited. */ /** The time until which this key is rate limited. */
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
*/ */
const KEY_REUSE_DELAY = 500; const KEY_REUSE_DELAY = 500;
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> { export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
readonly service = "google-palm"; readonly service = "google-ai";
private keys: GooglePalmKey[] = []; private keys: GoogleAIKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service }); private log = logger.child({ module: "key-provider", service: this.service });
constructor() { constructor() {
const keyConfig = config.googlePalmKey?.trim(); const keyConfig = config.googleAIKey?.trim();
if (!keyConfig) { if (!keyConfig) {
this.log.warn( this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available." "GOOGLE_AI_KEY is not set. Google AI API will not be available."
); );
return; return;
} }
let bareKeys: string[]; let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))]; bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) { for (const key of bareKeys) {
const newKey: GooglePalmKey = { const newKey: GoogleAIKey = {
key, key,
service: this.service, service: this.service,
modelFamilies: ["bison"], modelFamilies: ["gemini-pro"],
isDisabled: false, isDisabled: false,
isRevoked: false, isRevoked: false,
promptCount: 0, promptCount: 0,
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
.digest("hex") .digest("hex")
.slice(0, 8)}`, .slice(0, 8)}`,
lastChecked: 0, lastChecked: 0,
bisonTokens: 0, "gemini-proTokens": 0,
}; };
this.keys.push(newKey); this.keys.push(newKey);
} }
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys."); this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
} }
public init() {} public init() {}
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined })); return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
} }
public get(_model: GooglePalmModel) { public get(_model: GoogleAIModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled); const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) { if (availableKeys.length === 0) {
throw new Error("No Google PaLM keys available"); throw new Error("No Google AI keys available");
} }
// (largely copied from the OpenAI provider, without trial key support) // (largely copied from the OpenAI provider, without trial key support)
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey }; return { ...selectedKey };
} }
public disable(key: GooglePalmKey) { public disable(key: GoogleAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash); const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return; if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true; keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled"); this.log.warn({ key: key.hash }, "Key disabled");
} }
public update(hash: string, update: Partial<GooglePalmKey>) { public update(hash: string, update: Partial<GoogleAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!; const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update }); Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
} }
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
const key = this.keys.find((k) => k.hash === hash); const key = this.keys.find((k) => k.hash === hash);
if (!key) return; if (!key) return;
key.promptCount++; key.promptCount++;
key.bisonTokens += tokens; key["gemini-proTokens"] += tokens;
} }
public getLockoutPeriod() { public getLockoutPeriod() {
+5 -5
View File
@@ -1,6 +1,6 @@
import { OpenAIModel } from "./openai/provider"; import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider"; import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider"; import { GoogleAIModel } from "./google-ai/provider";
import { AwsBedrockModel } from "./aws/provider"; import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider"; import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool"; import { KeyPool } from "./key-pool";
@@ -10,20 +10,20 @@ import type { ModelFamily } from "../models";
export type APIFormat = export type APIFormat =
| "openai" | "openai"
| "anthropic" | "anthropic"
| "google-palm" | "google-ai"
| "openai-text" | "openai-text"
| "openai-image"; | "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */ /** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService = export type LLMService =
| "openai" | "openai"
| "anthropic" | "anthropic"
| "google-palm" | "google-ai"
| "aws" | "aws"
| "azure"; | "azure";
export type Model = export type Model =
| OpenAIModel | OpenAIModel
| AnthropicModel | AnthropicModel
| GooglePalmModel | GoogleAIModel
| AwsBedrockModel | AwsBedrockModel
| AzureOpenAIModel; | AzureOpenAIModel;
@@ -77,6 +77,6 @@ export interface KeyProvider<T extends Key = Key> {
export const keyPool = new KeyPool(); export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider"; export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider"; export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider"; export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider"; export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider"; export { AzureOpenAIKey } from "./azure/provider";
+6 -6
View File
@@ -7,7 +7,7 @@ import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index"; import { Key, Model, KeyProvider, LLMService } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider"; import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider"; import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider"; import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider"; import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models"; import { ModelFamily } from "../models";
import { assertNever } from "../utils"; import { assertNever } from "../utils";
@@ -24,7 +24,7 @@ export class KeyPool {
constructor() { constructor() {
this.keyProviders.push(new OpenAIKeyProvider()); this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider()); this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider()); this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider()); this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider()); this.keyProviders.push(new AzureOpenAIKeyProvider());
} }
@@ -119,9 +119,9 @@ export class KeyPool {
} else if (model.startsWith("claude-")) { } else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters // https://console.anthropic.com/docs/api/reference#parameters
return "anthropic"; return "anthropic";
} else if (model.includes("bison")) { } else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language // https://developers.generativeai.google.com/models/language
return "google-palm"; return "google-ai";
} else if (model.startsWith("anthropic.claude")) { } else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers // AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
@@ -142,8 +142,8 @@ export class KeyPool {
return "openai"; return "openai";
case "claude": case "claude":
return "anthropic"; return "anthropic";
case "bison": case "gemini-pro":
return "google-palm"; return "google-ai";
case "aws-claude": case "aws-claude":
return "aws"; return "aws";
case "azure-turbo": case "azure-turbo":
+8 -10
View File
@@ -11,7 +11,7 @@ export type OpenAIModelFamily =
| "gpt4-turbo" | "gpt4-turbo"
| "dall-e"; | "dall-e";
export type AnthropicModelFamily = "claude"; export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison"; export type GoogleAIModelFamily = "gemini-pro";
export type AwsBedrockModelFamily = "aws-claude"; export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude< export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily, OpenAIModelFamily,
@@ -20,7 +20,7 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
export type ModelFamily = export type ModelFamily =
| OpenAIModelFamily | OpenAIModelFamily
| AnthropicModelFamily | AnthropicModelFamily
| GooglePalmModelFamily | GoogleAIModelFamily
| AwsBedrockModelFamily | AwsBedrockModelFamily
| AzureOpenAIModelFamily; | AzureOpenAIModelFamily;
@@ -33,7 +33,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"gpt4-turbo", "gpt4-turbo",
"dall-e", "dall-e",
"claude", "claude",
"bison", "gemini-pro",
"aws-claude", "aws-claude",
"azure-turbo", "azure-turbo",
"azure-gpt4", "azure-gpt4",
@@ -53,7 +53,7 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^dall-e-\\d{1}$": "dall-e", "^dall-e-\\d{1}$": "dall-e",
}; };
const modelLogger = pino({ level: "debug" }).child({ module: "startup" }); pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily( export function getOpenAIModelFamily(
model: string, model: string,
@@ -70,10 +70,8 @@ export function getClaudeModelFamily(model: string): ModelFamily {
return "claude"; return "claude";
} }
export function getGooglePalmModelFamily(model: string): ModelFamily { export function getGoogleAIModelFamily(_model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison"; return "gemini-pro";
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
return "bison";
} }
export function getAwsBedrockModelFamily(_model: string): ModelFamily { export function getAwsBedrockModelFamily(_model: string): ModelFamily {
@@ -130,8 +128,8 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "openai-image": case "openai-image":
modelFamily = getOpenAIModelFamily(model); modelFamily = getOpenAIModelFamily(model);
break; break;
case "google-palm": case "google-ai":
modelFamily = getGooglePalmModelFamily(model); modelFamily = getGoogleAIModelFamily(model);
break; break;
default: default:
assertNever(req.outboundApi); assertNever(req.outboundApi);
+1 -1
View File
@@ -74,7 +74,7 @@ export function buildFakeSse(type: string, string: string, req: Request) {
log_id: "proxy-req-" + req.id, log_id: "proxy-req-" + req.id,
}; };
break; break;
case "google-palm": case "google-ai":
case "openai-image": case "openai-image":
throw new Error(`SSE not supported for ${req.inboundApi} requests`); throw new Error(`SSE not supported for ${req.inboundApi} requests`);
default: default:
+29 -4
View File
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json"; import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger"; import { logger } from "../../logger";
import { libSharp } from "../file-storage"; import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { z } from "zod";
const log = logger.child({ module: "tokenizer", service: "openai" }); const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170; const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
@@ -29,11 +33,11 @@ export async function getTokenCount(
return getTextTokenCount(prompt); return getTextTokenCount(prompt);
} }
const gpt4 = model.startsWith("gpt-4"); const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision"); const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4; const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0; let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
token_count: Math.ceil(tokens), token_count: Math.ceil(tokens),
}; };
} }
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
numTokens += encoder.encode(message.parts[0].text).length;
}
numTokens += 3;
return {
tokenizer: "tiktoken (google-ai estimate)",
token_count: numTokens,
};
}
+10 -5
View File
@@ -1,5 +1,8 @@
import { Request } from "express"; import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload"; import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils"; import { assertNever } from "../utils";
import { import {
init as initClaude, init as initClaude,
@@ -9,6 +12,7 @@ import {
init as initOpenAi, init as initOpenAi,
getTokenCount as getOpenAITokenCount, getTokenCount as getOpenAITokenCount,
getOpenAIImageCost, getOpenAIImageCost,
estimateGoogleAITokenCount,
} from "./openai"; } from "./openai";
import { APIFormat } from "../key-management"; import { APIFormat } from "../key-management";
@@ -24,8 +28,9 @@ type TokenCountRequest = { req: Request } & (
| { | {
prompt: string; prompt: string;
completion?: never; completion?: never;
service: "openai-text" | "anthropic" | "google-palm"; service: "openai-text" | "anthropic" | "google-ai";
} }
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| { prompt?: never; completion: string; service: APIFormat } | { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" } | { prompt?: never; completion?: never; service: "openai-image" }
); );
@@ -65,11 +70,11 @@ export async function countTokens({
}), }),
tokenization_duration_ms: getElapsedMs(time), tokenization_duration_ms: getElapsedMs(time),
}; };
case "google-palm": case "google-ai":
// TODO: Can't find a tokenization library for PaLM. There is an API // TODO: Can't find a tokenization library for Gemini. There is an API
// endpoint for it but it adds significant latency to the request. // endpoint for it but it adds significant latency to the request.
return { return {
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)), ...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time), tokenization_duration_ms: getElapsedMs(time),
}; };
default: default:
+1 -1
View File
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
"gpt4-turbo": z.number().optional().default(0), "gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0), "dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0), claude: z.number().optional().default(0),
bison: z.number().optional().default(0), "gemini-pro": z.number().optional().default(0),
"aws-claude": z.number().optional().default(0), "aws-claude": z.number().optional().default(0),
}); });
+4 -4
View File
@@ -14,7 +14,7 @@ import { config, getFirebaseApp } from "../../config";
import { import {
getAzureOpenAIModelFamily, getAzureOpenAIModelFamily,
getClaudeModelFamily, getClaudeModelFamily,
getGooglePalmModelFamily, getGoogleAIModelFamily,
getOpenAIModelFamily, getOpenAIModelFamily,
MODEL_FAMILIES, MODEL_FAMILIES,
ModelFamily, ModelFamily,
@@ -33,7 +33,7 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
"gpt4-turbo": 0, "gpt4-turbo": 0,
"dall-e": 0, "dall-e": 0,
claude: 0, claude: 0,
bison: 0, "gemini-pro": 0,
"aws-claude": 0, "aws-claude": 0,
"azure-turbo": 0, "azure-turbo": 0,
"azure-gpt4": 0, "azure-gpt4": 0,
@@ -397,8 +397,8 @@ function getModelFamilyForQuotaUsage(
return getOpenAIModelFamily(model); return getOpenAIModelFamily(model);
case "anthropic": case "anthropic":
return getClaudeModelFamily(model); return getClaudeModelFamily(model);
case "google-palm": case "google-ai":
return getGooglePalmModelFamily(model); return getGoogleAIModelFamily(model);
default: default:
assertNever(api); assertNever(api);
} }