mirror of
https://github.com/enricoros/big-AGI.git
synced 2026-05-10 21:50:14 -07:00
Llms: streaming as a vendor function (then all directed to the unified)
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
export const runtime = 'edge';
|
||||
export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llms.streaming';
|
||||
export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llm.server.streaming';
|
||||
@@ -13,10 +13,9 @@ import RecordVoiceOverIcon from '@mui/icons-material/RecordVoiceOver';
|
||||
|
||||
import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown';
|
||||
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client';
|
||||
import { SystemPurposeId, SystemPurposes } from '../../data';
|
||||
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
|
||||
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown';
|
||||
|
||||
import { Link } from '~/common/components/Link';
|
||||
@@ -216,7 +215,7 @@ export function CallUI(props: {
|
||||
responseAbortController.current = new AbortController();
|
||||
let finalText = '';
|
||||
let error: any | null = null;
|
||||
llmStreamChatGenerate(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
|
||||
llmStreamingChatGenerate(chatLLMId, callPrompt, null, null, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
|
||||
const text = updatedMessage.text?.trim();
|
||||
if (text) {
|
||||
finalText = text;
|
||||
|
||||
@@ -3,7 +3,7 @@ import * as React from 'react';
|
||||
import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy';
|
||||
import { SxProps } from '@mui/joy/styles/types';
|
||||
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
import type { VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
export function CallMessage(props: {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { DLLMId } from '~/modules/llms/store-llms';
|
||||
import { SystemPurposeId } from '../../../data';
|
||||
import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions';
|
||||
import { autoTitle } from '~/modules/aifn/autotitle/autoTitle';
|
||||
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
|
||||
import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';
|
||||
import { speakText } from '~/modules/elevenlabs/elevenlabs.client';
|
||||
|
||||
import { DMessage, useChatStore } from '~/common/state/store-chats';
|
||||
@@ -63,7 +63,7 @@ async function streamAssistantMessage(
|
||||
const messages = history.map(({ role, text }) => ({ role, content: text }));
|
||||
|
||||
try {
|
||||
await llmStreamChatGenerate(llmId, messages, abortSignal,
|
||||
await llmStreamingChatGenerate(llmId, messages, null, null, abortSignal,
|
||||
(updatedMessage: Partial<DMessage>) => {
|
||||
// update the message in the store (and thus schedule a re-render)
|
||||
editMessage(updatedMessage);
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import * as React from 'react';
|
||||
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
import { DLLMId, useModelsStore } from '~/modules/llms/store-llms';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
export interface LLMChainStep {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import type { VChatFunctionIn } from '~/modules/llms/client/llm.client.types';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client';
|
||||
import { useModelsStore } from '~/modules/llms/store-llms';
|
||||
|
||||
import { useChatStore } from '~/common/state/store-chats';
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
|
||||
import { useModelsStore } from '~/modules/llms/store-llms';
|
||||
|
||||
import { useChatStore } from '~/common/state/store-chats';
|
||||
|
||||
@@ -8,7 +8,7 @@ import ReplayIcon from '@mui/icons-material/Replay';
|
||||
import StopOutlinedIcon from '@mui/icons-material/StopOutlined';
|
||||
import TelegramIcon from '@mui/icons-material/Telegram';
|
||||
|
||||
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
|
||||
import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';
|
||||
|
||||
import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage';
|
||||
|
||||
@@ -86,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi
|
||||
const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction);
|
||||
|
||||
try {
|
||||
await llmStreamChatGenerate(diagramLlm.id, diagramPrompt, stepAbortController.signal,
|
||||
await llmStreamingChatGenerate(diagramLlm.id, diagramPrompt, null, null, stepAbortController.signal,
|
||||
(update: Partial<{ text: string, typing: boolean, originLLM: string }>) => {
|
||||
assistantMessage = { ...assistantMessage, ...update };
|
||||
setMessage(assistantMessage);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
|
||||
import type { FormRadioOption } from '~/common/components/forms/FormRadioControl';
|
||||
import type { VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
export type DiagramType = 'auto' | 'mind';
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
|
||||
import { useModelsStore } from '~/modules/llms/store-llms';
|
||||
|
||||
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
* porting of implementation from here: https://til.simonwillison.net/llms/python-react-pattern
|
||||
*/
|
||||
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
import { DLLMId } from '~/modules/llms/store-llms';
|
||||
import { callApiSearchGoogle } from '~/modules/google/search.client';
|
||||
import { callBrowseFetchPage } from '~/modules/browse/browse.client';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
// prompt to implement the ReAct paradigm: https://arxiv.org/abs/2210.03629
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
|
||||
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
// prompt to be tried when doing recursive summerization.
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import * as React from 'react';
|
||||
|
||||
import type { DLLMId } from '~/modules/llms/store-llms';
|
||||
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
|
||||
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
|
||||
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
export function useStreamChatText() {
|
||||
@@ -25,7 +24,7 @@ export function useStreamChatText() {
|
||||
|
||||
try {
|
||||
let lastText = '';
|
||||
await llmStreamChatGenerate(llmId, prompt, abortControllerRef.current.signal, (update) => {
|
||||
await llmStreamingChatGenerate(llmId, prompt, null, null, abortControllerRef.current.signal, (update) => {
|
||||
if (update.text) {
|
||||
lastText = update.text;
|
||||
setPartialText(lastText);
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
|
||||
|
||||
|
||||
// Model List types
|
||||
// export { type ModelDescriptionSchema } from '../server/llm.server.types';
|
||||
|
||||
|
||||
// Chat Generate types
|
||||
|
||||
export interface VChatMessageIn {
|
||||
role: 'assistant' | 'system' | 'user'; // | 'function';
|
||||
content: string;
|
||||
//name?: string; // when role: 'function'
|
||||
}
|
||||
|
||||
export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;
|
||||
|
||||
export interface VChatMessageOut {
|
||||
role: 'assistant' | 'system' | 'user';
|
||||
content: string;
|
||||
finish_reason: 'stop' | 'length' | null;
|
||||
}
|
||||
|
||||
export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {
|
||||
function_name: string;
|
||||
function_arguments: object | null;
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
import type { DLLMId } from '../store-llms';
|
||||
import { findVendorForLlmOrThrow } from '../vendors/vendors.registry';
|
||||
|
||||
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from './llm.client.types';
|
||||
|
||||
|
||||
export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
|
||||
llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number,
|
||||
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {
|
||||
|
||||
// id to DLLM and vendor
|
||||
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
|
||||
|
||||
// FIXME: relax the forced cast
|
||||
const options = llm.options as TLLMOptions;
|
||||
|
||||
// get the access
|
||||
const partialSourceSetup = llm._source.setup;
|
||||
const access = vendor.getTransportAccess(partialSourceSetup);
|
||||
|
||||
// execute via the vendor
|
||||
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
import type { DLLMId } from './store-llms';
|
||||
import type { OpenAIWire } from './server/openai/openai.wiretypes';
|
||||
import { findVendorForLlmOrThrow } from './vendors/vendors.registry';
|
||||
|
||||
|
||||
// LLM Client Types
|
||||
// NOTE: Model List types in '../server/llm.server.types';
|
||||
|
||||
export interface VChatMessageIn {
|
||||
role: 'assistant' | 'system' | 'user'; // | 'function';
|
||||
content: string;
|
||||
//name?: string; // when role: 'function'
|
||||
}
|
||||
|
||||
export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;
|
||||
|
||||
export interface VChatMessageOut {
|
||||
role: 'assistant' | 'system' | 'user';
|
||||
content: string;
|
||||
finish_reason: 'stop' | 'length' | null;
|
||||
}
|
||||
|
||||
export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {
|
||||
function_name: string;
|
||||
function_arguments: object | null;
|
||||
}
|
||||
|
||||
|
||||
// LLM Client Functions
|
||||
|
||||
export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
|
||||
llmId: DLLMId,
|
||||
messages: VChatMessageIn[],
|
||||
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
|
||||
maxTokens?: number,
|
||||
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {
|
||||
|
||||
// id to DLLM and vendor
|
||||
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
|
||||
|
||||
// FIXME: relax the forced cast
|
||||
const options = llm.options as TLLMOptions;
|
||||
|
||||
// get the access
|
||||
const partialSourceSetup = llm._source.setup;
|
||||
const access = vendor.getTransportAccess(partialSourceSetup);
|
||||
|
||||
// execute via the vendor
|
||||
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
|
||||
}
|
||||
|
||||
|
||||
export async function llmStreamingChatGenerate<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
|
||||
llmId: DLLMId,
|
||||
messages: VChatMessageIn[],
|
||||
functions: VChatFunctionIn[] | null,
|
||||
forceFunctionName: string | null,
|
||||
abortSignal: AbortSignal,
|
||||
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
|
||||
): Promise<void> {
|
||||
|
||||
// id to DLLM and vendor
|
||||
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
|
||||
|
||||
// FIXME: relax the forced cast
|
||||
const llmOptions = llm.options as TLLMOptions;
|
||||
|
||||
// get the access
|
||||
const partialSourceSetup = llm._source.setup;
|
||||
const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access'];
|
||||
|
||||
// execute via the vendor
|
||||
return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, functions, forceFunctionName, abortSignal, onUpdate);
|
||||
}
|
||||
+12
-9
@@ -9,6 +9,9 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, S
|
||||
import type { AnthropicWire } from './anthropic/anthropic.wiretypes';
|
||||
import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router';
|
||||
|
||||
// Gemini server imports
|
||||
import { geminiAccessSchema } from './gemini/gemini.router';
|
||||
|
||||
// Ollama server imports
|
||||
import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes';
|
||||
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router';
|
||||
@@ -37,24 +40,24 @@ type EventStreamFormat = 'sse' | 'json-nl';
|
||||
type AIStreamParser = (data: string) => { text: string, close: boolean };
|
||||
|
||||
|
||||
const chatStreamInputSchema = z.object({
|
||||
access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]),
|
||||
const chatStreamingInputSchema = z.object({
|
||||
access: z.union([anthropicAccessSchema, geminiAccessSchema, ollamaAccessSchema, openAIAccessSchema]),
|
||||
model: openAIModelSchema,
|
||||
history: openAIHistorySchema,
|
||||
});
|
||||
export type ChatStreamInputSchema = z.infer<typeof chatStreamInputSchema>;
|
||||
export type ChatStreamingInputSchema = z.infer<typeof chatStreamingInputSchema>;
|
||||
|
||||
const chatStreamFirstOutputPacketSchema = z.object({
|
||||
const chatStreamingFirstOutputPacketSchema = z.object({
|
||||
model: z.string(),
|
||||
});
|
||||
export type ChatStreamFirstOutputPacketSchema = z.infer<typeof chatStreamFirstOutputPacketSchema>;
|
||||
export type ChatStreamingFirstOutputPacketSchema = z.infer<typeof chatStreamingFirstOutputPacketSchema>;
|
||||
|
||||
|
||||
export async function llmStreamingRelayHandler(req: NextRequest): Promise<Response> {
|
||||
|
||||
// inputs - reuse the tRPC schema
|
||||
const body = await req.json();
|
||||
const { access, model, history } = chatStreamInputSchema.parse(body);
|
||||
const { access, model, history } = chatStreamingInputSchema.parse(body);
|
||||
|
||||
// access/dialect dependent setup:
|
||||
// - requestAccess: the headers and URL to use for the upstream API call
|
||||
@@ -240,7 +243,7 @@ function createAnthropicStreamParser(): AIStreamParser {
|
||||
// hack: prepend the model name to the first packet
|
||||
if (!hasBegun) {
|
||||
hasBegun = true;
|
||||
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model };
|
||||
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model };
|
||||
text = JSON.stringify(firstPacket) + text;
|
||||
}
|
||||
|
||||
@@ -276,7 +279,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser {
|
||||
// hack: prepend the model name to the first packet
|
||||
if (!hasBegun && chunk.model) {
|
||||
hasBegun = true;
|
||||
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: chunk.model };
|
||||
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model };
|
||||
text = JSON.stringify(firstPacket) + text;
|
||||
}
|
||||
|
||||
@@ -317,7 +320,7 @@ function createOpenAIStreamParser(): AIStreamParser {
|
||||
// hack: prepend the model name to the first packet
|
||||
if (!hasBegun) {
|
||||
hasBegun = true;
|
||||
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model };
|
||||
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model };
|
||||
text = JSON.stringify(firstPacket) + text;
|
||||
}
|
||||
|
||||
+12
-2
@@ -1,10 +1,10 @@
|
||||
import type React from 'react';
|
||||
import type { TRPCClientErrorBase } from '@trpc/client';
|
||||
|
||||
import type { DLLM, DModelSourceId } from '../store-llms';
|
||||
import type { DLLM, DLLMId, DModelSourceId } from '../store-llms';
|
||||
import type { ModelDescriptionSchema } from '../server/llm.server.types';
|
||||
import type { ModelVendorId } from './vendors.registry';
|
||||
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../client/llm.client.types';
|
||||
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '~/modules/llms/llm.client';
|
||||
|
||||
|
||||
export interface IModelVendor<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown, TDLLM = DLLM<TSourceSetup, TLLMOptions>> {
|
||||
@@ -43,4 +43,14 @@ export interface IModelVendor<TSourceSetup = unknown, TAccess = unknown, TLLMOpt
|
||||
maxTokens?: number,
|
||||
) => Promise<VChatMessageOut | VChatMessageOrFunctionCallOut>;
|
||||
|
||||
streamingChatGenerateOrThrow: (
|
||||
access: TAccess,
|
||||
llmId: DLLMId,
|
||||
llmOptions: TLLMOptions,
|
||||
messages: VChatMessageIn[],
|
||||
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
|
||||
abortSignal: AbortSignal,
|
||||
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
|
||||
) => Promise<void>;
|
||||
|
||||
}
|
||||
|
||||
+5
-1
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
|
||||
|
||||
import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router';
|
||||
import type { IModelVendor } from '../IModelVendor';
|
||||
import type { VChatMessageOut } from '../../client/llm.client.types';
|
||||
import type { VChatMessageOut } from '../../llm.client';
|
||||
import { unifiedStreamingClient } from '../unifiedStreamingClient';
|
||||
|
||||
import { LLMOptionsOpenAI } from '../openai/openai.vendor';
|
||||
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
|
||||
@@ -77,4 +78,7 @@ export const ModelVendorAnthropic: IModelVendor<SourceSetupAnthropic, AnthropicA
|
||||
}
|
||||
},
|
||||
|
||||
// Chat Generate (streaming) with Functions
|
||||
streamingChatGenerateOrThrow: unifiedStreamingClient,
|
||||
|
||||
};
|
||||
|
||||
@@ -61,4 +61,5 @@ export const ModelVendorAzure: IModelVendor<SourceSetupAzure, OpenAIAccessSchema
|
||||
// OpenAI transport ('azure' dialect in 'access')
|
||||
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
|
||||
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
|
||||
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
|
||||
};
|
||||
+5
-1
@@ -6,7 +6,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
|
||||
|
||||
import type { GeminiAccessSchema } from '../../server/gemini/gemini.router';
|
||||
import type { IModelVendor } from '../IModelVendor';
|
||||
import type { VChatMessageOut } from '../../client/llm.client.types';
|
||||
import type { VChatMessageOut } from '../../llm.client';
|
||||
import { unifiedStreamingClient } from '../unifiedStreamingClient';
|
||||
|
||||
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
|
||||
|
||||
@@ -86,4 +87,7 @@ export const ModelVendorGemini: IModelVendor<SourceSetupGemini, GeminiAccessSche
|
||||
}
|
||||
},
|
||||
|
||||
// Chat Generate (streaming) with Functions
|
||||
streamingChatGenerateOrThrow: unifiedStreamingClient,
|
||||
|
||||
};
|
||||
@@ -41,4 +41,5 @@ export const ModelVendorLocalAI: IModelVendor<SourceSetupLocalAI, OpenAIAccessSc
|
||||
// OpenAI transport ('localai' dialect in 'access')
|
||||
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
|
||||
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
|
||||
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
|
||||
};
|
||||
|
||||
@@ -51,4 +51,5 @@ export const ModelVendorMistral: IModelVendor<SourceSetupMistral, OpenAIAccessSc
|
||||
// OpenAI transport ('mistral' dialect in 'access')
|
||||
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
|
||||
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
|
||||
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
|
||||
};
|
||||
+5
-1
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
|
||||
|
||||
import type { IModelVendor } from '../IModelVendor';
|
||||
import type { OllamaAccessSchema } from '../../server/ollama/ollama.router';
|
||||
import type { VChatMessageOut } from '../../client/llm.client.types';
|
||||
import type { VChatMessageOut } from '../../llm.client';
|
||||
import { unifiedStreamingClient } from '../unifiedStreamingClient';
|
||||
|
||||
import type { LLMOptionsOpenAI } from '../openai/openai.vendor';
|
||||
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
|
||||
@@ -70,4 +71,7 @@ export const ModelVendorOllama: IModelVendor<SourceSetupOllama, OllamaAccessSche
|
||||
}
|
||||
},
|
||||
|
||||
// Chat Generate (streaming) with Functions
|
||||
streamingChatGenerateOrThrow: unifiedStreamingClient,
|
||||
|
||||
};
|
||||
|
||||
@@ -41,4 +41,5 @@ export const ModelVendorOoobabooga: IModelVendor<SourceSetupOobabooga, OpenAIAcc
|
||||
// OpenAI transport (oobabooga dialect in 'access')
|
||||
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
|
||||
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
|
||||
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
|
||||
};
|
||||
+5
-1
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
|
||||
|
||||
import type { IModelVendor } from '../IModelVendor';
|
||||
import type { OpenAIAccessSchema } from '../../server/openai/openai.router';
|
||||
import type { VChatMessageOrFunctionCallOut } from '../../client/llm.client.types';
|
||||
import type { VChatMessageOrFunctionCallOut } from '../../llm.client';
|
||||
import { unifiedStreamingClient } from '../unifiedStreamingClient';
|
||||
|
||||
import { OpenAILLMOptions } from './OpenAILLMOptions';
|
||||
import { OpenAISourceSetup } from './OpenAISourceSetup';
|
||||
@@ -84,4 +85,7 @@ export const ModelVendorOpenAI: IModelVendor<SourceSetupOpenAI, OpenAIAccessSche
|
||||
}
|
||||
},
|
||||
|
||||
// Chat Generate (streaming) with Functions
|
||||
streamingChatGenerateOrThrow: unifiedStreamingClient,
|
||||
|
||||
};
|
||||
|
||||
@@ -62,4 +62,5 @@ export const ModelVendorOpenRouter: IModelVendor<SourceSetupOpenRouter, OpenAIAc
|
||||
// OpenAI transport ('openrouter' dialect in 'access')
|
||||
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
|
||||
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
|
||||
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
|
||||
};
|
||||
+12
-27
@@ -1,13 +1,11 @@
|
||||
import { apiAsync } from '~/common/util/trpc.client';
|
||||
|
||||
import type { ChatStreamFirstOutputPacketSchema, ChatStreamInputSchema } from '../server/llms.streaming';
|
||||
import type { DLLM, DLLMId } from '../store-llms';
|
||||
import { findVendorForLlmOrThrow } from '../vendors/vendors.registry';
|
||||
import type { ChatStreamingFirstOutputPacketSchema, ChatStreamingInputSchema } from '../server/llm.server.streaming';
|
||||
import type { DLLMId } from '../store-llms';
|
||||
import type { VChatFunctionIn, VChatMessageIn } from '../llm.client';
|
||||
|
||||
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
|
||||
|
||||
import type { VChatMessageIn } from './llm.client.types';
|
||||
|
||||
|
||||
/**
|
||||
* Client side chat generation, with streaming. This decodes the (text) streaming response from
|
||||
@@ -16,27 +14,14 @@ import type { VChatMessageIn } from './llm.client.types';
|
||||
* Vendor-specific implementation is on our server backend (API) code. This function tries to be
|
||||
* as generic as possible.
|
||||
*
|
||||
* @param llmId LLM to use
|
||||
* @param messages the history of messages to send to the API endpoint
|
||||
* @param abortSignal used to initiate a client-side abort of the fetch request to the API endpoint
|
||||
* @param onUpdate callback when a piece of a message (text, model name, typing..) is received
|
||||
* NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received
|
||||
*/
|
||||
export async function llmStreamChatGenerate(
|
||||
export async function unifiedStreamingClient<TSourceSetup = unknown, TLLMOptions = unknown>(
|
||||
access: ChatStreamingInputSchema['access'],
|
||||
llmId: DLLMId,
|
||||
llmOptions: TLLMOptions,
|
||||
messages: VChatMessageIn[],
|
||||
abortSignal: AbortSignal,
|
||||
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
|
||||
): Promise<void> {
|
||||
const { llm, vendor } = findVendorForLlmOrThrow(llmId);
|
||||
const access = vendor.getTransportAccess(llm._source.setup) as ChatStreamInputSchema['access'];
|
||||
return await vendorStreamChat(access, llm, messages, abortSignal, onUpdate);
|
||||
}
|
||||
|
||||
|
||||
async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
|
||||
access: ChatStreamInputSchema['access'],
|
||||
llm: DLLM<TSourceSetup, TLLMOptions>,
|
||||
messages: VChatMessageIn[],
|
||||
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
|
||||
abortSignal: AbortSignal,
|
||||
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
|
||||
) {
|
||||
@@ -80,12 +65,12 @@ async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
|
||||
}
|
||||
|
||||
// model params (llm)
|
||||
const { llmRef, llmTemperature, llmResponseTokens } = (llm.options as any) || {};
|
||||
const { llmRef, llmTemperature, llmResponseTokens } = (llmOptions as any) || {};
|
||||
if (!llmRef || llmTemperature === undefined || llmResponseTokens === undefined)
|
||||
throw new Error(`Error in configuration for model ${llm.id}: ${JSON.stringify(llm.options)}`);
|
||||
throw new Error(`Error in configuration for model ${llmId}: ${JSON.stringify(llmOptions)}`);
|
||||
|
||||
// prepare the input, similarly to the tRPC openAI.chatGenerate
|
||||
const input: ChatStreamInputSchema = {
|
||||
const input: ChatStreamingInputSchema = {
|
||||
access,
|
||||
model: {
|
||||
id: llmRef,
|
||||
@@ -132,7 +117,7 @@ async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
|
||||
incrementalText = incrementalText.substring(endOfJson + 1);
|
||||
parsedFirstPacket = true;
|
||||
try {
|
||||
const parsed: ChatStreamFirstOutputPacketSchema = JSON.parse(json);
|
||||
const parsed: ChatStreamingFirstOutputPacketSchema = JSON.parse(json);
|
||||
onUpdate({ originLLM: parsed.model }, false);
|
||||
} catch (e) {
|
||||
// error parsing JSON, ignore
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
import { ModelVendorAnthropic } from './anthropic/anthropic.vendor';
|
||||
import { ModelVendorAzure } from './azure/azure.vendor';
|
||||
import { ModelVendorGemini } from './googleai/gemini.vendor';
|
||||
import { ModelVendorGemini } from './gemini/gemini.vendor';
|
||||
import { ModelVendorLocalAI } from './localai/localai.vendor';
|
||||
import { ModelVendorMistral } from './mistral/mistral.vendor';
|
||||
import { ModelVendorOllama } from './ollama/ollama.vendor';
|
||||
|
||||
Reference in New Issue
Block a user