diff --git a/server/src/repositories/ocr.repository.ts b/server/src/repositories/ocr.repository.ts index d858af6c3b..6375c1827e 100644 --- a/server/src/repositories/ocr.repository.ts +++ b/server/src/repositories/ocr.repository.ts @@ -26,27 +26,6 @@ export class OcrRepository { }); } - @GenerateSql({ - params: [ - DummyValue.UUID, - [ - { - assetId: DummyValue.UUID, - x1: DummyValue.NUMBER, - y1: DummyValue.NUMBER, - x2: DummyValue.NUMBER, - y2: DummyValue.NUMBER, - x3: DummyValue.NUMBER, - y3: DummyValue.NUMBER, - x4: DummyValue.NUMBER, - y4: DummyValue.NUMBER, - text: DummyValue.STRING, - boxScore: DummyValue.NUMBER, - textScore: DummyValue.NUMBER, - }, - ], - ], - }) upsert(assetId: string, ocrDataList: Insertable[]) { let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId)); if (ocrDataList.length > 0) { @@ -59,6 +38,10 @@ export class OcrRepository { .values({ assetId, text: searchText }) .onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))), ); + } else { + (query as any) = query.with('deleted_search', (db) => + db.deleteFrom('ocr_search').where('assetId', '=', assetId), + ); } return query.selectNoFrom(sql`1`.as('dummy')).execute(); diff --git a/server/test/medium.factory.ts b/server/test/medium.factory.ts index bb964dd3f0..d1187fc7b2 100644 --- a/server/test/medium.factory.ts +++ b/server/test/medium.factory.ts @@ -47,6 +47,7 @@ import { VersionHistoryRepository } from 'src/repositories/version-history.repos import { DB } from 'src/schema'; import { AlbumTable } from 'src/schema/tables/album.table'; import { AssetExifTable } from 'src/schema/tables/asset-exif.table'; +import { AssetFileTable } from 'src/schema/tables/asset-file.table'; import { AssetJobStatusTable } from 'src/schema/tables/asset-job-status.table'; import { AssetTable } from 'src/schema/tables/asset.table'; import { FaceSearchTable } from 'src/schema/tables/face-search.table'; @@ -167,6 +168,11 @@ export class MediumTestContext { return { asset, result }; } + async newAssetFile(dto: Insertable) { + const result = await this.get(AssetRepository).upsertFile(dto); + return { result }; + } + async newAssetFace(dto: Partial> & { assetId: string }) { const assetFace = mediumFactory.assetFaceInsert(dto); const result = await this.get(PersonRepository).createAssetFace(assetFace); @@ -339,7 +345,6 @@ const newMockRepository = (key: ClassConstructor) => { case AssetJobRepository: case ConfigRepository: case CryptoRepository: - case MachineLearningRepository: case MemoryRepository: case NotificationRepository: case OcrRepository: @@ -390,6 +395,10 @@ const newMockRepository = (key: ClassConstructor) => { return automock(LoggingRepository, { args: [undefined, configMock], strict: false }); } + case MachineLearningRepository: { + return automock(MachineLearningRepository, { args: [{ setContext: () => {} }] }); + } + case StorageRepository: { return automock(StorageRepository, { args: [{ setContext: () => {} }] }); } diff --git a/server/test/medium/specs/services/ocr.service.spec.ts b/server/test/medium/specs/services/ocr.service.spec.ts index cf51d980ec..45c34dd09e 100644 --- a/server/test/medium/specs/services/ocr.service.spec.ts +++ b/server/test/medium/specs/services/ocr.service.spec.ts @@ -1,10 +1,13 @@ import { Kysely } from 'kysely'; +import { AssetFileType, JobStatus } from 'src/enum'; import { AssetJobRepository } from 'src/repositories/asset-job.repository'; import { AssetRepository } from 'src/repositories/asset.repository'; +import { ConfigRepository } from 'src/repositories/config.repository'; import { JobRepository } from 'src/repositories/job.repository'; import { LoggingRepository } from 'src/repositories/logging.repository'; import { MachineLearningRepository } from 'src/repositories/machine-learning.repository'; import { OcrRepository } from 'src/repositories/ocr.repository'; +import { SystemMetadataRepository } from 'src/repositories/system-metadata.repository'; import { DB } from 'src/schema'; import { OcrService } from 'src/services/ocr.service'; import { newMediumService } from 'test/medium.factory'; @@ -15,8 +18,8 @@ let defaultDatabase: Kysely; const setup = (db?: Kysely) => { return newMediumService(OcrService, { database: db || defaultDatabase, - real: [AssetRepository, AssetJobRepository, JobRepository, OcrRepository], - mock: [LoggingRepository, MachineLearningRepository], + real: [AssetRepository, AssetJobRepository, ConfigRepository, OcrRepository, SystemMetadataRepository], + mock: [JobRepository, LoggingRepository, MachineLearningRepository], }); }; @@ -34,6 +37,7 @@ describe(OcrService.name, () => { const { sut, ctx } = setup(); const { user } = await ctx.newUser(); const { asset } = await ctx.newAsset({ ownerId: user.id }); + await ctx.newAssetFile({ assetId: asset.id, type: AssetFileType.Preview, path: 'preview.jpg' }); const machineLearningMock = ctx.getMock(MachineLearningRepository); machineLearningMock.ocr.mockResolvedValue({ @@ -43,7 +47,7 @@ describe(OcrService.name, () => { textScore: [0.95], }); - await expect(sut.handleOcr({ id: asset.id })).resolves.toBe('Success'); + await expect(sut.handleOcr({ id: asset.id })).resolves.toBe(JobStatus.Success); const ocrRepository = ctx.get(OcrRepository); await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([ @@ -69,22 +73,30 @@ describe(OcrService.name, () => { assetId: asset.id, text: 'Test OCR', }); + await expect( + ctx.database + .selectFrom('asset_job_status') + .select('asset_job_status.ocrAt') + .where('assetId', '=', asset.id) + .executeTakeFirst(), + ).resolves.toEqual({ ocrAt: expect.any(Date) }); }); it('should handle multiple boxes', async () => { const { sut, ctx } = setup(); const { user } = await ctx.newUser(); const { asset } = await ctx.newAsset({ ownerId: user.id }); + await ctx.newAssetFile({ assetId: asset.id, type: AssetFileType.Preview, path: 'preview.jpg' }); const machineLearningMock = ctx.getMock(MachineLearningRepository); machineLearningMock.ocr.mockResolvedValue({ - box: Array.from({ length: 8 * 10 }, (_, i) => i), + box: Array.from({ length: 8 * 5 }, (_, i) => i), boxScore: [0.7, 0.67, 0.65, 0.62, 0.6], text: ['One', 'Two', 'Three', 'Four', 'Five'], textScore: [0.9, 0.89, 0.88, 0.87, 0.86], }); - await expect(sut.handleOcr({ id: asset.id })).resolves.toBe('Success'); + await expect(sut.handleOcr({ id: asset.id })).resolves.toBe(JobStatus.Success); const ocrRepository = ctx.get(OcrRepository); await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([ @@ -168,7 +180,64 @@ describe(OcrService.name, () => { ctx.database.selectFrom('ocr_search').selectAll().where('assetId', '=', asset.id).executeTakeFirst(), ).resolves.toEqual({ assetId: asset.id, - text: 'One Two Three Four Fivee', + text: 'One Two Three Four Five', }); + await expect( + ctx.database + .selectFrom('asset_job_status') + .select('asset_job_status.ocrAt') + .where('assetId', '=', asset.id) + .executeTakeFirst(), + ).resolves.toEqual({ ocrAt: expect.any(Date) }); + }); + + it('should handle no boxes', async () => { + const { sut, ctx } = setup(); + const { user } = await ctx.newUser(); + const { asset } = await ctx.newAsset({ ownerId: user.id }); + await ctx.newAssetFile({ assetId: asset.id, type: AssetFileType.Preview, path: 'preview.jpg' }); + + const machineLearningMock = ctx.getMock(MachineLearningRepository); + machineLearningMock.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] }); + + await expect(sut.handleOcr({ id: asset.id })).resolves.toBe(JobStatus.Success); + + const ocrRepository = ctx.get(OcrRepository); + await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([]); + await expect( + ctx.database.selectFrom('ocr_search').selectAll().where('assetId', '=', asset.id).executeTakeFirst(), + ).resolves.toBeUndefined(); + await expect( + ctx.database + .selectFrom('asset_job_status') + .select('asset_job_status.ocrAt') + .where('assetId', '=', asset.id) + .executeTakeFirst(), + ).resolves.toEqual({ ocrAt: expect.any(Date) }); + }); + + it('should update existing results', async () => { + const { sut, ctx } = setup(); + const { user } = await ctx.newUser(); + const { asset } = await ctx.newAsset({ ownerId: user.id }); + await ctx.newAssetFile({ assetId: asset.id, type: AssetFileType.Preview, path: 'preview.jpg' }); + + const machineLearningMock = ctx.getMock(MachineLearningRepository); + machineLearningMock.ocr.mockResolvedValue({ + box: [10, 10, 50, 10, 50, 50, 10, 50], + boxScore: [0.99], + text: ['Test OCR'], + textScore: [0.95], + }); + await expect(sut.handleOcr({ id: asset.id })).resolves.toBe(JobStatus.Success); + + machineLearningMock.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] }); + await expect(sut.handleOcr({ id: asset.id })).resolves.toBe(JobStatus.Success); + + const ocrRepository = ctx.get(OcrRepository); + await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([]); + await expect( + ctx.database.selectFrom('ocr_search').selectAll().where('assetId', '=', asset.id).executeTakeFirst(), + ).resolves.toBeUndefined(); }); });