add tests

This commit is contained in:
mertalev 2025-10-15 14:16:51 -04:00
parent 0fd4d4c798
commit 88c9935c4a
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
15 changed files with 399 additions and 49 deletions

View file

@ -233,7 +233,7 @@ export const defaults = Object.freeze<SystemConfig>({
[QueueName.ThumbnailGeneration]: { concurrency: 3 },
[QueueName.VideoConversion]: { concurrency: 1 },
[QueueName.Notification]: { concurrency: 5 },
[QueueName.OCR]: { concurrency: 1 },
[QueueName.Ocr]: { concurrency: 1 },
},
logging: {
enabled: true,
@ -264,9 +264,9 @@ export const defaults = Object.freeze<SystemConfig>({
},
ocr: {
enabled: true,
modelName: 'PP-OCRv5_server',
modelName: 'PP-OCRv5_mobile',
minDetectionScore: 0.5,
minRecognitionScore: 0.9,
minRecognitionScore: 0.8,
maxResolution: 736,
},
},

View file

@ -95,5 +95,5 @@ export class AllJobStatusResponseDto implements Record<QueueName, JobStatusDto>
[QueueName.BackupDatabase]!: JobStatusDto;
@ApiProperty({ type: JobStatusDto })
[QueueName.OCR]!: JobStatusDto;
[QueueName.Ocr]!: JobStatusDto;
}

View file

@ -205,7 +205,7 @@ class SystemConfigJobDto implements Record<ConcurrentQueueName, JobSettingsDto>
@ValidateNested()
@IsObject()
@Type(() => JobSettingsDto)
[QueueName.OCR]!: JobSettingsDto;
[QueueName.Ocr]!: JobSettingsDto;
@ApiProperty({ type: JobSettingsDto })
@ValidateNested()

View file

@ -511,7 +511,7 @@ export enum QueueName {
Library = 'library',
Notification = 'notifications',
BackupDatabase = 'backupDatabase',
OCR = 'ocr',
Ocr = 'ocr',
}
export enum JobName {
@ -586,8 +586,8 @@ export enum JobName {
VersionCheck = 'VersionCheck',
// OCR
QUEUE_OCR = 'queue-ocr',
OCR = 'ocr',
OcrQueueAll = 'OcrQueueAll',
Ocr = 'Ocr',
}
export enum JobCommand {

View file

@ -220,9 +220,6 @@ export class JobRepository {
case JobName.FacialRecognitionQueueAll: {
return { jobId: JobName.FacialRecognitionQueueAll };
}
case JobName.QUEUE_OCR: {
return { jobId: JobName.QUEUE_OCR };
}
default: {
return null;
}

View file

@ -218,6 +218,17 @@ export class MachineLearningRepository {
return response[ModelTask.SEARCH];
}
async ocr(imagePath: string, { modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions) {
const request = {
[ModelTask.OCR]: {
[ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } },
[ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } },
},
};
const response = await this.predict<OcrResponse>({ imagePath }, request);
return response[ModelTask.OCR];
}
private async getFormData(payload: ModelPayload, config: MachineLearningRequest): Promise<FormData> {
const formData = new FormData();
formData.append('entries', JSON.stringify(config));
@ -233,19 +244,4 @@ export class MachineLearningRepository {
return formData;
}
async ocr(
urls: string[],
imagePath: string,
{ modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions,
) {
const request = {
[ModelTask.OCR]: {
[ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } },
[ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } },
},
};
const response = await this.predict<OcrResponse>({ imagePath }, request);
return response[ModelTask.OCR];
}
}

View file

@ -24,7 +24,7 @@ describe(JobService.name, () => {
it('should update concurrency', () => {
sut.onConfigUpdate({ newConfig: defaults, oldConfig: {} as SystemConfig });
expect(mocks.job.setConcurrency).toHaveBeenCalledTimes(15);
expect(mocks.job.setConcurrency).toHaveBeenCalledTimes(16);
expect(mocks.job.setConcurrency).toHaveBeenNthCalledWith(5, QueueName.FacialRecognition, 1);
expect(mocks.job.setConcurrency).toHaveBeenNthCalledWith(7, QueueName.DuplicateDetection, 1);
expect(mocks.job.setConcurrency).toHaveBeenNthCalledWith(8, QueueName.BackgroundTask, 5);
@ -98,6 +98,7 @@ describe(JobService.name, () => {
[QueueName.Library]: expectedJobStatus,
[QueueName.Notification]: expectedJobStatus,
[QueueName.BackupDatabase]: expectedJobStatus,
[QueueName.Ocr]: expectedJobStatus,
});
});
});
@ -270,12 +271,12 @@ describe(JobService.name, () => {
},
{
item: { name: JobName.AssetGenerateThumbnails, data: { id: 'asset-1', source: 'upload' } },
jobs: [JobName.SmartSearch, JobName.AssetDetectFaces],
jobs: [JobName.SmartSearch, JobName.AssetDetectFaces, JobName.Ocr],
stub: [assetStub.livePhotoStillAsset],
},
{
item: { name: JobName.AssetGenerateThumbnails, data: { id: 'asset-1', source: 'upload' } },
jobs: [JobName.SmartSearch, JobName.AssetDetectFaces, JobName.AssetEncodeVideo],
jobs: [JobName.SmartSearch, JobName.AssetDetectFaces, JobName.Ocr, JobName.AssetEncodeVideo],
stub: [assetStub.video],
},
{

View file

@ -237,12 +237,8 @@ export class JobService extends BaseService {
return this.jobRepository.queue({ name: JobName.DatabaseBackup, data: { force } });
}
case QueueName.OCR: {
return this.jobRepository.queue({ name: JobName.QUEUE_OCR, data: { force } });
}
case QueueName.OCR: {
return this.jobRepository.queue({ name: JobName.QUEUE_OCR, data: { force } });
case QueueName.Ocr: {
return this.jobRepository.queue({ name: JobName.OcrQueueAll, data: { force } });
}
default: {
@ -361,7 +357,7 @@ export class JobService extends BaseService {
const jobs: JobItem[] = [
{ name: JobName.SmartSearch, data: item.data },
{ name: JobName.AssetDetectFaces, data: item.data },
{ name: JobName.OCR, data: item.data },
{ name: JobName.Ocr, data: item.data },
];
if (asset.type === AssetType.Video) {

View file

@ -0,0 +1,177 @@
import { AssetVisibility, ImmichWorker, JobName, JobStatus } from 'src/enum';
import { OcrService } from 'src/services/ocr.service';
import { assetStub } from 'test/fixtures/asset.stub';
import { systemConfigStub } from 'test/fixtures/system-config.stub';
import { makeStream, newTestService, ServiceMocks } from 'test/utils';
describe(OcrService.name, () => {
let sut: OcrService;
let mocks: ServiceMocks;
beforeEach(() => {
({ sut, mocks } = newTestService(OcrService));
mocks.config.getWorker.mockReturnValue(ImmichWorker.Microservices);
});
it('should work', () => {
expect(sut).toBeDefined();
});
describe('handleQueueOcr', () => {
it('should do nothing if machine learning is disabled', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.machineLearningDisabled);
await sut.handleQueueOcr({ force: false });
expect(mocks.database.setDimensionSize).not.toHaveBeenCalled();
});
it('should queue the assets without ocr', async () => {
mocks.assetJob.streamForOcrJob.mockReturnValue(makeStream([assetStub.image]));
await sut.handleQueueOcr({ force: false });
expect(mocks.job.queueAll).toHaveBeenCalledWith([{ name: JobName.Ocr, data: { id: assetStub.image.id } }]);
expect(mocks.assetJob.streamForOcrJob).toHaveBeenCalledWith(false);
});
it('should queue all the assets', async () => {
mocks.assetJob.streamForOcrJob.mockReturnValue(makeStream([assetStub.image]));
await sut.handleQueueOcr({ force: true });
expect(mocks.job.queueAll).toHaveBeenCalledWith([{ name: JobName.Ocr, data: { id: assetStub.image.id } }]);
expect(mocks.assetJob.streamForOcrJob).toHaveBeenCalledWith(true);
});
});
describe('handleOcr', () => {
it('should do nothing if machine learning is disabled', async () => {
mocks.systemMetadata.get.mockResolvedValue(systemConfigStub.machineLearningDisabled);
expect(await sut.handleOcr({ id: '123' })).toEqual(JobStatus.Skipped);
expect(mocks.asset.getByIds).not.toHaveBeenCalled();
expect(mocks.machineLearning.encodeImage).not.toHaveBeenCalled();
});
it('should skip assets without a resize path', async () => {
mocks.assetJob.getForOcr.mockResolvedValue({ visibility: AssetVisibility.Timeline, previewFile: null });
expect(await sut.handleOcr({ id: assetStub.noResizePath.id })).toEqual(JobStatus.Failed);
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
});
it('should save the returned objects', async () => {
mocks.machineLearning.ocr.mockResolvedValue({
box: [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160],
boxScore: [0.9, 0.8],
text: ['One Two Three', 'Four Five'],
textScore: [0.95, 0.85],
});
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Timeline,
previewFile: assetStub.image.files[1].path,
});
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
expect(mocks.machineLearning.ocr).toHaveBeenCalledWith(
'/uploads/user-id/thumbs/path.jpg',
expect.objectContaining({
modelName: 'PP-OCRv5_mobile',
minDetectionScore: 0.5,
minRecognitionScore: 0.8,
maxResolution: 736,
}),
);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, [
{
assetId: assetStub.image.id,
boxScore: 0.9,
text: 'One Two Three',
textScore: 0.95,
x1: 10,
y1: 20,
x2: 30,
y2: 40,
x3: 50,
y3: 60,
x4: 70,
y4: 80,
},
{
assetId: assetStub.image.id,
boxScore: 0.8,
text: 'Four Five',
textScore: 0.85,
x1: 90,
y1: 100,
x2: 110,
y2: 120,
x3: 130,
y3: 140,
x4: 150,
y4: 160,
},
]);
});
it('should apply config settings', async () => {
mocks.systemMetadata.get.mockResolvedValue({
machineLearning: {
enabled: true,
ocr: {
modelName: 'PP-OCRv5_server',
enabled: true,
minDetectionScore: 0.8,
minRecognitionScore: 0.9,
maxResolution: 1500,
},
},
});
mocks.machineLearning.ocr.mockResolvedValue({ box: [], boxScore: [], text: [], textScore: [] });
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Timeline,
previewFile: assetStub.image.files[1].path,
});
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Success);
expect(mocks.machineLearning.ocr).toHaveBeenCalledWith(
'/uploads/user-id/thumbs/path.jpg',
expect.objectContaining({
modelName: 'PP-OCRv5_server',
minDetectionScore: 0.8,
minRecognitionScore: 0.9,
maxResolution: 1500,
}),
);
expect(mocks.ocr.upsert).toHaveBeenCalledWith(assetStub.image.id, []);
});
it('should skip invisible assets', async () => {
mocks.assetJob.getForOcr.mockResolvedValue({
visibility: AssetVisibility.Hidden,
previewFile: assetStub.image.files[1].path,
});
expect(await sut.handleOcr({ id: assetStub.livePhotoMotionAsset.id })).toEqual(JobStatus.Skipped);
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
});
it('should fail if asset could not be found', async () => {
mocks.assetJob.getForOcr.mockResolvedValue(void 0);
expect(await sut.handleOcr({ id: assetStub.image.id })).toEqual(JobStatus.Failed);
expect(mocks.machineLearning.ocr).not.toHaveBeenCalled();
expect(mocks.ocr.upsert).not.toHaveBeenCalled();
});
});
});

View file

@ -9,8 +9,8 @@ import { isOcrEnabled } from 'src/utils/misc';
@Injectable()
export class OcrService extends BaseService {
@OnJob({ name: JobName.QUEUE_OCR, queue: QueueName.OCR })
async handleQueueOcr({ force, nightly }: JobOf<JobName.QUEUE_OCR>): Promise<JobStatus> {
@OnJob({ name: JobName.OcrQueueAll, queue: QueueName.Ocr })
async handleQueueOcr({ force }: JobOf<JobName.OcrQueueAll>): Promise<JobStatus> {
const { machineLearning } = await this.getConfig({ withCache: false });
if (!isOcrEnabled(machineLearning)) {
return JobStatus.Skipped;
@ -24,7 +24,7 @@ export class OcrService extends BaseService {
const assets = this.assetJobRepository.streamForOcrJob(force);
for await (const asset of assets) {
jobs.push({ name: JobName.OCR, data: { id: asset.id } });
jobs.push({ name: JobName.Ocr, data: { id: asset.id } });
if (jobs.length >= JOBS_ASSET_PAGINATION_SIZE) {
await this.jobRepository.queueAll(jobs);
@ -36,8 +36,8 @@ export class OcrService extends BaseService {
return JobStatus.Success;
}
@OnJob({ name: JobName.OCR, queue: QueueName.OCR })
async handleOcr({ id }: JobOf<JobName.OCR>): Promise<JobStatus> {
@OnJob({ name: JobName.Ocr, queue: QueueName.Ocr })
async handleOcr({ id }: JobOf<JobName.Ocr>): Promise<JobStatus> {
const { machineLearning } = await this.getConfig({ withCache: true });
if (!isOcrEnabled(machineLearning)) {
return JobStatus.Skipped;
@ -52,11 +52,7 @@ export class OcrService extends BaseService {
return JobStatus.Skipped;
}
const ocrResults = await this.machineLearningRepository.ocr(
machineLearning.urls,
asset.previewFile,
machineLearning.ocr,
);
const ocrResults = await this.machineLearningRepository.ocr(asset.previewFile, machineLearning.ocr);
await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
@ -66,7 +62,7 @@ export class OcrService extends BaseService {
return JobStatus.Success;
}
parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) {
private parseOcrResults(id: string, { box, boxScore, text, textScore }: OCR) {
const ocrDataList = [];
for (let i = 0; i < text.length; i++) {
const boxOffset = i * 8;

View file

@ -141,6 +141,7 @@ describe(ServerService.name, () => {
reverseGeocoding: true,
oauth: false,
oauthAutoLaunch: false,
ocr: true,
passwordLogin: true,
search: true,
sidecar: true,

View file

@ -39,6 +39,7 @@ const updatedConfig = Object.freeze<SystemConfig>({
[QueueName.ThumbnailGeneration]: { concurrency: 3 },
[QueueName.VideoConversion]: { concurrency: 1 },
[QueueName.Notification]: { concurrency: 5 },
[QueueName.Ocr]: { concurrency: 1 },
},
backup: {
database: {
@ -102,6 +103,13 @@ const updatedConfig = Object.freeze<SystemConfig>({
maxDistance: 0.5,
minFaces: 3,
},
ocr: {
enabled: true,
modelName: 'PP-OCRv5_mobile',
minDetectionScore: 0.5,
minRecognitionScore: 0.8,
maxResolution: 736,
},
},
map: {
enabled: true,

View file

@ -373,8 +373,8 @@ export type JobItem =
| { name: JobName.VersionCheck; data: IBaseJob }
// OCR
| { name: JobName.QUEUE_OCR; data: INightlyJob }
| { name: JobName.OCR; data: IEntityJob };
| { name: JobName.OcrQueueAll; data: IBaseJob }
| { name: JobName.Ocr; data: IEntityJob };
export type VectorExtension = (typeof VECTOR_EXTENSIONS)[number];

View file

@ -0,0 +1,174 @@
import { Kysely } from 'kysely';
import { AssetJobRepository } from 'src/repositories/asset-job.repository';
import { AssetRepository } from 'src/repositories/asset.repository';
import { JobRepository } from 'src/repositories/job.repository';
import { LoggingRepository } from 'src/repositories/logging.repository';
import { MachineLearningRepository } from 'src/repositories/machine-learning.repository';
import { OcrRepository } from 'src/repositories/ocr.repository';
import { DB } from 'src/schema';
import { OcrService } from 'src/services/ocr.service';
import { newMediumService } from 'test/medium.factory';
import { getKyselyDB } from 'test/utils';
let defaultDatabase: Kysely<DB>;
const setup = (db?: Kysely<DB>) => {
return newMediumService(OcrService, {
database: db || defaultDatabase,
real: [AssetRepository, AssetJobRepository, JobRepository, OcrRepository],
mock: [LoggingRepository, MachineLearningRepository],
});
};
beforeAll(async () => {
defaultDatabase = await getKyselyDB();
});
describe(OcrService.name, () => {
it('should work', () => {
const { sut } = setup();
expect(sut).toBeDefined();
});
it('should parse asset', async () => {
const { sut, ctx } = setup();
const { user } = await ctx.newUser();
const { asset } = await ctx.newAsset({ ownerId: user.id });
const machineLearningMock = ctx.getMock(MachineLearningRepository);
machineLearningMock.ocr.mockResolvedValue({
box: [10, 10, 50, 10, 50, 50, 10, 50],
boxScore: [0.99],
text: ['Test OCR'],
textScore: [0.95],
});
await expect(sut.handleOcr({ id: asset.id })).resolves.toBe('Success');
const ocrRepository = ctx.get(OcrRepository);
await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([
{
assetId: asset.id,
boxScore: 0.99,
id: expect.any(String),
text: 'Test OCR',
textScore: 0.95,
x1: 10,
y1: 10,
x2: 50,
y2: 10,
x3: 50,
y3: 50,
x4: 10,
y4: 50,
},
]);
await expect(
ctx.database.selectFrom('ocr_search').selectAll().where('assetId', '=', asset.id).executeTakeFirst(),
).resolves.toEqual({
assetId: asset.id,
text: 'Test OCR',
});
});
it('should handle multiple boxes', async () => {
const { sut, ctx } = setup();
const { user } = await ctx.newUser();
const { asset } = await ctx.newAsset({ ownerId: user.id });
const machineLearningMock = ctx.getMock(MachineLearningRepository);
machineLearningMock.ocr.mockResolvedValue({
box: Array.from({ length: 8 * 10 }, (_, i) => i),
boxScore: [0.7, 0.67, 0.65, 0.62, 0.6],
text: ['One', 'Two', 'Three', 'Four', 'Five'],
textScore: [0.9, 0.89, 0.88, 0.87, 0.86],
});
await expect(sut.handleOcr({ id: asset.id })).resolves.toBe('Success');
const ocrRepository = ctx.get(OcrRepository);
await expect(ocrRepository.getByAssetId(asset.id)).resolves.toEqual([
{
assetId: asset.id,
boxScore: 0.7,
id: expect.any(String),
text: 'One',
textScore: 0.9,
x1: 0,
y1: 1,
x2: 2,
y2: 3,
x3: 4,
y3: 5,
x4: 6,
y4: 7,
},
{
assetId: asset.id,
boxScore: 0.67,
id: expect.any(String),
text: 'Two',
textScore: 0.89,
x1: 8,
y1: 9,
x2: 10,
y2: 11,
x3: 12,
y3: 13,
x4: 14,
y4: 15,
},
{
assetId: asset.id,
boxScore: 0.65,
id: expect.any(String),
text: 'Three',
textScore: 0.88,
x1: 16,
y1: 17,
x2: 18,
y2: 19,
x3: 20,
y3: 21,
x4: 22,
y4: 23,
},
{
assetId: asset.id,
boxScore: 0.62,
id: expect.any(String),
text: 'Four',
textScore: 0.87,
x1: 24,
y1: 25,
x2: 26,
y2: 27,
x3: 28,
y3: 29,
x4: 30,
y4: 31,
},
{
assetId: asset.id,
boxScore: 0.6,
id: expect.any(String),
text: 'Five',
textScore: 0.86,
x1: 32,
y1: 33,
x2: 34,
y2: 35,
x3: 36,
y3: 37,
x4: 38,
y4: 39,
},
]);
await expect(
ctx.database.selectFrom('ocr_search').selectAll().where('assetId', '=', asset.id).executeTakeFirst(),
).resolves.toEqual({
assetId: asset.id,
text: 'One Two Three Four Fivee',
});
});
});

View file

@ -41,6 +41,7 @@ import { MetadataRepository } from 'src/repositories/metadata.repository';
import { MoveRepository } from 'src/repositories/move.repository';
import { NotificationRepository } from 'src/repositories/notification.repository';
import { OAuthRepository } from 'src/repositories/oauth.repository';
import { OcrRepository } from 'src/repositories/ocr.repository';
import { PartnerRepository } from 'src/repositories/partner.repository';
import { PersonRepository } from 'src/repositories/person.repository';
import { ProcessRepository } from 'src/repositories/process.repository';
@ -228,6 +229,7 @@ export type ServiceOverrides = {
metadata: MetadataRepository;
move: MoveRepository;
notification: NotificationRepository;
ocr: OcrRepository;
oauth: OAuthRepository;
partner: PartnerRepository;
person: PersonRepository;
@ -298,6 +300,7 @@ export const newTestService = <T extends BaseService>(
metadata: newMetadataRepositoryMock(),
move: automock(MoveRepository, { strict: false }),
notification: automock(NotificationRepository),
ocr: automock(OcrRepository, { strict: false }),
oauth: automock(OAuthRepository, { args: [loggerMock] }),
partner: automock(PartnerRepository, { strict: false }),
person: automock(PersonRepository, { strict: false }),
@ -350,6 +353,7 @@ export const newTestService = <T extends BaseService>(
overrides.move || (mocks.move as As<MoveRepository>),
overrides.notification || (mocks.notification as As<NotificationRepository>),
overrides.oauth || (mocks.oauth as As<OAuthRepository>),
overrides.ocr || (mocks.ocr as As<OcrRepository>),
overrides.partner || (mocks.partner as As<PartnerRepository>),
overrides.person || (mocks.person as As<PersonRepository>),
overrides.process || (mocks.process as As<ProcessRepository>),