diff --git a/docs/environment-variables.md b/docs/environment-variables.md index d6d021c16..37081be5c 100644 --- a/docs/environment-variables.md +++ b/docs/environment-variables.md @@ -28,6 +28,7 @@ AZURE_OPENAI_API_ENDPOINT= AZURE_OPENAI_API_KEY= ANTHROPIC_API_KEY= ANTHROPIC_API_HOST= +OLLAMA_API_HOST= OPENROUTER_API_KEY= # Model Observability: Helicone @@ -73,6 +74,7 @@ requiring the user to enter an API key | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key, see [config-azure-openai.md](config-azure-openai.md) | Optional, but if set `AZURE_OPENAI_API_ENDPOINT` must also be set | | `ANTHROPIC_API_KEY` | The API key for Anthropic | Optional | | `ANTHROPIC_API_HOST` | Changes the backend host for the Anthropic vendor, to enable platforms such as [config-aws-bedrock.md](config-aws-bedrock.md) | Optional | +| `OLLAMA_API_HOST` | Changes the backend host for the Ollama vendor. See [config-ollama.md](config-ollama.md) | | | `OPENROUTER_API_KEY` | The API key for OpenRouter | Optional | ### Model Observability: Helicone diff --git a/next.config.js b/next.config.js index 5f19faedf..dcc711b0d 100644 --- a/next.config.js +++ b/next.config.js @@ -7,6 +7,7 @@ let nextConfig = { HAS_SERVER_KEY_ANTHROPIC: !!process.env.ANTHROPIC_API_KEY, HAS_SERVER_KEY_AZURE_OPENAI: !!process.env.AZURE_OPENAI_API_KEY && !!process.env.AZURE_OPENAI_API_ENDPOINT, HAS_SERVER_KEY_ELEVENLABS: !!process.env.ELEVENLABS_API_KEY, + HAS_SERVER_HOST_OLLAMA: !!process.env.OLLAMA_API_HOST, HAS_SERVER_KEY_OPENAI: !!process.env.OPENAI_API_KEY, HAS_SERVER_KEY_OPENROUTER: !!process.env.OPENROUTER_API_KEY, HAS_SERVER_KEY_PRODIA: !!process.env.PRODIA_API_KEY, diff --git a/pages/api/elevenlabs/speech.ts b/pages/api/elevenlabs/speech.ts index a460009f4..065936b24 100644 --- a/pages/api/elevenlabs/speech.ts +++ b/pages/api/elevenlabs/speech.ts @@ -1,8 +1,7 @@ import { NextRequest, NextResponse } from 'next/server'; -import { safeErrorString, serverFetchOrThrow } from '~/server/wire'; +import { createEmptyReadableStream, safeErrorString, serverFetchOrThrow } from '~/server/wire'; -import { createEmptyReadableStream } from '~/modules/llms/transports/server/openai/openai.streaming'; import { elevenlabsAccess, elevenlabsVoiceId, ElevenlabsWire, speechInputSchema } from '~/modules/elevenlabs/elevenlabs.router'; diff --git a/pages/api/llms/stream.ts b/pages/api/llms/stream.ts index 88883ee09..f8bdc6088 100644 --- a/pages/api/llms/stream.ts +++ b/pages/api/llms/stream.ts @@ -1,4 +1,4 @@ -export { openaiStreamingResponse as default } from '~/modules/llms/transports/server/openai/openai.streaming'; +export { openaiStreamingRelayHandler as default } from '~/modules/llms/transports/server/openai/openai.streaming'; // noinspection JSUnusedGlobalSymbols export const runtime = 'edge'; \ No newline at end of file diff --git a/src/common/types/env.d.ts b/src/common/types/env.d.ts index d3a8f4bc7..f4bf9bfb0 100644 --- a/src/common/types/env.d.ts +++ b/src/common/types/env.d.ts @@ -24,6 +24,9 @@ declare namespace NodeJS { ANTHROPIC_API_KEY?: string; ANTHROPIC_API_HOST?: string; + // LLM: Ollama + OLLAMA_API_HOST?: string; + // LLM: OpenRouter OPENROUTER_API_KEY: string; @@ -52,6 +55,7 @@ declare namespace NodeJS { HAS_SERVER_KEY_ANTHROPIC?: boolean; HAS_SERVER_KEY_AZURE_OPENAI?: boolean; HAS_SERVER_KEY_ELEVENLABS: boolean; + HAS_SERVER_HOST_OLLAMA?: boolean; HAS_SERVER_KEY_OPENAI?: boolean; HAS_SERVER_KEY_OPENROUTER?: boolean; HAS_SERVER_KEY_PRODIA: boolean; diff --git a/src/common/util/modelUtils.ts b/src/common/util/modelUtils.ts index 43459aa73..fd36dba02 100644 --- a/src/common/util/modelUtils.ts +++ b/src/common/util/modelUtils.ts @@ -9,5 +9,5 @@ export function prettyBaseModel(model: string | undefined): string { if (model.includes('gpt-3.5-turbo-16k')) return '3.5 Turbo 16k'; if (model.includes('gpt-3.5-turbo')) return '3.5 Turbo'; if (model.endsWith('.bin')) return model.slice(0, -4); - return model; + return model.replaceAll(':', ' '); } \ No newline at end of file diff --git a/src/modules/llms/transports/server/ollama/ollama.router.ts b/src/modules/llms/transports/server/ollama/ollama.router.ts new file mode 100644 index 000000000..7f94c51d5 --- /dev/null +++ b/src/modules/llms/transports/server/ollama/ollama.router.ts @@ -0,0 +1,269 @@ +import { z } from 'zod'; + +import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server'; +import { fetchJsonOrTRPCError, fetchTextOrTRPCError } from '~/server/api/trpc.serverutils'; + +import { LLM_IF_OAI_Chat } from '../../../store-llms'; + +import { capitalizeFirstLetter } from '~/common/util/textUtils'; + +import { fixupHost, openAIChatGenerateOutputSchema, openAIHistorySchema, openAIModelSchema } from '../openai/openai.router'; +import { listModelsOutputSchema, ModelDescriptionSchema } from '../server.schemas'; + +import { wireOllamaGenerationSchema } from './ollama.wiretypes'; + + +/** + * This is here because the API does not provide a list of available upstream models, and does not provide + * descriptions for the models. + * (nor does it reliably provide context window sizes) - TODO: open a bug upstream + * + * from: https://ollama.ai/library?sort=popular + */ +const OLLAMA_BASE_MODELS: { [key: string]: string } = { + 'mistral': 'The Mistral 7B model released by Mistral AI', + 'llama2': 'The most popular model for general use.', + 'codellama': 'A large language model that can use text prompts to generate and discuss code.', + 'vicuna': 'General use chat model based on Llama and Llama 2 with 2K to 16K context sizes.', + 'llama2-uncensored': 'Uncensored Llama 2 model by George Sung and Jarrad Hope.', + 'orca-mini': 'A general-purpose model ranging from 3 billion parameters to 70 billion, suitable for entry-level hardware.', + 'wizard-vicuna-uncensored': 'Wizard Vicuna Uncensored is a 7B, 13B, and 30B parameter model based on Llama 2 uncensored by Eric Hartford.', + 'nous-hermes': 'General use models based on Llama and Llama 2 from Nous Research.', + 'phind-codellama': 'Code generation model based on CodeLlama.', + 'mistral-openorca': 'Mistral OpenOrca is a 7 billion parameter model, fine-tuned on top of the Mistral 7B model using the OpenOrca dataset.', + 'wizardcoder': 'Llama based code generation model focused on Python.', + 'wizard-math': 'Model focused on math and logic problems', + 'llama2-chinese': 'Llama 2 based model fine tuned to improve Chinese dialogue ability.', + 'stable-beluga': 'Llama 2 based model fine tuned on an Orca-style dataset. Originally called Free Willy.', + 'zephyr': 'Zephyr beta is a fine-tuned 7B version of mistral that was trained on on a mix of publicly available, synthetic datasets.', + 'codeup': 'Great code generation model based on Llama2.', + 'falcon': 'A large language model built by the Technology Innovation Institute (TII) for use in summarization, text generation, and chat bots.', + 'everythinglm': 'Uncensored Llama2 based model with 16k context size.', + 'wizardlm-uncensored': 'Uncensored version of Wizard LM model', + 'medllama2': 'Fine-tuned Llama 2 model to answer medical questions based on an open source medical dataset.', + 'wizard-vicuna': 'Wizard Vicuna is a 13B parameter model based on Llama 2 trained by MelodysDreamj.', + 'open-orca-platypus2': 'Merge of the Open Orca OpenChat model and the Garage-bAInd Platypus 2 model. Designed for chat and code generation.', + 'starcoder': 'StarCoder is a code generation model trained on 80+ programming languages.', + 'samantha-mistral': 'A companion assistant trained in philosophy, psychology, and personal relationships. Based on Mistral.', + 'openhermes2-mistral': 'OpenHermes 2 Mistral is a 7B model fine-tuned on Mistral with 900,000 entries of primarily GPT-4 generated data from open datasets.', + 'wizardlm': 'General use 70 billion parameter model based on Llama 2.', + 'sqlcoder': 'SQLCoder is a code completion model fined-tuned on StarCoder for SQL generation tasks', + 'dolphin2.2-mistral': 'An instruct-tuned model based on Mistral. Version 2.2 is fine-tuned for improved conversation and empathy.', + 'dolphin2.1-mistral': 'An instruct-tuned model based on Mistral and trained on a dataset filtered to remove alignment and bias.', + 'yarn-mistral': 'An extension of Mistral to support a context of up to 128k tokens.', + 'codebooga': 'A high-performing code instruct model created by merging two existing code models.', + 'openhermes2.5-mistral': 'OpenHermes 2.5 Mistral 7B is a Mistral 7B fine-tune, a continuation of OpenHermes 2 model, which trained on additional code datasets.', + 'mistrallite': 'MistralLite is a fine-tuned model based on Mistral with enhanced capabilities of processing long contexts.', + 'nexusraven': 'Nexus Raven is a 13B instruction tuned model for function calling tasks.', + 'yarn-llama2': 'An extension of Llama 2 that supports a context of up to 128k tokens.', + 'xwinlm': 'Conversational model based on Llama 2 that performs competitively on various benchmarks.', +}; + +// Input Schemas + +export const ollamaAccessSchema = z.object({ + dialect: z.enum(['ollama']), + ollamaHost: z.string().trim(), +}); +export type OllamaAccessSchema = z.infer; + +const accessOnlySchema = z.object({ + access: ollamaAccessSchema, +}); + +const adminPullModelSchema = z.object({ + access: ollamaAccessSchema, + name: z.string(), +}); + +const chatGenerateInputSchema = z.object({ + access: ollamaAccessSchema, + model: openAIModelSchema, history: openAIHistorySchema, + // functions: openAIFunctionsSchema.optional(), forceFunctionName: z.string().optional(), +}); + + +// Output Schemas + +const listPullableOutputSchema = z.object({ + pullable: z.array(z.object({ + id: z.string(), + label: z.string(), + tag: z.string(), + description: z.string(), + })), +}); + + +export const llmOllamaRouter = createTRPCRouter({ + + /* Ollama: models that can be pulled */ + adminListPullable: publicProcedure + .input(accessOnlySchema) + .output(listPullableOutputSchema) + .query(async ({}) => { + return { + pullable: Object.entries(OLLAMA_BASE_MODELS).map(([model, description]) => ({ + id: model, + label: capitalizeFirstLetter(model), + tag: 'latest', + description, + })), + }; + }), + + /* Ollama: pull a model */ + adminPull: publicProcedure + .input(adminPullModelSchema) + .mutation(async ({ input }) => { + + // fetch as a large text buffer, made of JSONs separated by newlines + const { headers, url } = ollamaAccess(input.access, '/api/pull'); + const pullRequest = await fetchTextOrTRPCError(url, 'POST', headers, { 'name': input.name }, 'Ollama::pull'); + + // accumulate status and error messages + let lastStatus: string = 'unknown'; + let lastError: string | undefined = undefined; + for (let string of pullRequest.trim().split('\n')) { + const message = JSON.parse(string); + if (message.status) + lastStatus = input.name + ': ' + message.status; + if (message.error) + lastError = message.error; + } + + return { status: lastStatus, error: lastError }; + }), + + /* Ollama: List the Models available */ + listModels: publicProcedure + .input(accessOnlySchema) + .output(listModelsOutputSchema) + .query(async ({ input }) => { + + // get the models + const wireModels = await ollamaGET(input.access, '/api/tags'); + const wireOllamaListModelsSchema = z.object({ + models: z.array(z.object({ + name: z.string(), + modified_at: z.string(), + size: z.number(), + digest: z.string(), + })), + }); + let models = wireOllamaListModelsSchema.parse(wireModels).models; + + // retrieve info for each of the models (/api/show, post call, in parallel) + const detailedModels = await Promise.all(models.map(async model => { + const wireModelInfo = await ollamaPOST(input.access, { 'name': model.name }, '/api/show'); + const wireOllamaModelInfoSchema = z.object({ + license: z.string().optional(), + modelfile: z.string(), + parameters: z.string(), + template: z.string(), + }); + const modelInfo = wireOllamaModelInfoSchema.parse(wireModelInfo); + return { ...model, ...modelInfo }; + })); + + return { + models: detailedModels.map(model => { + // the model name is in the format "name:tag" (default tag = 'latest') + const [modelName, modelTag] = model.name.split(':'); + + // pretty label and description + const label = capitalizeFirstLetter(modelName) + ((modelTag && modelTag !== 'latest') ? ` · ${modelTag}` : ''); + const description = OLLAMA_BASE_MODELS[modelName] ?? 'Model unknown'; + + // console.log('>>> ollama model', model.name, model.template, model.modelfile, '\n'); + + return { + id: model.name, + label, + created: Date.parse(model.modified_at) ?? undefined, + updated: Date.parse(model.modified_at) ?? undefined, + description: description, // description: (model.license ? `License: ${model.license}. Info: ` : '') + model.modelfile || 'Model unknown', + contextWindow: 4096, // FIXME: request this information upstream? + interfaces: [LLM_IF_OAI_Chat], + } satisfies ModelDescriptionSchema; + }), + }; + }), + + /* Ollama: Chat generation */ + chatGenerate: publicProcedure + .input(chatGenerateInputSchema) + .output(openAIChatGenerateOutputSchema) + .mutation(async ({ input: { access, history, model } }) => { + + const wireGeneration = await ollamaPOST(access, ollamaChatCompletionPayload(model, history, false), '/api/generate'); + const generation = wireOllamaGenerationSchema.parse(wireGeneration); + + return { + role: 'assistant', + content: generation.response, + finish_reason: generation.done ? 'stop' : null, + }; + }), + +}); + + +type ModelSchema = z.infer; +type HistorySchema = z.infer; + +async function ollamaGET(access: OllamaAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = ollamaAccess(access, apiPath); + return await fetchJsonOrTRPCError(url, 'GET', headers, undefined, 'Ollama'); +} + +async function ollamaPOST(access: OllamaAccessSchema, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise { + const { headers, url } = ollamaAccess(access, apiPath); + return await fetchJsonOrTRPCError(url, 'POST', headers, body, 'Ollama'); +} + + +const DEFAULT_OLLAMA_HOST = 'http://127.0.0.1:11434'; + +export function ollamaAccess(access: OllamaAccessSchema, apiPath: string): { headers: HeadersInit, url: string } { + + const ollamaHost = fixupHost(access.ollamaHost || process.env.OLLAMA_API_HOST || DEFAULT_OLLAMA_HOST, apiPath); + + return { + headers: { + 'Content-Type': 'application/json', + }, + url: ollamaHost + apiPath, + }; + +} + +export function ollamaChatCompletionPayload(model: ModelSchema, history: HistorySchema, stream: boolean) { + + // if the first message is the system prompt, extract it + let systemPrompt: string | undefined = undefined; + if (history.length && history[0].role === 'system') { + const [firstMessage, ...rest] = history; + systemPrompt = firstMessage.content; + history = rest; + } + + // encode the prompt for ollama, assuming the same template for everyone for now + const prompt = history.map(({ role, content }) => { + return role === 'assistant' ? `\n\nAssistant: ${content}` : `\n\nHuman: ${content}`; + }).join('') + '\n\nAssistant:\n'; + + // const prompt = history.map(({ role, content }) => { + // return role === 'assistant' ? `### Response:\n${content}\n\n` : `### User:\n${content}\n\n`; + // }).join('') + '### Response:\n'; + + return { + model: model.id, + prompt, + options: { + ...(model.temperature && { temperature: model.temperature }), + }, + ...(systemPrompt && { system: systemPrompt }), + stream, + }; +} diff --git a/src/modules/llms/transports/server/ollama/ollama.wiretypes.ts b/src/modules/llms/transports/server/ollama/ollama.wiretypes.ts new file mode 100644 index 000000000..c9936e3bc --- /dev/null +++ b/src/modules/llms/transports/server/ollama/ollama.wiretypes.ts @@ -0,0 +1,16 @@ +import { z } from 'zod'; + +export const wireOllamaGenerationSchema = z.object({ + model: z.string(), + // created_at: z.string(), // commented because unused + response: z.string(), + done: z.boolean(), + + // only on the last message + // context: z.array(z.number()), + // total_duration: z.number(), + // load_duration: z.number(), + // eval_duration: z.number(), + // prompt_eval_count: z.number(), + // eval_count: z.number(), +}); diff --git a/src/modules/llms/transports/server/openai/openai.streaming.ts b/src/modules/llms/transports/server/openai/openai.streaming.ts index 805b94e61..e704d3f5a 100644 --- a/src/modules/llms/transports/server/openai/openai.streaming.ts +++ b/src/modules/llms/transports/server/openai/openai.streaming.ts @@ -1,160 +1,32 @@ import { z } from 'zod'; import { NextRequest, NextResponse } from 'next/server'; -import { createParser as createEventsourceParser, EventSourceParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; +import { createParser as createEventsourceParser, EventSourceParseCallback, EventSourceParser, ParsedEvent, ReconnectInterval } from 'eventsource-parser'; -import { debugGenerateCurlCommand, safeErrorString, SERVER_DEBUG_WIRE, serverFetchOrThrow } from '~/server/wire'; +import { createEmptyReadableStream, debugGenerateCurlCommand, safeErrorString, SERVER_DEBUG_WIRE, serverFetchOrThrow } from '~/server/wire'; import type { AnthropicWire } from '../anthropic/anthropic.wiretypes'; import type { OpenAIWire } from './openai.wiretypes'; import { anthropicAccess, anthropicAccessSchema, anthropicChatCompletionPayload } from '../anthropic/anthropic.router'; +import { ollamaAccess, ollamaAccessSchema, ollamaChatCompletionPayload } from '../ollama/ollama.router'; import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai.router'; +import { wireOllamaGenerationSchema } from '../ollama/ollama.wiretypes'; /** * Vendor stream parsers * - The vendor can decide to terminate the connection (close: true), transmitting anything in 'text' before doing so * - The vendor can also throw from this function, which will error and terminate the connection + * + * The peculiarity of our parser is the injection of a JSON structure at the beginning of the stream, to + * communicate parameters before the text starts flowing to the client. */ type AIStreamParser = (data: string) => { text: string, close: boolean }; - -// The peculiarity of our parser is the injection of a JSON structure at the beginning of the stream, to -// communicate parameters before the text starts flowing to the client. -function parseOpenAIStream(): AIStreamParser { - let hasBegun = false; - let hasWarned = false; - - return data => { - - const json: OpenAIWire.ChatCompletion.ResponseStreamingChunk = JSON.parse(data); - - // [OpenAI] an upstream error will be handled gracefully and transmitted as text (throw to transmit as 'error') - if (json.error) - return { text: `[OpenAI Issue] ${safeErrorString(json.error)}`, close: true }; - - if (json.choices.length !== 1) { - // [Azure] we seem to 'prompt_annotations' or 'prompt_filter_results' objects - which we will ignore to suppress the error - if (json.id === '' && json.object === '' && json.model === '') - return { text: '', close: false }; - throw new Error(`Expected 1 completion, got ${json.choices.length}`); - } - - const index = json.choices[0].index; - if (index !== 0 && index !== undefined /* LocalAI hack/workaround until https://github.com/go-skynet/LocalAI/issues/788 */) - throw new Error(`Expected completion index 0, got ${index}`); - let text = json.choices[0].delta?.content /*|| json.choices[0]?.text*/ || ''; - - // hack: prepend the model name to the first packet - if (!hasBegun) { - hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { - model: json.model, - }; - text = JSON.stringify(firstPacket) + text; - } - - // if there's a warning, log it once - if (json.warning && !hasWarned) { - hasWarned = true; - console.log('/api/llms/stream: OpenAI stream warning:', json.warning); - } - - // workaround: LocalAI doesn't send the [DONE] event, but similarly to OpenAI, it sends a "finish_reason" delta update - const close = !!json.choices[0].finish_reason; - return { text, close }; - }; -} - - -// Anthropic event stream parser -function parseAnthropicStream(): AIStreamParser { - let hasBegun = false; - - return data => { - - const json: AnthropicWire.Complete.Response = JSON.parse(data); - let text = json.completion; - - // hack: prepend the model name to the first packet - if (!hasBegun) { - hasBegun = true; - const firstPacket: ChatStreamFirstPacketSchema = { - model: json.model, - }; - text = JSON.stringify(firstPacket) + text; - } - - return { text, close: false }; - }; -} - - -/** - * Creates a TransformStream that parses events from an EventSource stream using a custom parser. - * @returns {TransformStream} TransformStream parsing events. - */ -function createEventStreamTransformer(vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { - const textDecoder = new TextDecoder(); - const textEncoder = new TextEncoder(); - let eventSourceParser: EventSourceParser; - - return new TransformStream({ - start: async (controller): Promise => { - - // only used for debugging - let debugLastMs: number | null = null; - - eventSourceParser = createEventsourceParser( - (event: ParsedEvent | ReconnectInterval) => { - - if (SERVER_DEBUG_WIRE) { - const nowMs = Date.now(); - const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; - debugLastMs = nowMs; - console.log(`<- SSE (${elapsedMs} ms):`, event); - } - - // ignore 'reconnect-interval' and events with no data - if (event.type !== 'event' || !('data' in event)) - return; - - // event stream termination, close our transformed stream - if (event.data === '[DONE]') { - controller.terminate(); - return; - } - - try { - const { text, close } = vendorTextParser(event.data); - if (text) - controller.enqueue(textEncoder.encode(text)); - if (close) - controller.terminate(); - } catch (error: any) { - // console.log(`/api/llms/stream: parse issue: ${error?.message || error}`); - controller.enqueue(textEncoder.encode(`[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}`)); - controller.terminate(); - } - }, - ); - }, - - // stream=true is set because the data is not guaranteed to be final and un-chunked - transform: (chunk: Uint8Array) => { - eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); - }, - }); -} - -export function createEmptyReadableStream(): ReadableStream { - return new ReadableStream({ - start: (controller) => controller.close(), - }); -} +type EventStreamFormat = 'sse' | 'json-nl'; const chatStreamInputSchema = z.object({ - access: z.union([openAIAccessSchema, anthropicAccessSchema]), + access: z.union([anthropicAccessSchema, ollamaAccessSchema, openAIAccessSchema]), model: openAIModelSchema, history: openAIHistorySchema, }); export type ChatStreamInputSchema = z.infer; @@ -164,7 +36,8 @@ const chatStreamFirstPacketSchema = z.object({ }); export type ChatStreamFirstPacketSchema = z.infer; -export async function openaiStreamingResponse(req: NextRequest): Promise { + +export async function openaiStreamingRelayHandler(req: NextRequest): Promise { // inputs - reuse the tRPC schema const { access, model, history } = chatStreamInputSchema.parse(await req.json()); @@ -173,6 +46,7 @@ export async function openaiStreamingResponse(req: NextRequest): Promise streaming curl', debugGenerateCurlCommand('POST', headersUrl.url, headersUrl.headers, body)); + console.log('-> streaming:', debugGenerateCurlCommand('POST', headersUrl.url, headersUrl.headers, body)); // POST to our API route upstreamResponse = await serverFetchOrThrow(headersUrl.url, 'POST', headersUrl.headers, body); @@ -205,7 +86,7 @@ export async function openaiStreamingResponse(req: NextRequest): Promise { + + const json: AnthropicWire.Complete.Response = JSON.parse(data); + let text = json.completion; + + // hack: prepend the model name to the first packet + if (!hasBegun) { + hasBegun = true; + const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + text = JSON.stringify(firstPacket) + text; + } + + return { text, close: false }; + }; +} + +function createOllamaStreamParser(): AIStreamParser { + let hasBegun = false; + + return (data: string) => { + + const wireGeneration = JSON.parse(data); + const generation = wireOllamaGenerationSchema.parse(wireGeneration); + let text = generation.response; + + // hack: prepend the model name to the first packet + if (!hasBegun) { + hasBegun = true; + const firstPacket: ChatStreamFirstPacketSchema = { model: generation.model }; + text = JSON.stringify(firstPacket) + text; + } + + return { text, close: generation.done }; + }; +} + +function createOpenAIStreamParser(): AIStreamParser { + let hasBegun = false; + let hasWarned = false; + + return (data: string) => { + + const json: OpenAIWire.ChatCompletion.ResponseStreamingChunk = JSON.parse(data); + + // [OpenAI] an upstream error will be handled gracefully and transmitted as text (throw to transmit as 'error') + if (json.error) + return { text: `[OpenAI Issue] ${safeErrorString(json.error)}`, close: true }; + + // [OpenAI] if there's a warning, log it once + if (json.warning && !hasWarned) { + hasWarned = true; + console.log('/api/llms/stream: OpenAI upstream warning:', json.warning); + } + + if (json.choices.length !== 1) { + // [Azure] we seem to 'prompt_annotations' or 'prompt_filter_results' objects - which we will ignore to suppress the error + if (json.id === '' && json.object === '' && json.model === '') + return { text: '', close: false }; + throw new Error(`Expected 1 completion, got ${json.choices.length}`); + } + + const index = json.choices[0].index; + if (index !== 0 && index !== undefined /* LocalAI hack/workaround until https://github.com/go-skynet/LocalAI/issues/788 */) + throw new Error(`Expected completion index 0, got ${index}`); + let text = json.choices[0].delta?.content /*|| json.choices[0]?.text*/ || ''; + + // hack: prepend the model name to the first packet + if (!hasBegun) { + hasBegun = true; + const firstPacket: ChatStreamFirstPacketSchema = { model: json.model }; + text = JSON.stringify(firstPacket) + text; + } + + // [LocalAI] workaround: LocalAI doesn't send the [DONE] event, but similarly to OpenAI, it sends a "finish_reason" delta update + const close = !!json.choices[0].finish_reason; + return { text, close }; + }; +} + + +// Event Stream Transformers + +/** + * Creates a TransformStream that parses events from an EventSource stream using a custom parser. + * @returns {TransformStream} TransformStream parsing events. + */ +function createEventStreamTransformer(vendorTextParser: AIStreamParser, inputFormat: EventStreamFormat, dialectLabel: string): TransformStream { + const textDecoder = new TextDecoder(); + const textEncoder = new TextEncoder(); + let eventSourceParser: EventSourceParser; + + return new TransformStream({ + start: async (controller): Promise => { + + // only used for debugging + let debugLastMs: number | null = null; + + const onNewEvent = (event: ParsedEvent | ReconnectInterval) => { + if (SERVER_DEBUG_WIRE) { + const nowMs = Date.now(); + const elapsedMs = debugLastMs ? nowMs - debugLastMs : 0; + debugLastMs = nowMs; + console.log(`<- SSE (${elapsedMs} ms):`, event); + } + + // ignore 'reconnect-interval' and events with no data + if (event.type !== 'event' || !('data' in event)) + return; + + // event stream termination, close our transformed stream + if (event.data === '[DONE]') { + controller.terminate(); + return; + } + + try { + const { text, close } = vendorTextParser(event.data); + if (text) + controller.enqueue(textEncoder.encode(text)); + if (close) + controller.terminate(); + } catch (error: any) { + // console.log(`/api/llms/stream: parse issue: ${error?.message || error}`); + controller.enqueue(textEncoder.encode(`[Stream Issue] ${dialectLabel}: ${safeErrorString(error) || 'Unknown stream parsing error'}`)); + controller.terminate(); + } + }; + + if (inputFormat === 'sse') + eventSourceParser = createEventsourceParser(onNewEvent); + else if (inputFormat === 'json-nl') + eventSourceParser = createJsonNewlineParser(onNewEvent); + }, + + // stream=true is set because the data is not guaranteed to be final and un-chunked + transform: (chunk: Uint8Array) => { + eventSourceParser.feed(textDecoder.decode(chunk, { stream: true })); + }, + }); +} + +/** + * Creates a parser for a 'JSON\n' non-event stream, to be swapped with an EventSource parser. + * Ollama is the only vendor that uses this format. + */ +function createJsonNewlineParser(onParse: EventSourceParseCallback): EventSourceParser { + let accumulator: string = ''; + return { + // feeds a new chunk to the parser - we accumulate in case of partial data, and only execute on full lines + feed: (chunk: string): void => { + accumulator += chunk; + if (accumulator.endsWith('\n')) { + for (const jsonString of chunk.split('\n').filter(line => !!line)) { + const mimicEvent: ParsedEvent = { + type: 'event', + id: undefined, + event: undefined, + data: jsonString, + }; + onParse(mimicEvent); + } + accumulator = ''; + } + }, + + // resets the parser state - not useful with our driving of the parser + reset: (): void => { + console.error('createJsonNewlineParser.reset() not implemented'); + }, + }; +} diff --git a/src/modules/llms/vendors/IModelVendor.ts b/src/modules/llms/vendors/IModelVendor.ts index 2276967bb..799b62204 100644 --- a/src/modules/llms/vendors/IModelVendor.ts +++ b/src/modules/llms/vendors/IModelVendor.ts @@ -4,7 +4,7 @@ import type { DLLM, DModelSourceId } from '../store-llms'; import { VChatFunctionIn, VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../transports/chatGenerate'; -export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'oobabooga' | 'openai' | 'openrouter'; +export type ModelVendorId = 'anthropic' | 'azure' | 'localai' | 'ollama' | 'oobabooga' | 'openai' | 'openrouter'; export interface IModelVendor> { diff --git a/src/modules/llms/vendors/ollama/OllamaAdmin.tsx b/src/modules/llms/vendors/ollama/OllamaAdmin.tsx new file mode 100644 index 000000000..4c2841104 --- /dev/null +++ b/src/modules/llms/vendors/ollama/OllamaAdmin.tsx @@ -0,0 +1,102 @@ +import * as React from 'react'; + +import { Box, Button, Divider, FormControl, FormHelperText, FormLabel, Input, Option, Select, Typography } from '@mui/joy'; + +import { GoodModal } from '~/common/components/GoodModal'; +import { apiQuery } from '~/common/util/trpc.client'; +import { settingsGap } from '~/common/theme'; + +import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; + + +export function OllamaAdmin(props: { access: OllamaAccessSchema, onClose: () => void }) { + + // state + const [pullModel, setPullModel] = React.useState('llama2'); + const [pullTag, setPullTag] = React.useState(''); + + // external state + const { data: pullable } = apiQuery.llmOllama.adminListPullable.useQuery({ access: props.access }, { + staleTime: 1000 * 60, + refetchOnWindowFocus: false, + }); + const { + data: pullData, isLoading: isPulling, status: pullStatus, error: pullModelError, + mutate: pullMutate, + } = apiQuery.llmOllama.adminPull.useMutation(); + + // derived state + const pullModelDescription = pullable?.pullable.find(p => p.id === pullModel)?.description ?? null; + + const handlePull = () => { + if (pullModel) { + pullMutate({ + access: props.access, + name: pullModel + (pullTag ? ':' + pullTag : ''), + }); + } + }; + + return ( + + + + + + + We assume your Ollama host is running and models are already available. + However we provide a way to pull models from the Ollama host, for convenience. + + + + + + Name + + + + + + Tag + + setPullTag(event.target.value)} + sx={{ minWidth: 100 }} + slotProps={{ input: { size: 10 } }} // halve the min width + /> + + + + + + + {pullModelError?.message || pullModelDescription} + + {!!pullData?.error + ? {pullData.error} + : !!pullData?.status + ? {pullData.status} + : null + } + + + + + + + + + + ); +} \ No newline at end of file diff --git a/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx new file mode 100644 index 000000000..a82c4d904 --- /dev/null +++ b/src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx @@ -0,0 +1,67 @@ +import * as React from 'react'; + +import { Box, Button } from '@mui/joy'; + +import { FormTextField } from '~/common/components/forms/FormTextField'; +import { InlineError } from '~/common/components/InlineError'; +import { Link } from '~/common/components/Link'; +import { SetupFormRefetchButton } from '~/common/components/forms/SetupFormRefetchButton'; +import { apiQuery } from '~/common/util/trpc.client'; +import { asValidURL } from '~/common/util/urlUtils'; +import { settingsGap } from '~/common/theme'; + +import { DModelSourceId, useModelsStore, useSourceSetup } from '../../store-llms'; +import { ModelVendorOllama } from './ollama.vendor'; +import { OllamaAdmin } from './OllamaAdmin'; +import { modelDescriptionToDLLM } from '../openai/OpenAISourceSetup'; + + +export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) { + + // state + const [adminOpen, setAdminOpen] = React.useState(false); + + // external state + const { source, sourceHasLLMs, access, updateSetup } = + useSourceSetup(props.sourceId, ModelVendorOllama.getAccess); + + // derived state + const { ollamaHost } = access; + + const hostValid = !!asValidURL(ollamaHost); + const hostError = !!ollamaHost && !hostValid; + const shallFetchSucceed = !hostError; + + // fetch models + const { isFetching, refetch, isError, error } = apiQuery.llmOllama.listModels.useQuery({ access }, { + enabled: !sourceHasLLMs && shallFetchSucceed, + onSuccess: models => source && useModelsStore.getState().addLLMs(models.models.map(model => modelDescriptionToDLLM(model, source))), + staleTime: Infinity, + }); + + return + + information} + placeholder='http://127.0.0.1:11434' + isError={hostError} + value={ollamaHost || ''} + onChange={text => updateSetup({ ollamaHost: text })} + /> + + setAdminOpen(true)}> + Ollama Admin + + } + /> + + {isError && } + + {adminOpen && setAdminOpen(false)} />} + + ; +} \ No newline at end of file diff --git a/src/modules/llms/vendors/ollama/ollama.vendor.ts b/src/modules/llms/vendors/ollama/ollama.vendor.ts new file mode 100644 index 000000000..5a1ab993d --- /dev/null +++ b/src/modules/llms/vendors/ollama/ollama.vendor.ts @@ -0,0 +1,86 @@ +import { apiAsync } from '~/common/util/trpc.client'; + +import type { IModelVendor } from '../IModelVendor'; +import type { VChatMessageIn, VChatMessageOrFunctionCallOut, VChatMessageOut } from '../../transports/chatGenerate'; +import type { OllamaAccessSchema } from '../../transports/server/ollama/ollama.router'; + +import { LLMOptionsOpenAI } from '../openai/openai.vendor'; +import { OpenAILLMOptions } from '../openai/OpenAILLMOptions'; + +import { OllamaSourceSetup } from './OllamaSourceSetup'; + + +export interface SourceSetupOllama { + ollamaHost: string; +} + +/** Implementation Notes for the Ollama Vendor + * + * TODO: Work in progress... + * + */ +export const ModelVendorOllama: IModelVendor = { + id: 'ollama', + name: 'Ollama', + rank: 22, + location: 'local', + instanceLimit: 2, + hasServerKey: !!process.env.HAS_SERVER_HOST_OLLAMA, + + // components + Icon: '🦙', + SourceSetupComponent: OllamaSourceSetup, + LLMOptionsComponent: OpenAILLMOptions, + + // functions + getAccess: (partialSetup): OllamaAccessSchema => ({ + dialect: 'ollama', + ollamaHost: partialSetup?.ollamaHost || '', + }), + callChatGenerate(llm, messages: VChatMessageIn[], maxTokens?: number): Promise { + return ollamaCallChatGenerate(this.getAccess(llm._source.setup), llm.options, messages, maxTokens); + }, + callChatGenerateWF(): Promise { + throw new Error('Ollama does not support "Functions" yet'); + }, +}; + + +/** + * This function either returns the LLM message, or throws a descriptive error string + */ +async function ollamaCallChatGenerate( + access: OllamaAccessSchema, llmOptions: Partial, messages: VChatMessageIn[], + maxTokens?: number, +): Promise { + const { llmRef, llmTemperature = 0.5, llmResponseTokens } = llmOptions; + try { + return await apiAsync.llmOllama.chatGenerate.mutate({ + access, + model: { + id: llmRef!, + temperature: llmTemperature, + maxTokens: maxTokens || llmResponseTokens || 1024, + }, + history: messages, + }) as TOut; + } catch (error: any) { + const errorMessage = error?.message || error?.toString() || 'Ollama Chat Generate Error'; + console.error(`ollamaCallChatGenerate: ${errorMessage}`); + throw new Error(errorMessage); + } +} + +/*import * as React from 'react'; + +import { SvgIcon } from '@mui/joy'; +import { SxProps } from '@mui/joy/styles/types'; + +export function OllamaIcon(props: { sx?: SxProps }) { + return + + ; +}*/ diff --git a/src/modules/llms/vendors/vendor.registry.ts b/src/modules/llms/vendors/vendor.registry.ts index b66cb8652..862745cdc 100644 --- a/src/modules/llms/vendors/vendor.registry.ts +++ b/src/modules/llms/vendors/vendor.registry.ts @@ -1,6 +1,7 @@ import { ModelVendorAnthropic } from './anthropic/anthropic.vendor'; import { ModelVendorAzure } from './azure/azure.vendor'; import { ModelVendorLocalAI } from './localai/localai.vendor'; +import { ModelVendorOllama } from './ollama/ollama.vendor'; import { ModelVendorOoobabooga } from './oobabooga/oobabooga.vendor'; import { ModelVendorOpenAI } from './openai/openai.vendor'; import { ModelVendorOpenRouter } from './openrouter/openrouter.vendor'; @@ -13,6 +14,7 @@ const MODEL_VENDOR_REGISTRY: Record = { anthropic: ModelVendorAnthropic, azure: ModelVendorAzure, localai: ModelVendorLocalAI, + ollama: ModelVendorOllama, oobabooga: ModelVendorOoobabooga, openai: ModelVendorOpenAI, openrouter: ModelVendorOpenRouter, diff --git a/src/server/api/trpc.router.ts b/src/server/api/trpc.router.ts index c3df756ce..1de787687 100644 --- a/src/server/api/trpc.router.ts +++ b/src/server/api/trpc.router.ts @@ -3,6 +3,7 @@ import { createTRPCRouter } from './trpc.server'; import { elevenlabsRouter } from '~/modules/elevenlabs/elevenlabs.router'; import { googleSearchRouter } from '~/modules/google/search.router'; import { llmAnthropicRouter } from '~/modules/llms/transports/server/anthropic/anthropic.router'; +import { llmOllamaRouter } from '~/modules/llms/transports/server/ollama/ollama.router'; import { llmOpenAIRouter } from '~/modules/llms/transports/server/openai/openai.router'; import { prodiaRouter } from '~/modules/prodia/prodia.router'; import { tradeRouter } from '../../apps/chat/trade/server/trade.router'; @@ -17,6 +18,7 @@ export const appRouterEdge = createTRPCRouter({ elevenlabs: elevenlabsRouter, googleSearch: googleSearchRouter, llmAnthropic: llmAnthropicRouter, + llmOllama: llmOllamaRouter, llmOpenAI: llmOpenAIRouter, prodia: prodiaRouter, ytpersona: ytPersonaRouter, diff --git a/src/server/wire.ts b/src/server/wire.ts index 68939fdce..cdc1acd2a 100644 --- a/src/server/wire.ts +++ b/src/server/wire.ts @@ -65,3 +65,9 @@ export function debugGenerateCurlCommand(method: 'GET' | 'POST', url: string, he return curl; } + +export function createEmptyReadableStream(): ReadableStream { + return new ReadableStream({ + start: (controller) => controller.close(), + }); +} \ No newline at end of file