diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt
index 2cae5660..22ce7dbb 100644
--- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt
+++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt
@@ -1,5 +1,7 @@
package com.shifthackz.aisdv1.core.common.file
+const val LOCAL_DIFFUSION_CUSTOM_PATH = "/storage/emulated/0/Download/SDAI/model"
+
interface FileProviderDescriptor {
val providerPath: String
val imagesCacheDirPath: String
diff --git a/core/localization/src/main/res/values-ru/strings.xml b/core/localization/src/main/res/values-ru/strings.xml
index 726aaa23..a9dd642b 100644
--- a/core/localization/src/main/res/values-ru/strings.xml
+++ b/core/localization/src/main/res/values-ru/strings.xml
@@ -295,10 +295,14 @@
Нажмите, чтобы открыть приложение.
Загрузка кастом модели
+ Разрешения
Чтобы иметь возможность загружать пользовательскую модель, вам необходимо разрешить приложению SDAI управлять разрешениями на хранилище, поскольку, начиная с Android 11, оно необходимо для доступа к файлам хранилища без области действия.
Настроить доступ
+ Путь к модели
+ Путь к папке локальной модели
+ Выбрать папку
- Чтобы использовать локальную пользовательскую модель, поместите ее в локальную папку в памяти телефона: Download/SDAi/model
+ Чтобы использовать локальную пользовательскую модель, поместите ее в локальную папку в памяти телефона.
Окончательная структура папок должна быть такой:
Отладка
diff --git a/core/localization/src/main/res/values-tr/strings.xml b/core/localization/src/main/res/values-tr/strings.xml
index 1c4a97ae..815b72d9 100644
--- a/core/localization/src/main/res/values-tr/strings.xml
+++ b/core/localization/src/main/res/values-tr/strings.xml
@@ -298,8 +298,12 @@
Özel modeli yükleyebilmek için SDAI uygulamasının depolama izinlerini yönetmesine izin vermeniz gerekir; çünkü Android 11\'den itibaren kapsamlı olmayan depolama dosyalarına erişmek gerekir.
Kurulum izni
- Yerel özel modeli kullanmak için telefonunuzun depolama alanındaki yerel klasöre yerleştirin: Download/SDAi/model
+ Yerel özel modeli kullanmak için telefonunuzun depolama alanındaki yerel klasöre yerleştirin.
+ İzinler
Son klasör yapısı şu şekilde olmalıdır::
+ Model yolu
+ Yerel model klasör yolu
+ Klasör seç
Hata ayıklama
Work Manager API
diff --git a/core/localization/src/main/res/values-uk/strings.xml b/core/localization/src/main/res/values-uk/strings.xml
index 8b901cd4..1ff9adf6 100644
--- a/core/localization/src/main/res/values-uk/strings.xml
+++ b/core/localization/src/main/res/values-uk/strings.xml
@@ -298,8 +298,12 @@
Щоб мати можливість завантажити спеціальну модель, вам потрібно дозволити додатку SDAI керувати дозволами на зберігання, оскільки, починаючи з Android 11, це потрібно для доступу до файлів зберігання без обмежень.
Налаштувати доступ
- Щоб використовувати локальну спеціальну модель, помістіть її в локальну папку в пам’яті телефону: Download/SDAi/model
+ Щоб використовувати локальну спеціальну модель, помістіть її в локальну папку в пам’яті телефону.
+ Дозволи
Остаточна структура папок має бути такою:
+ Шлях моделі
+ Шлях папки локальної моделі
+ Виберіть папку
Відладка
Work Manager API
diff --git a/core/localization/src/main/res/values-zh/strings.xml b/core/localization/src/main/res/values-zh/strings.xml
index 0849958f..2199fbcd 100644
--- a/core/localization/src/main/res/values-zh/strings.xml
+++ b/core/localization/src/main/res/values-zh/strings.xml
@@ -358,8 +358,12 @@
加载自定义模型
为了能够加载自定义模型,您需要允许SDAI应用管理存储权限,因为从Android 11开始,它需要访问非范围存储文件。
设置权限
+ 模型路径
+ 本地模型文件夹路径
+ 选择文件夹
- 要使用本地自定义模型,请将其放置在手机存储中的本地文件夹:Download/SDAi/model
+ 要使用本地自定义模型,请将其放置在手机存储中的本地文件夹。
+ 权限
最终的文件夹结构应该是:
diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml
index bd1b8e31..f455448f 100755
--- a/core/localization/src/main/res/values/strings.xml
+++ b/core/localization/src/main/res/values/strings.xml
@@ -317,10 +317,14 @@
Local Diffusion
Load custom model
+ Permissions
To be able to load custom model, you need to allow SDAI app manage storage permissions, because starting from Android 11 it is needed to access non-scoped storage files.
Setup permission
+ Model path
+ Local model folder path
+ Select folder
- To use local custom model, place it to local folder in your phone storage: Download/SDAi/model
+ To use local custom model, place it to local folder in your phone storage.
The final folder structure should be:
Debugging
diff --git a/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt b/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt
new file mode 100644
index 00000000..395f73cd
--- /dev/null
+++ b/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt
@@ -0,0 +1,123 @@
+package com.shifthackz.aisdv1.core.extensions
+
+import android.content.ContentUris
+import android.content.Context
+import android.database.Cursor
+import android.net.Uri
+import android.os.Environment
+import android.provider.DocumentsContract
+import android.provider.MediaStore
+
+fun getRealPath(context: Context, uri: Uri): String? {
+ if (DocumentsContract.isDocumentUri(context, uri)) {
+ if (isExternalStorageDocument(uri)) {
+ val docId = DocumentsContract.getDocumentId(uri)
+ val split = docId.split(":".toRegex()).dropLastWhile { it.isEmpty() }
+ .toTypedArray()
+ val type = split[0]
+
+ if ("primary".equals(type, ignoreCase = true)) {
+ return Environment.getExternalStorageDirectory().toString() + "/" + split[1]
+ } else {
+ return "/storage/$type/${split[1]}"
+ }
+
+ } else if (isDownloadsDocument(uri)) {
+ val id = DocumentsContract.getDocumentId(uri)
+ val contentUri = ContentUris.withAppendedId(
+ Uri.parse("content://downloads/public_downloads"), id.toLong()
+ )
+
+ return getDataColumn(context, contentUri, null, null)
+ } else if (isMediaDocument(uri)) {
+ val docId = DocumentsContract.getDocumentId(uri)
+ val split = docId.split(":".toRegex()).dropLastWhile { it.isEmpty() }
+ .toTypedArray()
+ val type = split[0]
+
+ var contentUri: Uri? = null
+ if ("image" == type) {
+ contentUri = MediaStore.Images.Media.EXTERNAL_CONTENT_URI
+ } else if ("video" == type) {
+ contentUri = MediaStore.Video.Media.EXTERNAL_CONTENT_URI
+ } else if ("audio" == type) {
+ contentUri = MediaStore.Audio.Media.EXTERNAL_CONTENT_URI
+ }
+
+ val selection = "_id=?"
+ val selectionArgs = arrayOf(
+ split[1]
+ )
+
+ return getDataColumn(context, contentUri, selection, selectionArgs)
+ }
+ } else if ("content".equals(uri.scheme, ignoreCase = true)) {
+ // Return the remote address
+
+ if (isGooglePhotosUri(uri)) return uri.lastPathSegment
+
+ return getDataColumn(context, uri, null, null)
+ } else if ("file".equals(uri.scheme, ignoreCase = true)) {
+ return uri.path
+ }
+
+ return null
+}
+
+fun getDataColumn(
+ context: Context, uri: Uri?, selection: String?,
+ selectionArgs: Array?
+): String? {
+ var cursor: Cursor? = null
+ val column = "_data"
+ val projection = arrayOf(
+ column
+ )
+
+ try {
+ cursor = context.contentResolver.query(
+ uri!!, projection, selection, selectionArgs,
+ null
+ )
+ if (cursor != null && cursor.moveToFirst()) {
+ val index = cursor.getColumnIndexOrThrow(column)
+ return cursor.getString(index)
+ }
+ } finally {
+ cursor?.close()
+ }
+ return null
+}
+
+
+/**
+ * @param uri The Uri to check.
+ * @return Whether the Uri authority is ExternalStorageProvider.
+ */
+fun isExternalStorageDocument(uri: Uri): Boolean {
+ return "com.android.externalstorage.documents" == uri.authority
+}
+
+/**
+ * @param uri The Uri to check.
+ * @return Whether the Uri authority is DownloadsProvider.
+ */
+fun isDownloadsDocument(uri: Uri): Boolean {
+ return "com.android.providers.downloads.documents" == uri.authority
+}
+
+/**
+ * @param uri The Uri to check.
+ * @return Whether the Uri authority is MediaProvider.
+ */
+fun isMediaDocument(uri: Uri): Boolean {
+ return "com.android.providers.media.documents" == uri.authority
+}
+
+/**
+ * @param uri The Uri to check.
+ * @return Whether the Uri authority is Google Photos.
+ */
+fun isGooglePhotosUri(uri: Uri): Boolean {
+ return "com.google.android.apps.photos.content" == uri.authority
+}
diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt
index 7c0d819c..81778ad9 100644
--- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt
+++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt
@@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator
import com.shifthackz.aisdv1.core.validation.common.CommonStringValidatorImpl
import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator
import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidatorImpl
+import com.shifthackz.aisdv1.core.validation.path.FilePathValidator
+import com.shifthackz.aisdv1.core.validation.path.FilePathValidatorImpl
import com.shifthackz.aisdv1.core.validation.url.UrlValidator
import com.shifthackz.aisdv1.core.validation.url.UrlValidatorImpl
import org.koin.core.module.dsl.factoryOf
@@ -16,4 +18,5 @@ val validatorsModule = module {
factory { UrlValidatorImpl() }
factoryOf(::CommonStringValidatorImpl) bind CommonStringValidator::class
+ factoryOf(::FilePathValidatorImpl) bind FilePathValidator::class
}
diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt
new file mode 100644
index 00000000..6d09d4e3
--- /dev/null
+++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt
@@ -0,0 +1,13 @@
+package com.shifthackz.aisdv1.core.validation.path
+
+import com.shifthackz.aisdv1.core.validation.ValidationResult
+
+interface FilePathValidator {
+
+ operator fun invoke(input: String?): ValidationResult
+
+ sealed interface Error {
+ data object Empty : Error
+ data object Invalid : Error
+ }
+}
diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt
new file mode 100644
index 00000000..ef3c8e96
--- /dev/null
+++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt
@@ -0,0 +1,23 @@
+package com.shifthackz.aisdv1.core.validation.path
+
+import com.shifthackz.aisdv1.core.validation.ValidationResult
+
+class FilePathValidatorImpl : FilePathValidator {
+
+ override fun invoke(input: String?): ValidationResult = when {
+ input.isNullOrBlank() -> ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Empty,
+ )
+ !isValidFilePath(input) -> ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Invalid,
+ )
+ else -> ValidationResult(true)
+ }
+
+ private fun isValidFilePath(path: String): Boolean {
+ val regex = Regex("^(/[^/<>:\"|?*]+)+/?$")
+ return regex.matches(path)
+ }
+}
diff --git a/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt
new file mode 100644
index 00000000..e33c2012
--- /dev/null
+++ b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt
@@ -0,0 +1,57 @@
+package com.shifthackz.aisdv1.core.validation.path
+
+import com.shifthackz.aisdv1.core.validation.ValidationResult
+import org.junit.Assert
+import org.junit.Test
+
+class FilePathValidatorImplTest {
+
+ private val validator = FilePathValidatorImpl()
+
+ @Test
+ fun `iven input is null, expected not valid with Empty error`() {
+ val expected = ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Empty,
+ )
+ val actual = validator(null)
+ Assert.assertEquals(expected, actual)
+ }
+
+ @Test
+ fun `given input is empty, expected not valid with Empty error`() {
+ val expected = ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Empty,
+ )
+ val actual = validator("")
+ Assert.assertEquals(expected, actual)
+ }
+
+ @Test
+ fun `given input is blank, expected not valid with Empty error`() {
+ val expected = ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Empty,
+ )
+ val actual = validator(" ")
+ Assert.assertEquals(expected, actual)
+ }
+
+ @Test
+ fun `given input is not valid, expected not valid with Invalid error`() {
+ val expected = ValidationResult(
+ isValid = false,
+ validationError = FilePathValidator.Error.Invalid,
+ )
+ val actual = validator("cc")
+ Assert.assertEquals(expected, actual)
+ }
+
+ @Test
+ fun `given input is valid, expected valid`() {
+ val expected = ValidationResult(true)
+ val actual = validator("/tmp/local/5598")
+ Assert.assertEquals(expected, actual)
+ }
+}
diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
index 1a41e3a5..1fffb4d4 100644
--- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
+++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt
@@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.data.preference
import android.content.SharedPreferences
import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes
import com.shifthackz.aisdv1.core.common.extensions.shouldUseNewMediaStore
+import com.shifthackz.aisdv1.core.common.file.LOCAL_DIFFUSION_CUSTOM_PATH
import com.shifthackz.aisdv1.core.common.schedulers.SchedulersToken
import com.shifthackz.aisdv1.domain.entity.ColorToken
import com.shifthackz.aisdv1.domain.entity.DarkThemeToken
@@ -59,6 +60,15 @@ class PreferenceManagerImpl(
.apply()
.also { onPreferencesChanged() }
+ override var localDiffusionCustomModelPath: String
+ get() = preferences.getString(
+ KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH,
+ LOCAL_DIFFUSION_CUSTOM_PATH,
+ ) ?: LOCAL_DIFFUSION_CUSTOM_PATH
+ set(value) = preferences.edit()
+ .putString(KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, value)
+ .apply()
+
override var localDiffusionAllowCancel: Boolean
get() = preferences.getBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, false)
set(value) = preferences.edit()
@@ -285,6 +295,7 @@ class PreferenceManagerImpl(
const val KEY_SWARM_MODEL = "key_swarm_model"
const val KEY_DEMO_MODE = "key_demo_mode"
const val KEY_DEVELOPER_MODE = "key_developer_mode"
+ const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path"
const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel"
const val KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD = "key_local_diffusion_scheduler_thread"
const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity"
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
index e644c6f8..fcf04c19 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt
@@ -16,4 +16,5 @@ data class Configuration(
val stabilityAiEngineId: String = "",
val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None,
val localModelId: String = "",
+ val localModelPath: String = "",
)
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
index 2ea33eef..50e0cda5 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt
@@ -12,6 +12,7 @@ interface PreferenceManager {
var swarmUiModel: String
var demoMode: Boolean
var developerMode: Boolean
+ var localDiffusionCustomModelPath: String
var localDiffusionAllowCancel: Boolean
var localDiffusionSchedulerThread: SchedulersToken
var monitorConnectivity: Boolean
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
index a46f4bd7..4070a0a6 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt
@@ -25,6 +25,7 @@ internal class GetConfigurationUseCaseImpl(
stabilityAiEngineId = preferenceManager.stabilityAiEngineId,
authCredentials = authorizationStore.getAuthorizationCredentials(),
localModelId = preferenceManager.localModelId,
+ localModelPath = preferenceManager.localDiffusionCustomModelPath,
)
)
}
diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
index 83ceab34..8d03619c 100644
--- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
+++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt
@@ -25,5 +25,6 @@ internal class SetServerConfigurationUseCaseImpl(
preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey
preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId
preferenceManager.localModelId = configuration.localModelId
+ preferenceManager.localDiffusionCustomModelPath = configuration.localModelPath
}
}
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
index 114e97b4..73503575 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt
@@ -13,4 +13,5 @@ val mockConfiguration = Configuration(
stabilityAiApiKey = "5598",
stabilityAiEngineId = "5598",
localModelId = "5598",
+ localModelPath = "/storage/emulated/0/5598",
)
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
index 9ce3e0af..55290a49 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt
@@ -72,6 +72,10 @@ class GetConfigurationUseCaseImplTest {
stubPreferenceManager::localModelId.get()
} returns mockConfiguration.localModelId
+ every {
+ stubPreferenceManager::localDiffusionCustomModelPath.get()
+ } returns mockConfiguration.localModelPath
+
useCase
.invoke()
.test()
diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
index e41c7ab7..62a74c34 100644
--- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
+++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt
@@ -71,6 +71,10 @@ class SetServerConfigurationUseCaseImplTest {
stubPreferenceManager::localModelId.set(any())
} returns Unit
+ every {
+ stubPreferenceManager::localDiffusionCustomModelPath.set(any())
+ } returns Unit
+
useCase
.invoke(mockConfiguration)
.test()
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt
index e539317b..0683aeff 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt
@@ -10,6 +10,7 @@ import android.util.Pair
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
import com.shifthackz.aisdv1.core.common.log.debugLog
import com.shifthackz.aisdv1.core.common.log.errorLog
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_INPUT_IDS
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT
@@ -32,6 +33,7 @@ internal class EnglishTextTokenizer(
private val ortEnvironmentProvider: OrtEnvironmentProvider,
private val fileProviderDescriptor: FileProviderDescriptor,
private val localModelIdProvider: LocalModelIdProvider,
+ private val preferenceManager: PreferenceManager,
) : LocalDiffusionTextTokenizer {
private val pattern = Pattern.compile(TOKENIZER_REGEX)
@@ -53,7 +55,7 @@ internal class EnglishTextTokenizer(
val options = OrtSession.SessionOptions()
options.addConfigEntry(ORT_KEY_MODEL_FORMAT, ORT)
session = ortEnvironmentProvider.get().createSession(
- "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}",
+ "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}",
options
)
debugLog("{$TAG} {TOKENIZER} {initialize} Session created successfully!")
@@ -229,7 +231,7 @@ internal class EnglishTextTokenizer(
private fun loadEncoder(): Map {
val map: MutableMap = HashMap()
try {
- val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}"
+ val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}"
val jsonReader = JsonReader(InputStreamReader(FileInputStream(path)))
jsonReader.beginObject()
while (jsonReader.hasNext()) {
@@ -255,7 +257,7 @@ internal class EnglishTextTokenizer(
private fun loadBpeRanks(): Map, Int?> {
val result: MutableMap, Int?> = HashMap()
try {
- val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}"
+ val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}"
val reader = BufferedReader(InputStreamReader(FileInputStream(path)))
var line: String
var startLine = 1
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt
index 96d427f2..85cc131f 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt
@@ -10,6 +10,7 @@ import android.graphics.Bitmap
import android.util.Pair
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
import com.shifthackz.aisdv1.core.common.log.debugLog
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_LATENT_SAMPLE
@@ -43,6 +44,7 @@ internal class UNet(
private val ortEnvironmentProvider: OrtEnvironmentProvider,
private val fileProviderDescriptor: FileProviderDescriptor,
private val localModelIdProvider: LocalModelIdProvider,
+ private val preferenceManager: PreferenceManager,
) {
private var decoder: VaeDecoder? = null
@@ -61,6 +63,7 @@ internal class UNet(
ortEnvironmentProvider,
fileProviderDescriptor,
localModelIdProvider,
+ preferenceManager,
deviceNNAPIFlagProvider.get(),
)
val options = SessionOptions()
@@ -69,7 +72,7 @@ internal class UNet(
options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED))
}
session = ortEnvironmentProvider.get().createSession(
- "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}",
+ "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}",
options
)
}
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt
index 3a15a436..5dea5b55 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt
@@ -7,13 +7,14 @@ import ai.onnxruntime.providers.NNAPIFlags
import android.graphics.Bitmap
import android.graphics.Color
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT
import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT
import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D
-import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider
import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag
import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider
+import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider
import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix
import java.util.EnumSet
import kotlin.math.roundToInt
@@ -22,6 +23,7 @@ internal class VaeDecoder(
private val ortEnvironmentProvider: OrtEnvironmentProvider,
private val fileProviderDescriptor: FileProviderDescriptor,
private val localModelIdProvider: LocalModelIdProvider,
+ private val preferenceManager: PreferenceManager,
private val deviceId: Int,
) {
@@ -67,7 +69,7 @@ internal class VaeDecoder(
options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED))
}
session = ortEnvironmentProvider.get().createSession(
- "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}",
+ "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}",
options
)
}
diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
index d4faaddd..7cb6df87 100644
--- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
+++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt
@@ -2,17 +2,17 @@ package com.shifthackz.aisdv1.feature.diffusion.extensions
import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
+import com.shifthackz.aisdv1.domain.preference.PreferenceManager
import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider
-private const val PATH = "/storage/emulated/0/Download/SDAI/model"
-
fun modelPathPrefix(
+ preferenceManager: PreferenceManager,
fileProviderDescriptor: FileProviderDescriptor,
localModelIdProvider: LocalModelIdProvider,
): String {
- val modelId = localModelIdProvider.get();
+ val modelId = localModelIdProvider.get()
return if (modelId == LocalAiModel.CUSTOM.id) {
- PATH
+ preferenceManager.localDiffusionCustomModelPath
} else {
"${fileProviderDescriptor.localModelDirPath}/${modelId}"
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt
index a9a1aac4..a22190ec 100755
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt
@@ -62,6 +62,7 @@ val viewModelModule = module {
fetchAndGetHuggingFaceModelsUseCase = get(),
urlValidator = get(),
stringValidator = get(),
+ filePathValidator = get(),
setupConnectionInterActor = get(),
downloadModelUseCase = get(),
deleteModelUseCase = get(),
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt
index 80d51618..425f9ef6 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt
@@ -36,6 +36,8 @@ sealed interface ServerSetupIntent : MviIntent {
data class UpdateHordeDefaultApiKey(val value: Boolean) : ServerSetupIntent
+ data class SelectLocalModelPath(val value: String) : ServerSetupIntent
+
data class SelectLocalModel(val model: ServerSetupState.LocalModel) : ServerSetupIntent
data class AllowLocalCustomModel(val allow: Boolean) : ServerSetupIntent
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
index 555a4d32..7e91bbf6 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt
@@ -35,6 +35,7 @@ data class ServerSetupState(
val huggingFaceModel: String = "",
val localModels: List = emptyList(),
val localCustomModel: Boolean = false,
+ val localCustomModelPath: String = "",
val passwordVisible: Boolean = false,
val serverUrlValidationError: UiText? = null,
val swarmUiUrlValidationError: UiText? = null,
@@ -44,6 +45,7 @@ data class ServerSetupState(
val huggingFaceApiKeyValidationError: UiText? = null,
val openAiApiKeyValidationError: UiText? = null,
val stabilityAiApiKeyValidationError: UiText? = null,
+ val localCustomModelPathValidationError: UiText? = null,
) : MviState, KoinComponent {
val demoModeUrl: String
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
index f0ce3e03..d02f1287 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt
@@ -5,6 +5,7 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider
import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread
import com.shifthackz.aisdv1.core.model.asUiText
import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator
+import com.shifthackz.aisdv1.core.validation.path.FilePathValidator
import com.shifthackz.aisdv1.core.validation.url.UrlValidator
import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel
import com.shifthackz.aisdv1.domain.entity.DownloadState
@@ -37,6 +38,7 @@ class ServerSetupViewModel(
fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase,
private val urlValidator: UrlValidator,
private val stringValidator: CommonStringValidator,
+ private val filePathValidator: FilePathValidator,
private val setupConnectionInterActor: SetupConnectionInterActor,
private val downloadModelUseCase: DownloadModelUseCase,
private val deleteModelUseCase: DeleteModelUseCase,
@@ -82,6 +84,7 @@ class ServerSetupViewModel(
stabilityAiApiKey = configuration.stabilityAiApiKey,
localModels = localModels.mapToUi(),
localCustomModel = localModels.mapLocalCustomModelSwitchState(),
+ localCustomModelPath = configuration.localModelPath,
mode = configuration.source,
demoMode = configuration.demoMode,
serverUrl = configuration.serverUrl,
@@ -227,6 +230,13 @@ class ServerSetupViewModel(
}
ServerSetupIntent.ConnectToLocalHost -> connectToServer()
+
+ is ServerSetupIntent.SelectLocalModelPath -> updateState {
+ it.copy(
+ localCustomModelPath = intent.value,
+ localCustomModelPathValidationError = null,
+ )
+ }
}
private fun validateAndConnectToServer() {
@@ -278,7 +288,15 @@ class ServerSetupViewModel(
}
ServerSource.LOCAL -> {
- currentState.localModels.find { it.selected && it.downloaded } != null
+ if (currentState.localCustomModel) {
+ val validation = filePathValidator(currentState.localCustomModelPath)
+ updateState {
+ it.copy(localCustomModelPathValidationError = validation.mapToUi())
+ }
+ validation.isValid
+ } else {
+ currentState.localModels.find { it.selected && it.downloaded } != null
+ }
}
ServerSource.HUGGING_FACE -> {
@@ -382,6 +400,7 @@ class ServerSetupViewModel(
}
private fun connectToLocalDiffusion(): Single> {
+ preferenceManager.localDiffusionCustomModelPath = currentState.localCustomModelPath
val localModelId = currentState.localModels.find { it.selected }?.id ?: ""
return setupConnectionInterActor.connectToLocal(localModelId)
}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
index df5819ce..6ebfed22 100644
--- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt
@@ -1,5 +1,9 @@
package com.shifthackz.aisdv1.presentation.screen.setup.forms
+import android.content.Intent
+import android.provider.DocumentsContract
+import androidx.activity.compose.rememberLauncherForActivityResult
+import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.background
import androidx.compose.foundation.border
import androidx.compose.foundation.clickable
@@ -15,23 +19,27 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.size
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.material.icons.Icons
+import androidx.compose.material.icons.filled.Refresh
import androidx.compose.material.icons.outlined.FileDownload
import androidx.compose.material.icons.outlined.FileDownloadDone
import androidx.compose.material.icons.outlined.FileDownloadOff
import androidx.compose.material.icons.outlined.Landslide
import androidx.compose.material3.Button
import androidx.compose.material3.Icon
+import androidx.compose.material3.IconButton
import androidx.compose.material3.LinearProgressIndicator
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedButton
import androidx.compose.material3.Switch
import androidx.compose.material3.Text
+import androidx.compose.material3.TextField
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color
+import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.testTag
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.text.font.FontWeight
@@ -41,6 +49,9 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.times
import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider
import com.shifthackz.aisdv1.core.common.appbuild.BuildType
+import com.shifthackz.aisdv1.core.common.file.LOCAL_DIFFUSION_CUSTOM_PATH
+import com.shifthackz.aisdv1.core.extensions.getRealPath
+import com.shifthackz.aisdv1.core.model.asString
import com.shifthackz.aisdv1.domain.entity.DownloadState
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent
@@ -150,77 +161,67 @@ fun LocalDiffusionForm(
Modifier.padding(start = (treeNum - 1) * 12.dp)
val folderStyle = MaterialTheme.typography.bodySmall
- Text(
- modifier = folderModifier(1),
- text = "Download",
- style = folderStyle,
- )
Text(
modifier = Modifier.padding(start = 12.dp),
- text = "SDAI",
- style = folderStyle,
- )
- Text(
- modifier = folderModifier(3),
- text = "model",
+ text = state.localCustomModelPath,
style = folderStyle,
)
Text(
- modifier = folderModifier(4),
+ modifier = folderModifier(3),
text = "text_encoder",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "model.ort",
style = folderStyle,
)
Text(
- modifier = folderModifier(4),
+ modifier = folderModifier(3),
text = "tokenizer",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "merges.txt",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(3),
text = "special_tokens_map.json",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "tokenizer_config.json",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "vocab.json",
style = folderStyle,
)
Text(
- modifier = folderModifier(4),
+ modifier = folderModifier(3),
text = "unet",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "model.ort",
style = folderStyle,
)
Text(
- modifier = folderModifier(4),
+ modifier = folderModifier(3),
text = "vae_decoder",
style = folderStyle,
)
Text(
- modifier = folderModifier(5),
+ modifier = folderModifier(4),
text = "model.ort",
style = folderStyle,
)
@@ -285,6 +286,15 @@ fun LocalDiffusionForm(
}
}
if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) {
+ Text(
+ modifier = Modifier
+ .align(Alignment.CenterHorizontally)
+ .padding(vertical = 8.dp),
+ text = stringResource(id = LocalizationR.string.model_local_permission_header),
+ style = MaterialTheme.typography.bodyLarge,
+ textAlign = TextAlign.Center,
+ fontWeight = FontWeight.Bold,
+ )
Text(
modifier = Modifier.padding(vertical = 8.dp),
text = stringResource(id = LocalizationR.string.model_local_permission_title),
@@ -301,6 +311,79 @@ fun LocalDiffusionForm(
color = LocalContentColor.current,
)
}
+ Text(
+ modifier = Modifier
+ .align(Alignment.CenterHorizontally)
+ .padding(vertical = 8.dp),
+ text = stringResource(id = LocalizationR.string.model_local_path_header),
+ style = MaterialTheme.typography.bodyLarge,
+ textAlign = TextAlign.Center,
+ fontWeight = FontWeight.Bold,
+ )
+ val context = LocalContext.current
+ val uriFlags =
+ Intent.FLAG_GRANT_READ_URI_PERMISSION or Intent.FLAG_GRANT_WRITE_URI_PERMISSION
+ val launcher = rememberLauncherForActivityResult(
+ contract = ActivityResultContracts.StartActivityForResult()
+ ) { result ->
+ result.data?.data?.let { uri ->
+ context.contentResolver.takePersistableUriPermission(uri, uriFlags)
+ val docUri = DocumentsContract.buildDocumentUriUsingTree(
+ uri,
+ DocumentsContract.getTreeDocumentId(uri)
+ )
+ getRealPath(context, docUri)
+ ?.let(ServerSetupIntent::SelectLocalModelPath)
+ ?.let(processIntent::invoke)
+ }
+ }
+ TextField(
+ modifier = Modifier
+ .fillMaxWidth()
+ .padding(top = 14.dp),
+ value = state.localCustomModelPath,
+ onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) },
+ enabled = true,
+ singleLine = true,
+ label = { Text(stringResource(LocalizationR.string.model_local_path_title)) },
+ trailingIcon = {
+ IconButton(
+ onClick = {
+ processIntent(
+ ServerSetupIntent.SelectLocalModelPath(LOCAL_DIFFUSION_CUSTOM_PATH)
+ )
+ },
+ content = {
+ Icon(
+ imageVector = Icons.Default.Refresh,
+ contentDescription = "Reset",
+ )
+ },
+ )
+ },
+ isError = state.localCustomModelPathValidationError != null,
+ supportingText = {
+ state.localCustomModelPathValidationError
+ ?.let { Text(it.asString(), color = MaterialTheme.colorScheme.error) }
+ },
+ )
+ OutlinedButton(
+ modifier = Modifier
+ .fillMaxSize()
+ .padding(top = 4.dp, bottom = 8.dp),
+ onClick = {
+ val intent = Intent(Intent.ACTION_OPEN_DOCUMENT_TREE).apply {
+ addFlags(uriFlags)
+ }
+ launcher.launch(intent)
+ },
+ ) {
+ Text(
+ text = stringResource(id = LocalizationR.string.model_local_path_button),
+ color = LocalContentColor.current,
+ )
+ }
+ Spacer(modifier = Modifier.height(8.dp))
}
state.localModels
.filter {
@@ -314,4 +397,4 @@ fun LocalDiffusionForm(
style = MaterialTheme.typography.bodyMedium,
)
}
-}
\ No newline at end of file
+}
diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt
new file mode 100644
index 00000000..08467bf9
--- /dev/null
+++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt
@@ -0,0 +1,15 @@
+package com.shifthackz.aisdv1.presentation.screen.setup.mappers
+
+import com.shifthackz.aisdv1.core.localization.R
+import com.shifthackz.aisdv1.core.model.UiText
+import com.shifthackz.aisdv1.core.model.asUiText
+import com.shifthackz.aisdv1.core.validation.ValidationResult
+import com.shifthackz.aisdv1.core.validation.path.FilePathValidator
+
+fun ValidationResult.mapToUi(): UiText? {
+ if (this.isValid) return null
+ return when (validationError as FilePathValidator.Error) {
+ FilePathValidator.Error.Empty -> R.string.error_empty_field
+ FilePathValidator.Error.Invalid -> R.string.error_invalid
+ }.asUiText()
+}
diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
index 3097b797..178d0c6a 100644
--- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
+++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt
@@ -1,6 +1,7 @@
package com.shifthackz.aisdv1.presentation.screen.setup
import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator
+import com.shifthackz.aisdv1.core.validation.path.FilePathValidator
import com.shifthackz.aisdv1.core.validation.url.UrlValidator
import com.shifthackz.aisdv1.domain.entity.Configuration
import com.shifthackz.aisdv1.domain.entity.DownloadState
@@ -38,6 +39,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
private val stubFetchAndGetHuggingFaceModelsUseCase = mockk()
private val stubUrlValidator = mockk()
private val stubCommonStringValidator = mockk()
+ private val stubFilePathValidator = mockk()
private val stubSetupConnectionInterActor = mockk()
private val stubDownloadModelUseCase = mockk()
private val stubDeleteModelUseCase = mockk()
@@ -52,6 +54,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() {
fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase,
urlValidator = stubUrlValidator,
stringValidator = stubCommonStringValidator,
+ filePathValidator = stubFilePathValidator,
setupConnectionInterActor = stubSetupConnectionInterActor,
downloadModelUseCase = stubDownloadModelUseCase,
deleteModelUseCase = stubDeleteModelUseCase,