Gemini: update dispatch responses

This commit is contained in:
Enrico Ros
2024-07-09 12:55:10 -07:00
parent b58e0f85f9
commit 21ec7219c3
2 changed files with 56 additions and 31 deletions
@@ -12,7 +12,7 @@ import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModels
import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';
import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelSchema, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes';
import { GeminiBlockSafetyLevel, geminiBlockSafetyLevelEnum, GeminiContentSchema, GeminiGenerateContentRequest, geminiGeneratedContentResponseSchema, geminiModelsGenerateContentPath, geminiModelsListOutputSchema, geminiModelsListPath } from './gemini.wiretypes';
import { geminiFilterModels, geminiModelToModelDescription, geminiSortModels } from '~/modules/llms/server/gemini/gemini.models';
@@ -110,7 +110,7 @@ async function geminiPOST<TOut extends object, TPostBody extends object>(access:
export const geminiAccessSchema = z.object({
dialect: z.enum(['gemini']),
geminiKey: z.string(),
minSafetyLevel: geminiBlockSafetyLevelSchema,
minSafetyLevel: geminiBlockSafetyLevelEnum,
});
export type GeminiAccessSchema = z.infer<typeof geminiAccessSchema>;
@@ -148,7 +148,7 @@ const geminiToolConfigSchema = z.object({
}).optional(),
});
const geminiHarmCategorySchema = z.enum([
const geminiHarmCategoryEnum = z.enum([
'HARM_CATEGORY_UNSPECIFIED',
'HARM_CATEGORY_DEROGATORY',
'HARM_CATEGORY_TOXICITY',
@@ -162,7 +162,7 @@ const geminiHarmCategorySchema = z.enum([
'HARM_CATEGORY_DANGEROUS_CONTENT',
]);
export const geminiBlockSafetyLevelSchema = z.enum([
export const geminiBlockSafetyLevelEnum = z.enum([
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
'BLOCK_LOW_AND_ABOVE',
'BLOCK_MEDIUM_AND_ABOVE',
@@ -170,11 +170,11 @@ export const geminiBlockSafetyLevelSchema = z.enum([
'BLOCK_NONE',
]);
export type GeminiBlockSafetyLevel = z.infer<typeof geminiBlockSafetyLevelSchema>;
export type GeminiBlockSafetyLevel = z.infer<typeof geminiBlockSafetyLevelEnum>;
const geminiSafetySettingSchema = z.object({
category: geminiHarmCategorySchema,
threshold: geminiBlockSafetyLevelSchema,
category: geminiHarmCategoryEnum,
threshold: geminiBlockSafetyLevelEnum,
});
const geminiGenerationConfigSchema = z.object({
@@ -220,7 +220,8 @@ export function geminiHarmProbabilitySortFunction(a: { probability: string }, b:
return order.indexOf(a.probability) - order.indexOf(b.probability);
}
const geminiHarmProbabilitySchema = z.enum([
const geminiHarmProbabilityEnum = z.enum([
'HARM_PROBABILITY_UNSPECIFIED',
'NEGLIGIBLE',
'LOW',
@@ -228,14 +229,7 @@ const geminiHarmProbabilitySchema = z.enum([
'HIGH',
]);
export type GeminiSafetyRatings = z.infer<typeof geminiSafetyRatingsSchema>;
const geminiSafetyRatingsSchema = z.array(z.object({
'category': geminiHarmCategorySchema,
'probability': geminiHarmProbabilitySchema,
'blocked': z.boolean().optional(),
}));
const geminiFinishReasonSchema = z.enum([
const geminiFinishReasonEnum = z.enum([
'FINISH_REASON_UNSPECIFIED',
'STOP',
'MAX_TOKENS',
@@ -244,13 +238,54 @@ const geminiFinishReasonSchema = z.enum([
'OTHER',
]);
const geminiBlockReasonEnum = z.enum([
'BLOCK_REASON_UNSPECIFIED',
'SAFETY',
'OTHER',
]);
export type GeminiSafetyRatings = z.infer<typeof geminiSafetyRatingsSchema>;
const geminiSafetyRatingsSchema = z.array(z.object({
'category': geminiHarmCategoryEnum,
'probability': geminiHarmProbabilityEnum,
'blocked': z.boolean().optional(),
}));
/*const geminiGroundingAttributionSchema = z.object({
sourceId: z.object({
groundingPassage: z.object({
passageId: z.string(),
partIndex: z.number(),
}).optional(),
semanticRetrieverChunk: z.object({
source: z.string(),
chunk: z.string(),
}).optional(),
}),
content: geminiContentSchema,
});*/
const geminiPromptFeedbackSchema = z.object({
blockReason: geminiBlockReasonEnum.optional(),
safetyRatings: geminiSafetyRatingsSchema.optional(),
});
const geminiUsageMetadataSchema = z.object({
promptTokenCount: z.number(),
cachedContentTokenCount: z.number().optional(),
candidatesTokenCount: z.number(),
totalTokenCount: z.number(),
});
export const geminiGeneratedContentResponseSchema = z.object({
// either all requested candidates are returned or no candidates at all
// no candidates are returned only if there was something wrong with the prompt (see promptFeedback)
candidates: z.array(z.object({
index: z.number(),
content: geminiContentSchema.optional(), // this can be missing if the finishReason is not 'MAX_TOKENS'
finishReason: geminiFinishReasonSchema.optional(),
finishReason: geminiFinishReasonEnum.optional(),
safetyRatings: geminiSafetyRatingsSchema.optional(), // undefined when finishReason is 'RECITATION'
citationMetadata: z.object({
startIndex: z.number().optional(),
@@ -259,18 +294,8 @@ export const geminiGeneratedContentResponseSchema = z.object({
license: z.string().optional(),
}).optional(),
tokenCount: z.number().optional(),
// groundingAttributions: z.array(GroundingAttribution).optional(), // This field is populated for GenerateAnswer calls.
})).optional(),
usageMetadata: z.object({
promptTokenCount: z.number(),
candidatesTokenCount: z.number(),
totalTokenCount: z.number(),
}).optional(),
// NOTE: promptFeedback is only send in the first chunk in a streaming response
promptFeedback: z.object({
blockReason: z.enum(['BLOCK_REASON_UNSPECIFIED', 'SAFETY', 'OTHER']).optional(),
safetyRatings: geminiSafetyRatingsSchema.optional(),
}).optional(),
// groundingAttributions: z.array(geminiGroundingAttributionSchema).optional(), // This field is populated for GenerateAnswer calls.
})),
promptFeedback: geminiPromptFeedbackSchema.optional(), // only sent in the 1st chunk of a streaming response
usageMetadata: geminiUsageMetadataSchema.optional(), // only use (sent?) at the end
});