diff --git a/.gitignore b/.gitignore index 5076b36e..9590699f 100755 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ captures/ !/.idea/vcs.xml !/.idea/fileTemplates/ !/.idea/inspectionProfiles/ +!/.idea/icon.svg !/.idea/scopes/ !/.idea/codeStyleSettings.xml !/.idea/encodings.xml diff --git a/.idea/icon.svg b/.idea/icon.svg new file mode 100644 index 00000000..7f7eab01 --- /dev/null +++ b/.idea/icon.svg @@ -0,0 +1,106 @@ + + + + + + + + + + + + + + + + + + diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 5aeb348d..9ad3e14a 100755 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -51,27 +51,6 @@ android { println("[Signature] -> Build will be signed with signature: $alias") buildTypes.getByName("release").signingConfig = signingConfigs.getByName("release") } - - flavorDimensions += "type" - productFlavors { - create("dev") { - dimension = "type" - applicationIdSuffix = ".dev" - resValue("string", "app_name", "SDAI Dev") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") - } - create("foss") { - dimension = "type" - applicationIdSuffix = ".foss" - resValue("string", "app_name", "SDAI FOSS") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") - } - create("playstore") { - dimension = "type" - resValue("string", "app_name", "SDAI") - buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"") - } - } } dependencies { @@ -85,6 +64,7 @@ dependencies { implementation(project(":domain")) implementation(project(":feature:auth")) implementation(project(":feature:diffusion")) + implementation(project(":feature:mediapipe")) implementation(project(":feature:work")) implementation(project(":data")) implementation(project(":demo")) diff --git a/app/src/dev/AndroidManifest.xml b/app/src/full/AndroidManifest.xml similarity index 100% rename from app/src/dev/AndroidManifest.xml rename to app/src/full/AndroidManifest.xml diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index 67b1befd..4aac9eb0 100755 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -12,6 +12,10 @@ android:theme="@style/Theme.AiSdCompose.Splash" android:usesCleartextTraffic="true"> + + + + append(" FULL") + BuildType.FOSS -> append(" FOSS") + BuildType.PLAY -> Unit + } } } } @@ -172,14 +176,14 @@ val providersModule = module { single { DeviceNNAPIFlagProvider { - get().localUseNNAPI + get().localOnnxUseNNAPI .let { nnApi -> if (nnApi) LocalDiffusionFlag.NN_API else LocalDiffusionFlag.CPU } .let(LocalDiffusionFlag::value) } } single { - LocalModelIdProvider { get().localModelId } + LocalModelIdProvider { get().localOnnxModelId } } single { diff --git a/build-logic/convention/build.gradle.kts b/build-logic/convention/build.gradle.kts index 75b16dee..2d730493 100644 --- a/build-logic/convention/build.gradle.kts +++ b/build-logic/convention/build.gradle.kts @@ -36,6 +36,10 @@ gradlePlugin { id = "generic.application" implementationClass = "ApplicationConventionPlugin" } + register("Flavors") { + id = "generic.flavors" + implementationClass = "FlavorsConventionPlugin" + } register("BaselineProFm") { id = "generic.baseline.profm" implementationClass = "BaselineProFmConventionPlugin" diff --git a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt index 2b126264..34cbad9a 100644 --- a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt +++ b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt @@ -1,6 +1,8 @@ import com.android.build.api.dsl.ApplicationExtension +import com.android.build.gradle.BaseExtension import com.shifthackz.aisdv1.buildlogic.configureApplication import com.shifthackz.aisdv1.buildlogic.configureCompose +import com.shifthackz.aisdv1.buildlogic.configureFlavors import com.shifthackz.aisdv1.buildlogic.libs import org.gradle.api.Plugin import org.gradle.api.Project @@ -22,6 +24,9 @@ class ApplicationConventionPlugin : Plugin { configureCompose(this) defaultConfig.targetSdk = libs.findVersion("targetSdk").get().toString().toInt() } + extensions.configure { + configureFlavors(this) + } } } } diff --git a/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt new file mode 100644 index 00000000..743465f1 --- /dev/null +++ b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt @@ -0,0 +1,16 @@ +import com.android.build.gradle.LibraryExtension +import com.shifthackz.aisdv1.buildlogic.configureFlavorsCommon +import org.gradle.api.Plugin +import org.gradle.api.Project +import org.gradle.kotlin.dsl.configure + +class FlavorsConventionPlugin : Plugin { + + override fun apply(target: Project) { + with(target) { + extensions.configure { + configureFlavorsCommon(this) + } + } + } +} diff --git a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt index 5bc8f1b3..6e83a3fc 100644 --- a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt +++ b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt @@ -7,7 +7,7 @@ import org.gradle.kotlin.dsl.configure import org.gradle.kotlin.dsl.withType import org.gradle.testing.jacoco.plugins.JacocoTaskExtension -class JacocoConventionPlugin : Plugin { +class JacocoConventionPlugin : Plugin { override fun apply(target: Project) { with(target) { diff --git a/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt new file mode 100644 index 00000000..6357b9ff --- /dev/null +++ b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt @@ -0,0 +1,43 @@ +package com.shifthackz.aisdv1.buildlogic + +import com.android.build.api.dsl.CommonExtension +import com.android.build.gradle.BaseExtension +import org.gradle.api.Project + +internal fun Project.configureFlavors( + commonExtension: BaseExtension, +) { + commonExtension.apply { + flavorDimensions("type") + productFlavors.create("full") { + dimension = "type" + applicationIdSuffix = ".full" + resValue("string", "app_name", "SDAI Full") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FULL\"") + + } + productFlavors.create("foss") { + dimension = "type" + applicationIdSuffix = ".foss" + resValue("string", "app_name", "SDAI FOSS") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"") + + } + productFlavors.create("playstore") { + dimension = "type" + resValue("string", "app_name", "SDAI") + buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"") + } + } +} + +internal fun Project.configureFlavorsCommon( + commonExtension: CommonExtension<*, *, *, *, *, *>, +) { + commonExtension.apply { + flavorDimensions += listOf("type") + productFlavors.create("full") { dimension = "type" } + productFlavors.create("foss") { dimension = "type" } + productFlavors.create("playstore") { dimension = "type" } + } +} diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt index d431d0f8..2f4c4471 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt @@ -1,11 +1,13 @@ package com.shifthackz.aisdv1.core.common.appbuild enum class BuildType { + FULL, FOSS, PLAY; companion object { fun fromBuildConfig(input: String) = when (input) { + "FULL" -> FULL "FOSS" -> FOSS else -> PLAY } diff --git a/core/localization/src/main/res/values-ru/strings.xml b/core/localization/src/main/res/values-ru/strings.xml index acbaa8e7..ef0e1272 100644 --- a/core/localization/src/main/res/values-ru/strings.xml +++ b/core/localization/src/main/res/values-ru/strings.xml @@ -133,9 +133,11 @@ Укажите свйой URL-адрес Swarm UI Модульный веб-интерфейс Stable Diffusion, в котором особое внимание уделяется обеспечению легкого доступа к инструментам, высокой производительности и расширяемости. - Эта конфигурация позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. + Эта конфигурация использует Microsoft ONNX и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. ВНИМАНИЕ! Функциональность Local Diffusion в бета-тестировании. Не ожидайте высококачественных изображений в локальном режиме. \n\nЭта реализация может не работать должным образом на мобильных телефонах. Производительность и скорость генерации зависят от ресурсов вашего телефона (ЦП, ОЗУ) и размера сгенерированного изображения (чем меньше размер изображения, тем быстрее генерируется). + Эта конфигурация использует Google AI MediaPipe и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку. + Веб Txt2Img diff --git a/core/localization/src/main/res/values-tr/strings.xml b/core/localization/src/main/res/values-tr/strings.xml index eed6dea1..8319a6c1 100644 --- a/core/localization/src/main/res/values-tr/strings.xml +++ b/core/localization/src/main/res/values-tr/strings.xml @@ -133,9 +133,11 @@ Swarm UI URL\'nizi sağlayın Araçlara kolay erişim, yüksek performans ve genişletilebilirlik üzerine odaklanan Modüler, Kararlı Yaygın Web Kullanıcı Arayüzü. - Bu yapılandırma, telefonunuzda uzak sunucuya/buluta bağlanmaya gerek kalmadan Stable Diffusion AI nesillerini çalıştırmanıza izin verir. + Bu yapılandırma Microsoft ONNX çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır. Uyarı! Yerel Yayılma işlevi beta testindedir. Yerel modu kullanarak yüksek kaliteli görüntüler beklemeyin. \n\nBu uygulama, güçlü olmayan telefonlarda iyi çalışmayabilir. Oluşturma performansı ve hızı, telefonunuzun kaynaklarına (CPU, RAM) ve oluşturulan görüntünün boyutuna bağlıdır (görüntü boyutu ne kadar küçükse, oluşturma o kadar hızlı olur). + Bu yapılandırma Google AI MediaPipe çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır. + Web arayüzü Txt2Img diff --git a/core/localization/src/main/res/values-uk/strings.xml b/core/localization/src/main/res/values-uk/strings.xml index da32118e..87fe4b5d 100644 --- a/core/localization/src/main/res/values-uk/strings.xml +++ b/core/localization/src/main/res/values-uk/strings.xml @@ -133,9 +133,11 @@ Provide your Swarm UI URL Модульний веб-інтерфейс Stable Diffusion з наголосом на полегшення доступу до інструментів, високу продуктивність і розширюваність. - Ця конфігурація дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. + Ця конфігурація використовує Microsoft ONNX та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. УВАГА! Функціональність Local Diffusion у бета-тестуванні. Не очікуйте високоякісних зображень у локальному режимі. \n\nЦя реалізація може не працювати належним чином на телефонах із слабкою потужністю. Продуктивність і швидкість генерації залежать від ресурсів вашого телефону (ЦП, ОЗУ) і розміру згенерованого зображення (чим менший розмір зображення, тим швидше генерується). + Ця конфігурація використовує Google AI MediaPipe та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари. + Веб Txt2Img diff --git a/core/localization/src/main/res/values-zh/strings.xml b/core/localization/src/main/res/values-zh/strings.xml index c3483d2e..7329dc23 100644 --- a/core/localization/src/main/res/values-zh/strings.xml +++ b/core/localization/src/main/res/values-zh/strings.xml @@ -167,9 +167,11 @@ 本地扩散 - 此配置允许在您的手机上运行Stable Diffusion AI生成,无需连接到远程服务器/云。 + 此配置使用 Microsoft ONNX 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。 警告!本地扩散功能处于测试版。不要期望使用本地模式获得高质量图像。 \n\n此实现可能在不强大的手机上运行不佳。生成性能和速度取决于您的手机资源(CPU、RAM)和生成的图像大小(图像越小,生成越快)。 + 此配置使用 Google AI MediaPipe 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。 + 网络界面 diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index cba23819..b591977e 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -70,8 +70,10 @@ A1111 Horde AI Cloud Horde - Local Diffusion (Beta) - Local + Local Diffusion Microsoft ONNX (Beta) + ONNX + Local Diffusion Google AI MediaPipe (Beta) + MediaPipe Hugging Face Inference HuggingFace Open AI @@ -150,10 +152,13 @@ Provide your Swarm UI URL A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility. - Local Diffusion - This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. + Local Diffusion Microsoft ONNX + This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation). + Local Diffusion Google AI MediaPipe + This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud. + Web UI Txt2Img diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt index 6e08ea30..21d8edda 100755 --- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl +import com.shifthackz.aisdv1.data.repository.MediaPipeGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl @@ -33,6 +34,7 @@ import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository import com.shifthackz.aisdv1.domain.repository.LorasRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.RandomImageRepository import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository @@ -63,6 +65,7 @@ val repositoryModule = module { singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class + factoryOf(::MediaPipeGenerationRepositoryImpl) bind MediaPipeGenerationRepository::class factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt index 0e4b9f02..4ab031c2 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.reactivex.rxjava3.core.Completable +import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -22,72 +23,94 @@ internal class DownloadableModelLocalDataSource( private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelDataSource.Local { - override fun getAll() = dao - .query() + override fun getAllOnnx() = dao + .queryByType(LocalAiModel.Type.ONNX.key) .map(List::mapEntityToDomain) .map { models -> buildList { addAll(models) - if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + if (buildInfoProvider.type != BuildType.PLAY) { + add(LocalAiModel.CustomOnnx) + } + } + } + .flatMap { models -> models.withLocalData() } + + override fun getAllMediaPipe(): Single> = dao + .queryByType(LocalAiModel.Type.MediaPipe.key) + .map(List::mapEntityToDomain) + .map { models -> + buildList { + addAll(models) + if (buildInfoProvider.type != BuildType.PLAY) { + add(LocalAiModel.CustomMediaPipe) + } } } .flatMap { models -> models.withLocalData() } override fun getById(id: String): Single { - val chain = if (id == LocalAiModel.CUSTOM.id) { - Single.just(LocalAiModel.CUSTOM) - } else { - dao + val chain = when (id) { + LocalAiModel.CustomOnnx.id -> Single.just(LocalAiModel.CustomOnnx) + LocalAiModel.CustomMediaPipe.id -> Single.just(LocalAiModel.CustomMediaPipe) + else -> dao .queryById(id) .map(LocalModelEntity::mapEntityToDomain) } - return chain.flatMap { model -> model.withLocalData() } } - override fun getSelected() = Single - .just(preferenceManager.localModelId) - .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } + override fun getSelectedOnnx() = Single + .just(preferenceManager.localOnnxModelId) .flatMap(::getById) .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) } - override fun observeAll() = dao - .observe() + override fun observeAllOnnx(): Flowable> = dao + .observeByType(LocalAiModel.Type.ONNX.key) .map(List::mapEntityToDomain) .map { models -> buildList { addAll(models) - if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM) + if (buildInfoProvider.type != BuildType.PLAY) add(LocalAiModel.CustomOnnx) } } .flatMap { models -> models.withLocalData().toFlowable() } - override fun select(id: String) = Completable.fromAction { - preferenceManager.localModelId = id - } - override fun save(list: List) = list - .filter { it.id != LocalAiModel.CUSTOM.id } + .filter { it.id != LocalAiModel.CustomOnnx.id } .mapDomainToEntity() .let(dao::insertList) - override fun isDownloaded(id: String) = Single.create { emitter -> + override fun delete(id: String): Completable = Completable.fromAction { + getLocalModelDirectory(id).deleteRecursively() + } + + private fun isDownloaded(model: LocalAiModel) = Single.create { emitter -> try { - if (id == LocalAiModel.CUSTOM.id) { - if (!emitter.isDisposed) emitter.onSuccess(true) - } else { - val files = getLocalModelFiles(id) - if (!emitter.isDisposed) emitter.onSuccess(files.size == 4) + when (model.id) { + LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true) + + else -> { + + when (model.type) { + LocalAiModel.Type.ONNX -> { + val files = getLocalModelFiles(model.id).filter { it.isDirectory } + emitter.onSuccess(files.size == 4) + } + + LocalAiModel.Type.MediaPipe -> { + val files = getLocalModelFiles(model.id) + emitter.onSuccess(files.isNotEmpty()) + } + } + } } } catch (e: Exception) { if (!emitter.isDisposed) emitter.onSuccess(false) } } - override fun delete(id: String): Completable = Completable.fromAction { - getLocalModelDirectory(id).deleteRecursively() - } - private fun getLocalModelDirectory(id: String): File { return File("${fileProviderDescriptor.localModelDirPath}/${id}") } @@ -95,7 +118,7 @@ internal class DownloadableModelLocalDataSource( private fun getLocalModelFiles(id: String): List { val localModelDir = getLocalModelDirectory(id) if (!localModelDir.exists()) return emptyList() - return localModelDir.listFiles()?.filter { it.isDirectory } ?: emptyList() + return localModelDir.listFiles()?.toList() ?: emptyList() } private fun List.withLocalData() = Observable @@ -103,11 +126,14 @@ internal class DownloadableModelLocalDataSource( .flatMapSingle { model -> model.withLocalData() } .toList() - private fun LocalAiModel.withLocalData() = isDownloaded(id) + private fun LocalAiModel.withLocalData() = isDownloaded(this) .map { downloaded -> copy( downloaded = downloaded, - selected = preferenceManager.localModelId == id, + selected = when (this.type) { + LocalAiModel.Type.ONNX -> preferenceManager.localOnnxModelId == id + LocalAiModel.Type.MediaPipe -> preferenceManager.localMediaPipeModelId == id + }, ) } } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt index d9261085..7d7033aa 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt @@ -5,12 +5,16 @@ import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity //region RAW --> DOMAIN -fun List.mapRawToCheckpointDomain(): List = - map(DownloadableModelResponse::mapRawToCheckpointDomain) +fun List.mapRawToCheckpointDomain( + type: LocalAiModel.Type, +): List = map { it.mapRawToCheckpointDomain(type) } -fun DownloadableModelResponse.mapRawToCheckpointDomain(): LocalAiModel = with(this) { +fun DownloadableModelResponse.mapRawToCheckpointDomain( + type: LocalAiModel.Type, +): LocalAiModel = with(this) { LocalAiModel( id = id ?: "", + type = type, name = name ?: "", size = size ?: "", sources = sources ?: emptyList(), @@ -23,7 +27,7 @@ fun List.mapDomainToEntity(): List = map(LocalAiModel::mapDomainToEntity) fun LocalAiModel.mapDomainToEntity(): LocalModelEntity = with(this) { - LocalModelEntity(id, name, size, sources) + LocalModelEntity(id, type.key, name, size, sources) } //endregion @@ -32,6 +36,6 @@ fun List.mapEntityToDomain(): List = map(LocalModelEntity::mapEntityToDomain) fun LocalModelEntity.mapEntityToDomain(): LocalAiModel = with(this) { - LocalAiModel(id, name, size, sources) + LocalAiModel(id, LocalAiModel.Type.parse(type), name, size, sources) } //endregion diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index e88fcf47..ac47d61e 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -60,7 +60,16 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } - override var localDiffusionCustomModelPath: String + override var localMediaPipeCustomModelPath: String + get() = preferences.getString( + KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH, + LOCAL_DIFFUSION_CUSTOM_PATH + ) ?: LOCAL_DIFFUSION_CUSTOM_PATH + set(value) = preferences.edit() + .putString(KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH, value) + .apply() + + override var localOnnxCustomModelPath: String get() = preferences.getString( KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, LOCAL_DIFFUSION_CUSTOM_PATH, @@ -69,14 +78,14 @@ class PreferenceManagerImpl( .putString(KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, value) .apply() - override var localDiffusionAllowCancel: Boolean + override var localOnnxAllowCancel: Boolean get() = preferences.getBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, false) set(value) = preferences.edit() .putBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, value) .apply() .also { onPreferencesChanged() } - override var localDiffusionSchedulerThread: SchedulersToken + override var localOnnxSchedulerThread: SchedulersToken get() = preferences .getInt(KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD, SchedulersToken.COMPUTATION.ordinal) .let { SchedulersToken.entries[it] } @@ -196,20 +205,27 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } - override var localModelId: String + override var localOnnxModelId: String get() = preferences.getString(KEY_LOCAL_MODEL_ID, "") ?: "" set(value) = preferences.edit() .putString(KEY_LOCAL_MODEL_ID, value) .apply() .also { onPreferencesChanged() } - override var localUseNNAPI: Boolean + override var localOnnxUseNNAPI: Boolean get() = preferences.getBoolean(KEY_LOCAL_NN_API, false) set(value) = preferences.edit() .putBoolean(KEY_LOCAL_NN_API, value) .apply() .also { onPreferencesChanged() } + override var localMediaPipeModelId: String + get() = preferences.getString(KEY_MEDIA_PIPE_MODEL_ID, "") ?: "" + set(value) = preferences.edit() + .putString(KEY_MEDIA_PIPE_MODEL_ID, value) + .apply() + .also { onPreferencesChanged() } + override var designUseSystemColorPalette: Boolean get() = preferences.getBoolean(KEY_DESIGN_DYNAMIC_COLORS, false) set(value) = preferences.edit() @@ -273,8 +289,8 @@ class PreferenceManagerImpl( sdModel = sdModel, demoMode = demoMode, developerMode = developerMode, - localDiffusionAllowCancel = localDiffusionAllowCancel, - localDiffusionSchedulerThread = localDiffusionSchedulerThread, + localDiffusionAllowCancel = localOnnxAllowCancel, + localDiffusionSchedulerThread = localOnnxSchedulerThread, monitorConnectivity = monitorConnectivity, backgroundGeneration = backgroundGeneration, autoSaveAiResults = autoSaveAiResults, @@ -283,7 +299,7 @@ class PreferenceManagerImpl( formPromptTaggedInput = formPromptTaggedInput, source = source, hordeApiKey = hordeApiKey, - localUseNNAPI = localUseNNAPI, + localUseNNAPI = localOnnxUseNNAPI, designUseSystemColorPalette = designUseSystemColorPalette, designUseSystemDarkTheme = designUseSystemDarkTheme, designDarkTheme = designDarkTheme, @@ -302,6 +318,7 @@ class PreferenceManagerImpl( const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_DEVELOPER_MODE = "key_developer_mode" const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path" + const val KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH = "key_mediapipe_custom_model_path" const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel" const val KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD = "key_local_diffusion_scheduler_thread" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" @@ -319,6 +336,7 @@ class PreferenceManagerImpl( const val KEY_STABILITY_AI_ENGINE_ID_KEY = "key_stability_ai_engine_id_key" const val KEY_ON_BOARDING_COMPLETE = "key_on_boarding_complete" const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.6.2" + const val KEY_MEDIA_PIPE_MODEL_ID = "key_mediapipe_model_id" const val KEY_LOCAL_MODEL_ID = "key_local_model_id" const val KEY_LOCAL_NN_API = "key_local_nn_api" const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors" diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt index a2e6b193..1a552534 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt @@ -5,10 +5,12 @@ import com.shifthackz.aisdv1.core.common.file.unzip import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single import java.io.File internal class DownloadableModelRemoteDataSource( @@ -16,12 +18,18 @@ internal class DownloadableModelRemoteDataSource( private val fileProviderDescriptor: FileProviderDescriptor, ) : DownloadableModelDataSource.Remote { - override fun fetch() = api - .fetchDownloadableModels() - .map(List::mapRawToCheckpointDomain) + override fun fetch(): Single> = Single.zip( + api + .fetchOnnxModels() + .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) }, + api + .fetchMediaPipeModels() + .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe) }, + ::Pair, + ) + .map { (onnx, mediapipe) -> listOf(onnx, mediapipe).flatten() } - override fun download(id: String, url: String): Observable = - Completable + override fun download(id: String, url: String): Observable = Completable .fromAction { val dir = File("${fileProviderDescriptor.localModelDirPath}/${id}") val destination = File(getDestinationPath(id)) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt index 5bde3662..6f84189c 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt @@ -1,15 +1,18 @@ package com.shifthackz.aisdv1.data.repository +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository +import io.reactivex.rxjava3.core.Single internal class DownloadableModelRepositoryImpl( private val remoteDataSource: DownloadableModelDataSource.Remote, private val localDataSource: DownloadableModelDataSource.Local, + private val buildInfoProvider: BuildInfoProvider, ) : DownloadableModelRepository { - override fun isModelDownloaded(id: String) = localDataSource.isDownloaded(id) - override fun download(id: String) = localDataSource .getById(id) .flatMapObservable { model -> @@ -18,15 +21,22 @@ internal class DownloadableModelRepositoryImpl( override fun delete(id: String) = localDataSource.delete(id) - override fun getAll() = remoteDataSource + override fun getAllOnnx() = remoteDataSource .fetch() .flatMapCompletable(localDataSource::save) - .andThen(localDataSource.getAll()) - .onErrorResumeNext { localDataSource.getAll() } - - override fun getById(id: String) = localDataSource.getById(id) + .andThen(localDataSource.getAllOnnx()) + .onErrorResumeNext { localDataSource.getAllOnnx() } - override fun observeAll() = localDataSource.observeAll() - - override fun select(id: String) = localDataSource.select(id) + override fun getAllMediaPipe(): Single> { + if (buildInfoProvider.type == BuildType.FOSS) { + return Single.just(emptyList()) + } + return remoteDataSource + .fetch() + .flatMapCompletable(localDataSource::save) + .andThen(localDataSource.getAllMediaPipe()) + .onErrorResumeNext { localDataSource.getAllMediaPipe() } + } + + override fun observeAllOnnx() = localDataSource.observeAllOnnx() } diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt index 0254b3d8..625a1ca8 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt @@ -36,7 +36,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( override fun observeStatus() = localDiffusion.observeStatus() override fun generateFromText(payload: TextToImagePayload) = downloadableLocalDataSource - .getSelected() + .getSelectedOnnx() .flatMap { model -> if (model.downloaded) generate(payload) else Single.error(IllegalStateException("Model not downloaded.")) @@ -46,7 +46,7 @@ internal class LocalDiffusionGenerationRepositoryImpl( private fun generate(payload: TextToImagePayload) = localDiffusion .process(payload) - .subscribeOn(schedulersProvider.byToken(preferenceManager.localDiffusionSchedulerThread)) + .subscribeOn(schedulersProvider.byToken(preferenceManager.localOnnxSchedulerThread)) .map(BitmapToBase64Converter::Input) .flatMap(bitmapToBase64Converter::invoke) .map(BitmapToBase64Converter.Output::base64ImageString) diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt new file mode 100644 index 00000000..88471209 --- /dev/null +++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt @@ -0,0 +1,45 @@ +package com.shifthackz.aisdv1.data.repository + +import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider +import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter +import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter +import com.shifthackz.aisdv1.data.core.CoreGenerationRepository +import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult +import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver +import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository +import io.reactivex.rxjava3.core.Single +import io.reactivex.rxjava3.schedulers.Schedulers + +internal class MediaPipeGenerationRepositoryImpl( + mediaStoreGateway: MediaStoreGateway, + base64ToBitmapConverter: Base64ToBitmapConverter, + localDataSource: GenerationResultDataSource.Local, + backgroundWorkObserver: BackgroundWorkObserver, + preferenceManager: PreferenceManager, + private val schedulersProvider: SchedulersProvider, + private val mediaPipe: MediaPipe, + private val bitmapToBase64Converter: BitmapToBase64Converter, +) : CoreGenerationRepository( + mediaStoreGateway = mediaStoreGateway, + base64ToBitmapConverter = base64ToBitmapConverter, + localDataSource = localDataSource, + preferenceManager = preferenceManager, + backgroundWorkObserver = backgroundWorkObserver, +), MediaPipeGenerationRepository { + + override fun generateFromText(payload: TextToImagePayload): Single = mediaPipe + .process(payload) + .subscribeOn(schedulersProvider.singleThread.let(Schedulers::from)) + .map(BitmapToBase64Converter::Input) + .flatMap(bitmapToBase64Converter::invoke) + .map(BitmapToBase64Converter.Output::base64ImageString) + .map { base64 -> payload to base64 } + .map(Pair::mapLocalDiffusionToAiGenResult) + .flatMap(::insertGenerationResult) +} diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt index e1c2c830..90fe0cbe 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt @@ -13,8 +13,6 @@ import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity import io.mockk.every import io.mockk.mockk -import io.mockk.mockkConstructor -import io.mockk.mockkStatic import io.reactivex.rxjava3.core.BackpressureStrategy import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Flowable @@ -22,7 +20,6 @@ import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.subjects.BehaviorSubject import org.junit.Assert import org.junit.Test -import java.io.File class DownloadableModelLocalDataSourceTest { @@ -43,7 +40,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns models list, app build type is PLAY, expected valid domain models list`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(mockLocalModelEntities) every { @@ -51,7 +48,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -61,7 +58,7 @@ class DownloadableModelLocalDataSourceTest { val expected = mockLocalModelEntities.mapEntityToDomain() localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue { actual -> @@ -74,7 +71,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns empty models list, app build type is PLAY, expected empty domain models list`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(emptyList()) every { @@ -82,11 +79,11 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(emptyList()) @@ -97,7 +94,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns models list, app build type is FOSS, expected valid domain models list with CUSTOM model included`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(mockLocalModelEntities) every { @@ -105,7 +102,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -114,11 +111,11 @@ class DownloadableModelLocalDataSourceTest { val expected = buildList { addAll(mockLocalModelEntities.mapEntityToDomain()) - add(LocalAiModel.CUSTOM.copy(downloaded = true)) + add(LocalAiModel.CustomOnnx.copy(downloaded = true)) } localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue { actual -> @@ -131,7 +128,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao returns empty models list, app build type is FOSS, expected domain models list with only CUSTOM model included`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.just(emptyList()) every { @@ -139,14 +136,14 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getAll() + .getAllOnnx() .test() .assertNoErrors() - .assertValue(listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + .assertValue(listOf(LocalAiModel.CustomOnnx.copy(downloaded = true))) .await() .assertComplete() } @@ -154,11 +151,11 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get all models, dao throws exception, expected error value`() { every { - stubDao.query() + stubDao.queryByType(any()) } returns Single.error(stubException) localDataSource - .getAll() + .getAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -173,7 +170,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -198,7 +195,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "5598" every { @@ -238,7 +235,7 @@ class DownloadableModelLocalDataSourceTest { } returns Single.just(mockLocalModelEntity) every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "5598" every { @@ -248,7 +245,7 @@ class DownloadableModelLocalDataSourceTest { val expected = mockLocalModelEntity.mapEntityToDomain().copy(selected = true) localDataSource - .getSelected() + .getSelectedOnnx() .test() .assertNoErrors() .assertValue(expected) @@ -259,11 +256,11 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to get selected model, preference throws exception, expected error value`() { every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" localDataSource - .getSelected() + .getSelectedOnnx() .test() .assertError { t -> t is IllegalStateException && t.message == "No selected model." @@ -276,7 +273,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is PLAY, expected empty list, then domain list with two items`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -284,7 +281,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.PLAY every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -292,7 +289,7 @@ class DownloadableModelLocalDataSourceTest { } returns "/tmp/local" val stubObserver = localDataSource - .observeAll() + .observeAllOnnx() .test() stubLocalModels.onNext(emptyList()) @@ -311,7 +308,7 @@ class DownloadableModelLocalDataSourceTest { @Test fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is FOSS, expected list with only CUSTOM model included, then domain list with two items and CUSTOM`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -319,7 +316,7 @@ class DownloadableModelLocalDataSourceTest { } returns BuildType.FOSS every { - stubPreferenceManager.localModelId + stubPreferenceManager.localOnnxModelId } returns "" every { @@ -327,14 +324,14 @@ class DownloadableModelLocalDataSourceTest { } returns "/tmp/local" val stubObserver = localDataSource - .observeAll() + .observeAllOnnx() .test() stubLocalModels.onNext(emptyList()) stubObserver .assertNoErrors() - .assertValueAt(0, listOf(LocalAiModel.CUSTOM.copy(downloaded = true))) + .assertValueAt(0, listOf(LocalAiModel.CustomOnnx.copy(downloaded = true))) stubLocalModels.onNext(mockLocalModelEntities) @@ -342,18 +339,18 @@ class DownloadableModelLocalDataSourceTest { .assertNoErrors() .assertValueAt(1, buildList { addAll(mockLocalModelEntities.mapEntityToDomain()) - add(LocalAiModel.CUSTOM.copy(downloaded = true)) + add(LocalAiModel.CustomOnnx.copy(downloaded = true)) }) } @Test fun `given attempt to observe all models, dao throws exception, expected error value`() { every { - stubDao.observe() + stubDao.observeByType(any()) } returns Flowable.error(stubException) localDataSource - .observeAll() + .observeAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -361,44 +358,6 @@ class DownloadableModelLocalDataSourceTest { .assertNotComplete() } - @Test - fun `given attempt to select model, preference changed, expected preference returns changed selected model id value`() { - every { - stubPreferenceManager.localModelId - } returns "" - - every { - stubPreferenceManager::localModelId.set(any()) - } returns Unit - - localDataSource - .select("5598") - .test() - .assertNoErrors() - .await() - .assertComplete() - - every { - stubPreferenceManager.localModelId - } returns "5598" - - Assert.assertEquals("5598", stubPreferenceManager.localModelId) - } - - @Test - fun `given attempt to select model, preference throws exception, expected error value`() { - every { - stubPreferenceManager::localModelId.set(any()) - } throws stubException - - localDataSource - .select("5598") - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - @Test fun `given attempt to save local model list, dao insert success, expected complete value`() { every { @@ -427,8 +386,6 @@ class DownloadableModelLocalDataSourceTest { .assertNotComplete() } - //-- - @Test fun `given attempt to delete file, delete operation success, expected complete value`() { every { @@ -456,15 +413,4 @@ class DownloadableModelLocalDataSourceTest { .await() .assertNotComplete() } - - @Test - fun `given attempt to check if CUSTOM model is downloaded, expected true`() { - localDataSource - .isDownloaded(LocalAiModel.CUSTOM.id) - .test() - .assertNoErrors() - .assertValue(true) - .await() - .assertComplete() - } } diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt index 6eee82bc..9aaa20a8 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel val mockLocalAiModel = LocalAiModel( id = "5598", + type = LocalAiModel.Type.ONNX, name = "Model 5598", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt index 33b45cd9..7beaa4be 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity val mockLocalModelEntity = LocalModelEntity( id = "5598", + type = "onnx", name = "Best model in entire universe", size = "5598 Gb", sources = listOf("https://5598.is.my.favourite.com"), @@ -12,6 +13,7 @@ val mockLocalModelEntity = LocalModelEntity( val mockLocalModelEntities = listOf( LocalModelEntity( id = "1", + type = "onnx", name = "Model 1", size = "1 Gb", sources = listOf("https://example.com/1.php"), diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt index 0535d7f8..20b1f5b3 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt @@ -223,17 +223,17 @@ class PreferenceManagerImplTest { Assert.assertEquals(ServerSource.AUTOMATIC1111, preferenceManager.source) whenever(stubPreference.getString(eq(KEY_SERVER_SOURCE), any())) - .thenReturn(ServerSource.LOCAL.key) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX.key) - preferenceManager.source = ServerSource.LOCAL + preferenceManager.source = ServerSource.LOCAL_MICROSOFT_ONNX - Assert.assertEquals(ServerSource.LOCAL, preferenceManager.source) + Assert.assertEquals(ServerSource.LOCAL_MICROSOFT_ONNX, preferenceManager.source) preferenceManager .observe() .test() .assertNoErrors() - .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL } + .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL_MICROSOFT_ONNX } } @Test @@ -373,14 +373,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) .thenReturn("") - Assert.assertEquals("", preferenceManager.localModelId) + Assert.assertEquals("", preferenceManager.localOnnxModelId) whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any())) .thenReturn("key") - preferenceManager.localModelId = "key" + preferenceManager.localOnnxModelId = "key" - Assert.assertEquals("key", preferenceManager.localModelId) + Assert.assertEquals("key", preferenceManager.localOnnxModelId) } @Test @@ -388,14 +388,14 @@ class PreferenceManagerImplTest { whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) .thenReturn(false) - Assert.assertEquals(false, preferenceManager.localUseNNAPI) + Assert.assertEquals(false, preferenceManager.localOnnxUseNNAPI) whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any())) .thenReturn(true) - preferenceManager.localUseNNAPI = true + preferenceManager.localOnnxUseNNAPI = true - Assert.assertEquals(true, preferenceManager.localUseNNAPI) + Assert.assertEquals(true, preferenceManager.localOnnxUseNNAPI) preferenceManager .observe() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt index c97b0a9e..775fd5a6 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt @@ -5,6 +5,7 @@ import com.nhaarman.mockitokotlin2.whenever import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi import io.reactivex.rxjava3.core.Single import org.junit.Test @@ -22,10 +23,16 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns data, expected valid domain models list`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) .thenReturn(Single.just(mockDownloadableModelsResponse)) - val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain() + whenever(stubApi.fetchMediaPipeModels()) + .thenReturn(Single.just(mockDownloadableModelsResponse)) + + val expected = listOf( + mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX), + mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe), + ).flatten() remoteDataSource .fetch() @@ -38,7 +45,10 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns empty data, expected empty domain models list`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) + .thenReturn(Single.just(emptyList())) + + whenever(stubApi.fetchMediaPipeModels()) .thenReturn(Single.just(emptyList())) remoteDataSource @@ -52,7 +62,10 @@ class DownloadableModelRemoteDataSourceTest { @Test fun `given attempt to fetch models list, api returns error, expected error value`() { - whenever(stubApi.fetchDownloadableModels()) + whenever(stubApi.fetchOnnxModels()) + .thenReturn(Single.error(stubException)) + + whenever(stubApi.fetchMediaPipeModels()) .thenReturn(Single.error(stubException)) remoteDataSource diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt index 2daf1b5f..379bc3c5 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.data.repository +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.data.mocks.mockLocalAiModel import com.shifthackz.aisdv1.data.mocks.mockLocalAiModels import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource @@ -25,16 +27,22 @@ class DownloadableModelRepositoryImplTest { private val stubDownloadState = BehaviorSubject.create() private val stubRemoteDataSource = mockk() private val stubLocalDataSource = mockk() + private val stubBuildInfoProvider = mockk() private val repository = DownloadableModelRepositoryImpl( remoteDataSource = stubRemoteDataSource, localDataSource = stubLocalDataSource, + buildInfoProvider = stubBuildInfoProvider, ) @Before fun initialize() { every { - stubLocalDataSource.observeAll() + stubBuildInfoProvider.type + } returns BuildType.FULL + + every { + stubLocalDataSource.observeAllOnnx() } returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST) every { @@ -42,51 +50,6 @@ class DownloadableModelRepositoryImplTest { } returns stubDownloadState } - @Test - fun `given attempt to check if model downloaded, local data source returns true, expected true value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.just(true) - - repository - .isModelDownloaded("5598") - .test() - .assertNoErrors() - .assertValue(true) - .await() - .assertComplete() - } - - @Test - fun `given attempt to check if model downloaded, local data source returns false, expected false value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.just(false) - - repository - .isModelDownloaded("5598") - .test() - .assertNoErrors() - .assertValue(false) - .await() - .assertComplete() - } - - @Test - fun `given attempt to check if model downloaded, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.isDownloaded(any()) - } returns Single.error(stubException) - - repository - .isModelDownloaded("5598") - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - @Test fun `given attempt to delete model, local data source completes, expected complete value`() { every { @@ -115,34 +78,6 @@ class DownloadableModelRepositoryImplTest { .assertNotComplete() } - @Test - fun `given attempt to select model, local data source completes, expected complete value`() { - every { - stubLocalDataSource.select(any()) - } returns Completable.complete() - - repository - .select("5598") - .test() - .assertNoErrors() - .await() - .assertComplete() - } - - @Test - fun `given attempt to select model, local data source throws exception, expected error value`() { - every { - stubLocalDataSource.select(any()) - } returns Completable.error(stubException) - - repository - .select("5598") - .test() - .assertError(stubException) - .await() - .assertNotComplete() - } - @Test fun `given attempt to get all, remote returns list, save success, local query success, expected valid domain model list value`() { every { @@ -154,11 +89,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -177,11 +112,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.error(stubException) every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -200,11 +135,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() + stubLocalDataSource.getAllOnnx() } returns Single.just(mockLocalAiModels) repository - .getAll() + .getAllOnnx() .test() .assertNoErrors() .assertValue(mockLocalAiModels) @@ -223,41 +158,11 @@ class DownloadableModelRepositoryImplTest { } returns Completable.complete() every { - stubLocalDataSource.getAll() - } returns Single.error(stubException) - - repository - .getAll() - .test() - .assertError(stubException) - .assertNoValues() - .await() - .assertNotComplete() - } - - @Test - fun `given attempt to get by id, local data source returns data, expected valid domain model value`() { - every { - stubLocalDataSource.getById(any()) - } returns Single.just(mockLocalAiModel) - - repository - .getById("5598") - .test() - .assertNoErrors() - .assertValue(mockLocalAiModel) - .await() - .assertComplete() - } - - @Test - fun `given attempt to get by id, local data source fails, expected error value`() { - every { - stubLocalDataSource.getById(any()) + stubLocalDataSource.getAllOnnx() } returns Single.error(stubException) repository - .getById("5598") + .getAllOnnx() .test() .assertError(stubException) .assertNoValues() @@ -267,7 +172,7 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source emits empty list, then another list, expected empty value, then valid domain models list value`() { - val stubObserver = repository.observeAll().test() + val stubObserver = repository.observeAllOnnx().test() stubLocalModels.onNext(emptyList()) @@ -284,7 +189,7 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source emits list, then changed list, expected valid domain models list value, then changed value`() { - val stubObserver = repository.observeAll().test() + val stubObserver = repository.observeAllOnnx().test() stubLocalModels.onNext(mockLocalAiModels) @@ -302,11 +207,11 @@ class DownloadableModelRepositoryImplTest { @Test fun `given observe all models, local data source throws exception, expected error value`() { every { - stubLocalDataSource.observeAll() + stubLocalDataSource.observeAllOnnx() } returns Flowable.error(stubException) repository - .observeAll() + .observeAllOnnx() .test() .assertError(stubException) .assertNoValues() diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt index 75df885f..126edac1 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway @@ -31,7 +32,7 @@ class LocalDiffusionGenerationRepositoryImplTest { private val stubBitmap = mockk() private val stubException = Throwable("Something went wrong.") - private val stubStatus = BehaviorSubject.create() + private val stubStatus = BehaviorSubject.create() private val stubMediaStoreGateway = mockk() private val stubBase64ToBitmapConverter = mockk() private val stubBitmapToBase64Converter = mockk() @@ -63,7 +64,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Before fun initialize() { every { - stubPreferenceManager::localDiffusionSchedulerThread.get() + stubPreferenceManager::localOnnxSchedulerThread.get() } returns SchedulersToken.COMPUTATION every { @@ -83,17 +84,17 @@ class LocalDiffusionGenerationRepositoryImplTest { fun `given attempt to observe status, local emits two values, expected same values with same order`() { val stubObserver = repository.observeStatus().test() - stubStatus.onNext(LocalDiffusion.Status(1, 2)) + stubStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 2)) + .assertValueAt(0, LocalDiffusionStatus(1, 2)) - stubStatus.onNext(LocalDiffusion.Status(2, 2)) + stubStatus.onNext(LocalDiffusionStatus(2, 2)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 2)) + .assertValueAt(1, LocalDiffusionStatus(2, 2)) } @Test @@ -142,7 +143,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, no selected model, expected error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.error(stubException) repository @@ -157,7 +158,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected not downloaded model, expected IllegalStateException error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = false)) every { @@ -182,7 +183,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected downloaded model, local process success, expected valid domain model value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = true)) every { @@ -205,7 +206,7 @@ class LocalDiffusionGenerationRepositoryImplTest { @Test fun `given attempt to generate from text, has selected downloaded model, local process fails, expected error value`() { every { - stubDownloadableLocalDataSource.getSelected() + stubDownloadableLocalDataSource.getSelectedOnnx() } returns Single.just(mockLocalAiModel.copy(downloaded = true)) every { diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt index 3cc310df..7ef7f17a 100644 --- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt +++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt @@ -40,7 +40,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -62,7 +62,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch and get, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -88,7 +88,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to fetch and observe, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubRemoteDataSource.fetch() @@ -110,7 +110,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to get, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX every { stubLocalDataSource.get() @@ -128,7 +128,7 @@ class StabilityAiCreditsRepositoryImplTest { fun `given server source is not STABILITY_AI, attempt to observe, expected IllegalStateException error value`() { every { stubPreferenceManager.source - } returns ServerSource.LOCAL + } returns ServerSource.LOCAL_MICROSOFT_ONNX repository .observe() diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt index b2320323..cb56c621 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt @@ -15,13 +15,12 @@ sealed interface DownloadableModelDataSource { } interface Local : DownloadableModelDataSource { - fun getAll(): Single> + fun getAllOnnx(): Single> + fun getAllMediaPipe(): Single> fun getById(id: String): Single - fun getSelected(): Single - fun observeAll(): Flowable> - fun select(id: String): Completable + fun getSelectedOnnx(): Single + fun observeAllOnnx(): Flowable> fun save(list: List): Completable - fun isDownloaded(id: String): Single fun delete(id: String): Completable } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt index bccb6ebd..5bdaf74e 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt @@ -36,10 +36,12 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCase import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase @@ -92,6 +94,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCaseImpl +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCaseImpl import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase @@ -155,15 +159,17 @@ internal val useCasesModule = module { factoryOf(::SaveLastResultToCacheUseCaseImpl) bind SaveLastResultToCacheUseCase::class factoryOf(::GetLastResultFromCacheUseCaseImpl) bind GetLastResultFromCacheUseCase::class factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class - factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class + factoryOf(::GetLocalOnnxModelsUseCaseImpl) bind GetLocalOnnxModelsUseCase::class + factoryOf(::GetLocalMediaPipeModelsUseCaseImpl) bind GetLocalMediaPipeModelsUseCase::class factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class - factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class + factoryOf(::ObserveLocalOnnxModelsUseCaseImpl) bind ObserveLocalOnnxModelsUseCase::class factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class factoryOf(::InterruptGenerationUseCaseImpl) bind InterruptGenerationUseCase::class factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class + factoryOf(::ConnectToMediaPipeUseCaseImpl) bind ConnectToMediaPipeUseCase::class factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class factoryOf(::ConnectToHuggingFaceUseCaseImpl) bind ConnectToHuggingFaceUseCase::class diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index fcf04c19..40c5e10a 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -15,6 +15,8 @@ data class Configuration( val stabilityAiApiKey: String = "", val stabilityAiEngineId: String = "", val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None, - val localModelId: String = "", - val localModelPath: String = "", + val localOnnxModelId: String = "", + val localOnnxModelPath: String = "", + val localMediaPipeModelId: String = "", + val localMediaPipeModelPath: String = "", ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt index 734ad5ed..b7958512 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt @@ -2,15 +2,34 @@ package com.shifthackz.aisdv1.domain.entity data class LocalAiModel( val id: String, + val type: Type, val name: String, val size: String, val sources: List, val downloaded: Boolean = false, val selected: Boolean = false, ) { + enum class Type(val key: String) { + ONNX("onnx"), + MediaPipe("mediapipe"); + + companion object { + fun parse(value: String?) = entries.find { it.key == value } ?: ONNX + } + } + companion object { - val CUSTOM = LocalAiModel( + val CustomOnnx = LocalAiModel( id = "CUSTOM", + type = Type.ONNX, + name = "Custom", + size = "NaN", + sources = emptyList(), + ) + + val CustomMediaPipe = LocalAiModel( + id = "CUSTOM_MP", + type = Type.MediaPipe, name = "Custom", size = "NaN", sources = emptyList(), diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt new file mode 100644 index 00000000..779d8bc6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt @@ -0,0 +1,6 @@ +package com.shifthackz.aisdv1.domain.entity + +data class LocalDiffusionStatus( + val current: Int, + val total: Int, +) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt index 4eeca344..15244e1d 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt @@ -1,8 +1,11 @@ package com.shifthackz.aisdv1.domain.entity +import com.shifthackz.aisdv1.core.common.appbuild.BuildType + enum class ServerSource( val key: String, val featureTags: Set, + val allowedInBuilds: Set = setOf(BuildType.FOSS, BuildType.PLAY, BuildType.FULL), ) { AUTOMATIC1111( key = "custom", @@ -62,13 +65,22 @@ enum class ServerSource( FeatureTag.Batch, ), ), - LOCAL( + LOCAL_MICROSOFT_ONNX( key = "local", featureTags = setOf( FeatureTag.Offline, FeatureTag.Txt2Img, FeatureTag.MultipleModels, ), + ), + LOCAL_GOOGLE_MEDIA_PIPE( + key = "local_google_media_pipe", + featureTags = setOf( + FeatureTag.Offline, + FeatureTag.Txt2Img, + FeatureTag.MultipleModels, + ), + allowedInBuilds = setOf(BuildType.PLAY, BuildType.FULL), ); companion object { diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt index d1f260e7..afcdefd0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.domain.feature.diffusion import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable @@ -9,7 +10,5 @@ import io.reactivex.rxjava3.core.Single interface LocalDiffusion { fun process(payload: TextToImagePayload): Single fun interrupt(): Completable - fun observeStatus(): Observable - - data class Status(val current: Int, val total: Int) + fun observeStatus(): Observable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt new file mode 100644 index 00000000..d27e0502 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.domain.feature.mediapipe + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Single + +interface MediaPipe { + fun process(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt index 4c42f81c..5959a035 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase @@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase interface SetupConnectionInterActor { val connectToHorde: ConnectToHordeUseCase val connectToLocal: ConnectToLocalDiffusionUseCase + val connectToMediaPipe: ConnectToMediaPipeUseCase val connectToA1111: ConnectToA1111UseCase val connectToHuggingFace: ConnectToHuggingFaceUseCase val connectToOpenAi: ConnectToOpenAiUseCase diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt index 05517631..306da094 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt @@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase +import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase @@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase internal data class SetupConnectionInterActorImpl( override val connectToHorde: ConnectToHordeUseCase, override val connectToLocal: ConnectToLocalDiffusionUseCase, + override val connectToMediaPipe: ConnectToMediaPipeUseCase, override val connectToA1111: ConnectToA1111UseCase, override val connectToHuggingFace: ConnectToHuggingFaceUseCase, override val connectToOpenAi: ConnectToOpenAiUseCase, diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 1ee8950b..42c32959 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -12,9 +12,10 @@ interface PreferenceManager { var swarmUiModel: String var demoMode: Boolean var developerMode: Boolean - var localDiffusionCustomModelPath: String - var localDiffusionAllowCancel: Boolean - var localDiffusionSchedulerThread: SchedulersToken + var localMediaPipeCustomModelPath: String + var localOnnxCustomModelPath: String + var localOnnxAllowCancel: Boolean + var localOnnxSchedulerThread: SchedulersToken var monitorConnectivity: Boolean var autoSaveAiResults: Boolean var saveToMediaStore: Boolean @@ -30,8 +31,9 @@ interface PreferenceManager { var stabilityAiEngineId: String var onBoardingComplete: Boolean var forceSetupAfterUpdate: Boolean - var localModelId: String - var localUseNNAPI: Boolean + var localOnnxModelId: String + var localOnnxUseNNAPI: Boolean + var localMediaPipeModelId: String var designUseSystemColorPalette: Boolean var designUseSystemDarkTheme: Boolean var designDarkTheme: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt index dcab5955..79efed66 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt @@ -8,11 +8,9 @@ import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface DownloadableModelRepository { - fun isModelDownloaded(id: String): Single fun download(id: String): Observable fun delete(id: String): Completable - fun getAll(): Single> - fun getById(id: String): Single - fun observeAll(): Flowable> - fun select(id: String): Completable + fun getAllOnnx(): Single> + fun getAllMediaPipe(): Single> + fun observeAllOnnx(): Flowable> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt index 0e109843..3ed5c370 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt @@ -1,14 +1,14 @@ package com.shifthackz.aisdv1.domain.repository import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import io.reactivex.rxjava3.core.Completable import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single interface LocalDiffusionGenerationRepository { - fun observeStatus(): Observable + fun observeStatus(): Observable fun generateFromText(payload: TextToImagePayload): Single fun interruptGeneration(): Completable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt new file mode 100644 index 00000000..a7b65fb5 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.domain.repository + +import com.shifthackz.aisdv1.domain.entity.AiGenerationResult +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import io.reactivex.rxjava3.core.Observable +import io.reactivex.rxjava3.core.Single + +interface MediaPipeGenerationRepository { + fun generateFromText(payload: TextToImagePayload): Single +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt new file mode 100644 index 00000000..cd93801f --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt @@ -0,0 +1,8 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import io.reactivex.rxjava3.core.Single + +interface GetLocalMediaPipeModelsUseCase { + operator fun invoke(): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt new file mode 100644 index 00000000..4da8eeba --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt @@ -0,0 +1,10 @@ +package com.shifthackz.aisdv1.domain.usecase.downloadable + +import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository + +internal class GetLocalMediaPipeModelsUseCaseImpl( + private val downloadableModelRepository: DownloadableModelRepository, + ) : GetLocalMediaPipeModelsUseCase { + + override fun invoke() = downloadableModelRepository.getAllMediaPipe() +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt similarity index 84% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt index efe71374..f2cce297 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt @@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Single -interface GetLocalAiModelsUseCase { +interface GetLocalOnnxModelsUseCase { operator fun invoke(): Single> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt similarity index 59% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt index 7bdbe7e7..444eccca 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt @@ -2,9 +2,9 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -internal class GetLocalAiModelsUseCaseImpl( +internal class GetLocalOnnxModelsUseCaseImpl( private val downloadableModelRepository: DownloadableModelRepository, -) : GetLocalAiModelsUseCase { +) : GetLocalOnnxModelsUseCase { - override fun invoke() = downloadableModelRepository.getAll() + override fun invoke() = downloadableModelRepository.getAllOnnx() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt similarity index 83% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt index f79d31b1..021ebd1e 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt @@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.entity.LocalAiModel import io.reactivex.rxjava3.core.Flowable -interface ObserveLocalAiModelsUseCase { +interface ObserveLocalOnnxModelsUseCase { operator fun invoke(): Flowable> } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt similarity index 70% rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt index 5fae54e1..e5e290f8 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt @@ -2,11 +2,11 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository -internal class ObserveLocalAiModelsUseCaseImpl( +internal class ObserveLocalOnnxModelsUseCaseImpl( private val repository: DownloadableModelRepository, -) : ObserveLocalAiModelsUseCase { +) : ObserveLocalOnnxModelsUseCase { override fun invoke() = repository - .observeAll() + .observeAllOnnx() .distinctUntilChanged() } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt index 0f769197..37f6720c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt @@ -17,7 +17,7 @@ internal class InterruptGenerationUseCaseImpl( override fun invoke() = when (preferenceManager.source) { ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.interruptGeneration() ServerSource.HORDE -> hordeGenerationRepository.interruptGeneration() - ServerSource.LOCAL -> localDiffusionGenerationRepository.interruptGeneration() + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.interruptGeneration() else -> Completable.complete() } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt index 3a65f91b..9530b6b0 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt @@ -1,8 +1,8 @@ package com.shifthackz.aisdv1.domain.usecase.generation -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import io.reactivex.rxjava3.core.Observable interface ObserveLocalDiffusionProcessStatusUseCase { - operator fun invoke(): Observable + operator fun invoke(): Observable } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt index 45be1d30..f823a153 100755 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt @@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository @@ -22,6 +23,7 @@ internal class TextToImageUseCaseImpl( private val stabilityAiGenerationRepository: StabilityAiGenerationRepository, private val swarmUiGenerationRepository: SwarmUiGenerationRepository, private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository, + private val mediaPipeGenerationRepository: MediaPipeGenerationRepository, private val preferenceManager: PreferenceManager, ) : TextToImageUseCase { @@ -34,11 +36,12 @@ internal class TextToImageUseCaseImpl( private fun generate(payload: TextToImagePayload) = when (preferenceManager.source) { ServerSource.HORDE -> hordeGenerationRepository.generateFromText(payload) - ServerSource.LOCAL -> localDiffusionGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.generateFromText(payload) ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromText(payload) ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload) ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload) ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload) ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> mediaPipeGenerationRepository.generateFromText(payload) } } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt index 0bbed0c9..d2519425 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt @@ -11,8 +11,8 @@ internal class ConnectToLocalDiffusionUseCaseImpl( override fun invoke(modelId: String) = getConfigurationUseCase() .map { originalConfiguration -> originalConfiguration.copy( - source = ServerSource.LOCAL, - localModelId = modelId, + source = ServerSource.LOCAL_MICROSOFT_ONNX, + localOnnxModelId = modelId, ) } .flatMapCompletable(setServerConfigurationUseCase::invoke) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt new file mode 100644 index 00000000..5c1027e6 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt @@ -0,0 +1,7 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import io.reactivex.rxjava3.core.Single + +interface ConnectToMediaPipeUseCase { + operator fun invoke(modelId: String): Single> +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt new file mode 100644 index 00000000..a60bdc68 --- /dev/null +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt @@ -0,0 +1,21 @@ +package com.shifthackz.aisdv1.domain.usecase.settings + +import com.shifthackz.aisdv1.domain.entity.ServerSource +import io.reactivex.rxjava3.core.Single + +internal class ConnectToMediaPipeUseCaseImpl( + private val getConfigurationUseCase: GetConfigurationUseCase, + private val setServerConfigurationUseCase: SetServerConfigurationUseCase, +) : ConnectToMediaPipeUseCase { + + override fun invoke(modelId: String): Single> = getConfigurationUseCase() + .map { originalConfiguration -> + originalConfiguration.copy( + source = ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + localMediaPipeModelId = modelId, + ) + } + .flatMapCompletable(setServerConfigurationUseCase::invoke) + .andThen(Single.just(Result.success(Unit))) + .onErrorResumeNext { t -> Single.just(Result.failure(t)) } +} diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index 4070a0a6..e88ebd63 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -24,8 +24,10 @@ internal class GetConfigurationUseCaseImpl( stabilityAiApiKey = preferenceManager.stabilityAiApiKey, stabilityAiEngineId = preferenceManager.stabilityAiEngineId, authCredentials = authorizationStore.getAuthorizationCredentials(), - localModelId = preferenceManager.localModelId, - localModelPath = preferenceManager.localDiffusionCustomModelPath, + localOnnxModelId = preferenceManager.localOnnxModelId, + localOnnxModelPath = preferenceManager.localOnnxCustomModelPath, + localMediaPipeModelId = preferenceManager.localMediaPipeModelId, + localMediaPipeModelPath = preferenceManager.localMediaPipeCustomModelPath, ) ) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 8d03619c..aab29abc 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -24,7 +24,9 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.huggingFaceModel = configuration.huggingFaceModel preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId - preferenceManager.localModelId = configuration.localModelId - preferenceManager.localDiffusionCustomModelPath = configuration.localModelPath + preferenceManager.localOnnxModelId = configuration.localOnnxModelId + preferenceManager.localOnnxCustomModelPath = configuration.localOnnxModelPath + preferenceManager.localMediaPipeModelId = configuration.localMediaPipeModelId + preferenceManager.localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath } } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index 73503575..da1d486a 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -12,6 +12,6 @@ val mockConfiguration = Configuration( huggingFaceModel = "5598", stabilityAiApiKey = "5598", stabilityAiEngineId = "5598", - localModelId = "5598", - localModelPath = "/storage/emulated/0/5598", + localOnnxModelId = "5598", + localOnnxModelPath = "/storage/emulated/0/5598", ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt index 69fcf3f0..5ce3fa3b 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt @@ -3,9 +3,10 @@ package com.shifthackz.aisdv1.domain.mocks import com.shifthackz.aisdv1.domain.entity.LocalAiModel val mockLocalAiModels = listOf( - LocalAiModel.CUSTOM, + LocalAiModel.CustomOnnx, LocalAiModel( id = "1", + type = LocalAiModel.Type.ONNX, name = "Model 1", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt similarity index 84% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt index 679ca297..d154c163 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt @@ -7,15 +7,15 @@ import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository import io.reactivex.rxjava3.core.Single import org.junit.Test -class GetLocalAiModelsUseCaseImplTest { +class GetLocalOnnxModelsUseCaseImplTest { private val stubRepository = mock() - private val useCase = GetLocalAiModelsUseCaseImpl(stubRepository) + private val useCase = GetLocalOnnxModelsUseCaseImpl(stubRepository) @Test fun `given repository returned models list, expected valid models list value`() { - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.just(mockLocalAiModels)) useCase() @@ -28,7 +28,7 @@ class GetLocalAiModelsUseCaseImplTest { @Test fun `given repository returned empty models list, expected empty models list value`() { - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.just(emptyList())) useCase() @@ -43,7 +43,7 @@ class GetLocalAiModelsUseCaseImplTest { fun `given repository thrown exception, expected error value`() { val stubException = Throwable("Unable to collect local models.") - whenever(stubRepository.getAll()) + whenever(stubRepository.getAllOnnx()) .thenReturn(Single.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt similarity index 91% rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt index 00fde27a..a61684b1 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt @@ -11,16 +11,16 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject import org.junit.Before import org.junit.Test -class ObserveLocalAiModelsUseCaseImplTest { +class ObserveLocalOnnxModelsUseCaseImplTest { private val stubLocalModels = BehaviorSubject.create>() private val stubRepository = mock() - private val useCase = ObserveLocalAiModelsUseCaseImpl(stubRepository) + private val useCase = ObserveLocalOnnxModelsUseCaseImpl(stubRepository) @Before fun initialize() { - whenever(stubRepository.observeAll()) + whenever(stubRepository.observeAllOnnx()) .thenReturn(stubLocalModels.toFlowable(BackpressureStrategy.LATEST)) } @@ -68,7 +68,7 @@ class ObserveLocalAiModelsUseCaseImplTest { .assertNoErrors() .assertValueAt(0, mockLocalAiModels) - val changedLocalAiModels = listOf(LocalAiModel.CUSTOM) + val changedLocalAiModels = listOf(LocalAiModel.CustomOnnx) stubLocalModels.onNext(changedLocalAiModels) stubObserver @@ -97,7 +97,7 @@ class ObserveLocalAiModelsUseCaseImplTest { fun `given observer terminates with unexpected error, expected receive error value`() { val stubException = Throwable("Unexpected Flowable termination.") - whenever(stubRepository.observeAll()) + whenever(stubRepository.observeAllOnnx()) .thenReturn(Flowable.error(stubException)) useCase() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt index 9f2fac3b..5641d4f2 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt @@ -293,7 +293,7 @@ class ImageToImageUseCaseImplTest { @Test fun `given source is LOCAL, expected Img2Img not yet supported error`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase(mockImageToImagePayload) .test() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt index 5ef6334c..895bb8a8 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt @@ -88,7 +88,7 @@ class InterruptGenerationUseCaseImplTest { @Test fun `given source is LOCAL, api interrupt success, expected complete value`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) .thenReturn(Completable.complete()) @@ -103,7 +103,7 @@ class InterruptGenerationUseCaseImplTest { @Test fun `given source is LOCAL, api interrupt fail, expected error value`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.interruptGeneration()) .thenReturn(Completable.error(stubException)) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt index 528e2e6e..47836c4e 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt @@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.domain.usecase.generation import com.nhaarman.mockitokotlin2.mock import com.nhaarman.mockitokotlin2.whenever -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.subjects.BehaviorSubject @@ -12,7 +12,7 @@ import org.junit.Test class ObserveLocalDiffusionProcessStatusUseCaseImplTest { private val stubException = Throwable("Error loading Local Diffusion.") - private val stubLocalStatus = BehaviorSubject.create() + private val stubLocalStatus = BehaviorSubject.create() private val stubRepository = mock() private val useCase = ObserveLocalDiffusionProcessStatusUseCaseImpl(stubRepository) @@ -27,23 +27,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest { fun `given repository processes three steps, expected three valid status values`() { val stubObserver = useCase().test() - stubLocalStatus.onNext(LocalDiffusion.Status(1, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 3)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 3)) + .assertValueAt(0, LocalDiffusionStatus(1, 3)) - stubLocalStatus.onNext(LocalDiffusion.Status(2, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(2, 3)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 3)) + .assertValueAt(1, LocalDiffusionStatus(2, 3)) - stubLocalStatus.onNext(LocalDiffusion.Status(3, 3)) + stubLocalStatus.onNext(LocalDiffusionStatus(3, 3)) stubObserver .assertNoErrors() - .assertValueAt(2, LocalDiffusion.Status(3, 3)) + .assertValueAt(2, LocalDiffusionStatus(3, 3)) .assertValueCount(3) } @@ -51,23 +51,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest { fun `given repository processes two steps, emits same step twice, expected two valid status values`() { val stubObserver = useCase().test() - stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() - .assertValueAt(0, LocalDiffusion.Status(1, 2)) + .assertValueAt(0, LocalDiffusionStatus(1, 2)) - stubLocalStatus.onNext(LocalDiffusion.Status(1, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(1, 2)) stubObserver .assertNoErrors() .assertValueCount(1) - stubLocalStatus.onNext(LocalDiffusion.Status(2, 2)) + stubLocalStatus.onNext(LocalDiffusionStatus(2, 2)) stubObserver .assertNoErrors() - .assertValueAt(1, LocalDiffusion.Status(2, 2)) + .assertValueAt(1, LocalDiffusionStatus(2, 2)) .assertValueCount(2) } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt index b6d3bced..98505eac 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository +import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository @@ -27,6 +28,7 @@ class TextToImageUseCaseImplTest { private val stubStabilityAiGenerationRepository = mock() private val stubSwarmUiGenerationRepository = mock() private val stubLocalDiffusionGenerationRepository = mock() + private val stubMediaPipeGenerationRepository = mock() private val stubPreferenceManager = mock() private val useCase = TextToImageUseCaseImpl( @@ -37,6 +39,7 @@ class TextToImageUseCaseImplTest { stabilityAiGenerationRepository = stubStabilityAiGenerationRepository, localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository, swarmUiGenerationRepository = stubSwarmUiGenerationRepository, + mediaPipeGenerationRepository = stubMediaPipeGenerationRepository, preferenceManager = stubPreferenceManager, ) @@ -363,7 +366,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 1, generated successfully, expected generations list with size 1`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) @@ -386,7 +389,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 10, generated successfully, expected generations list with size 10`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.just(mockAiGenerationResult)) @@ -409,7 +412,7 @@ class TextToImageUseCaseImplTest { @Test fun `given source is LOCAL, batch count is 1, generate failed, expected error`() { whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) whenever(stubLocalDiffusionGenerationRepository.generateFromText(any())) .thenReturn(Single.error(stubException)) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index 55290a49..5a5d12bc 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -69,12 +69,20 @@ class GetConfigurationUseCaseImplTest { } returns mockConfiguration.stabilityAiEngineId every { - stubPreferenceManager::localModelId.get() - } returns mockConfiguration.localModelId + stubPreferenceManager::localOnnxModelId.get() + } returns mockConfiguration.localOnnxModelId every { - stubPreferenceManager::localDiffusionCustomModelPath.get() - } returns mockConfiguration.localModelPath + stubPreferenceManager::localOnnxCustomModelPath.get() + } returns mockConfiguration.localOnnxModelPath + + every { + stubPreferenceManager::localMediaPipeModelId.get() + } returns mockConfiguration.localMediaPipeModelId + + every { + stubPreferenceManager::localMediaPipeCustomModelPath.get() + } returns mockConfiguration.localMediaPipeModelPath useCase .invoke() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index 62a74c34..a1620120 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -68,11 +68,19 @@ class SetServerConfigurationUseCaseImplTest { } returns Unit every { - stubPreferenceManager::localModelId.set(any()) + stubPreferenceManager::localOnnxModelId.set(any()) } returns Unit every { - stubPreferenceManager::localDiffusionCustomModelPath.set(any()) + stubPreferenceManager::localOnnxCustomModelPath.set(any()) + } returns Unit + + every { + stubPreferenceManager::localMediaPipeModelId.set(any()) + } returns Unit + + every { + stubPreferenceManager::localMediaPipeCustomModelPath.set(any()) } returns Unit useCase diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt index 1c2b3a0d..092b4c19 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt @@ -89,7 +89,7 @@ class SplashNavigationUseCaseImplTest { .thenReturn("") whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase() .test() @@ -109,7 +109,7 @@ class SplashNavigationUseCaseImplTest { .thenReturn("http://192.168.0.1:7860") whenever(stubPreferenceManager.source) - .thenReturn(ServerSource.LOCAL) + .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX) useCase() .test() diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt index 56b3f39d..775cad2c 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt @@ -4,6 +4,7 @@ import ai.onnxruntime.OnnxTensor import android.graphics.Bitmap import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.TextToImagePayload import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG @@ -20,7 +21,7 @@ internal class LocalDiffusionImpl( private val ortEnvironmentProvider: OrtEnvironmentProvider, ) : LocalDiffusion { - private val statusSubject: PublishSubject = PublishSubject.create() + private val statusSubject: PublishSubject = PublishSubject.create() override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> try { @@ -31,7 +32,7 @@ internal class LocalDiffusionImpl( uNet.setCallback(object : UNet.Callback { override fun onStep(maxStep: Int, step: Int) { debugLog(TAG, "Received step update: ${maxStep}/${step}") - statusSubject.onNext(LocalDiffusion.Status(step, maxStep)) + statusSubject.onNext(LocalDiffusionStatus(step, maxStep)) } override fun onBuildImage(status: Int, bitmap: Bitmap?) { diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt index 7cb6df87..1b22e6d9 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt @@ -11,8 +11,8 @@ fun modelPathPrefix( localModelIdProvider: LocalModelIdProvider, ): String { val modelId = localModelIdProvider.get() - return if (modelId == LocalAiModel.CUSTOM.id) { - preferenceManager.localDiffusionCustomModelPath + return if (modelId == LocalAiModel.CustomOnnx.id) { + preferenceManager.localOnnxCustomModelPath } else { "${fileProviderDescriptor.localModelDirPath}/${modelId}" } diff --git a/feature/mediapipe/.gitignore b/feature/mediapipe/.gitignore new file mode 100644 index 00000000..42afabfd --- /dev/null +++ b/feature/mediapipe/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/feature/mediapipe/build.gradle.kts b/feature/mediapipe/build.gradle.kts new file mode 100644 index 00000000..0249feed --- /dev/null +++ b/feature/mediapipe/build.gradle.kts @@ -0,0 +1,17 @@ +plugins { + alias(libs.plugins.generic.library) + alias(libs.plugins.generic.flavors) +} + +android { + namespace = "com.shifthackz.aisdv1.feature.mediapipe" +} + +dependencies { + implementation(project(":core:common")) + implementation(project(":domain")) + implementation(libs.koin.core) + implementation(libs.rx.kotlin) + fullImplementation(libs.google.mediapipe.image.generator) + playstoreImplementation(libs.google.mediapipe.image.generator) +} diff --git a/feature/mediapipe/consumer-rules.pro b/feature/mediapipe/consumer-rules.pro new file mode 100644 index 00000000..e69de29b diff --git a/feature/mediapipe/proguard-rules.pro b/feature/mediapipe/proguard-rules.pro new file mode 100644 index 00000000..481bb434 --- /dev/null +++ b/feature/mediapipe/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile \ No newline at end of file diff --git a/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..3be83e5e --- /dev/null +++ b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.graphics.Bitmap +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl : MediaPipe { + + override fun process(payload: TextToImagePayload): Single { + return Single.error(IllegalStateException("Google AI MediaPipe is not supported on FOSS build.")) + } +} diff --git a/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..a6e7e98a --- /dev/null +++ b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,63 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.content.Context +import android.graphics.Bitmap +import com.google.mediapipe.framework.image.BitmapExtractor +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.log.debugLog +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl( + private val context: Context, + private val preferenceManager: PreferenceManager, + private val fileProviderDescriptor: FileProviderDescriptor, +) : MediaPipe { + + private var imageGenerator: ImageGenerator? = null + + override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> + try { + initialize() + debugLog("Generating...") + val result = imageGenerator?.generate( + payload.prompt, + payload.samplingSteps, + payload.seed.toIntOrNull() ?: 0, + ) + debugLog("Extracting bitmap...") + val bitmap = BitmapExtractor.extract(result?.generatedImage()) + debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + close() + if (!emitter.isDisposed) emitter.onSuccess(bitmap) + } catch (e: Exception) { + close() + if (!emitter.isDisposed) emitter.onError(e) + } + } + + private fun initialize(): ImageGenerator { + val path = modelPath(preferenceManager, fileProviderDescriptor) + + val options = ImageGeneratorOptions.builder() + .setImageGeneratorModelDirectory(path) + .build() + + val generator = ImageGenerator.createFromOptions(context, options) + imageGenerator = generator + debugLog("Initialized successfully! Path: $path") + return generator + } + + private fun close() = runCatching { + debugLog("Closing...") + imageGenerator?.close() + imageGenerator = null + debugLog("Session closed!") + } +} diff --git a/feature/mediapipe/src/main/AndroidManifest.xml b/feature/mediapipe/src/main/AndroidManifest.xml new file mode 100644 index 00000000..44008a43 --- /dev/null +++ b/feature/mediapipe/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt new file mode 100644 index 00000000..3327827e --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt @@ -0,0 +1,11 @@ +package com.shifthackz.aisdv1.feature.mediapipe.di + +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.feature.mediapipe.MediaPipeImpl +import org.koin.core.module.dsl.factoryOf +import org.koin.dsl.bind +import org.koin.dsl.module + +val mediaPipeModule = module { + factoryOf(::MediaPipeImpl) bind MediaPipe::class +} diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt new file mode 100644 index 00000000..91e49404 --- /dev/null +++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt @@ -0,0 +1,17 @@ +package com.shifthackz.aisdv1.feature.mediapipe.extensions + +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager + +fun modelPath( + preferenceManager: PreferenceManager, + fileProviderDescriptor: FileProviderDescriptor, +): String { + val modelId = preferenceManager.localMediaPipeModelId + return if (modelId == LocalAiModel.CustomMediaPipe.id) { + preferenceManager.localMediaPipeCustomModelPath + } else { + "${fileProviderDescriptor.localModelDirPath}/${modelId}" + } +} diff --git a/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt new file mode 100644 index 00000000..a6e7e98a --- /dev/null +++ b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt @@ -0,0 +1,63 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import android.content.Context +import android.graphics.Bitmap +import com.google.mediapipe.framework.image.BitmapExtractor +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator +import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions +import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.core.common.log.debugLog +import com.shifthackz.aisdv1.domain.entity.TextToImagePayload +import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe +import com.shifthackz.aisdv1.domain.preference.PreferenceManager +import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath +import io.reactivex.rxjava3.core.Single + +internal class MediaPipeImpl( + private val context: Context, + private val preferenceManager: PreferenceManager, + private val fileProviderDescriptor: FileProviderDescriptor, +) : MediaPipe { + + private var imageGenerator: ImageGenerator? = null + + override fun process(payload: TextToImagePayload): Single = Single.create { emitter -> + try { + initialize() + debugLog("Generating...") + val result = imageGenerator?.generate( + payload.prompt, + payload.samplingSteps, + payload.seed.toIntOrNull() ?: 0, + ) + debugLog("Extracting bitmap...") + val bitmap = BitmapExtractor.extract(result?.generatedImage()) + debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}") + close() + if (!emitter.isDisposed) emitter.onSuccess(bitmap) + } catch (e: Exception) { + close() + if (!emitter.isDisposed) emitter.onError(e) + } + } + + private fun initialize(): ImageGenerator { + val path = modelPath(preferenceManager, fileProviderDescriptor) + + val options = ImageGeneratorOptions.builder() + .setImageGeneratorModelDirectory(path) + .build() + + val generator = ImageGenerator.createFromOptions(context, options) + imageGenerator = generator + debugLog("Initialized successfully! Path: $path") + return generator + } + + private fun close() = runCatching { + debugLog("Closing...") + imageGenerator?.close() + imageGenerator = null + debugLog("Session closed!") + } +} diff --git a/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt new file mode 100644 index 00000000..412481c5 --- /dev/null +++ b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt @@ -0,0 +1,16 @@ +package com.shifthackz.aisdv1.feature.mediapipe + +import org.junit.Assert.* +import org.junit.Test + +/** + * Example local unit test, which will execute on the development machine (host). + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +class ExampleUnitTest { + @Test + fun addition_isCorrect() { + assertEquals(4, 2 + 2) + } +} \ No newline at end of file diff --git a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt index d3076dee..b77330ee 100644 --- a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt +++ b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt @@ -93,7 +93,7 @@ internal abstract class CoreGenerationWorker( body = subTitle, silent = true, progress = status.current to status.total, - canCancel = preferenceManager.localDiffusionAllowCancel, + canCancel = preferenceManager.localOnnxAllowCancel, ) } } @@ -112,7 +112,7 @@ internal abstract class CoreGenerationWorker( setForegroundNotification( title = title, body = subTitle, - canCancel = source != ServerSource.LOCAL, + canCancel = source != ServerSource.LOCAL_MICROSOFT_ONNX, ) } diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d2f0faac..d1a1645c 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -46,6 +46,7 @@ catppuccin = "0.1.2" turbine = "1.1.0" roboelectric = "4.13" testCoroutines = "1.8.1" +mediaPipeGenerator = "0.10.14" [libraries] android-tools-build-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "agp"} @@ -77,6 +78,7 @@ androidx-work-runtime = { group = "androidx.work", name = "work-runtime", versio google-gson = { group = "com.google.code.gson", name = "gson", version.ref = "gson" } google-material = { group = "com.google.android.material", name = "material", version.ref = "material" } google-accompanist-systemuicontroller = { group = "com.google.accompanist", name = "accompanist-systemuicontroller", version.ref = "accompanistSystemUi" } +google-mediapipe-image-generator = { group = "com.google.mediapipe", name = "tasks-vision-image-generator", version.ref = "mediaPipeGenerator" } retrofit-core = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = "retrofit" } retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit" } retrofit-adapter-rxjava3 = { group = "com.squareup.retrofit2", name = "adapter-rxjava3", version.ref = "retrofit" } @@ -115,6 +117,7 @@ android-application = { id = "com.android.application", version.ref = "agp" } android-library = { id = "com.android.library", version.ref = "agp" } jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } jetbrains-kotlin-kapt = { id = "org.jetbrains.kotlin.kapt", version="unspecified" } +generic-flavors = { id = "generic.flavors", version = "unspecified" } generic-library = { id = "generic.library", version = "unspecified" } generic-baseline-profm = { id = "generic.baseline.profm", version = "unspecified" } generic-compose = { id = "generic.compose", version = "unspecified" } diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt index 90adc17b..78a35382 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt @@ -11,7 +11,9 @@ import java.io.File interface DownloadableModelsApi { - fun fetchDownloadableModels(): Single> + fun fetchOnnxModels(): Single> + + fun fetchMediaPipeModels(): Single> fun downloadModel( remoteUrl: String, @@ -23,7 +25,10 @@ interface DownloadableModelsApi { interface RawApi { @GET("/models.json") - fun fetchDownloadableModels(): Single> + fun fetchOnnxModels(): Single> + + @GET("/mediapipe.json") + fun fetchMediaPipeModels(): Single> @Streaming @GET diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt index 88ddc5ef..04fd71c1 100644 --- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt +++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.network.api.sdai import com.shifthackz.aisdv1.network.extensions.saveFile +import com.shifthackz.aisdv1.network.response.DownloadableModelResponse import io.reactivex.rxjava3.core.Observable import io.reactivex.rxjava3.core.Single import java.io.File @@ -9,7 +10,9 @@ internal class DownloadableModelsApiImpl( private val rawApi: DownloadableModelsApi.RawApi, ) : DownloadableModelsApi { - override fun fetchDownloadableModels() = rawApi.fetchDownloadableModels() + override fun fetchOnnxModels() = rawApi.fetchOnnxModels() + + override fun fetchMediaPipeModels() = rawApi.fetchMediaPipeModels() override fun downloadModel( remoteUrl: String, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt index 714f0ef9..689fc79c 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt @@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.OpenAiSize import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.StabilityAiSampler @@ -118,7 +119,7 @@ abstract class GenerationMviViewModel? get() = status?.let { (current, total) -> current to total } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt index 125eb82c..a9a2a438 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt @@ -47,7 +47,7 @@ internal class MainRouterImpl : MainRouter { override fun navigateToServerSetup(source: LaunchSource) { effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.ordinal}") { if (source == LaunchSource.SPLASH) { - popUpTo(Constants.ROUTE_SPLASH) { + popUpTo(0) { inclusive = true } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt index 0d2fd5df..c6b5f7de 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt @@ -64,7 +64,7 @@ class DebugMenuViewModel( DebugMenuIntent.ViewLogs -> mainRouter.navigateToLogger() DebugMenuIntent.AllowLocalDiffusionCancel -> { - preferenceManager.localDiffusionAllowCancel = !currentState.localDiffusionAllowCancel + preferenceManager.localOnnxAllowCancel = !currentState.localDiffusionAllowCancel } DebugMenuIntent.LocalDiffusionScheduler.Request -> updateState { @@ -72,7 +72,7 @@ class DebugMenuViewModel( } is DebugMenuIntent.LocalDiffusionScheduler.Confirm -> { - preferenceManager.localDiffusionSchedulerThread = intent.token + preferenceManager.localOnnxSchedulerThread = intent.token } DebugMenuIntent.DismissModal -> updateState { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt index d5ad617a..482ea2d9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt @@ -32,7 +32,7 @@ class GalleryPagingSource( limit = pageSize, offset = pageNext * Constants.PAGINATION_PAYLOAD_SIZE, ) - .subscribeOn(schedulersProvider.io) + .subscribeOn(schedulersProvider.computation) .flatMapObservable { Observable.fromIterable(it) } .map { ai -> ai.id to ai.image } .map { (id, base64) -> id to Input(base64) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt index a6052000..67572859 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt @@ -468,7 +468,7 @@ fun GalleryScreenContent( val selected = state.selection.contains(item.id) GalleryUiItem( modifier = Modifier - .animateItemPlacement(tween(500)) + .animateItem(tween(500)) .shake( enabled = state.selectionMode && !selected, animationDurationMillis = 188, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt index fe8eeec1..000bca56 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt @@ -54,6 +54,9 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.sp import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.math.roundTo +import com.shifthackz.aisdv1.core.model.UiText +import com.shifthackz.aisdv1.core.model.asString +import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.ui.MviComponent import com.shifthackz.aisdv1.domain.entity.AiGenerationResult import com.shifthackz.aisdv1.domain.entity.ServerSource @@ -135,7 +138,7 @@ private fun ScreenContent( ) }, actions = { - if (state.mode != ServerSource.LOCAL) { + if (state.mode != ServerSource.LOCAL_MICROSOFT_ONNX) { IconButton( onClick = { processIntent( @@ -233,17 +236,33 @@ private fun ScreenContent( ) Text( modifier = Modifier.padding(top = 14.dp), - text = stringResource( - if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title - else LocalizationR.string.dalle_no_img2img_support_sub_title - ), + text = when (state.mode) { + ServerSource.OPEN_AI -> LocalizationR.string + .dalle_no_img2img_support_sub_title + .asUiText() + + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .local_no_img2img_support_sub_title + .asUiText() + + else -> UiText.empty + }.asString(), ) Text( modifier = Modifier.padding(top = 14.dp), - text = stringResource( - if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title_2 - else LocalizationR.string.dalle_no_img2img_support_sub_title_2 - ), + text = when (state.mode) { + ServerSource.OPEN_AI -> LocalizationR.string + .dalle_no_img2img_support_sub_title_2 + .asUiText() + + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .local_no_img2img_support_sub_title_2 + .asUiText() + + else -> UiText.empty + }.asString(), ) } } @@ -251,7 +270,8 @@ private fun ScreenContent( }, bottomBar = { val isEnabled = when (state.mode) { - ServerSource.LOCAL, + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.OPEN_AI -> true else -> !state.hasValidationErrors && !state.imageState.isEmpty @@ -271,20 +291,28 @@ private fun ScreenContent( keyboardController?.hide() when (state.mode) { ServerSource.OPEN_AI, - ServerSource.LOCAL -> processIntent(GenerationMviIntent.Configuration) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, + ServerSource.LOCAL_MICROSOFT_ONNX -> processIntent( + GenerationMviIntent.Configuration + ) else -> { promptChipTextFieldState.value.text.takeIf(String::isNotBlank) ?.let { "${state.prompt}, ${it.trim()}" } ?.let(GenerationMviIntent.Update::Prompt) ?.let(processIntent::invoke) - ?.also { promptChipTextFieldState.value = TextFieldValue("") } + ?.also { + promptChipTextFieldState.value = TextFieldValue("") + } negativePromptChipTextFieldState.value.text.takeIf(String::isNotBlank) ?.let { "${state.negativePrompt}, ${it.trim()}" } ?.let(GenerationMviIntent.Update::NegativePrompt) ?.let(processIntent::invoke) - ?.also { negativePromptChipTextFieldState.value = TextFieldValue("") } + ?.also { + negativePromptChipTextFieldState.value = + TextFieldValue("") + } processIntent(GenerationMviIntent.Generate) } @@ -292,8 +320,11 @@ private fun ScreenContent( }, enabled = isEnabled, ) { - if (state.mode != ServerSource.LOCAL) { - Icon( + when (state.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit + + else -> Icon( modifier = Modifier.size(18.dp), imageVector = Icons.Default.AutoFixNormal, contentDescription = "Imagine", @@ -303,7 +334,8 @@ private fun ScreenContent( modifier = Modifier.padding(start = 8.dp), text = stringResource( id = when (state.mode) { - ServerSource.LOCAL, + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.OPEN_AI -> LocalizationR.string.action_change_configuration else -> LocalizationR.string.action_generate diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt index 533b7b9b..3e301a66 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt @@ -90,7 +90,7 @@ data class ImageToImageState( heightValidationError: UiText?, nsfw: Boolean, batchCount: Int, - generateButtonEnabled: Boolean + generateButtonEnabled: Boolean, ): GenerationMviState = copy( onBoardingDemo = onBoardingDemo, screenModal = screenModal, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt index 47e94240..98c45cbd 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt @@ -133,7 +133,7 @@ fun InPaintComponent( } MotionEvent.Move -> { - currentPath.quadraticBezierTo( + currentPath.quadraticTo( previousPosition.x, previousPosition.y, (previousPosition.x + currentPosition.x) / 2, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt index c97b3966..e3833534 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt @@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter import com.shifthackz.android.core.mvi.EmptyEffect import com.shifthackz.android.core.mvi.EmptyIntent import io.reactivex.rxjava3.kotlin.subscribeBy +import java.util.concurrent.TimeUnit import com.shifthackz.aisdv1.core.localization.R as LocalizationR class ConfigurationLoaderViewModel( @@ -28,6 +29,7 @@ class ConfigurationLoaderViewModel( init { !dataPreLoaderUseCase() + .timeout(15L, TimeUnit.SECONDS) .doOnSubscribe { updateState { ConfigurationLoaderState.StatusNotification( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt index fbfb6e79..9138f169 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt @@ -60,7 +60,7 @@ fun LocalDiffusionPageContent( modifier = localModifier, state = TextToImageState( onBoardingDemo = true, - mode = ServerSource.LOCAL, + mode = ServerSource.LOCAL_MICROSOFT_ONNX, advancedToggleButtonVisible = false, advancedOptionsVisible = true, formPromptTaggedInput = true, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt index 7e50da66..e611a02e 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt @@ -235,7 +235,8 @@ private fun ContentSettingsState( ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face_short ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai - ServerSource.LOCAL -> LocalizationR.string.srv_type_local_short + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local_short + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe_short ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui }.asUiText(), onClick = { processIntent(SettingsIntent.NavigateConfiguration) }, @@ -256,7 +257,7 @@ private fun ContentSettingsState( endValueText = state.sdModelSelected.asUiText(), onClick = { processIntent(SettingsIntent.SdModel.OpenChooser) }, ) - if (state.showLocalUseNNAPI) { + if (state.showLocalMICROSOFTONNXUseNNAPI) { SettingsItem( modifier = itemModifier, loading = state.loading, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt index aff0ad43..312ecd55 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt @@ -37,8 +37,8 @@ data class SettingsState( val showStabilityAiCredits: Boolean get() = serverSource == ServerSource.STABILITY_AI - val showLocalUseNNAPI: Boolean - get() = serverSource == ServerSource.LOCAL + val showLocalMICROSOFTONNXUseNNAPI: Boolean + get() = serverSource == ServerSource.LOCAL_MICROSOFT_ONNX val showSdModelSelector: Boolean get() = serverSource == ServerSource.AUTOMATIC1111 diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt index a669acc0..d7756582 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt @@ -25,6 +25,7 @@ import com.shifthackz.aisdv1.presentation.screen.debug.DebugMenuAccessor import com.shifthackz.aisdv1.presentation.screen.drawer.DrawerIntent import io.reactivex.rxjava3.core.Flowable import io.reactivex.rxjava3.kotlin.subscribeBy +import java.util.concurrent.TimeUnit import com.shifthackz.aisdv1.core.localization.R as LocalizationR class SettingsViewModel( @@ -48,6 +49,7 @@ class SettingsViewModel( private val appVersionProducer = Flowable.fromCallable { buildInfoProvider.toString() } private val sdModelsProducer = getStableDiffusionModelsUseCase() + .timeout(10L, TimeUnit.SECONDS) .toFlowable() .onErrorReturn { emptyList() } @@ -142,7 +144,7 @@ class SettingsViewModel( } is SettingsIntent.UpdateFlag.NNAPI -> { - preferenceManager.localUseNNAPI = intent.flag + preferenceManager.localOnnxUseNNAPI = intent.flag } is SettingsIntent.UpdateFlag.TaggedInput -> { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt index a83405f0..612b4c90 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt @@ -153,7 +153,11 @@ fun ServerSetupScreenContent( onClick = { processIntent(ServerSetupIntent.MainButtonClick) }, enabled = when (state.step) { ServerSetupState.Step.CONFIGURE -> when (state.mode) { - ServerSource.LOCAL -> state.localModels.any { + ServerSource.LOCAL_MICROSOFT_ONNX -> state.localOnnxModels.any { + it.downloaded && it.selected + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.localMediaPipeModels.any { it.downloaded && it.selected } @@ -168,7 +172,9 @@ fun ServerSetupScreenContent( id = when (state.step) { ServerSetupState.Step.SOURCE -> LocalizationR.string.next else -> when (state.mode) { - ServerSource.LOCAL -> LocalizationR.string.action_setup + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string + .action_setup else -> LocalizationR.string.action_connect } }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 553e7bde..44a37715 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -5,9 +5,11 @@ import com.shifthackz.aisdv1.core.common.links.LinksProvider import com.shifthackz.aisdv1.core.model.UiText import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.DownloadState +import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.presentation.model.Modal +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants import com.shifthackz.android.core.mvi.MviState import org.koin.core.component.KoinComponent @@ -33,9 +35,12 @@ data class ServerSetupState( val password: String = "", val huggingFaceModels: List = emptyList(), val huggingFaceModel: String = "", - val localModels: List = emptyList(), - val localCustomModel: Boolean = false, - val localCustomModelPath: String = "", + val localOnnxModels: List = emptyList(), + val localOnnxCustomModel: Boolean = false, + val localOnnxCustomModelPath: String = "", + val localMediaPipeModels: List = emptyList(), + val localMediaPipeCustomModel: Boolean = false, + val localMediaPipeCustomModelPath: String = "", val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, val swarmUiUrlValidationError: UiText? = null, @@ -45,9 +50,38 @@ data class ServerSetupState( val huggingFaceApiKeyValidationError: UiText? = null, val openAiApiKeyValidationError: UiText? = null, val stabilityAiApiKeyValidationError: UiText? = null, - val localCustomModelPathValidationError: UiText? = null, + val localCustomOnnxPathValidationError: UiText? = null, + val localCustomMediaPipePathValidationError: UiText? = null, ) : MviState, KoinComponent { + val localCustomModel: Boolean + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxCustomModel + } else { + localMediaPipeCustomModel + } + + val localCustomModelPath: String + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxCustomModelPath + } else { + localMediaPipeCustomModelPath + } + + val localModels: List + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localOnnxModels + } else { + localMediaPipeModels + } + + val localCustomModelPathValidationError: UiText? + get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + localCustomOnnxPathValidationError + } else { + localCustomMediaPipePathValidationError + } + val demoModeUrl: String get() { val linksProvider: LinksProvider by inject() @@ -60,13 +94,101 @@ data class ServerSetupState( ) fun withCredentials(value: AuthorizationCredentials) = when (value) { - is AuthorizationCredentials.HttpBasic -> this.copy( + is AuthorizationCredentials.HttpBasic -> copy( login = value.login, password = value.password, ) + AuthorizationCredentials.None -> this } + fun withLocalCustomModelPath(value: String): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxCustomModelPath = value, + localCustomOnnxPathValidationError = null, + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeCustomModelPath = value, + localCustomMediaPipePathValidationError = null, + ) + + else -> this + } + + fun withUpdatedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxModels = localOnnxModels.withNewState(value) + ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeModels = localMediaPipeModels.withNewState(value) + ) + else -> this + } + + fun withDeletedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + screenModal = Modal.None, + localOnnxModels = localOnnxModels.withNewState( + value.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ) + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + screenModal = Modal.None, + localMediaPipeModels = localMediaPipeModels.withNewState( + value.copy( + downloadState = DownloadState.Unknown, + downloaded = false, + ), + ) + ) + + else -> copy(screenModal = Modal.None) + } + + fun withSelectedLocalModel(value: LocalModel): ServerSetupState = when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> copy( + localOnnxModels = localOnnxModels.withNewState( + value.copy(selected = true), + ), + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy( + localMediaPipeModels = localMediaPipeModels.withNewState( + value.copy(selected = true), + ), + ) + + else -> this + } + + fun withAllowCustomModel(value: Boolean): ServerSetupState { + fun List.updateCustomModelSelection(id: String) = withNewState( + find { m -> m.id == id }?.copy(selected = value) + ) + return when (mode) { + ServerSource.LOCAL_MICROSOFT_ONNX -> this.copy( + localOnnxCustomModel = value, + localOnnxModels = localOnnxModels.updateCustomModelSelection( + id = LocalAiModel.CustomOnnx.id, + ), + ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> this.copy( + localMediaPipeCustomModel = value, + localMediaPipeModels = localMediaPipeModels.updateCustomModelSelection( + id = LocalAiModel.CustomMediaPipe.id, + ), + ) + + else -> this + } + } + enum class Step { SOURCE, CONFIGURE; diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index 3677b6b3..d054a301 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -1,7 +1,10 @@ package com.shifthackz.aisdv1.presentation.screen.setup +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType import com.shifthackz.aisdv1.core.common.log.errorLog import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider +import com.shifthackz.aisdv1.core.common.model.Quadruple import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText @@ -11,7 +14,6 @@ import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel -import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials import com.shifthackz.aisdv1.domain.interactor.settings.SetupConnectionInterActor @@ -19,15 +21,17 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.presentation.model.LaunchSource import com.shifthackz.aisdv1.presentation.model.Modal import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomModelSwitchState +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomMediaPipeSwitchState +import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomOnnxSwitchState import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi -import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState import com.shifthackz.aisdv1.presentation.utils.Constants import io.reactivex.rxjava3.core.Single import io.reactivex.rxjava3.disposables.Disposable @@ -37,7 +41,8 @@ class ServerSetupViewModel( launchSource: LaunchSource, dispatchersProvider: DispatchersProvider, getConfigurationUseCase: GetConfigurationUseCase, - getLocalAiModelsUseCase: GetLocalAiModelsUseCase, + getLocalOnnxModelsUseCase: GetLocalOnnxModelsUseCase, + getLocalMediaPipeModelsUseCase: GetLocalMediaPipeModelsUseCase, fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val urlValidator: UrlValidator, private val stringValidator: CommonStringValidator, @@ -49,6 +54,7 @@ class ServerSetupViewModel( private val preferenceManager: PreferenceManager, private val wakeLockInterActor: WakeLockInterActor, private val mainRouter: MainRouter, + private val buildInfoProvider: BuildInfoProvider, ) : MviRxViewModel() { override val initialState = ServerSetupState( @@ -74,12 +80,13 @@ class ServerSetupViewModel( init { !Single.zip( getConfigurationUseCase(), - getLocalAiModelsUseCase(), + getLocalOnnxModelsUseCase(), + getLocalMediaPipeModelsUseCase(), fetchAndGetHuggingFaceModelsUseCase(), - ::Triple, + ::Quadruple, ) .subscribeOnMainThread(schedulersProvider) - .subscribeBy(::errorLog) { (configuration, localModels, hfModels) -> + .subscribeBy(::errorLog) { (configuration, onnxModels, mpModels, hfModels) -> updateState { state -> state.copy( huggingFaceModels = hfModels.map(HuggingFaceModel::alias), @@ -87,10 +94,14 @@ class ServerSetupViewModel( huggingFaceApiKey = configuration.huggingFaceApiKey, openAiApiKey = configuration.openAiApiKey, stabilityAiApiKey = configuration.stabilityAiApiKey, - localModels = localModels.mapToUi(), - localCustomModel = localModels.mapLocalCustomModelSwitchState(), - localCustomModelPath = configuration.localModelPath, + localOnnxModels = onnxModels.mapToUi(), + localOnnxCustomModel = onnxModels.mapLocalCustomOnnxSwitchState(), + localOnnxCustomModelPath = configuration.localOnnxModelPath, + localMediaPipeModels = mpModels.mapToUi(), + localMediaPipeCustomModel = mpModels.mapLocalCustomMediaPipeSwitchState(), + localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath, mode = configuration.source, + allowedModes = buildInfoProvider.allowedModes, demoMode = configuration.demoMode, serverUrl = configuration.serverUrl, swarmUiUrl = configuration.swarmUiUrl, @@ -110,15 +121,8 @@ class ServerSetupViewModel( } override fun processIntent(intent: ServerSetupIntent) = when (intent) { - is ServerSetupIntent.AllowLocalCustomModel -> updateState { - it.copy( - localCustomModel = intent.allow, - localModels = currentState.localModels.withNewState( - currentState.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }?.copy( - selected = intent.allow, - ), - ), - ) + is ServerSetupIntent.AllowLocalCustomModel -> updateState { state -> + state.withAllowCustomModel(intent.allow) } ServerSetupIntent.DismissDialog -> setScreenModal(Modal.None) @@ -129,28 +133,11 @@ class ServerSetupViewModel( !deleteModelUseCase(intent.model.id) .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) - it.copy( - screenModal = Modal.None, - localModels = currentState.localModels.withNewState( - intent.model.copy( - downloadState = DownloadState.Unknown, - downloaded = false, - ), - ), - ) + it.withDeletedLocalModel(intent.model) } - is ServerSetupIntent.SelectLocalModel -> { - if (currentState.localModels.any { it.downloadState is DownloadState.Downloading }) { - Unit - } - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - intent.model.copy(selected = true), - ), - ) - } + is ServerSetupIntent.SelectLocalModel -> updateState { state -> + state.withSelectedLocalModel(intent.model) } ServerSetupIntent.MainButtonClick -> when (currentState.step) { @@ -236,11 +223,8 @@ class ServerSetupViewModel( ServerSetupIntent.ConnectToLocalHost -> connectToServer() - is ServerSetupIntent.SelectLocalModelPath -> updateState { - it.copy( - localCustomModelPath = intent.value, - localCustomModelPathValidationError = null, - ) + is ServerSetupIntent.SelectLocalModelPath -> updateState { state -> + state.withLocalCustomModelPath(intent.value) } } @@ -253,12 +237,13 @@ class ServerSetupViewModel( emitEffect(ServerSetupEffect.HideKeyboard) !when (currentState.mode) { ServerSource.HORDE -> connectToHorde() - ServerSource.LOCAL -> connectToLocalDiffusion() + ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusion() ServerSource.AUTOMATIC1111 -> connectToAutomaticInstance() ServerSource.HUGGING_FACE -> connectToHuggingFace() ServerSource.OPEN_AI -> connectToOpenAi() ServerSource.STABILITY_AI -> connectToStabilityAi() ServerSource.SWARM_UI -> connectToSwarmUi() + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> connectToMediaPipe() } .doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) } .subscribeOnMainThread(schedulersProvider) @@ -292,15 +277,27 @@ class ServerSetupViewModel( } } - ServerSource.LOCAL -> { - if (currentState.localCustomModel) { - val validation = filePathValidator(currentState.localCustomModelPath) + ServerSource.LOCAL_MICROSOFT_ONNX -> if (currentState.localOnnxCustomModel) { + val validation = filePathValidator(currentState.localOnnxCustomModelPath) + updateState { + it.copy(localCustomOnnxPathValidationError = validation.mapToUi()) + } + validation.isValid + } else { + currentState.localOnnxModels.find { it.selected && it.downloaded } != null + } + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> when { + buildInfoProvider.type == BuildType.FOSS -> false + currentState.localMediaPipeCustomModel -> { + val validation = filePathValidator(currentState.localMediaPipeCustomModelPath) updateState { - it.copy(localCustomModelPathValidationError = validation.mapToUi()) + it.copy(localCustomMediaPipePathValidationError = validation.mapToUi()) } validation.isValid - } else { - currentState.localModels.find { it.selected && it.downloaded } != null + } + else -> { + currentState.localMediaPipeModels.find { it.selected && it.downloaded } != null } } @@ -405,87 +402,81 @@ class ServerSetupViewModel( } private fun connectToLocalDiffusion(): Single> { - preferenceManager.localDiffusionCustomModelPath = currentState.localCustomModelPath - val localModelId = currentState.localModels.find { it.selected }?.id ?: "" + preferenceManager.localOnnxCustomModelPath = currentState.localOnnxCustomModelPath + val localModelId = currentState.localOnnxModels.find { it.selected }?.id ?: "" return setupConnectionInterActor.connectToLocal(localModelId) } - private fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) { + private fun connectToMediaPipe(): Single> { + preferenceManager.localMediaPipeCustomModelPath = currentState.localMediaPipeCustomModelPath + val localModelId = currentState.localMediaPipeModels.find { it.selected }?.id ?: "" + return setupConnectionInterActor.connectToMediaPipe(localModelId) + } + + private fun localModelDownloadClickReducer(value: ServerSetupState.LocalModel) { + fun localModel(): ServerSetupState.LocalModel = + currentState.localModels.firstOrNull { it.id == value.id } + ?.let { value.copy(selected = it.selected) } + ?: value + when { // User cancels download - localModel.downloadState is DownloadState.Downloading -> { - val index = downloadDisposables.indexOfFirst { it.first == localModel.id } + localModel().downloadState is DownloadState.Downloading -> { + val index = downloadDisposables.indexOfFirst { it.first == localModel().id } if (index != -1) { downloadDisposables[index].second.dispose() downloadDisposables.removeAt(index) } - !deleteModelUseCase(localModel.id) + !deleteModelUseCase(localModel().id) .subscribeOnMainThread(schedulersProvider) .subscribeBy(::errorLog) - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - localModel.copy(downloadState = DownloadState.Unknown), - ), + updateState { state -> + state.withUpdatedLocalModel( + value = localModel().copy(downloadState = DownloadState.Unknown), ) } } // User deletes local model - localModel.downloaded -> updateState { - it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel)) + localModel().downloaded -> updateState { + it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel())) } // User requested new download operation else -> { - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - localModel.copy( - downloadState = DownloadState.Downloading(), - ), - ), + updateState { state -> + state.withUpdatedLocalModel( + localModel().copy(downloadState = DownloadState.Downloading()), ) } - !downloadModelUseCase(localModel.id) + !downloadModelUseCase(localModel().id) .distinctUntilChanged() .doOnSubscribe { wakeLockInterActor.acquireWakelockUseCase() } .doFinally { wakeLockInterActor.releaseWakeLockUseCase() } - .subscribeOnMainThread(schedulersProvider).subscribeBy( + .subscribeOnMainThread(schedulersProvider) + .subscribeBy( onError = { t -> errorLog(t) val message = t.localizedMessage ?: "Error" - updateState { - it.copy( - localModels = currentState.localModels.withNewState( - localModel.copy( - downloadState = DownloadState.Error(t), - ), + updateState { state -> + state.withUpdatedLocalModel( + localModel().copy( + downloadState = DownloadState.Error(t), ), ) } setScreenModal(Modal.Error(message.asUiText())) }, onNext = { downloadState -> - updateState { - when (downloadState) { - is DownloadState.Complete -> it.copy( - localModels = it.localModels.withNewState( - localModel.copy( - downloadState = downloadState, - downloaded = true, - ), - ), - ) - - else -> it.copy( - localModels = it.localModels.withNewState( - localModel.copy(downloadState = downloadState), - ), - ) - } + updateState { state -> + state.withUpdatedLocalModel( + localModel().copy( + downloadState = downloadState, + downloaded = downloadState is DownloadState.Complete + ), + ) } }, ) - .also { downloadDisposables.add(localModel.id to it) } + .also { downloadDisposables.add(localModel().id to it) } } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt index 19815b66..c84b0547 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt @@ -72,7 +72,8 @@ fun ConfigurationModeButton( ServerSource.OPEN_AI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE -> Icons.Default.Cloud - ServerSource.LOCAL -> Icons.Default.Android + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Icons.Default.Android else -> Icons.Default.QuestionMark }, contentDescription = null, @@ -91,9 +92,10 @@ fun ConfigurationModeButton( ServerSource.HORDE -> LocalizationR.string.hint_server_horde_sub_title ServerSource.HUGGING_FACE -> LocalizationR.string.hint_hugging_face_sub_title ServerSource.OPEN_AI -> LocalizationR.string.hint_open_ai_sub_title - ServerSource.LOCAL -> LocalizationR.string.hint_local_diffusion_sub_title + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.hint_local_diffusion_sub_title ServerSource.STABILITY_AI -> LocalizationR.string.hint_stability_ai_sub_title ServerSource.SWARM_UI -> LocalizationR.string.hint_swarm_ui_sub_title + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.hint_mediapipe_sub_title else -> null } descriptionId?.let { resId -> diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt index 6ebfed22..4408ad5d 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt @@ -54,6 +54,7 @@ import com.shifthackz.aisdv1.core.extensions.getRealPath import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupScreenTags.CUSTOM_MODEL_SWITCH import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState @@ -91,7 +92,8 @@ fun LocalDiffusionForm( val icon = when (model.downloadState) { is DownloadState.Downloading -> Icons.Outlined.FileDownload else -> when { - model.id == LocalAiModel.CUSTOM.id -> Icons.Outlined.Landslide + model.id == LocalAiModel.CustomOnnx.id -> Icons.Outlined.Landslide + model.id == LocalAiModel.CustomMediaPipe.id -> Icons.Outlined.Landslide model.downloaded -> Icons.Outlined.FileDownloadDone else -> Icons.Outlined.FileDownloadOff } @@ -113,14 +115,20 @@ fun LocalDiffusionForm( overflow = TextOverflow.Ellipsis, maxLines = 2 ) - if (model.id != LocalAiModel.CUSTOM.id) { - Text( + when (model.id) { + LocalAiModel.CustomOnnx.id, + LocalAiModel.CustomMediaPipe.id -> Unit + + else -> Text( text = model.size, maxLines = 1 ) } } - if (model.id != LocalAiModel.CUSTOM.id) { + // Do not display action button for custom model + if (model.id != LocalAiModel.CustomOnnx.id + && model.id != LocalAiModel.CustomMediaPipe.id + ) { Button( modifier = Modifier.padding(end = 8.dp), onClick = { processIntent(ServerSetupIntent.LocalModel.ClickReduce(model)) }, @@ -142,7 +150,9 @@ fun LocalDiffusionForm( } } } - if (model.id == LocalAiModel.CUSTOM.id) { + if (model.id == LocalAiModel.CustomOnnx.id + || model.id == LocalAiModel.CustomMediaPipe.id + ) { Column( modifier = Modifier.padding(8.dp), ) { @@ -150,81 +160,83 @@ fun LocalDiffusionForm( text = stringResource(id = LocalizationR.string.model_local_custom_title), style = MaterialTheme.typography.bodyMedium, ) - Spacer(modifier = Modifier.height(4.dp)) - Text( - text = stringResource(id = LocalizationR.string.model_local_custom_sub_title), - style = MaterialTheme.typography.bodyMedium, - ) - Spacer(modifier = Modifier.height(4.dp)) + if (model.id == LocalAiModel.CustomOnnx.id) { + Spacer(modifier = Modifier.height(4.dp)) + Text( + text = stringResource(id = LocalizationR.string.model_local_custom_sub_title), + style = MaterialTheme.typography.bodyMedium, + ) + Spacer(modifier = Modifier.height(4.dp)) - fun folderModifier(treeNum: Int) = - Modifier.padding(start = (treeNum - 1) * 12.dp) + fun folderModifier(treeNum: Int) = + Modifier.padding(start = (treeNum - 1) * 12.dp) - val folderStyle = MaterialTheme.typography.bodySmall - Text( - modifier = Modifier.padding(start = 12.dp), - text = state.localCustomModelPath, - style = folderStyle, - ) + val folderStyle = MaterialTheme.typography.bodySmall + Text( + modifier = Modifier.padding(start = 12.dp), + text = state.localOnnxCustomModelPath, + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "text_encoder", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "text_encoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "tokenizer", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "merges.txt", - style = folderStyle, - ) - Text( - modifier = folderModifier(3), - text = "special_tokens_map.json", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "tokenizer_config.json", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "vocab.json", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "tokenizer", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "merges.txt", + style = folderStyle, + ) + Text( + modifier = folderModifier(3), + text = "special_tokens_map.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "tokenizer_config.json", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "vocab.json", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "unet", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "unet", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) - Text( - modifier = folderModifier(3), - text = "vae_decoder", - style = folderStyle, - ) - Text( - modifier = folderModifier(4), - text = "model.ort", - style = folderStyle, - ) + Text( + modifier = folderModifier(3), + text = "vae_decoder", + style = folderStyle, + ) + Text( + modifier = folderModifier(4), + text = "model.ort", + style = folderStyle, + ) + } } } when (model.downloadState) { @@ -258,17 +270,29 @@ fun LocalDiffusionForm( modifier = Modifier .fillMaxWidth() .padding(top = 32.dp, bottom = 8.dp), - text = stringResource(id = LocalizationR.string.hint_local_diffusion_title), + text = stringResource( + id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + LocalizationR.string.hint_local_diffusion_title + } else { + LocalizationR.string.hint_mediapipe_title + }, + ), style = MaterialTheme.typography.bodyLarge, textAlign = TextAlign.Center, fontWeight = FontWeight.Bold, ) Text( modifier = Modifier.padding(top = 16.dp, bottom = 16.dp), - text = stringResource(id = LocalizationR.string.hint_local_diffusion_sub_title), + text = stringResource( + id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) { + LocalizationR.string.hint_local_diffusion_sub_title + } else { + LocalizationR.string.hint_mediapipe_sub_title + }, + ), style = MaterialTheme.typography.bodyMedium, ) - if (buildInfoProvider.type == BuildType.FOSS) { + if (buildInfoProvider.type != BuildType.PLAY) { Row( verticalAlignment = Alignment.CenterVertically, ) { @@ -285,7 +309,7 @@ fun LocalDiffusionForm( ) } } - if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) { + if (state.localCustomModel && buildInfoProvider.type != BuildType.PLAY) { Text( modifier = Modifier .align(Alignment.CenterHorizontally) @@ -342,9 +366,12 @@ fun LocalDiffusionForm( .fillMaxWidth() .padding(top = 14.dp), value = state.localCustomModelPath, - onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) }, + onValueChange = { string -> + string.filter { it != '\n' } + .let(ServerSetupIntent::SelectLocalModelPath) + .let(processIntent::invoke) + }, enabled = true, - singleLine = true, label = { Text(stringResource(LocalizationR.string.model_local_path_title)) }, trailingIcon = { IconButton( @@ -387,7 +414,8 @@ fun LocalDiffusionForm( } state.localModels .filter { - val customPredicate = it.id == LocalAiModel.CUSTOM.id + val customPredicate = + it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id if (state.localCustomModel) customPredicate else !customPredicate } .forEach { localModel -> modelItemUi(localModel) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt new file mode 100644 index 00000000..3bb26b1d --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt @@ -0,0 +1,29 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.forms + +import androidx.compose.runtime.Composable +import androidx.compose.ui.Modifier +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent +import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState + +@Composable +fun MediaPipeForm( + modifier: Modifier = Modifier, + state: ServerSetupState, + buildInfoProvider: BuildInfoProvider = BuildInfoProvider.stub, + processIntent: (ServerSetupIntent) -> Unit = {}, +) { + when (buildInfoProvider.type) { + BuildType.FOSS -> { + + } + + else -> LocalDiffusionForm( + modifier = modifier, + state = state, + buildInfoProvider = buildInfoProvider, + processIntent = processIntent, + ) + } +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt index f3a9ddc6..3cd1e443 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt @@ -5,8 +5,11 @@ import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState fun List.mapToUi(): List = map(LocalAiModel::mapToUi) -fun List.mapLocalCustomModelSwitchState(): Boolean = - find { it.selected && it.id == LocalAiModel.CUSTOM.id } != null +fun List.mapLocalCustomOnnxSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CustomOnnx.id } != null + +fun List.mapLocalCustomMediaPipeSwitchState(): Boolean = + find { it.selected && it.id == LocalAiModel.CustomMediaPipe.id } != null fun LocalAiModel.mapToUi(): ServerSetupState.LocalModel = with(this) { ServerSetupState.LocalModel( diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt new file mode 100644 index 00000000..1dd0b9c2 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt @@ -0,0 +1,9 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.mappers + +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider +import com.shifthackz.aisdv1.domain.entity.ServerSource + +val BuildInfoProvider.allowedModes: List + get() = ServerSource + .entries + .filter { it.allowedInBuilds.contains(type) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt index 4d53ca89..d25e8237 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt @@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.forms.Automatic1111Form import com.shifthackz.aisdv1.presentation.screen.setup.forms.HordeForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.HuggingFaceForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.LocalDiffusionForm +import com.shifthackz.aisdv1.presentation.screen.setup.forms.MediaPipeForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.OpenAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.StabilityAiForm import com.shifthackz.aisdv1.presentation.screen.setup.forms.SwarmUiForm @@ -33,7 +34,7 @@ fun ConfigurationStep( processIntent = processIntent, ) - ServerSource.LOCAL -> LocalDiffusionForm( + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalDiffusionForm( state = state, buildInfoProvider = buildInfoProvider, processIntent = processIntent, @@ -58,6 +59,12 @@ fun ConfigurationStep( state = state, processIntent = processIntent, ) + + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> MediaPipeForm( + state = state, + buildInfoProvider = buildInfoProvider, + processIntent = processIntent, + ) } } } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt index 54c9f1b0..f7779458 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt @@ -5,16 +5,23 @@ import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.padding import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.LazyListItemInfo import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState import androidx.compose.runtime.Composable import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableIntStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue import androidx.compose.ui.Modifier +import androidx.compose.ui.layout.onSizeChanged import androidx.compose.ui.unit.dp import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton +import kotlin.math.abs @Composable fun SourceSelectionStep( @@ -23,12 +30,26 @@ fun SourceSelectionStep( processIntent: (ServerSetupIntent) -> Unit = {}, ) { val lazyListState = rememberLazyListState() + var lazyListHeight by remember { mutableIntStateOf(0) } + var lazyListItemHeight by remember { mutableIntStateOf(0) } + LaunchedEffect(state.mode) { // Adding 1 here, because item with index == 0 is top spacer - lazyListState.animateScrollToItem(state.mode.ordinal + 1) + val newIndex = state.mode.ordinal +1 + val visibleIndexes = lazyListState.layoutInfo.visibleItemsInfo + .filter { it.offset >= 0 } + .filter { + if (lazyListHeight == 0 || lazyListItemHeight == 0) true + else abs(lazyListHeight - it.offset) >= lazyListItemHeight + } + .map(LazyListItemInfo::index) + + if (!visibleIndexes.contains(newIndex)) lazyListState.animateScrollToItem(newIndex) } + LazyColumn( - modifier = modifier, + modifier = modifier + .onSizeChanged { lazyListHeight = it.height }, state = lazyListState, ) { item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) } @@ -39,7 +60,8 @@ fun SourceSelectionStep( ConfigurationModeButton( modifier = Modifier .fillMaxWidth() - .padding(horizontal = 16.dp, vertical = 4.dp), + .padding(horizontal = 16.dp, vertical = 4.dp) + .onSizeChanged { lazyListItemHeight = it.height }, state = state, mode = mode, onClick = { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt index 2dad0725..04466de6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt @@ -133,7 +133,7 @@ fun TextToImageState.mapToPayload(): TextToImagePayload = with(this) { subSeedStrength = subSeedStrength, sampler = selectedSampler, nsfw = if (mode == ServerSource.HORDE) nsfw else false, - batchCount = if (mode == ServerSource.LOCAL) 1 else batchCount, + batchCount = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) 1 else batchCount, style = openAiStyle.key.takeIf { mode == ServerSource.OPEN_AI && openAiModel == OpenAiModel.DALL_E_3 }, diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt index 515c3248..52f063ed 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt @@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.notification.PushNotificationManager import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager @@ -66,11 +67,13 @@ class TextToImageViewModel( ) { private val progressModal: Modal - get() { - if (currentState.mode == ServerSource.LOCAL) { - return Modal.Generating(canCancel = preferenceManager.localDiffusionAllowCancel) + get() = when (currentState.mode) { + ServerSource.LOCAL_MICROSOFT_ONNX, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> { + Modal.Generating(canCancel = preferenceManager.localOnnxAllowCancel) } - return Modal.Communicating() + + else -> Modal.Communicating() } override val initialState = TextToImageState() @@ -135,7 +138,7 @@ class TextToImageViewModel( ?.let(::setActiveModal) } - override fun onReceivedLocalDiffusionStatus(status: LocalDiffusion.Status) { + override fun onReceivedLocalDiffusionStatus(status: LocalDiffusionStatus) { (currentState.screenModal as? Modal.Generating) ?.copy(status = status) ?.let(::setActiveModal) diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt index ae6e5173..bf787e02 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt @@ -54,7 +54,7 @@ fun EngineSelectionComponent( onItemSelected = { intentHandler(EngineSelectionIntent(it)) }, ) - ServerSource.LOCAL -> DropdownTextField( + ServerSource.LOCAL_MICROSOFT_ONNX -> DropdownTextField( label = LocalizationR.string.hint_sd_model.asUiText(), loading = state.loading, modifier = modifier, @@ -64,6 +64,7 @@ fun EngineSelectionComponent( displayDelegate = { it.name.asUiText() }, ) + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit ServerSource.HORDE -> Unit ServerSource.OPEN_AI -> Unit } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt index 294a7c6c..9bb71002 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt @@ -11,7 +11,7 @@ import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase @@ -25,7 +25,7 @@ import io.reactivex.rxjava3.kotlin.subscribeBy class EngineSelectionViewModel( dispatchersProvider: DispatchersProvider, fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase, - observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase, + observeLocalOnnxModelsUseCase: ObserveLocalOnnxModelsUseCase, fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val preferenceManager: PreferenceManager, @@ -61,8 +61,8 @@ class EngineSelectionViewModel( .onErrorReturn { emptyList() } .toFlowable() - val localAiModels = observeLocalAiModelsUseCase() - .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CUSTOM.id } } + val localAiModels = observeLocalOnnxModelsUseCase() + .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CustomOnnx.id } } .onErrorReturn { emptyList() } !Flowable.combineLatest( @@ -94,7 +94,7 @@ class EngineSelectionViewModel( stEngines = stEngines.map { it.id }, selectedStEngine = config.stabilityAiEngineId, localAiModels = localModels, - selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localModelId }?.id + selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localOnnxModelId }?.id ?: state.selectedLocalAiModelId ) } @@ -131,7 +131,7 @@ class EngineSelectionViewModel( ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value - ServerSource.LOCAL -> preferenceManager.localModelId = intent.value + ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localOnnxModelId = intent.value else -> Unit } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt index 1a8583d5..fe9b4914 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt @@ -148,7 +148,7 @@ fun GenerationInputForm( ServerSource.SWARM_UI, ServerSource.STABILITY_AI, ServerSource.HUGGING_FACE, - ServerSource.LOCAL -> EngineSelectionComponent( + ServerSource.LOCAL_MICROSOFT_ONNX -> EngineSelectionComponent( modifier = Modifier .fillMaxWidth() .padding(top = 8.dp), @@ -206,7 +206,7 @@ fun GenerationInputForm( ServerSource.SWARM_UI, ServerSource.HUGGING_FACE, ServerSource.STABILITY_AI, - ServerSource.LOCAL -> { + ServerSource.LOCAL_MICROSOFT_ONNX -> { if (state.formPromptTaggedInput) { ChipTextFieldWithItem( modifier = Modifier @@ -256,7 +256,7 @@ fun GenerationInputForm( when (state.mode) { ServerSource.HORDE, - ServerSource.LOCAL -> { + ServerSource.LOCAL_MICROSOFT_ONNX -> { DropdownTextField( modifier = localModifier.padding(end = 4.dp), label = LocalizationR.string.width.asUiText(), @@ -295,6 +295,7 @@ fun GenerationInputForm( displayDelegate = { it.key.asUiText() }, ) } + else -> Unit } } @@ -497,9 +498,11 @@ fun GenerationInputForm( else -> Unit } + //Steps not available for open ai if (state.mode != ServerSource.OPEN_AI) { val stepsMax = when (state.mode) { - ServerSource.LOCAL -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX + ServerSource.LOCAL_MICROSOFT_ONNX -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX ServerSource.STABILITY_AI -> SAMPLING_STEPS_RANGE_STABILITY_AI_MAX else -> SAMPLING_STEPS_RANGE_MAX } @@ -519,24 +522,31 @@ fun GenerationInputForm( processIntent(GenerationMviIntent.Update.SamplingSteps(it.roundToInt())) }, ) + } - Text( - modifier = Modifier.padding(top = 8.dp), - text = stringResource( - LocalizationR.string.hint_cfg_scale, - "${state.cfgScale.roundTo(2)}", - ), - ) - SliderTextInputField( - value = state.cfgScale, - valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f), - valueDiff = 0.5f, - steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1, - sliderColors = sliderColors, - onValueChange = { - processIntent(GenerationMviIntent.Update.CfgScale(it)) - }, - ) + // CFG scale not available on open ai and google media pipe + when (state.mode) { + ServerSource.OPEN_AI, + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit + else -> { + Text( + modifier = Modifier.padding(top = 8.dp), + text = stringResource( + LocalizationR.string.hint_cfg_scale, + "${state.cfgScale.roundTo(2)}", + ), + ) + SliderTextInputField( + value = state.cfgScale, + valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f), + valueDiff = 0.5f, + steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1, + sliderColors = sliderColors, + onValueChange = { + processIntent(GenerationMviIntent.Update.CfgScale(it)) + }, + ) + } } when (state.mode) { @@ -548,9 +558,10 @@ fun GenerationInputForm( else -> Unit } - // Batch is not available for Local Diffusion - if (state.mode != ServerSource.LOCAL) { - batchComponent() + // Batch is not available for any Local + when (state.mode) { + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.LOCAL_MICROSOFT_ONNX -> Unit + else -> batchComponent() } //Restore faces available only for A1111 if (state.mode == ServerSource.AUTOMATIC1111) { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt index 4175e70f..8e3964c9 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt @@ -94,9 +94,9 @@ fun SliderTextInputField( enabled = true, singleLine = true, keyboardOptions = KeyboardOptions( + autoCorrectEnabled = false, keyboardType = KeyboardType.Number, - autoCorrect = false, - imeAction = ImeAction.Done, + imeAction = ImeAction.Done ), label = { Text(stringResource(id = R.string.hint_value)) }, trailingIcon = { diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt index fdca7744..4ee15fd4 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt @@ -15,7 +15,8 @@ fun ServerSource.getName(): String { fun ServerSource.getNameUiText(): UiText = when (this) { ServerSource.AUTOMATIC1111 -> LocalizationR.string.srv_type_own ServerSource.HORDE -> LocalizationR.string.srv_type_horde - ServerSource.LOCAL -> LocalizationR.string.srv_type_local + ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local + ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt index e6c332c2..aa25ae1c 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt @@ -4,8 +4,8 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.notification.PushNotificationManager import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus +import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus import com.shifthackz.aisdv1.domain.entity.Settings -import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor @@ -57,7 +57,7 @@ abstract class CoreGenerationMviViewModelTest() private val stubHordeProcessStatus = BehaviorSubject.create() - private val stubLdStatus = BehaviorSubject.create() + private val stubLdStatus = BehaviorSubject.create() protected val stubCustomSchedulers = object : SchedulersProvider { diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt index 8b9eb04a..62884de9 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt @@ -5,9 +5,10 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState val mockLocalAiModels = listOf( - LocalAiModel.CUSTOM, + LocalAiModel.CustomOnnx, LocalAiModel( id = "1", + type = LocalAiModel.Type.ONNX, name = "Model 1", size = "5 Gb", sources = listOf("https://example.com/1.html"), diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt index 05e9a78f..d9453b19 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt @@ -237,13 +237,13 @@ class SettingsViewModelTest : CoreViewModelTest() { @Test fun `given received UpdateFlag NNAPI intent, expected localUseNNAPI preference updated`() { every { - stubPreferenceManager::localUseNNAPI.set(any()) + stubPreferenceManager::localOnnxUseNNAPI.set(any()) } returns Unit viewModel.processIntent(SettingsIntent.UpdateFlag.NNAPI(true)) verify { - stubPreferenceManager::localUseNNAPI.set(true) + stubPreferenceManager::localOnnxUseNNAPI.set(true) } } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt index c82a1dbc..27eb75b1 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt @@ -107,7 +107,7 @@ class ServerSetupScreenTest : CoreComposeTest { stubUiState.update { it.copy( step = ServerSetupState.Step.CONFIGURE, - mode = ServerSource.LOCAL + mode = ServerSource.LOCAL_MICROSOFT_ONNX ) } @@ -125,8 +125,8 @@ class ServerSetupScreenTest : CoreComposeTest { stubUiState.update { it.copy( step = ServerSetupState.Step.CONFIGURE, - mode = ServerSource.LOCAL, - localModels = mockLocalAiModels.mapToUi() + mode = ServerSource.LOCAL_MICROSOFT_ONNX, + localOnnxModels = mockLocalAiModels.mapToUi() ) } val setupButton = onNodeWithTestTag(ServerSetupScreenTags.MAIN_BUTTON) @@ -143,9 +143,9 @@ class ServerSetupScreenTest : CoreComposeTest { switch.performClick() stubUiState.update { it.copy( - localCustomModel = true, - localModels = it.localModels.withNewState( - it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy( + localOnnxCustomModel = true, + localOnnxModels = it.localOnnxModels.withNewState( + it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy( selected = true, downloaded = true ), @@ -165,9 +165,9 @@ class ServerSetupScreenTest : CoreComposeTest { switch.performClick() stubUiState.update { it.copy( - localCustomModel = false, - localModels = it.localModels.withNewState( - it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy( + localOnnxCustomModel = false, + localOnnxModels = it.localOnnxModels.withNewState( + it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy( selected = false, ), ), diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index a75bad5b..f6ae1a88 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -1,5 +1,6 @@ package com.shifthackz.aisdv1.presentation.screen.setup +import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator import com.shifthackz.aisdv1.core.validation.path.FilePathValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidator @@ -11,7 +12,8 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase -import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest @@ -39,7 +41,8 @@ import org.junit.Test class ServerSetupViewModelTest : CoreViewModelTest() { private val stubGetConfigurationUseCase = mockk() - private val stubGetLocalAiModelsUseCase = mockk() + private val stubGetLocalOnnxModelsUseCase = mockk() + private val stubGetLocalMediaPipeModelsUseCase = mockk() private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() @@ -55,7 +58,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { launchSource = LaunchSource.SETTINGS, dispatchersProvider = stubDispatchersProvider, getConfigurationUseCase = stubGetConfigurationUseCase, - getLocalAiModelsUseCase = stubGetLocalAiModelsUseCase, + getLocalOnnxModelsUseCase = stubGetLocalOnnxModelsUseCase, + getLocalMediaPipeModelsUseCase = stubGetLocalMediaPipeModelsUseCase, fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, urlValidator = stubUrlValidator, stringValidator = stubCommonStringValidator, @@ -67,6 +71,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { preferenceManager = stubPreferenceManager, wakeLockInterActor = stubWakeLockInterActor, mainRouter = stubMainRouter, + buildInfoProvider = BuildInfoProvider.stub, ) @Before @@ -78,9 +83,13 @@ class ServerSetupViewModelTest : CoreViewModelTest() { } returns Single.just(Configuration(serverUrl = "https://5598.is.my.favorite.com")) every { - stubGetLocalAiModelsUseCase() + stubGetLocalOnnxModelsUseCase() } returns Single.just(mockLocalAiModels) + every { + stubGetLocalMediaPipeModelsUseCase() + } returns Single.just(emptyList()) + every { stubFetchAndGetHuggingFaceModelsUseCase() } returns Single.just(mockHuggingFaceModels) @@ -96,13 +105,14 @@ class ServerSetupViewModelTest : CoreViewModelTest() { fun `initialized, expected UI state updated with correct stub values`() { val state = viewModel.state.value Assert.assertEquals(true, state.huggingFaceModels.isNotEmpty()) - Assert.assertEquals(true, state.localModels.isNotEmpty()) + Assert.assertEquals(true, state.localOnnxModels.isNotEmpty()) Assert.assertEquals("https://5598.is.my.favorite.com", state.serverUrl) Assert.assertEquals(ServerSetupState.AuthType.ANONYMOUS, state.authType) } @Test fun `given received AllowLocalCustomModel intent, expected Custom local model selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) viewModel.processIntent(ServerSetupIntent.AllowLocalCustomModel(true)) val state = viewModel.state.value val expectedLocalModels = listOf( @@ -121,8 +131,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { selected = false, ) ) - Assert.assertEquals(true, state.localCustomModel) - Assert.assertEquals(expectedLocalModels, state.localModels) + Assert.assertEquals(true, state.localOnnxCustomModel) + Assert.assertEquals(expectedLocalModels, state.localOnnxModels) } @Test @@ -147,6 +157,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { stubWakeLockInterActor.releaseWakeLockUseCase() } returns Result.success(Unit) + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) val localModel = mockServerSetupStateLocalModel.copy( downloadState = DownloadState.Unknown, ) @@ -199,7 +210,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { val state = viewModel.state.value val expected = false - val actual = state.localModels.any { + val actual = state.localOnnxModels.any { it.downloadState == DownloadState.Downloading(22) } Assert.assertEquals(expected, actual) @@ -221,7 +232,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { runTest { val state = viewModel.state.value Assert.assertEquals(Modal.None, state.screenModal) - Assert.assertEquals(false, state.localModels.find { it.id == "1" }!!.downloaded) + Assert.assertEquals(false, state.localOnnxModels.find { it.id == "1" }!!.downloaded) } verify { stubDeleteModelUseCase("1") @@ -230,9 +241,10 @@ class ServerSetupViewModelTest : CoreViewModelTest() { @Test fun `given received SelectLocalModel intent, expected passed LocalModel is selected in UI state`() { + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) viewModel.processIntent(ServerSetupIntent.SelectLocalModel(mockServerSetupStateLocalModel)) val state = viewModel.state.value - Assert.assertEquals(true, state.localModels.find { it.id == "1" }!!.selected) + Assert.assertEquals(true, state.localOnnxModels.find { it.id == "1" }!!.selected) } @Test @@ -317,8 +329,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() { @Test fun `given received UpdateServerMode intent, expected mode field in UI state is LOCAL`() { - viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL)) - val expected = ServerSource.LOCAL + viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX)) + val expected = ServerSource.LOCAL_MICROSOFT_ONNX val actual = viewModel.state.value.mode Assert.assertEquals(expected, actual) } diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt index 855a9128..269474df 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt @@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.domain.entity.ServerSource import com.shifthackz.aisdv1.domain.entity.Settings import com.shifthackz.aisdv1.domain.preference.PreferenceManager -import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase +import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase @@ -44,7 +44,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest private val stubGetConfigurationUseCase = mockk() private val stubSelectStableDiffusionModelUseCase = mockk() private val stubGetStableDiffusionModelsUseCase = mockk() - private val stubObserveLocalAiModelsUseCase = mockk() + private val stubObserveLocalAiModelsUseCase = mockk() private val stubFetchAndGetStabilityAiEnginesUseCase = mockk() private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubFetchAndGetSwarmUiModelsUseCase = mockk() @@ -56,7 +56,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest getConfigurationUseCase = stubGetConfigurationUseCase, selectStableDiffusionModelUseCase = stubSelectStableDiffusionModelUseCase, getStableDiffusionModelsUseCase = stubGetStableDiffusionModelsUseCase, - observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase, + observeLocalOnnxModelsUseCase = stubObserveLocalAiModelsUseCase, fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase, getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, fetchAndGetSwarmUiModelsUseCase = stubFetchAndGetSwarmUiModelsUseCase, @@ -108,7 +108,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest selectedHfModel = "prompthero/openjourney-v4", stEngines = listOf("5598"), selectedStEngine = "5598", - localAiModels = listOf(LocalAiModel.CUSTOM), + localAiModels = listOf(LocalAiModel.CustomOnnx), selectedLocalAiModelId = "CUSTOM", swarmModels = listOf("5598"), selectedSwarmModel = "5598", @@ -229,16 +229,16 @@ class EngineSelectionViewModelTest : CoreViewModelTest @Test fun `given received EngineSelectionIntent, source is LOCAL, expected localModelId changed in preference`() { - mockInitialData(DataTestCase.Mock, ServerSource.LOCAL) + mockInitialData(DataTestCase.Mock, ServerSource.LOCAL_MICROSOFT_ONNX) every { - stubPreferenceManager::localModelId.set(any()) + stubPreferenceManager::localOnnxModelId.set(any()) } returns Unit viewModel.processIntent(EngineSelectionIntent("llm_5598")) verify { - stubPreferenceManager::localModelId.set("llm_5598") + stubPreferenceManager::localOnnxModelId.set("llm_5598") } } @@ -262,7 +262,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest huggingFaceModel = "prompthero/openjourney-v4", stabilityAiEngineId = "5598", swarmUiModel = "5598", - localModelId = "CUSTOM", + localOnnxModelId = "CUSTOM", source = source, ), ) diff --git a/settings.gradle.kts b/settings.gradle.kts index 6ee0480b..fca65518 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -1,8 +1,8 @@ pluginManagement { includeBuild("build-logic") repositories { - google() mavenCentral() + google() gradlePluginPortal() maven { url = uri("https://jitpack.io") @@ -12,8 +12,8 @@ pluginManagement { dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { - google() mavenCentral() + google() maven { url = uri("https://jitpack.io") } @@ -35,6 +35,7 @@ val modules = listOf( ":domain", ":feature:auth", ":feature:diffusion", + ":feature:mediapipe", ":feature:work", ":network", ":presentation", diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json new file mode 100644 index 00000000..7f7ac2a7 --- /dev/null +++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json @@ -0,0 +1,254 @@ +{ + "formatVersion": 1, + "database": { + "version": 6, + "identityHash": "6f6ccee56637122e0126c09bb3eb3fdc", + "entities": [ + { + "tableName": "generation_results", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `image_base_64` TEXT NOT NULL, `original_image_base_64` TEXT NOT NULL, `created_at` INTEGER NOT NULL, `generation_type` TEXT NOT NULL, `prompt` TEXT NOT NULL, `negative_prompt` TEXT NOT NULL, `width` INTEGER NOT NULL, `height` INTEGER NOT NULL, `sampling_steps` INTEGER NOT NULL, `cfg_scale` REAL NOT NULL, `restore_faces` INTEGER NOT NULL, `sampler` TEXT NOT NULL, `seed` TEXT NOT NULL, `sub_seed` TEXT NOT NULL DEFAULT '', `sub_seed_strength` REAL NOT NULL DEFAULT 0.0, `denoising_strength` REAL NOT NULL DEFAULT 0.0)", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "imageBase64", + "columnName": "image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "originalImageBase64", + "columnName": "original_image_base_64", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "createdAt", + "columnName": "created_at", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "generationType", + "columnName": "generation_type", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "prompt", + "columnName": "prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "negativePrompt", + "columnName": "negative_prompt", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "width", + "columnName": "width", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "height", + "columnName": "height", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "samplingSteps", + "columnName": "sampling_steps", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "cfgScale", + "columnName": "cfg_scale", + "affinity": "REAL", + "notNull": true + }, + { + "fieldPath": "restoreFaces", + "columnName": "restore_faces", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "sampler", + "columnName": "sampler", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "seed", + "columnName": "seed", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "subSeed", + "columnName": "sub_seed", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "''" + }, + { + "fieldPath": "subSeedStrength", + "columnName": "sub_seed_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + }, + { + "fieldPath": "denoisingStrength", + "columnName": "denoising_strength", + "affinity": "REAL", + "notNull": true, + "defaultValue": "0.0" + } + ], + "primaryKey": { + "autoGenerate": true, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "local_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `type` TEXT NOT NULL DEFAULT 'onnx', `name` TEXT NOT NULL, `size` TEXT NOT NULL, `sources` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "type", + "columnName": "type", + "affinity": "TEXT", + "notNull": true, + "defaultValue": "'onnx'" + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "size", + "columnName": "size", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "sources", + "columnName": "sources", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "hugging_face_models", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `alias` TEXT NOT NULL, `source` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "alias", + "columnName": "alias", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "source", + "columnName": "source", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + }, + { + "tableName": "supporters", + "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER NOT NULL, `name` TEXT NOT NULL, `date` INTEGER NOT NULL, `message` TEXT NOT NULL, PRIMARY KEY(`id`))", + "fields": [ + { + "fieldPath": "id", + "columnName": "id", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "name", + "columnName": "name", + "affinity": "TEXT", + "notNull": true + }, + { + "fieldPath": "date", + "columnName": "date", + "affinity": "INTEGER", + "notNull": true + }, + { + "fieldPath": "message", + "columnName": "message", + "affinity": "TEXT", + "notNull": true + } + ], + "primaryKey": { + "autoGenerate": false, + "columnNames": [ + "id" + ] + }, + "indices": [], + "foreignKeys": [] + } + ], + "views": [], + "setupQueries": [ + "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)", + "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '6f6ccee56637122e0126c09bb3eb3fdc')" + ] + } +} \ No newline at end of file diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt index df8eada6..1a59785c 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt @@ -4,18 +4,11 @@ import androidx.room.AutoMigration import androidx.room.Database import androidx.room.RoomDatabase import androidx.room.TypeConverters -import com.shifthackz.aisdv1.storage.converters.DateConverters -import com.shifthackz.aisdv1.storage.converters.ListConverters +import com.shifthackz.aisdv1.storage.converters.* import com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase.Companion.DB_VERSION -import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract -import com.shifthackz.aisdv1.storage.db.persistent.dao.GenerationResultDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.HuggingFaceModelDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao -import com.shifthackz.aisdv1.storage.db.persistent.dao.SupporterDao -import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity -import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity +import com.shifthackz.aisdv1.storage.db.persistent.contract.* +import com.shifthackz.aisdv1.storage.db.persistent.dao.* +import com.shifthackz.aisdv1.storage.db.persistent.entity.* @Database( version = DB_VERSION, @@ -46,6 +39,11 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity * Added [SupporterEntity]. */ AutoMigration(from = 4, to = 5), + /** + * Added 1 field to [LocalModelEntity]: + * - [LocalModelContract.TYPE] + */ + AutoMigration(from = 5, to = 6), ], ) @TypeConverters( @@ -60,6 +58,6 @@ internal abstract class PersistentDatabase : RoomDatabase() { companion object { const val DB_NAME = "ai_sd_v1_storage_db" - const val DB_VERSION = 5 + const val DB_VERSION = 6 } } diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt index 03b3c03a..2ffebf43 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt @@ -4,6 +4,7 @@ object LocalModelContract { const val TABLE = "local_models" const val ID = "id" + const val TYPE = "type" const val NAME = "name" const val SIZE = "size" const val SOURCES = "sources" diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt index 435d4c69..a2b94a16 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt @@ -16,9 +16,15 @@ interface LocalModelDao { @Query("SELECT * FROM ${LocalModelContract.TABLE}") fun query(): Single> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type") + fun queryByType(type: String): Single> + @Query("SELECT * FROM ${LocalModelContract.TABLE}") fun observe(): Flowable> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type") + fun observeByType(type: String): Flowable> + @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1") fun queryById(id: String): Single diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt index ab896641..d69856d2 100644 --- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt +++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt @@ -10,6 +10,8 @@ data class LocalModelEntity( @PrimaryKey(autoGenerate = false) @ColumnInfo(name = LocalModelContract.ID) val id: String, + @ColumnInfo(name = LocalModelContract.TYPE, defaultValue = "onnx") + val type: String, @ColumnInfo(name = LocalModelContract.NAME) val name: String, @ColumnInfo(name = LocalModelContract.SIZE)