Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8d5059534a | |||
| 0ea43f61c2 | |||
| ca4321b4cb | |||
| 7660ed8b94 | |||
| 55f1bbed3b | |||
| 57fd17ede0 | |||
| 9d00b8a9de | |||
| 155e185c6e |
+15
-6
@@ -46,6 +46,14 @@ NODE_ENV=production
|
|||||||
# 'azure-dall-e' to the list of allowed model families.
|
# 'azure-dall-e' to the list of allowed model families.
|
||||||
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
# ALLOWED_MODEL_FAMILIES=turbo,gpt4,gpt4-32k,gpt4-turbo,gpt4o,claude,claude-opus,gemini-pro,mistral-tiny,mistral-small,mistral-medium,mistral-large,aws-claude,aws-claude-opus,azure-turbo,azure-gpt4,azure-gpt4-32k,azure-gpt4-turbo,azure-gpt4o
|
||||||
|
|
||||||
|
# Which services can be used to process prompts containing images via multimodal
|
||||||
|
# models. The following services are recognized:
|
||||||
|
# openai | anthropic | aws | azure | google-ai | mistral-ai
|
||||||
|
# Do not enable this feature unless all users are trusted, as you will be liable
|
||||||
|
# for any user-submitted images containing illegal content.
|
||||||
|
# By default, no image services are allowed and image prompts are rejected.
|
||||||
|
# ALLOWED_VISION_SERVICES=
|
||||||
|
|
||||||
# IP addresses or CIDR blocks from which requests will be blocked.
|
# IP addresses or CIDR blocks from which requests will be blocked.
|
||||||
# IP_BLACKLIST=10.0.0.1/24
|
# IP_BLACKLIST=10.0.0.1/24
|
||||||
# URLs from which requests will be blocked.
|
# URLs from which requests will be blocked.
|
||||||
@@ -60,7 +68,7 @@ NODE_ENV=production
|
|||||||
# Avoid short or common phrases as this tests the entire prompt.
|
# Avoid short or common phrases as this tests the entire prompt.
|
||||||
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
|
# REJECT_PHRASES="phrase one,phrase two,"phrase three, which has a comma",phrase four"
|
||||||
# Message to show when requests are rejected.
|
# Message to show when requests are rejected.
|
||||||
# REJECT_MESSAGE="This content violates /aicg/'s acceptable use policy."
|
# REJECT_MESSAGE="You can't say that here."
|
||||||
|
|
||||||
# Whether prompts should be logged to Google Sheets.
|
# Whether prompts should be logged to Google Sheets.
|
||||||
# Requires additional setup. See `docs/google-sheets.md` for more information.
|
# Requires additional setup. See `docs/google-sheets.md` for more information.
|
||||||
@@ -102,18 +110,19 @@ NODE_ENV=production
|
|||||||
# ALLOW_NICKNAME_CHANGES=true
|
# ALLOW_NICKNAME_CHANGES=true
|
||||||
|
|
||||||
# Default token quotas for each model family. (0 for unlimited)
|
# Default token quotas for each model family. (0 for unlimited)
|
||||||
# DALL-E "tokens" are counted at a rate of 100000 tokens per US$1.00 generated,
|
# Specify as TOKEN_QUOTA_MODEL_FAMILY=value, replacing dashes with underscores.
|
||||||
# which is similar to the cost of GPT-4 Turbo.
|
|
||||||
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
|
||||||
# See `docs/dall-e-configuration.md` for more information.
|
|
||||||
# TOKEN_QUOTA_TURBO=0
|
# TOKEN_QUOTA_TURBO=0
|
||||||
# TOKEN_QUOTA_GPT4=0
|
# TOKEN_QUOTA_GPT4=0
|
||||||
# TOKEN_QUOTA_GPT4_32K=0
|
# TOKEN_QUOTA_GPT4_32K=0
|
||||||
# TOKEN_QUOTA_GPT4_TURBO=0
|
# TOKEN_QUOTA_GPT4_TURBO=0
|
||||||
# TOKEN_QUOTA_DALL_E=0
|
|
||||||
# TOKEN_QUOTA_CLAUDE=0
|
# TOKEN_QUOTA_CLAUDE=0
|
||||||
# TOKEN_QUOTA_GEMINI_PRO=0
|
# TOKEN_QUOTA_GEMINI_PRO=0
|
||||||
# TOKEN_QUOTA_AWS_CLAUDE=0
|
# TOKEN_QUOTA_AWS_CLAUDE=0
|
||||||
|
# "Tokens" for image-generation models are counted at a rate of 100000 tokens
|
||||||
|
# per US$1.00 generated, which is similar to the cost of GPT-4 Turbo.
|
||||||
|
# DALL-E 3 costs around US$0.10 per image (10000 tokens).
|
||||||
|
# See `docs/dall-e-configuration.md` for more information.
|
||||||
|
# TOKEN_QUOTA_DALL_E=0
|
||||||
|
|
||||||
# How often to refresh token quotas. (hourly | daily)
|
# How often to refresh token quotas. (hourly | daily)
|
||||||
# Leave unset to never automatically refresh quotas.
|
# Leave unset to never automatically refresh quotas.
|
||||||
|
|||||||
@@ -70,4 +70,4 @@ You can provide a comma-separated list containing individual IPv4 or IPv6 addres
|
|||||||
|
|
||||||
To whitelist an entire IP range, use CIDR notation. For example, `192.168.0.1/24` would whitelist all addresses from `192.168.0.0` to `192.168.0.255`.
|
To whitelist an entire IP range, use CIDR notation. For example, `192.168.0.1/24` would whitelist all addresses from `192.168.0.0` to `192.168.0.255`.
|
||||||
|
|
||||||
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0`, which will allow access from any IP address. This is the default behavior.
|
To disable the whitelist, set `ADMIN_WHITELIST=0.0.0.0/0,::0`, which will allow access from any IPv4 or IPv6 address. This is the default behavior.
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Router } from "express";
|
import { Router } from "express";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { encodeCursor, decodeCursor } from "../../shared/utils";
|
import { encodeCursor, decodeCursor } from "../../shared/utils";
|
||||||
import { eventsRepo } from "../../shared/database/repos/events";
|
import { eventsRepo } from "../../shared/database/repos/event";
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { Router } from "express";
|
|||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import * as userStore from "../../shared/users/user-store";
|
import * as userStore from "../../shared/users/user-store";
|
||||||
import { parseSort, sortBy } from "../../shared/utils";
|
import { parseSort, sortBy } from "../../shared/utils";
|
||||||
import { UserPartialSchema, UserSchema } from "../../shared/database/repos/users";
|
import { UserPartialSchema, UserSchema } from "../../shared/users/schema";
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,15 @@ import { parseSort, sortBy, paginate } from "../../shared/utils";
|
|||||||
import { keyPool } from "../../shared/key-management";
|
import { keyPool } from "../../shared/key-management";
|
||||||
import { LLMService, MODEL_FAMILIES } from "../../shared/models";
|
import { LLMService, MODEL_FAMILIES } from "../../shared/models";
|
||||||
import { getTokenCostUsd, prettyTokens } from "../../shared/stats";
|
import { getTokenCostUsd, prettyTokens } from "../../shared/stats";
|
||||||
|
import {
|
||||||
|
User,
|
||||||
|
UserPartialSchema,
|
||||||
|
UserSchema,
|
||||||
|
UserTokenCounts,
|
||||||
|
} from "../../shared/users/schema";
|
||||||
import { getLastNImages } from "../../shared/file-storage/image-history";
|
import { getLastNImages } from "../../shared/file-storage/image-history";
|
||||||
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
|
import { blacklists, parseCidrs, whitelists } from "../../shared/cidr";
|
||||||
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
|
import { invalidatePowHmacKey } from "../../user/web/pow-captcha";
|
||||||
import { User, UserPartialSchema, UserSchema, UserTokenCounts } from "../../shared/database/repos/users";
|
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
|
||||||
|
|||||||
+28
-8
@@ -3,7 +3,7 @@ import dotenv from "dotenv";
|
|||||||
import type firebase from "firebase-admin";
|
import type firebase from "firebase-admin";
|
||||||
import path from "path";
|
import path from "path";
|
||||||
import pino from "pino";
|
import pino from "pino";
|
||||||
import type { ModelFamily } from "./shared/models";
|
import type { LLMService, ModelFamily } from "./shared/models";
|
||||||
import { MODEL_FAMILIES } from "./shared/models";
|
import { MODEL_FAMILIES } from "./shared/models";
|
||||||
|
|
||||||
dotenv.config();
|
dotenv.config();
|
||||||
@@ -340,13 +340,18 @@ type Config = {
|
|||||||
*/
|
*/
|
||||||
allowOpenAIToolUsage?: boolean;
|
allowOpenAIToolUsage?: boolean;
|
||||||
/**
|
/**
|
||||||
* Whether to allow prompts containing images, for use with multimodal models.
|
* Which services will accept prompts containing images, for use with
|
||||||
* Avoid giving this to untrusted users, as they can submit illegal content.
|
* multimodal models. Users with `special` role are exempt from this
|
||||||
|
* restriction.
|
||||||
*
|
*
|
||||||
* Applies to GPT-4 Vision and Claude Vision. Users with `special` role are
|
* Do not enable this feature for untrusted users, as malicious users could
|
||||||
* exempt from this restriction.
|
* send images which violate your provider's terms of service or local laws.
|
||||||
|
*
|
||||||
|
* Defaults to no services, meaning image prompts are disabled. Use a comma-
|
||||||
|
* separated list. Available services are:
|
||||||
|
* openai,anthropic,google-ai,mistral-ai,aws,azure
|
||||||
*/
|
*/
|
||||||
allowImagePrompts?: boolean;
|
allowedVisionServices: LLMService[];
|
||||||
/**
|
/**
|
||||||
* Allows overriding the default proxy endpoint route. Defaults to /proxy.
|
* Allows overriding the default proxy endpoint route. Defaults to /proxy.
|
||||||
* A leading slash is required.
|
* A leading slash is required.
|
||||||
@@ -479,9 +484,13 @@ export const config: Config = {
|
|||||||
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
|
staticServiceInfo: getEnvWithDefault("STATIC_SERVICE_INFO", false),
|
||||||
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
|
trustedProxies: getEnvWithDefault("TRUSTED_PROXIES", 1),
|
||||||
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
|
allowOpenAIToolUsage: getEnvWithDefault("ALLOW_OPENAI_TOOL_USAGE", false),
|
||||||
allowImagePrompts: getEnvWithDefault("ALLOW_IMAGE_PROMPTS", false),
|
allowedVisionServices: parseCsv(
|
||||||
|
getEnvWithDefault("ALLOWED_VISION_SERVICES", "")
|
||||||
|
) as LLMService[],
|
||||||
proxyEndpointRoute: getEnvWithDefault("PROXY_ENDPOINT_ROUTE", "/proxy"),
|
proxyEndpointRoute: getEnvWithDefault("PROXY_ENDPOINT_ROUTE", "/proxy"),
|
||||||
adminWhitelist: parseCsv(getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0")),
|
adminWhitelist: parseCsv(
|
||||||
|
getEnvWithDefault("ADMIN_WHITELIST", "0.0.0.0/0,::/0")
|
||||||
|
),
|
||||||
ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")),
|
ipBlacklist: parseCsv(getEnvWithDefault("IP_BLACKLIST", "")),
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
@@ -534,6 +543,17 @@ export async function assertConfigIsValid() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (process.env.ALLOW_IMAGE_PROMPTS === "true") {
|
||||||
|
const hasAllowedServices = config.allowedVisionServices.length > 0;
|
||||||
|
if (!hasAllowedServices) {
|
||||||
|
config.allowedVisionServices = ["openai", "anthropic"];
|
||||||
|
startupLogger.warn(
|
||||||
|
{ allowedVisionServices: config.allowedVisionServices },
|
||||||
|
"ALLOW_IMAGE_PROMPTS is deprecated. Use ALLOWED_VISION_SERVICES instead."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (config.promptLogging && !config.promptLoggingBackend) {
|
if (config.promptLogging && !config.promptLoggingBackend) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Prompt logging is enabled but no backend is configured. Set PROMPT_LOGGING_BACKEND to 'google_sheets' or 'file'."
|
"Prompt logging is enabled but no backend is configured. Set PROMPT_LOGGING_BACKEND to 'google_sheets' or 'file'."
|
||||||
|
|||||||
@@ -66,7 +66,8 @@ export const gatekeeper: RequestHandler = (req, res, next) => {
|
|||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
403,
|
403,
|
||||||
"Forbidden: no more IPs can authenticate with this user token"
|
`Forbidden: no more IP addresses allowed for this user token`,
|
||||||
|
{ currentIp: ip, maxIps: user?.maxIps }
|
||||||
);
|
);
|
||||||
case "disabled":
|
case "disabled":
|
||||||
const bannedUser = getUser(token);
|
const bannedUser = getUser(token);
|
||||||
@@ -84,7 +85,8 @@ function sendError(
|
|||||||
req: Request,
|
req: Request,
|
||||||
res: Response,
|
res: Response,
|
||||||
status: number,
|
status: number,
|
||||||
message: string
|
message: string,
|
||||||
|
data: any = {}
|
||||||
) {
|
) {
|
||||||
const isPost = req.method === "POST";
|
const isPost = req.method === "POST";
|
||||||
const hasBody = isPost && req.body;
|
const hasBody = isPost && req.body;
|
||||||
@@ -103,6 +105,7 @@ function sendError(
|
|||||||
format: "unknown",
|
format: "unknown",
|
||||||
statusCode: status,
|
statusCode: status,
|
||||||
reqId: req.id,
|
reqId: req.id,
|
||||||
|
obj: data,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,14 @@ import { ForbiddenError } from "../../../../shared/errors";
|
|||||||
* Rejects prompts containing images if multimodal prompts are disabled.
|
* Rejects prompts containing images if multimodal prompts are disabled.
|
||||||
*/
|
*/
|
||||||
export const validateVision: RequestPreprocessor = async (req) => {
|
export const validateVision: RequestPreprocessor = async (req) => {
|
||||||
if (config.allowImagePrompts) return;
|
if (req.service === undefined) {
|
||||||
if (req.user?.type === "special") return;
|
throw new Error("Request service must be set before validateVision");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.user?.type === "special") return;
|
||||||
|
if (config.allowedVisionServices.includes(req.service)) return;
|
||||||
|
|
||||||
|
// vision not allowed for req's service, block prompts with images
|
||||||
let hasImage = false;
|
let hasImage = false;
|
||||||
switch (req.outboundApi) {
|
switch (req.outboundApi) {
|
||||||
case "openai":
|
case "openai":
|
||||||
|
|||||||
@@ -52,7 +52,13 @@ function getMessageContent({
|
|||||||
delete obj.stack;
|
delete obj.stack;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [header, friendlyMessage, serializedObj, prettyTrace].join("\n\n");
|
return [
|
||||||
|
header,
|
||||||
|
friendlyMessage,
|
||||||
|
serializedObj,
|
||||||
|
prettyTrace,
|
||||||
|
"<!-- oai-proxy-error -->",
|
||||||
|
].join("\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
type ErrorGeneratorOptions = {
|
type ErrorGeneratorOptions = {
|
||||||
@@ -116,6 +122,11 @@ export function sendErrorToClient({
|
|||||||
const isStreaming =
|
const isStreaming =
|
||||||
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
req.isStreaming || req.body.stream === true || req.body.stream === "true";
|
||||||
|
|
||||||
|
if (!res.headersSent) {
|
||||||
|
res.setHeader("x-oai-proxy-error", options.title);
|
||||||
|
res.setHeader("x-oai-proxy-error-status", options.statusCode || 500);
|
||||||
|
}
|
||||||
|
|
||||||
if (isStreaming) {
|
if (isStreaming) {
|
||||||
if (!res.headersSent) {
|
if (!res.headersSent) {
|
||||||
initializeSseStream(res);
|
initializeSseStream(res);
|
||||||
|
|||||||
+1
-4
@@ -179,10 +179,7 @@ function cleanup() {
|
|||||||
process.exit(0);
|
process.exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
process.on("exit", () => cleanup());
|
process.on("SIGINT", cleanup);
|
||||||
process.on("SIGHUP", () => process.exit(128 + 1));
|
|
||||||
process.on("SIGINT", () => process.exit(128 + 2));
|
|
||||||
process.on("SIGTERM", () => process.exit(128 + 15));
|
|
||||||
|
|
||||||
function registerUncaughtExceptionHandler() {
|
function registerUncaughtExceptionHandler() {
|
||||||
process.on("uncaughtException", (err: any) => {
|
process.on("uncaughtException", (err: any) => {
|
||||||
|
|||||||
@@ -119,7 +119,8 @@ export const transformOpenAIToAnthropicChat: APIFormatTransformer<
|
|||||||
stream: rest.stream,
|
stream: rest.stream,
|
||||||
temperature: rest.temperature,
|
temperature: rest.temperature,
|
||||||
top_p: rest.top_p,
|
top_p: rest.top_p,
|
||||||
stop_sequences: typeof rest.stop === "string" ? [rest.stop] : rest.stop,
|
stop_sequences:
|
||||||
|
typeof rest.stop === "string" ? [rest.stop] : rest.stop || undefined,
|
||||||
...(rest.user ? { metadata: { user_id: rest.user } } : {}),
|
...(rest.user ? { metadata: { user_id: rest.user } } : {}),
|
||||||
// Anthropic supports top_k, but OpenAI does not
|
// Anthropic supports top_k, but OpenAI does not
|
||||||
// OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed,
|
// OpenAI supports frequency_penalty, presence_penalty, logit_bias, n, seed,
|
||||||
|
|||||||
@@ -0,0 +1,181 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
import {
|
||||||
|
OPENAI_OUTPUT_MAX,
|
||||||
|
OpenAIV1ChatCompletionSchema,
|
||||||
|
flattenOpenAIMessageContent,
|
||||||
|
} from "./openai";
|
||||||
|
import { APIFormatTransformer } from ".";
|
||||||
|
|
||||||
|
// https://docs.cohere.com/reference/chat
|
||||||
|
export const CohereV1ChatSchema = z
|
||||||
|
.object({
|
||||||
|
message: z.string(),
|
||||||
|
model: z.string().default("command-r-plus"),
|
||||||
|
stream: z.boolean().default(false).optional(),
|
||||||
|
preamble: z.string().optional(),
|
||||||
|
chat_history: z
|
||||||
|
.array(
|
||||||
|
// Either a message from a chat participant, or a past tool call
|
||||||
|
z.union([
|
||||||
|
z.object({
|
||||||
|
role: z.enum(["CHATBOT", "SYSTEM", "USER"]),
|
||||||
|
message: z.string(),
|
||||||
|
tool_calls: z
|
||||||
|
.array(z.object({ name: z.string(), parameters: z.any() }))
|
||||||
|
.optional(),
|
||||||
|
}),
|
||||||
|
z.object({
|
||||||
|
role: z.enum(["TOOL"]),
|
||||||
|
tool_results: z.array(
|
||||||
|
z.object({
|
||||||
|
call: z.object({ name: z.string(), parameters: z.any() }),
|
||||||
|
outputs: z.array(z.any()),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
)
|
||||||
|
.optional(),
|
||||||
|
// Don't allow conversation_id as it causes calls to be stateful and we don't
|
||||||
|
// offer guarantees about which key a user's request will be routed to.
|
||||||
|
conversation_id: z.literal(undefined).optional(),
|
||||||
|
prompt_truncation: z
|
||||||
|
.enum(["AUTO", "AUTO_PRESERVE_ORDER", "OFF"])
|
||||||
|
.optional(),
|
||||||
|
/*
|
||||||
|
Supporting RAG is complex because documents can be arbitrary size and have
|
||||||
|
to have embeddings generated, which incurs a cost that is not trivial to
|
||||||
|
estimate. We don't support it for now.
|
||||||
|
connectors: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
id: z.string(),
|
||||||
|
user_access_token: z.string().optional(),
|
||||||
|
continue_on_failure: z.boolean().default(false).optional(),
|
||||||
|
options: z.any().optional(),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.optional(),
|
||||||
|
search_queries_only: z.boolean().default(false).optional(),
|
||||||
|
documents: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
id: z.string().optional(),
|
||||||
|
title: z.string().optional(),
|
||||||
|
text: z.string(),
|
||||||
|
_excludes: z.array(z.string()).optional(),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.optional(),
|
||||||
|
citation_quality: z.enum(["accurate", "fast"]).optional(),
|
||||||
|
*/
|
||||||
|
temperature: z.number().default(0.3).optional(),
|
||||||
|
max_tokens: z
|
||||||
|
.number()
|
||||||
|
.int()
|
||||||
|
.nullish()
|
||||||
|
.default(Math.min(OPENAI_OUTPUT_MAX, 4096))
|
||||||
|
.transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)),
|
||||||
|
max_input_tokens: z.number().int().optional(),
|
||||||
|
k: z.number().int().min(0).max(500).default(0).optional(),
|
||||||
|
p: z.number().min(0.01).max(0.99).default(0.75).optional(),
|
||||||
|
seed: z.number().int().optional(),
|
||||||
|
stop_sequences: z.array(z.string()).max(5).optional(),
|
||||||
|
frequency_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||||
|
presence_penalty: z.number().min(0).max(1).default(0).optional(),
|
||||||
|
tools: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
name: z.string(),
|
||||||
|
description: z.string(),
|
||||||
|
parameter_definitions: z.record(
|
||||||
|
z.object({
|
||||||
|
description: z.string().optional(),
|
||||||
|
type: z.string(),
|
||||||
|
required: z.boolean().optional().default(false),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.optional(),
|
||||||
|
tool_results: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
call: z.object({
|
||||||
|
name: z.string(),
|
||||||
|
parameters: z.record(z.any()),
|
||||||
|
}),
|
||||||
|
outputs: z.array(z.record(z.any())),
|
||||||
|
})
|
||||||
|
)
|
||||||
|
.optional(),
|
||||||
|
// We always force single step to avoid stateful calls or expensive multi-step
|
||||||
|
// generations when tools are involved.
|
||||||
|
force_single_step: z.literal(true).default(true).optional(),
|
||||||
|
})
|
||||||
|
.strip();
|
||||||
|
export type CohereChatMessage = NonNullable<
|
||||||
|
z.infer<typeof CohereV1ChatSchema>["chat_history"]
|
||||||
|
>[number];
|
||||||
|
|
||||||
|
export function flattenCohereMessageContent(
|
||||||
|
message: CohereChatMessage
|
||||||
|
): string {
|
||||||
|
return message.role === "TOOL"
|
||||||
|
? message.tool_results.map((r) => r.outputs[0].text).join("\n")
|
||||||
|
: message.message;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const transformOpenAIToCohere: APIFormatTransformer<
|
||||||
|
typeof CohereV1ChatSchema
|
||||||
|
> = async (req) => {
|
||||||
|
const { body } = req;
|
||||||
|
const result = OpenAIV1ChatCompletionSchema.safeParse({
|
||||||
|
...body,
|
||||||
|
model: "gpt-3.5-turbo",
|
||||||
|
});
|
||||||
|
if (!result.success) {
|
||||||
|
req.log.warn(
|
||||||
|
{ issues: result.error.issues, body },
|
||||||
|
"Invalid OpenAI-to-Cohere request"
|
||||||
|
);
|
||||||
|
throw result.error;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { messages, ...rest } = result.data;
|
||||||
|
// Final OAI message becomes the `message` field in Cohere
|
||||||
|
const message = messages[messages.length - 1];
|
||||||
|
// If the first message has system role, use it as preamble.
|
||||||
|
const hasSystemPreamble = messages[0]?.role === "system";
|
||||||
|
const preamble = hasSystemPreamble
|
||||||
|
? flattenOpenAIMessageContent(messages[0].content)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
const chatHistory = messages.slice(0, -1).map((m) => {
|
||||||
|
const role: Exclude<CohereChatMessage["role"], "TOOL"> =
|
||||||
|
m.role === "assistant"
|
||||||
|
? "CHATBOT"
|
||||||
|
: m.role === "system"
|
||||||
|
? "SYSTEM"
|
||||||
|
: "USER";
|
||||||
|
const content = flattenOpenAIMessageContent(m.content);
|
||||||
|
const message = m.name ? `${m.name}: ${content}` : content;
|
||||||
|
return { role, message };
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
model: rest.model,
|
||||||
|
preamble,
|
||||||
|
chat_history: chatHistory,
|
||||||
|
message: flattenOpenAIMessageContent(message.content),
|
||||||
|
stop_sequences:
|
||||||
|
typeof rest.stop === "string" ? [rest.stop] : rest.stop ?? undefined,
|
||||||
|
max_tokens: rest.max_tokens,
|
||||||
|
temperature: rest.temperature,
|
||||||
|
p: rest.top_p,
|
||||||
|
frequency_penalty: rest.frequency_penalty,
|
||||||
|
presence_penalty: rest.presence_penalty,
|
||||||
|
seed: rest.seed,
|
||||||
|
stream: rest.stream,
|
||||||
|
};
|
||||||
|
};
|
||||||
@@ -22,6 +22,7 @@ import {
|
|||||||
transformOpenAIToGoogleAI,
|
transformOpenAIToGoogleAI,
|
||||||
} from "./google-ai";
|
} from "./google-ai";
|
||||||
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
|
import { MistralAIV1ChatCompletionsSchema } from "./mistral-ai";
|
||||||
|
import { CohereV1ChatSchema, transformOpenAIToCohere } from "./cohere";
|
||||||
|
|
||||||
export { OpenAIChatMessage } from "./openai";
|
export { OpenAIChatMessage } from "./openai";
|
||||||
export {
|
export {
|
||||||
@@ -33,15 +34,29 @@ export {
|
|||||||
export { GoogleAIChatMessage } from "./google-ai";
|
export { GoogleAIChatMessage } from "./google-ai";
|
||||||
export { MistralAIChatMessage } from "./mistral-ai";
|
export { MistralAIChatMessage } from "./mistral-ai";
|
||||||
|
|
||||||
|
/** Represents a pair of API formats that can be transformed between. */
|
||||||
type APIPair = `${APIFormat}->${APIFormat}`;
|
type APIPair = `${APIFormat}->${APIFormat}`;
|
||||||
|
/** Represents a map of API format pairs to transformer functions. */
|
||||||
type TransformerMap = {
|
type TransformerMap = {
|
||||||
[key in APIPair]?: APIFormatTransformer<any>;
|
[key in APIPair]?: APIFormatTransformer<any>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents a transformer function that takes a Request and returns a Promise
|
||||||
|
* resolving to a value of the specified Zod schema type.
|
||||||
|
*
|
||||||
|
* @template Z The Zod schema type to transform the request into (from api-schemas).
|
||||||
|
* @param req The incoming Request to transform.
|
||||||
|
* @returns A Promise resolving to the transformed request body.
|
||||||
|
*/
|
||||||
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
|
export type APIFormatTransformer<Z extends z.ZodType<any, any>> = (
|
||||||
req: Request
|
req: Request
|
||||||
) => Promise<z.infer<Z>>;
|
) => Promise<z.infer<Z>>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specifies possible translations between API formats and the corresponding
|
||||||
|
* transformer functions to apply them.
|
||||||
|
*/
|
||||||
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
||||||
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
|
"anthropic-text->anthropic-chat": transformAnthropicTextToAnthropicChat,
|
||||||
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
|
"openai->anthropic-chat": transformOpenAIToAnthropicChat,
|
||||||
@@ -49,8 +64,12 @@ export const API_REQUEST_TRANSFORMERS: TransformerMap = {
|
|||||||
"openai->openai-text": transformOpenAIToOpenAIText,
|
"openai->openai-text": transformOpenAIToOpenAIText,
|
||||||
"openai->openai-image": transformOpenAIToOpenAIImage,
|
"openai->openai-image": transformOpenAIToOpenAIImage,
|
||||||
"openai->google-ai": transformOpenAIToGoogleAI,
|
"openai->google-ai": transformOpenAIToGoogleAI,
|
||||||
|
"openai->cohere-chat": transformOpenAIToCohere,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Specifies the schema for each API format to validate incoming requests.
|
||||||
|
*/
|
||||||
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
||||||
"anthropic-chat": AnthropicV1MessagesSchema,
|
"anthropic-chat": AnthropicV1MessagesSchema,
|
||||||
"anthropic-text": AnthropicV1TextSchema,
|
"anthropic-text": AnthropicV1TextSchema,
|
||||||
@@ -59,4 +78,5 @@ export const API_REQUEST_VALIDATORS: Record<APIFormat, z.ZodSchema<any>> = {
|
|||||||
"openai-image": OpenAIV1ImagesGenerationSchema,
|
"openai-image": OpenAIV1ImagesGenerationSchema,
|
||||||
"google-ai": GoogleAIV1GenerateContentSchema,
|
"google-ai": GoogleAIV1GenerateContentSchema,
|
||||||
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
"mistral-ai": MistralAIV1ChatCompletionsSchema,
|
||||||
|
"cohere-chat": CohereV1ChatSchema,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ export const OpenAIV1ChatCompletionSchema = z
|
|||||||
stream: z.boolean().optional().default(false),
|
stream: z.boolean().optional().default(false),
|
||||||
stop: z
|
stop: z
|
||||||
.union([z.string().max(500), z.array(z.string().max(500))])
|
.union([z.string().max(500), z.array(z.string().max(500))])
|
||||||
.optional(),
|
.nullish(),
|
||||||
max_tokens: z.coerce
|
max_tokens: z.coerce
|
||||||
.number()
|
.number()
|
||||||
.int()
|
.int()
|
||||||
|
|||||||
Vendored
+1
-1
@@ -3,8 +3,8 @@
|
|||||||
import type { HttpRequest } from "@smithy/types";
|
import type { HttpRequest } from "@smithy/types";
|
||||||
import { Express } from "express-serve-static-core";
|
import { Express } from "express-serve-static-core";
|
||||||
import { APIFormat, Key } from "./key-management";
|
import { APIFormat, Key } from "./key-management";
|
||||||
|
import { User } from "./users/schema";
|
||||||
import { LLMService, ModelFamily } from "./models";
|
import { LLMService, ModelFamily } from "./models";
|
||||||
import { User } from "./database/repos/users";
|
|
||||||
|
|
||||||
declare global {
|
declare global {
|
||||||
namespace Express {
|
namespace Express {
|
||||||
|
|||||||
@@ -23,11 +23,7 @@ export async function initializeDatabase() {
|
|||||||
log.info("Initializing database...");
|
log.info("Initializing database...");
|
||||||
|
|
||||||
const sqlite3 = await import("better-sqlite3");
|
const sqlite3 = await import("better-sqlite3");
|
||||||
database = sqlite3.default(config.sqliteDataPath, {
|
database = sqlite3.default(config.sqliteDataPath);
|
||||||
verbose: process.env.SQLITE_VERBOSE === "true"
|
|
||||||
? (msg, ...args) => log.debug({ args }, String(msg))
|
|
||||||
: undefined,
|
|
||||||
});
|
|
||||||
migrateDatabase();
|
migrateDatabase();
|
||||||
database.pragma("journal_mode = WAL");
|
database.pragma("journal_mode = WAL");
|
||||||
log.info("Database initialized.");
|
log.info("Database initialized.");
|
||||||
@@ -90,5 +86,4 @@ function assertNumber(value: unknown): asserts value is number {
|
|||||||
throw new Error("Expected number");
|
throw new Error("Expected number");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
export { EventLogEntry } from "./repos/event";
|
||||||
export { EventLogEntry } from "./repos/events";
|
|
||||||
|
|||||||
@@ -58,65 +58,4 @@ export const migrations = [
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "add users schema",
|
|
||||||
version: 4,
|
|
||||||
up: (db) => {
|
|
||||||
// language=SQLite
|
|
||||||
const sql = `
|
|
||||||
CREATE TABLE IF NOT EXISTS users
|
|
||||||
(
|
|
||||||
token TEXT PRIMARY KEY NOT NULL,
|
|
||||||
nickname TEXT,
|
|
||||||
type TEXT CHECK (type IN ('normal', 'special', 'temporary')) NOT NULL,
|
|
||||||
createdAt INTEGER NOT NULL,
|
|
||||||
lastUsedAt INTEGER,
|
|
||||||
disabledAt INTEGER,
|
|
||||||
disabledReason TEXT,
|
|
||||||
expiresAt INTEGER,
|
|
||||||
maxIps INTEGER,
|
|
||||||
adminNote TEXT
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS user_ips
|
|
||||||
(
|
|
||||||
userToken TEXT NOT NULL,
|
|
||||||
ip TEXT NOT NULL,
|
|
||||||
PRIMARY KEY (userToken, ip),
|
|
||||||
FOREIGN KEY (userToken) REFERENCES users (token)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS user_token_counts
|
|
||||||
(
|
|
||||||
userToken TEXT NOT NULL,
|
|
||||||
modelFamily TEXT NOT NULL,
|
|
||||||
inputTokens INTEGER NOT NULL,
|
|
||||||
outputTokens INTEGER NOT NULL,
|
|
||||||
tokenLimit INTEGER NOT NULL,
|
|
||||||
prompts INTEGER NOT NULL,
|
|
||||||
PRIMARY KEY (userToken, modelFamily)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS user_meta
|
|
||||||
(
|
|
||||||
userToken TEXT NOT NULL,
|
|
||||||
key TEXT NOT NULL,
|
|
||||||
value TEXT NOT NULL,
|
|
||||||
PRIMARY KEY (userToken, key),
|
|
||||||
FOREIGN KEY (userToken) REFERENCES users (token)
|
|
||||||
);
|
|
||||||
`;
|
|
||||||
db.exec(sql);
|
|
||||||
},
|
|
||||||
down: (db) => {
|
|
||||||
// language=SQLite
|
|
||||||
const sql = `
|
|
||||||
DROP TABLE users;
|
|
||||||
DROP TABLE user_ips;
|
|
||||||
DROP TABLE user_token_counts;
|
|
||||||
DROP TABLE user_meta;
|
|
||||||
`;
|
|
||||||
db.exec(sql);
|
|
||||||
},
|
|
||||||
},
|
|
||||||
] satisfies Migration[];
|
] satisfies Migration[];
|
||||||
|
|||||||
@@ -1,420 +0,0 @@
|
|||||||
import { ZodType, z } from "zod";
|
|
||||||
import { MODEL_FAMILIES, ModelFamily } from "../../models";
|
|
||||||
import { makeOptionalPropsNullable } from "../../utils";
|
|
||||||
import { getDatabase } from "../index";
|
|
||||||
import type { Transaction } from "better-sqlite3";
|
|
||||||
|
|
||||||
// This just dynamically creates a Zod object type with a key for each model
|
|
||||||
// family and an optional number value.
|
|
||||||
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
|
|
||||||
MODEL_FAMILIES.reduce(
|
|
||||||
(acc, family) => {
|
|
||||||
return {
|
|
||||||
...acc,
|
|
||||||
[family]: z.object({
|
|
||||||
input: z.number().optional().default(0),
|
|
||||||
output: z.number().optional().default(0),
|
|
||||||
limit: z.number().optional().default(0),
|
|
||||||
prompts: z.number().optional().default(0),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{} as Record<
|
|
||||||
ModelFamily,
|
|
||||||
ZodType<{ input: number; output: number; limit: number; prompts: number }>
|
|
||||||
>
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Old token counts schema before counts were combined into a single object.
|
|
||||||
const tokenCountsSchemaOld = z.object(
|
|
||||||
MODEL_FAMILIES.reduce(
|
|
||||||
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
|
|
||||||
{} as Record<ModelFamily, ZodType<number>>
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
export const UserSchema = z
|
|
||||||
.object({
|
|
||||||
/** User's personal access token. */
|
|
||||||
token: z.string(),
|
|
||||||
/** IP addresses the user has connected from. */
|
|
||||||
ip: z.array(z.string()),
|
|
||||||
/** User's nickname. */
|
|
||||||
nickname: z.string().max(80).optional(),
|
|
||||||
/**
|
|
||||||
* The user's privilege level.
|
|
||||||
* - `normal`: Default role. Subject to usual rate limits and quotas.
|
|
||||||
* - `special`: Special role. Higher quotas and exempt from auto-ban/lockout.
|
|
||||||
**/
|
|
||||||
type: z.enum(["normal", "special", "temporary"]),
|
|
||||||
/** Number of prompts the user has made. */
|
|
||||||
promptCount: z.number(),
|
|
||||||
/**
|
|
||||||
* @deprecated Use `tokenCounts` instead.
|
|
||||||
* Never used; retained for backwards compatibility.
|
|
||||||
*/
|
|
||||||
tokenCount: z.any().optional(),
|
|
||||||
/** Number of tokens the user has consumed, by model family. */
|
|
||||||
tokenCounts: tokenCountsSchemaOld,
|
|
||||||
/** Maximum number of tokens the user can consume, by model family. */
|
|
||||||
tokenLimits: tokenCountsSchemaOld,
|
|
||||||
/** Token data for the user, by model family. */
|
|
||||||
modelTokenCounts: tokenCountsSchema,
|
|
||||||
/** Time at which the user was created. */
|
|
||||||
createdAt: z.number(),
|
|
||||||
/** Time at which the user last connected. */
|
|
||||||
lastUsedAt: z.number().optional(),
|
|
||||||
/** Time at which the user was disabled, if applicable. */
|
|
||||||
disabledAt: z.number().optional(),
|
|
||||||
/** Reason for which the user was disabled, if applicable. */
|
|
||||||
disabledReason: z.string().optional(),
|
|
||||||
/** Time at which the user will expire and be disabled (for temp users). */
|
|
||||||
expiresAt: z.number().optional(),
|
|
||||||
/** The user's maximum number of IP addresses; supercedes global max. */
|
|
||||||
maxIps: z.coerce.number().int().min(0).optional(),
|
|
||||||
/** Private note about the user. */
|
|
||||||
adminNote: z.string().optional(),
|
|
||||||
meta: z.record(z.any()).optional(),
|
|
||||||
})
|
|
||||||
.strict();
|
|
||||||
/**
|
|
||||||
* Variant of `;
|
|
||||||
UserSchema` which allows for partial updates, and makes any
|
|
||||||
* optional properties on the base schema nullable. Null values are used to
|
|
||||||
* indicate that the property should be deleted from the user object.
|
|
||||||
*/
|
|
||||||
export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
|
|
||||||
.partial()
|
|
||||||
.extend({ token: z.string() });
|
|
||||||
export type UserTokenCounts = {
|
|
||||||
[K in ModelFamily]: {
|
|
||||||
input: number;
|
|
||||||
output: number;
|
|
||||||
limit: number;
|
|
||||||
prompts: number;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
export type UserTokenCountsOld = {
|
|
||||||
[K in ModelFamily]: number | undefined;
|
|
||||||
};
|
|
||||||
export type User = z.infer<typeof UserSchema>;
|
|
||||||
export type UserUpdate = z.infer<typeof UserPartialSchema>;
|
|
||||||
export type VirtualUser = User & { virtual: true; ipCount: number };
|
|
||||||
|
|
||||||
export const UsersRepo = {
|
|
||||||
getUserByToken: (token: string) => {
|
|
||||||
const db = getDatabase();
|
|
||||||
// language=SQLite
|
|
||||||
const sql = `
|
|
||||||
SELECT u.*,
|
|
||||||
json_group_array(ui.ip) as ip,
|
|
||||||
json_group_object(utc.modelFamily,
|
|
||||||
json_object('input', utc.inputTokens,
|
|
||||||
'output', utc.outputTokens,
|
|
||||||
'limit', utc.tokenLimit,
|
|
||||||
'prompts', utc.prompts)) as tokenCounts,
|
|
||||||
json_object(um.key, um.value) as meta
|
|
||||||
FROM users u
|
|
||||||
LEFT JOIN user_ips ui ON u.token = ui.userToken
|
|
||||||
LEFT JOIN user_token_counts utc ON u.token = utc.userToken
|
|
||||||
LEFT JOIN user_meta um ON u.token = um.userToken
|
|
||||||
WHERE u.token = ?;
|
|
||||||
`;
|
|
||||||
|
|
||||||
const user = db.prepare(sql).get(token);
|
|
||||||
if (!user) return;
|
|
||||||
|
|
||||||
return marshalUser(user);
|
|
||||||
},
|
|
||||||
getUsers: (pagination: { limit: number; cursor?: string }): VirtualUser[] => {
|
|
||||||
const db = getDatabase();
|
|
||||||
const { limit, cursor } = pagination;
|
|
||||||
const params = [];
|
|
||||||
let sql = `
|
|
||||||
SELECT u.*,
|
|
||||||
count(ui.ip) as ipCount,
|
|
||||||
json_group_object(utc.modelFamily,
|
|
||||||
json_object('input', utc.inputTokens,
|
|
||||||
'output', utc.outputTokens,
|
|
||||||
'limit', utc.tokenLimit,
|
|
||||||
'prompts', utc.prompts)) as tokenCounts,
|
|
||||||
json_object(um.key, um.value) as meta
|
|
||||||
FROM users u
|
|
||||||
LEFT JOIN user_ips ui ON u.token = ui.userToken
|
|
||||||
LEFT JOIN user_token_counts utc ON u.token = utc.userToken
|
|
||||||
LEFT JOIN user_meta um ON u.token = um.userToken
|
|
||||||
`;
|
|
||||||
|
|
||||||
if (cursor) {
|
|
||||||
sql += ` WHERE u.token < ?`;
|
|
||||||
params.push(cursor);
|
|
||||||
}
|
|
||||||
|
|
||||||
sql += ` GROUP BY u.token ORDER BY u.token DESC LIMIT ?`;
|
|
||||||
params.push(limit);
|
|
||||||
|
|
||||||
return db
|
|
||||||
.prepare(sql)
|
|
||||||
.all(params)
|
|
||||||
.map((r: any) => {
|
|
||||||
const virtual: VirtualUser = {
|
|
||||||
...marshalUser(r),
|
|
||||||
virtual: true,
|
|
||||||
ipCount: r.ipCount ?? 0,
|
|
||||||
};
|
|
||||||
return virtual;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Upserts a user record by user token. Intended for use via the REST API,
|
|
||||||
* prefer a more targeted method if possible. Undefined values are ignored,
|
|
||||||
* null values are used to indicate that the field should be cleared.
|
|
||||||
*
|
|
||||||
* @param update - The user data to upsert, with `token` required.
|
|
||||||
*/
|
|
||||||
upsertUser: (update: UserUpdate): void => {
|
|
||||||
const db = getDatabase();
|
|
||||||
if (!db.inTransaction) {
|
|
||||||
return db.transaction(() => UsersRepo.upsertUser(update))();
|
|
||||||
}
|
|
||||||
|
|
||||||
const updates: Partial<User> = {};
|
|
||||||
for (const field of Object.entries(update)) {
|
|
||||||
const [key, value] = field as [keyof User, any]; // assertion validated by zod
|
|
||||||
if (value === undefined || key === "token") continue;
|
|
||||||
updates[key] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
const setFields = Object.keys(updates)
|
|
||||||
.map((key) => `${key} = :${key}`)
|
|
||||||
.join(", ");
|
|
||||||
const params = { ...updates, token: update.token };
|
|
||||||
|
|
||||||
// scalars
|
|
||||||
const sql = `
|
|
||||||
INSERT INTO users (token, nickname, type, createdAt, lastUsedAt, disabledAt, disabledReason, expiresAt, maxIps,
|
|
||||||
adminNote)
|
|
||||||
VALUES (:token, :nickname, :type, :createdAt, :lastUsedAt, :disabledAt, :disabledReason, :expiresAt, :maxIps,
|
|
||||||
:adminNote)
|
|
||||||
ON CONFLICT(token) DO UPDATE SET ${setFields};
|
|
||||||
`;
|
|
||||||
|
|
||||||
db.prepare(sql).run(params);
|
|
||||||
|
|
||||||
// replace ip addresses
|
|
||||||
if (update.ip) {
|
|
||||||
const sql = `
|
|
||||||
DELETE
|
|
||||||
FROM user_ips
|
|
||||||
WHERE userToken = :token;
|
|
||||||
INSERT INTO user_ips (userToken, ip)
|
|
||||||
VALUES ${update.ip.map(() => "(?, ?)").join(", ")};
|
|
||||||
`;
|
|
||||||
|
|
||||||
db.prepare(sql).run(
|
|
||||||
update.ip.flatMap((ip: string) => [update.token, ip])
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (update.modelTokenCounts) {
|
|
||||||
const sql = `
|
|
||||||
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
|
|
||||||
VALUES (:token, :modelFamily, :inputTokens, :outputTokens, :tokenLimit, :prompts)
|
|
||||||
ON CONFLICT(userToken, modelFamily) DO UPDATE SET inputTokens = :inputTokens,
|
|
||||||
outputTokens = :outputTokens,
|
|
||||||
tokenLimit = :tokenLimit,
|
|
||||||
prompts = :prompts;
|
|
||||||
`;
|
|
||||||
|
|
||||||
for (const [family, counts] of Object.entries(update.modelTokenCounts)) {
|
|
||||||
db.prepare(sql).run({
|
|
||||||
token: update.token,
|
|
||||||
modelFamily: family,
|
|
||||||
...counts,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (update.meta) {
|
|
||||||
const sql = `
|
|
||||||
DELETE
|
|
||||||
FROM user_meta
|
|
||||||
WHERE userToken = :token;
|
|
||||||
INSERT INTO user_meta (userToken, key, value)
|
|
||||||
VALUES ${Object.keys(update.meta)
|
|
||||||
.map(() => "(?, ?, ?)")
|
|
||||||
.join(", ")};
|
|
||||||
`;
|
|
||||||
|
|
||||||
db.prepare(sql).run(
|
|
||||||
Object.entries(update.meta).flatMap(([key, value]) => [
|
|
||||||
update.token,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
])
|
|
||||||
);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Inserts or updates multiple user records in a single transaction.
|
|
||||||
* Periodically commits the transaction and yields to the event loop to
|
|
||||||
* prevent blocking the main thread for too long.
|
|
||||||
* @param updates - The user data to upsert.
|
|
||||||
*/
|
|
||||||
upsertUsers: async (updates: UserUpdate[]) => {
|
|
||||||
const db = getDatabase();
|
|
||||||
const BATCH_SIZE = 50;
|
|
||||||
const chunked = updates.reduce<UserUpdate[][]>((acc, _, i) => {
|
|
||||||
if (i % BATCH_SIZE === 0) acc.push(updates.slice(i, i + BATCH_SIZE));
|
|
||||||
return acc;
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const transaction = db.transaction((updates: UserUpdate[]) => {
|
|
||||||
for (const update of updates) {
|
|
||||||
UsersRepo.upsertUser(update);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
for (const chunk of chunked) {
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
|
||||||
transaction(chunk);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Increments the token usage counters for a user's token by the provided
|
|
||||||
* values, and increments prompt count by 1.
|
|
||||||
*/
|
|
||||||
incrementUsage(
|
|
||||||
userToken: string,
|
|
||||||
family: ModelFamily,
|
|
||||||
input: number,
|
|
||||||
output: number
|
|
||||||
) {
|
|
||||||
const db = getDatabase();
|
|
||||||
|
|
||||||
const sql = `
|
|
||||||
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
|
|
||||||
VALUES (:userToken, :modelFamily, :inputTokens, :outputTokens, 0, 1)
|
|
||||||
ON CONFLICT(userToken, modelFamily) DO UPDATE SET inputTokens = inputTokens + :inputTokens,
|
|
||||||
outputTokens = outputTokens + :outputTokens,
|
|
||||||
prompts = prompts + 1;
|
|
||||||
`;
|
|
||||||
|
|
||||||
db.prepare(sql).run({
|
|
||||||
userToken,
|
|
||||||
modelFamily: family,
|
|
||||||
inputTokens: input,
|
|
||||||
outputTokens: output,
|
|
||||||
});
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Disables user, optionally with reason.
|
|
||||||
*/
|
|
||||||
disableUser(userToken: string, reason?: string) {
|
|
||||||
const db = getDatabase();
|
|
||||||
const disabledAt = Date.now();
|
|
||||||
const sql = `
|
|
||||||
UPDATE users
|
|
||||||
SET disabledAt = :disabledAt,
|
|
||||||
disabledReason = :reason
|
|
||||||
WHERE token = :userToken;
|
|
||||||
INSERT OR REPLACE INTO user_meta (userToken, key, value)
|
|
||||||
VALUES (:userToken, 'refreshable', 'false');
|
|
||||||
`;
|
|
||||||
|
|
||||||
db.prepare(sql).run({ userToken, disabledAt, reason });
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Restores quotas for a user by adding the provided token counts to their
|
|
||||||
* existing counts.
|
|
||||||
*/
|
|
||||||
refreshQuotas(
|
|
||||||
userToken: string,
|
|
||||||
tokensByFamily: Record<ModelFamily, number>
|
|
||||||
): void {
|
|
||||||
const db = getDatabase();
|
|
||||||
if (!db.inTransaction) {
|
|
||||||
return db.transaction(() =>
|
|
||||||
UsersRepo.refreshQuotas(userToken, tokensByFamily)
|
|
||||||
)();
|
|
||||||
}
|
|
||||||
|
|
||||||
// for each provided family, increment the tokenLimit to equal inputTokens + outputTokens + refresh amount
|
|
||||||
const sql = `
|
|
||||||
INSERT INTO user_token_counts (userToken, modelFamily, inputTokens, outputTokens, tokenLimit, prompts)
|
|
||||||
VALUES (:userToken, :modelFamily, 0, 0, :refreshAmount, 0)
|
|
||||||
ON CONFLICT(userToken, modelFamily) DO UPDATE SET tokenLimit = inputTokens + outputTokens + :refreshAmount;
|
|
||||||
`;
|
|
||||||
|
|
||||||
for (const [family, tokens] of Object.entries(tokensByFamily)) {
|
|
||||||
db.prepare(sql).run({
|
|
||||||
userToken,
|
|
||||||
modelFamily: family,
|
|
||||||
refreshAmount: tokens,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
/**
|
|
||||||
* Resets token usage counters for a given user to zero.
|
|
||||||
*/
|
|
||||||
resetUsage(userToken: string) {
|
|
||||||
const db = getDatabase();
|
|
||||||
const sql = `
|
|
||||||
DELETE
|
|
||||||
FROM user_token_counts
|
|
||||||
WHERE userToken = :token
|
|
||||||
`;
|
|
||||||
db.prepare(sql).run({ token: userToken });
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
function marshalUser(row: any): User {
|
|
||||||
const user: Partial<User> = {
|
|
||||||
token: row.token,
|
|
||||||
nickname: row.nickname,
|
|
||||||
type: row.type,
|
|
||||||
createdAt: row.createdAt,
|
|
||||||
lastUsedAt: row.lastUsedAt,
|
|
||||||
disabledAt: row.disabledAt,
|
|
||||||
disabledReason: row.disabledReason,
|
|
||||||
expiresAt: row.expiresAt,
|
|
||||||
maxIps: row.maxIps,
|
|
||||||
adminNote: row.adminNote,
|
|
||||||
};
|
|
||||||
|
|
||||||
user.ip = row.ip ? JSON.parse(row.ip) : [];
|
|
||||||
user.meta = row.meta ? JSON.parse(row.meta) : {};
|
|
||||||
user.modelTokenCounts = JSON.parse(row.tokenCounts ?? "{}") as z.infer<
|
|
||||||
typeof tokenCountsSchema
|
|
||||||
>;
|
|
||||||
// legacy token fields
|
|
||||||
user.promptCount = 0;
|
|
||||||
user.tokenCount = 0;
|
|
||||||
user.tokenCounts = {} as z.infer<typeof tokenCountsSchemaOld>;
|
|
||||||
|
|
||||||
if (row.tokenCounts) {
|
|
||||||
// initialize missing model families
|
|
||||||
for (const family of MODEL_FAMILIES) {
|
|
||||||
if (!user.modelTokenCounts[family]) {
|
|
||||||
user.modelTokenCounts[family] = {
|
|
||||||
input: 0,
|
|
||||||
output: 0,
|
|
||||||
limit: 0,
|
|
||||||
prompts: 0,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// aggregate legacy fields
|
|
||||||
user.promptCount += user.modelTokenCounts[family].prompts;
|
|
||||||
user.tokenCount +=
|
|
||||||
user.modelTokenCounts[family].input +
|
|
||||||
user.modelTokenCounts[family].output;
|
|
||||||
user.tokenCounts[family] =
|
|
||||||
user.modelTokenCounts[family].input +
|
|
||||||
user.modelTokenCounts[family].output;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return user as User;
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,8 @@ export type APIFormat =
|
|||||||
| "anthropic-chat" // Anthropic's newer messages array format
|
| "anthropic-chat" // Anthropic's newer messages array format
|
||||||
| "anthropic-text" // Legacy flat string prompt format
|
| "anthropic-text" // Legacy flat string prompt format
|
||||||
| "google-ai"
|
| "google-ai"
|
||||||
| "mistral-ai";
|
| "mistral-ai"
|
||||||
|
| "cohere-chat";
|
||||||
|
|
||||||
export interface Key {
|
export interface Key {
|
||||||
/** The API key itself. Never log this, use `hash` instead. */
|
/** The API key itself. Never log this, use `hash` instead. */
|
||||||
|
|||||||
+18
-2
@@ -14,7 +14,8 @@ export type LLMService =
|
|||||||
| "google-ai"
|
| "google-ai"
|
||||||
| "mistral-ai"
|
| "mistral-ai"
|
||||||
| "aws"
|
| "aws"
|
||||||
| "azure";
|
| "azure"
|
||||||
|
| "cohere";
|
||||||
|
|
||||||
export type OpenAIModelFamily =
|
export type OpenAIModelFamily =
|
||||||
| "turbo"
|
| "turbo"
|
||||||
@@ -32,13 +33,15 @@ export type MistralAIModelFamily =
|
|||||||
| "mistral-large";
|
| "mistral-large";
|
||||||
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
export type AwsBedrockModelFamily = "aws-claude" | "aws-claude-opus";
|
||||||
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
export type AzureOpenAIModelFamily = `azure-${OpenAIModelFamily}`;
|
||||||
|
export type CohereModelFamily = "command-r" | "command-r-plus";
|
||||||
export type ModelFamily =
|
export type ModelFamily =
|
||||||
| OpenAIModelFamily
|
| OpenAIModelFamily
|
||||||
| AnthropicModelFamily
|
| AnthropicModelFamily
|
||||||
| GoogleAIModelFamily
|
| GoogleAIModelFamily
|
||||||
| MistralAIModelFamily
|
| MistralAIModelFamily
|
||||||
| AwsBedrockModelFamily
|
| AwsBedrockModelFamily
|
||||||
| AzureOpenAIModelFamily;
|
| AzureOpenAIModelFamily
|
||||||
|
| CohereModelFamily;
|
||||||
|
|
||||||
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
||||||
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
arr: A & ([ModelFamily] extends [A[number]] ? unknown : never)
|
||||||
@@ -64,6 +67,8 @@ export const MODEL_FAMILIES = (<A extends readonly ModelFamily[]>(
|
|||||||
"azure-gpt4-turbo",
|
"azure-gpt4-turbo",
|
||||||
"azure-gpt4o",
|
"azure-gpt4o",
|
||||||
"azure-dall-e",
|
"azure-dall-e",
|
||||||
|
"command-r",
|
||||||
|
"command-r-plus",
|
||||||
] as const);
|
] as const);
|
||||||
|
|
||||||
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
||||||
@@ -75,6 +80,7 @@ export const LLM_SERVICES = (<A extends readonly LLMService[]>(
|
|||||||
"mistral-ai",
|
"mistral-ai",
|
||||||
"aws",
|
"aws",
|
||||||
"azure",
|
"azure",
|
||||||
|
"cohere",
|
||||||
] as const);
|
] as const);
|
||||||
|
|
||||||
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
export const OPENAI_MODEL_FAMILY_MAP: { [regex: string]: OpenAIModelFamily } = {
|
||||||
@@ -116,6 +122,8 @@ export const MODEL_FAMILY_SERVICE: {
|
|||||||
"mistral-small": "mistral-ai",
|
"mistral-small": "mistral-ai",
|
||||||
"mistral-medium": "mistral-ai",
|
"mistral-medium": "mistral-ai",
|
||||||
"mistral-large": "mistral-ai",
|
"mistral-large": "mistral-ai",
|
||||||
|
"command-r": "cohere",
|
||||||
|
"command-r-plus": "cohere",
|
||||||
};
|
};
|
||||||
|
|
||||||
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
export const IMAGE_GEN_MODELS: ModelFamily[] = ["dall-e", "azure-dall-e"];
|
||||||
@@ -181,6 +189,11 @@ export function getAzureOpenAIModelFamily(
|
|||||||
return defaultFamily;
|
return defaultFamily;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getCohereModelFamily(model: string): CohereModelFamily {
|
||||||
|
if (model.includes("plus")) return "command-r-plus";
|
||||||
|
return "command-r";
|
||||||
|
}
|
||||||
|
|
||||||
export function assertIsKnownModelFamily(
|
export function assertIsKnownModelFamily(
|
||||||
modelFamily: string
|
modelFamily: string
|
||||||
): asserts modelFamily is ModelFamily {
|
): asserts modelFamily is ModelFamily {
|
||||||
@@ -220,6 +233,9 @@ export function getModelFamilyForRequest(req: Request): ModelFamily {
|
|||||||
case "mistral-ai":
|
case "mistral-ai":
|
||||||
modelFamily = getMistralAIModelFamily(model);
|
modelFamily = getMistralAIModelFamily(model);
|
||||||
break;
|
break;
|
||||||
|
case "cohere-chat":
|
||||||
|
modelFamily = getCohereModelFamily(model);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
assertNever(req.outboundApi);
|
assertNever(req.outboundApi);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { config } from "../../config";
|
import { config } from "../../config";
|
||||||
import type { EventLogEntry } from "../database";
|
import type { EventLogEntry } from "../database";
|
||||||
import { eventsRepo } from "../database/repos/events";
|
import { eventsRepo } from "../database/repos/event";
|
||||||
|
|
||||||
export const logEvent = (payload: Omit<EventLogEntry, "date">) => {
|
export const logEvent = (payload: Omit<EventLogEntry, "date">) => {
|
||||||
if (!config.eventLogging) {
|
if (!config.eventLogging) {
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import { ZodType, z } from "zod";
|
||||||
|
import { MODEL_FAMILIES, ModelFamily } from "../models";
|
||||||
|
import { makeOptionalPropsNullable } from "../utils";
|
||||||
|
|
||||||
|
// This just dynamically creates a Zod object type with a key for each model
|
||||||
|
// family and an optional number value.
|
||||||
|
export const tokenCountsSchema: ZodType<UserTokenCounts> = z.object(
|
||||||
|
MODEL_FAMILIES.reduce(
|
||||||
|
(acc, family) => ({ ...acc, [family]: z.number().optional().default(0) }),
|
||||||
|
{} as Record<ModelFamily, ZodType<number>>
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
export const UserSchema = z
|
||||||
|
.object({
|
||||||
|
/** User's personal access token. */
|
||||||
|
token: z.string(),
|
||||||
|
/** IP addresses the user has connected from. */
|
||||||
|
ip: z.array(z.string()),
|
||||||
|
/** User's nickname. */
|
||||||
|
nickname: z.string().max(80).optional(),
|
||||||
|
/**
|
||||||
|
* The user's privilege level.
|
||||||
|
* - `normal`: Default role. Subject to usual rate limits and quotas.
|
||||||
|
* - `special`: Special role. Higher quotas and exempt from
|
||||||
|
* auto-ban/lockout.
|
||||||
|
**/
|
||||||
|
type: z.enum(["normal", "special", "temporary"]),
|
||||||
|
/** Number of prompts the user has made. */
|
||||||
|
promptCount: z.number(),
|
||||||
|
/**
|
||||||
|
* @deprecated Use `tokenCounts` instead.
|
||||||
|
* Never used; retained for backwards compatibility.
|
||||||
|
*/
|
||||||
|
tokenCount: z.any().optional(),
|
||||||
|
/** Number of tokens the user has consumed, by model family. */
|
||||||
|
tokenCounts: tokenCountsSchema,
|
||||||
|
/** Maximum number of tokens the user can consume, by model family. */
|
||||||
|
tokenLimits: tokenCountsSchema,
|
||||||
|
/** Time at which the user was created. */
|
||||||
|
createdAt: z.number(),
|
||||||
|
/** Time at which the user last connected. */
|
||||||
|
lastUsedAt: z.number().optional(),
|
||||||
|
/** Time at which the user was disabled, if applicable. */
|
||||||
|
disabledAt: z.number().optional(),
|
||||||
|
/** Reason for which the user was disabled, if applicable. */
|
||||||
|
disabledReason: z.string().optional(),
|
||||||
|
/** Time at which the user will expire and be disabled (for temp users). */
|
||||||
|
expiresAt: z.number().optional(),
|
||||||
|
/** The user's maximum number of IP addresses; supercedes global max. */
|
||||||
|
maxIps: z.coerce.number().int().min(0).optional(),
|
||||||
|
/** Private note about the user. */
|
||||||
|
adminNote: z.string().optional(),
|
||||||
|
meta: z.record(z.any()).optional(),
|
||||||
|
})
|
||||||
|
.strict();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Variant of `UserSchema` which allows for partial updates, and makes any
|
||||||
|
* optional properties on the base schema nullable. Null values are used to
|
||||||
|
* indicate that the property should be deleted from the user object.
|
||||||
|
*/
|
||||||
|
export const UserPartialSchema = makeOptionalPropsNullable(UserSchema)
|
||||||
|
.partial()
|
||||||
|
.extend({ token: z.string() });
|
||||||
|
|
||||||
|
export type UserTokenCounts = {
|
||||||
|
[K in ModelFamily]: number | undefined;
|
||||||
|
};
|
||||||
|
export type User = z.infer<typeof UserSchema>;
|
||||||
|
export type UserUpdate = z.infer<typeof UserPartialSchema>;
|
||||||
@@ -22,9 +22,9 @@ import {
|
|||||||
ModelFamily,
|
ModelFamily,
|
||||||
} from "../models";
|
} from "../models";
|
||||||
import { logger } from "../../logger";
|
import { logger } from "../../logger";
|
||||||
|
import { User, UserTokenCounts, UserUpdate } from "./schema";
|
||||||
import { APIFormat } from "../key-management";
|
import { APIFormat } from "../key-management";
|
||||||
import { assertNever } from "../utils";
|
import { assertNever } from "../utils";
|
||||||
import { User, UserTokenCounts, UserUpdate } from "../database/repos/users";
|
|
||||||
|
|
||||||
const log = logger.child({ module: "users" });
|
const log = logger.child({ module: "users" });
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import { Router } from "express";
|
import { Router } from "express";
|
||||||
|
import { UserPartialSchema } from "../../shared/users/schema";
|
||||||
import * as userStore from "../../shared/users/user-store";
|
import * as userStore from "../../shared/users/user-store";
|
||||||
import { ForbiddenError, BadRequestError } from "../../shared/errors";
|
import { ForbiddenError, BadRequestError } from "../../shared/errors";
|
||||||
import { sanitizeAndTrim } from "../../shared/utils";
|
import { sanitizeAndTrim } from "../../shared/utils";
|
||||||
import { config } from "../../config";
|
import { config } from "../../config";
|
||||||
import { UserPartialSchema } from "../../shared/database/repos/users";
|
|
||||||
|
|
||||||
const router = Router();
|
const router = Router();
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user