223 lines
6.4 KiB
TypeScript
223 lines
6.4 KiB
TypeScript
import { Request, RequestHandler, Router } from "express";
|
|
import { createPreprocessorMiddleware } from "./middleware/request";
|
|
import { ipLimiter } from "./rate-limit";
|
|
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
|
|
import { addKey, finalizeBody } from "./middleware/request";
|
|
import { ProxyResHandlerWithBody } from "./middleware/response";
|
|
import axios from "axios";
|
|
import { CohereKey, keyPool } from "../shared/key-management";
|
|
import { isCohereModel, normalizeMessages } from "../shared/api-schemas/cohere";
|
|
import { logger } from "../logger";
|
|
|
|
const log = logger.child({ module: "proxy", service: "cohere" });
|
|
let modelsCache: any = null;
|
|
let modelsCacheTime = 0;
|
|
|
|
const cohereResponseHandler: ProxyResHandlerWithBody = async (
|
|
_proxyRes,
|
|
req,
|
|
res,
|
|
body
|
|
) => {
|
|
if (typeof body !== "object") {
|
|
throw new Error("Expected body to be an object");
|
|
}
|
|
|
|
res.status(200).json({ ...body, proxy: body.proxy });
|
|
};
|
|
|
|
const getModelsResponse = async () => {
|
|
// Return cache if less than 1 minute old
|
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
|
|
return modelsCache;
|
|
}
|
|
|
|
try {
|
|
// Get a Cohere key directly
|
|
const modelToUse = "command"; // Use any Cohere model here - just for key selection
|
|
const cohereKey = keyPool.get(modelToUse, "cohere") as CohereKey;
|
|
|
|
if (!cohereKey || !cohereKey.key) {
|
|
log.warn("No valid Cohere key available for model listing");
|
|
throw new Error("No valid Cohere API key available");
|
|
}
|
|
|
|
// Fetch models directly from Cohere API
|
|
const response = await axios.get("https://api.cohere.com/v1/models", {
|
|
headers: {
|
|
"Content-Type": "application/json",
|
|
"Authorization": `Bearer ${cohereKey.key}`,
|
|
"Cohere-Version": "2022-12-06"
|
|
},
|
|
});
|
|
|
|
if (!response.data || !response.data.models) {
|
|
throw new Error("Unexpected response format from Cohere API");
|
|
}
|
|
|
|
// Extract models and filter by those that support the chat endpoint
|
|
const filteredModels = response.data.models
|
|
.filter((model: any) => {
|
|
return model.endpoints && model.endpoints.includes("chat");
|
|
})
|
|
.map((model: any) => ({
|
|
id: model.name,
|
|
name: model.name,
|
|
// Adding additional OpenAI-compatible fields
|
|
context_window: model.context_window_size || 4096,
|
|
max_tokens: model.max_tokens || 4096
|
|
}));
|
|
|
|
log.debug({ modelCount: filteredModels.length, models: filteredModels.map((m: any) => m.id) }, "Filtered models from Cohere API");
|
|
|
|
// Format response to ensure OpenAI compatibility
|
|
const models = {
|
|
object: "list",
|
|
data: filteredModels.map((model: any) => ({
|
|
id: model.id,
|
|
object: "model",
|
|
created: Math.floor(Date.now() / 1000),
|
|
owned_by: "cohere",
|
|
permission: [],
|
|
root: model.id,
|
|
parent: null,
|
|
context_length: model.context_window,
|
|
})),
|
|
};
|
|
|
|
log.debug({ modelCount: filteredModels.length }, "Retrieved models from Cohere API");
|
|
|
|
// Cache the response
|
|
modelsCache = models;
|
|
modelsCacheTime = new Date().getTime();
|
|
return models;
|
|
} catch (error) {
|
|
// Provide detailed logging for better troubleshooting
|
|
if (error instanceof Error) {
|
|
log.error(
|
|
{ errorMessage: error.message, stack: error.stack },
|
|
"Error fetching Cohere models"
|
|
);
|
|
} else {
|
|
log.error({ error }, "Unknown error fetching Cohere models");
|
|
}
|
|
|
|
// Return empty list as fallback
|
|
return {
|
|
object: "list",
|
|
data: [],
|
|
};
|
|
}
|
|
};
|
|
|
|
const handleModelRequest: RequestHandler = async (_req, res) => {
|
|
try {
|
|
const models = await getModelsResponse();
|
|
res.status(200).json(models);
|
|
} catch (error) {
|
|
if (error instanceof Error) {
|
|
log.error(
|
|
{ errorMessage: error.message, stack: error.stack },
|
|
"Error handling model request"
|
|
);
|
|
} else {
|
|
log.error({ error }, "Unknown error handling model request");
|
|
}
|
|
res.status(500).json({ error: "Failed to fetch models" });
|
|
}
|
|
};
|
|
|
|
// Function to prepare messages for Cohere API
|
|
function prepareMessages(req: Request) {
|
|
if (req.body.messages && Array.isArray(req.body.messages)) {
|
|
req.body.messages = normalizeMessages(req.body.messages);
|
|
}
|
|
}
|
|
|
|
// Function to remove parameters not supported by Cohere models
|
|
function removeUnsupportedParameters(req: Request) {
|
|
const model = req.body.model;
|
|
|
|
// Remove parameters that Cohere doesn't support
|
|
if (req.body.logit_bias !== undefined) {
|
|
delete req.body.logit_bias;
|
|
}
|
|
|
|
if (req.body.top_logprobs !== undefined) {
|
|
delete req.body.top_logprobs;
|
|
}
|
|
|
|
if (req.body.max_completion_tokens !== undefined) {
|
|
delete req.body.max_completion_tokens;
|
|
}
|
|
|
|
// Handle structured output format
|
|
if (req.body.response_format && req.body.response_format.schema) {
|
|
// Transform to Cohere's format if needed
|
|
const jsonSchema = req.body.response_format.schema;
|
|
req.body.response_format = {
|
|
type: "json_object",
|
|
schema: jsonSchema
|
|
};
|
|
}
|
|
|
|
// Logging for debugging
|
|
if (process.env.NODE_ENV !== 'production') {
|
|
log.debug({ body: req.body }, "Request after parameter cleanup");
|
|
}
|
|
}
|
|
|
|
// Set up count token functionality for Cohere models
|
|
function countCohereTokens(req: Request) {
|
|
const model = req.body.model;
|
|
|
|
if (isCohereModel(model)) {
|
|
// Count tokens using prompt tokens (simplified)
|
|
if (req.promptTokens) {
|
|
req.log.debug(
|
|
{ tokens: req.promptTokens },
|
|
"Estimated token count for Cohere prompt"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
const cohereProxy = createQueuedProxyMiddleware({
|
|
mutations: [
|
|
addKey,
|
|
// Add Cohere-Version header to every request
|
|
(manager) => {
|
|
manager.setHeader("Cohere-Version", "2022-12-06");
|
|
},
|
|
finalizeBody
|
|
],
|
|
target: "https://api.cohere.ai/compatibility",
|
|
blockingResponseHandler: cohereResponseHandler,
|
|
});
|
|
|
|
const cohereRouter = Router();
|
|
|
|
cohereRouter.post(
|
|
"/v1/chat/completions",
|
|
ipLimiter,
|
|
createPreprocessorMiddleware(
|
|
{ inApi: "openai", outApi: "openai", service: "cohere" },
|
|
{ afterTransform: [ prepareMessages, removeUnsupportedParameters, countCohereTokens ] }
|
|
),
|
|
cohereProxy
|
|
);
|
|
|
|
cohereRouter.post(
|
|
"/v1/embeddings",
|
|
ipLimiter,
|
|
createPreprocessorMiddleware(
|
|
{ inApi: "openai", outApi: "openai", service: "cohere" },
|
|
{ afterTransform: [] }
|
|
),
|
|
cohereProxy
|
|
);
|
|
|
|
cohereRouter.get("/v1/models", handleModelRequest);
|
|
|
|
export const cohere = cohereRouter;
|