This commit is contained in:
khanon
2023-12-13 21:56:07 +00:00
parent 0d3682197c
commit fad16cc268
40 changed files with 588 additions and 357 deletions
+3 -3
View File
@@ -34,10 +34,10 @@
# Which model types users are allowed to access.
# The following model families are recognized:
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | bison | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# turbo | gpt4 | gpt4-32k | gpt4-turbo | dall-e | claude | gemini-pro | aws-claude | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo
# By default, all models are allowed except for 'dall-e'. To allow DALL-E image
# generation, uncomment the line below and add 'dall-e' to the list.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo
# URLs from which requests will be blocked.
# BLOCKED_ORIGINS=reddit.com,9gag.com
@@ -95,7 +95,7 @@
# TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_BISON=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# How often to refresh token quotas. (hourly | daily)
+1 -1
View File
@@ -35,7 +35,7 @@ Add `dall-e` to the `ALLOWED_MODEL_FAMILIES` environment variable to enable DALL
ALLOWED_MODEL_FAMILIES=turbo,gpt-4,gpt-4turbo,dall-e
# All models as of this writing
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,bison,aws-claude,dall-e
ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,claude,gemini-pro,aws-claude,dall-e
```
Refer to [.env.example](../.env.example) for a full list of supported model families. You can add `dall-e` to that list to enable all models.
+34
View File
@@ -36,6 +36,7 @@
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
"zlib": "^1.0.5",
@@ -51,6 +52,7 @@
"@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
@@ -1185,6 +1187,25 @@
"integrity": "sha512-70xBJoLv+oXjB5PhtA8vo7erjLDp9/qqI63SRHm4REKrwuPOLs8HhXwlZJBJaB4kC18cCZ1UUZ6Fb/PLFW4TCA==",
"dev": true
},
"node_modules/@types/stream-chain": {
"version": "2.0.4",
"resolved": "https://registry.npmjs.org/@types/stream-chain/-/stream-chain-2.0.4.tgz",
"integrity": "sha512-V7TsWLHrx79KumkHqSD7F8eR6POpEuWb6PuXJ7s/dRHAf3uVst3Jkp1yZ5XqIfECZLQ4a28vBVstTErmsMBvaQ==",
"dev": true,
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@types/stream-json": {
"version": "1.7.7",
"resolved": "https://registry.npmjs.org/@types/stream-json/-/stream-json-1.7.7.tgz",
"integrity": "sha512-hHG7cLQ09H/m9i0jzL6UJAeLLxIWej90ECn0svO4T8J0nGcl89xZDQ2ujT4WKlvg0GWkcxJbjIDzW/v7BYUM6Q==",
"dev": true,
"dependencies": {
"@types/node": "*",
"@types/stream-chain": "*"
}
},
"node_modules/@types/uuid": {
"version": "9.0.1",
"resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.1.tgz",
@@ -5135,6 +5156,11 @@
"node": ">= 0.8"
}
},
"node_modules/stream-chain": {
"version": "2.2.5",
"resolved": "https://registry.npmjs.org/stream-chain/-/stream-chain-2.2.5.tgz",
"integrity": "sha512-1TJmBx6aSWqZ4tx7aTpBDXK0/e2hhcNSTV8+CbFJtDjbb+I1mZ8lHit0Grw9GRT+6JbIrrDd8esncgBi8aBXGA=="
},
"node_modules/stream-events": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz",
@@ -5144,6 +5170,14 @@
"stubs": "^3.0.0"
}
},
"node_modules/stream-json": {
"version": "1.8.0",
"resolved": "https://registry.npmjs.org/stream-json/-/stream-json-1.8.0.tgz",
"integrity": "sha512-HZfXngYHUAr1exT4fxlbc1IOce1RYxp2ldeaf97LYCOPSoOqY/1Psp7iGvpb+6JIOgkra9zDYnPX01hGAHzEPw==",
"dependencies": {
"stream-chain": "^2.2.5"
}
},
"node_modules/stream-shift": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.1.tgz",
+2
View File
@@ -44,6 +44,7 @@
"sanitize-html": "^2.11.0",
"sharp": "^0.32.6",
"showdown": "^2.1.0",
"stream-json": "^1.8.0",
"tiktoken": "^1.0.10",
"uuid": "^9.0.0",
"zlib": "^1.0.5",
@@ -59,6 +60,7 @@
"@types/node-schedule": "^2.1.0",
"@types/sanitize-html": "^2.9.0",
"@types/showdown": "^2.0.0",
"@types/stream-json": "^1.7.7",
"@types/uuid": "^9.0.1",
"concurrently": "^8.0.1",
"esbuild": "^0.17.16",
+4 -3
View File
@@ -1,6 +1,6 @@
const axios = require("axios");
const concurrentRequests = 5;
const concurrentRequests = 75;
const headers = {
Authorization: "Bearer test",
"Content-Type": "application/json",
@@ -16,7 +16,7 @@ const payload = {
const makeRequest = async (i) => {
try {
const response = await axios.post(
"http://localhost:7860/proxy/azure/openai/v1/chat/completions",
"http://localhost:7860/proxy/google-ai/v1/chat/completions",
payload,
{ headers }
);
@@ -25,7 +25,8 @@ const makeRequest = async (i) => {
response.data
);
} catch (error) {
console.error(`Error in req ${i}:`, error.message);
const msg = error.response
console.error(`Error in req ${i}:`, error.message, msg || "");
}
};
+1 -1
View File
@@ -200,7 +200,7 @@ router.post("/maintenance", (req, res) => {
keyPool.recheck("anthropic");
const size = keyPool
.list()
.filter((k) => k.service !== "google-palm").length;
.filter((k) => k.service !== "google-ai").length;
flash.type = "success";
flash.message = `Scheduled recheck of ${size} keys for OpenAI and Anthropic.`;
break;
+10 -6
View File
@@ -19,8 +19,12 @@ type Config = {
openaiKey?: string;
/** Comma-delimited list of Anthropic API keys. */
anthropicKey?: string;
/** Comma-delimited list of Google PaLM API keys. */
googlePalmKey?: string;
/**
* Comma-delimited list of Google AI API keys. Note that these are not the
* same as the GCP keys/credentials used for Vertex AI; the models are the
* same but the APIs are different. Vertex is the GCP product for enterprise.
**/
googleAIKey?: string;
/**
* Comma-delimited list of AWS credentials. Each credential item should be a
* colon-delimited list of access key, secret key, and AWS region.
@@ -197,7 +201,7 @@ export const config: Config = {
port: getEnvWithDefault("PORT", 7860),
openaiKey: getEnvWithDefault("OPENAI_KEY", ""),
anthropicKey: getEnvWithDefault("ANTHROPIC_KEY", ""),
googlePalmKey: getEnvWithDefault("GOOGLE_PALM_KEY", ""),
googleAIKey: getEnvWithDefault("GOOGLE_AI_KEY", ""),
awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""),
azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""),
proxyKey: getEnvWithDefault("PROXY_KEY", ""),
@@ -229,7 +233,7 @@ export const config: Config = {
"gpt4-32k",
"gpt4-turbo",
"claude",
"bison",
"gemini-pro",
"aws-claude",
"azure-turbo",
"azure-gpt4",
@@ -366,7 +370,7 @@ export const OMITTED_KEYS: (keyof Config)[] = [
"logLevel",
"openaiKey",
"anthropicKey",
"googlePalmKey",
"googleAIKey",
"awsCredentials",
"azureCredentials",
"proxyKey",
@@ -433,7 +437,7 @@ function getEnvWithDefault<T>(env: string | string[], defaultValue: T): T {
[
"OPENAI_KEY",
"ANTHROPIC_KEY",
"GOOGLE_PALM_KEY",
"GOOGLE_AI_KEY",
"AWS_CREDENTIALS",
"AZURE_CREDENTIALS",
].includes(String(env))
+35 -29
View File
@@ -7,7 +7,7 @@ import {
AnthropicKey,
AwsBedrockKey,
AzureOpenAIKey,
GooglePalmKey,
GoogleAIKey,
keyPool,
OpenAIKey,
} from "./shared/key-management";
@@ -33,8 +33,8 @@ const keyIsAzureKey = (k: KeyPoolKey): k is AzureOpenAIKey =>
k.service === "azure";
const keyIsAnthropicKey = (k: KeyPoolKey): k is AnthropicKey =>
k.service === "anthropic";
const keyIsGooglePalmKey = (k: KeyPoolKey): k is GooglePalmKey =>
k.service === "google-palm";
const keyIsGoogleAIKey = (k: KeyPoolKey): k is GoogleAIKey =>
k.service === "google-ai";
const keyIsAwsKey = (k: KeyPoolKey): k is AwsBedrockKey => k.service === "aws";
type ModelAggregates = {
@@ -54,7 +54,7 @@ type ServiceAggregates = {
openaiKeys?: number;
openaiOrgs?: number;
anthropicKeys?: number;
palmKeys?: number;
googleAIKeys?: number;
awsKeys?: number;
azureKeys?: number;
proompts: number;
@@ -100,7 +100,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const openaiKeys = serviceStats.get("openaiKeys") || 0;
const anthropicKeys = serviceStats.get("anthropicKeys") || 0;
const palmKeys = serviceStats.get("palmKeys") || 0;
const googleAIKeys = serviceStats.get("googleAIKeys") || 0;
const awsKeys = serviceStats.get("awsKeys") || 0;
const azureKeys = serviceStats.get("azureKeys") || 0;
const proompts = serviceStats.get("proompts") || 0;
@@ -116,7 +116,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
? { ["openai-image"]: baseUrl + "/openai-image" }
: {}),
...(anthropicKeys ? { anthropic: baseUrl + "/anthropic" } : {}),
...(palmKeys ? { "google-palm": baseUrl + "/google-palm" } : {}),
...(googleAIKeys ? { "google-ai": baseUrl + "/google-ai" } : {}),
...(awsKeys ? { aws: baseUrl + "/aws/claude" } : {}),
...(azureKeys ? { azure: baseUrl + "/azure/openai" } : {}),
};
@@ -127,7 +127,13 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
...(config.textModelRateLimit ? { proomptersNow: getUniqueIps() } : {}),
};
const keyInfo = { openaiKeys, anthropicKeys, palmKeys, awsKeys, azureKeys };
const keyInfo = {
openaiKeys,
anthropicKeys,
googleAIKeys,
awsKeys,
azureKeys,
};
for (const key of Object.keys(keyInfo)) {
if (!(keyInfo as any)[key]) delete (keyInfo as any)[key];
}
@@ -135,7 +141,7 @@ export function buildInfoPageHtml(baseUrl: string, asAdmin = false) {
const providerInfo = {
...(openaiKeys ? getOpenAIInfo() : {}),
...(anthropicKeys ? getAnthropicInfo() : {}),
...(palmKeys ? getPalmInfo() : {}),
...(googleAIKeys ? getGoogleAIInfo() : {}),
...(awsKeys ? getAwsInfo() : {}),
...(azureKeys ? getAzureInfo() : {}),
};
@@ -197,7 +203,7 @@ function addKeyToAggregates(k: KeyPoolKey) {
increment(serviceStats, "proompts", k.promptCount);
increment(serviceStats, "openaiKeys", k.service === "openai" ? 1 : 0);
increment(serviceStats, "anthropicKeys", k.service === "anthropic" ? 1 : 0);
increment(serviceStats, "palmKeys", k.service === "google-palm" ? 1 : 0);
increment(serviceStats, "googleAIKeys", k.service === "google-ai" ? 1 : 0);
increment(serviceStats, "awsKeys", k.service === "aws" ? 1 : 0);
increment(serviceStats, "azureKeys", k.service === "azure" ? 1 : 0);
@@ -251,14 +257,14 @@ function addKeyToAggregates(k: KeyPoolKey) {
);
break;
}
case "google-palm": {
if (!keyIsGooglePalmKey(k)) throw new Error("Invalid key type");
const family = "bison";
sumTokens += k.bisonTokens;
sumCost += getTokenCostUsd(family, k.bisonTokens);
case "google-ai": {
if (!keyIsGoogleAIKey(k)) throw new Error("Invalid key type");
const family = "gemini-pro";
sumTokens += k["gemini-proTokens"];
sumCost += getTokenCostUsd(family, k["gemini-proTokens"]);
increment(modelStats, `${family}__active`, k.isDisabled ? 0 : 1);
increment(modelStats, `${family}__revoked`, k.isRevoked ? 1 : 0);
increment(modelStats, `${family}__tokens`, k.bisonTokens);
increment(modelStats, `${family}__tokens`, k["gemini-proTokens"]);
break;
}
case "aws": {
@@ -388,26 +394,26 @@ function getAnthropicInfo() {
};
}
function getPalmInfo() {
const bisonInfo: Partial<ModelAggregates> = {
active: modelStats.get("bison__active") || 0,
revoked: modelStats.get("bison__revoked") || 0,
function getGoogleAIInfo() {
const googleAIInfo: Partial<ModelAggregates> = {
active: modelStats.get("gemini-pro__active") || 0,
revoked: modelStats.get("gemini-pro__revoked") || 0,
};
const queue = getQueueInformation("bison");
bisonInfo.queued = queue.proomptersInQueue;
bisonInfo.queueTime = queue.estimatedQueueTime;
const queue = getQueueInformation("gemini-pro");
googleAIInfo.queued = queue.proomptersInQueue;
googleAIInfo.queueTime = queue.estimatedQueueTime;
const tokens = modelStats.get("bison__tokens") || 0;
const cost = getTokenCostUsd("bison", tokens);
const tokens = modelStats.get("gemini-pro__tokens") || 0;
const cost = getTokenCostUsd("gemini-pro", tokens);
return {
bison: {
gemini: {
usage: `${prettyTokens(tokens)} tokens${getCostString(cost)}`,
activeKeys: bisonInfo.active,
revokedKeys: bisonInfo.revoked,
proomptersInQueue: bisonInfo.queued,
estimatedQueueTime: bisonInfo.queueTime,
activeKeys: googleAIInfo.active,
revokedKeys: googleAIInfo.revoked,
proomptersInQueue: googleAIInfo.queued,
estimatedQueueTime: googleAIInfo.queueTime,
},
};
}
+141
View File
@@ -0,0 +1,141 @@
import { Request, RequestHandler, Router } from "express";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeSignedRequest,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googleAIKey) return { object: "list", data: [] };
const googleAIVariants = ["gemini-pro"];
const models = googleAIVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "google",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const googleAIResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google AI response to OpenAI format");
body = transformGoogleAIResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
res.status(200).json(body);
};
function transformGoogleAIResponse(
googleAIResp: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "goo-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: googleAIResp.candidates[0].content.parts[0].text,
},
finish_reason: googleAIResp.candidates[0].finishReason,
index: 0,
},
],
};
}
const googleAIProxy = createQueueMiddleware({
beforeProxy: addGoogleAIKey,
proxyMiddleware: createProxyMiddleware({
target: "bad-target-will-be-rewritten",
router: ({ signedRequest }) => {
const { protocol, hostname, path } = signedRequest;
return `${protocol}//${hostname}${path}`;
},
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
error: handleProxyError,
},
}),
});
const googleAIRouter = Router();
googleAIRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google AI compatibility endpoint.
googleAIRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-ai", service: "google-ai" },
{ afterTransform: [forceModel("gemini-pro")] }
),
googleAIProxy
);
export const googleAI = googleAIRouter;
+6 -3
View File
@@ -177,8 +177,11 @@ export function getCompletionFromBody(req: Request, body: Record<string, any>) {
return "";
}
return body.completion.trim();
case "google-palm":
return body.candidates[0].output;
case "google-ai":
if ("choices" in body) {
return body.choices[0].message.content;
}
return body.candidates[0].content.parts[0].text;
case "openai-image":
return body.data?.map((item: any) => item.url).join("\n");
default:
@@ -197,7 +200,7 @@ export function getModelFromBody(req: Request, body: Record<string, any>) {
case "anthropic":
// Anthropic confirms the model in the response, but AWS Claude doesn't.
return body.model || req.body.model;
case "google-palm":
case "google-ai":
// Google doesn't confirm the model in the response.
return req.body.model;
default:
@@ -29,7 +29,9 @@ export const createOnProxyReqHandler = ({
// The streaming flag must be set before any other onProxyReq handler runs,
// as it may influence the behavior of subsequent handlers.
// Image generation requests can't be streamed.
req.isStreaming = req.body.stream === true || req.body.stream === "true";
// TODO: this flag is set in too many places
req.isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
req.body.stream = req.isStreaming;
try {
@@ -31,10 +31,6 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
case "anthropic":
assignedKey = keyPool.get("claude-v1");
break;
case "google-palm":
assignedKey = keyPool.get("text-bison-001");
delete req.body.stream;
break;
case "openai-text":
assignedKey = keyPool.get("gpt-3.5-turbo-instruct");
break;
@@ -42,6 +38,8 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
throw new Error(
"OpenAI Chat as an API translation target is not supported"
);
case "google-ai":
throw new Error("add-key should not be used for this model.");
case "openai-image":
assignedKey = keyPool.get("dall-e-3");
break;
@@ -73,21 +71,13 @@ export const addKey: HPMRequestCallback = (proxyReq, req) => {
}
proxyReq.setHeader("Authorization", `Bearer ${assignedKey.key}`);
break;
case "google-palm":
const originalPath = proxyReq.path;
proxyReq.path = originalPath.replace(
/(\?.*)?$/,
`?key=${assignedKey.key}`
);
break;
case "azure":
const azureKey = assignedKey.key;
proxyReq.setHeader("api-key", azureKey);
break;
case "aws":
throw new Error(
"add-key should not be used for AWS security credentials. Use sign-aws-request instead."
);
case "google-ai":
throw new Error("add-key should not be used for this service.");
default:
assertNever(assignedKey.service);
}
@@ -1,9 +1,9 @@
import type { HPMRequestCallback } from "../index";
/**
* For AWS/Azure requests, the body is signed earlier in the request pipeline,
* before the proxy middleware. This function just assigns the path and headers
* to the proxy request.
* For AWS/Azure/Google requests, the body is signed earlier in the request
* pipeline, before the proxy middleware. This function just assigns the path
* and headers to the proxy request.
*/
export const finalizeSignedRequest: HPMRequestCallback = (proxyReq, req) => {
if (!req.signedRequest) {
@@ -0,0 +1,40 @@
import { keyPool } from "../../../../shared/key-management";
import { RequestPreprocessor } from "../index";
export const addGoogleAIKey: RequestPreprocessor = (req) => {
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
const serviceValid = req.service === "google-ai";
if (!apisValid || !serviceValid) {
throw new Error("addGoogleAIKey called on invalid request");
}
if (!req.body?.model) {
throw new Error("You must specify a model with your request.");
}
const model = req.body.model;
req.key = keyPool.get(model);
req.log.info(
{ key: req.key.hash, model },
"Assigned Google AI API key to request"
);
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
req.isStreaming = req.isStreaming || req.body.stream;
delete req.body.stream;
req.signedRequest = {
method: "POST",
protocol: "https:",
hostname: "generativelanguage.googleapis.com",
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
headers: {
["host"]: `generativelanguage.googleapis.com`,
["content-type"]: "application/json",
},
body: JSON.stringify(req.body),
};
};
@@ -1,7 +1,7 @@
import { RequestPreprocessor } from "../index";
import { countTokens } from "../../../../shared/tokenization";
import { assertNever } from "../../../../shared/utils";
import type { OpenAIChatMessage } from "./transform-outbound-payload";
import type { GoogleAIChatMessage, OpenAIChatMessage } from "./transform-outbound-payload";
/**
* Given a request with an already-transformed body, counts the number of
@@ -30,9 +30,9 @@ export const countPromptTokens: RequestPreprocessor = async (req) => {
result = await countTokens({ req, prompt, service });
break;
}
case "google-palm": {
req.outputTokens = req.body.maxOutputTokens;
const prompt: string = req.body.prompt.text;
case "google-ai": {
req.outputTokens = req.body.generationConfig.maxOutputTokens;
const prompt: GoogleAIChatMessage[] = req.body.contents;
result = await countTokens({ req, prompt, service });
break;
}
@@ -68,7 +68,7 @@ function getPromptFromRequest(req: Request) {
case "openai-text":
case "openai-image":
return body.prompt;
case "google-palm":
case "google-ai":
return body.prompt.text;
default:
assertNever(service);
@@ -1,7 +1,10 @@
import { Request } from "express";
import { z } from "zod";
import { config } from "../../../../config";
import { isTextGenerationRequest, isImageGenerationRequest } from "../../common";
import {
isTextGenerationRequest,
isImageGenerationRequest,
} from "../../common";
import { RequestPreprocessor } from "../index";
import { APIFormat } from "../../../../shared/key-management";
@@ -121,10 +124,19 @@ const OpenAIV1ImagesGenerationSchema = z.object({
user: z.string().optional(),
});
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateText
const PalmV1GenerateTextSchema = z.object({
model: z.string(),
prompt: z.object({ text: z.string() }),
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
const GoogleAIV1GenerateContentSchema = z.object({
model: z.string(), //actually specified in path but we need it for the router
stream: z.boolean().optional().default(false), // also used for router
contents: z.array(
z.object({
parts: z.array(z.object({ text: z.string() })),
role: z.enum(["user", "model"]),
})
),
tools: z.array(z.object({})).max(0).optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
generationConfig: z.object({
temperature: z.number().optional(),
maxOutputTokens: z.coerce
.number()
@@ -135,16 +147,20 @@ const PalmV1GenerateTextSchema = z.object({
candidateCount: z.literal(1).optional(),
topP: z.number().optional(),
topK: z.number().optional(),
safetySettings: z.array(z.object({})).max(0).optional(),
stopSequences: z.array(z.string()).max(5).optional(),
}),
});
export type GoogleAIChatMessage = z.infer<
typeof GoogleAIV1GenerateContentSchema
>["contents"][0];
const VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
anthropic: AnthropicV1CompleteSchema,
openai: OpenAIV1ChatCompletionSchema,
"openai-text": OpenAIV1TextCompletionSchema,
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-palm": PalmV1GenerateTextSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
};
/** Transforms an incoming request body to one that matches the target API. */
@@ -174,8 +190,8 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
return;
}
if (req.inboundApi === "openai" && req.outboundApi === "google-palm") {
req.body = openaiToPalm(req);
if (req.inboundApi === "openai" && req.outboundApi === "google-ai") {
req.body = openaiToGoogleAI(req);
return;
}
@@ -310,7 +326,9 @@ function openaiToOpenaiImage(req: Request) {
return OpenAIV1ImagesGenerationSchema.parse(transformed);
}
function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
function openaiToGoogleAI(
req: Request
): z.infer<typeof GoogleAIV1GenerateContentSchema> {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
@@ -319,13 +337,16 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Palm request"
"Invalid OpenAI-to-Google AI request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
const prompt = flattenOpenAIChatMessages(messages);
const contents = messages.map((m) => ({
parts: [{ text: flattenOpenAIMessageContent(m.content) }],
role: m.role === "user" ? "user" as const : "model" as const,
}));
let stops = rest.stop
? Array.isArray(rest.stop)
@@ -339,20 +360,22 @@ function openaiToPalm(req: Request): z.infer<typeof PalmV1GenerateTextSchema> {
z.array(z.string()).max(5).parse(stops);
return {
prompt: { text: prompt },
model: "gemini-pro",
stream: rest.stream,
contents,
tools: [],
generationConfig: {
maxOutputTokens: rest.max_tokens,
stopSequences: stops,
model: "text-bison-001",
topP: rest.top_p,
topK: 40, // openai schema doesn't have this, google ai defaults to 40
temperature: rest.temperature,
},
safetySettings: [
{ category: "HARM_CATEGORY_UNSPECIFIED", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DEROGATORY", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_TOXICITY", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_VIOLENCE", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_MEDICAL", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_HARASSMENT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_HATE_SPEECH", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold: "BLOCK_NONE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", threshold: "BLOCK_NONE" },
],
};
}
@@ -428,3 +451,4 @@ function flattenOpenAIMessageContent(
.join("\n")
: content;
}
@@ -6,7 +6,7 @@ import { RequestPreprocessor } from "../index";
const CLAUDE_MAX_CONTEXT = config.maxContextTokensAnthropic;
const OPENAI_MAX_CONTEXT = config.maxContextTokensOpenAI;
const BISON_MAX_CONTEXT = 8100;
const GOOGLE_AI_MAX_CONTEXT = 32000;
/**
* Assigns `req.promptTokens` and `req.outputTokens` based on the request body
@@ -31,8 +31,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
case "anthropic":
proxyMax = CLAUDE_MAX_CONTEXT;
break;
case "google-palm":
proxyMax = BISON_MAX_CONTEXT;
case "google-ai":
proxyMax = GOOGLE_AI_MAX_CONTEXT;
break;
case "openai-image":
return;
@@ -62,8 +62,8 @@ export const validateContextSize: RequestPreprocessor = async (req) => {
modelMax = 100000;
} else if (model.match(/^claude-2/)) {
modelMax = 200000;
} else if (model.match(/^text-bison-\d{3}$/)) {
modelMax = BISON_MAX_CONTEXT;
} else if (model.match(/^gemini-\d{3}$/)) {
modelMax = GOOGLE_AI_MAX_CONTEXT;
} else if (model.match(/^anthropic\.claude/)) {
// Not sure if AWS Claude has the same context limit as Anthropic Claude.
modelMax = 100000;
@@ -1,4 +1,3 @@
import express from "express";
import { pipeline } from "stream";
import { promisify } from "util";
import {
@@ -59,7 +58,7 @@ export const handleStreamedResponse: RawResponseBodyHandler = async (
const prefersNativeEvents = req.inboundApi === req.outboundApi;
const contentType = proxyRes.headers["content-type"];
const adapter = new SSEStreamAdapter({ contentType });
const adapter = new SSEStreamAdapter({ contentType, api: req.outboundApi });
const aggregator = new EventAggregator({ format: req.outboundApi });
const transformer = new SSEMessageTransformer({
inputFormat: req.outboundApi,
+22 -5
View File
@@ -288,7 +288,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
// For Anthropic, this is usually due to missing preamble.
switch (service) {
case "openai":
case "google-palm":
case "google-ai":
case "azure":
const filteredCodes = ["content_policy_violation", "content_filter"];
if (filteredCodes.includes(errorPayload.error?.code)) {
@@ -350,8 +350,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "azure":
handleAzureRateLimitError(req, errorPayload);
break;
case "google-palm":
errorPayload.proxy_note = `Automatic rate limit retries are not supported for this service. Try again in a few seconds.`;
case "google-ai":
handleGoogleAIRateLimitError(req, errorPayload);
break;
default:
assertNever(service);
@@ -373,8 +373,8 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async (
case "anthropic":
errorPayload.proxy_note = `The requested Claude model might not exist, or the key might not be provisioned for it.`;
break;
case "google-palm":
errorPayload.proxy_note = `The requested Google PaLM model might not exist, or the key might not be provisioned for it.`;
case "google-ai":
errorPayload.proxy_note = `The requested Google AI model might not exist, or the key might not be provisioned for it.`;
break;
case "aws":
errorPayload.proxy_note = `The requested AWS resource might not exist, or the key might not have access to it.`;
@@ -529,6 +529,23 @@ function handleAzureRateLimitError(
}
}
//{"error":{"code":429,"message":"Resource has been exhausted (e.g. check quota).","status":"RESOURCE_EXHAUSTED"}
function handleGoogleAIRateLimitError(
req: Request,
errorPayload: ProxiedErrorPayload
) {
const status = errorPayload.error?.status;
switch (status) {
case "RESOURCE_EXHAUSTED":
keyPool.markRateLimited(req.key!);
reenqueueRequest(req);
throw new RetryableError("Rate-limited request re-enqueued.");
default:
errorPayload.proxy_note = `Unrecognized rate limit error from Google AI (${status}). Please report this.`;
break;
}
}
const incrementUsage: ProxyResHandlerWithBody = async (_proxyRes, req) => {
if (isTextGenerationRequest(req) || isImageGenerationRequest(req)) {
const model = req.body.model;
+1 -1
View File
@@ -73,7 +73,7 @@ const getPromptForRequest = (
};
case "anthropic":
return req.body.prompt;
case "google-palm":
case "google-ai":
return req.body.prompt.text;
default:
assertNever(req.outboundApi);
@@ -27,12 +27,12 @@ export class EventAggregator {
getFinalResponse() {
switch (this.format) {
case "openai":
case "google-ai":
return mergeEventsForOpenAIChat(this.events);
case "openai-text":
return mergeEventsForOpenAIText(this.events);
case "anthropic":
return mergeEventsForAnthropic(this.events);
case "google-palm":
case "openai-image":
throw new Error(`SSE aggregation not supported for ${this.format}`);
default:
@@ -25,6 +25,8 @@ export type StreamingCompletionTransformer = (
export { openAITextToOpenAIChat } from "./transformers/openai-text-to-openai";
export { anthropicV1ToOpenAI } from "./transformers/anthropic-v1-to-openai";
export { anthropicV2ToOpenAI } from "./transformers/anthropic-v2-to-openai";
export { googleAIToOpenAI } from "./transformers/google-ai-to-openai";
export { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
export { mergeEventsForOpenAIChat } from "./aggregators/openai-chat";
export { mergeEventsForOpenAIText } from "./aggregators/openai-text";
export { mergeEventsForAnthropic } from "./aggregators/anthropic";
@@ -7,9 +7,10 @@ import {
anthropicV2ToOpenAI,
OpenAIChatCompletionStreamEvent,
openAITextToOpenAIChat,
googleAIToOpenAI,
passthroughToOpenAI,
StreamingCompletionTransformer,
} from "./index";
import { passthroughToOpenAI } from "./transformers/passthrough-to-openai";
const genlog = logger.child({ module: "sse-transformer" });
@@ -111,7 +112,8 @@ function getTransformer(
return version === "2023-01-01"
? anthropicV1ToOpenAI
: anthropicV2ToOpenAI;
case "google-palm":
case "google-ai":
return googleAIToOpenAI;
case "openai-image":
throw new Error(`SSE transformation not supported for ${responseApi}`);
default:
@@ -1,13 +1,19 @@
import { Transform, TransformOptions } from "stream";
import { StringDecoder } from "string_decoder";
// @ts-ignore
import { Parser } from "lifion-aws-event-stream";
import { logger } from "../../../../logger";
import { RetryableError } from "../index";
import { APIFormat } from "../../../../shared/key-management";
import StreamArray from "stream-json/streamers/StreamArray";
const log = logger.child({ module: "sse-stream-adapter" });
type SSEStreamAdapterOptions = TransformOptions & { contentType?: string };
type SSEStreamAdapterOptions = TransformOptions & {
contentType?: string;
api: APIFormat;
};
type AwsEventStreamMessage = {
headers: {
":message-type": "event" | "exception";
@@ -22,7 +28,9 @@ type AwsEventStreamMessage = {
*/
export class SSEStreamAdapter extends Transform {
private readonly isAwsStream;
private readonly isGoogleStream;
private parser = new Parser();
private jsonStream = StreamArray.withParser();
private partialMessage = "";
private decoder = new StringDecoder("utf8");
@@ -30,6 +38,7 @@ export class SSEStreamAdapter extends Transform {
super(options);
this.isAwsStream =
options?.contentType === "application/vnd.amazon.eventstream";
this.isGoogleStream = options?.api === "google-ai";
this.parser.on("data", (data: AwsEventStreamMessage) => {
const message = this.processAwsEvent(data);
@@ -37,6 +46,12 @@ export class SSEStreamAdapter extends Transform {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
this.jsonStream.on("data", (data: { value: any }) => {
const message = this.processGoogleValue(data.value);
if (message) {
this.push(Buffer.from(message + "\n\n"), "utf8");
}
});
}
protected processAwsEvent(event: AwsEventStreamMessage): string | null {
@@ -73,17 +88,38 @@ export class SSEStreamAdapter extends Transform {
}
}
// Google doesn't use event streams and just sends elements in an array over
// a long-lived HTTP connection. Needs stream-json to parse the array.
protected processGoogleValue(value: any): string | null {
if ("candidates" in value) {
return `data: ${JSON.stringify(value)}`;
} else {
log.error(
{ value },
"Received unexpected Google AI event stream message"
);
return getFakeErrorCompletion(
"proxy Google AI error",
JSON.stringify(value)
);
}
}
_transform(chunk: Buffer, _encoding: BufferEncoding, callback: Function) {
try {
if (this.isAwsStream) {
this.parser.write(chunk);
} else if (this.isGoogleStream) {
this.jsonStream.write(chunk);
} else {
// We may receive multiple (or partial) SSE messages in a single chunk,
// so we need to buffer and emit separate stream events for full
// messages so we can parse/transform them properly.
const str = this.decoder.write(chunk);
const fullMessages = (this.partialMessage + str).split(/\r\r|\n\n|\r\n\r\n/);
const fullMessages = (this.partialMessage + str).split(
/\r\r|\n\n|\r\n\r\n/
);
this.partialMessage = fullMessages.pop() || "";
for (const message of fullMessages) {
@@ -0,0 +1,69 @@
import { StreamingCompletionTransformer } from "../index";
import { parseEvent, ServerSentEvent } from "../parse-sse";
import { logger } from "../../../../../logger";
const log = logger.child({
module: "sse-transformer",
transformer: "google-ai-to-openai",
});
type GoogleAIStreamEvent = {
candidates: {
content: { parts: { text: string }[]; role: string };
finishReason?: "STOP" | "MAX_TOKENS" | "SAFETY" | "RECITATION" | "OTHER";
index: number;
tokenCount?: number;
safetyRatings: { category: string; probability: string }[];
}[];
};
/**
* Transforms an incoming Google AI SSE to an equivalent OpenAI
* chat.completion.chunk SSE.
*/
export const googleAIToOpenAI: StreamingCompletionTransformer = (params) => {
const { data } = params;
const rawEvent = parseEvent(data);
if (!rawEvent.data || rawEvent.data === "[DONE]") {
return { position: -1 };
}
const completionEvent = asCompletion(rawEvent);
if (!completionEvent) {
return { position: -1 };
}
const parts = completionEvent.candidates[0].content.parts;
const text = parts[0]?.text ?? "";
const newEvent = {
id: "goo-" + params.fallbackId,
object: "chat.completion.chunk" as const,
created: Date.now(),
model: params.fallbackModel,
choices: [
{
index: 0,
delta: { content: text },
finish_reason: completionEvent.candidates[0].finishReason ?? null,
},
],
};
return { position: -1, event: newEvent };
};
function asCompletion(event: ServerSentEvent): GoogleAIStreamEvent | null {
try {
const parsed = JSON.parse(event.data) as GoogleAIStreamEvent;
if (parsed.candidates?.length > 0) {
return parsed;
} else {
// noinspection ExceptionCaughtLocallyJS
throw new Error("Missing required fields");
}
} catch (error) {
log.warn({ error: error.stack, event }, "Received invalid event");
}
return null;
}
-1
View File
@@ -17,7 +17,6 @@ import {
} from "./middleware/response";
import { generateModelList } from "./openai";
import {
mirrorGeneratedImage,
OpenAIImageGenerationResult,
} from "../shared/file-storage/mirror-generated-image";
-170
View File
@@ -1,170 +0,0 @@
import { Request, RequestHandler, Router } from "express";
import * as http from "http";
import { createProxyMiddleware } from "http-proxy-middleware";
import { v4 } from "uuid";
import { config } from "../config";
import { logger } from "../logger";
import { createQueueMiddleware } from "./queue";
import { ipLimiter } from "./rate-limit";
import { handleProxyError } from "./middleware/common";
import {
addKey,
createOnProxyReqHandler,
createPreprocessorMiddleware,
finalizeBody,
forceModel,
} from "./middleware/request";
import {
createOnProxyResHandler,
ProxyResHandlerWithBody,
} from "./middleware/response";
let modelsCache: any = null;
let modelsCacheTime = 0;
const getModelsResponse = () => {
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
return modelsCache;
}
if (!config.googlePalmKey) return { object: "list", data: [] };
const bisonVariants = ["text-bison-001"];
const models = bisonVariants.map((id) => ({
id,
object: "model",
created: new Date().getTime(),
owned_by: "google",
permission: [],
root: "palm",
parent: null,
}));
modelsCache = { object: "list", data: models };
modelsCacheTime = new Date().getTime();
return modelsCache;
};
const handleModelRequest: RequestHandler = (_req, res) => {
res.status(200).json(getModelsResponse());
};
/** Only used for non-streaming requests. */
const palmResponseHandler: ProxyResHandlerWithBody = async (
_proxyRes,
req,
res,
body
) => {
if (typeof body !== "object") {
throw new Error("Expected body to be an object");
}
if (config.promptLogging) {
const host = req.get("host");
body.proxy_note = `Prompts are logged on this proxy instance. See ${host} for more information.`;
}
if (req.inboundApi === "openai") {
req.log.info("Transforming Google PaLM response to OpenAI format");
body = transformPalmResponse(body, req);
}
if (req.tokenizerInfo) {
body.proxy_tokenizer = req.tokenizerInfo;
}
// TODO: PaLM has no streaming capability which will pose a problem here if
// requests wait in the queue for too long. Probably need to fake streaming
// and return the entire completion in one stream event using the other
// response handler.
res.status(200).json(body);
};
/**
* Transforms a model response from the Anthropic API to match those from the
* OpenAI API, for users using Claude via the OpenAI-compatible endpoint. This
* is only used for non-streaming requests as streaming requests are handled
* on-the-fly.
*/
function transformPalmResponse(
palmRespBody: Record<string, any>,
req: Request
): Record<string, any> {
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0);
return {
id: "plm-" + v4(),
object: "chat.completion",
created: Date.now(),
model: req.body.model,
usage: {
prompt_tokens: req.promptTokens,
completion_tokens: req.outputTokens,
total_tokens: totalTokens,
},
choices: [
{
message: {
role: "assistant",
content: palmRespBody.candidates[0].output,
},
finish_reason: null, // palm doesn't return this
index: 0,
},
],
};
}
function reassignPathForPalmModel(proxyReq: http.ClientRequest, req: Request) {
if (req.body.stream) {
throw new Error("Google PaLM API doesn't support streaming requests");
}
// PaLM API specifies the model in the URL path, not the request body. This
// doesn't work well with our rewriter architecture, so we need to manually
// fix it here.
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateText
// POST https://generativelanguage.googleapis.com/v1beta2/{model=models/*}:generateMessage
// The chat api (generateMessage) is not very useful at this time as it has
// few params and no adjustable safety settings.
proxyReq.path = proxyReq.path.replace(
/^\/v1\/chat\/completions/,
`/v1beta2/models/${req.body.model}:generateText`
);
}
const googlePalmProxy = createQueueMiddleware({
proxyMiddleware: createProxyMiddleware({
target: "https://generativelanguage.googleapis.com",
changeOrigin: true,
selfHandleResponse: true,
logger,
on: {
proxyReq: createOnProxyReqHandler({
pipeline: [reassignPathForPalmModel, addKey, finalizeBody],
}),
proxyRes: createOnProxyResHandler([palmResponseHandler]),
error: handleProxyError,
},
}),
});
const palmRouter = Router();
palmRouter.get("/v1/models", handleModelRequest);
// OpenAI-to-Google PaLM compatibility endpoint.
palmRouter.post(
"/v1/chat/completions",
ipLimiter,
createPreprocessorMiddleware(
{ inApi: "openai", outApi: "google-palm", service: "google-palm" },
{ afterTransform: [forceModel("text-bison-001")] }
),
googlePalmProxy
);
export const googlePalm = palmRouter;
+2 -2
View File
@@ -4,7 +4,7 @@ import { checkRisuToken } from "./check-risu-token";
import { openai } from "./openai";
import { openaiImage } from "./openai-image";
import { anthropic } from "./anthropic";
import { googlePalm } from "./palm";
import { googleAI } from "./google-ai";
import { aws } from "./aws";
import { azure } from "./azure";
@@ -31,7 +31,7 @@ proxyRouter.use((req, _res, next) => {
proxyRouter.use("/openai", addV1, openai);
proxyRouter.use("/openai-image", addV1, openaiImage);
proxyRouter.use("/anthropic", addV1, anthropic);
proxyRouter.use("/google-palm", addV1, googlePalm);
proxyRouter.use("/google-ai", addV1, googleAI);
proxyRouter.use("/aws/claude", addV1, aws);
proxyRouter.use("/azure/openai", addV1, azure);
// Redirect browser requests to the homepage.
@@ -62,7 +62,6 @@ export class AnthropicKeyChecker extends KeyCheckerBase<AnthropicKey> {
{ key: key.hash, error: error.message },
"Key is rate limited. Rechecking in 10 seconds."
);
0;
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
this.updateKey(key.hash, { lastChecked: next });
break;
@@ -6,7 +6,6 @@ import type { AzureOpenAIModelFamily } from "../../models";
import { getAzureOpenAIModelFamily } from "../../models";
import { OpenAIModel } from "../openai/provider";
import { AzureOpenAIKeyChecker } from "./checker";
import { AwsKeyChecker } from "../aws/checker";
export type AzureOpenAIModel = Exclude<OpenAIModel, "dall-e">;
@@ -2,13 +2,17 @@ import crypto from "crypto";
import { Key, KeyProvider } from "..";
import { config } from "../../../config";
import { logger } from "../../../logger";
import type { GooglePalmModelFamily } from "../../models";
import type { GoogleAIModelFamily } from "../../models";
// https://developers.generativeai.google.com/models/language
export type GooglePalmModel = "text-bison-001";
// Note that Google AI is not the same as Vertex AI, both are provided by Google
// but Vertex is the GCP product for enterprise. while Google AI is the
// consumer-ish product. The API is different, and keys are not compatible.
// https://ai.google.dev/docs/migrate_to_cloud
export type GooglePalmKeyUpdate = Omit<
Partial<GooglePalmKey>,
export type GoogleAIModel = "gemini-pro";
export type GoogleAIKeyUpdate = Omit<
Partial<GoogleAIKey>,
| "key"
| "hash"
| "lastUsed"
@@ -17,13 +21,13 @@ export type GooglePalmKeyUpdate = Omit<
| "rateLimitedUntil"
>;
type GooglePalmKeyUsage = {
[K in GooglePalmModelFamily as `${K}Tokens`]: number;
type GoogleAIKeyUsage = {
[K in GoogleAIModelFamily as `${K}Tokens`]: number;
};
export interface GooglePalmKey extends Key, GooglePalmKeyUsage {
readonly service: "google-palm";
readonly modelFamilies: GooglePalmModelFamily[];
export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
readonly service: "google-ai";
readonly modelFamilies: GoogleAIModelFamily[];
/** The time at which this key was last rate limited. */
rateLimitedAt: number;
/** The time until which this key is rate limited. */
@@ -42,27 +46,27 @@ const RATE_LIMIT_LOCKOUT = 2000;
*/
const KEY_REUSE_DELAY = 500;
export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
readonly service = "google-palm";
export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
readonly service = "google-ai";
private keys: GooglePalmKey[] = [];
private keys: GoogleAIKey[] = [];
private log = logger.child({ module: "key-provider", service: this.service });
constructor() {
const keyConfig = config.googlePalmKey?.trim();
const keyConfig = config.googleAIKey?.trim();
if (!keyConfig) {
this.log.warn(
"GOOGLE_PALM_KEY is not set. PaLM API will not be available."
"GOOGLE_AI_KEY is not set. Google AI API will not be available."
);
return;
}
let bareKeys: string[];
bareKeys = [...new Set(keyConfig.split(",").map((k) => k.trim()))];
for (const key of bareKeys) {
const newKey: GooglePalmKey = {
const newKey: GoogleAIKey = {
key,
service: this.service,
modelFamilies: ["bison"],
modelFamilies: ["gemini-pro"],
isDisabled: false,
isRevoked: false,
promptCount: 0,
@@ -75,11 +79,11 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
.digest("hex")
.slice(0, 8)}`,
lastChecked: 0,
bisonTokens: 0,
"gemini-proTokens": 0,
};
this.keys.push(newKey);
}
this.log.info({ keyCount: this.keys.length }, "Loaded PaLM keys.");
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
}
public init() {}
@@ -88,10 +92,10 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
}
public get(_model: GooglePalmModel) {
public get(_model: GoogleAIModel) {
const availableKeys = this.keys.filter((k) => !k.isDisabled);
if (availableKeys.length === 0) {
throw new Error("No Google PaLM keys available");
throw new Error("No Google AI keys available");
}
// (largely copied from the OpenAI provider, without trial key support)
@@ -122,14 +126,14 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
return { ...selectedKey };
}
public disable(key: GooglePalmKey) {
public disable(key: GoogleAIKey) {
const keyFromPool = this.keys.find((k) => k.hash === key.hash);
if (!keyFromPool || keyFromPool.isDisabled) return;
keyFromPool.isDisabled = true;
this.log.warn({ key: key.hash }, "Key disabled");
}
public update(hash: string, update: Partial<GooglePalmKey>) {
public update(hash: string, update: Partial<GoogleAIKey>) {
const keyFromPool = this.keys.find((k) => k.hash === hash)!;
Object.assign(keyFromPool, { lastChecked: Date.now(), ...update });
}
@@ -142,7 +146,7 @@ export class GooglePalmKeyProvider implements KeyProvider<GooglePalmKey> {
const key = this.keys.find((k) => k.hash === hash);
if (!key) return;
key.promptCount++;
key.bisonTokens += tokens;
key["gemini-proTokens"] += tokens;
}
public getLockoutPeriod() {
+5 -5
View File
@@ -1,6 +1,6 @@
import { OpenAIModel } from "./openai/provider";
import { AnthropicModel } from "./anthropic/provider";
import { GooglePalmModel } from "./palm/provider";
import { GoogleAIModel } from "./google-ai/provider";
import { AwsBedrockModel } from "./aws/provider";
import { AzureOpenAIModel } from "./azure/provider";
import { KeyPool } from "./key-pool";
@@ -10,20 +10,20 @@ import type { ModelFamily } from "../models";
export type APIFormat =
| "openai"
| "anthropic"
| "google-palm"
| "google-ai"
| "openai-text"
| "openai-image";
/** The service that a model is hosted on; distinct because services like AWS provide multiple APIs, but have their own endpoints and authentication. */
export type LLMService =
| "openai"
| "anthropic"
| "google-palm"
| "google-ai"
| "aws"
| "azure";
export type Model =
| OpenAIModel
| AnthropicModel
| GooglePalmModel
| GoogleAIModel
| AwsBedrockModel
| AzureOpenAIModel;
@@ -77,6 +77,6 @@ export interface KeyProvider<T extends Key = Key> {
export const keyPool = new KeyPool();
export { AnthropicKey } from "./anthropic/provider";
export { OpenAIKey } from "./openai/provider";
export { GooglePalmKey } from "./palm/provider";
export { GoogleAIKey } from "././google-ai/provider";
export { AwsBedrockKey } from "./aws/provider";
export { AzureOpenAIKey } from "./azure/provider";
+6 -6
View File
@@ -7,7 +7,7 @@ import { logger } from "../../logger";
import { Key, Model, KeyProvider, LLMService } from "./index";
import { AnthropicKeyProvider, AnthropicKeyUpdate } from "./anthropic/provider";
import { OpenAIKeyProvider, OpenAIKeyUpdate } from "./openai/provider";
import { GooglePalmKeyProvider } from "./palm/provider";
import { GoogleAIKeyProvider } from "./google-ai/provider";
import { AwsBedrockKeyProvider } from "./aws/provider";
import { ModelFamily } from "../models";
import { assertNever } from "../utils";
@@ -24,7 +24,7 @@ export class KeyPool {
constructor() {
this.keyProviders.push(new OpenAIKeyProvider());
this.keyProviders.push(new AnthropicKeyProvider());
this.keyProviders.push(new GooglePalmKeyProvider());
this.keyProviders.push(new GoogleAIKeyProvider());
this.keyProviders.push(new AwsBedrockKeyProvider());
this.keyProviders.push(new AzureOpenAIKeyProvider());
}
@@ -119,9 +119,9 @@ export class KeyPool {
} else if (model.startsWith("claude-")) {
// https://console.anthropic.com/docs/api/reference#parameters
return "anthropic";
} else if (model.includes("bison")) {
} else if (model.includes("gemini")) {
// https://developers.generativeai.google.com/models/language
return "google-palm";
return "google-ai";
} else if (model.startsWith("anthropic.claude")) {
// AWS offers models from a few providers
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
@@ -142,8 +142,8 @@ export class KeyPool {
return "openai";
case "claude":
return "anthropic";
case "bison":
return "google-palm";
case "gemini-pro":
return "google-ai";
case "aws-claude":
return "aws";
case "azure-turbo":
+8 -10
View File
@@ -11,7 +11,7 @@ export type OpenAIModelFamily =
| "gpt4-turbo"
| "dall-e";
export type AnthropicModelFamily = "claude";
export type GooglePalmModelFamily = "bison";
export type GoogleAIModelFamily = "gemini-pro";
export type AwsBedrockModelFamily = "aws-claude";
export type AzureOpenAIModelFamily = `azure-${Exclude<
OpenAIModelFamily,
@@ -20,7 +20,7 @@ export type AzureOpenAIModelFamily = `azure-${Exclude<
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GooglePalmModelFamily
| GoogleAIModelFamily
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
@@ -33,7 +33,7 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"gpt4-turbo",
"dall-e",
"claude",
"bison",
"gemini-pro",
"aws-claude",
"azure-turbo",
"azure-gpt4",
@@ -53,7 +53,7 @@ export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
"^dall-e-\\d{1}$": "dall-e",
};
const modelLogger = pino({ level: "debug" }).child({ module: "startup" });
pino({ level: "debug" }).child({ module: "startup" });
export function getOpenAIModelFamily(
model: string,
@@ -70,10 +70,8 @@ export function getClaudeModelFamily(model: string): ModelFamily {
return "claude";
}
export function getGooglePalmModelFamily(model: string): ModelFamily {
if (model.match(/^\w+-bison-\d{3}$/)) return "bison";
modelLogger.warn({ model }, "Could not determine Google PaLM model family");
return "bison";
export function getGoogleAIModelFamily(_model: string): ModelFamily {
return "gemini-pro";
}
export function getAwsBedrockModelFamily(_model: string): ModelFamily {
@@ -130,8 +128,8 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "openai-image":
modelFamily = getOpenAIModelFamily(model);
break;
case "google-palm":
modelFamily = getGooglePalmModelFamily(model);
case "google-ai":
modelFamily = getGoogleAIModelFamily(model);
break;
default:
assertNever(req.outboundApi);
+1 -1
View File
@@ -74,7 +74,7 @@ export function buildFakeSse(type: string, string: string, req: Request) {
log_id: "proxy-req-" + req.id,
};
break;
case "google-palm":
case "google-ai":
case "openai-image":
throw new Error(`SSE not supported for ${req.inboundApi} requests`);
default:
+29 -4
View File
@@ -2,7 +2,11 @@ import { Tiktoken } from "tiktoken/lite";
import cl100k_base from "tiktoken/encoders/cl100k_base.json";
import { logger } from "../../logger";
import { libSharp } from "../file-storage";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { z } from "zod";
const log = logger.child({ module: "tokenizer", service: "openai" });
const GPT4_VISION_SYSTEM_PROMPT_SIZE = 170;
@@ -29,11 +33,11 @@ export async function getTokenCount(
return getTextTokenCount(prompt);
}
const gpt4 = model.startsWith("gpt-4");
const oldFormatting = model.startsWith("turbo-0301");
const vision = model.includes("vision");
const tokensPerMessage = gpt4 ? 3 : 4;
const tokensPerName = gpt4 ? 1 : -1; // turbo omits role if name is present
const tokensPerMessage = oldFormatting ? 4 : 3;
const tokensPerName = oldFormatting ? -1 : 1; // older formatting replaces role with name if name is present
let numTokens = vision ? GPT4_VISION_SYSTEM_PROMPT_SIZE : 0;
@@ -228,3 +232,24 @@ export function getOpenAIImageCost(params: {
token_count: Math.ceil(tokens),
};
}
export function estimateGoogleAITokenCount(prompt: string | GoogleAIChatMessage[]) {
if (typeof prompt === "string") {
return getTextTokenCount(prompt);
}
const tokensPerMessage = 3;
let numTokens = 0;
for (const message of prompt) {
numTokens += tokensPerMessage;
numTokens += encoder.encode(message.parts[0].text).length;
}
numTokens += 3;
return {
tokenizer: "tiktoken (google-ai estimate)",
token_count: numTokens,
};
}
+10 -5
View File
@@ -1,5 +1,8 @@
import { Request } from "express";
import type { OpenAIChatMessage } from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import type {
GoogleAIChatMessage,
OpenAIChatMessage,
} from "../../proxy/middleware/request/preprocessors/transform-outbound-payload";
import { assertNever } from "../utils";
import {
init as initClaude,
@@ -9,6 +12,7 @@ import {
init as initOpenAi,
getTokenCount as getOpenAITokenCount,
getOpenAIImageCost,
estimateGoogleAITokenCount,
} from "./openai";
import { APIFormat } from "../key-management";
@@ -24,8 +28,9 @@ type TokenCountRequest = { req: Request } & (
| {
prompt: string;
completion?: never;
service: "openai-text" | "anthropic" | "google-palm";
service: "openai-text" | "anthropic" | "google-ai";
}
| { prompt?: GoogleAIChatMessage[]; completion?: never; service: "google-ai" }
| { prompt?: never; completion: string; service: APIFormat }
| { prompt?: never; completion?: never; service: "openai-image" }
);
@@ -65,11 +70,11 @@ export async function countTokens({
}),
tokenization_duration_ms: getElapsedMs(time),
};
case "google-palm":
// TODO: Can't find a tokenization library for PaLM. There is an API
case "google-ai":
// TODO: Can't find a tokenization library for Gemini. There is an API
// endpoint for it but it adds significant latency to the request.
return {
...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
...estimateGoogleAITokenCount(prompt ?? (completion || [])),
tokenization_duration_ms: getElapsedMs(time),
};
default:
+1 -1
View File
@@ -9,7 +9,7 @@ export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object({
"gpt4-turbo": z.number().optional().default(0),
"dall-e": z.number().optional().default(0),
claude: z.number().optional().default(0),
bison: z.number().optional().default(0),
"gemini-pro": z.number().optional().default(0),
"aws-claude": z.number().optional().default(0),
});
+4 -4
View File
@@ -14,7 +14,7 @@ import { config, getFirebaseApp } from "../../config";
import {
getAzureOpenAIModelFamily,
getClaudeModelFamily,
getGooglePalmModelFamily,
getGoogleAIModelFamily,
getOpenAIModelFamily,
MODEL_FAMILIES,
ModelFamily,
@@ -33,7 +33,7 @@ const INITIAL_TOKENS: Required<UserTokenCounts> = {
"gpt4-turbo": 0,
"dall-e": 0,
claude: 0,
bison: 0,
"gemini-pro": 0,
"aws-claude": 0,
"azure-turbo": 0,
"azure-gpt4": 0,
@@ -397,8 +397,8 @@ function getModelFamilyForQuotaUsage(
return getOpenAIModelFamily(model);
case "anthropic":
return getClaudeModelFamily(model);
case "google-palm":
return getGooglePalmModelFamily(model);
case "google-ai":
return getGoogleAIModelFamily(model);
default:
assertNever(api);
}