AIX: Intake: improve schemas

This commit is contained in:
Enrico Ros
2024-07-12 01:48:23 -07:00
parent 03d633715a
commit 6fc6b23f38
5 changed files with 62 additions and 64 deletions
+5 -5
View File
@@ -6,7 +6,7 @@ import { fetchResponseOrTRPCThrow } from '~/server/api/trpc.router.fetchers';
import { IntakeHandler } from './intake/IntakeHandler';
import { createDispatch } from './dispatch/createDispatch';
import { intakeAccessSchema, intakeChatGenerateRequestSchema, intakeContextChatStreamSchema, intakeModelSchema } from './intake/schemas.intake.api';
import { intake_Access_Schema, intake_ChatGenerateRequest_Schema, intake_ContextChatStream_Schema, intake_Model_Schema } from './intake/schemas.intake.api';
export const aixRouter = createTRPCRouter({
@@ -17,10 +17,10 @@ export const aixRouter = createTRPCRouter({
*/
chatGenerateContent: publicProcedure
.input(z.object({
access: intakeAccessSchema,
model: intakeModelSchema,
chatGenerate: intakeChatGenerateRequestSchema,
context: intakeContextChatStreamSchema,
access: intake_Access_Schema,
model: intake_Model_Schema,
chatGenerate: intake_ChatGenerateRequest_Schema,
context: intake_ContextChatStream_Schema,
streaming: z.boolean(),
_debugRequestBody: z.boolean().optional(),
}))
@@ -1,6 +1,6 @@
import type { IntakeChatGenerateRequest, IntakeModel } from '../../intake/schemas.intake.api';
import type { IntakeChatMessage } from '../../intake/schemas.intake.parts';
import type { IntakeToolDefinition, IntakeToolsPolicy } from '../../intake/schemas.intake.tools';
import type { Intake_ChatGenerateRequest, Intake_Model } from '../../intake/schemas.intake.api';
import type { Intake_ChatMessage } from '../../intake/schemas.intake.parts';
import type { Intake_ToolDefinition, Intake_ToolsPolicy } from '../../intake/schemas.intake.tools';
import { anthropicWire_ImageBlock, AnthropicWire_MessageCreate, anthropicWire_MessageCreate_Schema, anthropicWire_TextBlock, anthropicWire_ToolResultBlock, anthropicWire_ToolUseBlock } from './anthropic.wiretypes';
@@ -8,12 +8,10 @@ import { anthropicWire_ImageBlock, AnthropicWire_MessageCreate, anthropicWire_Me
// configuration
const hotFixImagePartsFirst = true;
const hotFixMapModelImagesToUser = true;
// max from https://docs.anthropic.com/en/docs/about-claude/models
const ANTHROPIC_FALLBACK_MAX_TOKENS = 4096;
const hotFixMissingTokens = 4096; // [2024-07-12] max from https://docs.anthropic.com/en/docs/about-claude/models
export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate: IntakeChatGenerateRequest, streaming: boolean): AnthropicWire_MessageCreate {
export function intakeToAnthropicMessageCreate(model: Intake_Model, chatGenerate: Intake_ChatGenerateRequest, streaming: boolean): AnthropicWire_MessageCreate {
// Convert the system message
const systemMessage: AnthropicWire_MessageCreate['system'] = chatGenerate.systemMessage?.parts.length
@@ -38,7 +36,7 @@ export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate:
// Construct the request payload
const payload: AnthropicWire_MessageCreate = {
max_tokens: model.maxTokens !== undefined ? model.maxTokens : ANTHROPIC_FALLBACK_MAX_TOKENS,
max_tokens: model.maxTokens !== undefined ? model.maxTokens : hotFixMissingTokens,
model: model.id,
system: systemMessage,
messages: chatMessages,
@@ -61,7 +59,7 @@ export function intakeToAnthropicMessageCreate(model: IntakeModel, chatGenerate:
}
function* _generateAnthropicMessagesContentBlocks({ parts, role }: IntakeChatMessage): Generator<{
function* _generateAnthropicMessagesContentBlocks({ parts, role }: Intake_ChatMessage): Generator<{
role: 'user' | 'assistant',
content: AnthropicWire_MessageCreate['messages'][number]['content'][number]
}> {
@@ -130,7 +128,7 @@ function* _generateAnthropicMessagesContentBlocks({ parts, role }: IntakeChatMes
}
}
function _intakeToAnthropicTools(itds: IntakeToolDefinition[]): NonNullable<AnthropicWire_MessageCreate['tools']> {
function _intakeToAnthropicTools(itds: Intake_ToolDefinition[]): NonNullable<AnthropicWire_MessageCreate['tools']> {
return itds.map(itd => {
switch (itd.type) {
case 'function_call':
@@ -152,7 +150,7 @@ function _intakeToAnthropicTools(itds: IntakeToolDefinition[]): NonNullable<Anth
});
}
function _intakeToAnthropicToolChoice(itp: IntakeToolsPolicy): NonNullable<AnthropicWire_MessageCreate['tool_choice']> {
function _intakeToAnthropicToolChoice(itp: Intake_ToolsPolicy): NonNullable<AnthropicWire_MessageCreate['tool_choice']> {
switch (itp.type) {
case 'auto':
return { type: 'auto' as const };
@@ -5,20 +5,20 @@ import { geminiAccessSchema } from '~/modules/llms/server/gemini/gemini.router';
import { ollamaAccessSchema } from '~/modules/llms/server/ollama/ollama.router';
import { openAIAccessSchema } from '~/modules/llms/server/openai/openai.router';
import { intakeChatMessageSchema, intakeSystemMessageSchema } from './schemas.intake.parts';
import { intakeToolDefinitionSchema, intakeToolsPolicySchema } from './schemas.intake.tools';
import { intake_ChatMessage_Schema, intake_SystemMessage_Schema } from './schemas.intake.parts';
import { intake_ToolDefinition_Schema, intake_ToolsPolicy_Schema } from './schemas.intake.tools';
// Export types
export type IntakeAccess = z.infer<typeof intakeAccessSchema>;
export type IntakeModel = z.infer<typeof intakeModelSchema>;
export type IntakeChatGenerateRequest = z.infer<typeof intakeChatGenerateRequestSchema>;
export type IntakeContextChatStream = z.infer<typeof intakeContextChatStreamSchema>;
export type Intake_Access = z.infer<typeof intake_Access_Schema>;
export type Intake_Model = z.infer<typeof intake_Model_Schema>;
export type Intake_ChatGenerateRequest = z.infer<typeof intake_ChatGenerateRequest_Schema>;
export type Intake_ContextChatStream = z.infer<typeof intake_ContextChatStream_Schema>;
// Intake Access Schema
export const intakeAccessSchema = z.discriminatedUnion(
export const intake_Access_Schema = z.discriminatedUnion(
'dialect',
[
anthropicAccessSchema,
@@ -31,7 +31,7 @@ export const intakeAccessSchema = z.discriminatedUnion(
// Intake Model Schema
export const intakeModelSchema = z.object({
export const intake_Model_Schema = z.object({
id: z.string(),
temperature: z.number().min(0).max(2).optional(),
maxTokens: z.number().min(1).max(1000000).optional(),
@@ -40,17 +40,17 @@ export const intakeModelSchema = z.object({
// Intake Content Generation Schema
export const intakeChatGenerateRequestSchema = z.object({
systemMessage: intakeSystemMessageSchema.optional(),
chatSequence: z.array(intakeChatMessageSchema),
tools: z.array(intakeToolDefinitionSchema).optional(),
toolsPolicy: intakeToolsPolicySchema.optional(),
export const intake_ChatGenerateRequest_Schema = z.object({
systemMessage: intake_SystemMessage_Schema.optional(),
chatSequence: z.array(intake_ChatMessage_Schema),
tools: z.array(intake_ToolDefinition_Schema).optional(),
toolsPolicy: intake_ToolsPolicy_Schema.optional(),
});
// Intake Context (Streaming) Schema
export const intakeContextChatStreamSchema = z.object({
export const intake_ContextChatStream_Schema = z.object({
method: z.literal('chat-stream'),
name: z.enum(['conversation', 'ai-diagram', 'ai-flattener', 'call', 'beam-scatter', 'beam-gather', 'persona-extract']),
ref: z.string(),
@@ -2,26 +2,26 @@ import { z } from 'zod';
// Export types
export type IntakeInlineImagePart = z.infer<typeof intakeInlineImagePartSchema>;
export type IntakeMetaReplyToPart = z.infer<typeof intakeMetaReplyToPartSchema>;
export type IntakeChatMessage = z.infer<typeof intakeChatMessageSchema>;
export type IntakeSystemMessage = z.infer<typeof intakeSystemMessageSchema>;
export type Intake_InlineImagePart = z.infer<typeof intake_InlineImagePart_Schema>;
export type Intake_MetaReplyToPart = z.infer<typeof intake_MetaReplyToPart_Schema>;
export type Intake_ChatMessage = z.infer<typeof intake_ChatMessage_Schema>;
export type Intake_SystemMessage = z.infer<typeof intake_SystemMessage_Schema>;
// Parts: mirror the Typescript definitions from the frontend-side
const dMessageDataInlineSchema = z.object({
const dMessage_DataInline_Schema = z.object({
idt: z.literal('text'),
text: z.string(),
mimeType: z.string().optional(),
});
const dMessageTextPartSchema = z.object({
const dMessage_TextPart_Schema = z.object({
pt: z.literal('text'),
text: z.string(),
});
const dMessageDocPartSchema = z.object({
const dMessage_DocPart_Schema = z.object({
pt: z.literal('doc'),
type: z.enum([
@@ -32,7 +32,7 @@ const dMessageDocPartSchema = z.object({
'text/plain',
]),
data: dMessageDataInlineSchema,
data: dMessage_DataInline_Schema,
// id of the document, to be known to the model
ref: z.string(),
@@ -40,14 +40,14 @@ const dMessageDocPartSchema = z.object({
// meta: ignored...
});
const dMessageToolCallPartSchema = z.object({
const dMessage_ToolCallPart_Schema = z.object({
pt: z.literal('tool_call'),
id: z.string(),
name: z.string(),
args: z.record(z.any()).optional(),
});
const dMessageToolResponsePartSchema = z.object({
const dMessage_ToolResponsePart_Schema = z.object({
pt: z.literal('tool_response'),
id: z.string(),
name: z.string(),
@@ -56,7 +56,7 @@ const dMessageToolResponsePartSchema = z.object({
});
const intakeInlineImagePartSchema = z.object({
const intake_InlineImagePart_Schema = z.object({
pt: z.literal('inline_image'),
/**
* The MIME type of the image.
@@ -75,7 +75,7 @@ const intakeInlineAudioPartSchema = z.object({
base64: z.string(),
});*/
const intakeMetaReplyToPartSchema = z.object({
const intake_MetaReplyToPart_Schema = z.object({
pt: z.literal('meta_reply_to'),
replyTo: z.string(),
});
@@ -83,17 +83,17 @@ const intakeMetaReplyToPartSchema = z.object({
// Messagges
export const intakeSystemMessageSchema = z.object({
parts: z.array(dMessageTextPartSchema),
export const intake_SystemMessage_Schema = z.object({
parts: z.array(dMessage_TextPart_Schema),
});
export const intakeChatMessageSchema = z.discriminatedUnion('role', [
export const intake_ChatMessage_Schema = z.discriminatedUnion('role', [
// User
z.object({
role: z.literal('user'),
parts: z.array(z.discriminatedUnion('pt', [
dMessageTextPartSchema, intakeInlineImagePartSchema, dMessageDocPartSchema, intakeMetaReplyToPartSchema,
dMessage_TextPart_Schema, intake_InlineImagePart_Schema, dMessage_DocPart_Schema, intake_MetaReplyToPart_Schema,
])),
}),
@@ -101,14 +101,14 @@ export const intakeChatMessageSchema = z.discriminatedUnion('role', [
z.object({
role: z.literal('model'),
parts: z.array(z.discriminatedUnion('pt', [
dMessageTextPartSchema, intakeInlineImagePartSchema, dMessageToolCallPartSchema,
dMessage_TextPart_Schema, intake_InlineImagePart_Schema, dMessage_ToolCallPart_Schema,
])),
}),
// Tool
z.object({
role: z.literal('tool'),
parts: z.array(dMessageToolResponsePartSchema),
parts: z.array(dMessage_ToolResponsePart_Schema),
}),
]);
@@ -2,8 +2,8 @@ import { z } from 'zod';
// Export types
export type IntakeToolDefinition = z.infer<typeof intakeToolDefinitionSchema>;
export type IntakeToolsPolicy = z.infer<typeof intakeToolsPolicySchema>;
export type Intake_ToolDefinition = z.infer<typeof intake_ToolDefinition_Schema>;
export type Intake_ToolsPolicy = z.infer<typeof intake_ToolsPolicy_Schema>;
// Tools > Function Call
@@ -18,7 +18,7 @@ export type IntakeToolsPolicy = z.infer<typeof intakeToolsPolicySchema>;
* of the properties for our function calling use case.
*
*/
const openAPISchemaObjectSchema = z.object({
export const openAPI_SchemaObject_Schema = z.object({
// allowed data types - https://ai.google.dev/api/rest/v1beta/cachedContents#Type
type: z.enum(['string', 'number', 'integer', 'boolean', 'array', 'object']),
@@ -53,12 +53,12 @@ const openAPISchemaObjectSchema = z.object({
});
// an object-only subset of the above, which is the JSON object owner of the parameters
const intakeFunctionCallInputSchemaSchema = z.object({
properties: z.record(openAPISchemaObjectSchema),
const intake_FunctionCallInputSchema_Schema = z.object({
properties: z.record(openAPI_SchemaObject_Schema),
required: z.array(z.string()).optional(),
});
const intakeFunctionCallSchema = z.object({
const intake_FunctionCall_Schema = z.object({
/**
* The name of the function to call. Up to 64 characters long, and can only contain letters, numbers, underscores, and hyphens.
*/
@@ -75,19 +75,19 @@ const intakeFunctionCallSchema = z.object({
* A JSON Schema object defining the expected parameters for the function call.
* (OpenAI,Google: parameters, Anthropic: input_schema)
*/
input_schema: intakeFunctionCallInputSchemaSchema.optional(),
input_schema: intake_FunctionCallInputSchema_Schema.optional(),
});
const intakeToolFunctionCallDefinitionSchema = z.object({
const intake_ToolFunctionCallDefinition_Schema = z.object({
type: z.literal('function_call'),
function_call: intakeFunctionCallSchema,
function_call: intake_FunctionCall_Schema,
// domain: z.enum(['server', 'client']).optional(),
});
// Tools - Gemini Code Interpreter
const intakeToolGeminiCodeInterpreterSchema = z.object({
const intake_ToolGeminiCodeInterpreter_Schema = z.object({
type: z.literal('gemini_code_interpreter'),
});
@@ -102,7 +102,7 @@ const intakeToolGeminiCodeInterpreterSchema = z.object({
*
* In the future we can have multiple preprocessors, such as data retrieval and generation (rag), etc.
*/
const intakeToolPreprocessorSchema = z.object({
const intake_ToolPreprocessor_Schema = z.object({
type: z.literal('preprocessor'),
pname: z.literal('anthropic_artifacts'),
});
@@ -135,10 +135,10 @@ const intakeToolPreprocessorSchema = z.object({
* { type: 'preprocessor', pname: 'anthropic_artifacts' },
* ]
*/
export const intakeToolDefinitionSchema = z.discriminatedUnion('type', [
intakeToolFunctionCallDefinitionSchema,
intakeToolGeminiCodeInterpreterSchema,
intakeToolPreprocessorSchema,
export const intake_ToolDefinition_Schema = z.discriminatedUnion('type', [
intake_ToolFunctionCallDefinition_Schema,
intake_ToolGeminiCodeInterpreter_Schema,
intake_ToolPreprocessor_Schema,
]);
/**
@@ -148,7 +148,7 @@ export const intakeToolDefinitionSchema = z.discriminatedUnion('type', [
* - function_call: must use a specific Function Tool
* - none: same as not giving the model any tool [REMOVED - just give no tools]
*/
export const intakeToolsPolicySchema = z.discriminatedUnion('type', [
export const intake_ToolsPolicy_Schema = z.discriminatedUnion('type', [
z.object({ type: z.literal('auto') }),
z.object({ type: z.literal('any') /*, parallel: z.boolean()*/ }),
z.object({ type: z.literal('function_call'), function_call: z.object({ name: z.string() }) }),