From 2a410f52b56627460eb75d03bb0aacc9157800a7 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Wed, 3 Jul 2024 19:57:09 -0700 Subject: [PATCH] AIX: improved all uplinks --- .../llms/server/anthropic/anthropic.router.ts | 20 +- .../server/anthropic/anthropic.wiretypes.ts | 194 ++++++++++++++---- .../llms/server/gemini/gemini.wiretypes.ts | 9 +- .../llms/server/llm.server.streaming.ts | 10 +- .../llms/server/ollama/ollama.wiretypes.ts | 6 +- .../llms/server/openai/openai.wiretypes.ts | 6 + 6 files changed, 182 insertions(+), 63 deletions(-) diff --git a/src/modules/llms/server/anthropic/anthropic.router.ts b/src/modules/llms/server/anthropic/anthropic.router.ts index 5e2864c6b..169023626 100644 --- a/src/modules/llms/server/anthropic/anthropic.router.ts +++ b/src/modules/llms/server/anthropic/anthropic.router.ts @@ -10,7 +10,7 @@ import { fixupHost } from '~/common/util/urlUtils'; import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router'; import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema } from '../llm.server.types'; -import { AnthropicWireMessagesRequest, anthropicWireMessagesRequestSchema, AnthropicWireMessagesResponse, anthropicWireMessagesResponseSchema } from './anthropic.wiretypes'; +import { AnthropicWireMessageCreate, anthropicWireMessageCreateSchema, AnthropicWireMessageResponse, anthropicWireMessageResponseSchema } from '~/modules/llms/server/anthropic/anthropic.wiretypes'; import { hardcodedAnthropicModels } from './anthropic.models'; @@ -63,7 +63,7 @@ export function anthropicAccess(access: AnthropicAccessSchema, apiPath: string): }; } -export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): AnthropicWireMessagesRequest { +export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): AnthropicWireMessageCreate { // Take the System prompt, if it's the first message // But if it's the only message, treat it as a user message @@ -79,7 +79,7 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor // skip empty messages if (!historyItem.content.trim()) return acc; - const lastMessage: AnthropicWireMessagesRequest['messages'][number] | undefined = acc[acc.length - 1]; + const lastMessage: AnthropicWireMessageCreate['messages'][number] | undefined = acc[acc.length - 1]; const anthropicRole = historyItem.role === 'assistant' ? 'assistant' : 'user'; if (index === 0 || anthropicRole !== lastMessage?.role) { @@ -103,13 +103,13 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor }); } else { // Merge consecutive messages with the same role - (lastMessage.content as AnthropicWireMessagesRequest['messages'][number]['content']).push( + (lastMessage.content as AnthropicWireMessageCreate['messages'][number]['content']).push( { type: 'text', text: historyItem.content }, ); } return acc; }, - [] as AnthropicWireMessagesRequest['messages'], + [] as AnthropicWireMessageCreate['messages'], ); // NOTE: if the last message is 'assistant', then the API will perform a continuation - shall we add a user message? TBD @@ -120,9 +120,9 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor // messages.push({ role: 'user', content: [{ type: 'text', text: '' }] }); // Construct the request payload - const payload: AnthropicWireMessagesRequest = { + const payload: AnthropicWireMessageCreate = { model: model.id, - ...(systemPrompt !== undefined && { system: systemPrompt }), + ...(systemPrompt !== undefined && { system: [{ type: 'text', text: systemPrompt }] }), messages: messages, max_tokens: model.maxTokens || DEFAULT_MAX_TOKENS, stream: stream, @@ -136,7 +136,7 @@ export function anthropicMessagesPayloadOrThrow(model: OpenAIModelSchema, histor }; // Validate the payload against the schema to ensure correctness - const validated = anthropicWireMessagesRequestSchema.safeParse(payload); + const validated = anthropicWireMessageCreateSchema.safeParse(payload); if (!validated.success) throw new Error(`Invalid message sequence for Anthropic models: ${validated.error.errors?.[0]?.message || validated.error}`); @@ -187,8 +187,8 @@ export const llmAnthropicRouter = createTRPCRouter({ // throw if the message sequence is not okay const payload = anthropicMessagesPayloadOrThrow(model, history, false); - const response = await anthropicPOST(access, payload, '/v1/messages'); - const completion = anthropicWireMessagesResponseSchema.parse(response); + const response = await anthropicPOST(access, payload, '/v1/messages'); + const completion = anthropicWireMessageResponseSchema.parse(response); // validate output if (!completion || completion.type !== 'message' || completion.role !== 'assistant' || completion.stop_reason === undefined) diff --git a/src/modules/llms/server/anthropic/anthropic.wiretypes.ts b/src/modules/llms/server/anthropic/anthropic.wiretypes.ts index 54918716c..a6696f7c5 100644 --- a/src/modules/llms/server/anthropic/anthropic.wiretypes.ts +++ b/src/modules/llms/server/anthropic/anthropic.wiretypes.ts @@ -1,54 +1,95 @@ import { z } from 'zod'; +// See the latest Anthropic Typescript definitions on: +// https://github.com/anthropics/anthropic-sdk-typescript/blob/main/src/resources/messages.ts -// text, e.g.: { 'type': 'text', 'text': 'Hello, Claude' } -const anthropicWireTextBlockSchema = z.object({ + +// Content Blocks + +const anthropicWire_TextBlock_Schema = z.object({ type: z.literal('text'), text: z.string(), }); -// image, e.g.: { 'type': 'image', 'source': { 'type': 'base64', 'media_type': 'image/jpeg', 'data': '/9j/4AAQSkZJRg...' } } -const anthropicWireImageBlockSchema = z.object({ +const anthropicWire_ImageBlock_Schema = z.object({ type: z.literal('image'), source: z.object({ - type: z.enum(['base64']), + type: z.literal('base64'), media_type: z.enum(['image/jpeg', 'image/png', 'image/gif', 'image/webp']), data: z.string(), }), }); -const anthropicWireMessagesSchema = z.array( - z.object({ - role: z.enum(['user', 'assistant']), - // NOTE: could be a string or an array of text/image blocks, but for a better implementation - // we will assume it's always an array - // content: z.union([ - // z.array(z.union([anthropicWireTextBlockSchema, anthropicWireImageBlockSchema])), - // z.string(), - // ]), - content: z.array( - z.union([ - anthropicWireTextBlockSchema, - anthropicWireImageBlockSchema, - ]), - ), - }), -); +const anthropicWire_ToolUseBlock_Schema = z.object({ + type: z.literal('tool_use'), + id: z.string(), + name: z.string(), + input: z.unknown(), +}); -export const anthropicWireMessagesRequestSchema = z.object({ +const anthropicWire_ToolResultBlock_Schema = z.object({ + type: z.literal('tool_result'), + tool_use_id: z.string(), + // NOTE: could be a string too, but we force it to be an array for a better implementation + content: z.array(z.union([anthropicWire_TextBlock_Schema, anthropicWire_ImageBlock_Schema])).optional(), + is_error: z.boolean().optional(), +}); + + +// Uplink + +const anthropicWire_ContentBlockUL_Schema = z.discriminatedUnion('type', [ + anthropicWire_TextBlock_Schema, + anthropicWire_ImageBlock_Schema, + anthropicWire_ToolUseBlock_Schema, + anthropicWire_ToolResultBlock_Schema, +]); + +const anthropicWire_HistoryMessageUL_Schema = z.object({ + role: z.enum(['user', 'assistant']), + content: z.array(anthropicWire_ContentBlockUL_Schema), // NOTE: could be a string, but we force it to be an array +}); + +const anthropicWire_ToolUL_Schema = z.object({ + name: z.string(), + // Description of what this tool does. + description: z.string().optional(), + /** + * [JSON schema](https://json-schema.org/) for this tool's input. + * + * This defines the shape of the `input` that your tool accepts and that the model will provide. + */ + input_schema: z.object({ + type: z.literal('object'), + properties: z.record(z.unknown()).optional(), + }).and(z.record(z.unknown())), +}); + + +export type AnthropicWireMessageCreate = z.infer; +export const anthropicWireMessageCreateSchema = z.object({ + /** + * (required) The maximum number of tokens to generate before stopping. + */ + max_tokens: z.number(), + + /** + * (required) The model to use for generating the response. + * See [models](https://docs.anthropic.com/en/docs/models-overview) for additional details and options. + */ model: z.string(), /** * If you want to include a system prompt, you can use the top-level system parameter — there is no "system" role for input messages in the Messages API. */ - system: z.string().optional(), + system: z.array(anthropicWire_TextBlock_Schema).optional(), /** * (required) Input messages. - operates on alternating user and assistant conversational turns - the first message must always use the user role * If the final message uses the assistant role, the response content will continue immediately from the content in that message. * This can be used to constrain part of the model's response. */ - messages: anthropicWireMessagesSchema.refine( + messages: z.array(anthropicWire_HistoryMessageUL_Schema).refine( (messages) => { // Ensure the first message uses the user role @@ -66,10 +107,18 @@ export const anthropicWireMessagesRequestSchema = z.object({ ), /** - * (required) The maximum number of tokens to generate before stopping. + * How the model should use the provided tools. The model can use a specific tool, any available tool, or decide by itself. */ - max_tokens: z.number(), + tool_choice: z.union([ + z.object({ type: z.literal('auto') }), + z.object({ type: z.literal('any') }), // use one at least + z.object({ type: z.literal('tool'), name: z.string() }), + ]).optional(), + /** + * + */ + tools: z.array(anthropicWire_ToolUL_Schema).optional(), /** * (optional) Metadata to include with the request. @@ -89,27 +138,35 @@ export const anthropicWireMessagesRequestSchema = z.object({ */ stream: z.boolean().optional(), + /** * Defaults to 1.0. Ranges from 0.0 to 1.0. Use temperature closer to 0.0 for analytical / multiple choice, and closer to 1.0 for creative and generative tasks. */ temperature: z.number().optional(), - /** - * Use nucleus sampling. - * Recommended for advanced use cases only. You usually only need to use temperature. - */ - top_p: z.number().optional(), - /** * Only sample from the top K options for each subsequent token. - * Recommended for advanced use cases only. You usually only need to use temperature. + * Recommended for advanced use cases only. You usually only need to use `temperature`. */ top_k: z.number().optional(), + + /** + * Use nucleus sampling. + * Recommended for advanced use cases only. You usually only need to use `temperature`. + * */ + top_p: z.number().optional(), }); -export type AnthropicWireMessagesRequest = z.infer; -export const anthropicWireMessagesResponseSchema = z.object({ +/// Downlink + +const anthropicWire_ContentBlockDL_Schema = z.discriminatedUnion('type', [ + anthropicWire_TextBlock_Schema, + anthropicWire_ToolUseBlock_Schema, +]); + +export type AnthropicWireMessageResponse = z.infer; +export const anthropicWireMessageResponseSchema = z.object({ // Unique object identifier. id: z.string(), @@ -117,14 +174,14 @@ export const anthropicWireMessagesResponseSchema = z.object({ type: z.literal('message'), // Conversational role of the generated message. This will always be "assistant". role: z.literal('assistant'), + // The model that handled the request. + model: z.string(), + /** * Content generated by the model. * This is an array of content blocks, each of which has a type that determines its shape. Currently, the only type in responses is "text". */ - content: z.array(anthropicWireTextBlockSchema), - - // The model that handled the request. - model: z.string(), + content: z.array(anthropicWire_ContentBlockDL_Schema), /** * This may be one the following values: @@ -136,8 +193,7 @@ export const anthropicWireMessagesResponseSchema = z.object({ * * In non-streaming mode this value is always non-null. In streaming mode, it is null in the message_start event and non-null otherwise. */ - stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence']).nullable(), - + stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence', 'tool_use']).nullable(), // Which custom stop sequence was generated, if any. stop_sequence: z.string().nullable(), @@ -146,6 +202,56 @@ export const anthropicWireMessagesResponseSchema = z.object({ input_tokens: z.number(), output_tokens: z.number(), }), - }); -export type AnthropicWireMessagesResponse = z.infer; + + +// Events - Message + +export const anthropicWire_MessageStartEvent_Schema = z.object({ + type: z.literal('message_start'), + message: anthropicWireMessageResponseSchema, +}); + +export const anthropicWire_MessageStopEvent_Schema = z.object({ + type: z.literal('message_stop'), +}); + +export const anthropicWire_MessageDeltaEvent_Schema = z.object({ + type: z.literal('message_delta'), + // MessageDelta + delta: z.object({ + stop_reason: z.enum(['end_turn', 'max_tokens', 'stop_sequence', 'tool_use']).nullable(), + stop_sequence: z.string().nullable(), + }), + // MessageDeltaUsage + usage: z.object({ output_tokens: z.number() }), +}); + + +// Events - Content Block + +export const anthropicWire_ContentBlockStartEvent_Schema = z.object({ + type: z.literal('content_block_start'), + index: z.number(), + content_block: anthropicWire_ContentBlockDL_Schema, +}); + +export const anthropicWire_ContentBlockStopEvent_Schema = z.object({ + type: z.literal('content_block_stop'), + index: z.number(), +}); + +export const anthropicWire_ContentBlockDeltaEvent_Schema = z.object({ + type: z.literal('content_block_delta'), + index: z.number(), + delta: z.union([ + z.object({ + type: z.literal('text_delta'), + text: z.string(), + }), + z.object({ + type: z.literal('input_json_delta'), + partial_json: z.string(), + }), + ]), +}); diff --git a/src/modules/llms/server/gemini/gemini.wiretypes.ts b/src/modules/llms/server/gemini/gemini.wiretypes.ts index 318599685..4032ed1f7 100644 --- a/src/modules/llms/server/gemini/gemini.wiretypes.ts +++ b/src/modules/llms/server/gemini/gemini.wiretypes.ts @@ -185,7 +185,14 @@ export const geminiGeneratedContentResponseSchema = z.object({ tokenCount: z.number().optional(), // groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls. })).optional(), - // NOTE: promptFeedback is only send in the first chunk in a streaming response + + usageMetadata: z.object({ + promptTokenCount: z.number(), + candidatesTokenCount: z.number(), + totalTokenCount: z.number(), + }).optional(), + +// NOTE: promptFeedback is only send in the first chunk in a streaming response promptFeedback: z.object({ blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(), safetyRatings: z.array(geminiSafetyRatingSchema).optional(), diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index d632a4522..aac651e34 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -6,7 +6,7 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, nonTrpcServerFetch // Anthropic server imports -import { AnthropicWireMessagesResponse, anthropicWireMessagesResponseSchema } from './anthropic/anthropic.wiretypes'; +import { AnthropicWireMessageResponse, anthropicWireMessageResponseSchema } from './anthropic/anthropic.wiretypes'; import { anthropicAccess, anthropicAccessSchema, anthropicMessagesPayloadOrThrow } from './anthropic/anthropic.router'; // Gemini server imports @@ -265,7 +265,7 @@ function createUpstreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: /// Stream Parsers function createStreamParserAnthropicMessages(): AIStreamParser { - let responseMessage: AnthropicWireMessagesResponse | null = null; + let responseMessage: AnthropicWireMessageResponse | null = null; let hasErrored = false; // Note: at this stage, the parser only returns the text content as text, which is streamed as text @@ -287,7 +287,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser { case 'message_start': const firstMessage = !responseMessage; const { message } = JSON.parse(data); - responseMessage = anthropicWireMessagesResponseSchema.parse(message); + responseMessage = anthropicWireMessageResponseSchema.parse(message); // hack: prepend the model name to the first packet if (firstMessage) { const firstPacket: ChatStreamingPreambleModelSchema = { model: responseMessage.model }; @@ -301,7 +301,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser { const { index, content_block } = JSON.parse(data); if (responseMessage.content[index] === undefined) responseMessage.content[index] = content_block; - text = responseMessage.content[index].text; + text = (responseMessage.content[index] as any).text; } else throw new Error('Unexpected content block start'); break; @@ -314,7 +314,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser { throw new Error(`Unexpected content block non-text delta (${delta.type})`); if (responseMessage.content[index] === undefined) throw new Error(`Unexpected content block delta location (${index})`); - responseMessage.content[index].text += delta.text; + (responseMessage.content[index] as any).text += delta.text; text = delta.text; } else throw new Error('Unexpected content block delta'); diff --git a/src/modules/llms/server/ollama/ollama.wiretypes.ts b/src/modules/llms/server/ollama/ollama.wiretypes.ts index 9a600e993..7602b3488 100644 --- a/src/modules/llms/server/ollama/ollama.wiretypes.ts +++ b/src/modules/llms/server/ollama/ollama.wiretypes.ts @@ -101,10 +101,10 @@ export const wireOllamaChunkedOutputSchema = z.union([ // only on the last message // context: z.array(z.number()), // non-chat endpoint // total_duration: z.number(), - // prompt_eval_count: z.number(), + prompt_eval_count: z.number().optional(), // prompt_eval_duration: z.number(), - // eval_count: z.number(), - // eval_duration: z.number(), + eval_count: z.number().optional(), + eval_duration: z.number().optional(), }), // Possible Error diff --git a/src/modules/llms/server/openai/openai.wiretypes.ts b/src/modules/llms/server/openai/openai.wiretypes.ts index 898595eda..e1bb81939 100644 --- a/src/modules/llms/server/openai/openai.wiretypes.ts +++ b/src/modules/llms/server/openai/openai.wiretypes.ts @@ -105,6 +105,12 @@ export namespace OpenAIWire { param: string | null; code: string | null; }; + // [OpenRouter/LocalAI] Extended usage statistics + usage?: { + prompt_tokens?: number; + completion_tokens?: number; + total_tokens?: number; + }; } }