adds gemini/makersuite keychecker, native endpoint, and streaming fixes
This commit is contained in:
+3
-2
@@ -40,11 +40,11 @@ NODE_ENV=production
|
||||
|
||||
# Which model types users are allowed to access.
|
||||
# The following model families are recognized:
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-pro | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
# turbo | gpt4 | gpt4-32k | gpt4-turbo | gpt4o | dall-e | claude | claude-opus | gemini-flash | gemini-pro | gemini-ultra | mistral-tiny | mistral-small | mistral-medium | mistral-large | aws-claude | aws-claude-opus | azure-turbo | azure-gpt4 | azure-gpt4-32k | azure-gpt4-turbo | azure-gpt4o | azure-dall-e
|
||||
# By default, all models are allowed except for 'dall-e' / 'azure-dall-e'.
|
||||
# To allow DALL-E image generation, uncomment the line below and add 'dall-e' or
|
||||
# 'azure-dall-e' to the list of allowed model families.
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-flash,gemini-pro,gemini-ultra,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||
|
||||
# Which services can be used to process prompts containing images via multimodal
|
||||
# models. The following services are recognized:
|
||||
@@ -144,6 +144,7 @@ NODE_ENV=production
|
||||
# For AWS credentials, separate the access key ID, secret key, and region with a colon.
|
||||
OPENAI_KEY=sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
ANTHROPIC_KEY=sk-ant-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
GOOGLE_AI_KEY=AIzaxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# See `docs/aws-configuration.md` for more information, there may be additional steps required to set up AWS.
|
||||
AWS_CREDENTIALS=myaccesskeyid:mysecretkey:us-east-1,anotheraccesskeyid:anothersecretkey:us-west-2
|
||||
# See `docs/azure-configuration.md` for more information, there may be additional steps required to set up Azure.
|
||||
|
||||
@@ -428,7 +428,9 @@ export const config: Config = {
|
||||
"gpt4o",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-flash",
|
||||
"gemini-pro",
|
||||
"gemini-ultra",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
|
||||
@@ -20,7 +20,9 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = {
|
||||
"dall-e": "DALL-E",
|
||||
claude: "Claude (Sonnet)",
|
||||
"claude-opus": "Claude (Opus)",
|
||||
"gemini-flash": "Gemini Flash",
|
||||
"gemini-pro": "Gemini Pro",
|
||||
"gemini-ultra": "Gemini Ultra",
|
||||
"mistral-tiny": "Mistral 7B",
|
||||
"mistral-small": "Mistral Nemo",
|
||||
"mistral-medium": "Mistral Medium",
|
||||
|
||||
@@ -12,6 +12,7 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
|
||||
// pass the _proxy_ key in this header too, instead of providing it as a
|
||||
// Bearer token in the Authorization header. So we need to check both.
|
||||
// Prefer the Authorization header if both are present.
|
||||
// Google AI uses a key querystring parameter.
|
||||
|
||||
if (req.headers.authorization) {
|
||||
const token = req.headers.authorization?.slice("Bearer ".length);
|
||||
@@ -24,6 +25,12 @@ function getProxyAuthorizationFromRequest(req: Request): string | undefined {
|
||||
delete req.headers["x-api-key"];
|
||||
return token;
|
||||
}
|
||||
|
||||
if (req.query.key) {
|
||||
const token = req.query.key?.toString();
|
||||
delete req.query.key;
|
||||
return token;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
+68
-10
@@ -16,6 +16,7 @@ import {
|
||||
ProxyResHandlerWithBody,
|
||||
} from "./middleware/response";
|
||||
import { addGoogleAIKey } from "./middleware/request/preprocessors/add-google-ai-key";
|
||||
import { GoogleAIKey, keyPool } from "../shared/key-management";
|
||||
|
||||
let modelsCache: any = null;
|
||||
let modelsCacheTime = 0;
|
||||
@@ -30,14 +31,19 @@ const getModelsResponse = () => {
|
||||
|
||||
if (!config.googleAIKey) return { object: "list", data: [] };
|
||||
|
||||
const googleAIVariants = [
|
||||
"gemini-pro",
|
||||
"gemini-1.0-pro",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-pro-latest",
|
||||
];
|
||||
const keys = keyPool
|
||||
.list()
|
||||
.filter((k) => k.service === "google-ai") as GoogleAIKey[];
|
||||
if (keys.length === 0) {
|
||||
modelsCache = { object: "list", data: [] };
|
||||
modelsCacheTime = new Date().getTime();
|
||||
return modelsCache;
|
||||
}
|
||||
|
||||
const models = googleAIVariants.map((id) => ({
|
||||
const modelIds = Array.from(
|
||||
new Set(keys.map((k) => k.modelIds).flat())
|
||||
).filter((id) => id.startsWith("models/gemini"));
|
||||
const models = modelIds.map((id) => ({
|
||||
id,
|
||||
object: "model",
|
||||
created: new Date().getTime(),
|
||||
@@ -114,7 +120,17 @@ const googleAIProxy = createQueueMiddleware({
|
||||
},
|
||||
changeOrigin: true,
|
||||
selfHandleResponse: true,
|
||||
logger,
|
||||
// Prevent logging of the API key by HPM
|
||||
logger: logger.child(
|
||||
{},
|
||||
{
|
||||
redact: {
|
||||
paths: ["*"],
|
||||
censor: (v) =>
|
||||
typeof v === "string" ? v.replace(/key=\S+/g, "key=xxxxxxx") : v,
|
||||
},
|
||||
}
|
||||
),
|
||||
on: {
|
||||
proxyReq: createOnProxyReqHandler({ pipeline: [finalizeSignedRequest] }),
|
||||
proxyRes: createOnProxyResHandler([googleAIResponseHandler]),
|
||||
@@ -125,6 +141,22 @@ const googleAIProxy = createQueueMiddleware({
|
||||
|
||||
const googleAIRouter = Router();
|
||||
googleAIRouter.get("/v1/models", handleModelRequest);
|
||||
|
||||
// Native Google AI chat completion endpoint
|
||||
googleAIRouter.post(
|
||||
"/v1beta/models/:modelId:(generateContent|streamGenerateContent)",
|
||||
ipLimiter,
|
||||
createPreprocessorMiddleware(
|
||||
{
|
||||
inApi: "google-ai",
|
||||
outApi: "google-ai",
|
||||
service: "google-ai",
|
||||
},
|
||||
{ afterTransform: [maybeReassignModel, setStreamFlag] }
|
||||
),
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
// OpenAI-to-Google AI compatibility endpoint.
|
||||
googleAIRouter.post(
|
||||
"/v1/chat/completions",
|
||||
@@ -136,12 +168,38 @@ googleAIRouter.post(
|
||||
googleAIProxy
|
||||
);
|
||||
|
||||
/** Replaces requests for non-Google AI models with gemini-pro-1.5-latest. */
|
||||
function setStreamFlag(req: Request) {
|
||||
const isStreaming = req.url.includes("streamGenerateContent");
|
||||
if (isStreaming) {
|
||||
req.body.stream = true;
|
||||
req.isStreaming = true;
|
||||
} else {
|
||||
req.body.stream = false;
|
||||
req.isStreaming = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replaces requests for non-Google AI models with gemini-pro-1.5-latest.
|
||||
* Also strips models/ from the beginning of the model IDs.
|
||||
**/
|
||||
function maybeReassignModel(req: Request) {
|
||||
const requested = req.body.model;
|
||||
// Ensure model is on body as a lot of middleware will expect it.
|
||||
const model = req.body.model || req.url.split("/").pop()?.split(":").shift();
|
||||
if (!model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
req.body.model = model;
|
||||
|
||||
const requested = model;
|
||||
if (requested.startsWith("models/")) {
|
||||
req.body.model = requested.slice("models/".length);
|
||||
}
|
||||
|
||||
if (requested.includes("gemini")) {
|
||||
return;
|
||||
}
|
||||
|
||||
req.log.info({ requested }, "Reassigning model to gemini-pro-1.5-latest");
|
||||
req.body.model = "gemini-pro-1.5-latest";
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ const handleTestMessage: RequestHandler = (req, res) => {
|
||||
};
|
||||
|
||||
function isTestMessage(body: any) {
|
||||
const { messages, prompt } = body;
|
||||
const { messages, prompt, contents } = body;
|
||||
|
||||
if (messages) {
|
||||
return (
|
||||
@@ -151,6 +151,11 @@ function isTestMessage(body: any) {
|
||||
messages[0].role === "user" &&
|
||||
messages[0].content === "Hi"
|
||||
);
|
||||
} else if (contents) {
|
||||
return (
|
||||
contents.length === 1 &&
|
||||
contents[0].parts[0]?.text === "Hi"
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
prompt?.trim() === "Human: Hi\n\nAssistant:" ||
|
||||
|
||||
@@ -2,39 +2,38 @@ import { keyPool } from "../../../../shared/key-management";
|
||||
import { RequestPreprocessor } from "../index";
|
||||
|
||||
export const addGoogleAIKey: RequestPreprocessor = (req) => {
|
||||
const apisValid = req.inboundApi === "openai" && req.outboundApi === "google-ai";
|
||||
const inboundValid =
|
||||
req.inboundApi === "openai" || req.inboundApi === "google-ai";
|
||||
const outboundValid = req.outboundApi === "google-ai";
|
||||
|
||||
const serviceValid = req.service === "google-ai";
|
||||
if (!apisValid || !serviceValid) {
|
||||
if (!inboundValid || !outboundValid || !serviceValid) {
|
||||
throw new Error("addGoogleAIKey called on invalid request");
|
||||
}
|
||||
|
||||
if (!req.body?.model) {
|
||||
throw new Error("You must specify a model with your request.");
|
||||
}
|
||||
|
||||
|
||||
const model = req.body.model;
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
req.key = keyPool.get(model, "google-ai");
|
||||
|
||||
req.log.info(
|
||||
{ key: req.key.hash, model },
|
||||
{ key: req.key.hash, model, stream: req.isStreaming },
|
||||
"Assigned Google AI API key to request"
|
||||
);
|
||||
|
||||
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:generateContent?key=$API_KEY
|
||||
// https://generativelanguage.googleapis.com/v1beta/models/$MODEL_ID:streamGenerateContent?key=${API_KEY}
|
||||
|
||||
req.isStreaming = req.isStreaming || req.body.stream;
|
||||
delete req.body.stream;
|
||||
const payload = { ...req.body, stream: undefined, model: undefined };
|
||||
|
||||
req.signedRequest = {
|
||||
method: "POST",
|
||||
protocol: "https:",
|
||||
hostname: "generativelanguage.googleapis.com",
|
||||
path: `/v1beta/models/${model}:${req.isStreaming ? "streamGenerateContent" : "generateContent"}?key=${req.key.key}`,
|
||||
path: `/v1beta/models/${model}:${
|
||||
req.isStreaming ? "streamGenerateContent" : "generateContent"
|
||||
}?key=${req.key.key}`,
|
||||
headers: {
|
||||
["host"]: `generativelanguage.googleapis.com`,
|
||||
["content-type"]: "application/json",
|
||||
},
|
||||
body: JSON.stringify(req.body),
|
||||
body: JSON.stringify(payload),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -143,6 +143,8 @@ export function sendErrorToClient({
|
||||
res.setHeader("x-oai-proxy-error-status", redactedOpts.statusCode || 500);
|
||||
}
|
||||
|
||||
req.log.info({ statusCode: res.statusCode, isStreaming, format, redactedOpts, event }, "Sending error response");
|
||||
|
||||
if (isStreaming) {
|
||||
if (!res.headersSent) {
|
||||
initializeSseStream(res);
|
||||
@@ -223,19 +225,16 @@ export function buildSpoofedCompletion({
|
||||
// TODO: Native Google AI non-streaming responses are not supported, this
|
||||
// is an untested guess at what the response should look like.
|
||||
return {
|
||||
id: "error-" + id,
|
||||
object: "chat.completion",
|
||||
created: Date.now(),
|
||||
model,
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: content }], role: "model" },
|
||||
content: { parts: [{ text: content }], role: "assistant" },
|
||||
finishReason: title,
|
||||
index: 0,
|
||||
tokenCount: null,
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 },
|
||||
};
|
||||
case "openai-image":
|
||||
return obj;
|
||||
@@ -302,7 +301,10 @@ export function buildSpoofedSSE({
|
||||
};
|
||||
break;
|
||||
case "google-ai":
|
||||
return JSON.stringify({
|
||||
// TODO: google ai supports two streaming transports, SSE and JSON.
|
||||
// we currently only support SSE.
|
||||
// return JSON.stringify({
|
||||
event = {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: content }], role: "model" },
|
||||
@@ -312,7 +314,8 @@ export function buildSpoofedSSE({
|
||||
safetyRatings: [],
|
||||
},
|
||||
],
|
||||
});
|
||||
};
|
||||
break;
|
||||
case "openai-image":
|
||||
return JSON.stringify(obj);
|
||||
default:
|
||||
|
||||
@@ -561,7 +561,7 @@ async function handleGoogleAIBadRequestError(
|
||||
errorPayload.proxy_note = `Assigned API key is invalid.`;
|
||||
}
|
||||
} else if (status === "FAILED_PRECONDITION") {
|
||||
if (message.includes(/please enable billing/i)) {
|
||||
if (message.match(/please enable billing/i)) {
|
||||
req.log.warn(
|
||||
{ key: req.key?.hash, status, msg: error.message },
|
||||
"Cannot use key due to billing restrictions."
|
||||
|
||||
@@ -116,7 +116,7 @@ export class SSEStreamAdapter extends Transform {
|
||||
try {
|
||||
const hasParts = candidates[0].content?.parts?.length > 0;
|
||||
if (hasParts) {
|
||||
return `data: ${JSON.stringify(data.value ?? data)}\n`;
|
||||
return `data: ${JSON.stringify(data.value ?? data)}`;
|
||||
} else {
|
||||
this.log.error({ event: data }, "Received bad Google AI event");
|
||||
return `data: ${buildSpoofedSSE({
|
||||
|
||||
+1
-1
@@ -70,7 +70,7 @@ export { proxyRouter as proxyRouter };
|
||||
|
||||
function addV1(req: Request, res: Response, next: NextFunction) {
|
||||
// Clients don't consistently use the /v1 prefix so we'll add it for them.
|
||||
if (!req.path.startsWith("/v1/")) {
|
||||
if (!req.path.startsWith("/v1/") && !req.path.startsWith("/v1beta/")) {
|
||||
req.url = `/v1${req.url}`;
|
||||
}
|
||||
next();
|
||||
|
||||
@@ -87,6 +87,15 @@ app.use(blacklist);
|
||||
app.use(checkOrigin);
|
||||
|
||||
app.use("/admin", adminRouter);
|
||||
app.use((req, _, next) => {
|
||||
// For whatever reason SillyTavern just ignores the path a user provides
|
||||
// when using Google AI with reverse proxy. We'll fix it here.
|
||||
if (req.path.startsWith("/v1beta/models/")) {
|
||||
req.url = `${config.proxyEndpointRoute}/google-ai${req.url}`;
|
||||
return next();
|
||||
}
|
||||
next();
|
||||
});
|
||||
app.use(config.proxyEndpointRoute, proxyRouter);
|
||||
app.use("/user", userRouter);
|
||||
if (config.staticServiceInfo) {
|
||||
|
||||
@@ -5,19 +5,20 @@ import {
|
||||
} from "./openai";
|
||||
import { APIFormatTransformer } from "./index";
|
||||
|
||||
const GoogleAIV1ContentSchema = z.object({
|
||||
parts: z.array(z.object({ text: z.string() })), // TODO: add other media types
|
||||
role: z.enum(["user", "model"]).optional(),
|
||||
});
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateContent
|
||||
export const GoogleAIV1GenerateContentSchema = z
|
||||
.object({
|
||||
model: z.string().max(100), //actually specified in path but we need it for the router
|
||||
stream: z.boolean().optional().default(false), // also used for router
|
||||
contents: z.array(
|
||||
z.object({
|
||||
parts: z.array(z.object({ text: z.string() })),
|
||||
role: z.enum(["user", "model"]),
|
||||
})
|
||||
),
|
||||
contents: z.array(GoogleAIV1ContentSchema),
|
||||
tools: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: z.array(z.object({})).max(0).optional(),
|
||||
safetySettings: z.array(z.object({})).optional(),
|
||||
systemInstruction: GoogleAIV1ContentSchema.optional(),
|
||||
generationConfig: z.object({
|
||||
temperature: z.number().optional(),
|
||||
maxOutputTokens: z.coerce
|
||||
@@ -25,7 +26,7 @@ export const GoogleAIV1GenerateContentSchema = z
|
||||
.int()
|
||||
.optional()
|
||||
.default(16)
|
||||
.transform((v) => Math.min(v, 1024)), // TODO: Add config
|
||||
.transform((v) => Math.min(v, 4096)), // TODO: Add config
|
||||
candidateCount: z.literal(1).optional(),
|
||||
topP: z.number().optional(),
|
||||
topK: z.number().optional(),
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
import axios, { AxiosError } from "axios";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
import { KeyCheckerBase } from "../key-checker-base";
|
||||
import type { GoogleAIKey, GoogleAIKeyProvider } from "./provider";
|
||||
import { getGoogleAIModelFamily } from "../../models";
|
||||
|
||||
const MIN_CHECK_INTERVAL = 3 * 1000; // 3 seconds
|
||||
const KEY_CHECK_PERIOD = 3 * 60 * 60 * 1000; // 3 hours
|
||||
const LIST_MODELS_URL =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models";
|
||||
|
||||
type ListModelsResponse = {
|
||||
models: {
|
||||
name: string;
|
||||
baseModelId: string;
|
||||
version: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
inputTokenLimit: number;
|
||||
outputTokenLimit: number;
|
||||
supportedGenerationMethods: string[];
|
||||
temperature: number;
|
||||
maxTemperature: number;
|
||||
topP: number;
|
||||
topK: number;
|
||||
}[];
|
||||
nextPageToken: string;
|
||||
};
|
||||
|
||||
type UpdateFn = typeof GoogleAIKeyProvider.prototype.update;
|
||||
|
||||
export class GoogleAIKeyChecker extends KeyCheckerBase<GoogleAIKey> {
|
||||
constructor(keys: GoogleAIKey[], updateKey: UpdateFn) {
|
||||
super(keys, {
|
||||
service: "google-ai",
|
||||
keyCheckPeriod: KEY_CHECK_PERIOD,
|
||||
minCheckInterval: MIN_CHECK_INTERVAL,
|
||||
recurringChecksEnabled: false,
|
||||
updateKey,
|
||||
});
|
||||
}
|
||||
|
||||
protected async testKeyOrFail(key: GoogleAIKey) {
|
||||
const provisionedModels = await this.getProvisionedModels(key);
|
||||
const updates = {
|
||||
modelFamilies: provisionedModels,
|
||||
};
|
||||
this.updateKey(key.hash, updates);
|
||||
this.log.info(
|
||||
{ key: key.hash, models: key.modelFamilies, ids: key.modelIds.length },
|
||||
"Checked key."
|
||||
);
|
||||
}
|
||||
|
||||
private async getProvisionedModels(
|
||||
key: GoogleAIKey
|
||||
): Promise<GoogleAIModelFamily[]> {
|
||||
const { data } = await axios.get<ListModelsResponse>(
|
||||
`${LIST_MODELS_URL}?pageSize=1000&key=${key.key}`
|
||||
);
|
||||
const models = data.models;
|
||||
|
||||
const ids = new Set<string>();
|
||||
const families = new Set<GoogleAIModelFamily>();
|
||||
models.forEach(({ name }) => {
|
||||
families.add(getGoogleAIModelFamily(name));
|
||||
ids.add(name);
|
||||
});
|
||||
|
||||
const familiesArray = Array.from(families);
|
||||
this.updateKey(key.hash, {
|
||||
modelFamilies: familiesArray,
|
||||
modelIds: Array.from(ids),
|
||||
});
|
||||
|
||||
return familiesArray;
|
||||
}
|
||||
|
||||
protected handleAxiosError(key: GoogleAIKey, error: AxiosError): void {
|
||||
if (error.response && GoogleAIKeyChecker.errorIsGoogleAIError(error)) {
|
||||
const httpStatus = error.response.status;
|
||||
const { code, message, status, details } = error.response.data.error;
|
||||
|
||||
switch (httpStatus) {
|
||||
case 400:
|
||||
const reason = details?.[0]?.reason;
|
||||
if (status === "INVALID_ARGUMENT" && reason === "API_KEY_INVALID") {
|
||||
this.log.warn(
|
||||
{ key: key.hash, reason, details },
|
||||
"Key check returned API_KEY_INVALID error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
} else if (
|
||||
status === "FAILED_PRECONDITION" &&
|
||||
message.match(/please enable billing/i)
|
||||
) {
|
||||
this.log.warn(
|
||||
{ key: key.hash, message, details },
|
||||
"Key check returned billing disabled error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
}
|
||||
break;
|
||||
case 401:
|
||||
case 403:
|
||||
this.log.warn(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Key check returned Forbidden/Unauthorized error. Disabling key."
|
||||
);
|
||||
this.updateKey(key.hash, { isDisabled: true, isRevoked: true });
|
||||
return;
|
||||
case 429:
|
||||
this.log.warn(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Key is rate limited. Rechecking key in 1 minute."
|
||||
);
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - 10 * 1000);
|
||||
this.updateKey(key.hash, { lastChecked: next });
|
||||
return;
|
||||
}
|
||||
|
||||
this.log.error(
|
||||
{ key: key.hash, status, code, message, details },
|
||||
"Encountered unexpected error status while checking key. This may indicate a change in the API; please report this."
|
||||
);
|
||||
return this.updateKey(key.hash, { lastChecked: Date.now() });
|
||||
}
|
||||
|
||||
this.log.error(
|
||||
{ key: key.hash, error: error.message },
|
||||
"Network error while checking key; trying this key again in a minute."
|
||||
);
|
||||
const oneMinute = 10 * 1000;
|
||||
const next = Date.now() - (KEY_CHECK_PERIOD - oneMinute);
|
||||
return this.updateKey(key.hash, { lastChecked: next });
|
||||
}
|
||||
|
||||
static errorIsGoogleAIError(
|
||||
error: AxiosError
|
||||
): error is AxiosError<GoogleAIError> {
|
||||
const data = error.response?.data as any;
|
||||
return data?.error?.code || data?.error?.status;
|
||||
}
|
||||
}
|
||||
|
||||
type GoogleAIError = {
|
||||
error: {
|
||||
code: string;
|
||||
message: string;
|
||||
status: string;
|
||||
details: any[];
|
||||
};
|
||||
};
|
||||
@@ -2,12 +2,13 @@ import crypto from "crypto";
|
||||
import { Key, KeyProvider } from "..";
|
||||
import { config } from "../../../config";
|
||||
import { logger } from "../../../logger";
|
||||
import type { GoogleAIModelFamily } from "../../models";
|
||||
import { HttpError, PaymentRequiredError } from "../../errors";
|
||||
import { getGoogleAIModelFamily, type GoogleAIModelFamily } from "../../models";
|
||||
import { PaymentRequiredError } from "../../errors";
|
||||
import { GoogleAIKeyChecker } from "./checker";
|
||||
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by Google
|
||||
// but Vertex is the GCP product for enterprise. while Google AI is the
|
||||
// consumer-ish product. The API is different, and keys are not compatible.
|
||||
// Note that Google AI is not the same as Vertex AI, both are provided by
|
||||
// Google but Vertex is the GCP product for enterprise, while Google API is a
|
||||
// development/hobbyist product. They use completely different APIs and keys.
|
||||
// https://ai.google.dev/docs/migrate_to_cloud
|
||||
|
||||
export type GoogleAIKeyUpdate = Omit<
|
||||
@@ -31,6 +32,8 @@ export interface GoogleAIKey extends Key, GoogleAIKeyUsage {
|
||||
rateLimitedAt: number;
|
||||
/** The time until which this key is rate limited. */
|
||||
rateLimitedUntil: number;
|
||||
/** All detected model IDs on this key. */
|
||||
modelIds: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -49,6 +52,7 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
readonly service = "google-ai";
|
||||
|
||||
private keys: GoogleAIKey[] = [];
|
||||
private checker?: GoogleAIKeyChecker;
|
||||
private log = logger.child({ module: "key-provider", service: this.service });
|
||||
|
||||
constructor() {
|
||||
@@ -78,14 +82,22 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
.digest("hex")
|
||||
.slice(0, 8)}`,
|
||||
lastChecked: 0,
|
||||
"gemini-flashTokens": 0,
|
||||
"gemini-proTokens": 0,
|
||||
"gemini-ultraTokens": 0,
|
||||
modelIds: [],
|
||||
};
|
||||
this.keys.push(newKey);
|
||||
}
|
||||
this.log.info({ keyCount: this.keys.length }, "Loaded Google AI keys.");
|
||||
}
|
||||
|
||||
public init() {}
|
||||
public init() {
|
||||
if (config.checkKeys) {
|
||||
this.checker = new GoogleAIKeyChecker(this.keys, this.update.bind(this));
|
||||
this.checker.start();
|
||||
}
|
||||
}
|
||||
|
||||
public list() {
|
||||
return this.keys.map((k) => Object.freeze({ ...k, key: undefined }));
|
||||
@@ -141,11 +153,11 @@ export class GoogleAIKeyProvider implements KeyProvider<GoogleAIKey> {
|
||||
return this.keys.filter((k) => !k.isDisabled).length;
|
||||
}
|
||||
|
||||
public incrementUsage(hash: string, _model: string, tokens: number) {
|
||||
public incrementUsage(hash: string, model: string, tokens: number) {
|
||||
const key = this.keys.find((k) => k.hash === hash);
|
||||
if (!key) return;
|
||||
key.promptCount++;
|
||||
key["gemini-proTokens"] += tokens;
|
||||
key[`${getGoogleAIModelFamily(model)}Tokens`] += tokens;
|
||||
}
|
||||
|
||||
public getLockoutPeriod() {
|
||||
|
||||
@@ -114,7 +114,8 @@ export abstract class KeyCheckerBase<TKey extends Key> {
|
||||
);
|
||||
|
||||
// Don't check any individual key too often.
|
||||
// Don't check anything at all at a rate faster than once per 3 seconds.
|
||||
// Don't check anything at all more frequently than some minimum interval
|
||||
// even if keys still need to be checked.
|
||||
const nextCheck = Math.max(
|
||||
oldestKey.lastChecked + this.keyCheckPeriod,
|
||||
this.lastCheck + this.minCheckInterval
|
||||
|
||||
+14
-3
@@ -23,7 +23,10 @@ export type OpenAIModelFamily =
|
||||
| "gpt4o"
|
||||
| "dall-e";
|
||||
export type AnthropicModelFamily = "claude" | "claude-opus";
|
||||
export type GoogleAIModelFamily = "gemini-pro";
|
||||
export type GoogleAIModelFamily =
|
||||
| "gemini-flash"
|
||||
| "gemini-pro"
|
||||
| "gemini-ultra";
|
||||
export type MistralAIModelFamily =
|
||||
// mistral changes their model classes frequently so these no longer
|
||||
// correspond to specific models. consider them rough pricing tiers.
|
||||
@@ -49,7 +52,9 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||
"dall-e",
|
||||
"claude",
|
||||
"claude-opus",
|
||||
"gemini-flash",
|
||||
"gemini-pro",
|
||||
"gemini-ultra",
|
||||
"mistral-tiny",
|
||||
"mistral-small",
|
||||
"mistral-medium",
|
||||
@@ -94,7 +99,9 @@ export const MODEL_FAMILY_SERVICE: {
|
||||
"azure-gpt4-turbo": "azure",
|
||||
"azure-gpt4o": "azure",
|
||||
"azure-dall-e": "azure",
|
||||
"gemini-flash": "google-ai",
|
||||
"gemini-pro": "google-ai",
|
||||
"gemini-ultra": "google-ai",
|
||||
"mistral-tiny": "mistral-ai",
|
||||
"mistral-small": "mistral-ai",
|
||||
"mistral-medium": "mistral-ai",
|
||||
@@ -134,8 +141,12 @@ export function getClaudeModelFamily(model: string): AnthropicModelFamily {
|
||||
return "claude";
|
||||
}
|
||||
|
||||
export function getGoogleAIModelFamily(_model: string): ModelFamily {
|
||||
return "gemini-pro";
|
||||
export function getGoogleAIModelFamily(model: string): GoogleAIModelFamily {
|
||||
return model.includes("ultra")
|
||||
? "gemini-ultra"
|
||||
: model.includes("flash")
|
||||
? "gemini-flash"
|
||||
: "gemini-pro";
|
||||
}
|
||||
|
||||
export function getMistralAIModelFamily(model: string): MistralAIModelFamily {
|
||||
|
||||
Reference in New Issue
Block a user