From 9e6fd7c24c7789b4158d0cae6be9ef4e360d9963 Mon Sep 17 00:00:00 2001 From: user Date: Fri, 16 Aug 2024 15:45:49 +0300 Subject: [PATCH] Implement tools (function calling) for Claude --- .../request/preprocessors/sign-aws-request.ts | 2 ++ .../preprocessors/sign-vertex-ai-request.ts | 3 ++- src/shared/api-schemas/anthropic.ts | 17 +++++++++++++++++ src/shared/tokenization/claude.ts | 3 +++ 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts index 5a937bc..a564594 100644 --- a/src/proxy/middleware/request/preprocessors/sign-aws-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-aws-request.ts @@ -144,6 +144,8 @@ function applyAwsStrictValidation(req: Request): unknown { temperature: true, top_k: true, top_p: true, + tools: true, + tool_choice: true, }) .strip() .parse(req.body); diff --git a/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts b/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts index afc6fb7..d95b1be 100644 --- a/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts +++ b/src/proxy/middleware/request/preprocessors/sign-vertex-ai-request.ts @@ -24,7 +24,6 @@ export const signGcpRequest: RequestPreprocessor = async (req) => { req.isStreaming = String(stream) === "true"; // TODO: This should happen in transform-outbound-payload.ts - // TODO: Support tools let strippedParams: Record; strippedParams = AnthropicV1MessagesSchema.pick({ messages: true, @@ -34,6 +33,8 @@ export const signGcpRequest: RequestPreprocessor = async (req) => { temperature: true, top_k: true, top_p: true, + tools: true, + tool_choice: true, stream: true, }) .strip() diff --git a/src/shared/api-schemas/anthropic.ts b/src/shared/api-schemas/anthropic.ts index aa077f3..9bcd132 100644 --- a/src/shared/api-schemas/anthropic.ts +++ b/src/shared/api-schemas/anthropic.ts @@ -19,7 +19,12 @@ const AnthropicV1BaseSchema = z top_k: z.coerce.number().optional(), top_p: z.coerce.number().optional(), metadata: z.object({ user_id: z.string().optional() }).optional(), + tools: z.array(z.any()).optional(), + tool_choice: z.any().optional(), }) + .omit( + Boolean(config.allowOpenAIToolUsage) ? {} : { tools: true, tool_choice: true } + ) .strip(); // https://docs.anthropic.com/claude/reference/complete_post [deprecated] @@ -44,6 +49,18 @@ const AnthropicV1MessageMultimodalContentSchema = z.array( data: z.string(), }), }), + z.object({ + type: z.literal("tool_use"), + id: z.string(), + name: z.string(), + input: z.object({}).passthrough(), + }), + z.object({ + type: z.literal("tool_result"), + tool_use_id: z.string(), + is_error: z.boolean().optional(), + content: z.union([z.string(), z.object({}).passthrough()]).optional(), + }), ]) ); diff --git a/src/shared/tokenization/claude.ts b/src/shared/tokenization/claude.ts index 880db82..891ed2f 100644 --- a/src/shared/tokenization/claude.ts +++ b/src/shared/tokenization/claude.ts @@ -67,6 +67,9 @@ async function getTokenCountForMessages({ case "image": numTokens += await getImageTokenCount(part.source.data); break; + case "tool_use": + case "tool_result": + break; default: throw new Error(`Unsupported Anthropic content type.`); }