diff --git a/src/apps/chat/components/message/ContentPartImageRef.tsx b/src/apps/chat/components/message/ContentPartImageRef.tsx index 42708879b..ef914f684 100644 --- a/src/apps/chat/components/message/ContentPartImageRef.tsx +++ b/src/apps/chat/components/message/ContentPartImageRef.tsx @@ -1,5 +1,6 @@ import * as React from 'react'; import TimeAgo from 'react-timeago'; +import { useQuery } from '@tanstack/react-query'; import type { SxProps } from '@mui/joy/styles/types'; import { Box } from '@mui/joy'; @@ -8,6 +9,7 @@ import type { DBlobAssetId, DBlobImageAsset } from '~/modules/dblobs/dblobs.type import { RenderImageURL } from '~/modules/blocks/RenderImageURL'; import { blocksRendererSx } from '~/modules/blocks/BlocksRenderer'; import { getImageAssetAsBlobURL } from '~/modules/dblobs/dblobs.images'; +import { t2iGenerateImageContentFragments } from '~/modules/t2i/t2i.client'; import { useDBAsset } from '~/modules/dblobs/dblobs.hooks'; import type { DMessageContentFragment, DMessageDataRef, DMessageImageRefPart } from '~/common/stores/chat/chat.message'; @@ -46,25 +48,22 @@ function ContentPartImageRefDBlob(props: { const [imageItem] = useDBAsset(props.dataRefDBlobAssetId); // handlers - const { label: imageItemLabel, origin: imageItemOrigin, metadata: imageItemMetadata } = imageItem || {}; const recreationPrompt = ((imageItemOrigin?.ot === 'generated') ? imageItemOrigin.prompt : undefined) || imageItemLabel || props.imageAltText; - const recreationWidth = imageItemMetadata?.width || props.imageWidth; - const recreationHeight = imageItemMetadata?.height || props.imageHeight; + const _recreationWidth = imageItemMetadata?.width || props.imageWidth; + const _recreationHeight = imageItemMetadata?.height || props.imageHeight; - const handleImageRegenerate = React.useCallback(() => { - // TODO: ... t2iGenerateImagesOrThrow() - console.log('ContentPartImageDBlob: handleImageRegenerate: notImplemented', imageItem, recreationPrompt, recreationWidth, recreationHeight); - - // props.onImageReplace( createImageContentFragment() - // { - // type: 'image', - // dataRef: { reftype: 'dblob', dblobAssetId: props.dataRefDBlobAssetId }, - // altText: props.imageAltText, - // width: props.imageWidth, - // height: props.imageHeight, - // }); - }, [imageItem, recreationPrompt, recreationWidth, recreationHeight]); + // async image regeneration + const { isLoading: isRegenerating, refetch: handleImageRegenerate } = useQuery({ + enabled: false, + queryKey: ['regen-image-asset', props.dataRefDBlobAssetId, recreationPrompt], + queryFn: async ({ signal }) => { + if (signal?.aborted || !recreationPrompt) return; + const newImageFragments = await t2iGenerateImageContentFragments(null, recreationPrompt, 1, 'global', 'app-chat'); + if (newImageFragments.length === 1) + props.onReplaceFragment(newImageFragments[0]); + }, + }); // memo the description and overlay text const { dataUrl, altText, overlayText } = React.useMemo(() => { diff --git a/src/apps/chat/editors/image-generate.ts b/src/apps/chat/editors/image-generate.ts index 2c8368acc..53823dce5 100644 --- a/src/apps/chat/editors/image-generate.ts +++ b/src/apps/chat/editors/image-generate.ts @@ -1,10 +1,9 @@ import type { DBlobAssetId } from '~/modules/dblobs/dblobs.types'; -import { addDBImageAsset, gcDBImageAssets } from '~/modules/dblobs/dblobs.images'; -import { getActiveTextToImageProviderOrThrow, t2iGenerateImagesOrThrow } from '~/modules/t2i/t2i.client'; +import { gcDBImageAssets } from '~/modules/dblobs/dblobs.images'; +import { getActiveTextToImageProviderOrThrow, t2iGenerateImageContentFragments } from '~/modules/t2i/t2i.client'; import type { ConversationHandler } from '~/common/chats/ConversationHandler'; import type { TextToImageProvider } from '~/common/components/useCapabilities'; -import { createDMessageDataRefDBlob, createImageContentFragment } from '~/common/stores/chat/chat.message'; import { useChatStore } from '~/common/stores/chat/store-chats'; @@ -38,39 +37,12 @@ export async function runImageGenerationUpdatingState(cHandler: ConversationHand ); try { - const images = await t2iGenerateImagesOrThrow(t2iProvider, imageText, repeat); - for (const _i of images) { + const imageContentFragments = await t2iGenerateImageContentFragments(t2iProvider, imageText, repeat, 'global', 'app-chat'); - // add the image to the DB - const dblobAssetId = await addDBImageAsset('global', 'app-chat', { - label: imageText, // 'Image: ' + _i.altText - data: { - mimeType: _i.mimeType as any, /* we assume the mime is supported */ - base64: _i.base64Data, - }, - origin: { - ot: 'generated', - source: 'ai-text-to-image', - generatorName: _i.generatorName, - prompt: _i.altText, - parameters: _i.parameters, - generatedAt: _i.generatedAt, - }, - metadata: { - width: _i.width || 0, - height: _i.height || 0, - // description: '', - }, - }); - - // Create and add an Image Content Fragment - const imageContentFragment = createImageContentFragment( - createDMessageDataRefDBlob(dblobAssetId, _i.mimeType, _i.base64Data.length), - _i.altText, - _i.width, _i.height, - ); + // add the image content fragments to the message + for (const imageContentFragment of imageContentFragments) cHandler.messageAppendContentFragment(assistantMessageId, imageContentFragment, true, true); - } + return true; } catch (error: any) { diff --git a/src/apps/draw/DrawCreate.tsx b/src/apps/draw/DrawCreate.tsx index 62552b46f..681e2517c 100644 --- a/src/apps/draw/DrawCreate.tsx +++ b/src/apps/draw/DrawCreate.tsx @@ -5,14 +5,12 @@ import type { SxProps } from '@mui/joy/styles/types'; import { Box, Card, Skeleton } from '@mui/joy'; import type { ImageBlock } from '~/modules/blocks/blocks'; -import { addDBImageAsset } from '~/modules/dblobs/dblobs.images'; -import { getActiveTextToImageProviderOrThrow, t2iGenerateImagesOrThrow } from '~/modules/t2i/t2i.client'; +import { t2iGenerateImageContentFragments } from '~/modules/t2i/t2i.client'; import type { TextToImageProvider } from '~/common/components/useCapabilities'; import { InlineError } from '~/common/components/InlineError'; import { ScrollToBottom } from '~/common/scroll-to-bottom/ScrollToBottom'; import { ScrollToBottomButton } from '~/common/scroll-to-bottom/ScrollToBottomButton'; -import { createDMessageDataRefDBlob } from '~/common/stores/chat/chat.message'; import { DesignerPrompt, PromptComposer } from './create/PromptComposer'; import { ProviderConfigure } from './create/ProviderConfigure'; @@ -46,47 +44,12 @@ const imagineScrollContainerSx: SxProps = { * @returns up-to `vectorSize` image URLs */ async function queryActiveGenerateImageVector(singlePrompt: string, vectorSize: number = 1) { - const t2iProvider = getActiveTextToImageProviderOrThrow(); + const imageContentFragments = await t2iGenerateImageContentFragments(null, singlePrompt, vectorSize, 'global', 'app-draw'); - const images = await t2iGenerateImagesOrThrow(t2iProvider, singlePrompt, vectorSize); - if (!images?.length) - throw new Error('No image generated'); - - // save the generated images - for (const _i of images) { - - // add the image to the DB - const dblobAssetId = await addDBImageAsset('global', 'app-draw', { - label: singlePrompt, - data: { - mimeType: _i.mimeType as any, /* we assume the mime is supported */ - base64: _i.base64Data, - }, - origin: { - ot: 'generated', - source: 'ai-text-to-image', - generatorName: _i.generatorName, // t2iProvider.painter? - prompt: _i.altText, - parameters: _i.parameters, - generatedAt: _i.generatedAt, - }, - metadata: { - width: _i.width || 0, - height: _i.height || 0, - // description: '', - }, - }); - - // Create a data reference for the image from the message - const imagePartDataRef = createDMessageDataRefDBlob(dblobAssetId, _i.mimeType, _i.base64Data.length); - - // TODO: move to DMessageImagePart? - console.log('TODO: notImplemented: imagePartDataRef: CRUD and View of blobs as ImageBlocks', imagePartDataRef); + for (const imageContentFragment of imageContentFragments) { + console.log('TODO: notImplemented: imagePartDataRef: CRUD and View of blobs as ImageBlocks', imageContentFragment.part); } - - // const block = heuristicMarkdownImageReferenceBlocks(images.join('\n')); - // if (!block?.length) - // throw new Error('No URLs in the generated images'); + // TODO continue... return []; } diff --git a/src/modules/t2i/t2i.client.ts b/src/modules/t2i/t2i.client.ts index 3bc51df71..ec4b7de67 100644 --- a/src/modules/t2i/t2i.client.ts +++ b/src/modules/t2i/t2i.client.ts @@ -3,11 +3,14 @@ import { shallow } from 'zustand/shallow'; import { useShallow } from 'zustand/react/shallow'; import { useStoreWithEqualityFn } from 'zustand/traditional'; +import type { DBlobDBAsset } from '~/modules/dblobs/dblobs.types'; import type { ModelVendorId } from '~/modules/llms/vendors/vendors.registry'; import { DLLM, DModelSource, DModelSourceId, useModelsStore } from '~/modules/llms/store-llms'; +import { addDBImageAsset } from '~/modules/dblobs/dblobs.images'; import { getBackendCapabilities } from '~/modules/backend/store-backend-capabilities'; import type { CapabilityTextToImage, TextToImageProvider } from '~/common/components/useCapabilities'; +import { createDMessageDataRefDBlob, createImageContentFragment, DMessageContentFragment } from '~/common/stores/chat/chat.message'; import type { T2iCreateImageOutput } from './t2i.server'; import { openAIGenerateImagesOrThrow } from './dalle/openaiGenerateImages'; @@ -89,7 +92,7 @@ export function getActiveTextToImageProviderOrThrow() { return activeProvider; } -export async function t2iGenerateImagesOrThrow({ id, vendor }: TextToImageProvider, prompt: string, count: number): Promise { +async function _t2iGenerateImagesOrThrow({ id, vendor }: TextToImageProvider, prompt: string, count: number): Promise { switch (vendor) { case 'localai': @@ -113,6 +116,59 @@ export async function t2iGenerateImagesOrThrow({ id, vendor }: TextToImageProvid } } +/** + * Generate image content fragments using the provided TextToImageProvider + * If t2iprovider is null, the active provider will be used + */ +export async function t2iGenerateImageContentFragments( + t2iProvider: TextToImageProvider | null, prompt: string, count: number, + contextId: DBlobDBAsset['contextId'], scopeId: DBlobDBAsset['scopeId'], +): Promise { + + // T2I: Use the active provider if null + if (!t2iProvider) + t2iProvider = getActiveTextToImageProviderOrThrow(); + + // T2I: Generate + const generatedImages = await _t2iGenerateImagesOrThrow(t2iProvider, prompt, count); + if (!generatedImages?.length) + throw new Error('No image generated'); + + const imageFragments: DMessageContentFragment[] = []; + for (const _i of generatedImages) { + + // add the image to the DB + const dblobAssetId = await addDBImageAsset(contextId, scopeId, { + label: prompt, + data: { + mimeType: _i.mimeType as any, + base64: _i.base64Data, + }, + origin: { + ot: 'generated', + source: 'ai-text-to-image', + generatorName: _i.generatorName, + prompt: _i.altText, + parameters: _i.parameters, + generatedAt: _i.generatedAt, + }, + metadata: { + width: _i.width || 0, + height: _i.height || 0, + // description: '', + }, + }); + + // create a data reference for the image + const imageAssetDataRef = createDMessageDataRefDBlob(dblobAssetId, _i.mimeType, _i.base64Data.length); + + // create an Image Content Fragment + const imageContentFragment = createImageContentFragment(imageAssetDataRef, _i.altText, _i.width, _i.height); + imageFragments.push(imageContentFragment); + } + return imageFragments; +} + /// Private