diff --git a/src/config.ts b/src/config.ts index 4d0f17e..50756c7 100644 --- a/src/config.ts +++ b/src/config.ts @@ -58,6 +58,10 @@ type Config = { * Comma-delimited list of Qwen API keys. */ qwenKey?: string; + /** + * Comma-delimited list of Moonshot API keys. + */ + moonshotKey?: string; /** * Comma-delimited list of AWS credentials. Each credential item should be a @@ -464,6 +468,7 @@ export const config: Config = { deepseekKey: getEnvWithDefault("DEEPSEEK_KEY", ""), xaiKey: getEnvWithDefault("XAI_KEY", ""), cohereKey: getEnvWithDefault("COHERE_KEY", ""), + moonshotKey: getEnvWithDefault("MOONSHOT_KEY", ""), awsCredentials: getEnvWithDefault("AWS_CREDENTIALS", ""), gcpCredentials: getEnvWithDefault("GCP_CREDENTIALS", ""), azureCredentials: getEnvWithDefault("AZURE_CREDENTIALS", ""), @@ -765,6 +770,7 @@ export const OMITTED_KEYS = [ "xaiKey", "cohereKey", "qwenKey", + "moonshotKey", "mistralAIKey", "awsCredentials", "gcpCredentials", diff --git a/src/info-page.ts b/src/info-page.ts index 27cd433..11afed4 100644 --- a/src/info-page.ts +++ b/src/info-page.ts @@ -31,6 +31,7 @@ const MODEL_FAMILY_FRIENDLY_NAME: { [f in ModelFamily]: string } = { cohere: "Cohere", deepseek: "Deepseek", xai: "Grok", + moonshot: "Moonshot", turbo: "GPT-4o Mini / 3.5 Turbo", gpt4: "GPT-4", "gpt4-32k": "GPT-4 32k", diff --git a/src/proxy/middleware/request/mutators/add-key.ts b/src/proxy/middleware/request/mutators/add-key.ts index 21b32a8..57a5c75 100644 --- a/src/proxy/middleware/request/mutators/add-key.ts +++ b/src/proxy/middleware/request/mutators/add-key.ts @@ -105,6 +105,9 @@ export const addKey: ProxyReqMutator = (manager) => { case "qwen": manager.setHeader("Authorization", `Bearer ${assignedKey.key}`); break; + case "moonshot": + manager.setHeader("Authorization", `Bearer ${assignedKey.key}`); + break; case "aws": case "gcp": case "google-ai": diff --git a/src/proxy/middleware/request/preprocessors/validate-context-size.ts b/src/proxy/middleware/request/preprocessors/validate-context-size.ts index 553a958..cd5ad7e 100644 --- a/src/proxy/middleware/request/preprocessors/validate-context-size.ts +++ b/src/proxy/middleware/request/preprocessors/validate-context-size.ts @@ -125,23 +125,30 @@ export const validateContextSize: RequestPreprocessor = async (req) => { modelMax = 100000; } else if (model.match(/^deepseek/)) { modelMax = 64000; + } else if (model.match(/^kimi-k2/)) { + // Kimi K2 models have 131k context window + modelMax = 131000; + } else if (model.match(/moonshot/)) { + // Moonshot models typically have 200k context window + modelMax = 200000; + } else if (model.match(/command[\w-]*-03-202[0-9]/)) { + // Cohere's command-a-03 models have 256k context window + modelMax = 256000; + } else if (model.match(/command/) || model.match(/cohere/)) { + // Default for all other Cohere models + modelMax = 128000; } else if (model.match(/^grok-4/)) { modelMax = 256000; } else if (model.match(/^grok/)) { modelMax = 128000; - } else if (model.match(/command-a-03-202[0-9]/)) { - // Cohere's command-a-03 models have 256k context window - modelMax = 256000; - } else if (model.match(/command[\w-]*-03-202[0-9]/)) { - // Other Command models with -03- pattern (including r, r-plus) have 128k context window - modelMax = 128000; - } else if (model.match(/command/) || model.match(/cohere/)) { - // Default for all other Cohere models - modelMax = 128000; } else if (model.match(/^magistral/)) { modelMax = 40000; } else if (model.match(/^magistral/)) { modelMax = 40000; + } else if (model.match(/^moonshot/)) { + modelMax = 200000; + } else if (model.match(/^kimi-k2/)) { + modelMax = 131000; } else if (model.match(/tral/)) { // catches mistral, mixtral, codestral, mathstral, etc. mistral models have // no name convention and wildly different context windows so this is a diff --git a/src/proxy/middleware/response/index.ts b/src/proxy/middleware/response/index.ts index 627c5e8..494a831 100644 --- a/src/proxy/middleware/response/index.ts +++ b/src/proxy/middleware/response/index.ts @@ -267,6 +267,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( case "qwen": // No special handling yet break; + case "moonshot": + errorPayload.proxy_note = `The Moonshot API rejected the request. Check the error message for details.`; + break; default: assertNever(service); } @@ -328,6 +331,7 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( return; case "mistral-ai": case "gcp": + case "moonshot": keyPool.disable(req.key!, "revoked"); errorPayload.proxy_note = `Assigned API key is invalid or revoked, please try again.`; return; @@ -366,6 +370,9 @@ const handleUpstreamErrors: ProxyResHandlerWithBody = async ( // Similar handling to OpenAI for rate limits await handleOpenAIRateLimitError(req, errorPayload); break; + case "moonshot": + await handleMoonshotRateLimitError(req, errorPayload); + break; default: assertNever(service as never); } @@ -598,6 +605,39 @@ async function handleCohereRateLimitError( errorPayload.proxy_note = "Too many requests to the Cohere API. Please try again later."; } +async function handleMoonshotRateLimitError( + req: Request, + errorPayload: ProxiedErrorPayload +) { + // Mark the current key as rate limited + keyPool.markRateLimited(req.key!); + + // Store the original request attempt count or initialize it + req.retryCount = (req.retryCount || 0) + 1; + + // Only retry up to 3 times with different keys + if (req.retryCount <= 3) { + try { + // Add a small delay before retrying (2-6 seconds for Moonshot) + const delayMs = 2000 + Math.floor(Math.random() * 4000); + await new Promise(resolve => setTimeout(resolve, delayMs)); + + // Re-enqueue the request to try with a different key + await reenqueueRequest(req); + req.log.info({ attempt: req.retryCount }, "Moonshot rate-limited request re-enqueued"); + throw new RetryableError(`Moonshot rate-limited request re-enqueued (attempt ${req.retryCount}/3).`); + } catch (error) { + if (error instanceof RetryableError) { + throw error; // Rethrow RetryableError to continue the flow + } + req.log.error({ error }, "Failed to re-enqueue rate-limited Moonshot request"); + } + } + + // If we've already retried 3 times, show the error to the user + errorPayload.proxy_note = "Too many requests to the Moonshot API. Please try again later."; +} + async function handleOpenAIRateLimitError( req: Request, errorPayload: ProxiedErrorPayload diff --git a/src/proxy/moonshot.ts b/src/proxy/moonshot.ts new file mode 100644 index 0000000..82424b6 --- /dev/null +++ b/src/proxy/moonshot.ts @@ -0,0 +1,219 @@ +import { Request, RequestHandler, Router } from "express"; +import { createPreprocessorMiddleware } from "./middleware/request"; +import { ipLimiter } from "./rate-limit"; +import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; +import { addKey, finalizeBody } from "./middleware/request"; +import { ProxyResHandlerWithBody } from "./middleware/response"; +import axios from "axios"; +import { MoonshotKey, keyPool } from "../shared/key-management"; +import { isMoonshotModel, isMoonshotVisionModel, enableMoonshotPartial, hasMoonshotPartialMode } from "../shared/api-schemas/moonshot"; +import { logger } from "../logger"; + +const log = logger.child({ module: "proxy", service: "moonshot" }); +let modelsCache: any = null; +let modelsCacheTime = 0; + +const moonshotResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + res.status(200).json({ ...body, proxy: body.proxy }); +}; + +const getModelsResponse = async () => { + // Return cache if less than 1 minute old + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; + } + + try { + const modelToUse = "moonshot-v1-8k"; + const moonshotKey = keyPool.get(modelToUse, "moonshot") as MoonshotKey; + + if (!moonshotKey || !moonshotKey.key) { + log.warn("No valid Moonshot key available for model listing"); + throw new Error("No valid Moonshot API key available"); + } + + // Fetch models from Moonshot API + const response = await axios.get("https://api.moonshot.cn/v1/models", { + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${moonshotKey.key}` + }, + }); + + if (!response.data || !response.data.data) { + throw new Error("Unexpected response format from Moonshot API"); + } + + // Format response to ensure OpenAI compatibility + const models = { + object: "list", + data: response.data.data.map((model: any) => ({ + id: model.id, + object: "model", + created: model.created || Math.floor(Date.now() / 1000), + owned_by: model.owned_by || "moonshot", + permission: model.permission || [], + root: model.root || model.id, + parent: model.parent || null, + })), + }; + + log.debug({ modelCount: models.data.length }, "Retrieved models from Moonshot API"); + + // Cache the response + modelsCache = models; + modelsCacheTime = new Date().getTime(); + return models; + } catch (error) { + if (error instanceof Error) { + log.error( + { errorMessage: error.message, stack: error.stack }, + "Error fetching Moonshot models" + ); + } else { + log.error({ error }, "Unknown error fetching Moonshot models"); + } + + // Return a default list of known Moonshot models as fallback + return { + object: "list", + data: [ + { id: "moonshot-v1-8k", object: "model", created: 1678888000, owned_by: "moonshot" }, + { id: "moonshot-v1-32k", object: "model", created: 1678888000, owned_by: "moonshot" }, + { id: "moonshot-v1-128k", object: "model", created: 1678888000, owned_by: "moonshot" }, + ], + }; + } +}; + +const handleModelRequest: RequestHandler = async (_req, res) => { + try { + const models = await getModelsResponse(); + res.status(200).json(models); + } catch (error) { + if (error instanceof Error) { + log.error( + { errorMessage: error.message, stack: error.stack }, + "Error handling model request" + ); + } else { + log.error({ error }, "Unknown error handling model request"); + } + res.status(500).json({ error: "Failed to fetch models" }); + } +}; + +// Function to handle partial mode for Moonshot +function handlePartialMode(req: Request) { + if (!process.env.NO_MOONSHOT_PARTIAL && req.body.messages && Array.isArray(req.body.messages)) { + const msgs = req.body.messages; + if (msgs.at(-1)?.role !== 'assistant') return; + + let i = msgs.length - 1; + let content = ''; + + while (i >= 0 && msgs[i].role === 'assistant') { + // Consolidate consecutive assistant messages + content = msgs[i--].content + content; + } + + // Replace consecutive assistant messages with single message with partial: true + msgs.splice(i + 1, msgs.length, { role: 'assistant', content, partial: true }); + log.debug("Consolidated assistant messages and enabled partial mode for Moonshot request"); + } +} + +// Function to handle vision model content transformation +function handleVisionContent(req: Request) { + const model = req.body.model; + + if (isMoonshotVisionModel(model) && req.body.messages) { + // Ensure vision content is properly formatted + req.body.messages = req.body.messages.map((msg: any) => { + if (msg.content && typeof msg.content === 'string') { + // Keep string content as is for non-vision requests + return msg; + } + return msg; + }); + } +} + +// Function to count tokens for Moonshot models +function countMoonshotTokens(req: Request) { + const model = req.body.model; + + if (isMoonshotModel(model)) { + if (req.promptTokens) { + log.debug( + { tokens: req.promptTokens, model }, + "Estimated token count for Moonshot prompt" + ); + } + } +} + +// Handle rate limit errors for Moonshot +async function handleMoonshotRateLimitError(req: Request, error: any) { + if (error.response?.status === 429) { + log.warn({ model: req.body.model }, "Moonshot rate limit hit, rotating key"); + + const currentKey = req.key as MoonshotKey; + keyPool.markRateLimited(currentKey); + + // Try to get a new key + const newKey = keyPool.get(req.body.model, "moonshot") as MoonshotKey; + if (newKey.hash !== currentKey.hash) { + req.key = newKey; + return true; // Retry with new key + } + } + return false; +} + +const moonshotProxy = createQueuedProxyMiddleware({ + mutations: [ + addKey, + finalizeBody + ], + target: "https://api.moonshot.cn", + blockingResponseHandler: moonshotResponseHandler, +}); + +const moonshotRouter = Router(); + +// Chat completions endpoint +moonshotRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "openai", outApi: "openai", service: "moonshot" }, + { afterTransform: [ handlePartialMode, handleVisionContent, countMoonshotTokens ] } + ), + moonshotProxy +); + +// Embeddings endpoint +moonshotRouter.post( + "/v1/embeddings", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "openai", outApi: "openai", service: "moonshot" }, + { afterTransform: [ countMoonshotTokens ] } + ), + moonshotProxy +); + +// Models endpoint +moonshotRouter.get("/v1/models", handleModelRequest); + +export const moonshot = moonshotRouter; diff --git a/src/proxy/routes.ts b/src/proxy/routes.ts index e94d522..a463e48 100644 --- a/src/proxy/routes.ts +++ b/src/proxy/routes.ts @@ -14,6 +14,7 @@ import { deepseek } from "./deepseek"; import { xai } from "./xai"; import { cohere } from "./cohere"; import { qwen } from "./qwen"; +import { moonshot } from "./moonshot"; import { sendErrorToClient } from "./middleware/response/error-generator"; const proxyRouter = express.Router(); @@ -57,6 +58,7 @@ proxyRouter.use("/deepseek", addV1, deepseek); proxyRouter.use("/xai", addV1, xai); proxyRouter.use("/cohere", addV1, cohere); proxyRouter.use("/qwen", addV1, qwen); +proxyRouter.use("/moonshot", addV1, moonshot); // Redirect browser requests to the homepage. proxyRouter.get("*", (req, res, next) => { diff --git a/src/service-info.ts b/src/service-info.ts index 004064c..78e9817 100644 --- a/src/service-info.ts +++ b/src/service-info.ts @@ -9,6 +9,7 @@ import { XaiKey, CohereKey, QwenKey, + MoonshotKey, } from "./shared/key-management"; import { AnthropicModelFamily, @@ -27,6 +28,7 @@ import { XaiModelFamily, CohereModelFamily, QwenModelFamily, + MoonshotModelFamily, } from "./shared/models"; import { getCostSuffix, getTokenCostUsd, prettyTokens } from "./shared/stats"; import { getUniqueIps } from "./proxy/rate-limit"; @@ -50,6 +52,8 @@ const keyIsCohereKey = (k: KeyPoolKey): k is CohereKey => k.service === "cohere"; const keyIsQwenKey = (k: KeyPoolKey): k is QwenKey => k.service === "qwen"; +const keyIsMoonshotKey = (k: KeyPoolKey): k is MoonshotKey => + k.service === "moonshot"; /** Stats aggregated across all keys for a given service. */ type ServiceAggregate = "keys" | "uncheckedKeys" | "orgs"; @@ -147,7 +151,8 @@ export type ServiceInfo = { & { [f in DeepseekModelFamily]?: BaseFamilyInfo } & { [f in XaiModelFamily]?: BaseFamilyInfo } & { [f in CohereModelFamily]?: BaseFamilyInfo } - & { [f in QwenModelFamily]?: BaseFamilyInfo }; + & { [f in QwenModelFamily]?: BaseFamilyInfo } + & { [f in MoonshotModelFamily]?: BaseFamilyInfo }; // https://stackoverflow.com/a/66661477 // type DeepKeyOf = ( @@ -201,6 +206,9 @@ const SERVICE_ENDPOINTS: { [s in LLMService]: Record } = { qwen: { qwen: `%BASE%/qwen`, }, + moonshot: { + moonshot: `%BASE%/moonshot`, + }, }; const familyStats = new Map(); @@ -358,6 +366,7 @@ function addKeyToAggregates(k: KeyPoolKey) { addToService("xai__keys", k.service === "xai" ? 1 : 0); addToService("cohere__keys", k.service === "cohere" ? 1 : 0); addToService("qwen__keys", k.service === "qwen" ? 1 : 0); + addToService("moonshot__keys", k.service === "moonshot" ? 1 : 0); let sumInputTokens = 0; let sumOutputTokens = 0; @@ -521,6 +530,9 @@ function addKeyToAggregates(k: KeyPoolKey) { case "qwen": k.modelFamilies.forEach(incrementGenericFamilyStats); break; + case "moonshot": + k.modelFamilies.forEach(incrementGenericFamilyStats); + break; default: assertNever(k.service); } @@ -640,6 +652,9 @@ function getInfoForFamily(family: ModelFamily): BaseFamilyInfo { case "qwen": info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0; break; + case "moonshot": + info.overQuotaKeys = familyStats.get(`${family}__overQuota`) || 0; + break; } } diff --git a/src/shared/api-schemas/moonshot.ts b/src/shared/api-schemas/moonshot.ts new file mode 100644 index 0000000..50a1bfc --- /dev/null +++ b/src/shared/api-schemas/moonshot.ts @@ -0,0 +1,106 @@ +import { z } from "zod"; +import { OPENAI_OUTPUT_MAX } from "./openai"; + +/** + * Helper function to check if a model is from Moonshot + */ +export function isMoonshotModel(model: string): boolean { + return model.includes("moonshot"); +} + +/** + * Helper function to check if a model is a Moonshot vision model + */ +export function isMoonshotVisionModel(model: string): boolean { + return model.includes("moonshot") && model.includes("vision"); +} + +// Content schema for vision models +const MoonshotVisionContentSchema = z.union([ + z.string(), + z.array( + z.union([ + z.object({ + type: z.literal("text"), + text: z.string(), + }), + z.object({ + type: z.literal("image_url"), + image_url: z.object({ + url: z.string(), + detail: z.enum(["low", "high", "auto"]).optional(), + }), + }), + ]) + ), +]); + +// Basic chat message schema +const MoonshotChatMessageSchema = z.object({ + role: z.enum(["user", "assistant", "system"]), + content: z.union([z.string(), MoonshotVisionContentSchema]).nullable(), + name: z.string().optional(), + // Support for partial mode + partial: z.boolean().optional(), +}); + +const MoonshotMessagesSchema = z.array(MoonshotChatMessageSchema); + +// Schema for Moonshot chat completions +export const MoonshotV1ChatCompletionsSchema = z.object({ + model: z.string(), + messages: MoonshotMessagesSchema, + temperature: z.number().optional().default(0.3), + top_p: z.number().optional().default(1), + max_tokens: z.coerce + .number() + .int() + .nullish() + .transform((v) => Math.min(v ?? OPENAI_OUTPUT_MAX, OPENAI_OUTPUT_MAX)), + stream: z.boolean().optional().default(false), + stop: z + .union([z.string(), z.array(z.string()).max(5)]) + .optional() + .default([]) + .transform((v) => (Array.isArray(v) ? v : [v])), + seed: z.number().int().min(0).optional(), + response_format: z + .object({ + type: z.enum(["text", "json_object"]) + }) + .optional(), + tools: z.array(z.any()).optional(), + tool_choice: z.any().optional(), + frequency_penalty: z.number().min(-2).max(2).optional().default(0), + presence_penalty: z.number().min(-2).max(2).optional().default(0), + n: z.number().int().min(1).max(5).optional().default(1), +}); + +// Schema for Moonshot embeddings +export const MoonshotV1EmbeddingsSchema = z.object({ + model: z.string(), + input: z.union([z.string(), z.array(z.string())]), + encoding_format: z.enum(["float", "base64"]).optional() +}); + +// Helper function to enable partial mode for Moonshot (similar to Deepseek's prefill) +export function enableMoonshotPartial(messages: any[]): any[] { + // If the last message is from assistant and doesn't have partial flag, add it + if (messages.length > 0 && messages[messages.length - 1].role === 'assistant') { + const lastMessage = messages[messages.length - 1]; + if (!lastMessage.partial) { + return [ + ...messages.slice(0, -1), + { ...lastMessage, partial: true } + ]; + } + } + return messages; +} + +// Helper function to check if request uses partial mode +export function hasMoonshotPartialMode(messages: any[]): boolean { + return messages.length > 0 && + messages[messages.length - 1].role === 'assistant' && + messages[messages.length - 1].partial === true; +} diff --git a/src/shared/key-management/index.ts b/src/shared/key-management/index.ts index 7f09eaa..5bc382d 100644 --- a/src/shared/key-management/index.ts +++ b/src/shared/key-management/index.ts @@ -105,3 +105,4 @@ export { DeepseekKey } from "./deepseek/provider"; export { XaiKey } from "./xai/provider"; export { CohereKey } from "./cohere/provider"; export { QwenKey } from "./qwen/provider"; +export { MoonshotKey } from "./moonshot/provider"; diff --git a/src/shared/key-management/key-pool.ts b/src/shared/key-management/key-pool.ts index b54047f..8b1a13b 100644 --- a/src/shared/key-management/key-pool.ts +++ b/src/shared/key-management/key-pool.ts @@ -17,6 +17,7 @@ import { DeepseekKeyProvider } from "./deepseek/provider"; import { XaiKeyProvider } from "./xai/provider"; import { CohereKeyProvider } from "./cohere/provider"; import { QwenKeyProvider } from "./qwen/provider"; +import { MoonshotKeyProvider } from "./moonshot/provider"; type AllowedPartial = OpenAIKeyUpdate | AnthropicKeyUpdate | Partial; @@ -38,6 +39,7 @@ export class KeyPool { this.keyProviders.push(new XaiKeyProvider()); this.keyProviders.push(new CohereKeyProvider()); this.keyProviders.push(new QwenKeyProvider()); + this.keyProviders.push(new MoonshotKeyProvider()); } public init() { @@ -81,7 +83,8 @@ export class KeyPool { service instanceof DeepseekKeyProvider || service instanceof XaiKeyProvider || service instanceof CohereKeyProvider || - service instanceof QwenKeyProvider + service instanceof QwenKeyProvider || + service instanceof MoonshotKeyProvider ) { service.update(key.hash, { isOverQuota: reason === "quota" }); } @@ -211,6 +214,8 @@ export class KeyPool { return "cohere"; } else if (model.includes("qwen")) { return "qwen"; + } else if (model.includes("moonshot")) { + return "moonshot"; } else if (model.startsWith("anthropic.claude")) { // AWS offers models from a few providers // https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html diff --git a/src/shared/key-management/moonshot/checker.ts b/src/shared/key-management/moonshot/checker.ts new file mode 100644 index 0000000..447ef40 --- /dev/null +++ b/src/shared/key-management/moonshot/checker.ts @@ -0,0 +1,127 @@ +import { MoonshotKey } from "./provider"; +import { logger } from "../../../logger"; +import { assertNever } from "../../utils"; + +const CHECK_TIMEOUT = 10000; +const API_URL = "https://api.moonshot.cn/v1/users/me/balance"; + +export class MoonshotKeyChecker { + private log = logger.child({ module: "key-checker", service: "moonshot" }); + + constructor(private readonly update: (hash: string, key: Partial) => void) { + this.log.info("MoonshotKeyChecker initialized"); + } + + public async checkKey(key: MoonshotKey): Promise { + this.log.info({ hash: key.hash }, "Starting key validation check"); + try { + const result = await this.validateKey(key); + this.handleCheckResult(key, result); + } catch (error) { + if (error instanceof Error) { + this.log.warn( + { error: error.message, stack: error.stack, hash: key.hash }, + "Failed to check key status" + ); + } else { + this.log.warn( + { error, hash: key.hash }, + "Failed to check key status with unknown error" + ); + } + } + } + + private async validateKey(key: MoonshotKey): Promise<"valid" | "invalid" | "quota"> { + const controller = new AbortController(); + const timeout = setTimeout(() => { + controller.abort(); + this.log.warn({ hash: key.hash }, "Key validation timed out after " + CHECK_TIMEOUT + "ms"); + }, CHECK_TIMEOUT); + + try { + // Check balance endpoint to verify key validity + const headers = { + "Content-Type": "application/json", + "Authorization": `Bearer ${key.key}` + }; + + const response = await fetch(API_URL, { + method: "GET", + headers, + signal: controller.signal, + }); + + if (response.status === 200) { + const data = await response.json(); + // Check if response has the expected Moonshot API structure + if (data && data.status === true && data.code === 0 && data.data) { + const balance = data.data.available_balance; + // Check if balance is too low (consider it quota exceeded if balance is 0 or negative) + if (typeof balance === 'number' && balance <= 0) { + return "quota"; + } + return "valid"; + } else { + this.log.warn( + { response: data, hash: key.hash }, + "Unexpected response format from Moonshot API" + ); + return "invalid"; + } + } else if (response.status === 401) { + // Unauthorized - invalid key + return "invalid"; + } else if (response.status === 429) { + // Rate limit - but key is valid + return "valid"; + } else { + this.log.warn( + { status: response.status, hash: key.hash }, + "Unexpected status code while testing key validity" + ); + return "invalid"; + } + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + this.log.warn({ hash: key.hash }, "Key validation aborted"); + } + throw error; + } finally { + clearTimeout(timeout); + } + } + + private handleCheckResult( + key: MoonshotKey, + result: "valid" | "invalid" | "quota" + ): void { + switch (result) { + case "valid": + this.log.info({ hash: key.hash }, "Key is valid and enabled"); + this.update(key.hash, { + isDisabled: false, + lastChecked: Date.now(), + }); + break; + case "invalid": + this.log.warn({ hash: key.hash }, "Key is invalid, marking as revoked"); + this.update(key.hash, { + isDisabled: true, + isRevoked: true, + lastChecked: Date.now(), + }); + break; + case "quota": + this.log.warn({ hash: key.hash }, "Key has exceeded its quota, disabling"); + this.update(key.hash, { + isDisabled: true, + isOverQuota: true, + lastChecked: Date.now(), + }); + break; + default: + assertNever(result); + } + } +} diff --git a/src/shared/key-management/moonshot/index.ts b/src/shared/key-management/moonshot/index.ts new file mode 100644 index 0000000..fe6feb7 --- /dev/null +++ b/src/shared/key-management/moonshot/index.ts @@ -0,0 +1,2 @@ +export { MoonshotKey, MoonshotKeyProvider } from "./provider"; +export { MoonshotKeyChecker } from "./checker"; diff --git a/src/shared/key-management/moonshot/provider.ts b/src/shared/key-management/moonshot/provider.ts new file mode 100644 index 0000000..b9467ab --- /dev/null +++ b/src/shared/key-management/moonshot/provider.ts @@ -0,0 +1,166 @@ +import { Key, KeyProvider, createGenericGetLockoutPeriod } from ".."; +import { MoonshotKeyChecker } from "./checker"; +import { config } from "../../../config"; +import { logger } from "../../../logger"; +import { MoonshotModelFamily, ModelFamily } from "../../models"; + +export interface MoonshotKey extends Key { + readonly service: "moonshot"; + readonly modelFamilies: MoonshotModelFamily[]; + isOverQuota: boolean; +} + +export class MoonshotKeyProvider implements KeyProvider { + readonly service = "moonshot"; + + private keys: MoonshotKey[] = []; + private checker?: MoonshotKeyChecker; + private log = logger.child({ module: "key-provider", service: this.service }); + + constructor() { + const keyConfig = config.moonshotKey?.trim(); + if (!keyConfig) { + return; + } + + const keys = keyConfig.split(",").map((k) => k.trim()); + for (const key of keys) { + if (!key) continue; + this.keys.push({ + key, + service: this.service, + modelFamilies: ["moonshot"], + isDisabled: false, + isRevoked: false, + promptCount: 0, + lastUsed: 0, + lastChecked: 0, + hash: this.hashKey(key), + rateLimitedAt: 0, + rateLimitedUntil: 0, + tokenUsage: {}, + isOverQuota: false, + }); + } + } + + private hashKey(key: string): string { + return require("crypto").createHash("sha256").update(key).digest("hex"); + } + + public init() { + if (this.keys.length === 0) return; + if (!config.checkKeys) { + this.log.warn( + "Key checking is disabled. Keys will not be verified." + ); + return; + } + this.checker = new MoonshotKeyChecker(this.update.bind(this)); + for (const key of this.keys) { + void this.checker.checkKey(key); + } + } + + public get(model: string): MoonshotKey { + const availableKeys = this.keys.filter((k) => !k.isDisabled); + if (availableKeys.length === 0) { + throw new Error("No Moonshot keys available"); + } + const key = availableKeys[Math.floor(Math.random() * availableKeys.length)]; + key.lastUsed = Date.now(); + this.throttle(key.hash); + return { ...key }; + } + + public list(): Omit[] { + return this.keys.map(({ key, ...rest }) => rest); + } + + public disable(key: MoonshotKey): void { + const found = this.keys.find((k) => k.hash === key.hash); + if (found) { + found.isDisabled = true; + } + } + + public update(hash: string, update: Partial): void { + const key = this.keys.find((k) => k.hash === hash); + if (key) { + Object.assign(key, update); + } + } + + public available(): number { + return this.keys.filter((k) => !k.isDisabled).length; + } + + public incrementUsage(keyHash: string, modelFamily: MoonshotModelFamily, usage: { input: number; output: number }) { + const key = this.keys.find((k) => k.hash === keyHash); + if (!key) return; + + key.promptCount++; + + if (!key.tokenUsage) { + key.tokenUsage = {}; + } + // Moonshot only has one model family "moonshot" + if (!key.tokenUsage[modelFamily]) { + key.tokenUsage[modelFamily] = { input: 0, output: 0 }; + } + + const currentFamilyUsage = key.tokenUsage[modelFamily]!; + currentFamilyUsage.input += usage.input; + currentFamilyUsage.output += usage.output; + } + + /** + * Upon being rate limited, a key will be locked out for this many milliseconds + * while we wait for other concurrent requests to finish. + */ + private static readonly RATE_LIMIT_LOCKOUT = 2000; + /** + * Upon assigning a key, we will wait this many milliseconds before allowing it + * to be used again. This is to prevent the queue from flooding a key with too + * many requests while we wait to learn whether previous ones succeeded. + */ + private static readonly KEY_REUSE_DELAY = 500; + + getLockoutPeriod = createGenericGetLockoutPeriod(() => this.keys); + + public markRateLimited(keyHash: string) { + this.log.debug({ key: keyHash }, "Key rate limited"); + const key = this.keys.find((k) => k.hash === keyHash)!; + const now = Date.now(); + key.rateLimitedAt = now; + key.rateLimitedUntil = now + MoonshotKeyProvider.RATE_LIMIT_LOCKOUT; + } + + public recheck(): void { + if (!this.checker || !config.checkKeys) return; + for (const key of this.keys) { + this.update(key.hash, { + isOverQuota: false, + isDisabled: false, + lastChecked: 0 + }); + void this.checker.checkKey(key); + } + } + + /** + * Applies a short artificial delay to the key upon dequeueing, in order to + * prevent it from being immediately assigned to another request before the + * current one can be dispatched. + **/ + private throttle(hash: string) { + const now = Date.now(); + const key = this.keys.find((k) => k.hash === hash)!; + + const currentRateLimit = key.rateLimitedUntil; + const nextRateLimit = now + MoonshotKeyProvider.KEY_REUSE_DELAY; + + key.rateLimitedAt = now; + key.rateLimitedUntil = Math.max(currentRateLimit, nextRateLimit); + } +} diff --git a/src/shared/models.ts b/src/shared/models.ts index 88d94e2..15cd23c 100644 --- a/src/shared/models.ts +++ b/src/shared/models.ts @@ -18,7 +18,8 @@ export type LLMService = | "deepseek" | "xai" | "cohere" - | "qwen"; + | "qwen" + | "moonshot"; export type OpenAIModelFamily = | "turbo" @@ -58,6 +59,7 @@ export type DeepseekModelFamily = "deepseek"; export type XaiModelFamily = "xai"; export type CohereModelFamily = "cohere"; export type QwenModelFamily = "qwen"; +export type MoonshotModelFamily = "moonshot"; export type ModelFamily = | OpenAIModelFamily @@ -70,11 +72,13 @@ export type ModelFamily = | DeepseekModelFamily | XaiModelFamily | CohereModelFamily - | QwenModelFamily; + | QwenModelFamily + | MoonshotModelFamily; export const MODEL_FAMILIES = (( arr: A & ([ModelFamily] extends [A[number]] ? unknown : never) ) => arr)([ + "moonshot", "qwen", "cohere", "xai", @@ -149,12 +153,14 @@ export const LLM_SERVICES = (( "deepseek", "xai", "cohere", - "qwen" + "qwen", + "moonshot" ] as const); export const MODEL_FAMILY_SERVICE: { [f in ModelFamily]: LLMService; } = { + moonshot: "moonshot", qwen: "qwen", cohere: "cohere", xai: "xai", @@ -404,12 +410,10 @@ export function getModelFamilyForRequest(req: Request): ModelFamily { case "openai-image": if (req.service === "deepseek") { modelFamily = "deepseek"; - } else { - modelFamily = getOpenAIModelFamily(model); - } - break; - if (req.service === "xai") { + } else if (req.service === "xai") { modelFamily = "xai"; + } else if (req.service === "moonshot") { + modelFamily = "moonshot"; } else { modelFamily = getOpenAIModelFamily(model); } diff --git a/src/shared/stats.ts b/src/shared/stats.ts index 53581da..b2d2195 100644 --- a/src/shared/stats.ts +++ b/src/shared/stats.ts @@ -64,6 +64,7 @@ const MODEL_PRICING: Record