diff --git a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt index 2cae5660..22ce7dbb 100644 --- a/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt +++ b/core/common/src/main/java/com/shifthackz/aisdv1/core/common/file/FileProviderDescriptor.kt @@ -1,5 +1,7 @@ package com.shifthackz.aisdv1.core.common.file +const val LOCAL_DIFFUSION_CUSTOM_PATH = "/storage/emulated/0/Download/SDAI/model" + interface FileProviderDescriptor { val providerPath: String val imagesCacheDirPath: String diff --git a/core/localization/src/main/res/values-ru/strings.xml b/core/localization/src/main/res/values-ru/strings.xml index 726aaa23..a9dd642b 100644 --- a/core/localization/src/main/res/values-ru/strings.xml +++ b/core/localization/src/main/res/values-ru/strings.xml @@ -295,10 +295,14 @@ Нажмите, чтобы открыть приложение. Загрузка кастом модели + Разрешения Чтобы иметь возможность загружать пользовательскую модель, вам необходимо разрешить приложению SDAI управлять разрешениями на хранилище, поскольку, начиная с Android 11, оно необходимо для доступа к файлам хранилища без области действия. Настроить доступ + Путь к модели + Путь к папке локальной модели + Выбрать папку - Чтобы использовать локальную пользовательскую модель, поместите ее в локальную папку в памяти телефона: Download/SDAi/model + Чтобы использовать локальную пользовательскую модель, поместите ее в локальную папку в памяти телефона. Окончательная структура папок должна быть такой: Отладка diff --git a/core/localization/src/main/res/values-tr/strings.xml b/core/localization/src/main/res/values-tr/strings.xml index 1c4a97ae..815b72d9 100644 --- a/core/localization/src/main/res/values-tr/strings.xml +++ b/core/localization/src/main/res/values-tr/strings.xml @@ -298,8 +298,12 @@ Özel modeli yükleyebilmek için SDAI uygulamasının depolama izinlerini yönetmesine izin vermeniz gerekir; çünkü Android 11\'den itibaren kapsamlı olmayan depolama dosyalarına erişmek gerekir. Kurulum izni - Yerel özel modeli kullanmak için telefonunuzun depolama alanındaki yerel klasöre yerleştirin: Download/SDAi/model + Yerel özel modeli kullanmak için telefonunuzun depolama alanındaki yerel klasöre yerleştirin. + İzinler Son klasör yapısı şu şekilde olmalıdır:: + Model yolu + Yerel model klasör yolu + Klasör seç Hata ayıklama Work Manager API diff --git a/core/localization/src/main/res/values-uk/strings.xml b/core/localization/src/main/res/values-uk/strings.xml index 8b901cd4..1ff9adf6 100644 --- a/core/localization/src/main/res/values-uk/strings.xml +++ b/core/localization/src/main/res/values-uk/strings.xml @@ -298,8 +298,12 @@ Щоб мати можливість завантажити спеціальну модель, вам потрібно дозволити додатку SDAI керувати дозволами на зберігання, оскільки, починаючи з Android 11, це потрібно для доступу до файлів зберігання без обмежень. Налаштувати доступ - Щоб використовувати локальну спеціальну модель, помістіть її в локальну папку в пам’яті телефону: Download/SDAi/model + Щоб використовувати локальну спеціальну модель, помістіть її в локальну папку в пам’яті телефону. + Дозволи Остаточна структура папок має бути такою: + Шлях моделі + Шлях папки локальної моделі + Виберіть папку Відладка Work Manager API diff --git a/core/localization/src/main/res/values-zh/strings.xml b/core/localization/src/main/res/values-zh/strings.xml index 0849958f..2199fbcd 100644 --- a/core/localization/src/main/res/values-zh/strings.xml +++ b/core/localization/src/main/res/values-zh/strings.xml @@ -358,8 +358,12 @@ 加载自定义模型 为了能够加载自定义模型,您需要允许SDAI应用管理存储权限,因为从Android 11开始,它需要访问非范围存储文件。 设置权限 + 模型路径 + 本地模型文件夹路径 + 选择文件夹 - 要使用本地自定义模型,请将其放置在手机存储中的本地文件夹:Download/SDAi/model + 要使用本地自定义模型,请将其放置在手机存储中的本地文件夹。 + 权限 最终的文件夹结构应该是: diff --git a/core/localization/src/main/res/values/strings.xml b/core/localization/src/main/res/values/strings.xml index bd1b8e31..f455448f 100755 --- a/core/localization/src/main/res/values/strings.xml +++ b/core/localization/src/main/res/values/strings.xml @@ -317,10 +317,14 @@ Local Diffusion Load custom model + Permissions To be able to load custom model, you need to allow SDAI app manage storage permissions, because starting from Android 11 it is needed to access non-scoped storage files. Setup permission + Model path + Local model folder path + Select folder - To use local custom model, place it to local folder in your phone storage: Download/SDAi/model + To use local custom model, place it to local folder in your phone storage. The final folder structure should be: Debugging diff --git a/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt b/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt new file mode 100644 index 00000000..395f73cd --- /dev/null +++ b/core/ui/src/main/java/com/shifthackz/aisdv1/core/extensions/RealPathExtensions.kt @@ -0,0 +1,123 @@ +package com.shifthackz.aisdv1.core.extensions + +import android.content.ContentUris +import android.content.Context +import android.database.Cursor +import android.net.Uri +import android.os.Environment +import android.provider.DocumentsContract +import android.provider.MediaStore + +fun getRealPath(context: Context, uri: Uri): String? { + if (DocumentsContract.isDocumentUri(context, uri)) { + if (isExternalStorageDocument(uri)) { + val docId = DocumentsContract.getDocumentId(uri) + val split = docId.split(":".toRegex()).dropLastWhile { it.isEmpty() } + .toTypedArray() + val type = split[0] + + if ("primary".equals(type, ignoreCase = true)) { + return Environment.getExternalStorageDirectory().toString() + "/" + split[1] + } else { + return "/storage/$type/${split[1]}" + } + + } else if (isDownloadsDocument(uri)) { + val id = DocumentsContract.getDocumentId(uri) + val contentUri = ContentUris.withAppendedId( + Uri.parse("content://downloads/public_downloads"), id.toLong() + ) + + return getDataColumn(context, contentUri, null, null) + } else if (isMediaDocument(uri)) { + val docId = DocumentsContract.getDocumentId(uri) + val split = docId.split(":".toRegex()).dropLastWhile { it.isEmpty() } + .toTypedArray() + val type = split[0] + + var contentUri: Uri? = null + if ("image" == type) { + contentUri = MediaStore.Images.Media.EXTERNAL_CONTENT_URI + } else if ("video" == type) { + contentUri = MediaStore.Video.Media.EXTERNAL_CONTENT_URI + } else if ("audio" == type) { + contentUri = MediaStore.Audio.Media.EXTERNAL_CONTENT_URI + } + + val selection = "_id=?" + val selectionArgs = arrayOf( + split[1] + ) + + return getDataColumn(context, contentUri, selection, selectionArgs) + } + } else if ("content".equals(uri.scheme, ignoreCase = true)) { + // Return the remote address + + if (isGooglePhotosUri(uri)) return uri.lastPathSegment + + return getDataColumn(context, uri, null, null) + } else if ("file".equals(uri.scheme, ignoreCase = true)) { + return uri.path + } + + return null +} + +fun getDataColumn( + context: Context, uri: Uri?, selection: String?, + selectionArgs: Array? +): String? { + var cursor: Cursor? = null + val column = "_data" + val projection = arrayOf( + column + ) + + try { + cursor = context.contentResolver.query( + uri!!, projection, selection, selectionArgs, + null + ) + if (cursor != null && cursor.moveToFirst()) { + val index = cursor.getColumnIndexOrThrow(column) + return cursor.getString(index) + } + } finally { + cursor?.close() + } + return null +} + + +/** + * @param uri The Uri to check. + * @return Whether the Uri authority is ExternalStorageProvider. + */ +fun isExternalStorageDocument(uri: Uri): Boolean { + return "com.android.externalstorage.documents" == uri.authority +} + +/** + * @param uri The Uri to check. + * @return Whether the Uri authority is DownloadsProvider. + */ +fun isDownloadsDocument(uri: Uri): Boolean { + return "com.android.providers.downloads.documents" == uri.authority +} + +/** + * @param uri The Uri to check. + * @return Whether the Uri authority is MediaProvider. + */ +fun isMediaDocument(uri: Uri): Boolean { + return "com.android.providers.media.documents" == uri.authority +} + +/** + * @param uri The Uri to check. + * @return Whether the Uri authority is Google Photos. + */ +fun isGooglePhotosUri(uri: Uri): Boolean { + return "com.google.android.apps.photos.content" == uri.authority +} diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt index 7c0d819c..81778ad9 100644 --- a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/di/ValidatorsModule.kt @@ -4,6 +4,8 @@ import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator import com.shifthackz.aisdv1.core.validation.common.CommonStringValidatorImpl import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidator import com.shifthackz.aisdv1.core.validation.dimension.DimensionValidatorImpl +import com.shifthackz.aisdv1.core.validation.path.FilePathValidator +import com.shifthackz.aisdv1.core.validation.path.FilePathValidatorImpl import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidatorImpl import org.koin.core.module.dsl.factoryOf @@ -16,4 +18,5 @@ val validatorsModule = module { factory { UrlValidatorImpl() } factoryOf(::CommonStringValidatorImpl) bind CommonStringValidator::class + factoryOf(::FilePathValidatorImpl) bind FilePathValidator::class } diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt new file mode 100644 index 00000000..6d09d4e3 --- /dev/null +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidator.kt @@ -0,0 +1,13 @@ +package com.shifthackz.aisdv1.core.validation.path + +import com.shifthackz.aisdv1.core.validation.ValidationResult + +interface FilePathValidator { + + operator fun invoke(input: String?): ValidationResult + + sealed interface Error { + data object Empty : Error + data object Invalid : Error + } +} diff --git a/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt new file mode 100644 index 00000000..ef3c8e96 --- /dev/null +++ b/core/validation/src/main/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImpl.kt @@ -0,0 +1,23 @@ +package com.shifthackz.aisdv1.core.validation.path + +import com.shifthackz.aisdv1.core.validation.ValidationResult + +class FilePathValidatorImpl : FilePathValidator { + + override fun invoke(input: String?): ValidationResult = when { + input.isNullOrBlank() -> ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Empty, + ) + !isValidFilePath(input) -> ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Invalid, + ) + else -> ValidationResult(true) + } + + private fun isValidFilePath(path: String): Boolean { + val regex = Regex("^(/[^/<>:\"|?*]+)+/?$") + return regex.matches(path) + } +} diff --git a/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt new file mode 100644 index 00000000..e33c2012 --- /dev/null +++ b/core/validation/src/test/java/com/shifthackz/aisdv1/core/validation/path/FilePathValidatorImplTest.kt @@ -0,0 +1,57 @@ +package com.shifthackz.aisdv1.core.validation.path + +import com.shifthackz.aisdv1.core.validation.ValidationResult +import org.junit.Assert +import org.junit.Test + +class FilePathValidatorImplTest { + + private val validator = FilePathValidatorImpl() + + @Test + fun `iven input is null, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Empty, + ) + val actual = validator(null) + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is empty, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Empty, + ) + val actual = validator("") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is blank, expected not valid with Empty error`() { + val expected = ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Empty, + ) + val actual = validator(" ") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is not valid, expected not valid with Invalid error`() { + val expected = ValidationResult( + isValid = false, + validationError = FilePathValidator.Error.Invalid, + ) + val actual = validator("cc") + Assert.assertEquals(expected, actual) + } + + @Test + fun `given input is valid, expected valid`() { + val expected = ValidationResult(true) + val actual = validator("/tmp/local/5598") + Assert.assertEquals(expected, actual) + } +} diff --git a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt index 1a41e3a5..1fffb4d4 100644 --- a/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt +++ b/data/src/main/java/com/shifthackz/aisdv1/data/preference/PreferenceManagerImpl.kt @@ -3,6 +3,7 @@ package com.shifthackz.aisdv1.data.preference import android.content.SharedPreferences import com.shifthackz.aisdv1.core.common.extensions.fixUrlSlashes import com.shifthackz.aisdv1.core.common.extensions.shouldUseNewMediaStore +import com.shifthackz.aisdv1.core.common.file.LOCAL_DIFFUSION_CUSTOM_PATH import com.shifthackz.aisdv1.core.common.schedulers.SchedulersToken import com.shifthackz.aisdv1.domain.entity.ColorToken import com.shifthackz.aisdv1.domain.entity.DarkThemeToken @@ -59,6 +60,15 @@ class PreferenceManagerImpl( .apply() .also { onPreferencesChanged() } + override var localDiffusionCustomModelPath: String + get() = preferences.getString( + KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, + LOCAL_DIFFUSION_CUSTOM_PATH, + ) ?: LOCAL_DIFFUSION_CUSTOM_PATH + set(value) = preferences.edit() + .putString(KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH, value) + .apply() + override var localDiffusionAllowCancel: Boolean get() = preferences.getBoolean(KEY_ALLOW_LOCAL_DIFFUSION_CANCEL, false) set(value) = preferences.edit() @@ -285,6 +295,7 @@ class PreferenceManagerImpl( const val KEY_SWARM_MODEL = "key_swarm_model" const val KEY_DEMO_MODE = "key_demo_mode" const val KEY_DEVELOPER_MODE = "key_developer_mode" + const val KEY_LOCAL_DIFFUSION_CUSTOM_MODEL_PATH = "key_local_diffusion_custom_model_path" const val KEY_ALLOW_LOCAL_DIFFUSION_CANCEL = "key_allow_local_diffusion_cancel" const val KEY_LOCAL_DIFFUSION_SCHEDULER_THREAD = "key_local_diffusion_scheduler_thread" const val KEY_MONITOR_CONNECTIVITY = "key_monitor_connectivity" diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt index e644c6f8..fcf04c19 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/entity/Configuration.kt @@ -16,4 +16,5 @@ data class Configuration( val stabilityAiEngineId: String = "", val authCredentials: AuthorizationCredentials = AuthorizationCredentials.None, val localModelId: String = "", + val localModelPath: String = "", ) diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt index 2ea33eef..50e0cda5 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/preference/PreferenceManager.kt @@ -12,6 +12,7 @@ interface PreferenceManager { var swarmUiModel: String var demoMode: Boolean var developerMode: Boolean + var localDiffusionCustomModelPath: String var localDiffusionAllowCancel: Boolean var localDiffusionSchedulerThread: SchedulersToken var monitorConnectivity: Boolean diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt index a46f4bd7..4070a0a6 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImpl.kt @@ -25,6 +25,7 @@ internal class GetConfigurationUseCaseImpl( stabilityAiEngineId = preferenceManager.stabilityAiEngineId, authCredentials = authorizationStore.getAuthorizationCredentials(), localModelId = preferenceManager.localModelId, + localModelPath = preferenceManager.localDiffusionCustomModelPath, ) ) } diff --git a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt index 83ceab34..8d03619c 100644 --- a/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt +++ b/domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImpl.kt @@ -25,5 +25,6 @@ internal class SetServerConfigurationUseCaseImpl( preferenceManager.stabilityAiApiKey = configuration.stabilityAiApiKey preferenceManager.stabilityAiEngineId = configuration.stabilityAiEngineId preferenceManager.localModelId = configuration.localModelId + preferenceManager.localDiffusionCustomModelPath = configuration.localModelPath } } diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt index 114e97b4..73503575 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/mocks/ConfigurationMocks.kt @@ -13,4 +13,5 @@ val mockConfiguration = Configuration( stabilityAiApiKey = "5598", stabilityAiEngineId = "5598", localModelId = "5598", + localModelPath = "/storage/emulated/0/5598", ) diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt index 9ce3e0af..55290a49 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/GetConfigurationUseCaseImplTest.kt @@ -72,6 +72,10 @@ class GetConfigurationUseCaseImplTest { stubPreferenceManager::localModelId.get() } returns mockConfiguration.localModelId + every { + stubPreferenceManager::localDiffusionCustomModelPath.get() + } returns mockConfiguration.localModelPath + useCase .invoke() .test() diff --git a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt index e41c7ab7..62a74c34 100644 --- a/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt +++ b/domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/settings/SetServerConfigurationUseCaseImplTest.kt @@ -71,6 +71,10 @@ class SetServerConfigurationUseCaseImplTest { stubPreferenceManager::localModelId.set(any()) } returns Unit + every { + stubPreferenceManager::localDiffusionCustomModelPath.set(any()) + } returns Unit + useCase .invoke(mockConfiguration) .test() diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt index e539317b..0683aeff 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/tokenizer/EnglishTextTokenizer.kt @@ -10,6 +10,7 @@ import android.util.Pair import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.log.debugLog import com.shifthackz.aisdv1.core.common.log.errorLog +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_INPUT_IDS import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT @@ -32,6 +33,7 @@ internal class EnglishTextTokenizer( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, private val localModelIdProvider: LocalModelIdProvider, + private val preferenceManager: PreferenceManager, ) : LocalDiffusionTextTokenizer { private val pattern = Pattern.compile(TOKENIZER_REGEX) @@ -53,7 +55,7 @@ internal class EnglishTextTokenizer( val options = OrtSession.SessionOptions() options.addConfigEntry(ORT_KEY_MODEL_FORMAT, ORT) session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MODEL}", options ) debugLog("{$TAG} {TOKENIZER} {initialize} Session created successfully!") @@ -229,7 +231,7 @@ internal class EnglishTextTokenizer( private fun loadEncoder(): Map { val map: MutableMap = HashMap() try { - val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" + val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_VOCABULARY}" val jsonReader = JsonReader(InputStreamReader(FileInputStream(path))) jsonReader.beginObject() while (jsonReader.hasNext()) { @@ -255,7 +257,7 @@ internal class EnglishTextTokenizer( private fun loadBpeRanks(): Map, Int?> { val result: MutableMap, Int?> = HashMap() try { - val path = "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}" + val path = "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.TOKENIZER_MERGES}" val reader = BufferedReader(InputStreamReader(FileInputStream(path))) var line: String var startLine = 1 diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt index 96d427f2..85cc131f 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/unet/UNet.kt @@ -10,6 +10,7 @@ import android.graphics.Bitmap import android.util.Pair import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.core.common.log.debugLog +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_ENCODER_HIDDEN_STATES import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.KEY_LATENT_SAMPLE @@ -43,6 +44,7 @@ internal class UNet( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, private val localModelIdProvider: LocalModelIdProvider, + private val preferenceManager: PreferenceManager, ) { private var decoder: VaeDecoder? = null @@ -61,6 +63,7 @@ internal class UNet( ortEnvironmentProvider, fileProviderDescriptor, localModelIdProvider, + preferenceManager, deviceNNAPIFlagProvider.get(), ) val options = SessionOptions() @@ -69,7 +72,7 @@ internal class UNet( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.UNET_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt index 3a15a436..5dea5b55 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/ai/vae/VaeDecoder.kt @@ -7,13 +7,14 @@ import ai.onnxruntime.providers.NNAPIFlags import android.graphics.Bitmap import android.graphics.Color import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT import com.shifthackz.aisdv1.feature.diffusion.LocalDiffusionContract.ORT_KEY_MODEL_FORMAT import com.shifthackz.aisdv1.feature.diffusion.entity.Array3D -import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import com.shifthackz.aisdv1.feature.diffusion.entity.LocalDiffusionFlag import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider +import com.shifthackz.aisdv1.feature.diffusion.environment.OrtEnvironmentProvider import com.shifthackz.aisdv1.feature.diffusion.extensions.modelPathPrefix import java.util.EnumSet import kotlin.math.roundToInt @@ -22,6 +23,7 @@ internal class VaeDecoder( private val ortEnvironmentProvider: OrtEnvironmentProvider, private val fileProviderDescriptor: FileProviderDescriptor, private val localModelIdProvider: LocalModelIdProvider, + private val preferenceManager: PreferenceManager, private val deviceId: Int, ) { @@ -67,7 +69,7 @@ internal class VaeDecoder( options.addNnapi(EnumSet.of(NNAPIFlags.CPU_DISABLED)) } session = ortEnvironmentProvider.get().createSession( - "${modelPathPrefix(fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}", + "${modelPathPrefix(preferenceManager, fileProviderDescriptor, localModelIdProvider)}/${LocalDiffusionContract.VAE_MODEL}", options ) } diff --git a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt index d4faaddd..7cb6df87 100644 --- a/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt +++ b/feature/diffusion/src/main/java/com/shifthackz/aisdv1/feature/diffusion/extensions/LocalDiffusionPaths.kt @@ -2,17 +2,17 @@ package com.shifthackz.aisdv1.feature.diffusion.extensions import com.shifthackz.aisdv1.core.common.file.FileProviderDescriptor import com.shifthackz.aisdv1.domain.entity.LocalAiModel +import com.shifthackz.aisdv1.domain.preference.PreferenceManager import com.shifthackz.aisdv1.feature.diffusion.environment.LocalModelIdProvider -private const val PATH = "/storage/emulated/0/Download/SDAI/model" - fun modelPathPrefix( + preferenceManager: PreferenceManager, fileProviderDescriptor: FileProviderDescriptor, localModelIdProvider: LocalModelIdProvider, ): String { - val modelId = localModelIdProvider.get(); + val modelId = localModelIdProvider.get() return if (modelId == LocalAiModel.CUSTOM.id) { - PATH + preferenceManager.localDiffusionCustomModelPath } else { "${fileProviderDescriptor.localModelDirPath}/${modelId}" } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt index a9a1aac4..a22190ec 100755 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/di/ViewModelModule.kt @@ -62,6 +62,7 @@ val viewModelModule = module { fetchAndGetHuggingFaceModelsUseCase = get(), urlValidator = get(), stringValidator = get(), + filePathValidator = get(), setupConnectionInterActor = get(), downloadModelUseCase = get(), deleteModelUseCase = get(), diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt index 80d51618..425f9ef6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupIntent.kt @@ -36,6 +36,8 @@ sealed interface ServerSetupIntent : MviIntent { data class UpdateHordeDefaultApiKey(val value: Boolean) : ServerSetupIntent + data class SelectLocalModelPath(val value: String) : ServerSetupIntent + data class SelectLocalModel(val model: ServerSetupState.LocalModel) : ServerSetupIntent data class AllowLocalCustomModel(val allow: Boolean) : ServerSetupIntent diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt index 555a4d32..7e91bbf6 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupState.kt @@ -35,6 +35,7 @@ data class ServerSetupState( val huggingFaceModel: String = "", val localModels: List = emptyList(), val localCustomModel: Boolean = false, + val localCustomModelPath: String = "", val passwordVisible: Boolean = false, val serverUrlValidationError: UiText? = null, val swarmUiUrlValidationError: UiText? = null, @@ -44,6 +45,7 @@ data class ServerSetupState( val huggingFaceApiKeyValidationError: UiText? = null, val openAiApiKeyValidationError: UiText? = null, val stabilityAiApiKeyValidationError: UiText? = null, + val localCustomModelPathValidationError: UiText? = null, ) : MviState, KoinComponent { val demoModeUrl: String diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt index f0ce3e03..d02f1287 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModel.kt @@ -5,6 +5,7 @@ import com.shifthackz.aisdv1.core.common.schedulers.SchedulersProvider import com.shifthackz.aisdv1.core.common.schedulers.subscribeOnMainThread import com.shifthackz.aisdv1.core.model.asUiText import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.path.FilePathValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.core.viewmodel.MviRxViewModel import com.shifthackz.aisdv1.domain.entity.DownloadState @@ -37,6 +38,7 @@ class ServerSetupViewModel( fetchAndGetHuggingFaceModelsUseCase: FetchAndGetHuggingFaceModelsUseCase, private val urlValidator: UrlValidator, private val stringValidator: CommonStringValidator, + private val filePathValidator: FilePathValidator, private val setupConnectionInterActor: SetupConnectionInterActor, private val downloadModelUseCase: DownloadModelUseCase, private val deleteModelUseCase: DeleteModelUseCase, @@ -82,6 +84,7 @@ class ServerSetupViewModel( stabilityAiApiKey = configuration.stabilityAiApiKey, localModels = localModels.mapToUi(), localCustomModel = localModels.mapLocalCustomModelSwitchState(), + localCustomModelPath = configuration.localModelPath, mode = configuration.source, demoMode = configuration.demoMode, serverUrl = configuration.serverUrl, @@ -227,6 +230,13 @@ class ServerSetupViewModel( } ServerSetupIntent.ConnectToLocalHost -> connectToServer() + + is ServerSetupIntent.SelectLocalModelPath -> updateState { + it.copy( + localCustomModelPath = intent.value, + localCustomModelPathValidationError = null, + ) + } } private fun validateAndConnectToServer() { @@ -278,7 +288,15 @@ class ServerSetupViewModel( } ServerSource.LOCAL -> { - currentState.localModels.find { it.selected && it.downloaded } != null + if (currentState.localCustomModel) { + val validation = filePathValidator(currentState.localCustomModelPath) + updateState { + it.copy(localCustomModelPathValidationError = validation.mapToUi()) + } + validation.isValid + } else { + currentState.localModels.find { it.selected && it.downloaded } != null + } } ServerSource.HUGGING_FACE -> { @@ -382,6 +400,7 @@ class ServerSetupViewModel( } private fun connectToLocalDiffusion(): Single> { + preferenceManager.localDiffusionCustomModelPath = currentState.localCustomModelPath val localModelId = currentState.localModels.find { it.selected }?.id ?: "" return setupConnectionInterActor.connectToLocal(localModelId) } diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt index df5819ce..6ebfed22 100644 --- a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/forms/LocalDiffusionForm.kt @@ -1,5 +1,9 @@ package com.shifthackz.aisdv1.presentation.screen.setup.forms +import android.content.Intent +import android.provider.DocumentsContract +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.contract.ActivityResultContracts import androidx.compose.foundation.background import androidx.compose.foundation.border import androidx.compose.foundation.clickable @@ -15,23 +19,27 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.size import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Refresh import androidx.compose.material.icons.outlined.FileDownload import androidx.compose.material.icons.outlined.FileDownloadDone import androidx.compose.material.icons.outlined.FileDownloadOff import androidx.compose.material.icons.outlined.Landslide import androidx.compose.material3.Button import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.LocalContentColor import androidx.compose.material3.MaterialTheme import androidx.compose.material3.OutlinedButton import androidx.compose.material3.Switch import androidx.compose.material3.Text +import androidx.compose.material3.TextField import androidx.compose.runtime.Composable import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip import androidx.compose.ui.graphics.Color +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.testTag import androidx.compose.ui.res.stringResource import androidx.compose.ui.text.font.FontWeight @@ -41,6 +49,9 @@ import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.times import com.shifthackz.aisdv1.core.common.appbuild.BuildInfoProvider import com.shifthackz.aisdv1.core.common.appbuild.BuildType +import com.shifthackz.aisdv1.core.common.file.LOCAL_DIFFUSION_CUSTOM_PATH +import com.shifthackz.aisdv1.core.extensions.getRealPath +import com.shifthackz.aisdv1.core.model.asString import com.shifthackz.aisdv1.domain.entity.DownloadState import com.shifthackz.aisdv1.domain.entity.LocalAiModel import com.shifthackz.aisdv1.presentation.screen.setup.ServerSetupIntent @@ -150,77 +161,67 @@ fun LocalDiffusionForm( Modifier.padding(start = (treeNum - 1) * 12.dp) val folderStyle = MaterialTheme.typography.bodySmall - Text( - modifier = folderModifier(1), - text = "Download", - style = folderStyle, - ) Text( modifier = Modifier.padding(start = 12.dp), - text = "SDAI", - style = folderStyle, - ) - Text( - modifier = folderModifier(3), - text = "model", + text = state.localCustomModelPath, style = folderStyle, ) Text( - modifier = folderModifier(4), + modifier = folderModifier(3), text = "text_encoder", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "model.ort", style = folderStyle, ) Text( - modifier = folderModifier(4), + modifier = folderModifier(3), text = "tokenizer", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "merges.txt", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(3), text = "special_tokens_map.json", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "tokenizer_config.json", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "vocab.json", style = folderStyle, ) Text( - modifier = folderModifier(4), + modifier = folderModifier(3), text = "unet", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "model.ort", style = folderStyle, ) Text( - modifier = folderModifier(4), + modifier = folderModifier(3), text = "vae_decoder", style = folderStyle, ) Text( - modifier = folderModifier(5), + modifier = folderModifier(4), text = "model.ort", style = folderStyle, ) @@ -285,6 +286,15 @@ fun LocalDiffusionForm( } } if (state.localCustomModel && buildInfoProvider.type == BuildType.FOSS) { + Text( + modifier = Modifier + .align(Alignment.CenterHorizontally) + .padding(vertical = 8.dp), + text = stringResource(id = LocalizationR.string.model_local_permission_header), + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + fontWeight = FontWeight.Bold, + ) Text( modifier = Modifier.padding(vertical = 8.dp), text = stringResource(id = LocalizationR.string.model_local_permission_title), @@ -301,6 +311,79 @@ fun LocalDiffusionForm( color = LocalContentColor.current, ) } + Text( + modifier = Modifier + .align(Alignment.CenterHorizontally) + .padding(vertical = 8.dp), + text = stringResource(id = LocalizationR.string.model_local_path_header), + style = MaterialTheme.typography.bodyLarge, + textAlign = TextAlign.Center, + fontWeight = FontWeight.Bold, + ) + val context = LocalContext.current + val uriFlags = + Intent.FLAG_GRANT_READ_URI_PERMISSION or Intent.FLAG_GRANT_WRITE_URI_PERMISSION + val launcher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.StartActivityForResult() + ) { result -> + result.data?.data?.let { uri -> + context.contentResolver.takePersistableUriPermission(uri, uriFlags) + val docUri = DocumentsContract.buildDocumentUriUsingTree( + uri, + DocumentsContract.getTreeDocumentId(uri) + ) + getRealPath(context, docUri) + ?.let(ServerSetupIntent::SelectLocalModelPath) + ?.let(processIntent::invoke) + } + } + TextField( + modifier = Modifier + .fillMaxWidth() + .padding(top = 14.dp), + value = state.localCustomModelPath, + onValueChange = { processIntent(ServerSetupIntent.SelectLocalModelPath(it)) }, + enabled = true, + singleLine = true, + label = { Text(stringResource(LocalizationR.string.model_local_path_title)) }, + trailingIcon = { + IconButton( + onClick = { + processIntent( + ServerSetupIntent.SelectLocalModelPath(LOCAL_DIFFUSION_CUSTOM_PATH) + ) + }, + content = { + Icon( + imageVector = Icons.Default.Refresh, + contentDescription = "Reset", + ) + }, + ) + }, + isError = state.localCustomModelPathValidationError != null, + supportingText = { + state.localCustomModelPathValidationError + ?.let { Text(it.asString(), color = MaterialTheme.colorScheme.error) } + }, + ) + OutlinedButton( + modifier = Modifier + .fillMaxSize() + .padding(top = 4.dp, bottom = 8.dp), + onClick = { + val intent = Intent(Intent.ACTION_OPEN_DOCUMENT_TREE).apply { + addFlags(uriFlags) + } + launcher.launch(intent) + }, + ) { + Text( + text = stringResource(id = LocalizationR.string.model_local_path_button), + color = LocalContentColor.current, + ) + } + Spacer(modifier = Modifier.height(8.dp)) } state.localModels .filter { @@ -314,4 +397,4 @@ fun LocalDiffusionForm( style = MaterialTheme.typography.bodyMedium, ) } -} \ No newline at end of file +} diff --git a/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt new file mode 100644 index 00000000..08467bf9 --- /dev/null +++ b/presentation/src/main/java/com/shifthackz/aisdv1/presentation/screen/setup/mappers/ServerSetupValidationFilePathErrorMapper.kt @@ -0,0 +1,15 @@ +package com.shifthackz.aisdv1.presentation.screen.setup.mappers + +import com.shifthackz.aisdv1.core.localization.R +import com.shifthackz.aisdv1.core.model.UiText +import com.shifthackz.aisdv1.core.model.asUiText +import com.shifthackz.aisdv1.core.validation.ValidationResult +import com.shifthackz.aisdv1.core.validation.path.FilePathValidator + +fun ValidationResult.mapToUi(): UiText? { + if (this.isValid) return null + return when (validationError as FilePathValidator.Error) { + FilePathValidator.Error.Empty -> R.string.error_empty_field + FilePathValidator.Error.Invalid -> R.string.error_invalid + }.asUiText() +} diff --git a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt index 3097b797..178d0c6a 100644 --- a/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt +++ b/presentation/src/test/java/com/shifthackz/aisdv1/presentation/screen/setup/ServerSetupViewModelTest.kt @@ -1,6 +1,7 @@ package com.shifthackz.aisdv1.presentation.screen.setup import com.shifthackz.aisdv1.core.validation.common.CommonStringValidator +import com.shifthackz.aisdv1.core.validation.path.FilePathValidator import com.shifthackz.aisdv1.core.validation.url.UrlValidator import com.shifthackz.aisdv1.domain.entity.Configuration import com.shifthackz.aisdv1.domain.entity.DownloadState @@ -38,6 +39,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { private val stubFetchAndGetHuggingFaceModelsUseCase = mockk() private val stubUrlValidator = mockk() private val stubCommonStringValidator = mockk() + private val stubFilePathValidator = mockk() private val stubSetupConnectionInterActor = mockk() private val stubDownloadModelUseCase = mockk() private val stubDeleteModelUseCase = mockk() @@ -52,6 +54,7 @@ class ServerSetupViewModelTest : CoreViewModelTest() { fetchAndGetHuggingFaceModelsUseCase = stubFetchAndGetHuggingFaceModelsUseCase, urlValidator = stubUrlValidator, stringValidator = stubCommonStringValidator, + filePathValidator = stubFilePathValidator, setupConnectionInterActor = stubSetupConnectionInterActor, downloadModelUseCase = stubDownloadModelUseCase, deleteModelUseCase = stubDeleteModelUseCase,