AIX: improved all uplinks

This commit is contained in:
Enrico Ros
2024-07-03 19:57:09 -07:00
parent eb7a32ed16
commit 2a410f52b5
6 changed files with 182 additions and 63 deletions
@@ -10,7 +10,7 @@ import { fixupHost } from '~/common/util/urlUtils';
import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';
import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema } from '../llm.server.types';
import { AnthropicWireMessagesRequest, anthropicWireMessagesRequestSchema, AnthropicWireMessagesResponse, anthropicWireMessagesResponseSchema } from './anthropic.wiretypes';
import { AnthropicWireMessageCreate, anthropicWireMessageCreateSchema, AnthropicWireMessageResponse, anthropicWireMessageResponseSchema } from '~/modules/llms/server/anthropic/anthropic.wiretypes';
import { hardcodedAnthropicModels } from './anthropic.models';
@@ -63,7 +63,7 @@ export function anthropicAccess(access: AnthropicAccessSchema, apiPath: string):
};
}
export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): AnthropicWireMessagesRequest {
export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): AnthropicWireMessageCreate {
// Take the System prompt, if it's the first message
// But if it's the only message, treat it as a user message
@@ -79,7 +79,7 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor
// skip empty messages
if (!historyItem.content.trim()) return acc;
const lastMessage: AnthropicWireMessagesRequest['messages'][number] | undefined = acc[acc.length - 1];
const lastMessage: AnthropicWireMessageCreate['messages'][number] | undefined = acc[acc.length - 1];
const anthropicRole = historyItem.role === 'assistant' ? 'assistant' : 'user';
if (index === 0 || anthropicRole !== lastMessage?.role) {
@@ -103,13 +103,13 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor
});
} else {
// Merge consecutive messages with the same role
(lastMessage.content as AnthropicWireMessagesRequest['messages'][number]['content']).push(
(lastMessage.content as AnthropicWireMessageCreate['messages'][number]['content']).push(
{ type: 'text', text: historyItem.content },
);
}
return acc;
},
[] as AnthropicWireMessagesRequest['messages'],
[] as AnthropicWireMessageCreate['messages'],
);
// NOTE: if the last message is 'assistant', then the API will perform a continuation - shall we add a user message? TBD
@@ -120,9 +120,9 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor
// messages.push({ role: 'user', content: [{ type: 'text', text: '' }] });
// Construct the request payload
const payload: AnthropicWireMessagesRequest = {
const payload: AnthropicWireMessageCreate = {
model: model.id,
...(systemPrompt !== undefined && { system: systemPrompt }),
...(systemPrompt !== undefined && { system: [{ type: 'text', text: systemPrompt }] }),
messages: messages,
max_tokens: model.maxTokens || DEFAULT_MAX_TOKENS,
stream: stream,
@@ -136,7 +136,7 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor
};
// Validate the payload against the schema to ensure correctness
const validated = anthropicWireMessagesRequestSchema.safeParse(payload);
const validated = anthropicWireMessageCreateSchema.safeParse(payload);
if (!validated.success)
throw new Error(`Invalid message sequence for Anthropic models: ${validated.error.errors?.[0]?.message || validated.error}`);
@@ -187,8 +187,8 @@ export const llmAnthropicRouter = createTRPCRouter({
// throw if the message sequence is not okay
const payload = anthropicMessagesPayloadOrThrow(model, history, false);
const response = await anthropicPOST<AnthropicWireMessagesResponse, AnthropicWireMessagesRequest>(access, payload, '/v1/messages');
const completion = anthropicWireMessagesResponseSchema.parse(response);
const response = await anthropicPOST<AnthropicWireMessageResponse, AnthropicWireMessageCreate>(access, payload, '/v1/messages');
const completion = anthropicWireMessageResponseSchema.parse(response);
// validate output
if (!completion || completion.type !== 'message' || completion.role !== 'assistant' || completion.stop_reason === undefined)
@@ -1,54 +1,95 @@
import { z } from 'zod';
// See the latest Anthropic Typescript definitions on:
// https://github.com/anthropics/anthropic-sdk-typescript/blob/main/src/resources/messages.ts
// text, e.g.: { 'type': 'text', 'text': 'Hello, Claude' }
const anthropicWireTextBlockSchema = z.object({
// Content Blocks
const anthropicWire_TextBlock_Schema = z.object({
type: z.literal('text'),
text: z.string(),
});
// image, e.g.: { 'type': 'image', 'source': { 'type': 'base64', 'media_type': 'image/jpeg', 'data': '/9j/4AAQSkZJRg...' } }
const anthropicWireImageBlockSchema = z.object({
const anthropicWire_ImageBlock_Schema = z.object({
type: z.literal('image'),
source: z.object({
type: z.enum(['base64']),
type: z.literal('base64'),
media_type: z.enum(['image/jpeg', 'image/png', 'image/gif', 'image/webp']),
data: z.string(),
}),
});
const anthropicWireMessagesSchema = z.array(
z.object({
role: z.enum(['user', 'assistant']),
// NOTE: could be a string or an array of text/image blocks, but for a better implementation
// we will assume it's always an array
// content: z.union([
// z.array(z.union([anthropicWireTextBlockSchema, anthropicWireImageBlockSchema])),
// z.string(),
// ]),
content: z.array(
z.union([
anthropicWireTextBlockSchema,
anthropicWireImageBlockSchema,
]),
),
}),
);
const anthropicWire_ToolUseBlock_Schema = z.object({
type: z.literal('tool_use'),
id: z.string(),
name: z.string(),
input: z.unknown(),
});
export const anthropicWireMessagesRequestSchema = z.object({
const anthropicWire_ToolResultBlock_Schema = z.object({
type: z.literal('tool_result'),
tool_use_id: z.string(),
// NOTE: could be a string too, but we force it to be an array for a better implementation
content: z.array(z.union([anthropicWire_TextBlock_Schema, anthropicWire_ImageBlock_Schema])).optional(),
is_error: z.boolean().optional(),
});
// Uplink
const anthropicWire_ContentBlockUL_Schema = z.discriminatedUnion('type', [
anthropicWire_TextBlock_Schema,
anthropicWire_ImageBlock_Schema,
anthropicWire_ToolUseBlock_Schema,
anthropicWire_ToolResultBlock_Schema,
]);
const anthropicWire_HistoryMessageUL_Schema = z.object({
role: z.enum(['user', 'assistant']),
content: z.array(anthropicWire_ContentBlockUL_Schema), // NOTE: could be a string, but we force it to be an array
});
const anthropicWire_ToolUL_Schema = z.object({
name: z.string(),
// Description of what this tool does.
description: z.string().optional(),
/**
* [JSON schema](https://json-schema.org/) for this tool's input.
*
* This defines the shape of the `input` that your tool accepts and that the model will provide.
*/
input_schema: z.object({
type: z.literal('object'),
properties: z.record(z.unknown()).optional(),
}).and(z.record(z.unknown())),
});
export type AnthropicWireMessageCreate = z.infer<typeof anthropicWireMessageCreateSchema>;
export const anthropicWireMessageCreateSchema = z.object({
/**
* (required) The maximum number of tokens to generate before stopping.
*/
max_tokens: z.number(),
/**
* (required) The model to use for generating the response.
* See [models](https://docs.anthropic.com/en/docs/models-overview) for additional details and options.
*/
model: z.string(),
/**
* If you want to include a system prompt, you can use the top-level system parameter — there is no "system" role for input messages in the Messages API.
*/
system: z.string().optional(),
system: z.array(anthropicWire_TextBlock_Schema).optional(),
/**
* (required) Input messages. - operates on alternating user and assistant conversational turns - the first message must always use the user role
* If the final message uses the assistant role, the response content will continue immediately from the content in that message.
* This can be used to constrain part of the model's response.
*/
messages: anthropicWireMessagesSchema.refine(
messages: z.array(anthropicWire_HistoryMessageUL_Schema).refine(
(messages) => {
// Ensure the first message uses the user role
@@ -66,10 +107,18 @@ export const anthropicWireMessagesRequestSchema = z.object({
),
/**
* (required) The maximum number of tokens to generate before stopping.
* How the model should use the provided tools. The model can use a specific tool, any available tool, or decide by itself.
*/
max_tokens: z.number(),
tool_choice: z.union([
z.object({ type: z.literal('auto') }),
z.object({ type: z.literal('any') }), // use one at least
z.object({ type: z.literal('tool'), name: z.string() }),
]).optional(),
/**
*
*/
tools: z.array(anthropicWire_ToolUL_Schema).optional(),
/**
* (optional) Metadata to include with the request.
@@ -89,27 +138,35 @@ export const anthropicWireMessagesRequestSchema = z.object({
*/
stream: z.boolean().optional(),
/**
* Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks.
*/
temperature: z.number().optional(),
/**
* Use nucleus sampling.
* Recommended for advanced use cases only. You usually only need to use temperature.
*/
top_p: z.number().optional(),
/**
* Only sample from the top K options for each subsequent token.
* Recommended for advanced use cases only. You usually only need to use temperature.
* Recommended for advanced use cases only. You usually only need to use `temperature`.
*/
top_k: z.number().optional(),
/**
* Use nucleus sampling.
* Recommended for advanced use cases only. You usually only need to use `temperature`.
* */
top_p: z.number().optional(),
});
export type AnthropicWireMessagesRequest = z.infer<typeof anthropicWireMessagesRequestSchema>;
export const anthropicWireMessagesResponseSchema = z.object({
/// Downlink
const anthropicWire_ContentBlockDL_Schema = z.discriminatedUnion('type', [
anthropicWire_TextBlock_Schema,
anthropicWire_ToolUseBlock_Schema,
]);
export type AnthropicWireMessageResponse = z.infer<typeof anthropicWireMessageResponseSchema>;
export const anthropicWireMessageResponseSchema = z.object({
// Unique object identifier.
id: z.string(),
@@ -117,14 +174,14 @@ export const anthropicWireMessagesResponseSchema = z.object({
type: z.literal('message'),
// Conversational role of the generated message. This will always be "assistant".
role: z.literal('assistant'),
// The model that handled the request.
model: z.string(),
/**
* Content generated by the model.
* This is an array of content blocks, each of which has a type that determines its shape. Currently, the only type in responses is "text".
*/
content: z.array(anthropicWireTextBlockSchema),
// The model that handled the request.
model: z.string(),
content: z.array(anthropicWire_ContentBlockDL_Schema),
/**
* This may be one the following values:
@@ -136,8 +193,7 @@ export const anthropicWireMessagesResponseSchema = z.object({
*
* In non-streaming mode this value is always non-null. In streaming mode, it is null in the message_start event and non-null otherwise.
*/
stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence']).nullable(),
stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence', 'tool_use']).nullable(),
// Which custom stop sequence was generated, if any.
stop_sequence: z.string().nullable(),
@@ -146,6 +202,56 @@ export const anthropicWireMessagesResponseSchema = z.object({
input_tokens: z.number(),
output_tokens: z.number(),
}),
});
export type AnthropicWireMessagesResponse = z.infer<typeof anthropicWireMessagesResponseSchema>;
// Events - Message
export const anthropicWire_MessageStartEvent_Schema = z.object({
type: z.literal('message_start'),
message: anthropicWireMessageResponseSchema,
});
export const anthropicWire_MessageStopEvent_Schema = z.object({
type: z.literal('message_stop'),
});
export const anthropicWire_MessageDeltaEvent_Schema = z.object({
type: z.literal('message_delta'),
// MessageDelta
delta: z.object({
stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence', 'tool_use']).nullable(),
stop_sequence: z.string().nullable(),
}),
// MessageDeltaUsage
usage: z.object({ output_tokens: z.number() }),
});
// Events - Content Block
export const anthropicWire_ContentBlockStartEvent_Schema = z.object({
type: z.literal('content_block_start'),
index: z.number(),
content_block: anthropicWire_ContentBlockDL_Schema,
});
export const anthropicWire_ContentBlockStopEvent_Schema = z.object({
type: z.literal('content_block_stop'),
index: z.number(),
});
export const anthropicWire_ContentBlockDeltaEvent_Schema = z.object({
type: z.literal('content_block_delta'),
index: z.number(),
delta: z.union([
z.object({
type: z.literal('text_delta'),
text: z.string(),
}),
z.object({
type: z.literal('input_json_delta'),
partial_json: z.string(),
}),
]),
});
@@ -185,7 +185,14 @@ export const geminiGeneratedContentResponseSchema = z.object({
tokenCount: z.number().optional(),
// groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls.
})).optional(),
// NOTE: promptFeedback is only send in the first chunk in a streaming response
usageMetadata: z.object({
promptTokenCount: z.number(),
candidatesTokenCount: z.number(),
totalTokenCount: z.number(),
}).optional(),
// NOTE: promptFeedback is only send in the first chunk in a streaming response
promptFeedback: z.object({
blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(),
safetyRatings: z.array(geminiSafetyRatingSchema).optional(),
@@ -6,7 +6,7 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, nonTrpcServerFetch
// Anthropic server imports
import { AnthropicWireMessagesResponse, anthropicWireMessagesResponseSchema } from './anthropic/anthropic.wiretypes';
import { AnthropicWireMessageResponse, anthropicWireMessageResponseSchema } from './anthropic/anthropic.wiretypes';
import { anthropicAccess, anthropicAccessSchema, anthropicMessagesPayloadOrThrow } from './anthropic/anthropic.router';
// Gemini server imports
@@ -265,7 +265,7 @@ function createUpstreamTransformer(muxingFormat: MuxingFormat, vendorTextParser:
/// Stream Parsers
function createStreamParserAnthropicMessages(): AIStreamParser {
let responseMessage: AnthropicWireMessagesResponse | null = null;
let responseMessage: AnthropicWireMessageResponse | null = null;
let hasErrored = false;
// Note: at this stage, the parser only returns the text content as text, which is streamed as text
@@ -287,7 +287,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser {
case 'message_start':
const firstMessage = !responseMessage;
const { message } = JSON.parse(data);
responseMessage = anthropicWireMessagesResponseSchema.parse(message);
responseMessage = anthropicWireMessageResponseSchema.parse(message);
// hack: prepend the model name to the first packet
if (firstMessage) {
const firstPacket: ChatStreamingPreambleModelSchema = { model: responseMessage.model };
@@ -301,7 +301,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser {
const { index, content_block } = JSON.parse(data);
if (responseMessage.content[index] === undefined)
responseMessage.content[index] = content_block;
text = responseMessage.content[index].text;
text = (responseMessage.content[index] as any).text;
} else
throw new Error('Unexpected content block start');
break;
@@ -314,7 +314,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser {
throw new Error(`Unexpected content block non-text delta (${delta.type})`);
if (responseMessage.content[index] === undefined)
throw new Error(`Unexpected content block delta location (${index})`);
responseMessage.content[index].text += delta.text;
(responseMessage.content[index] as any).text += delta.text;
text = delta.text;
} else
throw new Error('Unexpected content block delta');
@@ -101,10 +101,10 @@ export const wireOllamaChunkedOutputSchema = z.union([
// only on the last message
// context: z.array(z.number()), // non-chat endpoint
// total_duration: z.number(),
// prompt_eval_count: z.number(),
prompt_eval_count: z.number().optional(),
// prompt_eval_duration: z.number(),
// eval_count: z.number(),
// eval_duration: z.number(),
eval_count: z.number().optional(),
eval_duration: z.number().optional(),
}),
// Possible Error
@@ -105,6 +105,12 @@ export namespace OpenAIWire {
param: string | null;
code: string | null;
};
// [OpenRouter/LocalAI] Extended usage statistics
usage?: {
prompt_tokens?: number;
completion_tokens?: number;
total_tokens?: number;
};
}
}