8 Commits

Author SHA1 Message Date
nai-degen 8d5059534a starts adding cohere api format and schemas 2024-06-09 12:46:02 -05:00
nai-degen 0ea43f61c2 fixes incorrect variable name in .env.example docs 2024-06-09 11:36:20 -05:00
nai-degen ca4321b4cb adjusts openai schema validation to allow
ull stop sequence
2024-06-07 14:29:18 -05:00
nai-degen 7660ed8b94 allows enabling vision prompts on a per-service basis 2024-06-07 12:09:43 -05:00
nai-degen 55f1bbed3b adds ipv6 mask to default ADMIN_WHITELIST 2024-06-02 20:49:18 -05:00
nai-degen 57fd17ede0 makes it easier for clients to detect proxy errors programatically 2024-05-27 15:30:28 -05:00
nai-degen 9d00b8a9de adjusts max IP error message wording 2024-05-27 08:24:56 -05:00
nai-degen 155e185c6e fixes shutdown handler fuckup 2024-05-26 15:36:54 -05:00
25 changed files with 378 additions and 524 deletions
+15 -6
View File
@@ -46,6 +46,14 @@ NODE_ENV=production
# 'azure-dall-e' to the list of allowed model families.
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
# Which services can be used to process prompts containing images via multimodal
# models. The following services are recognized:
# openai | anthropic | aws | azure | google-ai | mistral-ai
# Do not enable this feature unless all users are trusted, as you will be liable
# for any user-submitted images containing illegal content.
# By default, no image services are allowed and image prompts are rejected.
# ALLOWED_VISION_SERVICES=
# IP addresses or CIDR blocks from which requests will be blocked.
# IP_BLACKLIST=10.0.0.1/24
# URLs from which requests will be blocked.
@@ -60,7 +68,7 @@ NODE_ENV=production
# Avoid short or common phrases as this tests the entire prompt.
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
# Message to show when requests are rejected.
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
# REJECT_MESSAGE="You can't say that here."
# Whether prompts should be logged to Google Sheets.
# Requires additional setup. See `docs/google-sheets.md` for more information.
@@ -102,18 +110,19 @@ NODE_ENV=production
# ALLOW_NICKNAME_CHANGES=true
# Default token quotas for each model family. (0 for unlimited)
# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
# which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
# See `docs/dall-e-configuration.md` for more information.
# Specify as TOKEN_QUOTA_MODEL_FAMILY=value, replacing dashes with underscores.
# TOKEN_QUOTA_TURBO=0
# TOKEN_QUOTA_GPT4=0
# TOKEN_QUOTA_GPT4_32K=0
# TOKEN_QUOTA_GPT4_TURBO=0
# TOKEN_QUOTA_DALL_E=0
# TOKEN_QUOTA_CLAUDE=0
# TOKEN_QUOTA_GEMINI_PRO=0
# TOKEN_QUOTA_AWS_CLAUDE=0
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
# See `docs/dall-e-configuration.md` for more information.
# TOKEN_QUOTA_DALL_E=0
# How often to refresh token quotas. (hourly | daily)
# Leave unset to never automatically refresh quotas.
+1 -1
View File
@@ -70,4 +70,4 @@ You can provide a comma-separated list containing individual IPv4 or IPv6 addres
To whitelist an entire IP range, use CIDR notation. For example, `192.168.0.1/24` would whitelist all addresses from `192.168.0.0` to `192.168.0.255`.
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0`, which will allow access from any IP address. This is the default behavior.
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0,::0`, which will allow access from any IPv4 or IPv6 address. This is the default behavior.
+1 -1
View File
@@ -1,7 +1,7 @@
import { Router } from "express";
import { z } from "zod";
import { encodeCursor, decodeCursor } from "../../shared/utils";
import { eventsRepo } from "../../shared/database/repos/events";
import { eventsRepo } from "../../shared/database/repos/event";
const router = Router();
+1 -1
View File
@@ -2,7 +2,7 @@ import { Router } from "express";
import { z } from "zod";
import * as userStore from "../../shared/users/user-store";
import { parseSort, sortBy } from "../../shared/utils";
import { UserPartialSchema, UserSchema } from "../../shared/database/repos/users";
import { UserPartialSchema, UserSchema } from "../../shared/users/schema";
const router = Router();
+6 -1
View File
@@ -9,10 +9,15 @@ import { parseSort, sortBy, paginate } from "../../shared/utils";
import { keyPool } from "../../shared/key-management";
import { LLMService, MODEL_FAMILIES } from "../../shared/models";
import { getTokenCostUsd, prettyTokens } from "../../shared/stats";
import {
User,
UserPartialSchema,
UserSchema,
UserTokenCounts,
} from "../../shared/users/schema";
import { getLastNImages } from "../../shared/file-storage/image-history";
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
import { User, UserPartialSchema, UserSchema, UserTokenCounts } from "../../shared/database/repos/users";
const router = Router();
+28 -8
View File
@@ -3,7 +3,7 @@ import dotenv from "dotenv";
import type firebase from "firebase-admin";
import path from "path";
import pino from "pino";
import type { ModelFamily } from "./shared/models";
import type { LLMService, ModelFamily } from "./shared/models";
import { MODEL_FAMILIES } from "./shared/models";
dotenv.config();
@@ -340,13 +340,18 @@ type Config = {
*/
allowOpenAIToolUsage?: boolean;
/**
* Whether to allow prompts containing images, for use with multimodal models.
* Avoid giving this to untrusted users, as they can submit illegal content.
* Which services will accept prompts containing images, for use with
* multimodal models. Users with `special` role are exempt from this
* restriction.
*
* Applies to GPT-4 Vision and Claude Vision. Users with `special` role are
* exempt from this restriction.
* Do not enable this feature for untrusted users, as malicious users could
* send images which violate your provider's terms of service or local laws.
*
* Defaults to no services, meaning image prompts are disabled. Use a comma-
* separated list. Available services are:
* openai,anthropic,google-ai,mistral-ai,aws,azure
*/
allowImagePrompts?: boolean;
allowedVisionServices: LLMService[];
/**
* Allows overriding the default proxy endpoint route. Defaults to /proxy.
* A leading slash is required.
@@ -479,9 +484,13 @@ export const config: Config = {
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
allowImagePrompts: getEnvWithDefault("ALLOW_IMAGE_PROMPTS", false),
allowedVisionServices: parseCsv(
getEnvWithDefault("ALLOWED_VISION_SERVICES", "")
) as LLMService[],
proxyEndpointRoute: getEnvWithDefault("PROXY_ENDPOINT_ROUTE", "/proxy"),
adminWhitelist: parseCsv(getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0")),
adminWhitelist: parseCsv(
getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0,::/0")
),
ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")),
} as const;
@@ -534,6 +543,17 @@ export async function assertConfigIsValid() {
);
}
if (process.env.ALLOW_IMAGE_PROMPTS === "true") {
const hasAllowedServices = config.allowedVisionServices.length > 0;
if (!hasAllowedServices) {
config.allowedVisionServices = ["openai", "anthropic"];
startupLogger.warn(
{ allowedVisionServices: config.allowedVisionServices },
"ALLOW_IMAGE_PROMPTS is deprecated. Use ALLOWED_VISION_SERVICES instead."
);
}
}
if (config.promptLogging && !config.promptLoggingBackend) {
throw new Error(
"Prompt logging is enabled but no backend is configured. Set PROMPT_LOGGING_BACKEND to 'google_sheets' or 'file'."
+5 -2
View File
@@ -66,7 +66,8 @@ export const gatekeeper: RequestHandler = (req, res, next) => {
req,
res,
403,
"Forbidden: no more IPs can authenticate with this user token"
`Forbidden: no more IP addresses allowed for this user token`,
{ currentIp: ip, maxIps: user?.maxIps }
);
case "disabled":
const bannedUser = getUser(token);
@@ -84,7 +85,8 @@ function sendError(
req: Request,
res: Response,
status: number,
message: string
message: string,
data: any = {}
) {
const isPost = req.method === "POST";
const hasBody = isPost && req.body;
@@ -103,6 +105,7 @@ function sendError(
format: "unknown",
statusCode: status,
reqId: req.id,
obj: data,
},
});
}
@@ -9,9 +9,14 @@ import { ForbiddenError } from "../../../../shared/errors";
* Rejects prompts containing images if multimodal prompts are disabled.
*/
export const validateVision: RequestPreprocessor = async (req) => {
if (config.allowImagePrompts) return;
if (req.user?.type === "special") return;
if (req.service === undefined) {
throw new Error("Request service must be set before validateVision");
}
if (req.user?.type === "special") return;
if (config.allowedVisionServices.includes(req.service)) return;
// vision not allowed for req's service, block prompts with images
let hasImage = false;
switch (req.outboundApi) {
case "openai":
@@ -52,7 +52,13 @@ function getMessageContent({
delete obj.stack;
}
return [header, friendlyMessage, serializedObj, prettyTrace].join("\n\n");
return [
header,
friendlyMessage,
serializedObj,
prettyTrace,
"<!-- oai-proxy-error -->",
].join("\n\n");
}
type ErrorGeneratorOptions = {
@@ -116,6 +122,11 @@ export function sendErrorToClient({
const isStreaming =
req.isStreaming || req.body.stream === true || req.body.stream === "true";
if (!res.headersSent) {
res.setHeader("x-oai-proxy-error", options.title);
res.setHeader("x-oai-proxy-error-status", options.statusCode || 500);
}
if (isStreaming) {
if (!res.headersSent) {
initializeSseStream(res);
+1 -4
View File
@@ -179,10 +179,7 @@ function cleanup() {
process.exit(0);
}
process.on("exit", () => cleanup());
process.on("SIGHUP", () => process.exit(128 + 1));
process.on("SIGINT", () => process.exit(128 + 2));
process.on("SIGTERM", () => process.exit(128 + 15));
process.on("SIGINT", cleanup);
function registerUncaughtExceptionHandler() {
process.on("uncaughtException", (err: any) => {
+2 -1
View File
@@ -119,7 +119,8 @@ export const transformOpenAIToAnthropicChat: APIFormatTransformer<
stream: rest.stream,
temperature: rest.temperature,
top_p: rest.top_p,
stop_sequences: typeof rest.stop === "string" ? [rest.stop] : rest.stop,
stop_sequences:
typeof rest.stop === "string" ? [rest.stop] : rest.stop || undefined,
...(rest.user ? { metadata: { user_id: rest.user } } : {}),
// Anthropic supports top_k, but OpenAI does not
// OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed,
+181
View File
@@ -0,0 +1,181 @@
import { z } from "zod";
import {
OPENAI_OUTPUT_MAX,
OpenAIV1ChatCompletionSchema,
flattenOpenAIMessageContent,
} from "./openai";
import { APIFormatTransformer } from ".";
// https://docs.cohere.com/reference/chat
export const CohereV1ChatSchema = z
.object({
message: z.string(),
model: z.string().default("command-r-plus"),
stream: z.boolean().default(false).optional(),
preamble: z.string().optional(),
chat_history: z
.array(
// Either a message from a chat participant, or a past tool call
z.union([
z.object({
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
message: z.string(),
tool_calls: z
.array(z.object({ name: z.string(), parameters: z.any() }))
.optional(),
}),
z.object({
role: z.enum(["TOOL"]),
tool_results: z.array(
z.object({
call: z.object({ name: z.string(), parameters: z.any() }),
outputs: z.array(z.any()),
})
),
}),
])
)
.optional(),
// Don't allow conversation_id as it causes calls to be stateful and we don't
// offer guarantees about which key a user's request will be routed to.
conversation_id: z.literal(undefined).optional(),
prompt_truncation: z
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
.optional(),
/*
Supporting RAG is complex because documents can be arbitrary size and have
to have embeddings generated, which incurs a cost that is not trivial to
estimate. We don't support it for now.
connectors: z
.array(
z.object({
id: z.string(),
user_access_token: z.string().optional(),
continue_on_failure: z.boolean().default(false).optional(),
options: z.any().optional(),
})
)
.optional(),
search_queries_only: z.boolean().default(false).optional(),
documents: z
.array(
z.object({
id: z.string().optional(),
title: z.string().optional(),
text: z.string(),
_excludes: z.array(z.string()).optional(),
})
)
.optional(),
citation_quality: z.enum(["accurate", "fast"]).optional(),
*/
temperature: z.number().default(0.3).optional(),
max_tokens: z
.number()
.int()
.nullish()
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
max_input_tokens: z.number().int().optional(),
k: z.number().int().min(0).max(500).default(0).optional(),
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
seed: z.number().int().optional(),
stop_sequences: z.array(z.string()).max(5).optional(),
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
presence_penalty: z.number().min(0).max(1).default(0).optional(),
tools: z
.array(
z.object({
name: z.string(),
description: z.string(),
parameter_definitions: z.record(
z.object({
description: z.string().optional(),
type: z.string(),
required: z.boolean().optional().default(false),
})
),
})
)
.optional(),
tool_results: z
.array(
z.object({
call: z.object({
name: z.string(),
parameters: z.record(z.any()),
}),
outputs: z.array(z.record(z.any())),
})
)
.optional(),
// We always force single step to avoid stateful calls or expensive multi-step
// generations when tools are involved.
force_single_step: z.literal(true).default(true).optional(),
})
.strip();
export type CohereChatMessage = NonNullable<
z.infer<typeof CohereV1ChatSchema>["chat_history"]
>[number];
export function flattenCohereMessageContent(
message: CohereChatMessage
): string {
return message.role === "TOOL"
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
: message.message;
}
export const transformOpenAIToCohere: APIFormatTransformer<
typeof CohereV1ChatSchema
> = async (req) => {
const { body } = req;
const result = OpenAIV1ChatCompletionSchema.safeParse({
...body,
model: "gpt-3.5-turbo",
});
if (!result.success) {
req.log.warn(
{ issues: result.error.issues, body },
"Invalid OpenAI-to-Cohere request"
);
throw result.error;
}
const { messages, ...rest } = result.data;
// Final OAI message becomes the `message` field in Cohere
const message = messages[messages.length - 1];
// If the first message has system role, use it as preamble.
const hasSystemPreamble = messages[0]?.role === "system";
const preamble = hasSystemPreamble
? flattenOpenAIMessageContent(messages[0].content)
: undefined;
const chatHistory = messages.slice(0, -1).map((m) => {
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
m.role === "assistant"
? "CHATBOT"
: m.role === "system"
? "SYSTEM"
: "USER";
const content = flattenOpenAIMessageContent(m.content);
const message = m.name ? `${m.name}: ${content}` : content;
return { role, message };
});
return {
model: rest.model,
preamble,
chat_history: chatHistory,
message: flattenOpenAIMessageContent(message.content),
stop_sequences:
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
max_tokens: rest.max_tokens,
temperature: rest.temperature,
p: rest.top_p,
frequency_penalty: rest.frequency_penalty,
presence_penalty: rest.presence_penalty,
seed: rest.seed,
stream: rest.stream,
};
};
+20
View File
@@ -22,6 +22,7 @@ import {
transformOpenAIToGoogleAI,
} from "./google-ai";
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
export { OpenAIChatMessage } from "./openai";
export {
@@ -33,15 +34,29 @@ export {
export { GoogleAIChatMessage } from "./google-ai";
export { MistralAIChatMessage } from "./mistral-ai";
/** Represents a pair of API formats that can be transformed between. */
type APIPair = `${APIFormat}->${APIFormat}`;
/** Represents a map of API format pairs to transformer functions. */
type TransformerMap = {
[key in APIPair]?: APIFormatTransformer<any>;
};
/**
* Represents a transformer function that takes a Request and returns a Promise
* resolving to a value of the specified Zod schema type.
*
* @template Z The Zod schema type to transform the request into (from api-schemas).
* @param req The incoming Request to transform.
* @returns A Promise resolving to the transformed request body.
*/
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
req: Request
) => Promise<z.infer<Z>>;
/**
* Specifies possible translations between API formats and the corresponding
* transformer functions to apply them.
*/
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
@@ -49,8 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
"openai->openai-text": transformOpenAIToOpenAIText,
"openai->openai-image": transformOpenAIToOpenAIImage,
"openai->google-ai": transformOpenAIToGoogleAI,
"openai->cohere-chat": transformOpenAIToCohere,
};
/**
* Specifies the schema for each API format to validate incoming requests.
*/
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"anthropic-chat": AnthropicV1MessagesSchema,
"anthropic-text": AnthropicV1TextSchema,
@@ -59,4 +78,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
"openai-image": OpenAIV1ImagesGenerationSchema,
"google-ai": GoogleAIV1GenerateContentSchema,
"mistral-ai": MistralAIV1ChatCompletionsSchema,
"cohere-chat": CohereV1ChatSchema,
};
+1 -1
View File
@@ -47,7 +47,7 @@ export const OpenAIV1ChatCompletionSchema = z
stream: z.boolean().optional().default(false),
stop: z
.union([z.string().max(500), z.array(z.string().max(500))])
.optional(),
.nullish(),
max_tokens: z.coerce
.number()
.int()
+1 -1
View File
@@ -3,8 +3,8 @@
import type { HttpRequest } from "@smithy/types";
import { Express } from "express-serve-static-core";
import { APIFormat, Key } from "./key-management";
import { User } from "./users/schema";
import { LLMService, ModelFamily } from "./models";
import { User } from "./database/repos/users";
declare global {
namespace Express {
+2 -7
View File
@@ -23,11 +23,7 @@ export async function initializeDatabase() {
log.info("Initializing database...");
const sqlite3 = await import("better-sqlite3");
database = sqlite3.default(config.sqliteDataPath, {
verbose: process.env.SQLITE_VERBOSE === "true"
? (msg, ...args) => log.debug({ args }, String(msg))
: undefined,
});
database = sqlite3.default(config.sqliteDataPath);
migrateDatabase();
database.pragma("journal_mode = WAL");
log.info("Database initialized.");
@@ -90,5 +86,4 @@ function assertNumber(value: unknown): asserts value is number {
throw new Error("Expected number");
}
}
export { EventLogEntry } from "./repos/events";
export { EventLogEntry } from "./repos/event";
-61
View File
@@ -58,65 +58,4 @@ export const migrations = [
);
},
},
{
name: "add users schema",
version: 4,
up: (db) => {
// language=SQLite
const sql = `
CREATE TABLE IF NOT EXISTS users
(
token TEXT PRIMARY KEY NOT NULL,
nickname TEXT,
type TEXT CHECK (type IN ('normal', 'special', 'temporary')) NOT NULL,
createdAt INTEGER NOT NULL,
lastUsedAt INTEGER,
disabledAt INTEGER,
disabledReason TEXT,
expiresAt INTEGER,
maxIps INTEGER,
adminNote TEXT
);
CREATE TABLE IF NOT EXISTS user_ips
(
userToken TEXT NOT NULL,
ip TEXT NOT NULL,
PRIMARY KEY (userToken, ip),
FOREIGN KEY (userToken) REFERENCES users (token)
);
CREATE TABLE IF NOT EXISTS user_token_counts
(
userToken TEXT NOT NULL,
modelFamily TEXT NOT NULL,
inputTokens INTEGER NOT NULL,
outputTokens INTEGER NOT NULL,
tokenLimit INTEGER NOT NULL,
prompts INTEGER NOT NULL,
PRIMARY KEY (userToken, modelFamily)
);
CREATE TABLE IF NOT EXISTS user_meta
(
userToken TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (userToken, key),
FOREIGN KEY (userToken) REFERENCES users (token)
);
`;
db.exec(sql);
},
down: (db) => {
// language=SQLite
const sql = `
DROP TABLE users;
DROP TABLE user_ips;
DROP TABLE user_token_counts;
DROP TABLE user_meta;
`;
db.exec(sql);
},
},
] satisfies Migration[];
-420
View File
@@ -1,420 +0,0 @@
import { ZodType, z } from "zod";
import { MODEL_FAMILIES, ModelFamily } from "../../models";
import { makeOptionalPropsNullable } from "../../utils";
import { getDatabase } from "../index";
import type { Transaction } from "better-sqlite3";
// This just dynamically creates a Zod object type with a key for each model
// family and an optional number value.
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
MODEL_FAMILIES.reduce(
(acc, family) => {
return {
...acc,
[family]: z.object({
input: z.number().optional().default(0),
output: z.number().optional().default(0),
limit: z.number().optional().default(0),
prompts: z.number().optional().default(0),
}),
};
},
{} as Record<
ModelFamily,
ZodType<{ input: number; output: number; limit: number; prompts: number }>
>
)
);
// Old token counts schema before counts were combined into a single object.
const tokenCountsSchemaOld = z.object(
MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
{} as Record<ModelFamily, ZodType<number>>
)
);
export const UserSchema = z
.object({
/** User's personal access token. */
token: z.string(),
/** IP addresses the user has connected from. */
ip: z.array(z.string()),
/** User's nickname. */
nickname: z.string().max(80).optional(),
/**
* The user's privilege level.
* - `normal`: Default role. Subject to usual rate limits and quotas.
* - `special`: Special role. Higher quotas and exempt from auto-ban/lockout.
**/
type: z.enum(["normal", "special", "temporary"]),
/** Number of prompts the user has made. */
promptCount: z.number(),
/**
* @deprecated Use `tokenCounts` instead.
* Never used; retained for backwards compatibility.
*/
tokenCount: z.any().optional(),
/** Number of tokens the user has consumed, by model family. */
tokenCounts: tokenCountsSchemaOld,
/** Maximum number of tokens the user can consume, by model family. */
tokenLimits: tokenCountsSchemaOld,
/** Token data for the user, by model family. */
modelTokenCounts: tokenCountsSchema,
/** Time at which the user was created. */
createdAt: z.number(),
/** Time at which the user last connected. */
lastUsedAt: z.number().optional(),
/** Time at which the user was disabled, if applicable. */
disabledAt: z.number().optional(),
/** Reason for which the user was disabled, if applicable. */
disabledReason: z.string().optional(),
/** Time at which the user will expire and be disabled (for temp users). */
expiresAt: z.number().optional(),
/** The user's maximum number of IP addresses; supercedes global max. */
maxIps: z.coerce.number().int().min(0).optional(),
/** Private note about the user. */
adminNote: z.string().optional(),
meta: z.record(z.any()).optional(),
})
.strict();
/**
* Variant of `;
UserSchema` which allows for partial updates, and makes any
* optional properties on the base schema nullable. Null values are used to
* indicate that the property should be deleted from the user object.
*/
export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
.partial()
.extend({ token: z.string() });
export type UserTokenCounts = {
[K in ModelFamily]: {
input: number;
output: number;
limit: number;
prompts: number;
};
};
export type UserTokenCountsOld = {
[K in ModelFamily]: number | undefined;
};
export type User = z.infer<typeof UserSchema>;
export type UserUpdate = z.infer<typeof UserPartialSchema>;
export type VirtualUser = User & { virtual: true; ipCount: number };
export const UsersRepo = {
getUserByToken: (token: string) => {
const db = getDatabase();
// language=SQLite
const sql = `
SELECT u.*,
json_group_array(ui.ip) as ip,
json_group_object(utc.modelFamily,
json_object('input', utc.inputTokens,
'output', utc.outputTokens,
'limit', utc.tokenLimit,
'prompts', utc.prompts)) as tokenCounts,
json_object(um.key, um.value) as meta
FROM users u
LEFT JOIN user_ips ui ON u.token = ui.userToken
LEFT JOIN user_token_counts utc ON u.token = utc.userToken
LEFT JOIN user_meta um ON u.token = um.userToken
WHERE u.token = ?;
`;
const user = db.prepare(sql).get(token);
if (!user) return;
return marshalUser(user);
},
getUsers: (pagination: { limit: number; cursor?: string }): VirtualUser[] => {
const db = getDatabase();
const { limit, cursor } = pagination;
const params = [];
let sql = `
SELECT u.*,
count(ui.ip) as ipCount,
json_group_object(utc.modelFamily,
json_object('input', utc.inputTokens,
'output', utc.outputTokens,
'limit', utc.tokenLimit,
'prompts', utc.prompts)) as tokenCounts,
json_object(um.key, um.value) as meta
FROM users u
LEFT JOIN user_ips ui ON u.token = ui.userToken
LEFT JOIN user_token_counts utc ON u.token = utc.userToken
LEFT JOIN user_meta um ON u.token = um.userToken
`;
if (cursor) {
sql += ` WHERE u.token < ?`;
params.push(cursor);
}
sql += ` GROUP BY u.token ORDER BY u.token DESC LIMIT ?`;
params.push(limit);
return db
.prepare(sql)
.all(params)
.map((r: any) => {
const virtual: VirtualUser = {
...marshalUser(r),
virtual: true,
ipCount: r.ipCount ?? 0,
};
return virtual;
});
},
/**
* Upserts a user record by user token. Intended for use via the REST API,
* prefer a more targeted method if possible. Undefined values are ignored,
* null values are used to indicate that the field should be cleared.
*
* @param update - The user data to upsert, with `token` required.
*/
upsertUser: (update: UserUpdate): void => {
const db = getDatabase();
if (!db.inTransaction) {
return db.transaction(() => UsersRepo.upsertUser(update))();
}
const updates: Partial<User> = {};
for (const field of Object.entries(update)) {
const [key, value] = field as [keyof User, any]; // assertion validated by zod
if (value === undefined || key === "token") continue;
updates[key] = value;
}
const setFields = Object.keys(updates)
.map((key) => `${key} = :${key}`)
.join(", ");
const params = { ...updates, token: update.token };
// scalars
const sql = `
INSERT INTO users (token, nickname, type, createdAt, lastUsedAt, disabledAt, disabledReason, expiresAt, maxIps,
adminNote)
VALUES (:token, :nickname, :type, :createdAt, :lastUsedAt, :disabledAt, :disabledReason, :expiresAt, :maxIps,
:adminNote)
ON CONFLICT(token) DO UPDATE SET ${setFields};
`;
db.prepare(sql).run(params);
// replace ip addresses
if (update.ip) {
const sql = `
DELETE
FROM user_ips
WHERE userToken = :token;
INSERT INTO user_ips (userToken, ip)
VALUES ${update.ip.map(() => "(?, ?)").join(", ")};
`;
db.prepare(sql).run(
update.ip.flatMap((ip: string) => [update.token, ip])
);
}
if (update.modelTokenCounts) {
const sql = `
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
VALUES (:token, :modelFamily, :inputTokens, :outputTokens, :tokenLimit, :prompts)
ON CONFLICT(userToken, modelFamily) DO UPDATE SET inputTokens = :inputTokens,
outputTokens = :outputTokens,
tokenLimit = :tokenLimit,
prompts = :prompts;
`;
for (const [family, counts] of Object.entries(update.modelTokenCounts)) {
db.prepare(sql).run({
token: update.token,
modelFamily: family,
...counts,
});
}
}
if (update.meta) {
const sql = `
DELETE
FROM user_meta
WHERE userToken = :token;
INSERT INTO user_meta (userToken, key, value)
VALUES ${Object.keys(update.meta)
.map(() => "(?, ?, ?)")
.join(", ")};
`;
db.prepare(sql).run(
Object.entries(update.meta).flatMap(([key, value]) => [
update.token,
key,
value,
])
);
}
},
/**
* Inserts or updates multiple user records in a single transaction.
* Periodically commits the transaction and yields to the event loop to
* prevent blocking the main thread for too long.
* @param updates - The user data to upsert.
*/
upsertUsers: async (updates: UserUpdate[]) => {
const db = getDatabase();
const BATCH_SIZE = 50;
const chunked = updates.reduce<UserUpdate[][]>((acc, _, i) => {
if (i % BATCH_SIZE === 0) acc.push(updates.slice(i, i + BATCH_SIZE));
return acc;
}, []);
const transaction = db.transaction((updates: UserUpdate[]) => {
for (const update of updates) {
UsersRepo.upsertUser(update);
}
});
for (const chunk of chunked) {
await new Promise((resolve) => setTimeout(resolve, 0));
transaction(chunk);
}
},
/**
* Increments the token usage counters for a user's token by the provided
* values, and increments prompt count by 1.
*/
incrementUsage(
userToken: string,
family: ModelFamily,
input: number,
output: number
) {
const db = getDatabase();
const sql = `
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
VALUES (:userToken, :modelFamily, :inputTokens, :outputTokens, 0, 1)
ON CONFLICT(userToken, modelFamily) DO UPDATE SET inputTokens = inputTokens + :inputTokens,
outputTokens = outputTokens + :outputTokens,
prompts = prompts + 1;
`;
db.prepare(sql).run({
userToken,
modelFamily: family,
inputTokens: input,
outputTokens: output,
});
},
/**
* Disables user, optionally with reason.
*/
disableUser(userToken: string, reason?: string) {
const db = getDatabase();
const disabledAt = Date.now();
const sql = `
UPDATE users
SET disabledAt = :disabledAt,
disabledReason = :reason
WHERE token = :userToken;
INSERT OR REPLACE INTO user_meta (userToken, key, value)
VALUES (:userToken, 'refreshable', 'false');
`;
db.prepare(sql).run({ userToken, disabledAt, reason });
},
/**
* Restores quotas for a user by adding the provided token counts to their
* existing counts.
*/
refreshQuotas(
userToken: string,
tokensByFamily: Record<ModelFamily, number>
): void {
const db = getDatabase();
if (!db.inTransaction) {
return db.transaction(() =>
UsersRepo.refreshQuotas(userToken, tokensByFamily)
)();
}
// for each provided family, increment the tokenLimit to equal inputTokens + outputTokens + refresh amount
const sql = `
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
VALUES (:userToken, :modelFamily, 0, 0, :refreshAmount, 0)
ON CONFLICT(userToken, modelFamily) DO UPDATE SET tokenLimit = inputTokens + outputTokens + :refreshAmount;
`;
for (const [family, tokens] of Object.entries(tokensByFamily)) {
db.prepare(sql).run({
userToken,
modelFamily: family,
refreshAmount: tokens,
});
}
},
/**
* Resets token usage counters for a given user to zero.
*/
resetUsage(userToken: string) {
const db = getDatabase();
const sql = `
DELETE
FROM user_token_counts
WHERE userToken = :token
`;
db.prepare(sql).run({ token: userToken });
},
};
function marshalUser(row: any): User {
const user: Partial<User> = {
token: row.token,
nickname: row.nickname,
type: row.type,
createdAt: row.createdAt,
lastUsedAt: row.lastUsedAt,
disabledAt: row.disabledAt,
disabledReason: row.disabledReason,
expiresAt: row.expiresAt,
maxIps: row.maxIps,
adminNote: row.adminNote,
};
user.ip = row.ip ? JSON.parse(row.ip) : [];
user.meta = row.meta ? JSON.parse(row.meta) : {};
user.modelTokenCounts = JSON.parse(row.tokenCounts ?? "{}") as z.infer<
typeof tokenCountsSchema
>;
// legacy token fields
user.promptCount = 0;
user.tokenCount = 0;
user.tokenCounts = {} as z.infer<typeof tokenCountsSchemaOld>;
if (row.tokenCounts) {
// initialize missing model families
for (const family of MODEL_FAMILIES) {
if (!user.modelTokenCounts[family]) {
user.modelTokenCounts[family] = {
input: 0,
output: 0,
limit: 0,
prompts: 0,
};
}
// aggregate legacy fields
user.promptCount += user.modelTokenCounts[family].prompts;
user.tokenCount +=
user.modelTokenCounts[family].input +
user.modelTokenCounts[family].output;
user.tokenCounts[family] =
user.modelTokenCounts[family].input +
user.modelTokenCounts[family].output;
}
}
return user as User;
}
+2 -1
View File
@@ -9,7 +9,8 @@ export type APIFormat =
| "anthropic-chat" // Anthropic's newer messages array format
| "anthropic-text" // Legacy flat string prompt format
| "google-ai"
| "mistral-ai";
| "mistral-ai"
| "cohere-chat";
export interface Key {
/** The API key itself. Never log this, use `hash` instead. */
+18 -2
View File
@@ -14,7 +14,8 @@ export type LLMService =
| "google-ai"
| "mistral-ai"
| "aws"
| "azure";
| "azure"
| "cohere";
export type OpenAIModelFamily =
| "turbo"
@@ -32,13 +33,15 @@ export type MistralAIModelFamily =
| "mistral-large";
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
export type CohereModelFamily = "command-r" | "command-r-plus";
export type ModelFamily =
| OpenAIModelFamily
| AnthropicModelFamily
| GoogleAIModelFamily
| MistralAIModelFamily
| AwsBedrockModelFamily
| AzureOpenAIModelFamily;
| AzureOpenAIModelFamily
| CohereModelFamily;
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
@@ -64,6 +67,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
"azure-gpt4-turbo",
"azure-gpt4o",
"azure-dall-e",
"command-r",
"command-r-plus",
] as const);
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
@@ -75,6 +80,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
"mistral-ai",
"aws",
"azure",
"cohere",
] as const);
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
@@ -116,6 +122,8 @@ export const MODEL_FAMILY_SERVICE: {
"mistral-small": "mistral-ai",
"mistral-medium": "mistral-ai",
"mistral-large": "mistral-ai",
"command-r": "cohere",
"command-r-plus": "cohere",
};
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
@@ -181,6 +189,11 @@ export function getAzureOpenAIModelFamily(
return defaultFamily;
}
export function getCohereModelFamily(model: string): CohereModelFamily {
if (model.includes("plus")) return "command-r-plus";
return "command-r";
}
export function assertIsKnownModelFamily(
modelFamily: string
): asserts modelFamily is ModelFamily {
@@ -220,6 +233,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
case "mistral-ai":
modelFamily = getMistralAIModelFamily(model);
break;
case "cohere-chat":
modelFamily = getCohereModelFamily(model);
break;
default:
assertNever(req.outboundApi);
}
+1 -1
View File
@@ -1,6 +1,6 @@
import { config } from "../../config";
import type { EventLogEntry } from "../database";
import { eventsRepo } from "../database/repos/events";
import { eventsRepo } from "../database/repos/event";
export const logEvent = (payload: Omit<EventLogEntry, "date">) => {
if (!config.eventLogging) {
+71
View File
@@ -0,0 +1,71 @@
import { ZodType, z } from "zod";
import { MODEL_FAMILIES, ModelFamily } from "../models";
import { makeOptionalPropsNullable } from "../utils";
// This just dynamically creates a Zod object type with a key for each model
// family and an optional number value.
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
MODEL_FAMILIES.reduce(
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
{} as Record<ModelFamily, ZodType<number>>
)
);
export const UserSchema = z
.object({
/** User's personal access token. */
token: z.string(),
/** IP addresses the user has connected from. */
ip: z.array(z.string()),
/** User's nickname. */
nickname: z.string().max(80).optional(),
/**
* The user's privilege level.
* - `normal`: Default role. Subject to usual rate limits and quotas.
* - `special`: Special role. Higher quotas and exempt from
* auto-ban/lockout.
**/
type: z.enum(["normal", "special", "temporary"]),
/** Number of prompts the user has made. */
promptCount: z.number(),
/**
* @deprecated Use `tokenCounts` instead.
* Never used; retained for backwards compatibility.
*/
tokenCount: z.any().optional(),
/** Number of tokens the user has consumed, by model family. */
tokenCounts: tokenCountsSchema,
/** Maximum number of tokens the user can consume, by model family. */
tokenLimits: tokenCountsSchema,
/** Time at which the user was created. */
createdAt: z.number(),
/** Time at which the user last connected. */
lastUsedAt: z.number().optional(),
/** Time at which the user was disabled, if applicable. */
disabledAt: z.number().optional(),
/** Reason for which the user was disabled, if applicable. */
disabledReason: z.string().optional(),
/** Time at which the user will expire and be disabled (for temp users). */
expiresAt: z.number().optional(),
/** The user's maximum number of IP addresses; supercedes global max. */
maxIps: z.coerce.number().int().min(0).optional(),
/** Private note about the user. */
adminNote: z.string().optional(),
meta: z.record(z.any()).optional(),
})
.strict();
/**
* Variant of `UserSchema` which allows for partial updates, and makes any
* optional properties on the base schema nullable. Null values are used to
* indicate that the property should be deleted from the user object.
*/
export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
.partial()
.extend({ token: z.string() });
export type UserTokenCounts = {
[K in ModelFamily]: number | undefined;
};
export type User = z.infer<typeof UserSchema>;
export type UserUpdate = z.infer<typeof UserPartialSchema>;
+1 -1
View File
@@ -22,9 +22,9 @@ import {
ModelFamily,
} from "../models";
import { logger } from "../../logger";
import { User, UserTokenCounts, UserUpdate } from "./schema";
import { APIFormat } from "../key-management";
import { assertNever } from "../utils";
import { User, UserTokenCounts, UserUpdate } from "../database/repos/users";
const log = logger.child({ module: "users" });
+1 -1
View File
@@ -1,9 +1,9 @@
import { Router } from "express";
import { UserPartialSchema } from "../../shared/users/schema";
import * as userStore from "../../shared/users/user-store";
import { ForbiddenError, BadRequestError } from "../../shared/errors";
import { sanitizeAndTrim } from "../../shared/utils";
import { config } from "../../config";
import { UserPartialSchema } from "../../shared/database/repos/users";
const router = Router();