Llms: streaming as a vendor function (then all directed to the unified)

This commit is contained in:
Enrico Ros
2023-12-19 19:00:19 -08:00
parent 0ece1ce58c
commit bee49a4b1c
30 changed files with 153 additions and 116 deletions
+1 -1
View File
@@ -1,2 +1,2 @@
export const runtime = 'edge';
export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llms.streaming';
export { llmStreamingRelayHandler as POST } from '~/modules/llms/server/llm.server.streaming';
+2 -3
View File
@@ -13,10 +13,9 @@ import RecordVoiceOverIcon from '@mui/icons-material/RecordVoiceOver';
import { useChatLLMDropdown } from '../chat/components/applayout/useLLMDropdown';
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import { EXPERIMENTAL_speakTextStream } from '~/modules/elevenlabs/elevenlabs.client';
import { SystemPurposeId, SystemPurposes } from '../../data';
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';
import { useElevenLabsVoiceDropdown } from '~/modules/elevenlabs/useElevenLabsVoiceDropdown';
import { Link } from '~/common/components/Link';
@@ -216,7 +215,7 @@ export function CallUI(props: {
responseAbortController.current = new AbortController();
let finalText = '';
let error: any | null = null;
llmStreamChatGenerate(chatLLMId, callPrompt, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
llmStreamingChatGenerate(chatLLMId, callPrompt, null, null, responseAbortController.current.signal, (updatedMessage: Partial<DMessage>) => {
const text = updatedMessage.text?.trim();
if (text) {
finalText = text;
+1 -1
View File
@@ -3,7 +3,7 @@ import * as React from 'react';
import { Chip, ColorPaletteProp, VariantProp } from '@mui/joy';
import { SxProps } from '@mui/joy/styles/types';
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import type { VChatMessageIn } from '~/modules/llms/llm.client';
export function CallMessage(props: {
+2 -2
View File
@@ -2,7 +2,7 @@ import { DLLMId } from '~/modules/llms/store-llms';
import { SystemPurposeId } from '../../../data';
import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions';
import { autoTitle } from '~/modules/aifn/autotitle/autoTitle';
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';
import { speakText } from '~/modules/elevenlabs/elevenlabs.client';
import { DMessage, useChatStore } from '~/common/state/store-chats';
@@ -63,7 +63,7 @@ async function streamAssistantMessage(
const messages = history.map(({ role, text }) => ({ role, content: text }));
try {
await llmStreamChatGenerate(llmId, messages, abortSignal,
await llmStreamingChatGenerate(llmId, messages, null, null, abortSignal,
(updatedMessage: Partial<DMessage>) => {
// update the message in the store (and thus schedule a re-render)
editMessage(updatedMessage);
+1 -2
View File
@@ -1,8 +1,7 @@
import * as React from 'react';
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import { DLLMId, useModelsStore } from '~/modules/llms/store-llms';
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';
export interface LLMChainStep {
@@ -1,5 +1,4 @@
import type { VChatFunctionIn } from '~/modules/llms/client/llm.client.types';
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';
import { useChatStore } from '~/common/state/store-chats';
+1 -1
View File
@@ -1,4 +1,4 @@
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';
import { useChatStore } from '~/common/state/store-chats';
+2 -2
View File
@@ -8,7 +8,7 @@ import ReplayIcon from '@mui/icons-material/Replay';
import StopOutlinedIcon from '@mui/icons-material/StopOutlined';
import TelegramIcon from '@mui/icons-material/Telegram';
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
import { llmStreamingChatGenerate } from '~/modules/llms/llm.client';
import { ChatMessage } from '../../../apps/chat/components/message/ChatMessage';
@@ -86,7 +86,7 @@ export function DiagramsModal(props: { config: DiagramConfig, onClose: () => voi
const diagramPrompt = bigDiagramPrompt(diagramType, diagramLanguage, systemMessage.text, subject, customInstruction);
try {
await llmStreamChatGenerate(diagramLlm.id, diagramPrompt, stepAbortController.signal,
await llmStreamingChatGenerate(diagramLlm.id, diagramPrompt, null, null, stepAbortController.signal,
(update: Partial<{ text: string, typing: boolean, originLLM: string }>) => {
assistantMessage = { ...assistantMessage, ...update };
setMessage(assistantMessage);
+1 -2
View File
@@ -1,6 +1,5 @@
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import type { FormRadioOption } from '~/common/components/forms/FormRadioControl';
import type { VChatMessageIn } from '~/modules/llms/llm.client';
export type DiagramType = 'auto' | 'mind';
@@ -1,4 +1,4 @@
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';
+1 -2
View File
@@ -2,11 +2,10 @@
* porting of implementation from here: https://til.simonwillison.net/llms/python-react-pattern
*/
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import { DLLMId } from '~/modules/llms/store-llms';
import { callApiSearchGoogle } from '~/modules/google/search.client';
import { callBrowseFetchPage } from '~/modules/browse/browse.client';
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';
// prompt to implement the ReAct paradigm: https://arxiv.org/abs/2210.03629
+1 -1
View File
@@ -1,5 +1,5 @@
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
import { llmChatGenerateOrThrow } from '~/modules/llms/client/llmChatGenerate';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
// prompt to be tried when doing recursive summerization.
+2 -3
View File
@@ -1,8 +1,7 @@
import * as React from 'react';
import type { DLLMId } from '~/modules/llms/store-llms';
import type { VChatMessageIn } from '~/modules/llms/client/llm.client.types';
import { llmStreamChatGenerate } from '~/modules/llms/client/llmStreamChatGenerate';
import { llmStreamingChatGenerate, VChatMessageIn } from '~/modules/llms/llm.client';
export function useStreamChatText() {
@@ -25,7 +24,7 @@ export function useStreamChatText() {
try {
let lastText = '';
await llmStreamChatGenerate(llmId, prompt, abortControllerRef.current.signal, (update) => {
await llmStreamingChatGenerate(llmId, prompt, null, null, abortControllerRef.current.signal, (update) => {
if (update.text) {
lastText = update.text;
setPartialText(lastText);
@@ -1,27 +0,0 @@
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
// Model List types
// export { type ModelDescriptionSchema } from '../server/llm.server.types';
// Chat Generate types
export interface VChatMessageIn {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}
export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;
export interface VChatMessageOut {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}
export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {
function_name: string;
function_arguments: object | null;
}
@@ -1,23 +0,0 @@
import type { DLLMId } from '../store-llms';
import { findVendorForLlmOrThrow } from '../vendors/vendors.registry';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from './llm.client.types';
export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId, messages: VChatMessageIn[], functions: VChatFunctionIn[] | null, forceFunctionName: string | null, maxTokens?: number,
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {
// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
// FIXME: relax the forced cast
const options = llm.options as TLLMOptions;
// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup);
// execute via the vendor
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
}
+74
View File
@@ -0,0 +1,74 @@
import type { DLLMId } from './store-llms';
import type { OpenAIWire } from './server/openai/openai.wiretypes';
import { findVendorForLlmOrThrow } from './vendors/vendors.registry';
// LLM Client Types
// NOTE: Model List types in '../server/llm.server.types';
export interface VChatMessageIn {
role: 'assistant' | 'system' | 'user'; // | 'function';
content: string;
//name?: string; // when role: 'function'
}
export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;
export interface VChatMessageOut {
role: 'assistant' | 'system' | 'user';
content: string;
finish_reason: 'stop' | 'length' | null;
}
export interface VChatMessageOrFunctionCallOut extends VChatMessageOut {
function_name: string;
function_arguments: object | null;
}
// LLM Client Functions
export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
maxTokens?: number,
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {
// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
// FIXME: relax the forced cast
const options = llm.options as TLLMOptions;
// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup);
// execute via the vendor
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
}
export async function llmStreamingChatGenerate<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
abortSignal: AbortSignal,
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
): Promise<void> {
// id to DLLM and vendor
const { llm, vendor } = findVendorForLlmOrThrow<TSourceSetup, TAccess, TLLMOptions>(llmId);
// FIXME: relax the forced cast
const llmOptions = llm.options as TLLMOptions;
// get the access
const partialSourceSetup = llm._source.setup;
const access = vendor.getTransportAccess(partialSourceSetup); // as ChatStreamInputSchema['access'];
// execute via the vendor
return await vendor.streamingChatGenerateOrThrow(access, llmId, llmOptions, messages, functions, forceFunctionName, abortSignal, onUpdate);
}
@@ -9,6 +9,9 @@ import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, S
import type { AnthropicWire } from './anthropic/anthropic.wiretypes';
import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from './anthropic/anthropic.router';
// Gemini server imports
import { geminiAccessSchema } from './gemini/gemini.router';
// Ollama server imports
import { wireOllamaChunkedOutputSchema } from './ollama/ollama.wiretypes';
import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from './ollama/ollama.router';
@@ -37,24 +40,24 @@ type EventStreamFormat = 'sse' | 'json-nl';
type AIStreamParser = (data: string) => { text: string, close: boolean };
const chatStreamInputSchema = z.object({
access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]),
const chatStreamingInputSchema = z.object({
access: z.union([anthropicAccessSchema, geminiAccessSchema, ollamaAccessSchema, openAIAccessSchema]),
model: openAIModelSchema,
history: openAIHistorySchema,
});
export type ChatStreamInputSchema = z.infer<typeof chatStreamInputSchema>;
export type ChatStreamingInputSchema = z.infer<typeof chatStreamingInputSchema>;
const chatStreamFirstOutputPacketSchema = z.object({
const chatStreamingFirstOutputPacketSchema = z.object({
model: z.string(),
});
export type ChatStreamFirstOutputPacketSchema = z.infer<typeof chatStreamFirstOutputPacketSchema>;
export type ChatStreamingFirstOutputPacketSchema = z.infer<typeof chatStreamingFirstOutputPacketSchema>;
export async function llmStreamingRelayHandler(req: NextRequest): Promise<Response> {
// inputs - reuse the tRPC schema
const body = await req.json();
const { access, model, history } = chatStreamInputSchema.parse(body);
const { access, model, history } = chatStreamingInputSchema.parse(body);
// access/dialect dependent setup:
// - requestAccess: the headers and URL to use for the upstream API call
@@ -240,7 +243,7 @@ function createAnthropicStreamParser(): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun) {
hasBegun = true;
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model };
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model };
text = JSON.stringify(firstPacket) + text;
}
@@ -276,7 +279,7 @@ function createOllamaChatCompletionStreamParser(): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun && chunk.model) {
hasBegun = true;
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: chunk.model };
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model };
text = JSON.stringify(firstPacket) + text;
}
@@ -317,7 +320,7 @@ function createOpenAIStreamParser(): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun) {
hasBegun = true;
const firstPacket: ChatStreamFirstOutputPacketSchema = { model: json.model };
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model };
text = JSON.stringify(firstPacket) + text;
}
+12 -2
View File
@@ -1,10 +1,10 @@
import type React from 'react';
import type { TRPCClientErrorBase } from '@trpc/client';
import type { DLLM, DModelSourceId } from '../store-llms';
import type { DLLM, DLLMId, DModelSourceId } from '../store-llms';
import type { ModelDescriptionSchema } from '../server/llm.server.types';
import type { ModelVendorId } from './vendors.registry';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../client/llm.client.types';
import type { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '~/modules/llms/llm.client';
export interface IModelVendor<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown, TDLLM = DLLM<TSourceSetup, TLLMOptions>> {
@@ -43,4 +43,14 @@ export interface IModelVendor<TSourceSetup = unknown, TAccess = unknown, TLLMOpt
maxTokens?: number,
) => Promise<VChatMessageOut | VChatMessageOrFunctionCallOut>;
streamingChatGenerateOrThrow: (
access: TAccess,
llmId: DLLMId,
llmOptions: TLLMOptions,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
abortSignal: AbortSignal,
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
) => Promise<void>;
}
+5 -1
View File
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
import type { AnthropicAccessSchema } from '../../server/anthropic/anthropic.router';
import type { IModelVendor } from '../IModelVendor';
import type { VChatMessageOut } from '../../client/llm.client.types';
import type { VChatMessageOut } from '../../llm.client';
import { unifiedStreamingClient } from '../unifiedStreamingClient';
import { LLMOptionsOpenAI } from '../openai/openai.vendor';
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
@@ -77,4 +78,7 @@ export const ModelVendorAnthropic: IModelVendor<SourceSetupAnthropic, AnthropicA
}
},
// Chat Generate (streaming) with Functions
streamingChatGenerateOrThrow: unifiedStreamingClient,
};
+1
View File
@@ -61,4 +61,5 @@ export const ModelVendorAzure: IModelVendor<SourceSetupAzure, OpenAIAccessSchema
// OpenAI transport ('azure' dialect in 'access')
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
};
@@ -6,7 +6,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
import type { GeminiAccessSchema } from '../../server/gemini/gemini.router';
import type { IModelVendor } from '../IModelVendor';
import type { VChatMessageOut } from '../../client/llm.client.types';
import type { VChatMessageOut } from '../../llm.client';
import { unifiedStreamingClient } from '../unifiedStreamingClient';
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
@@ -86,4 +87,7 @@ export const ModelVendorGemini: IModelVendor<SourceSetupGemini, GeminiAccessSche
}
},
// Chat Generate (streaming) with Functions
streamingChatGenerateOrThrow: unifiedStreamingClient,
};
+1
View File
@@ -41,4 +41,5 @@ export const ModelVendorLocalAI: IModelVendor<SourceSetupLocalAI, OpenAIAccessSc
// OpenAI transport ('localai' dialect in 'access')
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
};
+1
View File
@@ -51,4 +51,5 @@ export const ModelVendorMistral: IModelVendor<SourceSetupMistral, OpenAIAccessSc
// OpenAI transport ('mistral' dialect in 'access')
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
};
+5 -1
View File
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
import type { IModelVendor } from '../IModelVendor';
import type { OllamaAccessSchema } from '../../server/ollama/ollama.router';
import type { VChatMessageOut } from '../../client/llm.client.types';
import type { VChatMessageOut } from '../../llm.client';
import { unifiedStreamingClient } from '../unifiedStreamingClient';
import type { LLMOptionsOpenAI } from '../openai/openai.vendor';
import { OpenAILLMOptions } from '../openai/OpenAILLMOptions';
@@ -70,4 +71,7 @@ export const ModelVendorOllama: IModelVendor<SourceSetupOllama, OllamaAccessSche
}
},
// Chat Generate (streaming) with Functions
streamingChatGenerateOrThrow: unifiedStreamingClient,
};
@@ -41,4 +41,5 @@ export const ModelVendorOoobabooga: IModelVendor<SourceSetupOobabooga, OpenAIAcc
// OpenAI transport (oobabooga dialect in 'access')
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
};
+5 -1
View File
@@ -5,7 +5,8 @@ import { apiAsync, apiQuery } from '~/common/util/trpc.client';
import type { IModelVendor } from '../IModelVendor';
import type { OpenAIAccessSchema } from '../../server/openai/openai.router';
import type { VChatMessageOrFunctionCallOut } from '../../client/llm.client.types';
import type { VChatMessageOrFunctionCallOut } from '../../llm.client';
import { unifiedStreamingClient } from '../unifiedStreamingClient';
import { OpenAILLMOptions } from './OpenAILLMOptions';
import { OpenAISourceSetup } from './OpenAISourceSetup';
@@ -84,4 +85,7 @@ export const ModelVendorOpenAI: IModelVendor<SourceSetupOpenAI, OpenAIAccessSche
}
},
// Chat Generate (streaming) with Functions
streamingChatGenerateOrThrow: unifiedStreamingClient,
};
@@ -62,4 +62,5 @@ export const ModelVendorOpenRouter: IModelVendor<SourceSetupOpenRouter, OpenAIAc
// OpenAI transport ('openrouter' dialect in 'access')
rpcUpdateModelsQuery: ModelVendorOpenAI.rpcUpdateModelsQuery,
rpcChatGenerateOrThrow: ModelVendorOpenAI.rpcChatGenerateOrThrow,
streamingChatGenerateOrThrow: ModelVendorOpenAI.streamingChatGenerateOrThrow,
};
@@ -1,13 +1,11 @@
import { apiAsync } from '~/common/util/trpc.client';
import type { ChatStreamFirstOutputPacketSchema, ChatStreamInputSchema } from '../server/llms.streaming';
import type { DLLM, DLLMId } from '../store-llms';
import { findVendorForLlmOrThrow } from '../vendors/vendors.registry';
import type { ChatStreamingFirstOutputPacketSchema, ChatStreamingInputSchema } from '../server/llm.server.streaming';
import type { DLLMId } from '../store-llms';
import type { VChatFunctionIn, VChatMessageIn } from '../llm.client';
import type { OpenAIWire } from '../server/openai/openai.wiretypes';
import type { VChatMessageIn } from './llm.client.types';
/**
* Client side chat generation, with streaming. This decodes the (text) streaming response from
@@ -16,27 +14,14 @@ import type { VChatMessageIn } from './llm.client.types';
* Vendor-specific implementation is on our server backend (API) code. This function tries to be
* as generic as possible.
*
* @param llmId LLM to use
* @param messages the history of messages to send to the API endpoint
* @param abortSignal used to initiate a client-side abort of the fetch request to the API endpoint
* @param onUpdate callback when a piece of a message (text, model name, typing..) is received
* NOTE: onUpdate is callback when a piece of a message (text, model name, typing..) is received
*/
export async function llmStreamChatGenerate(
export async function unifiedStreamingClient<TSourceSetup = unknown, TLLMOptions = unknown>(
access: ChatStreamingInputSchema['access'],
llmId: DLLMId,
llmOptions: TLLMOptions,
messages: VChatMessageIn[],
abortSignal: AbortSignal,
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
): Promise<void> {
const { llm, vendor } = findVendorForLlmOrThrow(llmId);
const access = vendor.getTransportAccess(llm._source.setup) as ChatStreamInputSchema['access'];
return await vendorStreamChat(access, llm, messages, abortSignal, onUpdate);
}
async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
access: ChatStreamInputSchema['access'],
llm: DLLM<TSourceSetup, TLLMOptions>,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
abortSignal: AbortSignal,
onUpdate: (update: Partial<{ text: string, typing: boolean, originLLM: string }>, done: boolean) => void,
) {
@@ -80,12 +65,12 @@ async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
}
// model params (llm)
const { llmRef, llmTemperature, llmResponseTokens } = (llm.options as any) || {};
const { llmRef, llmTemperature, llmResponseTokens } = (llmOptions as any) || {};
if (!llmRef || llmTemperature === undefined || llmResponseTokens === undefined)
throw new Error(`Error in configuration for model ${llm.id}: ${JSON.stringify(llm.options)}`);
throw new Error(`Error in configuration for model ${llmId}: ${JSON.stringify(llmOptions)}`);
// prepare the input, similarly to the tRPC openAI.chatGenerate
const input: ChatStreamInputSchema = {
const input: ChatStreamingInputSchema = {
access,
model: {
id: llmRef,
@@ -132,7 +117,7 @@ async function vendorStreamChat<TSourceSetup = unknown, TLLMOptions = unknown>(
incrementalText = incrementalText.substring(endOfJson + 1);
parsedFirstPacket = true;
try {
const parsed: ChatStreamFirstOutputPacketSchema = JSON.parse(json);
const parsed: ChatStreamingFirstOutputPacketSchema = JSON.parse(json);
onUpdate({ originLLM: parsed.model }, false);
} catch (e) {
// error parsing JSON, ignore
+1 -1
View File
@@ -1,6 +1,6 @@
import { ModelVendorAnthropic } from './anthropic/anthropic.vendor';
import { ModelVendorAzure } from './azure/azure.vendor';
import { ModelVendorGemini } from './googleai/gemini.vendor';
import { ModelVendorGemini } from './gemini/gemini.vendor';
import { ModelVendorLocalAI } from './localai/localai.vendor';
import { ModelVendorMistral } from './mistral/mistral.vendor';
import { ModelVendorOllama } from './ollama/ollama.vendor';