diff --git a/.gitignore b/.gitignore
index 5076b36e..9590699f 100755
--- a/.gitignore
+++ b/.gitignore
@@ -44,6 +44,7 @@ captures/
!/.idea/vcs.xml
!/.idea/fileTemplates/
!/.idea/inspectionProfiles/
+!/.idea/icon.svg
!/.idea/scopes/
!/.idea/codeStyleSettings.xml
!/.idea/encodings.xml
diff --git a/.idea/icon.svg b/.idea/icon.svg
new file mode 100644
index 00000000..7f7eab01
--- /dev/null
+++ b/.idea/icon.svg
@@ -0,0 +1,106 @@
+
+
diff --git a/app/build.gradle.kts b/app/build.gradle.kts
index 5aeb348d..9ad3e14a 100755
--- a/app/build.gradle.kts
+++ b/app/build.gradle.kts
@@ -51,27 +51,6 @@ android {
println("[Signature] -> Build will be signed with signature: $alias")
buildTypes.getByName("release").signingConfig = signingConfigs.getByName("release")
}
-
- flavorDimensions += "type"
- productFlavors {
- create("dev") {
- dimension = "type"
- applicationIdSuffix = ".dev"
- resValue("string", "app_name", "SDAI Dev")
- buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"")
- }
- create("foss") {
- dimension = "type"
- applicationIdSuffix = ".foss"
- resValue("string", "app_name", "SDAI FOSS")
- buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"")
- }
- create("playstore") {
- dimension = "type"
- resValue("string", "app_name", "SDAI")
- buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"")
- }
- }
}
dependencies {
@@ -85,6 +64,7 @@ dependencies {
implementation(project(":domain"))
implementation(project(":feature:auth"))
implementation(project(":feature:diffusion"))
+ implementation(project(":feature:mediapipe"))
implementation(project(":feature:work"))
implementation(project(":data"))
implementation(project(":demo"))
diff --git a/app/src/dev/AndroidManifest.xml b/app/src/full/AndroidManifest.xml
similarity index 100%
rename from app/src/dev/AndroidManifest.xml
rename to app/src/full/AndroidManifest.xml
diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml
index 67b1befd..4aac9eb0 100755
--- a/app/src/main/AndroidManifest.xml
+++ b/app/src/main/AndroidManifest.xml
@@ -12,6 +12,10 @@
android:theme="@style/Theme.AiSdCompose.Splash"
android:usesCleartextTraffic="true">
+
+
+
+
append(" FULL")
+ BuildType.FOSS -> append(" FOSS")
+ BuildType.PLAY -> Unit
+ }
}
}
}
@@ -172,14 +176,14 @@ val providersModule = module {
single {
DeviceNNAPIFlagProvider {
- get().localUseNNAPI
+ get().localOnnxUseNNAPI
.let { nnApi -> if (nnApi) LocalDiffusionFlag.NN_API else LocalDiffusionFlag.CPU }
.let(LocalDiffusionFlag::value)
}
}
single {
- LocalModelIdProvider { get().localModelId }
+ LocalModelIdProvider { get().localOnnxModelId }
}
single {
diff --git a/build-logic/convention/build.gradle.kts b/build-logic/convention/build.gradle.kts
index 75b16dee..2d730493 100644
--- a/build-logic/convention/build.gradle.kts
+++ b/build-logic/convention/build.gradle.kts
@@ -36,6 +36,10 @@ gradlePlugin {
id = "generic.application"
implementationClass = "ApplicationConventionPlugin"
}
+ register("Flavors") {
+ id = "generic.flavors"
+ implementationClass = "FlavorsConventionPlugin"
+ }
register("BaselineProFm") {
id = "generic.baseline.profm"
implementationClass = "BaselineProFmConventionPlugin"
diff --git a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt
index 2b126264..34cbad9a 100644
--- a/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt
+++ b/build-logic/convention/src/main/kotlin/ApplicationConventionPlugin.kt
@@ -1,6 +1,8 @@
import com.android.build.api.dsl.ApplicationExtension
+import com.android.build.gradle.BaseExtension
import com.shifthackz.aisdv1.buildlogic.configureApplication
import com.shifthackz.aisdv1.buildlogic.configureCompose
+import com.shifthackz.aisdv1.buildlogic.configureFlavors
import com.shifthackz.aisdv1.buildlogic.libs
import org.gradle.api.Plugin
import org.gradle.api.Project
@@ -22,6 +24,9 @@ class ApplicationConventionPlugin : Plugin {
configureCompose(this)
defaultConfig.targetSdk = libs.findVersion("targetSdk").get().toString().toInt()
}
+ extensions.configure {
+ configureFlavors(this)
+ }
}
}
}
diff --git a/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt
new file mode 100644
index 00000000..743465f1
--- /dev/null
+++ b/build-logic/convention/src/main/kotlin/FlavorsConventionPlugin.kt
@@ -0,0 +1,16 @@
+import com.android.build.gradle.LibraryExtension
+import com.shifthackz.aisdv1.buildlogic.configureFlavorsCommon
+import org.gradle.api.Plugin
+import org.gradle.api.Project
+import org.gradle.kotlin.dsl.configure
+
+class FlavorsConventionPlugin : Plugin {
+
+ override fun apply(target: Project) {
+ with(target) {
+ extensions.configure {
+ configureFlavorsCommon(this)
+ }
+ }
+ }
+}
diff --git a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt
index 5bc8f1b3..6e83a3fc 100644
--- a/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt
+++ b/build-logic/convention/src/main/kotlin/JacocoConventionPlugin.kt
@@ -7,7 +7,7 @@ import org.gradle.kotlin.dsl.configure
import org.gradle.kotlin.dsl.withType
import org.gradle.testing.jacoco.plugins.JacocoTaskExtension
-class JacocoConventionPlugin : Plugin {
+class JacocoConventionPlugin : Plugin {
override fun apply(target: Project) {
with(target) {
diff --git a/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt
new file mode 100644
index 00000000..6357b9ff
--- /dev/null
+++ b/build-logic/convention/src/main/kotlin/com/shifthackz/aisdv1/buildlogic/Flavors.kt
@@ -0,0 +1,43 @@
+package com.shifthackz.aisdv1.buildlogic
+
+import com.android.build.api.dsl.CommonExtension
+import com.android.build.gradle.BaseExtension
+import org.gradle.api.Project
+
+internal fun Project.configureFlavors(
+ commonExtension: BaseExtension,
+) {
+ commonExtension.apply {
+ flavorDimensions("type")
+ productFlavors.create("full") {
+ dimension = "type"
+ applicationIdSuffix = ".full"
+ resValue("string", "app_name", "SDAI Full")
+ buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FULL\"")
+
+ }
+ productFlavors.create("foss") {
+ dimension = "type"
+ applicationIdSuffix = ".foss"
+ resValue("string", "app_name", "SDAI FOSS")
+ buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"FOSS\"")
+
+ }
+ productFlavors.create("playstore") {
+ dimension = "type"
+ resValue("string", "app_name", "SDAI")
+ buildConfigField("String", "BUILD_FLAVOR_TYPE", "\"GOOGLE_PLAY\"")
+ }
+ }
+}
+
+internal fun Project.configureFlavorsCommon(
+ commonExtension: CommonExtension<*, *, *, *, *, *>,
+) {
+ commonExtension.apply {
+ flavorDimensions += listOf("type")
+ productFlavors.create("full") { dimension = "type" }
+ productFlavors.create("foss") { dimension = "type" }
+ productFlavors.create("playstore") { dimension = "type" }
+ }
+}
diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt
index d431d0f8..2f4c4471 100644
--- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt
+++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/appbuild/BuildType.kt
@@ -1,11 +1,13 @@
package com.shifthackz.aisdv1.core.common.appbuild
enum class BuildType {
+ FULL,
FOSS,
PLAY;
companion object {
fun fromBuildConfig(input: String) = when (input) {
+ "FULL" -> FULL
"FOSS" -> FOSS
else -> PLAY
}
diff --git a/core/localization/src/main/res/values-ru/strings.xml b/core/localization/src/main/res/values-ru/strings.xml
index acbaa8e7..ef0e1272 100644
--- a/core/localization/src/main/res/values-ru/strings.xml
+++ b/core/localization/src/main/res/values-ru/strings.xml
@@ -133,9 +133,11 @@
Укажите свйой URL-адрес Swarm UI
Модульный веб-интерфейс Stable Diffusion, в котором особое внимание уделяется обеспечению легкого доступа к инструментам, высокой производительности и расширяемости.
- Эта конфигурация позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку.
+ Эта конфигурация использует Microsoft ONNX и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку.
ВНИМАНИЕ! Функциональность Local Diffusion в бета-тестировании. Не ожидайте высококачественных изображений в локальном режиме. \n\nЭта реализация может не работать должным образом на мобильных телефонах. Производительность и скорость генерации зависят от ресурсов вашего телефона (ЦП, ОЗУ) и размера сгенерированного изображения (чем меньше размер изображения, тем быстрее генерируется).
+ Эта конфигурация использует Google AI MediaPipe и позволяет запускать генерации Stable Diffusion на вашем телефоне без необходимости подключаться к удаленному серверу/облаку.
+
Веб
Txt2Img
diff --git a/core/localization/src/main/res/values-tr/strings.xml b/core/localization/src/main/res/values-tr/strings.xml
index eed6dea1..8319a6c1 100644
--- a/core/localization/src/main/res/values-tr/strings.xml
+++ b/core/localization/src/main/res/values-tr/strings.xml
@@ -133,9 +133,11 @@
Swarm UI URL\'nizi sağlayın
Araçlara kolay erişim, yüksek performans ve genişletilebilirlik üzerine odaklanan Modüler, Kararlı Yaygın Web Kullanıcı Arayüzü.
- Bu yapılandırma, telefonunuzda uzak sunucuya/buluta bağlanmaya gerek kalmadan Stable Diffusion AI nesillerini çalıştırmanıza izin verir.
+ Bu yapılandırma Microsoft ONNX çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır.
Uyarı! Yerel Yayılma işlevi beta testindedir. Yerel modu kullanarak yüksek kaliteli görüntüler beklemeyin. \n\nBu uygulama, güçlü olmayan telefonlarda iyi çalışmayabilir. Oluşturma performansı ve hızı, telefonunuzun kaynaklarına (CPU, RAM) ve oluşturulan görüntünün boyutuna bağlıdır (görüntü boyutu ne kadar küçükse, oluşturma o kadar hızlı olur).
+ Bu yapılandırma Google AI MediaPipe çalışma zamanını kullanır ve uzak bir sunucuya/buluta bağlanmaya gerek kalmadan telefonunuzda Stable Diffusion AI nesillerini çalıştırmanıza olanak tanır.
+
Web arayüzü
Txt2Img
diff --git a/core/localization/src/main/res/values-uk/strings.xml b/core/localization/src/main/res/values-uk/strings.xml
index da32118e..87fe4b5d 100644
--- a/core/localization/src/main/res/values-uk/strings.xml
+++ b/core/localization/src/main/res/values-uk/strings.xml
@@ -133,9 +133,11 @@
Provide your Swarm UI URL
Модульний веб-інтерфейс Stable Diffusion з наголосом на полегшення доступу до інструментів, високу продуктивність і розширюваність.
- Ця конфігурація дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари.
+ Ця конфігурація використовує Microsoft ONNX та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари.
УВАГА! Функціональність Local Diffusion у бета-тестуванні. Не очікуйте високоякісних зображень у локальному режимі. \n\nЦя реалізація може не працювати належним чином на телефонах із слабкою потужністю. Продуктивність і швидкість генерації залежать від ресурсів вашого телефону (ЦП, ОЗУ) і розміру згенерованого зображення (чим менший розмір зображення, тим швидше генерується).
+ Ця конфігурація використовує Google AI MediaPipe та дозволяє запускати генерації Stable Diffusion на вашому телефоні без необхідності підключатися до віддаленого сервера/хмари.
+
Веб
Txt2Img
diff --git a/core/localization/src/main/res/values-zh/strings.xml b/core/localization/src/main/res/values-zh/strings.xml
index c3483d2e..7329dc23 100644
--- a/core/localization/src/main/res/values-zh/strings.xml
+++ b/core/localization/src/main/res/values-zh/strings.xml
@@ -167,9 +167,11 @@
本地扩散
- 此配置允许在您的手机上运行Stable Diffusion AI生成,无需连接到远程服务器/云。
+ 此配置使用 Microsoft ONNX 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。
警告!本地扩散功能处于测试版。不要期望使用本地模式获得高质量图像。 \n\n此实现可能在不强大的手机上运行不佳。生成性能和速度取决于您的手机资源(CPU、RAM)和生成的图像大小(图像越小,生成越快)。
+ 此配置使用 Google AI MediaPipe 运行时,并允许在手机上运行稳定的 Diffusion AI 生成,无需连接到远程服务器/云。
+
网络界面
diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml
index cba23819..b591977e 100755
--- a/core/localization/src/main/res/values/strings.xml
+++ b/core/localization/src/main/res/values/strings.xml
@@ -70,8 +70,10 @@
A1111
Horde AI Cloud
Horde
- Local Diffusion (Beta)
- Local
+ Local Diffusion Microsoft ONNX (Beta)
+ ONNX
+ Local Diffusion Google AI MediaPipe (Beta)
+ MediaPipe
Hugging Face Inference
HuggingFace
Open AI
@@ -150,10 +152,13 @@
Provide your Swarm UI URL
A Modular Stable Diffusion Web-User-Interface, with an emphasis on making tools easily accessible, high performance, and extensibility.
- Local Diffusion
- This configuration allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud.
+ Local Diffusion Microsoft ONNX
+ This configuration uses Microsoft ONNX runtime and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud.
Warning! Local Diffusion functionality is in beta-test. Don\'t expect for high quality images using local mode. \n\nThis implementation may not work well on non-powerful phones. Generation performance and speed depends on your phone resources (CPU, RAM) and the size of generated image (the smaller the image size, the faster the generation).
+ Local Diffusion Google AI MediaPipe
+ This configuration uses Google AI MediaPipe and allows to run Stable Diffusion AI generations on your phone, with no need to connect to remote server/cloud.
+
Web UI
Txt2Img
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt
index 6e08ea30..21d8edda 100755
--- a/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/di/RepositoryModule.kt
@@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.repository.HuggingFaceGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.HuggingFaceModelsRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LocalDiffusionGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.LorasRepositoryImpl
+import com.shifthackz.aisdv1.data.repository.MediaPipeGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.OpenAiGenerationRepositoryImpl
import com.shifthackz.aisdv1.data.repository.RandomImageRepositoryImpl
import com.shifthackz.aisdv1.data.repository.ServerConfigurationRepositoryImpl
@@ -33,6 +34,7 @@ import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
import com.shifthackz.aisdv1.domain.repository.HuggingFaceModelsRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LorasRepository
+import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
import com.shifthackz.aisdv1.domain.repository.RandomImageRepository
import com.shifthackz.aisdv1.domain.repository.ServerConfigurationRepository
@@ -63,6 +65,7 @@ val repositoryModule = module {
singleOf(::TemporaryGenerationResultRepositoryImpl) bind TemporaryGenerationResultRepository::class
factoryOf(::LocalDiffusionGenerationRepositoryImpl) bind LocalDiffusionGenerationRepository::class
+ factoryOf(::MediaPipeGenerationRepositoryImpl) bind MediaPipeGenerationRepository::class
factoryOf(::HordeGenerationRepositoryImpl) bind HordeGenerationRepository::class
factoryOf(::HuggingFaceGenerationRepositoryImpl) bind HuggingFaceGenerationRepository::class
factoryOf(::OpenAiGenerationRepositoryImpl) bind OpenAiGenerationRepository::class
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt
index 0e4b9f02..4ab031c2 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSource.kt
@@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
import io.reactivex.rxjava3.core.Completable
+import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import java.io.File
@@ -22,72 +23,94 @@ internal class DownloadableModelLocalDataSource(
private val buildInfoProvider: BuildInfoProvider,
) : DownloadableModelDataSource.Local {
- override fun getAll() = dao
- .query()
+ override fun getAllOnnx() = dao
+ .queryByType(LocalAiModel.Type.ONNX.key)
.map(List::mapEntityToDomain)
.map { models ->
buildList {
addAll(models)
- if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM)
+ if (buildInfoProvider.type != BuildType.PLAY) {
+ add(LocalAiModel.CustomOnnx)
+ }
+ }
+ }
+ .flatMap { models -> models.withLocalData() }
+
+ override fun getAllMediaPipe(): Single> = dao
+ .queryByType(LocalAiModel.Type.MediaPipe.key)
+ .map(List::mapEntityToDomain)
+ .map { models ->
+ buildList {
+ addAll(models)
+ if (buildInfoProvider.type != BuildType.PLAY) {
+ add(LocalAiModel.CustomMediaPipe)
+ }
}
}
.flatMap { models -> models.withLocalData() }
override fun getById(id: String): Single {
- val chain = if (id == LocalAiModel.CUSTOM.id) {
- Single.just(LocalAiModel.CUSTOM)
- } else {
- dao
+ val chain = when (id) {
+ LocalAiModel.CustomOnnx.id -> Single.just(LocalAiModel.CustomOnnx)
+ LocalAiModel.CustomMediaPipe.id -> Single.just(LocalAiModel.CustomMediaPipe)
+ else -> dao
.queryById(id)
.map(LocalModelEntity::mapEntityToDomain)
}
-
return chain.flatMap { model -> model.withLocalData() }
}
- override fun getSelected() = Single
- .just(preferenceManager.localModelId)
- .onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) }
+ override fun getSelectedOnnx() = Single
+ .just(preferenceManager.localOnnxModelId)
.flatMap(::getById)
.onErrorResumeNext { Single.error(IllegalStateException("No selected model.")) }
- override fun observeAll() = dao
- .observe()
+ override fun observeAllOnnx(): Flowable> = dao
+ .observeByType(LocalAiModel.Type.ONNX.key)
.map(List::mapEntityToDomain)
.map { models ->
buildList {
addAll(models)
- if (buildInfoProvider.type == BuildType.FOSS) add(LocalAiModel.CUSTOM)
+ if (buildInfoProvider.type != BuildType.PLAY) add(LocalAiModel.CustomOnnx)
}
}
.flatMap { models -> models.withLocalData().toFlowable() }
- override fun select(id: String) = Completable.fromAction {
- preferenceManager.localModelId = id
- }
-
override fun save(list: List) = list
- .filter { it.id != LocalAiModel.CUSTOM.id }
+ .filter { it.id != LocalAiModel.CustomOnnx.id }
.mapDomainToEntity()
.let(dao::insertList)
- override fun isDownloaded(id: String) = Single.create { emitter ->
+ override fun delete(id: String): Completable = Completable.fromAction {
+ getLocalModelDirectory(id).deleteRecursively()
+ }
+
+ private fun isDownloaded(model: LocalAiModel) = Single.create { emitter ->
try {
- if (id == LocalAiModel.CUSTOM.id) {
- if (!emitter.isDisposed) emitter.onSuccess(true)
- } else {
- val files = getLocalModelFiles(id)
- if (!emitter.isDisposed) emitter.onSuccess(files.size == 4)
+ when (model.id) {
+ LocalAiModel.CustomOnnx.id,
+ LocalAiModel.CustomMediaPipe.id -> emitter.onSuccess(true)
+
+ else -> {
+
+ when (model.type) {
+ LocalAiModel.Type.ONNX -> {
+ val files = getLocalModelFiles(model.id).filter { it.isDirectory }
+ emitter.onSuccess(files.size == 4)
+ }
+
+ LocalAiModel.Type.MediaPipe -> {
+ val files = getLocalModelFiles(model.id)
+ emitter.onSuccess(files.isNotEmpty())
+ }
+ }
+ }
}
} catch (e: Exception) {
if (!emitter.isDisposed) emitter.onSuccess(false)
}
}
- override fun delete(id: String): Completable = Completable.fromAction {
- getLocalModelDirectory(id).deleteRecursively()
- }
-
private fun getLocalModelDirectory(id: String): File {
return File("${fileProviderDescriptor.localModelDirPath}/${id}")
}
@@ -95,7 +118,7 @@ internal class DownloadableModelLocalDataSource(
private fun getLocalModelFiles(id: String): List {
val localModelDir = getLocalModelDirectory(id)
if (!localModelDir.exists()) return emptyList()
- return localModelDir.listFiles()?.filter { it.isDirectory } ?: emptyList()
+ return localModelDir.listFiles()?.toList() ?: emptyList()
}
private fun List.withLocalData() = Observable
@@ -103,11 +126,14 @@ internal class DownloadableModelLocalDataSource(
.flatMapSingle { model -> model.withLocalData() }
.toList()
- private fun LocalAiModel.withLocalData() = isDownloaded(id)
+ private fun LocalAiModel.withLocalData() = isDownloaded(this)
.map { downloaded ->
copy(
downloaded = downloaded,
- selected = preferenceManager.localModelId == id,
+ selected = when (this.type) {
+ LocalAiModel.Type.ONNX -> preferenceManager.localOnnxModelId == id
+ LocalAiModel.Type.MediaPipe -> preferenceManager.localMediaPipeModelId == id
+ },
)
}
}
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt
index d9261085..7d7033aa 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/mappers/LocalAiModelMappers.kt
@@ -5,12 +5,16 @@ import com.shifthackz.aisdv1.network.response.DownloadableModelResponse
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
//region RAW --> DOMAIN
-fun List.mapRawToCheckpointDomain(): List =
- map(DownloadableModelResponse::mapRawToCheckpointDomain)
+fun List.mapRawToCheckpointDomain(
+ type: LocalAiModel.Type,
+): List = map { it.mapRawToCheckpointDomain(type) }
-fun DownloadableModelResponse.mapRawToCheckpointDomain(): LocalAiModel = with(this) {
+fun DownloadableModelResponse.mapRawToCheckpointDomain(
+ type: LocalAiModel.Type,
+): LocalAiModel = with(this) {
LocalAiModel(
id = id ?: "",
+ type = type,
name = name ?: "",
size = size ?: "",
sources = sources ?: emptyList(),
@@ -23,7 +27,7 @@ fun List.mapDomainToEntity(): List =
map(LocalAiModel::mapDomainToEntity)
fun LocalAiModel.mapDomainToEntity(): LocalModelEntity = with(this) {
- LocalModelEntity(id, name, size, sources)
+ LocalModelEntity(id, type.key, name, size, sources)
}
//endregion
@@ -32,6 +36,6 @@ fun List.mapEntityToDomain(): List =
map(LocalModelEntity::mapEntityToDomain)
fun LocalModelEntity.mapEntityToDomain(): LocalAiModel = with(this) {
- LocalAiModel(id, name, size, sources)
+ LocalAiModel(id, LocalAiModel.Type.parse(type), name, size, sources)
}
//endregion
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
index e88fcf47..ac47d61e 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
@@ -60,7 +60,16 @@ class PreferenceManagerImpl(
.apply()
.also { onPreferencesChanged() }
- override var localDiffusionCustomModelPath: String
+ override var localMediaPipeCustomModelPath: String
+ get() = preferences.getString(
+ KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH,
+ LOCAL_DIFFUSION_CUSTOM_PATH
+ ) ?: LOCAL_DIFFUSION_CUSTOM_PATH
+ set(value) = preferences.edit()
+ .putString(KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH, value)
+ .apply()
+
+ override var localOnnxCustomModelPath: String
get() = preferences.getString(
KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH,
LOCAL_DIFFUSION_CUSTOM_PATH,
@@ -69,14 +78,14 @@ class PreferenceManagerImpl(
.putString(KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, value)
.apply()
- override var localDiffusionAllowCancel: Boolean
+ override var localOnnxAllowCancel: Boolean
get() = preferences.getBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, false)
set(value) = preferences.edit()
.putBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, value)
.apply()
.also { onPreferencesChanged() }
- override var localDiffusionSchedulerThread: SchedulersToken
+ override var localOnnxSchedulerThread: SchedulersToken
get() = preferences
.getInt(KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD, SchedulersToken.COMPUTATION.ordinal)
.let { SchedulersToken.entries[it] }
@@ -196,20 +205,27 @@ class PreferenceManagerImpl(
.apply()
.also { onPreferencesChanged() }
- override var localModelId: String
+ override var localOnnxModelId: String
get() = preferences.getString(KEY_LOCAL_MODEL_ID, "") ?: ""
set(value) = preferences.edit()
.putString(KEY_LOCAL_MODEL_ID, value)
.apply()
.also { onPreferencesChanged() }
- override var localUseNNAPI: Boolean
+ override var localOnnxUseNNAPI: Boolean
get() = preferences.getBoolean(KEY_LOCAL_NN_API, false)
set(value) = preferences.edit()
.putBoolean(KEY_LOCAL_NN_API, value)
.apply()
.also { onPreferencesChanged() }
+ override var localMediaPipeModelId: String
+ get() = preferences.getString(KEY_MEDIA_PIPE_MODEL_ID, "") ?: ""
+ set(value) = preferences.edit()
+ .putString(KEY_MEDIA_PIPE_MODEL_ID, value)
+ .apply()
+ .also { onPreferencesChanged() }
+
override var designUseSystemColorPalette: Boolean
get() = preferences.getBoolean(KEY_DESIGN_DYNAMIC_COLORS, false)
set(value) = preferences.edit()
@@ -273,8 +289,8 @@ class PreferenceManagerImpl(
sdModel = sdModel,
demoMode = demoMode,
developerMode = developerMode,
- localDiffusionAllowCancel = localDiffusionAllowCancel,
- localDiffusionSchedulerThread = localDiffusionSchedulerThread,
+ localDiffusionAllowCancel = localOnnxAllowCancel,
+ localDiffusionSchedulerThread = localOnnxSchedulerThread,
monitorConnectivity = monitorConnectivity,
backgroundGeneration = backgroundGeneration,
autoSaveAiResults = autoSaveAiResults,
@@ -283,7 +299,7 @@ class PreferenceManagerImpl(
formPromptTaggedInput = formPromptTaggedInput,
source = source,
hordeApiKey = hordeApiKey,
- localUseNNAPI = localUseNNAPI,
+ localUseNNAPI = localOnnxUseNNAPI,
designUseSystemColorPalette = designUseSystemColorPalette,
designUseSystemDarkTheme = designUseSystemDarkTheme,
designDarkTheme = designDarkTheme,
@@ -302,6 +318,7 @@ class PreferenceManagerImpl(
const val KEY_DEMO_MODE = "key_demo_mode"
const val KEY_DEVELOPER_MODE = "key_developer_mode"
const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path"
+ const val KEY_MEDIA_PIPE_CUSTOM_MODEL_PATH = "key_mediapipe_custom_model_path"
const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel"
const val KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD = "key_local_diffusion_scheduler_thread"
const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity"
@@ -319,6 +336,7 @@ class PreferenceManagerImpl(
const val KEY_STABILITY_AI_ENGINE_ID_KEY = "key_stability_ai_engine_id_key"
const val KEY_ON_BOARDING_COMPLETE = "key_on_boarding_complete"
const val KEY_FORCE_SETUP_AFTER_UPDATE = "force_upd_setup_v0.x.x-v0.6.2"
+ const val KEY_MEDIA_PIPE_MODEL_ID = "key_mediapipe_model_id"
const val KEY_LOCAL_MODEL_ID = "key_local_model_id"
const val KEY_LOCAL_NN_API = "key_local_nn_api"
const val KEY_DESIGN_DYNAMIC_COLORS = "key_design_dynamic_colors"
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt
index a2e6b193..1a552534 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSource.kt
@@ -5,10 +5,12 @@ import com.shifthackz.aisdv1.core.common.file.unzip
import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
import com.shifthackz.aisdv1.domain.entity.DownloadState
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi
import com.shifthackz.aisdv1.network.response.DownloadableModelResponse
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Observable
+import io.reactivex.rxjava3.core.Single
import java.io.File
internal class DownloadableModelRemoteDataSource(
@@ -16,12 +18,18 @@ internal class DownloadableModelRemoteDataSource(
private val fileProviderDescriptor: FileProviderDescriptor,
) : DownloadableModelDataSource.Remote {
- override fun fetch() = api
- .fetchDownloadableModels()
- .map(List::mapRawToCheckpointDomain)
+ override fun fetch(): Single> = Single.zip(
+ api
+ .fetchOnnxModels()
+ .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX) },
+ api
+ .fetchMediaPipeModels()
+ .map { it.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe) },
+ ::Pair,
+ )
+ .map { (onnx, mediapipe) -> listOf(onnx, mediapipe).flatten() }
- override fun download(id: String, url: String): Observable =
- Completable
+ override fun download(id: String, url: String): Observable = Completable
.fromAction {
val dir = File("${fileProviderDescriptor.localModelDirPath}/${id}")
val destination = File(getDestinationPath(id))
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt
index 5bde3662..6f84189c 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt
@@ -1,15 +1,18 @@
package com.shifthackz.aisdv1.data.repository
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
+import com.shifthackz.aisdv1.core.common.appbuild.BuildType
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
+import io.reactivex.rxjava3.core.Single
internal class DownloadableModelRepositoryImpl(
private val remoteDataSource: DownloadableModelDataSource.Remote,
private val localDataSource: DownloadableModelDataSource.Local,
+ private val buildInfoProvider: BuildInfoProvider,
) : DownloadableModelRepository {
- override fun isModelDownloaded(id: String) = localDataSource.isDownloaded(id)
-
override fun download(id: String) = localDataSource
.getById(id)
.flatMapObservable { model ->
@@ -18,15 +21,22 @@ internal class DownloadableModelRepositoryImpl(
override fun delete(id: String) = localDataSource.delete(id)
- override fun getAll() = remoteDataSource
+ override fun getAllOnnx() = remoteDataSource
.fetch()
.flatMapCompletable(localDataSource::save)
- .andThen(localDataSource.getAll())
- .onErrorResumeNext { localDataSource.getAll() }
-
- override fun getById(id: String) = localDataSource.getById(id)
+ .andThen(localDataSource.getAllOnnx())
+ .onErrorResumeNext { localDataSource.getAllOnnx() }
- override fun observeAll() = localDataSource.observeAll()
-
- override fun select(id: String) = localDataSource.select(id)
+ override fun getAllMediaPipe(): Single> {
+ if (buildInfoProvider.type == BuildType.FOSS) {
+ return Single.just(emptyList())
+ }
+ return remoteDataSource
+ .fetch()
+ .flatMapCompletable(localDataSource::save)
+ .andThen(localDataSource.getAllMediaPipe())
+ .onErrorResumeNext { localDataSource.getAllMediaPipe() }
+ }
+
+ override fun observeAllOnnx() = localDataSource.observeAllOnnx()
}
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt
index 0254b3d8..625a1ca8 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImpl.kt
@@ -36,7 +36,7 @@ internal class LocalDiffusionGenerationRepositoryImpl(
override fun observeStatus() = localDiffusion.observeStatus()
override fun generateFromText(payload: TextToImagePayload) = downloadableLocalDataSource
- .getSelected()
+ .getSelectedOnnx()
.flatMap { model ->
if (model.downloaded) generate(payload)
else Single.error(IllegalStateException("Model not downloaded."))
@@ -46,7 +46,7 @@ internal class LocalDiffusionGenerationRepositoryImpl(
private fun generate(payload: TextToImagePayload) = localDiffusion
.process(payload)
- .subscribeOn(schedulersProvider.byToken(preferenceManager.localDiffusionSchedulerThread))
+ .subscribeOn(schedulersProvider.byToken(preferenceManager.localOnnxSchedulerThread))
.map(BitmapToBase64Converter::Input)
.flatMap(bitmapToBase64Converter::invoke)
.map(BitmapToBase64Converter.Output::base64ImageString)
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt
new file mode 100644
index 00000000..88471209
--- /dev/null
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/repository/MediaPipeGenerationRepositoryImpl.kt
@@ -0,0 +1,45 @@
+package com.shifthackz.aisdv1.data.repository
+
+import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
+import com.shifthackz.aisdv1.core.imageprocessing.Base64ToBitmapConverter
+import com.shifthackz.aisdv1.core.imageprocessing.BitmapToBase64Converter
+import com.shifthackz.aisdv1.data.core.CoreGenerationRepository
+import com.shifthackz.aisdv1.data.mappers.mapLocalDiffusionToAiGenResult
+import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
+import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe
+import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver
+import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
+import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository
+import io.reactivex.rxjava3.core.Single
+import io.reactivex.rxjava3.schedulers.Schedulers
+
+internal class MediaPipeGenerationRepositoryImpl(
+ mediaStoreGateway: MediaStoreGateway,
+ base64ToBitmapConverter: Base64ToBitmapConverter,
+ localDataSource: GenerationResultDataSource.Local,
+ backgroundWorkObserver: BackgroundWorkObserver,
+ preferenceManager: PreferenceManager,
+ private val schedulersProvider: SchedulersProvider,
+ private val mediaPipe: MediaPipe,
+ private val bitmapToBase64Converter: BitmapToBase64Converter,
+) : CoreGenerationRepository(
+ mediaStoreGateway = mediaStoreGateway,
+ base64ToBitmapConverter = base64ToBitmapConverter,
+ localDataSource = localDataSource,
+ preferenceManager = preferenceManager,
+ backgroundWorkObserver = backgroundWorkObserver,
+), MediaPipeGenerationRepository {
+
+ override fun generateFromText(payload: TextToImagePayload): Single = mediaPipe
+ .process(payload)
+ .subscribeOn(schedulersProvider.singleThread.let(Schedulers::from))
+ .map(BitmapToBase64Converter::Input)
+ .flatMap(bitmapToBase64Converter::invoke)
+ .map(BitmapToBase64Converter.Output::base64ImageString)
+ .map { base64 -> payload to base64 }
+ .map(Pair::mapLocalDiffusionToAiGenResult)
+ .flatMap(::insertGenerationResult)
+}
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt
index e1c2c830..90fe0cbe 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/local/DownloadableModelLocalDataSourceTest.kt
@@ -13,8 +13,6 @@ import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao
import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
import io.mockk.every
import io.mockk.mockk
-import io.mockk.mockkConstructor
-import io.mockk.mockkStatic
import io.reactivex.rxjava3.core.BackpressureStrategy
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
@@ -22,7 +20,6 @@ import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.subjects.BehaviorSubject
import org.junit.Assert
import org.junit.Test
-import java.io.File
class DownloadableModelLocalDataSourceTest {
@@ -43,7 +40,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get all models, dao returns models list, app build type is PLAY, expected valid domain models list`() {
every {
- stubDao.query()
+ stubDao.queryByType(any())
} returns Single.just(mockLocalModelEntities)
every {
@@ -51,7 +48,7 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.PLAY
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
every {
@@ -61,7 +58,7 @@ class DownloadableModelLocalDataSourceTest {
val expected = mockLocalModelEntities.mapEntityToDomain()
localDataSource
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue { actual ->
@@ -74,7 +71,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get all models, dao returns empty models list, app build type is PLAY, expected empty domain models list`() {
every {
- stubDao.query()
+ stubDao.queryByType(any())
} returns Single.just(emptyList())
every {
@@ -82,11 +79,11 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.PLAY
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
localDataSource
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue(emptyList())
@@ -97,7 +94,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get all models, dao returns models list, app build type is FOSS, expected valid domain models list with CUSTOM model included`() {
every {
- stubDao.query()
+ stubDao.queryByType(any())
} returns Single.just(mockLocalModelEntities)
every {
@@ -105,7 +102,7 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.FOSS
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
every {
@@ -114,11 +111,11 @@ class DownloadableModelLocalDataSourceTest {
val expected = buildList {
addAll(mockLocalModelEntities.mapEntityToDomain())
- add(LocalAiModel.CUSTOM.copy(downloaded = true))
+ add(LocalAiModel.CustomOnnx.copy(downloaded = true))
}
localDataSource
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue { actual ->
@@ -131,7 +128,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get all models, dao returns empty models list, app build type is FOSS, expected domain models list with only CUSTOM model included`() {
every {
- stubDao.query()
+ stubDao.queryByType(any())
} returns Single.just(emptyList())
every {
@@ -139,14 +136,14 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.FOSS
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
localDataSource
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
- .assertValue(listOf(LocalAiModel.CUSTOM.copy(downloaded = true)))
+ .assertValue(listOf(LocalAiModel.CustomOnnx.copy(downloaded = true)))
.await()
.assertComplete()
}
@@ -154,11 +151,11 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get all models, dao throws exception, expected error value`() {
every {
- stubDao.query()
+ stubDao.queryByType(any())
} returns Single.error(stubException)
localDataSource
- .getAll()
+ .getAllOnnx()
.test()
.assertError(stubException)
.assertNoValues()
@@ -173,7 +170,7 @@ class DownloadableModelLocalDataSourceTest {
} returns Single.just(mockLocalModelEntity)
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
every {
@@ -198,7 +195,7 @@ class DownloadableModelLocalDataSourceTest {
} returns Single.just(mockLocalModelEntity)
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns "5598"
every {
@@ -238,7 +235,7 @@ class DownloadableModelLocalDataSourceTest {
} returns Single.just(mockLocalModelEntity)
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns "5598"
every {
@@ -248,7 +245,7 @@ class DownloadableModelLocalDataSourceTest {
val expected = mockLocalModelEntity.mapEntityToDomain().copy(selected = true)
localDataSource
- .getSelected()
+ .getSelectedOnnx()
.test()
.assertNoErrors()
.assertValue(expected)
@@ -259,11 +256,11 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to get selected model, preference throws exception, expected error value`() {
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
localDataSource
- .getSelected()
+ .getSelectedOnnx()
.test()
.assertError { t ->
t is IllegalStateException && t.message == "No selected model."
@@ -276,7 +273,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is PLAY, expected empty list, then domain list with two items`() {
every {
- stubDao.observe()
+ stubDao.observeByType(any())
} returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST)
every {
@@ -284,7 +281,7 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.PLAY
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
every {
@@ -292,7 +289,7 @@ class DownloadableModelLocalDataSourceTest {
} returns "/tmp/local"
val stubObserver = localDataSource
- .observeAll()
+ .observeAllOnnx()
.test()
stubLocalModels.onNext(emptyList())
@@ -311,7 +308,7 @@ class DownloadableModelLocalDataSourceTest {
@Test
fun `given attempt to observe all models, dao emits empty list, then list with two items, app build type is FOSS, expected list with only CUSTOM model included, then domain list with two items and CUSTOM`() {
every {
- stubDao.observe()
+ stubDao.observeByType(any())
} returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST)
every {
@@ -319,7 +316,7 @@ class DownloadableModelLocalDataSourceTest {
} returns BuildType.FOSS
every {
- stubPreferenceManager.localModelId
+ stubPreferenceManager.localOnnxModelId
} returns ""
every {
@@ -327,14 +324,14 @@ class DownloadableModelLocalDataSourceTest {
} returns "/tmp/local"
val stubObserver = localDataSource
- .observeAll()
+ .observeAllOnnx()
.test()
stubLocalModels.onNext(emptyList())
stubObserver
.assertNoErrors()
- .assertValueAt(0, listOf(LocalAiModel.CUSTOM.copy(downloaded = true)))
+ .assertValueAt(0, listOf(LocalAiModel.CustomOnnx.copy(downloaded = true)))
stubLocalModels.onNext(mockLocalModelEntities)
@@ -342,18 +339,18 @@ class DownloadableModelLocalDataSourceTest {
.assertNoErrors()
.assertValueAt(1, buildList {
addAll(mockLocalModelEntities.mapEntityToDomain())
- add(LocalAiModel.CUSTOM.copy(downloaded = true))
+ add(LocalAiModel.CustomOnnx.copy(downloaded = true))
})
}
@Test
fun `given attempt to observe all models, dao throws exception, expected error value`() {
every {
- stubDao.observe()
+ stubDao.observeByType(any())
} returns Flowable.error(stubException)
localDataSource
- .observeAll()
+ .observeAllOnnx()
.test()
.assertError(stubException)
.assertNoValues()
@@ -361,44 +358,6 @@ class DownloadableModelLocalDataSourceTest {
.assertNotComplete()
}
- @Test
- fun `given attempt to select model, preference changed, expected preference returns changed selected model id value`() {
- every {
- stubPreferenceManager.localModelId
- } returns ""
-
- every {
- stubPreferenceManager::localModelId.set(any())
- } returns Unit
-
- localDataSource
- .select("5598")
- .test()
- .assertNoErrors()
- .await()
- .assertComplete()
-
- every {
- stubPreferenceManager.localModelId
- } returns "5598"
-
- Assert.assertEquals("5598", stubPreferenceManager.localModelId)
- }
-
- @Test
- fun `given attempt to select model, preference throws exception, expected error value`() {
- every {
- stubPreferenceManager::localModelId.set(any())
- } throws stubException
-
- localDataSource
- .select("5598")
- .test()
- .assertError(stubException)
- .await()
- .assertNotComplete()
- }
-
@Test
fun `given attempt to save local model list, dao insert success, expected complete value`() {
every {
@@ -427,8 +386,6 @@ class DownloadableModelLocalDataSourceTest {
.assertNotComplete()
}
- //--
-
@Test
fun `given attempt to delete file, delete operation success, expected complete value`() {
every {
@@ -456,15 +413,4 @@ class DownloadableModelLocalDataSourceTest {
.await()
.assertNotComplete()
}
-
- @Test
- fun `given attempt to check if CUSTOM model is downloaded, expected true`() {
- localDataSource
- .isDownloaded(LocalAiModel.CUSTOM.id)
- .test()
- .assertNoErrors()
- .assertValue(true)
- .await()
- .assertComplete()
- }
}
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt
index 6eee82bc..9aaa20a8 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalAiModelMocks.kt
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel
val mockLocalAiModel = LocalAiModel(
id = "5598",
+ type = LocalAiModel.Type.ONNX,
name = "Model 5598",
size = "5 Gb",
sources = listOf("https://example.com/1.html"),
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt
index 33b45cd9..7beaa4be 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/mocks/LocalModelEntityMocks.kt
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
val mockLocalModelEntity = LocalModelEntity(
id = "5598",
+ type = "onnx",
name = "Best model in entire universe",
size = "5598 Gb",
sources = listOf("https://5598.is.my.favourite.com"),
@@ -12,6 +13,7 @@ val mockLocalModelEntity = LocalModelEntity(
val mockLocalModelEntities = listOf(
LocalModelEntity(
id = "1",
+ type = "onnx",
name = "Model 1",
size = "1 Gb",
sources = listOf("https://example.com/1.php"),
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt
index 0535d7f8..20b1f5b3 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImplTest.kt
@@ -223,17 +223,17 @@ class PreferenceManagerImplTest {
Assert.assertEquals(ServerSource.AUTOMATIC1111, preferenceManager.source)
whenever(stubPreference.getString(eq(KEY_SERVER_SOURCE), any()))
- .thenReturn(ServerSource.LOCAL.key)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX.key)
- preferenceManager.source = ServerSource.LOCAL
+ preferenceManager.source = ServerSource.LOCAL_MICROSOFT_ONNX
- Assert.assertEquals(ServerSource.LOCAL, preferenceManager.source)
+ Assert.assertEquals(ServerSource.LOCAL_MICROSOFT_ONNX, preferenceManager.source)
preferenceManager
.observe()
.test()
.assertNoErrors()
- .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL }
+ .assertValueAt(0) { settings -> settings.source == ServerSource.LOCAL_MICROSOFT_ONNX }
}
@Test
@@ -373,14 +373,14 @@ class PreferenceManagerImplTest {
whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any()))
.thenReturn("")
- Assert.assertEquals("", preferenceManager.localModelId)
+ Assert.assertEquals("", preferenceManager.localOnnxModelId)
whenever(stubPreference.getString(eq(KEY_LOCAL_MODEL_ID), any()))
.thenReturn("key")
- preferenceManager.localModelId = "key"
+ preferenceManager.localOnnxModelId = "key"
- Assert.assertEquals("key", preferenceManager.localModelId)
+ Assert.assertEquals("key", preferenceManager.localOnnxModelId)
}
@Test
@@ -388,14 +388,14 @@ class PreferenceManagerImplTest {
whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any()))
.thenReturn(false)
- Assert.assertEquals(false, preferenceManager.localUseNNAPI)
+ Assert.assertEquals(false, preferenceManager.localOnnxUseNNAPI)
whenever(stubPreference.getBoolean(eq(KEY_LOCAL_NN_API), any()))
.thenReturn(true)
- preferenceManager.localUseNNAPI = true
+ preferenceManager.localOnnxUseNNAPI = true
- Assert.assertEquals(true, preferenceManager.localUseNNAPI)
+ Assert.assertEquals(true, preferenceManager.localOnnxUseNNAPI)
preferenceManager
.observe()
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt
index c97b0a9e..775fd5a6 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/remote/DownloadableModelRemoteDataSourceTest.kt
@@ -5,6 +5,7 @@ import com.nhaarman.mockitokotlin2.whenever
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
import com.shifthackz.aisdv1.data.mappers.mapRawToCheckpointDomain
import com.shifthackz.aisdv1.data.mocks.mockDownloadableModelsResponse
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.network.api.sdai.DownloadableModelsApi
import io.reactivex.rxjava3.core.Single
import org.junit.Test
@@ -22,10 +23,16 @@ class DownloadableModelRemoteDataSourceTest {
@Test
fun `given attempt to fetch models list, api returns data, expected valid domain models list`() {
- whenever(stubApi.fetchDownloadableModels())
+ whenever(stubApi.fetchOnnxModels())
.thenReturn(Single.just(mockDownloadableModelsResponse))
- val expected = mockDownloadableModelsResponse.mapRawToCheckpointDomain()
+ whenever(stubApi.fetchMediaPipeModels())
+ .thenReturn(Single.just(mockDownloadableModelsResponse))
+
+ val expected = listOf(
+ mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.ONNX),
+ mockDownloadableModelsResponse.mapRawToCheckpointDomain(LocalAiModel.Type.MediaPipe),
+ ).flatten()
remoteDataSource
.fetch()
@@ -38,7 +45,10 @@ class DownloadableModelRemoteDataSourceTest {
@Test
fun `given attempt to fetch models list, api returns empty data, expected empty domain models list`() {
- whenever(stubApi.fetchDownloadableModels())
+ whenever(stubApi.fetchOnnxModels())
+ .thenReturn(Single.just(emptyList()))
+
+ whenever(stubApi.fetchMediaPipeModels())
.thenReturn(Single.just(emptyList()))
remoteDataSource
@@ -52,7 +62,10 @@ class DownloadableModelRemoteDataSourceTest {
@Test
fun `given attempt to fetch models list, api returns error, expected error value`() {
- whenever(stubApi.fetchDownloadableModels())
+ whenever(stubApi.fetchOnnxModels())
+ .thenReturn(Single.error(stubException))
+
+ whenever(stubApi.fetchMediaPipeModels())
.thenReturn(Single.error(stubException))
remoteDataSource
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt
index 2daf1b5f..379bc3c5 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt
@@ -1,5 +1,7 @@
package com.shifthackz.aisdv1.data.repository
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
+import com.shifthackz.aisdv1.core.common.appbuild.BuildType
import com.shifthackz.aisdv1.data.mocks.mockLocalAiModel
import com.shifthackz.aisdv1.data.mocks.mockLocalAiModels
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
@@ -25,16 +27,22 @@ class DownloadableModelRepositoryImplTest {
private val stubDownloadState = BehaviorSubject.create()
private val stubRemoteDataSource = mockk()
private val stubLocalDataSource = mockk()
+ private val stubBuildInfoProvider = mockk()
private val repository = DownloadableModelRepositoryImpl(
remoteDataSource = stubRemoteDataSource,
localDataSource = stubLocalDataSource,
+ buildInfoProvider = stubBuildInfoProvider,
)
@Before
fun initialize() {
every {
- stubLocalDataSource.observeAll()
+ stubBuildInfoProvider.type
+ } returns BuildType.FULL
+
+ every {
+ stubLocalDataSource.observeAllOnnx()
} returns stubLocalModels.toFlowable(BackpressureStrategy.LATEST)
every {
@@ -42,51 +50,6 @@ class DownloadableModelRepositoryImplTest {
} returns stubDownloadState
}
- @Test
- fun `given attempt to check if model downloaded, local data source returns true, expected true value`() {
- every {
- stubLocalDataSource.isDownloaded(any())
- } returns Single.just(true)
-
- repository
- .isModelDownloaded("5598")
- .test()
- .assertNoErrors()
- .assertValue(true)
- .await()
- .assertComplete()
- }
-
- @Test
- fun `given attempt to check if model downloaded, local data source returns false, expected false value`() {
- every {
- stubLocalDataSource.isDownloaded(any())
- } returns Single.just(false)
-
- repository
- .isModelDownloaded("5598")
- .test()
- .assertNoErrors()
- .assertValue(false)
- .await()
- .assertComplete()
- }
-
- @Test
- fun `given attempt to check if model downloaded, local data source throws exception, expected error value`() {
- every {
- stubLocalDataSource.isDownloaded(any())
- } returns Single.error(stubException)
-
- repository
- .isModelDownloaded("5598")
- .test()
- .assertError(stubException)
- .assertNoValues()
- .await()
- .assertNotComplete()
- }
-
@Test
fun `given attempt to delete model, local data source completes, expected complete value`() {
every {
@@ -115,34 +78,6 @@ class DownloadableModelRepositoryImplTest {
.assertNotComplete()
}
- @Test
- fun `given attempt to select model, local data source completes, expected complete value`() {
- every {
- stubLocalDataSource.select(any())
- } returns Completable.complete()
-
- repository
- .select("5598")
- .test()
- .assertNoErrors()
- .await()
- .assertComplete()
- }
-
- @Test
- fun `given attempt to select model, local data source throws exception, expected error value`() {
- every {
- stubLocalDataSource.select(any())
- } returns Completable.error(stubException)
-
- repository
- .select("5598")
- .test()
- .assertError(stubException)
- .await()
- .assertNotComplete()
- }
-
@Test
fun `given attempt to get all, remote returns list, save success, local query success, expected valid domain model list value`() {
every {
@@ -154,11 +89,11 @@ class DownloadableModelRepositoryImplTest {
} returns Completable.complete()
every {
- stubLocalDataSource.getAll()
+ stubLocalDataSource.getAllOnnx()
} returns Single.just(mockLocalAiModels)
repository
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue(mockLocalAiModels)
@@ -177,11 +112,11 @@ class DownloadableModelRepositoryImplTest {
} returns Completable.error(stubException)
every {
- stubLocalDataSource.getAll()
+ stubLocalDataSource.getAllOnnx()
} returns Single.just(mockLocalAiModels)
repository
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue(mockLocalAiModels)
@@ -200,11 +135,11 @@ class DownloadableModelRepositoryImplTest {
} returns Completable.complete()
every {
- stubLocalDataSource.getAll()
+ stubLocalDataSource.getAllOnnx()
} returns Single.just(mockLocalAiModels)
repository
- .getAll()
+ .getAllOnnx()
.test()
.assertNoErrors()
.assertValue(mockLocalAiModels)
@@ -223,41 +158,11 @@ class DownloadableModelRepositoryImplTest {
} returns Completable.complete()
every {
- stubLocalDataSource.getAll()
- } returns Single.error(stubException)
-
- repository
- .getAll()
- .test()
- .assertError(stubException)
- .assertNoValues()
- .await()
- .assertNotComplete()
- }
-
- @Test
- fun `given attempt to get by id, local data source returns data, expected valid domain model value`() {
- every {
- stubLocalDataSource.getById(any())
- } returns Single.just(mockLocalAiModel)
-
- repository
- .getById("5598")
- .test()
- .assertNoErrors()
- .assertValue(mockLocalAiModel)
- .await()
- .assertComplete()
- }
-
- @Test
- fun `given attempt to get by id, local data source fails, expected error value`() {
- every {
- stubLocalDataSource.getById(any())
+ stubLocalDataSource.getAllOnnx()
} returns Single.error(stubException)
repository
- .getById("5598")
+ .getAllOnnx()
.test()
.assertError(stubException)
.assertNoValues()
@@ -267,7 +172,7 @@ class DownloadableModelRepositoryImplTest {
@Test
fun `given observe all models, local data source emits empty list, then another list, expected empty value, then valid domain models list value`() {
- val stubObserver = repository.observeAll().test()
+ val stubObserver = repository.observeAllOnnx().test()
stubLocalModels.onNext(emptyList())
@@ -284,7 +189,7 @@ class DownloadableModelRepositoryImplTest {
@Test
fun `given observe all models, local data source emits list, then changed list, expected valid domain models list value, then changed value`() {
- val stubObserver = repository.observeAll().test()
+ val stubObserver = repository.observeAllOnnx().test()
stubLocalModels.onNext(mockLocalAiModels)
@@ -302,11 +207,11 @@ class DownloadableModelRepositoryImplTest {
@Test
fun `given observe all models, local data source throws exception, expected error value`() {
every {
- stubLocalDataSource.observeAll()
+ stubLocalDataSource.observeAllOnnx()
} returns Flowable.error(stubException)
repository
- .observeAll()
+ .observeAllOnnx()
.test()
.assertError(stubException)
.assertNoValues()
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt
index 75df885f..126edac1 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/LocalDiffusionGenerationRepositoryImplTest.kt
@@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.data.mocks.mockTextToImagePayload
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
import com.shifthackz.aisdv1.domain.datasource.GenerationResultDataSource
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver
import com.shifthackz.aisdv1.domain.gateway.MediaStoreGateway
@@ -31,7 +32,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
private val stubBitmap = mockk()
private val stubException = Throwable("Something went wrong.")
- private val stubStatus = BehaviorSubject.create()
+ private val stubStatus = BehaviorSubject.create()
private val stubMediaStoreGateway = mockk()
private val stubBase64ToBitmapConverter = mockk()
private val stubBitmapToBase64Converter = mockk()
@@ -63,7 +64,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
@Before
fun initialize() {
every {
- stubPreferenceManager::localDiffusionSchedulerThread.get()
+ stubPreferenceManager::localOnnxSchedulerThread.get()
} returns SchedulersToken.COMPUTATION
every {
@@ -83,17 +84,17 @@ class LocalDiffusionGenerationRepositoryImplTest {
fun `given attempt to observe status, local emits two values, expected same values with same order`() {
val stubObserver = repository.observeStatus().test()
- stubStatus.onNext(LocalDiffusion.Status(1, 2))
+ stubStatus.onNext(LocalDiffusionStatus(1, 2))
stubObserver
.assertNoErrors()
- .assertValueAt(0, LocalDiffusion.Status(1, 2))
+ .assertValueAt(0, LocalDiffusionStatus(1, 2))
- stubStatus.onNext(LocalDiffusion.Status(2, 2))
+ stubStatus.onNext(LocalDiffusionStatus(2, 2))
stubObserver
.assertNoErrors()
- .assertValueAt(1, LocalDiffusion.Status(2, 2))
+ .assertValueAt(1, LocalDiffusionStatus(2, 2))
}
@Test
@@ -142,7 +143,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
@Test
fun `given attempt to generate from text, no selected model, expected error value`() {
every {
- stubDownloadableLocalDataSource.getSelected()
+ stubDownloadableLocalDataSource.getSelectedOnnx()
} returns Single.error(stubException)
repository
@@ -157,7 +158,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
@Test
fun `given attempt to generate from text, has selected not downloaded model, expected IllegalStateException error value`() {
every {
- stubDownloadableLocalDataSource.getSelected()
+ stubDownloadableLocalDataSource.getSelectedOnnx()
} returns Single.just(mockLocalAiModel.copy(downloaded = false))
every {
@@ -182,7 +183,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
@Test
fun `given attempt to generate from text, has selected downloaded model, local process success, expected valid domain model value`() {
every {
- stubDownloadableLocalDataSource.getSelected()
+ stubDownloadableLocalDataSource.getSelectedOnnx()
} returns Single.just(mockLocalAiModel.copy(downloaded = true))
every {
@@ -205,7 +206,7 @@ class LocalDiffusionGenerationRepositoryImplTest {
@Test
fun `given attempt to generate from text, has selected downloaded model, local process fails, expected error value`() {
every {
- stubDownloadableLocalDataSource.getSelected()
+ stubDownloadableLocalDataSource.getSelectedOnnx()
} returns Single.just(mockLocalAiModel.copy(downloaded = true))
every {
diff --git a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt
index 3cc310df..7ef7f17a 100644
--- a/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt
+++ b/data/src/test/java/com/shifthackz/aisdv1/data/repository/StabilityAiCreditsRepositoryImplTest.kt
@@ -40,7 +40,7 @@ class StabilityAiCreditsRepositoryImplTest {
fun `given server source is not STABILITY_AI, attempt to fetch, expected IllegalStateException error value`() {
every {
stubPreferenceManager.source
- } returns ServerSource.LOCAL
+ } returns ServerSource.LOCAL_MICROSOFT_ONNX
every {
stubRemoteDataSource.fetch()
@@ -62,7 +62,7 @@ class StabilityAiCreditsRepositoryImplTest {
fun `given server source is not STABILITY_AI, attempt to fetch and get, expected IllegalStateException error value`() {
every {
stubPreferenceManager.source
- } returns ServerSource.LOCAL
+ } returns ServerSource.LOCAL_MICROSOFT_ONNX
every {
stubRemoteDataSource.fetch()
@@ -88,7 +88,7 @@ class StabilityAiCreditsRepositoryImplTest {
fun `given server source is not STABILITY_AI, attempt to fetch and observe, expected IllegalStateException error value`() {
every {
stubPreferenceManager.source
- } returns ServerSource.LOCAL
+ } returns ServerSource.LOCAL_MICROSOFT_ONNX
every {
stubRemoteDataSource.fetch()
@@ -110,7 +110,7 @@ class StabilityAiCreditsRepositoryImplTest {
fun `given server source is not STABILITY_AI, attempt to get, expected IllegalStateException error value`() {
every {
stubPreferenceManager.source
- } returns ServerSource.LOCAL
+ } returns ServerSource.LOCAL_MICROSOFT_ONNX
every {
stubLocalDataSource.get()
@@ -128,7 +128,7 @@ class StabilityAiCreditsRepositoryImplTest {
fun `given server source is not STABILITY_AI, attempt to observe, expected IllegalStateException error value`() {
every {
stubPreferenceManager.source
- } returns ServerSource.LOCAL
+ } returns ServerSource.LOCAL_MICROSOFT_ONNX
repository
.observe()
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt
index b2320323..cb56c621 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/datasource/DownloadableModelDataSource.kt
@@ -15,13 +15,12 @@ sealed interface DownloadableModelDataSource {
}
interface Local : DownloadableModelDataSource {
- fun getAll(): Single>
+ fun getAllOnnx(): Single>
+ fun getAllMediaPipe(): Single>
fun getById(id: String): Single
- fun getSelected(): Single
- fun observeAll(): Flowable>
- fun select(id: String): Completable
+ fun getSelectedOnnx(): Single
+ fun observeAllOnnx(): Flowable>
fun save(list: List): Completable
- fun isDownloaded(id: String): Single
fun delete(id: String): Completable
}
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt
index bccb6ebd..5bdaf74e 100755
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt
@@ -36,10 +36,12 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
-import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
-import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCaseImpl
-import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
-import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCaseImpl
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCaseImpl
+import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCase
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteAllGalleryUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.gallery.DeleteGalleryItemUseCase
@@ -92,6 +94,8 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCaseImpl
+import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase
+import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCaseImpl
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase
@@ -155,15 +159,17 @@ internal val useCasesModule = module {
factoryOf(::SaveLastResultToCacheUseCaseImpl) bind SaveLastResultToCacheUseCase::class
factoryOf(::GetLastResultFromCacheUseCaseImpl) bind GetLastResultFromCacheUseCase::class
factoryOf(::ObserveLocalDiffusionProcessStatusUseCaseImpl) bind ObserveLocalDiffusionProcessStatusUseCase::class
- factoryOf(::GetLocalAiModelsUseCaseImpl) bind GetLocalAiModelsUseCase::class
+ factoryOf(::GetLocalOnnxModelsUseCaseImpl) bind GetLocalOnnxModelsUseCase::class
+ factoryOf(::GetLocalMediaPipeModelsUseCaseImpl) bind GetLocalMediaPipeModelsUseCase::class
factoryOf(::DownloadModelUseCaseImpl) bind DownloadModelUseCase::class
- factoryOf(::ObserveLocalAiModelsUseCaseImpl) bind ObserveLocalAiModelsUseCase::class
+ factoryOf(::ObserveLocalOnnxModelsUseCaseImpl) bind ObserveLocalOnnxModelsUseCase::class
factoryOf(::DeleteModelUseCaseImpl) bind DeleteModelUseCase::class
factoryOf(::AcquireWakelockUseCaseImpl) bind AcquireWakelockUseCase::class
factoryOf(::ReleaseWakeLockUseCaseImpl) bind ReleaseWakeLockUseCase::class
factoryOf(::InterruptGenerationUseCaseImpl) bind InterruptGenerationUseCase::class
factoryOf(::ConnectToHordeUseCaseImpl) bind ConnectToHordeUseCase::class
factoryOf(::ConnectToLocalDiffusionUseCaseImpl) bind ConnectToLocalDiffusionUseCase::class
+ factoryOf(::ConnectToMediaPipeUseCaseImpl) bind ConnectToMediaPipeUseCase::class
factoryOf(::ConnectToA1111UseCaseImpl) bind ConnectToA1111UseCase::class
factoryOf(::ConnectToSwarmUiUseCaseImpl) bind ConnectToSwarmUiUseCase::class
factoryOf(::ConnectToHuggingFaceUseCaseImpl) bind ConnectToHuggingFaceUseCase::class
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
index fcf04c19..40c5e10a 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
@@ -15,6 +15,8 @@ data class Configuration(
val stabilityAiApiKey: String = "",
val stabilityAiEngineId: String = "",
val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None,
- val localModelId: String = "",
- val localModelPath: String = "",
+ val localOnnxModelId: String = "",
+ val localOnnxModelPath: String = "",
+ val localMediaPipeModelId: String = "",
+ val localMediaPipeModelPath: String = "",
)
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt
index 734ad5ed..b7958512 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalAiModel.kt
@@ -2,15 +2,34 @@ package com.shifthackz.aisdv1.domain.entity
data class LocalAiModel(
val id: String,
+ val type: Type,
val name: String,
val size: String,
val sources: List,
val downloaded: Boolean = false,
val selected: Boolean = false,
) {
+ enum class Type(val key: String) {
+ ONNX("onnx"),
+ MediaPipe("mediapipe");
+
+ companion object {
+ fun parse(value: String?) = entries.find { it.key == value } ?: ONNX
+ }
+ }
+
companion object {
- val CUSTOM = LocalAiModel(
+ val CustomOnnx = LocalAiModel(
id = "CUSTOM",
+ type = Type.ONNX,
+ name = "Custom",
+ size = "NaN",
+ sources = emptyList(),
+ )
+
+ val CustomMediaPipe = LocalAiModel(
+ id = "CUSTOM_MP",
+ type = Type.MediaPipe,
name = "Custom",
size = "NaN",
sources = emptyList(),
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt
new file mode 100644
index 00000000..779d8bc6
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/LocalDiffusionStatus.kt
@@ -0,0 +1,6 @@
+package com.shifthackz.aisdv1.domain.entity
+
+data class LocalDiffusionStatus(
+ val current: Int,
+ val total: Int,
+)
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt
index 4eeca344..15244e1d 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/ServerSource.kt
@@ -1,8 +1,11 @@
package com.shifthackz.aisdv1.domain.entity
+import com.shifthackz.aisdv1.core.common.appbuild.BuildType
+
enum class ServerSource(
val key: String,
val featureTags: Set,
+ val allowedInBuilds: Set = setOf(BuildType.FOSS, BuildType.PLAY, BuildType.FULL),
) {
AUTOMATIC1111(
key = "custom",
@@ -62,13 +65,22 @@ enum class ServerSource(
FeatureTag.Batch,
),
),
- LOCAL(
+ LOCAL_MICROSOFT_ONNX(
key = "local",
featureTags = setOf(
FeatureTag.Offline,
FeatureTag.Txt2Img,
FeatureTag.MultipleModels,
),
+ ),
+ LOCAL_GOOGLE_MEDIA_PIPE(
+ key = "local_google_media_pipe",
+ featureTags = setOf(
+ FeatureTag.Offline,
+ FeatureTag.Txt2Img,
+ FeatureTag.MultipleModels,
+ ),
+ allowedInBuilds = setOf(BuildType.PLAY, BuildType.FULL),
);
companion object {
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt
index d1f260e7..afcdefd0 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/diffusion/LocalDiffusion.kt
@@ -1,6 +1,7 @@
package com.shifthackz.aisdv1.domain.feature.diffusion
import android.graphics.Bitmap
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Observable
@@ -9,7 +10,5 @@ import io.reactivex.rxjava3.core.Single
interface LocalDiffusion {
fun process(payload: TextToImagePayload): Single
fun interrupt(): Completable
- fun observeStatus(): Observable
-
- data class Status(val current: Int, val total: Int)
+ fun observeStatus(): Observable
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt
new file mode 100644
index 00000000..d27e0502
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/feature/mediapipe/MediaPipe.kt
@@ -0,0 +1,9 @@
+package com.shifthackz.aisdv1.domain.feature.mediapipe
+
+import android.graphics.Bitmap
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import io.reactivex.rxjava3.core.Single
+
+interface MediaPipe {
+ fun process(payload: TextToImagePayload): Single
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt
index 4c42f81c..5959a035 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActor.kt
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase
+import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase
@@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase
interface SetupConnectionInterActor {
val connectToHorde: ConnectToHordeUseCase
val connectToLocal: ConnectToLocalDiffusionUseCase
+ val connectToMediaPipe: ConnectToMediaPipeUseCase
val connectToA1111: ConnectToA1111UseCase
val connectToHuggingFace: ConnectToHuggingFaceUseCase
val connectToOpenAi: ConnectToOpenAiUseCase
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt
index 05517631..306da094 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/interactor/settings/SetupConnectionInterActorImpl.kt
@@ -4,6 +4,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToA1111UseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHordeUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToHuggingFaceUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToLocalDiffusionUseCase
+import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToMediaPipeUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToOpenAiUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToStabilityAiUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase
@@ -11,6 +12,7 @@ import com.shifthackz.aisdv1.domain.usecase.settings.ConnectToSwarmUiUseCase
internal data class SetupConnectionInterActorImpl(
override val connectToHorde: ConnectToHordeUseCase,
override val connectToLocal: ConnectToLocalDiffusionUseCase,
+ override val connectToMediaPipe: ConnectToMediaPipeUseCase,
override val connectToA1111: ConnectToA1111UseCase,
override val connectToHuggingFace: ConnectToHuggingFaceUseCase,
override val connectToOpenAi: ConnectToOpenAiUseCase,
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
index 1ee8950b..42c32959 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
@@ -12,9 +12,10 @@ interface PreferenceManager {
var swarmUiModel: String
var demoMode: Boolean
var developerMode: Boolean
- var localDiffusionCustomModelPath: String
- var localDiffusionAllowCancel: Boolean
- var localDiffusionSchedulerThread: SchedulersToken
+ var localMediaPipeCustomModelPath: String
+ var localOnnxCustomModelPath: String
+ var localOnnxAllowCancel: Boolean
+ var localOnnxSchedulerThread: SchedulersToken
var monitorConnectivity: Boolean
var autoSaveAiResults: Boolean
var saveToMediaStore: Boolean
@@ -30,8 +31,9 @@ interface PreferenceManager {
var stabilityAiEngineId: String
var onBoardingComplete: Boolean
var forceSetupAfterUpdate: Boolean
- var localModelId: String
- var localUseNNAPI: Boolean
+ var localOnnxModelId: String
+ var localOnnxUseNNAPI: Boolean
+ var localMediaPipeModelId: String
var designUseSystemColorPalette: Boolean
var designUseSystemDarkTheme: Boolean
var designDarkTheme: Boolean
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt
index dcab5955..79efed66 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt
@@ -8,11 +8,9 @@ import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
interface DownloadableModelRepository {
- fun isModelDownloaded(id: String): Single
fun download(id: String): Observable
fun delete(id: String): Completable
- fun getAll(): Single>
- fun getById(id: String): Single
- fun observeAll(): Flowable>
- fun select(id: String): Completable
+ fun getAllOnnx(): Single>
+ fun getAllMediaPipe(): Single>
+ fun observeAllOnnx(): Flowable>
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt
index 0e109843..3ed5c370 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/LocalDiffusionGenerationRepository.kt
@@ -1,14 +1,14 @@
package com.shifthackz.aisdv1.domain.repository
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
-import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
interface LocalDiffusionGenerationRepository {
- fun observeStatus(): Observable
+ fun observeStatus(): Observable
fun generateFromText(payload: TextToImagePayload): Single
fun interruptGeneration(): Completable
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt
new file mode 100644
index 00000000..a7b65fb5
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/repository/MediaPipeGenerationRepository.kt
@@ -0,0 +1,11 @@
+package com.shifthackz.aisdv1.domain.repository
+
+import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import io.reactivex.rxjava3.core.Observable
+import io.reactivex.rxjava3.core.Single
+
+interface MediaPipeGenerationRepository {
+ fun generateFromText(payload: TextToImagePayload): Single
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt
new file mode 100644
index 00000000..cd93801f
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCase.kt
@@ -0,0 +1,8 @@
+package com.shifthackz.aisdv1.domain.usecase.downloadable
+
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
+import io.reactivex.rxjava3.core.Single
+
+interface GetLocalMediaPipeModelsUseCase {
+ operator fun invoke(): Single>
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt
new file mode 100644
index 00000000..4da8eeba
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalMediaPipeModelsUseCaseImpl.kt
@@ -0,0 +1,10 @@
+package com.shifthackz.aisdv1.domain.usecase.downloadable
+
+import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
+
+internal class GetLocalMediaPipeModelsUseCaseImpl(
+ private val downloadableModelRepository: DownloadableModelRepository,
+ ) : GetLocalMediaPipeModelsUseCase {
+
+ override fun invoke() = downloadableModelRepository.getAllMediaPipe()
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt
similarity index 84%
rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt
rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt
index efe71374..f2cce297 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCase.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCase.kt
@@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Single
-interface GetLocalAiModelsUseCase {
+interface GetLocalOnnxModelsUseCase {
operator fun invoke(): Single>
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt
similarity index 59%
rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt
rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt
index 7bdbe7e7..444eccca 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImpl.kt
@@ -2,9 +2,9 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable
import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
-internal class GetLocalAiModelsUseCaseImpl(
+internal class GetLocalOnnxModelsUseCaseImpl(
private val downloadableModelRepository: DownloadableModelRepository,
-) : GetLocalAiModelsUseCase {
+) : GetLocalOnnxModelsUseCase {
- override fun invoke() = downloadableModelRepository.getAll()
+ override fun invoke() = downloadableModelRepository.getAllOnnx()
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt
similarity index 83%
rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt
rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt
index f79d31b1..021ebd1e 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCase.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCase.kt
@@ -3,6 +3,6 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import io.reactivex.rxjava3.core.Flowable
-interface ObserveLocalAiModelsUseCase {
+interface ObserveLocalOnnxModelsUseCase {
operator fun invoke(): Flowable>
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt
similarity index 70%
rename from domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt
rename to domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt
index 5fae54e1..e5e290f8 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImpl.kt
@@ -2,11 +2,11 @@ package com.shifthackz.aisdv1.domain.usecase.downloadable
import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
-internal class ObserveLocalAiModelsUseCaseImpl(
+internal class ObserveLocalOnnxModelsUseCaseImpl(
private val repository: DownloadableModelRepository,
-) : ObserveLocalAiModelsUseCase {
+) : ObserveLocalOnnxModelsUseCase {
override fun invoke() = repository
- .observeAll()
+ .observeAllOnnx()
.distinctUntilChanged()
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt
index 0f769197..37f6720c 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImpl.kt
@@ -17,7 +17,7 @@ internal class InterruptGenerationUseCaseImpl(
override fun invoke() = when (preferenceManager.source) {
ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.interruptGeneration()
ServerSource.HORDE -> hordeGenerationRepository.interruptGeneration()
- ServerSource.LOCAL -> localDiffusionGenerationRepository.interruptGeneration()
+ ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.interruptGeneration()
else -> Completable.complete()
}
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt
index 3a65f91b..9530b6b0 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCase.kt
@@ -1,8 +1,8 @@
package com.shifthackz.aisdv1.domain.usecase.generation
-import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import io.reactivex.rxjava3.core.Observable
interface ObserveLocalDiffusionProcessStatusUseCase {
- operator fun invoke(): Observable
+ operator fun invoke(): Observable
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt
index 45be1d30..f823a153 100755
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImpl.kt
@@ -7,6 +7,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
+import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository
import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository
@@ -22,6 +23,7 @@ internal class TextToImageUseCaseImpl(
private val stabilityAiGenerationRepository: StabilityAiGenerationRepository,
private val swarmUiGenerationRepository: SwarmUiGenerationRepository,
private val localDiffusionGenerationRepository: LocalDiffusionGenerationRepository,
+ private val mediaPipeGenerationRepository: MediaPipeGenerationRepository,
private val preferenceManager: PreferenceManager,
) : TextToImageUseCase {
@@ -34,11 +36,12 @@ internal class TextToImageUseCaseImpl(
private fun generate(payload: TextToImagePayload) = when (preferenceManager.source) {
ServerSource.HORDE -> hordeGenerationRepository.generateFromText(payload)
- ServerSource.LOCAL -> localDiffusionGenerationRepository.generateFromText(payload)
+ ServerSource.LOCAL_MICROSOFT_ONNX -> localDiffusionGenerationRepository.generateFromText(payload)
ServerSource.HUGGING_FACE -> huggingFaceGenerationRepository.generateFromText(payload)
ServerSource.AUTOMATIC1111 -> stableDiffusionGenerationRepository.generateFromText(payload)
ServerSource.OPEN_AI -> openAiGenerationRepository.generateFromText(payload)
ServerSource.STABILITY_AI -> stabilityAiGenerationRepository.generateFromText(payload)
ServerSource.SWARM_UI -> swarmUiGenerationRepository.generateFromText(payload)
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> mediaPipeGenerationRepository.generateFromText(payload)
}
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt
index 0bbed0c9..d2519425 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToLocalDiffusionUseCaseImpl.kt
@@ -11,8 +11,8 @@ internal class ConnectToLocalDiffusionUseCaseImpl(
override fun invoke(modelId: String) = getConfigurationUseCase()
.map { originalConfiguration ->
originalConfiguration.copy(
- source = ServerSource.LOCAL,
- localModelId = modelId,
+ source = ServerSource.LOCAL_MICROSOFT_ONNX,
+ localOnnxModelId = modelId,
)
}
.flatMapCompletable(setServerConfigurationUseCase::invoke)
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt
new file mode 100644
index 00000000..5c1027e6
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCase.kt
@@ -0,0 +1,7 @@
+package com.shifthackz.aisdv1.domain.usecase.settings
+
+import io.reactivex.rxjava3.core.Single
+
+interface ConnectToMediaPipeUseCase {
+ operator fun invoke(modelId: String): Single>
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt
new file mode 100644
index 00000000..a60bdc68
--- /dev/null
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/ConnectToMediaPipeUseCaseImpl.kt
@@ -0,0 +1,21 @@
+package com.shifthackz.aisdv1.domain.usecase.settings
+
+import com.shifthackz.aisdv1.domain.entity.ServerSource
+import io.reactivex.rxjava3.core.Single
+
+internal class ConnectToMediaPipeUseCaseImpl(
+ private val getConfigurationUseCase: GetConfigurationUseCase,
+ private val setServerConfigurationUseCase: SetServerConfigurationUseCase,
+) : ConnectToMediaPipeUseCase {
+
+ override fun invoke(modelId: String): Single> = getConfigurationUseCase()
+ .map { originalConfiguration ->
+ originalConfiguration.copy(
+ source = ServerSource.LOCAL_GOOGLE_MEDIA_PIPE,
+ localMediaPipeModelId = modelId,
+ )
+ }
+ .flatMapCompletable(setServerConfigurationUseCase::invoke)
+ .andThen(Single.just(Result.success(Unit)))
+ .onErrorResumeNext { t -> Single.just(Result.failure(t)) }
+}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
index 4070a0a6..e88ebd63 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
@@ -24,8 +24,10 @@ internal class GetConfigurationUseCaseImpl(
stabilityAiApiKey = preferenceManager.stabilityAiApiKey,
stabilityAiEngineId = preferenceManager.stabilityAiEngineId,
authCredentials = authorizationStore.getAuthorizationCredentials(),
- localModelId = preferenceManager.localModelId,
- localModelPath = preferenceManager.localDiffusionCustomModelPath,
+ localOnnxModelId = preferenceManager.localOnnxModelId,
+ localOnnxModelPath = preferenceManager.localOnnxCustomModelPath,
+ localMediaPipeModelId = preferenceManager.localMediaPipeModelId,
+ localMediaPipeModelPath = preferenceManager.localMediaPipeCustomModelPath,
)
)
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
index 8d03619c..aab29abc 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
@@ -24,7 +24,9 @@ internal class SetServerConfigurationUseCaseImpl(
preferenceManager.huggingFaceModel = configuration.huggingFaceModel
preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey
preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId
- preferenceManager.localModelId = configuration.localModelId
- preferenceManager.localDiffusionCustomModelPath = configuration.localModelPath
+ preferenceManager.localOnnxModelId = configuration.localOnnxModelId
+ preferenceManager.localOnnxCustomModelPath = configuration.localOnnxModelPath
+ preferenceManager.localMediaPipeModelId = configuration.localMediaPipeModelId
+ preferenceManager.localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath
}
}
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
index 73503575..da1d486a 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
@@ -12,6 +12,6 @@ val mockConfiguration = Configuration(
huggingFaceModel = "5598",
stabilityAiApiKey = "5598",
stabilityAiEngineId = "5598",
- localModelId = "5598",
- localModelPath = "/storage/emulated/0/5598",
+ localOnnxModelId = "5598",
+ localOnnxModelPath = "/storage/emulated/0/5598",
)
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt
index 69fcf3f0..5ce3fa3b 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/LocalAiModelMocks.kt
@@ -3,9 +3,10 @@ package com.shifthackz.aisdv1.domain.mocks
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
val mockLocalAiModels = listOf(
- LocalAiModel.CUSTOM,
+ LocalAiModel.CustomOnnx,
LocalAiModel(
id = "1",
+ type = LocalAiModel.Type.ONNX,
name = "Model 1",
size = "5 Gb",
sources = listOf("https://example.com/1.html"),
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt
similarity index 84%
rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt
rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt
index 679ca297..d154c163 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalAiModelsUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/GetLocalOnnxModelsUseCaseImplTest.kt
@@ -7,15 +7,15 @@ import com.shifthackz.aisdv1.domain.repository.DownloadableModelRepository
import io.reactivex.rxjava3.core.Single
import org.junit.Test
-class GetLocalAiModelsUseCaseImplTest {
+class GetLocalOnnxModelsUseCaseImplTest {
private val stubRepository = mock()
- private val useCase = GetLocalAiModelsUseCaseImpl(stubRepository)
+ private val useCase = GetLocalOnnxModelsUseCaseImpl(stubRepository)
@Test
fun `given repository returned models list, expected valid models list value`() {
- whenever(stubRepository.getAll())
+ whenever(stubRepository.getAllOnnx())
.thenReturn(Single.just(mockLocalAiModels))
useCase()
@@ -28,7 +28,7 @@ class GetLocalAiModelsUseCaseImplTest {
@Test
fun `given repository returned empty models list, expected empty models list value`() {
- whenever(stubRepository.getAll())
+ whenever(stubRepository.getAllOnnx())
.thenReturn(Single.just(emptyList()))
useCase()
@@ -43,7 +43,7 @@ class GetLocalAiModelsUseCaseImplTest {
fun `given repository thrown exception, expected error value`() {
val stubException = Throwable("Unable to collect local models.")
- whenever(stubRepository.getAll())
+ whenever(stubRepository.getAllOnnx())
.thenReturn(Single.error(stubException))
useCase()
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt
similarity index 91%
rename from domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt
rename to domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt
index 00fde27a..a61684b1 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalAiModelsUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/ObserveLocalOnnxModelsUseCaseImplTest.kt
@@ -11,16 +11,16 @@ import io.reactivex.rxjava3.subjects.BehaviorSubject
import org.junit.Before
import org.junit.Test
-class ObserveLocalAiModelsUseCaseImplTest {
+class ObserveLocalOnnxModelsUseCaseImplTest {
private val stubLocalModels = BehaviorSubject.create>()
private val stubRepository = mock()
- private val useCase = ObserveLocalAiModelsUseCaseImpl(stubRepository)
+ private val useCase = ObserveLocalOnnxModelsUseCaseImpl(stubRepository)
@Before
fun initialize() {
- whenever(stubRepository.observeAll())
+ whenever(stubRepository.observeAllOnnx())
.thenReturn(stubLocalModels.toFlowable(BackpressureStrategy.LATEST))
}
@@ -68,7 +68,7 @@ class ObserveLocalAiModelsUseCaseImplTest {
.assertNoErrors()
.assertValueAt(0, mockLocalAiModels)
- val changedLocalAiModels = listOf(LocalAiModel.CUSTOM)
+ val changedLocalAiModels = listOf(LocalAiModel.CustomOnnx)
stubLocalModels.onNext(changedLocalAiModels)
stubObserver
@@ -97,7 +97,7 @@ class ObserveLocalAiModelsUseCaseImplTest {
fun `given observer terminates with unexpected error, expected receive error value`() {
val stubException = Throwable("Unexpected Flowable termination.")
- whenever(stubRepository.observeAll())
+ whenever(stubRepository.observeAllOnnx())
.thenReturn(Flowable.error(stubException))
useCase()
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt
index 9f2fac3b..5641d4f2 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ImageToImageUseCaseImplTest.kt
@@ -293,7 +293,7 @@ class ImageToImageUseCaseImplTest {
@Test
fun `given source is LOCAL, expected Img2Img not yet supported error`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
useCase(mockImageToImagePayload)
.test()
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt
index 5ef6334c..895bb8a8 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/InterruptGenerationUseCaseImplTest.kt
@@ -88,7 +88,7 @@ class InterruptGenerationUseCaseImplTest {
@Test
fun `given source is LOCAL, api interrupt success, expected complete value`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
whenever(stubLocalDiffusionGenerationRepository.interruptGeneration())
.thenReturn(Completable.complete())
@@ -103,7 +103,7 @@ class InterruptGenerationUseCaseImplTest {
@Test
fun `given source is LOCAL, api interrupt fail, expected error value`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
whenever(stubLocalDiffusionGenerationRepository.interruptGeneration())
.thenReturn(Completable.error(stubException))
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt
index 528e2e6e..47836c4e 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/ObserveLocalDiffusionProcessStatusUseCaseImplTest.kt
@@ -2,7 +2,7 @@ package com.shifthackz.aisdv1.domain.usecase.generation
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.whenever
-import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.subjects.BehaviorSubject
@@ -12,7 +12,7 @@ import org.junit.Test
class ObserveLocalDiffusionProcessStatusUseCaseImplTest {
private val stubException = Throwable("Error loading Local Diffusion.")
- private val stubLocalStatus = BehaviorSubject.create()
+ private val stubLocalStatus = BehaviorSubject.create()
private val stubRepository = mock()
private val useCase = ObserveLocalDiffusionProcessStatusUseCaseImpl(stubRepository)
@@ -27,23 +27,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest {
fun `given repository processes three steps, expected three valid status values`() {
val stubObserver = useCase().test()
- stubLocalStatus.onNext(LocalDiffusion.Status(1, 3))
+ stubLocalStatus.onNext(LocalDiffusionStatus(1, 3))
stubObserver
.assertNoErrors()
- .assertValueAt(0, LocalDiffusion.Status(1, 3))
+ .assertValueAt(0, LocalDiffusionStatus(1, 3))
- stubLocalStatus.onNext(LocalDiffusion.Status(2, 3))
+ stubLocalStatus.onNext(LocalDiffusionStatus(2, 3))
stubObserver
.assertNoErrors()
- .assertValueAt(1, LocalDiffusion.Status(2, 3))
+ .assertValueAt(1, LocalDiffusionStatus(2, 3))
- stubLocalStatus.onNext(LocalDiffusion.Status(3, 3))
+ stubLocalStatus.onNext(LocalDiffusionStatus(3, 3))
stubObserver
.assertNoErrors()
- .assertValueAt(2, LocalDiffusion.Status(3, 3))
+ .assertValueAt(2, LocalDiffusionStatus(3, 3))
.assertValueCount(3)
}
@@ -51,23 +51,23 @@ class ObserveLocalDiffusionProcessStatusUseCaseImplTest {
fun `given repository processes two steps, emits same step twice, expected two valid status values`() {
val stubObserver = useCase().test()
- stubLocalStatus.onNext(LocalDiffusion.Status(1, 2))
+ stubLocalStatus.onNext(LocalDiffusionStatus(1, 2))
stubObserver
.assertNoErrors()
- .assertValueAt(0, LocalDiffusion.Status(1, 2))
+ .assertValueAt(0, LocalDiffusionStatus(1, 2))
- stubLocalStatus.onNext(LocalDiffusion.Status(1, 2))
+ stubLocalStatus.onNext(LocalDiffusionStatus(1, 2))
stubObserver
.assertNoErrors()
.assertValueCount(1)
- stubLocalStatus.onNext(LocalDiffusion.Status(2, 2))
+ stubLocalStatus.onNext(LocalDiffusionStatus(2, 2))
stubObserver
.assertNoErrors()
- .assertValueAt(1, LocalDiffusion.Status(2, 2))
+ .assertValueAt(1, LocalDiffusionStatus(2, 2))
.assertValueCount(2)
}
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt
index b6d3bced..98505eac 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/generation/TextToImageUseCaseImplTest.kt
@@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.repository.HordeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.HuggingFaceGenerationRepository
import com.shifthackz.aisdv1.domain.repository.LocalDiffusionGenerationRepository
+import com.shifthackz.aisdv1.domain.repository.MediaPipeGenerationRepository
import com.shifthackz.aisdv1.domain.repository.OpenAiGenerationRepository
import com.shifthackz.aisdv1.domain.repository.StabilityAiGenerationRepository
import com.shifthackz.aisdv1.domain.repository.StableDiffusionGenerationRepository
@@ -27,6 +28,7 @@ class TextToImageUseCaseImplTest {
private val stubStabilityAiGenerationRepository = mock()
private val stubSwarmUiGenerationRepository = mock()
private val stubLocalDiffusionGenerationRepository = mock()
+ private val stubMediaPipeGenerationRepository = mock()
private val stubPreferenceManager = mock()
private val useCase = TextToImageUseCaseImpl(
@@ -37,6 +39,7 @@ class TextToImageUseCaseImplTest {
stabilityAiGenerationRepository = stubStabilityAiGenerationRepository,
localDiffusionGenerationRepository = stubLocalDiffusionGenerationRepository,
swarmUiGenerationRepository = stubSwarmUiGenerationRepository,
+ mediaPipeGenerationRepository = stubMediaPipeGenerationRepository,
preferenceManager = stubPreferenceManager,
)
@@ -363,7 +366,7 @@ class TextToImageUseCaseImplTest {
@Test
fun `given source is LOCAL, batch count is 1, generated successfully, expected generations list with size 1`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
whenever(stubLocalDiffusionGenerationRepository.generateFromText(any()))
.thenReturn(Single.just(mockAiGenerationResult))
@@ -386,7 +389,7 @@ class TextToImageUseCaseImplTest {
@Test
fun `given source is LOCAL, batch count is 10, generated successfully, expected generations list with size 10`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
whenever(stubLocalDiffusionGenerationRepository.generateFromText(any()))
.thenReturn(Single.just(mockAiGenerationResult))
@@ -409,7 +412,7 @@ class TextToImageUseCaseImplTest {
@Test
fun `given source is LOCAL, batch count is 1, generate failed, expected error`() {
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
whenever(stubLocalDiffusionGenerationRepository.generateFromText(any()))
.thenReturn(Single.error(stubException))
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
index 55290a49..5a5d12bc 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
@@ -69,12 +69,20 @@ class GetConfigurationUseCaseImplTest {
} returns mockConfiguration.stabilityAiEngineId
every {
- stubPreferenceManager::localModelId.get()
- } returns mockConfiguration.localModelId
+ stubPreferenceManager::localOnnxModelId.get()
+ } returns mockConfiguration.localOnnxModelId
every {
- stubPreferenceManager::localDiffusionCustomModelPath.get()
- } returns mockConfiguration.localModelPath
+ stubPreferenceManager::localOnnxCustomModelPath.get()
+ } returns mockConfiguration.localOnnxModelPath
+
+ every {
+ stubPreferenceManager::localMediaPipeModelId.get()
+ } returns mockConfiguration.localMediaPipeModelId
+
+ every {
+ stubPreferenceManager::localMediaPipeCustomModelPath.get()
+ } returns mockConfiguration.localMediaPipeModelPath
useCase
.invoke()
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
index 62a74c34..a1620120 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
@@ -68,11 +68,19 @@ class SetServerConfigurationUseCaseImplTest {
} returns Unit
every {
- stubPreferenceManager::localModelId.set(any())
+ stubPreferenceManager::localOnnxModelId.set(any())
} returns Unit
every {
- stubPreferenceManager::localDiffusionCustomModelPath.set(any())
+ stubPreferenceManager::localOnnxCustomModelPath.set(any())
+ } returns Unit
+
+ every {
+ stubPreferenceManager::localMediaPipeModelId.set(any())
+ } returns Unit
+
+ every {
+ stubPreferenceManager::localMediaPipeCustomModelPath.set(any())
} returns Unit
useCase
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt
index 1c2b3a0d..092b4c19 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/splash/SplashNavigationUseCaseImplTest.kt
@@ -89,7 +89,7 @@ class SplashNavigationUseCaseImplTest {
.thenReturn("")
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
useCase()
.test()
@@ -109,7 +109,7 @@ class SplashNavigationUseCaseImplTest {
.thenReturn("http://192.168.0.1:7860")
whenever(stubPreferenceManager.source)
- .thenReturn(ServerSource.LOCAL)
+ .thenReturn(ServerSource.LOCAL_MICROSOFT_ONNX)
useCase()
.test()
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt
index 56b3f39d..775cad2c 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/LocalDiffusionImpl.kt
@@ -4,6 +4,7 @@ import ai.onnxruntime.OnnxTensor
import android.graphics.Bitmap
import com.shifthackz.aisdv1.core.common.log.debugLog
import com.shifthackz.aisdv1.core.common.log.errorLog
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.TAG
@@ -20,7 +21,7 @@ internal class LocalDiffusionImpl(
private val ortEnvironmentProvider: OrtEnvironmentProvider,
) : LocalDiffusion {
- private val statusSubject: PublishSubject = PublishSubject.create()
+ private val statusSubject: PublishSubject = PublishSubject.create()
override fun process(payload: TextToImagePayload): Single = Single.create { emitter ->
try {
@@ -31,7 +32,7 @@ internal class LocalDiffusionImpl(
uNet.setCallback(object : UNet.Callback {
override fun onStep(maxStep: Int, step: Int) {
debugLog(TAG, "Received step update: ${maxStep}/${step}")
- statusSubject.onNext(LocalDiffusion.Status(step, maxStep))
+ statusSubject.onNext(LocalDiffusionStatus(step, maxStep))
}
override fun onBuildImage(status: Int, bitmap: Bitmap?) {
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
index 7cb6df87..1b22e6d9 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
@@ -11,8 +11,8 @@ fun modelPathPrefix(
localModelIdProvider: LocalModelIdProvider,
): String {
val modelId = localModelIdProvider.get()
- return if (modelId == LocalAiModel.CUSTOM.id) {
- preferenceManager.localDiffusionCustomModelPath
+ return if (modelId == LocalAiModel.CustomOnnx.id) {
+ preferenceManager.localOnnxCustomModelPath
} else {
"${fileProviderDescriptor.localModelDirPath}/${modelId}"
}
diff --git a/feature/mediapipe/.gitignore b/feature/mediapipe/.gitignore
new file mode 100644
index 00000000..42afabfd
--- /dev/null
+++ b/feature/mediapipe/.gitignore
@@ -0,0 +1 @@
+/build
\ No newline at end of file
diff --git a/feature/mediapipe/build.gradle.kts b/feature/mediapipe/build.gradle.kts
new file mode 100644
index 00000000..0249feed
--- /dev/null
+++ b/feature/mediapipe/build.gradle.kts
@@ -0,0 +1,17 @@
+plugins {
+ alias(libs.plugins.generic.library)
+ alias(libs.plugins.generic.flavors)
+}
+
+android {
+ namespace = "com.shifthackz.aisdv1.feature.mediapipe"
+}
+
+dependencies {
+ implementation(project(":core:common"))
+ implementation(project(":domain"))
+ implementation(libs.koin.core)
+ implementation(libs.rx.kotlin)
+ fullImplementation(libs.google.mediapipe.image.generator)
+ playstoreImplementation(libs.google.mediapipe.image.generator)
+}
diff --git a/feature/mediapipe/consumer-rules.pro b/feature/mediapipe/consumer-rules.pro
new file mode 100644
index 00000000..e69de29b
diff --git a/feature/mediapipe/proguard-rules.pro b/feature/mediapipe/proguard-rules.pro
new file mode 100644
index 00000000..481bb434
--- /dev/null
+++ b/feature/mediapipe/proguard-rules.pro
@@ -0,0 +1,21 @@
+# Add project specific ProGuard rules here.
+# You can control the set of applied configuration files using the
+# proguardFiles setting in build.gradle.
+#
+# For more details, see
+# http://developer.android.com/guide/developing/tools/proguard.html
+
+# If your project uses WebView with JS, uncomment the following
+# and specify the fully qualified class name to the JavaScript interface
+# class:
+#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
+# public *;
+#}
+
+# Uncomment this to preserve the line number information for
+# debugging stack traces.
+#-keepattributes SourceFile,LineNumberTable
+
+# If you keep the line number information, uncomment this to
+# hide the original source file name.
+#-renamesourcefileattribute SourceFile
\ No newline at end of file
diff --git a/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
new file mode 100644
index 00000000..3be83e5e
--- /dev/null
+++ b/feature/mediapipe/src/foss/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
@@ -0,0 +1,13 @@
+package com.shifthackz.aisdv1.feature.mediapipe
+
+import android.graphics.Bitmap
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe
+import io.reactivex.rxjava3.core.Single
+
+internal class MediaPipeImpl : MediaPipe {
+
+ override fun process(payload: TextToImagePayload): Single {
+ return Single.error(IllegalStateException("Google AI MediaPipe is not supported on FOSS build."))
+ }
+}
diff --git a/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
new file mode 100644
index 00000000..a6e7e98a
--- /dev/null
+++ b/feature/mediapipe/src/full/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
@@ -0,0 +1,63 @@
+package com.shifthackz.aisdv1.feature.mediapipe
+
+import android.content.Context
+import android.graphics.Bitmap
+import com.google.mediapipe.framework.image.BitmapExtractor
+import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
+import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions
+import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
+import com.shifthackz.aisdv1.core.common.log.debugLog
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
+import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath
+import io.reactivex.rxjava3.core.Single
+
+internal class MediaPipeImpl(
+ private val context: Context,
+ private val preferenceManager: PreferenceManager,
+ private val fileProviderDescriptor: FileProviderDescriptor,
+) : MediaPipe {
+
+ private var imageGenerator: ImageGenerator? = null
+
+ override fun process(payload: TextToImagePayload): Single = Single.create { emitter ->
+ try {
+ initialize()
+ debugLog("Generating...")
+ val result = imageGenerator?.generate(
+ payload.prompt,
+ payload.samplingSteps,
+ payload.seed.toIntOrNull() ?: 0,
+ )
+ debugLog("Extracting bitmap...")
+ val bitmap = BitmapExtractor.extract(result?.generatedImage())
+ debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}")
+ close()
+ if (!emitter.isDisposed) emitter.onSuccess(bitmap)
+ } catch (e: Exception) {
+ close()
+ if (!emitter.isDisposed) emitter.onError(e)
+ }
+ }
+
+ private fun initialize(): ImageGenerator {
+ val path = modelPath(preferenceManager, fileProviderDescriptor)
+
+ val options = ImageGeneratorOptions.builder()
+ .setImageGeneratorModelDirectory(path)
+ .build()
+
+ val generator = ImageGenerator.createFromOptions(context, options)
+ imageGenerator = generator
+ debugLog("Initialized successfully! Path: $path")
+ return generator
+ }
+
+ private fun close() = runCatching {
+ debugLog("Closing...")
+ imageGenerator?.close()
+ imageGenerator = null
+ debugLog("Session closed!")
+ }
+}
diff --git a/feature/mediapipe/src/main/AndroidManifest.xml b/feature/mediapipe/src/main/AndroidManifest.xml
new file mode 100644
index 00000000..44008a43
--- /dev/null
+++ b/feature/mediapipe/src/main/AndroidManifest.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt
new file mode 100644
index 00000000..3327827e
--- /dev/null
+++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/di/MediaPipeModule.kt
@@ -0,0 +1,11 @@
+package com.shifthackz.aisdv1.feature.mediapipe.di
+
+import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe
+import com.shifthackz.aisdv1.feature.mediapipe.MediaPipeImpl
+import org.koin.core.module.dsl.factoryOf
+import org.koin.dsl.bind
+import org.koin.dsl.module
+
+val mediaPipeModule = module {
+ factoryOf(::MediaPipeImpl) bind MediaPipe::class
+}
diff --git a/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt
new file mode 100644
index 00000000..91e49404
--- /dev/null
+++ b/feature/mediapipe/src/main/java/com/shifthackz/aisdv1/feature/mediapipe/extensions/MediaPipeModelPaths.kt
@@ -0,0 +1,17 @@
+package com.shifthackz.aisdv1.feature.mediapipe.extensions
+
+import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
+
+fun modelPath(
+ preferenceManager: PreferenceManager,
+ fileProviderDescriptor: FileProviderDescriptor,
+): String {
+ val modelId = preferenceManager.localMediaPipeModelId
+ return if (modelId == LocalAiModel.CustomMediaPipe.id) {
+ preferenceManager.localMediaPipeCustomModelPath
+ } else {
+ "${fileProviderDescriptor.localModelDirPath}/${modelId}"
+ }
+}
diff --git a/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
new file mode 100644
index 00000000..a6e7e98a
--- /dev/null
+++ b/feature/mediapipe/src/playstore/java/com/shifthackz/aisdv1/feature/mediapipe/MediaPipeImpl.kt
@@ -0,0 +1,63 @@
+package com.shifthackz.aisdv1.feature.mediapipe
+
+import android.content.Context
+import android.graphics.Bitmap
+import com.google.mediapipe.framework.image.BitmapExtractor
+import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator
+import com.google.mediapipe.tasks.vision.imagegenerator.ImageGenerator.ImageGeneratorOptions
+import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
+import com.shifthackz.aisdv1.core.common.log.debugLog
+import com.shifthackz.aisdv1.domain.entity.TextToImagePayload
+import com.shifthackz.aisdv1.domain.feature.mediapipe.MediaPipe
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
+import com.shifthackz.aisdv1.feature.mediapipe.extensions.modelPath
+import io.reactivex.rxjava3.core.Single
+
+internal class MediaPipeImpl(
+ private val context: Context,
+ private val preferenceManager: PreferenceManager,
+ private val fileProviderDescriptor: FileProviderDescriptor,
+) : MediaPipe {
+
+ private var imageGenerator: ImageGenerator? = null
+
+ override fun process(payload: TextToImagePayload): Single = Single.create { emitter ->
+ try {
+ initialize()
+ debugLog("Generating...")
+ val result = imageGenerator?.generate(
+ payload.prompt,
+ payload.samplingSteps,
+ payload.seed.toIntOrNull() ?: 0,
+ )
+ debugLog("Extracting bitmap...")
+ val bitmap = BitmapExtractor.extract(result?.generatedImage())
+ debugLog("bitmap = $bitmap, ${bitmap.width}X${bitmap.height}")
+ close()
+ if (!emitter.isDisposed) emitter.onSuccess(bitmap)
+ } catch (e: Exception) {
+ close()
+ if (!emitter.isDisposed) emitter.onError(e)
+ }
+ }
+
+ private fun initialize(): ImageGenerator {
+ val path = modelPath(preferenceManager, fileProviderDescriptor)
+
+ val options = ImageGeneratorOptions.builder()
+ .setImageGeneratorModelDirectory(path)
+ .build()
+
+ val generator = ImageGenerator.createFromOptions(context, options)
+ imageGenerator = generator
+ debugLog("Initialized successfully! Path: $path")
+ return generator
+ }
+
+ private fun close() = runCatching {
+ debugLog("Closing...")
+ imageGenerator?.close()
+ imageGenerator = null
+ debugLog("Session closed!")
+ }
+}
diff --git a/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt
new file mode 100644
index 00000000..412481c5
--- /dev/null
+++ b/feature/mediapipe/src/test/java/com/shifthackz/aisdv1/feature/mediapipe/ExampleUnitTest.kt
@@ -0,0 +1,16 @@
+package com.shifthackz.aisdv1.feature.mediapipe
+
+import org.junit.Assert.*
+import org.junit.Test
+
+/**
+ * Example local unit test, which will execute on the development machine (host).
+ *
+ * See [testing documentation](http://d.android.com/tools/testing).
+ */
+class ExampleUnitTest {
+ @Test
+ fun addition_isCorrect() {
+ assertEquals(4, 2 + 2)
+ }
+}
\ No newline at end of file
diff --git a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt
index d3076dee..b77330ee 100644
--- a/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt
+++ b/feature/work/src/main/java/com/shifthackz/aisdv1/work/core/CoreGenerationWorker.kt
@@ -93,7 +93,7 @@ internal abstract class CoreGenerationWorker(
body = subTitle,
silent = true,
progress = status.current to status.total,
- canCancel = preferenceManager.localDiffusionAllowCancel,
+ canCancel = preferenceManager.localOnnxAllowCancel,
)
}
}
@@ -112,7 +112,7 @@ internal abstract class CoreGenerationWorker(
setForegroundNotification(
title = title,
body = subTitle,
- canCancel = source != ServerSource.LOCAL,
+ canCancel = source != ServerSource.LOCAL_MICROSOFT_ONNX,
)
}
diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml
index d2f0faac..d1a1645c 100644
--- a/gradle/libs.versions.toml
+++ b/gradle/libs.versions.toml
@@ -46,6 +46,7 @@ catppuccin = "0.1.2"
turbine = "1.1.0"
roboelectric = "4.13"
testCoroutines = "1.8.1"
+mediaPipeGenerator = "0.10.14"
[libraries]
android-tools-build-gradle = { group = "com.android.tools.build", name = "gradle", version.ref = "agp"}
@@ -77,6 +78,7 @@ androidx-work-runtime = { group = "androidx.work", name = "work-runtime", versio
google-gson = { group = "com.google.code.gson", name = "gson", version.ref = "gson" }
google-material = { group = "com.google.android.material", name = "material", version.ref = "material" }
google-accompanist-systemuicontroller = { group = "com.google.accompanist", name = "accompanist-systemuicontroller", version.ref = "accompanistSystemUi" }
+google-mediapipe-image-generator = { group = "com.google.mediapipe", name = "tasks-vision-image-generator", version.ref = "mediaPipeGenerator" }
retrofit-core = { group = "com.squareup.retrofit2", name = "retrofit", version.ref = "retrofit" }
retrofit-converter-gson = { group = "com.squareup.retrofit2", name = "converter-gson", version.ref = "retrofit" }
retrofit-adapter-rxjava3 = { group = "com.squareup.retrofit2", name = "adapter-rxjava3", version.ref = "retrofit" }
@@ -115,6 +117,7 @@ android-application = { id = "com.android.application", version.ref = "agp" }
android-library = { id = "com.android.library", version.ref = "agp" }
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
jetbrains-kotlin-kapt = { id = "org.jetbrains.kotlin.kapt", version="unspecified" }
+generic-flavors = { id = "generic.flavors", version = "unspecified" }
generic-library = { id = "generic.library", version = "unspecified" }
generic-baseline-profm = { id = "generic.baseline.profm", version = "unspecified" }
generic-compose = { id = "generic.compose", version = "unspecified" }
diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt
index 90adc17b..78a35382 100644
--- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt
+++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApi.kt
@@ -11,7 +11,9 @@ import java.io.File
interface DownloadableModelsApi {
- fun fetchDownloadableModels(): Single>
+ fun fetchOnnxModels(): Single>
+
+ fun fetchMediaPipeModels(): Single>
fun downloadModel(
remoteUrl: String,
@@ -23,7 +25,10 @@ interface DownloadableModelsApi {
interface RawApi {
@GET("/models.json")
- fun fetchDownloadableModels(): Single>
+ fun fetchOnnxModels(): Single>
+
+ @GET("/mediapipe.json")
+ fun fetchMediaPipeModels(): Single>
@Streaming
@GET
diff --git a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt
index 88ddc5ef..04fd71c1 100644
--- a/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt
+++ b/network/src/main/java/com/shifthackz/aisdv1/network/api/sdai/DownloadableModelsApiImpl.kt
@@ -1,6 +1,7 @@
package com.shifthackz.aisdv1.network.api.sdai
import com.shifthackz.aisdv1.network.extensions.saveFile
+import com.shifthackz.aisdv1.network.response.DownloadableModelResponse
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import java.io.File
@@ -9,7 +10,9 @@ internal class DownloadableModelsApiImpl(
private val rawApi: DownloadableModelsApi.RawApi,
) : DownloadableModelsApi {
- override fun fetchDownloadableModels() = rawApi.fetchDownloadableModels()
+ override fun fetchOnnxModels() = rawApi.fetchOnnxModels()
+
+ override fun fetchMediaPipeModels() = rawApi.fetchMediaPipeModels()
override fun downloadModel(
remoteUrl: String,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt
index 714f0ef9..689fc79c 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/core/GenerationMviViewModel.kt
@@ -9,6 +9,7 @@ import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread
import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator
import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel
import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.OpenAiSize
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.entity.StabilityAiSampler
@@ -118,7 +119,7 @@ abstract class GenerationMviViewModel?
get() = status?.let { (current, total) -> current to total }
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt
index 125eb82c..a9a2a438 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/navigation/router/main/MainRouterImpl.kt
@@ -47,7 +47,7 @@ internal class MainRouterImpl : MainRouter {
override fun navigateToServerSetup(source: LaunchSource) {
effectSubject.onNext(NavigationEffect.Navigate.RouteBuilder("${Constants.ROUTE_SERVER_SETUP}/${source.ordinal}") {
if (source == LaunchSource.SPLASH) {
- popUpTo(Constants.ROUTE_SPLASH) {
+ popUpTo(0) {
inclusive = true
}
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt
index 0d2fd5df..c6b5f7de 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/debug/DebugMenuViewModel.kt
@@ -64,7 +64,7 @@ class DebugMenuViewModel(
DebugMenuIntent.ViewLogs -> mainRouter.navigateToLogger()
DebugMenuIntent.AllowLocalDiffusionCancel -> {
- preferenceManager.localDiffusionAllowCancel = !currentState.localDiffusionAllowCancel
+ preferenceManager.localOnnxAllowCancel = !currentState.localDiffusionAllowCancel
}
DebugMenuIntent.LocalDiffusionScheduler.Request -> updateState {
@@ -72,7 +72,7 @@ class DebugMenuViewModel(
}
is DebugMenuIntent.LocalDiffusionScheduler.Confirm -> {
- preferenceManager.localDiffusionSchedulerThread = intent.token
+ preferenceManager.localOnnxSchedulerThread = intent.token
}
DebugMenuIntent.DismissModal -> updateState {
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt
index d5ad617a..482ea2d9 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryPagingSource.kt
@@ -32,7 +32,7 @@ class GalleryPagingSource(
limit = pageSize,
offset = pageNext * Constants.PAGINATION_PAYLOAD_SIZE,
)
- .subscribeOn(schedulersProvider.io)
+ .subscribeOn(schedulersProvider.computation)
.flatMapObservable { Observable.fromIterable(it) }
.map { ai -> ai.id to ai.image }
.map { (id, base64) -> id to Input(base64) }
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt
index a6052000..67572859 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/gallery/list/GalleryScreen.kt
@@ -468,7 +468,7 @@ fun GalleryScreenContent(
val selected = state.selection.contains(item.id)
GalleryUiItem(
modifier = Modifier
- .animateItemPlacement(tween(500))
+ .animateItem(tween(500))
.shake(
enabled = state.selectionMode && !selected,
animationDurationMillis = 188,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt
index fe8eeec1..000bca56 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageScreen.kt
@@ -54,6 +54,9 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
import com.shifthackz.aisdv1.core.common.math.roundTo
+import com.shifthackz.aisdv1.core.model.UiText
+import com.shifthackz.aisdv1.core.model.asString
+import com.shifthackz.aisdv1.core.model.asUiText
import com.shifthackz.aisdv1.core.ui.MviComponent
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
import com.shifthackz.aisdv1.domain.entity.ServerSource
@@ -135,7 +138,7 @@ private fun ScreenContent(
)
},
actions = {
- if (state.mode != ServerSource.LOCAL) {
+ if (state.mode != ServerSource.LOCAL_MICROSOFT_ONNX) {
IconButton(
onClick = {
processIntent(
@@ -233,17 +236,33 @@ private fun ScreenContent(
)
Text(
modifier = Modifier.padding(top = 14.dp),
- text = stringResource(
- if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title
- else LocalizationR.string.dalle_no_img2img_support_sub_title
- ),
+ text = when (state.mode) {
+ ServerSource.OPEN_AI -> LocalizationR.string
+ .dalle_no_img2img_support_sub_title
+ .asUiText()
+
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string
+ .local_no_img2img_support_sub_title
+ .asUiText()
+
+ else -> UiText.empty
+ }.asString(),
)
Text(
modifier = Modifier.padding(top = 14.dp),
- text = stringResource(
- if (state.mode == ServerSource.LOCAL) LocalizationR.string.local_no_img2img_support_sub_title_2
- else LocalizationR.string.dalle_no_img2img_support_sub_title_2
- ),
+ text = when (state.mode) {
+ ServerSource.OPEN_AI -> LocalizationR.string
+ .dalle_no_img2img_support_sub_title_2
+ .asUiText()
+
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string
+ .local_no_img2img_support_sub_title_2
+ .asUiText()
+
+ else -> UiText.empty
+ }.asString(),
)
}
}
@@ -251,7 +270,8 @@ private fun ScreenContent(
},
bottomBar = {
val isEnabled = when (state.mode) {
- ServerSource.LOCAL,
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE,
ServerSource.OPEN_AI -> true
else -> !state.hasValidationErrors && !state.imageState.isEmpty
@@ -271,20 +291,28 @@ private fun ScreenContent(
keyboardController?.hide()
when (state.mode) {
ServerSource.OPEN_AI,
- ServerSource.LOCAL -> processIntent(GenerationMviIntent.Configuration)
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE,
+ ServerSource.LOCAL_MICROSOFT_ONNX -> processIntent(
+ GenerationMviIntent.Configuration
+ )
else -> {
promptChipTextFieldState.value.text.takeIf(String::isNotBlank)
?.let { "${state.prompt}, ${it.trim()}" }
?.let(GenerationMviIntent.Update::Prompt)
?.let(processIntent::invoke)
- ?.also { promptChipTextFieldState.value = TextFieldValue("") }
+ ?.also {
+ promptChipTextFieldState.value = TextFieldValue("")
+ }
negativePromptChipTextFieldState.value.text.takeIf(String::isNotBlank)
?.let { "${state.negativePrompt}, ${it.trim()}" }
?.let(GenerationMviIntent.Update::NegativePrompt)
?.let(processIntent::invoke)
- ?.also { negativePromptChipTextFieldState.value = TextFieldValue("") }
+ ?.also {
+ negativePromptChipTextFieldState.value =
+ TextFieldValue("")
+ }
processIntent(GenerationMviIntent.Generate)
}
@@ -292,8 +320,11 @@ private fun ScreenContent(
},
enabled = isEnabled,
) {
- if (state.mode != ServerSource.LOCAL) {
- Icon(
+ when (state.mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit
+
+ else -> Icon(
modifier = Modifier.size(18.dp),
imageVector = Icons.Default.AutoFixNormal,
contentDescription = "Imagine",
@@ -303,7 +334,8 @@ private fun ScreenContent(
modifier = Modifier.padding(start = 8.dp),
text = stringResource(
id = when (state.mode) {
- ServerSource.LOCAL,
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE,
ServerSource.OPEN_AI -> LocalizationR.string.action_change_configuration
else -> LocalizationR.string.action_generate
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt
index 533b7b9b..3e301a66 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/img2img/ImageToImageState.kt
@@ -90,7 +90,7 @@ data class ImageToImageState(
heightValidationError: UiText?,
nsfw: Boolean,
batchCount: Int,
- generateButtonEnabled: Boolean
+ generateButtonEnabled: Boolean,
): GenerationMviState = copy(
onBoardingDemo = onBoardingDemo,
screenModal = screenModal,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt
index 47e94240..98c45cbd 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/inpaint/components/InPaintComponent.kt
@@ -133,7 +133,7 @@ fun InPaintComponent(
}
MotionEvent.Move -> {
- currentPath.quadraticBezierTo(
+ currentPath.quadraticTo(
previousPosition.x,
previousPosition.y,
(previousPosition.x + currentPosition.x) / 2,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt
index c97b3966..e3833534 100755
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/loader/ConfigurationLoaderViewModel.kt
@@ -11,6 +11,7 @@ import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter
import com.shifthackz.android.core.mvi.EmptyEffect
import com.shifthackz.android.core.mvi.EmptyIntent
import io.reactivex.rxjava3.kotlin.subscribeBy
+import java.util.concurrent.TimeUnit
import com.shifthackz.aisdv1.core.localization.R as LocalizationR
class ConfigurationLoaderViewModel(
@@ -28,6 +29,7 @@ class ConfigurationLoaderViewModel(
init {
!dataPreLoaderUseCase()
+ .timeout(15L, TimeUnit.SECONDS)
.doOnSubscribe {
updateState {
ConfigurationLoaderState.StatusNotification(
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt
index fbfb6e79..9138f169 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/onboarding/page/LocalDiffusionPageContent.kt
@@ -60,7 +60,7 @@ fun LocalDiffusionPageContent(
modifier = localModifier,
state = TextToImageState(
onBoardingDemo = true,
- mode = ServerSource.LOCAL,
+ mode = ServerSource.LOCAL_MICROSOFT_ONNX,
advancedToggleButtonVisible = false,
advancedOptionsVisible = true,
formPromptTaggedInput = true,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt
index 7e50da66..e611a02e 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsScreen.kt
@@ -235,7 +235,8 @@ private fun ContentSettingsState(
ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face_short
ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai
ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai
- ServerSource.LOCAL -> LocalizationR.string.srv_type_local_short
+ ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local_short
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe_short
ServerSource.SWARM_UI -> LocalizationR.string.srv_type_swarm_ui
}.asUiText(),
onClick = { processIntent(SettingsIntent.NavigateConfiguration) },
@@ -256,7 +257,7 @@ private fun ContentSettingsState(
endValueText = state.sdModelSelected.asUiText(),
onClick = { processIntent(SettingsIntent.SdModel.OpenChooser) },
)
- if (state.showLocalUseNNAPI) {
+ if (state.showLocalMICROSOFTONNXUseNNAPI) {
SettingsItem(
modifier = itemModifier,
loading = state.loading,
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt
index aff0ad43..312ecd55 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsState.kt
@@ -37,8 +37,8 @@ data class SettingsState(
val showStabilityAiCredits: Boolean
get() = serverSource == ServerSource.STABILITY_AI
- val showLocalUseNNAPI: Boolean
- get() = serverSource == ServerSource.LOCAL
+ val showLocalMICROSOFTONNXUseNNAPI: Boolean
+ get() = serverSource == ServerSource.LOCAL_MICROSOFT_ONNX
val showSdModelSelector: Boolean
get() = serverSource == ServerSource.AUTOMATIC1111
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt
index a669acc0..d7756582 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModel.kt
@@ -25,6 +25,7 @@ import com.shifthackz.aisdv1.presentation.screen.debug.DebugMenuAccessor
import com.shifthackz.aisdv1.presentation.screen.drawer.DrawerIntent
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.kotlin.subscribeBy
+import java.util.concurrent.TimeUnit
import com.shifthackz.aisdv1.core.localization.R as LocalizationR
class SettingsViewModel(
@@ -48,6 +49,7 @@ class SettingsViewModel(
private val appVersionProducer = Flowable.fromCallable { buildInfoProvider.toString() }
private val sdModelsProducer = getStableDiffusionModelsUseCase()
+ .timeout(10L, TimeUnit.SECONDS)
.toFlowable()
.onErrorReturn { emptyList() }
@@ -142,7 +144,7 @@ class SettingsViewModel(
}
is SettingsIntent.UpdateFlag.NNAPI -> {
- preferenceManager.localUseNNAPI = intent.flag
+ preferenceManager.localOnnxUseNNAPI = intent.flag
}
is SettingsIntent.UpdateFlag.TaggedInput -> {
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt
index a83405f0..612b4c90 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreen.kt
@@ -153,7 +153,11 @@ fun ServerSetupScreenContent(
onClick = { processIntent(ServerSetupIntent.MainButtonClick) },
enabled = when (state.step) {
ServerSetupState.Step.CONFIGURE -> when (state.mode) {
- ServerSource.LOCAL -> state.localModels.any {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> state.localOnnxModels.any {
+ it.downloaded && it.selected
+ }
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> state.localMediaPipeModels.any {
it.downloaded && it.selected
}
@@ -168,7 +172,9 @@ fun ServerSetupScreenContent(
id = when (state.step) {
ServerSetupState.Step.SOURCE -> LocalizationR.string.next
else -> when (state.mode) {
- ServerSource.LOCAL -> LocalizationR.string.action_setup
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string
+ .action_setup
else -> LocalizationR.string.action_connect
}
},
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
index 553e7bde..44a37715 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
@@ -5,9 +5,11 @@ import com.shifthackz.aisdv1.core.common.links.LinksProvider
import com.shifthackz.aisdv1.core.model.UiText
import com.shifthackz.aisdv1.domain.entity.Configuration
import com.shifthackz.aisdv1.domain.entity.DownloadState
+import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials
import com.shifthackz.aisdv1.presentation.model.Modal
+import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState
import com.shifthackz.aisdv1.presentation.utils.Constants
import com.shifthackz.android.core.mvi.MviState
import org.koin.core.component.KoinComponent
@@ -33,9 +35,12 @@ data class ServerSetupState(
val password: String = "",
val huggingFaceModels: List = emptyList(),
val huggingFaceModel: String = "",
- val localModels: List = emptyList(),
- val localCustomModel: Boolean = false,
- val localCustomModelPath: String = "",
+ val localOnnxModels: List = emptyList(),
+ val localOnnxCustomModel: Boolean = false,
+ val localOnnxCustomModelPath: String = "",
+ val localMediaPipeModels: List = emptyList(),
+ val localMediaPipeCustomModel: Boolean = false,
+ val localMediaPipeCustomModelPath: String = "",
val passwordVisible: Boolean = false,
val serverUrlValidationError: UiText? = null,
val swarmUiUrlValidationError: UiText? = null,
@@ -45,9 +50,38 @@ data class ServerSetupState(
val huggingFaceApiKeyValidationError: UiText? = null,
val openAiApiKeyValidationError: UiText? = null,
val stabilityAiApiKeyValidationError: UiText? = null,
- val localCustomModelPathValidationError: UiText? = null,
+ val localCustomOnnxPathValidationError: UiText? = null,
+ val localCustomMediaPipePathValidationError: UiText? = null,
) : MviState, KoinComponent {
+ val localCustomModel: Boolean
+ get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ localOnnxCustomModel
+ } else {
+ localMediaPipeCustomModel
+ }
+
+ val localCustomModelPath: String
+ get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ localOnnxCustomModelPath
+ } else {
+ localMediaPipeCustomModelPath
+ }
+
+ val localModels: List
+ get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ localOnnxModels
+ } else {
+ localMediaPipeModels
+ }
+
+ val localCustomModelPathValidationError: UiText?
+ get() = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ localCustomOnnxPathValidationError
+ } else {
+ localCustomMediaPipePathValidationError
+ }
+
val demoModeUrl: String
get() {
val linksProvider: LinksProvider by inject()
@@ -60,13 +94,101 @@ data class ServerSetupState(
)
fun withCredentials(value: AuthorizationCredentials) = when (value) {
- is AuthorizationCredentials.HttpBasic -> this.copy(
+ is AuthorizationCredentials.HttpBasic -> copy(
login = value.login,
password = value.password,
)
+
AuthorizationCredentials.None -> this
}
+ fun withLocalCustomModelPath(value: String): ServerSetupState = when (mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> copy(
+ localOnnxCustomModelPath = value,
+ localCustomOnnxPathValidationError = null,
+ )
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy(
+ localMediaPipeCustomModelPath = value,
+ localCustomMediaPipePathValidationError = null,
+ )
+
+ else -> this
+ }
+
+ fun withUpdatedLocalModel(value: LocalModel): ServerSetupState = when (mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> copy(
+ localOnnxModels = localOnnxModels.withNewState(value)
+ )
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy(
+ localMediaPipeModels = localMediaPipeModels.withNewState(value)
+ )
+ else -> this
+ }
+
+ fun withDeletedLocalModel(value: LocalModel): ServerSetupState = when (mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> copy(
+ screenModal = Modal.None,
+ localOnnxModels = localOnnxModels.withNewState(
+ value.copy(
+ downloadState = DownloadState.Unknown,
+ downloaded = false,
+ ),
+ )
+ )
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy(
+ screenModal = Modal.None,
+ localMediaPipeModels = localMediaPipeModels.withNewState(
+ value.copy(
+ downloadState = DownloadState.Unknown,
+ downloaded = false,
+ ),
+ )
+ )
+
+ else -> copy(screenModal = Modal.None)
+ }
+
+ fun withSelectedLocalModel(value: LocalModel): ServerSetupState = when (mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> copy(
+ localOnnxModels = localOnnxModels.withNewState(
+ value.copy(selected = true),
+ ),
+ )
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> copy(
+ localMediaPipeModels = localMediaPipeModels.withNewState(
+ value.copy(selected = true),
+ ),
+ )
+
+ else -> this
+ }
+
+ fun withAllowCustomModel(value: Boolean): ServerSetupState {
+ fun List.updateCustomModelSelection(id: String) = withNewState(
+ find { m -> m.id == id }?.copy(selected = value)
+ )
+ return when (mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> this.copy(
+ localOnnxCustomModel = value,
+ localOnnxModels = localOnnxModels.updateCustomModelSelection(
+ id = LocalAiModel.CustomOnnx.id,
+ ),
+ )
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> this.copy(
+ localMediaPipeCustomModel = value,
+ localMediaPipeModels = localMediaPipeModels.updateCustomModelSelection(
+ id = LocalAiModel.CustomMediaPipe.id,
+ ),
+ )
+
+ else -> this
+ }
+ }
+
enum class Step {
SOURCE,
CONFIGURE;
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
index 3677b6b3..d054a301 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
@@ -1,7 +1,10 @@
package com.shifthackz.aisdv1.presentation.screen.setup
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
+import com.shifthackz.aisdv1.core.common.appbuild.BuildType
import com.shifthackz.aisdv1.core.common.log.errorLog
import com.shifthackz.aisdv1.core.common.schedulers.DispatchersProvider
+import com.shifthackz.aisdv1.core.common.model.Quadruple
import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread
import com.shifthackz.aisdv1.core.model.asUiText
@@ -11,7 +14,6 @@ import com.shifthackz.aisdv1.core.validation.url.UrlValidator
import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel
import com.shifthackz.aisdv1.domain.entity.DownloadState
import com.shifthackz.aisdv1.domain.entity.HuggingFaceModel
-import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.feature.auth.AuthorizationCredentials
import com.shifthackz.aisdv1.domain.interactor.settings.SetupConnectionInterActor
@@ -19,15 +21,17 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
-import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase
import com.shifthackz.aisdv1.presentation.model.LaunchSource
import com.shifthackz.aisdv1.presentation.model.Modal
import com.shifthackz.aisdv1.presentation.navigation.router.main.MainRouter
-import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomModelSwitchState
+import com.shifthackz.aisdv1.presentation.screen.setup.mappers.allowedModes
+import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomMediaPipeSwitchState
+import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapLocalCustomOnnxSwitchState
import com.shifthackz.aisdv1.presentation.screen.setup.mappers.mapToUi
-import com.shifthackz.aisdv1.presentation.screen.setup.mappers.withNewState
import com.shifthackz.aisdv1.presentation.utils.Constants
import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.disposables.Disposable
@@ -37,7 +41,8 @@ class ServerSetupViewModel(
launchSource: LaunchSource,
dispatchersProvider: DispatchersProvider,
getConfigurationUseCase: GetConfigurationUseCase,
- getLocalAiModelsUseCase: GetLocalAiModelsUseCase,
+ getLocalOnnxModelsUseCase: GetLocalOnnxModelsUseCase,
+ getLocalMediaPipeModelsUseCase: GetLocalMediaPipeModelsUseCase,
fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase,
private val urlValidator: UrlValidator,
private val stringValidator: CommonStringValidator,
@@ -49,6 +54,7 @@ class ServerSetupViewModel(
private val preferenceManager: PreferenceManager,
private val wakeLockInterActor: WakeLockInterActor,
private val mainRouter: MainRouter,
+ private val buildInfoProvider: BuildInfoProvider,
) : MviRxViewModel() {
override val initialState = ServerSetupState(
@@ -74,12 +80,13 @@ class ServerSetupViewModel(
init {
!Single.zip(
getConfigurationUseCase(),
- getLocalAiModelsUseCase(),
+ getLocalOnnxModelsUseCase(),
+ getLocalMediaPipeModelsUseCase(),
fetchAndGetHuggingFaceModelsUseCase(),
- ::Triple,
+ ::Quadruple,
)
.subscribeOnMainThread(schedulersProvider)
- .subscribeBy(::errorLog) { (configuration, localModels, hfModels) ->
+ .subscribeBy(::errorLog) { (configuration, onnxModels, mpModels, hfModels) ->
updateState { state ->
state.copy(
huggingFaceModels = hfModels.map(HuggingFaceModel::alias),
@@ -87,10 +94,14 @@ class ServerSetupViewModel(
huggingFaceApiKey = configuration.huggingFaceApiKey,
openAiApiKey = configuration.openAiApiKey,
stabilityAiApiKey = configuration.stabilityAiApiKey,
- localModels = localModels.mapToUi(),
- localCustomModel = localModels.mapLocalCustomModelSwitchState(),
- localCustomModelPath = configuration.localModelPath,
+ localOnnxModels = onnxModels.mapToUi(),
+ localOnnxCustomModel = onnxModels.mapLocalCustomOnnxSwitchState(),
+ localOnnxCustomModelPath = configuration.localOnnxModelPath,
+ localMediaPipeModels = mpModels.mapToUi(),
+ localMediaPipeCustomModel = mpModels.mapLocalCustomMediaPipeSwitchState(),
+ localMediaPipeCustomModelPath = configuration.localMediaPipeModelPath,
mode = configuration.source,
+ allowedModes = buildInfoProvider.allowedModes,
demoMode = configuration.demoMode,
serverUrl = configuration.serverUrl,
swarmUiUrl = configuration.swarmUiUrl,
@@ -110,15 +121,8 @@ class ServerSetupViewModel(
}
override fun processIntent(intent: ServerSetupIntent) = when (intent) {
- is ServerSetupIntent.AllowLocalCustomModel -> updateState {
- it.copy(
- localCustomModel = intent.allow,
- localModels = currentState.localModels.withNewState(
- currentState.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }?.copy(
- selected = intent.allow,
- ),
- ),
- )
+ is ServerSetupIntent.AllowLocalCustomModel -> updateState { state ->
+ state.withAllowCustomModel(intent.allow)
}
ServerSetupIntent.DismissDialog -> setScreenModal(Modal.None)
@@ -129,28 +133,11 @@ class ServerSetupViewModel(
!deleteModelUseCase(intent.model.id)
.subscribeOnMainThread(schedulersProvider)
.subscribeBy(::errorLog)
- it.copy(
- screenModal = Modal.None,
- localModels = currentState.localModels.withNewState(
- intent.model.copy(
- downloadState = DownloadState.Unknown,
- downloaded = false,
- ),
- ),
- )
+ it.withDeletedLocalModel(intent.model)
}
- is ServerSetupIntent.SelectLocalModel -> {
- if (currentState.localModels.any { it.downloadState is DownloadState.Downloading }) {
- Unit
- }
- updateState {
- it.copy(
- localModels = currentState.localModels.withNewState(
- intent.model.copy(selected = true),
- ),
- )
- }
+ is ServerSetupIntent.SelectLocalModel -> updateState { state ->
+ state.withSelectedLocalModel(intent.model)
}
ServerSetupIntent.MainButtonClick -> when (currentState.step) {
@@ -236,11 +223,8 @@ class ServerSetupViewModel(
ServerSetupIntent.ConnectToLocalHost -> connectToServer()
- is ServerSetupIntent.SelectLocalModelPath -> updateState {
- it.copy(
- localCustomModelPath = intent.value,
- localCustomModelPathValidationError = null,
- )
+ is ServerSetupIntent.SelectLocalModelPath -> updateState { state ->
+ state.withLocalCustomModelPath(intent.value)
}
}
@@ -253,12 +237,13 @@ class ServerSetupViewModel(
emitEffect(ServerSetupEffect.HideKeyboard)
!when (currentState.mode) {
ServerSource.HORDE -> connectToHorde()
- ServerSource.LOCAL -> connectToLocalDiffusion()
+ ServerSource.LOCAL_MICROSOFT_ONNX -> connectToLocalDiffusion()
ServerSource.AUTOMATIC1111 -> connectToAutomaticInstance()
ServerSource.HUGGING_FACE -> connectToHuggingFace()
ServerSource.OPEN_AI -> connectToOpenAi()
ServerSource.STABILITY_AI -> connectToStabilityAi()
ServerSource.SWARM_UI -> connectToSwarmUi()
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> connectToMediaPipe()
}
.doOnSubscribe { setScreenModal(Modal.Communicating(canCancel = false)) }
.subscribeOnMainThread(schedulersProvider)
@@ -292,15 +277,27 @@ class ServerSetupViewModel(
}
}
- ServerSource.LOCAL -> {
- if (currentState.localCustomModel) {
- val validation = filePathValidator(currentState.localCustomModelPath)
+ ServerSource.LOCAL_MICROSOFT_ONNX -> if (currentState.localOnnxCustomModel) {
+ val validation = filePathValidator(currentState.localOnnxCustomModelPath)
+ updateState {
+ it.copy(localCustomOnnxPathValidationError = validation.mapToUi())
+ }
+ validation.isValid
+ } else {
+ currentState.localOnnxModels.find { it.selected && it.downloaded } != null
+ }
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> when {
+ buildInfoProvider.type == BuildType.FOSS -> false
+ currentState.localMediaPipeCustomModel -> {
+ val validation = filePathValidator(currentState.localMediaPipeCustomModelPath)
updateState {
- it.copy(localCustomModelPathValidationError = validation.mapToUi())
+ it.copy(localCustomMediaPipePathValidationError = validation.mapToUi())
}
validation.isValid
- } else {
- currentState.localModels.find { it.selected && it.downloaded } != null
+ }
+ else -> {
+ currentState.localMediaPipeModels.find { it.selected && it.downloaded } != null
}
}
@@ -405,87 +402,81 @@ class ServerSetupViewModel(
}
private fun connectToLocalDiffusion(): Single> {
- preferenceManager.localDiffusionCustomModelPath = currentState.localCustomModelPath
- val localModelId = currentState.localModels.find { it.selected }?.id ?: ""
+ preferenceManager.localOnnxCustomModelPath = currentState.localOnnxCustomModelPath
+ val localModelId = currentState.localOnnxModels.find { it.selected }?.id ?: ""
return setupConnectionInterActor.connectToLocal(localModelId)
}
- private fun localModelDownloadClickReducer(localModel: ServerSetupState.LocalModel) {
+ private fun connectToMediaPipe(): Single> {
+ preferenceManager.localMediaPipeCustomModelPath = currentState.localMediaPipeCustomModelPath
+ val localModelId = currentState.localMediaPipeModels.find { it.selected }?.id ?: ""
+ return setupConnectionInterActor.connectToMediaPipe(localModelId)
+ }
+
+ private fun localModelDownloadClickReducer(value: ServerSetupState.LocalModel) {
+ fun localModel(): ServerSetupState.LocalModel =
+ currentState.localModels.firstOrNull { it.id == value.id }
+ ?.let { value.copy(selected = it.selected) }
+ ?: value
+
when {
// User cancels download
- localModel.downloadState is DownloadState.Downloading -> {
- val index = downloadDisposables.indexOfFirst { it.first == localModel.id }
+ localModel().downloadState is DownloadState.Downloading -> {
+ val index = downloadDisposables.indexOfFirst { it.first == localModel().id }
if (index != -1) {
downloadDisposables[index].second.dispose()
downloadDisposables.removeAt(index)
}
- !deleteModelUseCase(localModel.id)
+ !deleteModelUseCase(localModel().id)
.subscribeOnMainThread(schedulersProvider)
.subscribeBy(::errorLog)
- updateState {
- it.copy(
- localModels = currentState.localModels.withNewState(
- localModel.copy(downloadState = DownloadState.Unknown),
- ),
+ updateState { state ->
+ state.withUpdatedLocalModel(
+ value = localModel().copy(downloadState = DownloadState.Unknown),
)
}
}
// User deletes local model
- localModel.downloaded -> updateState {
- it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel))
+ localModel().downloaded -> updateState {
+ it.copy(screenModal = Modal.DeleteLocalModelConfirm(localModel()))
}
// User requested new download operation
else -> {
- updateState {
- it.copy(
- localModels = currentState.localModels.withNewState(
- localModel.copy(
- downloadState = DownloadState.Downloading(),
- ),
- ),
+ updateState { state ->
+ state.withUpdatedLocalModel(
+ localModel().copy(downloadState = DownloadState.Downloading()),
)
}
- !downloadModelUseCase(localModel.id)
+ !downloadModelUseCase(localModel().id)
.distinctUntilChanged()
.doOnSubscribe { wakeLockInterActor.acquireWakelockUseCase() }
.doFinally { wakeLockInterActor.releaseWakeLockUseCase() }
- .subscribeOnMainThread(schedulersProvider).subscribeBy(
+ .subscribeOnMainThread(schedulersProvider)
+ .subscribeBy(
onError = { t ->
errorLog(t)
val message = t.localizedMessage ?: "Error"
- updateState {
- it.copy(
- localModels = currentState.localModels.withNewState(
- localModel.copy(
- downloadState = DownloadState.Error(t),
- ),
+ updateState { state ->
+ state.withUpdatedLocalModel(
+ localModel().copy(
+ downloadState = DownloadState.Error(t),
),
)
}
setScreenModal(Modal.Error(message.asUiText()))
},
onNext = { downloadState ->
- updateState {
- when (downloadState) {
- is DownloadState.Complete -> it.copy(
- localModels = it.localModels.withNewState(
- localModel.copy(
- downloadState = downloadState,
- downloaded = true,
- ),
- ),
- )
-
- else -> it.copy(
- localModels = it.localModels.withNewState(
- localModel.copy(downloadState = downloadState),
- ),
- )
- }
+ updateState { state ->
+ state.withUpdatedLocalModel(
+ localModel().copy(
+ downloadState = downloadState,
+ downloaded = downloadState is DownloadState.Complete
+ ),
+ )
}
},
)
- .also { downloadDisposables.add(localModel.id to it) }
+ .also { downloadDisposables.add(localModel().id to it) }
}
}
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt
index 19815b66..c84b0547 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/components/ConfigurationModeButton.kt
@@ -72,7 +72,8 @@ fun ConfigurationModeButton(
ServerSource.OPEN_AI,
ServerSource.STABILITY_AI,
ServerSource.HUGGING_FACE -> Icons.Default.Cloud
- ServerSource.LOCAL -> Icons.Default.Android
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Icons.Default.Android
else -> Icons.Default.QuestionMark
},
contentDescription = null,
@@ -91,9 +92,10 @@ fun ConfigurationModeButton(
ServerSource.HORDE -> LocalizationR.string.hint_server_horde_sub_title
ServerSource.HUGGING_FACE -> LocalizationR.string.hint_hugging_face_sub_title
ServerSource.OPEN_AI -> LocalizationR.string.hint_open_ai_sub_title
- ServerSource.LOCAL -> LocalizationR.string.hint_local_diffusion_sub_title
+ ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.hint_local_diffusion_sub_title
ServerSource.STABILITY_AI -> LocalizationR.string.hint_stability_ai_sub_title
ServerSource.SWARM_UI -> LocalizationR.string.hint_swarm_ui_sub_title
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.hint_mediapipe_sub_title
else -> null
}
descriptionId?.let { resId ->
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
index 6ebfed22..4408ad5d 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
@@ -54,6 +54,7 @@ import com.shifthackz.aisdv1.core.extensions.getRealPath
import com.shifthackz.aisdv1.core.model.asString
import com.shifthackz.aisdv1.domain.entity.DownloadState
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
+import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupScreenTags.CUSTOM_MODEL_SWITCH
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
@@ -91,7 +92,8 @@ fun LocalDiffusionForm(
val icon = when (model.downloadState) {
is DownloadState.Downloading -> Icons.Outlined.FileDownload
else -> when {
- model.id == LocalAiModel.CUSTOM.id -> Icons.Outlined.Landslide
+ model.id == LocalAiModel.CustomOnnx.id -> Icons.Outlined.Landslide
+ model.id == LocalAiModel.CustomMediaPipe.id -> Icons.Outlined.Landslide
model.downloaded -> Icons.Outlined.FileDownloadDone
else -> Icons.Outlined.FileDownloadOff
}
@@ -113,14 +115,20 @@ fun LocalDiffusionForm(
overflow = TextOverflow.Ellipsis,
maxLines = 2
)
- if (model.id != LocalAiModel.CUSTOM.id) {
- Text(
+ when (model.id) {
+ LocalAiModel.CustomOnnx.id,
+ LocalAiModel.CustomMediaPipe.id -> Unit
+
+ else -> Text(
text = model.size,
maxLines = 1
)
}
}
- if (model.id != LocalAiModel.CUSTOM.id) {
+ // Do not display action button for custom model
+ if (model.id != LocalAiModel.CustomOnnx.id
+ && model.id != LocalAiModel.CustomMediaPipe.id
+ ) {
Button(
modifier = Modifier.padding(end = 8.dp),
onClick = { processIntent(ServerSetupIntent.LocalModel.ClickReduce(model)) },
@@ -142,7 +150,9 @@ fun LocalDiffusionForm(
}
}
}
- if (model.id == LocalAiModel.CUSTOM.id) {
+ if (model.id == LocalAiModel.CustomOnnx.id
+ || model.id == LocalAiModel.CustomMediaPipe.id
+ ) {
Column(
modifier = Modifier.padding(8.dp),
) {
@@ -150,81 +160,83 @@ fun LocalDiffusionForm(
text = stringResource(id = LocalizationR.string.model_local_custom_title),
style = MaterialTheme.typography.bodyMedium,
)
- Spacer(modifier = Modifier.height(4.dp))
- Text(
- text = stringResource(id = LocalizationR.string.model_local_custom_sub_title),
- style = MaterialTheme.typography.bodyMedium,
- )
- Spacer(modifier = Modifier.height(4.dp))
+ if (model.id == LocalAiModel.CustomOnnx.id) {
+ Spacer(modifier = Modifier.height(4.dp))
+ Text(
+ text = stringResource(id = LocalizationR.string.model_local_custom_sub_title),
+ style = MaterialTheme.typography.bodyMedium,
+ )
+ Spacer(modifier = Modifier.height(4.dp))
- fun folderModifier(treeNum: Int) =
- Modifier.padding(start = (treeNum - 1) * 12.dp)
+ fun folderModifier(treeNum: Int) =
+ Modifier.padding(start = (treeNum - 1) * 12.dp)
- val folderStyle = MaterialTheme.typography.bodySmall
- Text(
- modifier = Modifier.padding(start = 12.dp),
- text = state.localCustomModelPath,
- style = folderStyle,
- )
+ val folderStyle = MaterialTheme.typography.bodySmall
+ Text(
+ modifier = Modifier.padding(start = 12.dp),
+ text = state.localOnnxCustomModelPath,
+ style = folderStyle,
+ )
- Text(
- modifier = folderModifier(3),
- text = "text_encoder",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "model.ort",
- style = folderStyle,
- )
+ Text(
+ modifier = folderModifier(3),
+ text = "text_encoder",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "model.ort",
+ style = folderStyle,
+ )
- Text(
- modifier = folderModifier(3),
- text = "tokenizer",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "merges.txt",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(3),
- text = "special_tokens_map.json",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "tokenizer_config.json",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "vocab.json",
- style = folderStyle,
- )
+ Text(
+ modifier = folderModifier(3),
+ text = "tokenizer",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "merges.txt",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(3),
+ text = "special_tokens_map.json",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "tokenizer_config.json",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "vocab.json",
+ style = folderStyle,
+ )
- Text(
- modifier = folderModifier(3),
- text = "unet",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "model.ort",
- style = folderStyle,
- )
+ Text(
+ modifier = folderModifier(3),
+ text = "unet",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "model.ort",
+ style = folderStyle,
+ )
- Text(
- modifier = folderModifier(3),
- text = "vae_decoder",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(4),
- text = "model.ort",
- style = folderStyle,
- )
+ Text(
+ modifier = folderModifier(3),
+ text = "vae_decoder",
+ style = folderStyle,
+ )
+ Text(
+ modifier = folderModifier(4),
+ text = "model.ort",
+ style = folderStyle,
+ )
+ }
}
}
when (model.downloadState) {
@@ -258,17 +270,29 @@ fun LocalDiffusionForm(
modifier = Modifier
.fillMaxWidth()
.padding(top = 32.dp, bottom = 8.dp),
- text = stringResource(id = LocalizationR.string.hint_local_diffusion_title),
+ text = stringResource(
+ id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ LocalizationR.string.hint_local_diffusion_title
+ } else {
+ LocalizationR.string.hint_mediapipe_title
+ },
+ ),
style = MaterialTheme.typography.bodyLarge,
textAlign = TextAlign.Center,
fontWeight = FontWeight.Bold,
)
Text(
modifier = Modifier.padding(top = 16.dp, bottom = 16.dp),
- text = stringResource(id = LocalizationR.string.hint_local_diffusion_sub_title),
+ text = stringResource(
+ id = if (state.mode == ServerSource.LOCAL_MICROSOFT_ONNX) {
+ LocalizationR.string.hint_local_diffusion_sub_title
+ } else {
+ LocalizationR.string.hint_mediapipe_sub_title
+ },
+ ),
style = MaterialTheme.typography.bodyMedium,
)
- if (buildInfoProvider.type == BuildType.FOSS) {
+ if (buildInfoProvider.type != BuildType.PLAY) {
Row(
verticalAlignment = Alignment.CenterVertically,
) {
@@ -285,7 +309,7 @@ fun LocalDiffusionForm(
)
}
}
- if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) {
+ if (state.localCustomModel && buildInfoProvider.type != BuildType.PLAY) {
Text(
modifier = Modifier
.align(Alignment.CenterHorizontally)
@@ -342,9 +366,12 @@ fun LocalDiffusionForm(
.fillMaxWidth()
.padding(top = 14.dp),
value = state.localCustomModelPath,
- onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) },
+ onValueChange = { string ->
+ string.filter { it != '\n' }
+ .let(ServerSetupIntent::SelectLocalModelPath)
+ .let(processIntent::invoke)
+ },
enabled = true,
- singleLine = true,
label = { Text(stringResource(LocalizationR.string.model_local_path_title)) },
trailingIcon = {
IconButton(
@@ -387,7 +414,8 @@ fun LocalDiffusionForm(
}
state.localModels
.filter {
- val customPredicate = it.id == LocalAiModel.CUSTOM.id
+ val customPredicate =
+ it.id == LocalAiModel.CustomOnnx.id || it.id == LocalAiModel.CustomMediaPipe.id
if (state.localCustomModel) customPredicate else !customPredicate
}
.forEach { localModel -> modelItemUi(localModel) }
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt
new file mode 100644
index 00000000..3bb26b1d
--- /dev/null
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/MediaPipeForm.kt
@@ -0,0 +1,29 @@
+package com.shifthackz.aisdv1.presentation.screen.setup.forms
+
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Modifier
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
+import com.shifthackz.aisdv1.core.common.appbuild.BuildType
+import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
+import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
+
+@Composable
+fun MediaPipeForm(
+ modifier: Modifier = Modifier,
+ state: ServerSetupState,
+ buildInfoProvider: BuildInfoProvider = BuildInfoProvider.stub,
+ processIntent: (ServerSetupIntent) -> Unit = {},
+) {
+ when (buildInfoProvider.type) {
+ BuildType.FOSS -> {
+
+ }
+
+ else -> LocalDiffusionForm(
+ modifier = modifier,
+ state = state,
+ buildInfoProvider = buildInfoProvider,
+ processIntent = processIntent,
+ )
+ }
+}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt
index f3a9ddc6..3cd1e443 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/LocalModelMappers.kt
@@ -5,8 +5,11 @@ import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
fun List.mapToUi(): List = map(LocalAiModel::mapToUi)
-fun List.mapLocalCustomModelSwitchState(): Boolean =
- find { it.selected && it.id == LocalAiModel.CUSTOM.id } != null
+fun List.mapLocalCustomOnnxSwitchState(): Boolean =
+ find { it.selected && it.id == LocalAiModel.CustomOnnx.id } != null
+
+fun List.mapLocalCustomMediaPipeSwitchState(): Boolean =
+ find { it.selected && it.id == LocalAiModel.CustomMediaPipe.id } != null
fun LocalAiModel.mapToUi(): ServerSetupState.LocalModel = with(this) {
ServerSetupState.LocalModel(
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt
new file mode 100644
index 00000000..1dd0b9c2
--- /dev/null
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ModesMapper.kt
@@ -0,0 +1,9 @@
+package com.shifthackz.aisdv1.presentation.screen.setup.mappers
+
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
+import com.shifthackz.aisdv1.domain.entity.ServerSource
+
+val BuildInfoProvider.allowedModes: List
+ get() = ServerSource
+ .entries
+ .filter { it.allowedInBuilds.contains(type) }
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt
index 4d53ca89..d25e8237 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/ConfigurationStep.kt
@@ -10,6 +10,7 @@ import com.shifthackz.aisdv1.presentation.screen.setup.forms.Automatic1111Form
import com.shifthackz.aisdv1.presentation.screen.setup.forms.HordeForm
import com.shifthackz.aisdv1.presentation.screen.setup.forms.HuggingFaceForm
import com.shifthackz.aisdv1.presentation.screen.setup.forms.LocalDiffusionForm
+import com.shifthackz.aisdv1.presentation.screen.setup.forms.MediaPipeForm
import com.shifthackz.aisdv1.presentation.screen.setup.forms.OpenAiForm
import com.shifthackz.aisdv1.presentation.screen.setup.forms.StabilityAiForm
import com.shifthackz.aisdv1.presentation.screen.setup.forms.SwarmUiForm
@@ -33,7 +34,7 @@ fun ConfigurationStep(
processIntent = processIntent,
)
- ServerSource.LOCAL -> LocalDiffusionForm(
+ ServerSource.LOCAL_MICROSOFT_ONNX -> LocalDiffusionForm(
state = state,
buildInfoProvider = buildInfoProvider,
processIntent = processIntent,
@@ -58,6 +59,12 @@ fun ConfigurationStep(
state = state,
processIntent = processIntent,
)
+
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> MediaPipeForm(
+ state = state,
+ buildInfoProvider = buildInfoProvider,
+ processIntent = processIntent,
+ )
}
}
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt
index 54c9f1b0..f7779458 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/steps/SourceSelectionStep.kt
@@ -5,16 +5,23 @@ import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
+import androidx.compose.foundation.lazy.LazyListItemInfo
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableIntStateOf
+import androidx.compose.runtime.remember
+import androidx.compose.runtime.setValue
import androidx.compose.ui.Modifier
+import androidx.compose.ui.layout.onSizeChanged
import androidx.compose.ui.unit.dp
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
import com.shifthackz.aisdv1.presentation.screen.setup.components.ConfigurationModeButton
+import kotlin.math.abs
@Composable
fun SourceSelectionStep(
@@ -23,12 +30,26 @@ fun SourceSelectionStep(
processIntent: (ServerSetupIntent) -> Unit = {},
) {
val lazyListState = rememberLazyListState()
+ var lazyListHeight by remember { mutableIntStateOf(0) }
+ var lazyListItemHeight by remember { mutableIntStateOf(0) }
+
LaunchedEffect(state.mode) {
// Adding 1 here, because item with index == 0 is top spacer
- lazyListState.animateScrollToItem(state.mode.ordinal + 1)
+ val newIndex = state.mode.ordinal +1
+ val visibleIndexes = lazyListState.layoutInfo.visibleItemsInfo
+ .filter { it.offset >= 0 }
+ .filter {
+ if (lazyListHeight == 0 || lazyListItemHeight == 0) true
+ else abs(lazyListHeight - it.offset) >= lazyListItemHeight
+ }
+ .map(LazyListItemInfo::index)
+
+ if (!visibleIndexes.contains(newIndex)) lazyListState.animateScrollToItem(newIndex)
}
+
LazyColumn(
- modifier = modifier,
+ modifier = modifier
+ .onSizeChanged { lazyListHeight = it.height },
state = lazyListState,
) {
item(key = "SPACER_TOP") { Spacer(modifier = Modifier.height(12.dp)) }
@@ -39,7 +60,8 @@ fun SourceSelectionStep(
ConfigurationModeButton(
modifier = Modifier
.fillMaxWidth()
- .padding(horizontal = 16.dp, vertical = 4.dp),
+ .padding(horizontal = 16.dp, vertical = 4.dp)
+ .onSizeChanged { lazyListItemHeight = it.height },
state = state,
mode = mode,
onClick = {
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt
index 2dad0725..04466de6 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageState.kt
@@ -133,7 +133,7 @@ fun TextToImageState.mapToPayload(): TextToImagePayload = with(this) {
subSeedStrength = subSeedStrength,
sampler = selectedSampler,
nsfw = if (mode == ServerSource.HORDE) nsfw else false,
- batchCount = if (mode == ServerSource.LOCAL) 1 else batchCount,
+ batchCount = if (mode == ServerSource.LOCAL_MICROSOFT_ONNX) 1 else batchCount,
style = openAiStyle.key.takeIf {
mode == ServerSource.OPEN_AI && openAiModel == OpenAiModel.DALL_E_3
},
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt
index 515c3248..52f063ed 100755
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/txt2img/TextToImageViewModel.kt
@@ -8,6 +8,7 @@ import com.shifthackz.aisdv1.core.model.asUiText
import com.shifthackz.aisdv1.core.notification.PushNotificationManager
import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator
import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager
@@ -66,11 +67,13 @@ class TextToImageViewModel(
) {
private val progressModal: Modal
- get() {
- if (currentState.mode == ServerSource.LOCAL) {
- return Modal.Generating(canCancel = preferenceManager.localDiffusionAllowCancel)
+ get() = when (currentState.mode) {
+ ServerSource.LOCAL_MICROSOFT_ONNX,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> {
+ Modal.Generating(canCancel = preferenceManager.localOnnxAllowCancel)
}
- return Modal.Communicating()
+
+ else -> Modal.Communicating()
}
override val initialState = TextToImageState()
@@ -135,7 +138,7 @@ class TextToImageViewModel(
?.let(::setActiveModal)
}
- override fun onReceivedLocalDiffusionStatus(status: LocalDiffusion.Status) {
+ override fun onReceivedLocalDiffusionStatus(status: LocalDiffusionStatus) {
(currentState.screenModal as? Modal.Generating)
?.copy(status = status)
?.let(::setActiveModal)
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt
index ae6e5173..bf787e02 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionComponent.kt
@@ -54,7 +54,7 @@ fun EngineSelectionComponent(
onItemSelected = { intentHandler(EngineSelectionIntent(it)) },
)
- ServerSource.LOCAL -> DropdownTextField(
+ ServerSource.LOCAL_MICROSOFT_ONNX -> DropdownTextField(
label = LocalizationR.string.hint_sd_model.asUiText(),
loading = state.loading,
modifier = modifier,
@@ -64,6 +64,7 @@ fun EngineSelectionComponent(
displayDelegate = { it.name.asUiText() },
)
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit
ServerSource.HORDE -> Unit
ServerSource.OPEN_AI -> Unit
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt
index 294a7c6c..9bb71002 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModel.kt
@@ -11,7 +11,7 @@ import com.shifthackz.aisdv1.domain.entity.Configuration
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
-import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase
@@ -25,7 +25,7 @@ import io.reactivex.rxjava3.kotlin.subscribeBy
class EngineSelectionViewModel(
dispatchersProvider: DispatchersProvider,
fetchAndGetSwarmUiModelsUseCase: FetchAndGetSwarmUiModelsUseCase,
- observeLocalAiModelsUseCase: ObserveLocalAiModelsUseCase,
+ observeLocalOnnxModelsUseCase: ObserveLocalOnnxModelsUseCase,
fetchAndGetStabilityAiEnginesUseCase: FetchAndGetStabilityAiEnginesUseCase,
getHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase,
private val preferenceManager: PreferenceManager,
@@ -61,8 +61,8 @@ class EngineSelectionViewModel(
.onErrorReturn { emptyList() }
.toFlowable()
- val localAiModels = observeLocalAiModelsUseCase()
- .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CUSTOM.id } }
+ val localAiModels = observeLocalOnnxModelsUseCase()
+ .map { models -> models.filter { it.downloaded || it.id == LocalAiModel.CustomOnnx.id } }
.onErrorReturn { emptyList() }
!Flowable.combineLatest(
@@ -94,7 +94,7 @@ class EngineSelectionViewModel(
stEngines = stEngines.map { it.id },
selectedStEngine = config.stabilityAiEngineId,
localAiModels = localModels,
- selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localModelId }?.id
+ selectedLocalAiModelId = localModels.firstOrNull { it.id == config.localOnnxModelId }?.id
?: state.selectedLocalAiModelId
)
}
@@ -131,7 +131,7 @@ class EngineSelectionViewModel(
ServerSource.STABILITY_AI -> preferenceManager.stabilityAiEngineId = intent.value
- ServerSource.LOCAL -> preferenceManager.localModelId = intent.value
+ ServerSource.LOCAL_MICROSOFT_ONNX -> preferenceManager.localOnnxModelId = intent.value
else -> Unit
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt
index 1a8583d5..fe9b4914 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/GenerationInputForm.kt
@@ -148,7 +148,7 @@ fun GenerationInputForm(
ServerSource.SWARM_UI,
ServerSource.STABILITY_AI,
ServerSource.HUGGING_FACE,
- ServerSource.LOCAL -> EngineSelectionComponent(
+ ServerSource.LOCAL_MICROSOFT_ONNX -> EngineSelectionComponent(
modifier = Modifier
.fillMaxWidth()
.padding(top = 8.dp),
@@ -206,7 +206,7 @@ fun GenerationInputForm(
ServerSource.SWARM_UI,
ServerSource.HUGGING_FACE,
ServerSource.STABILITY_AI,
- ServerSource.LOCAL -> {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> {
if (state.formPromptTaggedInput) {
ChipTextFieldWithItem(
modifier = Modifier
@@ -256,7 +256,7 @@ fun GenerationInputForm(
when (state.mode) {
ServerSource.HORDE,
- ServerSource.LOCAL -> {
+ ServerSource.LOCAL_MICROSOFT_ONNX -> {
DropdownTextField(
modifier = localModifier.padding(end = 4.dp),
label = LocalizationR.string.width.asUiText(),
@@ -295,6 +295,7 @@ fun GenerationInputForm(
displayDelegate = { it.key.asUiText() },
)
}
+ else -> Unit
}
}
@@ -497,9 +498,11 @@ fun GenerationInputForm(
else -> Unit
}
+ //Steps not available for open ai
if (state.mode != ServerSource.OPEN_AI) {
val stepsMax = when (state.mode) {
- ServerSource.LOCAL -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX
+ ServerSource.LOCAL_MICROSOFT_ONNX -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> SAMPLING_STEPS_LOCAL_DIFFUSION_MAX
ServerSource.STABILITY_AI -> SAMPLING_STEPS_RANGE_STABILITY_AI_MAX
else -> SAMPLING_STEPS_RANGE_MAX
}
@@ -519,24 +522,31 @@ fun GenerationInputForm(
processIntent(GenerationMviIntent.Update.SamplingSteps(it.roundToInt()))
},
)
+ }
- Text(
- modifier = Modifier.padding(top = 8.dp),
- text = stringResource(
- LocalizationR.string.hint_cfg_scale,
- "${state.cfgScale.roundTo(2)}",
- ),
- )
- SliderTextInputField(
- value = state.cfgScale,
- valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f),
- valueDiff = 0.5f,
- steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1,
- sliderColors = sliderColors,
- onValueChange = {
- processIntent(GenerationMviIntent.Update.CfgScale(it))
- },
- )
+ // CFG scale not available on open ai and google media pipe
+ when (state.mode) {
+ ServerSource.OPEN_AI,
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> Unit
+ else -> {
+ Text(
+ modifier = Modifier.padding(top = 8.dp),
+ text = stringResource(
+ LocalizationR.string.hint_cfg_scale,
+ "${state.cfgScale.roundTo(2)}",
+ ),
+ )
+ SliderTextInputField(
+ value = state.cfgScale,
+ valueRange = (CFG_SCALE_RANGE_MIN * 1f)..(CFG_SCALE_RANGE_MAX * 1f),
+ valueDiff = 0.5f,
+ steps = abs(CFG_SCALE_RANGE_MAX - CFG_SCALE_RANGE_MIN) * 2 - 1,
+ sliderColors = sliderColors,
+ onValueChange = {
+ processIntent(GenerationMviIntent.Update.CfgScale(it))
+ },
+ )
+ }
}
when (state.mode) {
@@ -548,9 +558,10 @@ fun GenerationInputForm(
else -> Unit
}
- // Batch is not available for Local Diffusion
- if (state.mode != ServerSource.LOCAL) {
- batchComponent()
+ // Batch is not available for any Local
+ when (state.mode) {
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE, ServerSource.LOCAL_MICROSOFT_ONNX -> Unit
+ else -> batchComponent()
}
//Restore faces available only for A1111
if (state.mode == ServerSource.AUTOMATIC1111) {
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt
index 4175e70f..8e3964c9 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/input/SliderTextInputField.kt
@@ -94,9 +94,9 @@ fun SliderTextInputField(
enabled = true,
singleLine = true,
keyboardOptions = KeyboardOptions(
+ autoCorrectEnabled = false,
keyboardType = KeyboardType.Number,
- autoCorrect = false,
- imeAction = ImeAction.Done,
+ imeAction = ImeAction.Done
),
label = { Text(stringResource(id = R.string.hint_value)) },
trailingIcon = {
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt
index fdca7744..4ee15fd4 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/widget/source/ServerSourceLabel.kt
@@ -15,7 +15,8 @@ fun ServerSource.getName(): String {
fun ServerSource.getNameUiText(): UiText = when (this) {
ServerSource.AUTOMATIC1111 -> LocalizationR.string.srv_type_own
ServerSource.HORDE -> LocalizationR.string.srv_type_horde
- ServerSource.LOCAL -> LocalizationR.string.srv_type_local
+ ServerSource.LOCAL_MICROSOFT_ONNX -> LocalizationR.string.srv_type_local
+ ServerSource.LOCAL_GOOGLE_MEDIA_PIPE -> LocalizationR.string.srv_type_media_pipe
ServerSource.HUGGING_FACE -> LocalizationR.string.srv_type_hugging_face
ServerSource.OPEN_AI -> LocalizationR.string.srv_type_open_ai
ServerSource.STABILITY_AI -> LocalizationR.string.srv_type_stability_ai
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt
index e6c332c2..aa25ae1c 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/core/CoreGenerationMviViewModelTest.kt
@@ -4,8 +4,8 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
import com.shifthackz.aisdv1.core.notification.PushNotificationManager
import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator
import com.shifthackz.aisdv1.domain.entity.HordeProcessStatus
+import com.shifthackz.aisdv1.domain.entity.LocalDiffusionStatus
import com.shifthackz.aisdv1.domain.entity.Settings
-import com.shifthackz.aisdv1.domain.feature.diffusion.LocalDiffusion
import com.shifthackz.aisdv1.domain.feature.work.BackgroundTaskManager
import com.shifthackz.aisdv1.domain.feature.work.BackgroundWorkObserver
import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor
@@ -57,7 +57,7 @@ abstract class CoreGenerationMviViewModelTest()
private val stubHordeProcessStatus = BehaviorSubject.create()
- private val stubLdStatus = BehaviorSubject.create()
+ private val stubLdStatus = BehaviorSubject.create()
protected val stubCustomSchedulers = object : SchedulersProvider {
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt
index 8b9eb04a..62884de9 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/mocks/LocalAiModelMocks.kt
@@ -5,9 +5,10 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupState
val mockLocalAiModels = listOf(
- LocalAiModel.CUSTOM,
+ LocalAiModel.CustomOnnx,
LocalAiModel(
id = "1",
+ type = LocalAiModel.Type.ONNX,
name = "Model 1",
size = "5 Gb",
sources = listOf("https://example.com/1.html"),
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt
index 05e9a78f..d9453b19 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/settings/SettingsViewModelTest.kt
@@ -237,13 +237,13 @@ class SettingsViewModelTest : CoreViewModelTest() {
@Test
fun `given received UpdateFlag NNAPI intent, expected localUseNNAPI preference updated`() {
every {
- stubPreferenceManager::localUseNNAPI.set(any())
+ stubPreferenceManager::localOnnxUseNNAPI.set(any())
} returns Unit
viewModel.processIntent(SettingsIntent.UpdateFlag.NNAPI(true))
verify {
- stubPreferenceManager::localUseNNAPI.set(true)
+ stubPreferenceManager::localOnnxUseNNAPI.set(true)
}
}
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt
index c82a1dbc..27eb75b1 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupScreenTest.kt
@@ -107,7 +107,7 @@ class ServerSetupScreenTest : CoreComposeTest {
stubUiState.update {
it.copy(
step = ServerSetupState.Step.CONFIGURE,
- mode = ServerSource.LOCAL
+ mode = ServerSource.LOCAL_MICROSOFT_ONNX
)
}
@@ -125,8 +125,8 @@ class ServerSetupScreenTest : CoreComposeTest {
stubUiState.update {
it.copy(
step = ServerSetupState.Step.CONFIGURE,
- mode = ServerSource.LOCAL,
- localModels = mockLocalAiModels.mapToUi()
+ mode = ServerSource.LOCAL_MICROSOFT_ONNX,
+ localOnnxModels = mockLocalAiModels.mapToUi()
)
}
val setupButton = onNodeWithTestTag(ServerSetupScreenTags.MAIN_BUTTON)
@@ -143,9 +143,9 @@ class ServerSetupScreenTest : CoreComposeTest {
switch.performClick()
stubUiState.update {
it.copy(
- localCustomModel = true,
- localModels = it.localModels.withNewState(
- it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy(
+ localOnnxCustomModel = true,
+ localOnnxModels = it.localOnnxModels.withNewState(
+ it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy(
selected = true,
downloaded = true
),
@@ -165,9 +165,9 @@ class ServerSetupScreenTest : CoreComposeTest {
switch.performClick()
stubUiState.update {
it.copy(
- localCustomModel = false,
- localModels = it.localModels.withNewState(
- it.localModels.find { m -> m.id == LocalAiModel.CUSTOM.id }!!.copy(
+ localOnnxCustomModel = false,
+ localOnnxModels = it.localOnnxModels.withNewState(
+ it.localOnnxModels.find { m -> m.id == LocalAiModel.CustomOnnx.id }!!.copy(
selected = false,
),
),
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
index a75bad5b..f6ae1a88 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
@@ -1,5 +1,6 @@
package com.shifthackz.aisdv1.presentation.screen.setup
+import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator
import com.shifthackz.aisdv1.core.validation.path.FilePathValidator
import com.shifthackz.aisdv1.core.validation.url.UrlValidator
@@ -11,7 +12,8 @@ import com.shifthackz.aisdv1.domain.interactor.wakelock.WakeLockInterActor
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.domain.usecase.downloadable.DeleteModelUseCase
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
-import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalAiModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.settings.GetConfigurationUseCase
import com.shifthackz.aisdv1.presentation.core.CoreViewModelTest
@@ -39,7 +41,8 @@ import org.junit.Test
class ServerSetupViewModelTest : CoreViewModelTest() {
private val stubGetConfigurationUseCase = mockk()
- private val stubGetLocalAiModelsUseCase = mockk()
+ private val stubGetLocalOnnxModelsUseCase = mockk()
+ private val stubGetLocalMediaPipeModelsUseCase = mockk()
private val stubFetchAndGetHuggingFaceModelsUseCase = mockk()
private val stubUrlValidator = mockk()
private val stubCommonStringValidator = mockk()
@@ -55,7 +58,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
launchSource = LaunchSource.SETTINGS,
dispatchersProvider = stubDispatchersProvider,
getConfigurationUseCase = stubGetConfigurationUseCase,
- getLocalAiModelsUseCase = stubGetLocalAiModelsUseCase,
+ getLocalOnnxModelsUseCase = stubGetLocalOnnxModelsUseCase,
+ getLocalMediaPipeModelsUseCase = stubGetLocalMediaPipeModelsUseCase,
fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase,
urlValidator = stubUrlValidator,
stringValidator = stubCommonStringValidator,
@@ -67,6 +71,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
preferenceManager = stubPreferenceManager,
wakeLockInterActor = stubWakeLockInterActor,
mainRouter = stubMainRouter,
+ buildInfoProvider = BuildInfoProvider.stub,
)
@Before
@@ -78,9 +83,13 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
} returns Single.just(Configuration(serverUrl = "https://5598.is.my.favorite.com"))
every {
- stubGetLocalAiModelsUseCase()
+ stubGetLocalOnnxModelsUseCase()
} returns Single.just(mockLocalAiModels)
+ every {
+ stubGetLocalMediaPipeModelsUseCase()
+ } returns Single.just(emptyList())
+
every {
stubFetchAndGetHuggingFaceModelsUseCase()
} returns Single.just(mockHuggingFaceModels)
@@ -96,13 +105,14 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
fun `initialized, expected UI state updated with correct stub values`() {
val state = viewModel.state.value
Assert.assertEquals(true, state.huggingFaceModels.isNotEmpty())
- Assert.assertEquals(true, state.localModels.isNotEmpty())
+ Assert.assertEquals(true, state.localOnnxModels.isNotEmpty())
Assert.assertEquals("https://5598.is.my.favorite.com", state.serverUrl)
Assert.assertEquals(ServerSetupState.AuthType.ANONYMOUS, state.authType)
}
@Test
fun `given received AllowLocalCustomModel intent, expected Custom local model selected in UI state`() {
+ viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX))
viewModel.processIntent(ServerSetupIntent.AllowLocalCustomModel(true))
val state = viewModel.state.value
val expectedLocalModels = listOf(
@@ -121,8 +131,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
selected = false,
)
)
- Assert.assertEquals(true, state.localCustomModel)
- Assert.assertEquals(expectedLocalModels, state.localModels)
+ Assert.assertEquals(true, state.localOnnxCustomModel)
+ Assert.assertEquals(expectedLocalModels, state.localOnnxModels)
}
@Test
@@ -147,6 +157,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
stubWakeLockInterActor.releaseWakeLockUseCase()
} returns Result.success(Unit)
+ viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX))
val localModel = mockServerSetupStateLocalModel.copy(
downloadState = DownloadState.Unknown,
)
@@ -199,7 +210,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
val state = viewModel.state.value
val expected = false
- val actual = state.localModels.any {
+ val actual = state.localOnnxModels.any {
it.downloadState == DownloadState.Downloading(22)
}
Assert.assertEquals(expected, actual)
@@ -221,7 +232,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
runTest {
val state = viewModel.state.value
Assert.assertEquals(Modal.None, state.screenModal)
- Assert.assertEquals(false, state.localModels.find { it.id == "1" }!!.downloaded)
+ Assert.assertEquals(false, state.localOnnxModels.find { it.id == "1" }!!.downloaded)
}
verify {
stubDeleteModelUseCase("1")
@@ -230,9 +241,10 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
@Test
fun `given received SelectLocalModel intent, expected passed LocalModel is selected in UI state`() {
+ viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX))
viewModel.processIntent(ServerSetupIntent.SelectLocalModel(mockServerSetupStateLocalModel))
val state = viewModel.state.value
- Assert.assertEquals(true, state.localModels.find { it.id == "1" }!!.selected)
+ Assert.assertEquals(true, state.localOnnxModels.find { it.id == "1" }!!.selected)
}
@Test
@@ -317,8 +329,8 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
@Test
fun `given received UpdateServerMode intent, expected mode field in UI state is LOCAL`() {
- viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL))
- val expected = ServerSource.LOCAL
+ viewModel.processIntent(ServerSetupIntent.UpdateServerMode(ServerSource.LOCAL_MICROSOFT_ONNX))
+ val expected = ServerSource.LOCAL_MICROSOFT_ONNX
val actual = viewModel.state.value.mode
Assert.assertEquals(expected, actual)
}
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt
index 855a9128..269474df 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/widget/engine/EngineSelectionViewModelTest.kt
@@ -5,7 +5,7 @@ import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.domain.entity.ServerSource
import com.shifthackz.aisdv1.domain.entity.Settings
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
-import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalAiModelsUseCase
+import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.huggingface.FetchAndGetHuggingFaceModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.GetStableDiffusionModelsUseCase
import com.shifthackz.aisdv1.domain.usecase.sdmodel.SelectStableDiffusionModelUseCase
@@ -44,7 +44,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest
private val stubGetConfigurationUseCase = mockk()
private val stubSelectStableDiffusionModelUseCase = mockk()
private val stubGetStableDiffusionModelsUseCase = mockk()
- private val stubObserveLocalAiModelsUseCase = mockk()
+ private val stubObserveLocalAiModelsUseCase = mockk()
private val stubFetchAndGetStabilityAiEnginesUseCase = mockk()
private val stubFetchAndGetHuggingFaceModelsUseCase = mockk()
private val stubFetchAndGetSwarmUiModelsUseCase = mockk()
@@ -56,7 +56,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest
getConfigurationUseCase = stubGetConfigurationUseCase,
selectStableDiffusionModelUseCase = stubSelectStableDiffusionModelUseCase,
getStableDiffusionModelsUseCase = stubGetStableDiffusionModelsUseCase,
- observeLocalAiModelsUseCase = stubObserveLocalAiModelsUseCase,
+ observeLocalOnnxModelsUseCase = stubObserveLocalAiModelsUseCase,
fetchAndGetStabilityAiEnginesUseCase = stubFetchAndGetStabilityAiEnginesUseCase,
getHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase,
fetchAndGetSwarmUiModelsUseCase = stubFetchAndGetSwarmUiModelsUseCase,
@@ -108,7 +108,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest
selectedHfModel = "prompthero/openjourney-v4",
stEngines = listOf("5598"),
selectedStEngine = "5598",
- localAiModels = listOf(LocalAiModel.CUSTOM),
+ localAiModels = listOf(LocalAiModel.CustomOnnx),
selectedLocalAiModelId = "CUSTOM",
swarmModels = listOf("5598"),
selectedSwarmModel = "5598",
@@ -229,16 +229,16 @@ class EngineSelectionViewModelTest : CoreViewModelTest
@Test
fun `given received EngineSelectionIntent, source is LOCAL, expected localModelId changed in preference`() {
- mockInitialData(DataTestCase.Mock, ServerSource.LOCAL)
+ mockInitialData(DataTestCase.Mock, ServerSource.LOCAL_MICROSOFT_ONNX)
every {
- stubPreferenceManager::localModelId.set(any())
+ stubPreferenceManager::localOnnxModelId.set(any())
} returns Unit
viewModel.processIntent(EngineSelectionIntent("llm_5598"))
verify {
- stubPreferenceManager::localModelId.set("llm_5598")
+ stubPreferenceManager::localOnnxModelId.set("llm_5598")
}
}
@@ -262,7 +262,7 @@ class EngineSelectionViewModelTest : CoreViewModelTest
huggingFaceModel = "prompthero/openjourney-v4",
stabilityAiEngineId = "5598",
swarmUiModel = "5598",
- localModelId = "CUSTOM",
+ localOnnxModelId = "CUSTOM",
source = source,
),
)
diff --git a/settings.gradle.kts b/settings.gradle.kts
index 6ee0480b..fca65518 100755
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -1,8 +1,8 @@
pluginManagement {
includeBuild("build-logic")
repositories {
- google()
mavenCentral()
+ google()
gradlePluginPortal()
maven {
url = uri("https://jitpack.io")
@@ -12,8 +12,8 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
- google()
mavenCentral()
+ google()
maven {
url = uri("https://jitpack.io")
}
@@ -35,6 +35,7 @@ val modules = listOf(
":domain",
":feature:auth",
":feature:diffusion",
+ ":feature:mediapipe",
":feature:work",
":network",
":presentation",
diff --git a/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json
new file mode 100644
index 00000000..7f7ac2a7
--- /dev/null
+++ b/storage/schemas/com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase/6.json
@@ -0,0 +1,254 @@
+{
+ "formatVersion": 1,
+ "database": {
+ "version": 6,
+ "identityHash": "6f6ccee56637122e0126c09bb3eb3fdc",
+ "entities": [
+ {
+ "tableName": "generation_results",
+ "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, `image_base_64` TEXT NOT NULL, `original_image_base_64` TEXT NOT NULL, `created_at` INTEGER NOT NULL, `generation_type` TEXT NOT NULL, `prompt` TEXT NOT NULL, `negative_prompt` TEXT NOT NULL, `width` INTEGER NOT NULL, `height` INTEGER NOT NULL, `sampling_steps` INTEGER NOT NULL, `cfg_scale` REAL NOT NULL, `restore_faces` INTEGER NOT NULL, `sampler` TEXT NOT NULL, `seed` TEXT NOT NULL, `sub_seed` TEXT NOT NULL DEFAULT '', `sub_seed_strength` REAL NOT NULL DEFAULT 0.0, `denoising_strength` REAL NOT NULL DEFAULT 0.0)",
+ "fields": [
+ {
+ "fieldPath": "id",
+ "columnName": "id",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "imageBase64",
+ "columnName": "image_base_64",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "originalImageBase64",
+ "columnName": "original_image_base_64",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "createdAt",
+ "columnName": "created_at",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "generationType",
+ "columnName": "generation_type",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "prompt",
+ "columnName": "prompt",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "negativePrompt",
+ "columnName": "negative_prompt",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "width",
+ "columnName": "width",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "height",
+ "columnName": "height",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "samplingSteps",
+ "columnName": "sampling_steps",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "cfgScale",
+ "columnName": "cfg_scale",
+ "affinity": "REAL",
+ "notNull": true
+ },
+ {
+ "fieldPath": "restoreFaces",
+ "columnName": "restore_faces",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "sampler",
+ "columnName": "sampler",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "seed",
+ "columnName": "seed",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "subSeed",
+ "columnName": "sub_seed",
+ "affinity": "TEXT",
+ "notNull": true,
+ "defaultValue": "''"
+ },
+ {
+ "fieldPath": "subSeedStrength",
+ "columnName": "sub_seed_strength",
+ "affinity": "REAL",
+ "notNull": true,
+ "defaultValue": "0.0"
+ },
+ {
+ "fieldPath": "denoisingStrength",
+ "columnName": "denoising_strength",
+ "affinity": "REAL",
+ "notNull": true,
+ "defaultValue": "0.0"
+ }
+ ],
+ "primaryKey": {
+ "autoGenerate": true,
+ "columnNames": [
+ "id"
+ ]
+ },
+ "indices": [],
+ "foreignKeys": []
+ },
+ {
+ "tableName": "local_models",
+ "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `type` TEXT NOT NULL DEFAULT 'onnx', `name` TEXT NOT NULL, `size` TEXT NOT NULL, `sources` TEXT NOT NULL, PRIMARY KEY(`id`))",
+ "fields": [
+ {
+ "fieldPath": "id",
+ "columnName": "id",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "type",
+ "columnName": "type",
+ "affinity": "TEXT",
+ "notNull": true,
+ "defaultValue": "'onnx'"
+ },
+ {
+ "fieldPath": "name",
+ "columnName": "name",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "size",
+ "columnName": "size",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "sources",
+ "columnName": "sources",
+ "affinity": "TEXT",
+ "notNull": true
+ }
+ ],
+ "primaryKey": {
+ "autoGenerate": false,
+ "columnNames": [
+ "id"
+ ]
+ },
+ "indices": [],
+ "foreignKeys": []
+ },
+ {
+ "tableName": "hugging_face_models",
+ "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` TEXT NOT NULL, `name` TEXT NOT NULL, `alias` TEXT NOT NULL, `source` TEXT NOT NULL, PRIMARY KEY(`id`))",
+ "fields": [
+ {
+ "fieldPath": "id",
+ "columnName": "id",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "name",
+ "columnName": "name",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "alias",
+ "columnName": "alias",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "source",
+ "columnName": "source",
+ "affinity": "TEXT",
+ "notNull": true
+ }
+ ],
+ "primaryKey": {
+ "autoGenerate": false,
+ "columnNames": [
+ "id"
+ ]
+ },
+ "indices": [],
+ "foreignKeys": []
+ },
+ {
+ "tableName": "supporters",
+ "createSql": "CREATE TABLE IF NOT EXISTS `${TABLE_NAME}` (`id` INTEGER NOT NULL, `name` TEXT NOT NULL, `date` INTEGER NOT NULL, `message` TEXT NOT NULL, PRIMARY KEY(`id`))",
+ "fields": [
+ {
+ "fieldPath": "id",
+ "columnName": "id",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "name",
+ "columnName": "name",
+ "affinity": "TEXT",
+ "notNull": true
+ },
+ {
+ "fieldPath": "date",
+ "columnName": "date",
+ "affinity": "INTEGER",
+ "notNull": true
+ },
+ {
+ "fieldPath": "message",
+ "columnName": "message",
+ "affinity": "TEXT",
+ "notNull": true
+ }
+ ],
+ "primaryKey": {
+ "autoGenerate": false,
+ "columnNames": [
+ "id"
+ ]
+ },
+ "indices": [],
+ "foreignKeys": []
+ }
+ ],
+ "views": [],
+ "setupQueries": [
+ "CREATE TABLE IF NOT EXISTS room_master_table (id INTEGER PRIMARY KEY,identity_hash TEXT)",
+ "INSERT OR REPLACE INTO room_master_table (id,identity_hash) VALUES(42, '6f6ccee56637122e0126c09bb3eb3fdc')"
+ ]
+ }
+}
\ No newline at end of file
diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt
index df8eada6..1a59785c 100644
--- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt
+++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/PersistentDatabase.kt
@@ -4,18 +4,11 @@ import androidx.room.AutoMigration
import androidx.room.Database
import androidx.room.RoomDatabase
import androidx.room.TypeConverters
-import com.shifthackz.aisdv1.storage.converters.DateConverters
-import com.shifthackz.aisdv1.storage.converters.ListConverters
+import com.shifthackz.aisdv1.storage.converters.*
import com.shifthackz.aisdv1.storage.db.persistent.PersistentDatabase.Companion.DB_VERSION
-import com.shifthackz.aisdv1.storage.db.persistent.contract.GenerationResultContract
-import com.shifthackz.aisdv1.storage.db.persistent.dao.GenerationResultDao
-import com.shifthackz.aisdv1.storage.db.persistent.dao.HuggingFaceModelDao
-import com.shifthackz.aisdv1.storage.db.persistent.dao.LocalModelDao
-import com.shifthackz.aisdv1.storage.db.persistent.dao.SupporterDao
-import com.shifthackz.aisdv1.storage.db.persistent.entity.GenerationResultEntity
-import com.shifthackz.aisdv1.storage.db.persistent.entity.HuggingFaceModelEntity
-import com.shifthackz.aisdv1.storage.db.persistent.entity.LocalModelEntity
-import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity
+import com.shifthackz.aisdv1.storage.db.persistent.contract.*
+import com.shifthackz.aisdv1.storage.db.persistent.dao.*
+import com.shifthackz.aisdv1.storage.db.persistent.entity.*
@Database(
version = DB_VERSION,
@@ -46,6 +39,11 @@ import com.shifthackz.aisdv1.storage.db.persistent.entity.SupporterEntity
* Added [SupporterEntity].
*/
AutoMigration(from = 4, to = 5),
+ /**
+ * Added 1 field to [LocalModelEntity]:
+ * - [LocalModelContract.TYPE]
+ */
+ AutoMigration(from = 5, to = 6),
],
)
@TypeConverters(
@@ -60,6 +58,6 @@ internal abstract class PersistentDatabase : RoomDatabase() {
companion object {
const val DB_NAME = "ai_sd_v1_storage_db"
- const val DB_VERSION = 5
+ const val DB_VERSION = 6
}
}
diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt
index 03b3c03a..2ffebf43 100644
--- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt
+++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/contract/LocalModelContract.kt
@@ -4,6 +4,7 @@ object LocalModelContract {
const val TABLE = "local_models"
const val ID = "id"
+ const val TYPE = "type"
const val NAME = "name"
const val SIZE = "size"
const val SOURCES = "sources"
diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt
index 435d4c69..a2b94a16 100644
--- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt
+++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/dao/LocalModelDao.kt
@@ -16,9 +16,15 @@ interface LocalModelDao {
@Query("SELECT * FROM ${LocalModelContract.TABLE}")
fun query(): Single>
+ @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type")
+ fun queryByType(type: String): Single>
+
@Query("SELECT * FROM ${LocalModelContract.TABLE}")
fun observe(): Flowable>
+ @Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.TYPE} = :type")
+ fun observeByType(type: String): Flowable>
+
@Query("SELECT * FROM ${LocalModelContract.TABLE} WHERE ${LocalModelContract.ID} = :id LIMIT 1")
fun queryById(id: String): Single
diff --git a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt
index ab896641..d69856d2 100644
--- a/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt
+++ b/storage/src/main/java/com/shifthackz/aisdv1/storage/db/persistent/entity/LocalModelEntity.kt
@@ -10,6 +10,8 @@ data class LocalModelEntity(
@PrimaryKey(autoGenerate = false)
@ColumnInfo(name = LocalModelContract.ID)
val id: String,
+ @ColumnInfo(name = LocalModelContract.TYPE, defaultValue = "onnx")
+ val type: String,
@ColumnInfo(name = LocalModelContract.NAME)
val name: String,
@ColumnInfo(name = LocalModelContract.SIZE)