diff --git a/src/proxy/xai.ts b/src/proxy/xai.ts index 9d65a58..fd1078e 100644 --- a/src/proxy/xai.ts +++ b/src/proxy/xai.ts @@ -1,77 +1,156 @@ -import { Request, RequestHandler, Router } from "express"; -import { createPreprocessorMiddleware } from "./middleware/request"; -import { ipLimiter } from "./rate-limit"; -import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; -import { addKey, finalizeBody } from "./middleware/request"; -import { ProxyResHandlerWithBody } from "./middleware/response"; - -const xaiResponseHandler: ProxyResHandlerWithBody = async ( - _proxyRes, - req, - res, - body -) => { - if (typeof body !== "object") { - throw new Error("Expected body to be an object"); - } - - let newBody = body; - - res.status(200).json({ ...newBody, proxy: body.proxy }); -}; - -const handleModelRequest: RequestHandler = (_req, res) => { - res.status(200).json({ - object: "list", - data: [ - { id: "grok-2-1212", object: "model", owned_by: "xai" }, - { id: "grok-beta", object: "model", owned_by: "xai" }, - { id: "grok-3", object: "model", owned_by: "xai"}, - { id: "grok-3-fast", object: "model", owned_by: "xai"}, - { id: "grok-3-mini", object: "model", owned_by: "xai"}, - { id: "grok-3-mini-fast", object: "model", owned_by: "xai"}, - ], - }); -}; - -const xaiProxy = createQueuedProxyMiddleware({ - mutations: [addKey, finalizeBody], - target: "https://api.x.ai", - blockingResponseHandler: xaiResponseHandler, -}); - -const xaiRouter = Router(); - -// combines all the assistant messages at the end of the context and adds the -// beta 'prefix' option, makes prefills work the same way they work for Claude -function enablePrefill(req: Request) { - // If you want to disable - if (process.env.NO_XAI_PREFILL) return - - const msgs = req.body.messages; - if (msgs.at(-1)?.role !== 'assistant') return; - - let i = msgs.length - 1; - let content = ''; - - while (i >= 0 && msgs[i].role === 'assistant') { - // maybe we should also add a newline between messages? no for now. - content = msgs[i--].content + content; - } - - msgs.splice(i + 1, msgs.length, { role: 'assistant', content, prefix: true }); -} - -xaiRouter.post( - "/v1/chat/completions", - ipLimiter, - createPreprocessorMiddleware( - { inApi: "openai", outApi: "openai", service: "xai" }, - { afterTransform: [ enablePrefill ] } - ), - xaiProxy -); - -xaiRouter.get("/v1/models", handleModelRequest); - +import { Request, RequestHandler, Router } from "express"; +import { createPreprocessorMiddleware } from "./middleware/request"; +import { ipLimiter } from "./rate-limit"; +import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; +import { addKey, finalizeBody } from "./middleware/request"; +import { ProxyResHandlerWithBody } from "./middleware/response"; +import axios from "axios"; +import { XaiKey, keyPool } from "../shared/key-management"; + +let modelsCache: any = null; +let modelsCacheTime = 0; + +const xaiResponseHandler: ProxyResHandlerWithBody = async ( + _proxyRes, + req, + res, + body +) => { + if (typeof body !== "object") { + throw new Error("Expected body to be an object"); + } + + let newBody = body; + + res.status(200).json({ ...newBody, proxy: body.proxy }); +}; + +const getModelsResponse = async () => { + // Return cache if less than 1 minute old + if (new Date().getTime() - modelsCacheTime < 1000 * 60) { + return modelsCache; + } + + try { + // Get an XAI key directly using keyPool.get() + const modelToUse = "grok-3"; // Use any XAI model here - just for key selection + const xaiKey = keyPool.get(modelToUse, "xai") as XaiKey; + + if (!xaiKey || !xaiKey.key) { + throw new Error("Failed to get valid XAI key"); + } + + // Fetch models from XAI API with authorization + const response = await axios.get("https://api.x.ai/v1/models", { + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${xaiKey.key}` + }, + }); + + // If successful, update the cache + if (response.data && response.data.data) { + modelsCache = { + object: "list", + data: response.data.data.map((model: any) => ({ + id: model.id, + object: "model", + owned_by: "xai", + })), + }; + } else { + throw new Error("Unexpected response format from XAI API"); + } + } catch (error) { + console.error("Error fetching XAI models:", error); + throw error; // No fallback - error will be passed to caller + } + + modelsCacheTime = new Date().getTime(); + return modelsCache; +}; + +const handleModelRequest: RequestHandler = async (_req, res) => { + try { + const modelsResponse = await getModelsResponse(); + res.status(200).json(modelsResponse); + } catch (error) { + console.error("Error in handleModelRequest:", error); + res.status(500).json({ error: "Failed to fetch models" }); + } +}; + +const xaiProxy = createQueuedProxyMiddleware({ + mutations: [addKey, finalizeBody], + target: "https://api.x.ai", + blockingResponseHandler: xaiResponseHandler, +}); + +const xaiRouter = Router(); + +// combines all the assistant messages at the end of the context and adds the +// beta 'prefix' option, makes prefills work the same way they work for Claude +function enablePrefill(req: Request) { + // If you want to disable + if (process.env.NO_XAI_PREFILL) return + + const msgs = req.body.messages; + if (msgs.at(-1)?.role !== 'assistant') return; + + let i = msgs.length - 1; + let content = ''; + + while (i >= 0 && msgs[i].role === 'assistant') { + // maybe we should also add a newline between messages? no for now. + content = msgs[i--].content + content; + } + + msgs.splice(i + 1, msgs.length, { role: 'assistant', content, prefix: true }); +} + +// Function to remove parameters not supported by X.AI/Grok models +function removeUnsupportedParameters(req: Request) { + const model = req.body.model; + + // Check if this is a grok-3-mini variant + // This will match grok-3-mini, grok-3-mini-fast, grok-3-mini-beta, grok-3-mini-fast-beta, etc. + const isGrok3Mini = /^grok-3-mini(-[a-z]+)*(-beta)?$/.test(model); + + if (isGrok3Mini) { + // List of parameters not supported by Grok-3-mini models + const unsupportedParams = [ + 'presence_penalty', + 'frequency_penalty' + ]; + + for (const param of unsupportedParams) { + if (req.body[param] !== undefined) { + req.log.info(`Removing unsupported parameter for Grok-3-mini model: ${param}`); + delete req.body[param]; + } + } + + // Support reasoning_effort parameter + if (req.body.reasoning_effort) { + // If reasoning_effort is already present in the request, validate it + if (!['low', 'medium', 'high'].includes(req.body.reasoning_effort)) { + req.log.warn(`Invalid reasoning_effort value: ${req.body.reasoning_effort}, removing it`); + delete req.body.reasoning_effort; + } + } + } +} + +xaiRouter.post( + "/v1/chat/completions", + ipLimiter, + createPreprocessorMiddleware( + { inApi: "openai", outApi: "openai", service: "xai" }, + { afterTransform: [ enablePrefill, removeUnsupportedParameters ] } + ), + xaiProxy +); + +xaiRouter.get("/v1/models", handleModelRequest); + export const xai = xaiRouter; \ No newline at end of file