attempts to auto-convert Mistral prompts for its more strict rules

This commit is contained in:
nai-degen
2024-01-28 17:42:23 -06:00
parent 3f2f30e605
commit 924db33f7e
2 changed files with 47 additions and 4 deletions
@@ -106,9 +106,7 @@ export const OpenAIV1ChatCompletionSchema = z
// Tool usage must be enabled via config because we currently have no way to
// track quota usage for them or enforce limits.
.omit(
Boolean(config.allowOpenAIToolUsage)
? {}
: { tools: true, functions: true }
Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, functions: true }
)
.strip();
@@ -233,6 +231,15 @@ export const transformOutboundPayload: RequestPreprocessor = async (req) => {
if (alreadyTransformed || notTransformable) return;
if (req.inboundApi === "mistral-ai") {
const messages = req.body.messages;
req.body.messages = fixMistralPrompt(messages);
req.log.info(
{ old: messages.length, new: req.body.messages.length },
"Fixed Mistral prompt"
);
}
if (sameService) {
const result = VALIDATORS[req.inboundApi].safeParse(req.body);
if (!result.success) {
@@ -540,3 +547,39 @@ function flattenOpenAIMessageContent(
.join("\n")
: content;
}
function fixMistralPrompt(
messages: MistralAIChatMessage[]
): MistralAIChatMessage[] {
// Mistral uses OpenAI format but has some additional requirements:
// - Only one system message per request, and it must be the first message if
// present.
// - Final message must be a user message.
// - Cannot have multiple messages from the same role in a row.
// While frontends should be able to handle this, we can fix it here in the
// meantime.
const result = messages.reduce<MistralAIChatMessage[]>((acc, msg) => {
if (acc.length === 0) {
acc.push(msg);
return acc;
}
const copy = { ...msg };
// Reattribute subsequent system messages to the user
if (msg.role === "system") {
copy.role = "user";
}
// Consolidate multiple messages from the same role
const last = acc[acc.length - 1];
if (last.role === copy.role) {
last.content += "\n\n" + copy.content;
} else {
acc.push(copy);
}
return acc;
}, []);
return result;
}