Skip to content

Commit 96efe1d

Browse files
authored
Local Model Download mirrors (#457)
* Local Model Download mirrors * Update versions * Fix Unit tests
1 parent f8b57fa commit 96efe1d

File tree

23 files changed

+409
-76
lines changed

23 files changed

+409
-76
lines changed

core/localization/src/main/res/values/strings.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
<string name="title_txt_inversion">Textual Inversion</string>
187187
<string name="title_txt_inversion_short">Inversion</string>
188188
<string name="title_tag_edit">Edit tag</string>
189+
<string name="title_select_download_source">Select source</string>
189190

190191
<string name="gallery_media_store_banner">You have %1$s photos saved in Download/SDAI</string>
191192
<string name="gallery_info_field_date">Created</string>

data/src/main/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImpl.kt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@ internal class DownloadableModelRepositoryImpl(
1313
private val buildInfoProvider: BuildInfoProvider,
1414
) : DownloadableModelRepository {
1515

16-
override fun download(id: String) = localDataSource
17-
.getById(id)
18-
.flatMapObservable { model ->
19-
remoteDataSource.download(id, model.sources.firstOrNull() ?: "")
20-
}
16+
override fun download(id: String, url: String) = remoteDataSource.download(id, url)
2117

2218
override fun delete(id: String) = localDataSource.delete(id)
2319

data/src/test/java/com/shifthackz/aisdv1/data/repository/DownloadableModelRepositoryImplTest.kt

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -219,29 +219,14 @@ class DownloadableModelRepositoryImplTest {
219219
.assertNotComplete()
220220
}
221221

222-
@Test
223-
fun `given attempt to download model, local data source has no such model, expected error value`() {
224-
every {
225-
stubLocalDataSource.getById(any())
226-
} returns Single.error(stubException)
227-
228-
repository
229-
.download("5598")
230-
.test()
231-
.assertNoValues()
232-
.assertError(stubException)
233-
.await()
234-
.assertNotComplete()
235-
}
236-
237222
@Test
238223
fun `given attempt to download model, local data source has such model, download succeeds, expected unknown, downloading, complete values`() {
239224
every {
240225
stubLocalDataSource.getById(any())
241226
} returns Single.just(mockLocalAiModel)
242227

243228
val stubObserver = repository
244-
.download("5598")
229+
.download("5598", "https://moroz.cc/stub.zip")
245230
.test()
246231

247232
stubDownloadState.onNext(DownloadState.Unknown)
@@ -276,7 +261,7 @@ class DownloadableModelRepositoryImplTest {
276261
} returns Single.just(mockLocalAiModel)
277262

278263
val stubObserver = repository
279-
.download("5598")
264+
.download("5598", "https://moroz.cc/stub.zip")
280265
.test()
281266

282267
stubDownloadState.onNext(DownloadState.Unknown)
@@ -309,7 +294,7 @@ class DownloadableModelRepositoryImplTest {
309294
} returns Observable.error(stubException)
310295

311296
repository
312-
.download("5598")
297+
.download("5598", "https://moroz.cc/stub.zip")
313298
.test()
314299
.assertError(stubException)
315300
.assertNoValues()

domain/src/main/java/com/shifthackz/aisdv1/domain/di/DomainModule.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCase
3838
import com.shifthackz.aisdv1.domain.usecase.downloadable.DownloadModelUseCaseImpl
3939
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCase
4040
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalMediaPipeModelsUseCaseImpl
41+
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalModelUseCase
42+
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalModelUseCaseImpl
4143
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCase
4244
import com.shifthackz.aisdv1.domain.usecase.downloadable.GetLocalOnnxModelsUseCaseImpl
4345
import com.shifthackz.aisdv1.domain.usecase.downloadable.ObserveLocalOnnxModelsUseCase
@@ -185,6 +187,7 @@ internal val useCasesModule = module {
185187
factoryOf(::FetchAndGetStabilityAiEnginesUseCaseImpl) bind FetchAndGetStabilityAiEnginesUseCase::class
186188
factoryOf(::FetchAndGetSupportersUseCaseImpl) bind FetchAndGetSupportersUseCase::class
187189
factoryOf(::SendReportUseCaseImpl) bind SendReportUseCase::class
190+
factoryOf(::GetLocalModelUseCaseImpl) bind GetLocalModelUseCase::class
188191
}
189192

190193
internal val interActorsModule = module {

domain/src/main/java/com/shifthackz/aisdv1/domain/repository/DownloadableModelRepository.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import io.reactivex.rxjava3.core.Observable
88
import io.reactivex.rxjava3.core.Single
99

1010
interface DownloadableModelRepository {
11-
fun download(id: String): Observable<DownloadState>
11+
fun download(id: String, url: String): Observable<DownloadState>
1212
fun delete(id: String): Completable
1313
fun getAllOnnx(): Single<List<LocalAiModel>>
1414
fun getAllMediaPipe(): Single<List<LocalAiModel>>

domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCase.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ import com.shifthackz.aisdv1.domain.entity.DownloadState
44
import io.reactivex.rxjava3.core.Observable
55

66
interface DownloadModelUseCase {
7-
operator fun invoke(id: String): Observable<DownloadState>
7+
operator fun invoke(id: String, url: String): Observable<DownloadState>
88
}

domain/src/main/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImpl.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ internal class DownloadModelUseCaseImpl(
66
private val downloadableModelRepository: DownloadableModelRepository,
77
) : DownloadModelUseCase {
88

9-
override fun invoke(id: String) = downloadableModelRepository.download(id)
9+
override fun invoke(id: String, url: String) = downloadableModelRepository.download(id, url)
1010
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.shifthackz.aisdv1.domain.usecase.downloadable
2+
3+
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
4+
import io.reactivex.rxjava3.core.Single
5+
6+
interface GetLocalModelUseCase {
7+
operator fun invoke(id: String): Single<LocalAiModel>
8+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package com.shifthackz.aisdv1.domain.usecase.downloadable
2+
3+
import com.shifthackz.aisdv1.domain.datasource.DownloadableModelDataSource
4+
import com.shifthackz.aisdv1.domain.entity.LocalAiModel
5+
import io.reactivex.rxjava3.core.Single
6+
7+
internal class GetLocalModelUseCaseImpl(
8+
private val localDataSource: DownloadableModelDataSource.Local,
9+
) : GetLocalModelUseCase {
10+
11+
override fun invoke(id: String): Single<LocalAiModel> = localDataSource.getById(id)
12+
}

domain/src/test/java/com/shifthackz/aisdv1/domain/usecase/downloadable/DownloadModelUseCaseImplTest.kt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ class DownloadModelUseCaseImplTest {
2323

2424
@Before
2525
fun initialize() {
26-
whenever(stubRepository.download(any()))
26+
whenever(stubRepository.download(any(), any()))
2727
.thenReturn(stubDownloadStatus)
2828
}
2929

3030
@Test
3131
fun `given download running, then finishes successfully, expected final state is Complete`() {
32-
val stubObserver = useCase("5598").test()
32+
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()
3333

3434
stubDownloadStatus.onNext(DownloadState.Unknown)
3535

@@ -58,7 +58,7 @@ class DownloadModelUseCaseImplTest {
5858

5959
@Test
6060
fun `given download running, then fails, expected final state is Error`() {
61-
val stubObserver = useCase("5598").test()
61+
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()
6262

6363
stubDownloadStatus.onNext(DownloadState.Unknown)
6464

@@ -87,7 +87,7 @@ class DownloadModelUseCaseImplTest {
8787

8888
@Test
8989
fun `given download running, then fails, then user restarts download, then completes, expected state Error on 1st try, final state is Complete`() {
90-
val stubObserver = useCase("5598").test()
90+
val stubObserver = useCase("5598", "https://moroz.cc/stub.zip").test()
9191

9292
stubDownloadStatus.onNext(DownloadState.Unknown)
9393

@@ -140,10 +140,10 @@ class DownloadModelUseCaseImplTest {
140140

141141
@Test
142142
fun `given observable terminated with unexpected error, expected error value`() {
143-
whenever(stubRepository.download(any()))
143+
whenever(stubRepository.download(any(), any()))
144144
.thenReturn(Observable.error(stubTerminateException))
145145

146-
useCase("5598")
146+
useCase("5598", "https://moroz.cc/stub.zip")
147147
.test()
148148
.assertError(stubTerminateException)
149149
.await()

0 commit comments

Comments
 (0)