diff --git a/machine-learning/immich_ml/main.py b/machine-learning/immich_ml/main.py index e884af0e69..62d8d1ec8e 100644 --- a/machine-learning/immich_ml/main.py +++ b/machine-learning/immich_ml/main.py @@ -183,7 +183,10 @@ async def run_inference(payload: Image | str, entries: InferenceEntries) -> Infe response: InferenceResponse = {} async def _run_inference(entry: InferenceEntry) -> None: - model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl) + model = await model_cache.get( + entry["name"], entry["type"], entry["task"], + ttl=settings.model_ttl, **entry["options"] + ) inputs = [payload] for dep in model.depends: try: diff --git a/machine-learning/immich_ml/models/cache.py b/machine-learning/immich_ml/models/cache.py index d8f9ca81bd..cd4e3e59b6 100644 --- a/machine-learning/immich_ml/models/cache.py +++ b/machine-learning/immich_ml/models/cache.py @@ -38,7 +38,13 @@ class ModelCache: async def get( self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any ) -> InferenceModel: - key = f"{model_name}{model_type}{model_task}" + config_key = "" + if model_type == ModelType.OCR and model_task == ModelTask.OCR: + orientation = model_kwargs.get("orientationClassifyEnabled", True) + unwarping = model_kwargs.get("unwarpingEnabled", True) + config_key = f"_o{orientation}_u{unwarping}" + + key = f"{model_name}{model_type}{model_task}{config_key}" async with OptimisticLock(self.cache, key) as lock: model: InferenceModel | None = await self.cache.get(key) diff --git a/machine-learning/immich_ml/models/constants.py b/machine-learning/immich_ml/models/constants.py index 604b9a304f..4879532a03 100644 --- a/machine-learning/immich_ml/models/constants.py +++ b/machine-learning/immich_ml/models/constants.py @@ -76,7 +76,8 @@ _INSIGHTFACE_MODELS = { _PADDLE_MODELS = { - "paddle", + "PP-OCRv5_server", + "PP-OCRv5_mobile", } SUPPORTED_PROVIDERS = [ diff --git a/machine-learning/immich_ml/models/ocr/paddle.py b/machine-learning/immich_ml/models/ocr/paddle.py index 1cc82761c9..950a9af78e 100644 --- a/machine-learning/immich_ml/models/ocr/paddle.py +++ b/machine-learning/immich_ml/models/ocr/paddle.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List import numpy as np from numpy.typing import NDArray @@ -14,34 +14,33 @@ class PaddleOCRecognizer(InferenceModel): def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None: self.min_score = model_kwargs.pop("minScore", min_score) + self.orientation_classify_enabled = model_kwargs.pop("orientationClassifyEnabled", True) + self.unwarping_enabled = model_kwargs.pop("unwarpingEnabled", True) super().__init__(model_name, **model_kwargs) self._load() self.loaded = True - def _load(self) -> None: - try: - self.model = PaddleOCR( - use_doc_orientation_classify=False, - use_doc_unwarping=False, - use_textline_orientation=False - ) - except Exception as e: - print(f"Error loading PaddleOCR model: {e}") - raise e + def _load(self) -> PaddleOCR: + self.model = PaddleOCR( + text_detection_model_name=f"{self.model_name}_det", + text_recognition_model_name=f"{self.model_name}_rec", + use_doc_orientation_classify=self.orientation_classify_enabled, + use_doc_unwarping=self.unwarping_enabled, + ) - def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> OCROutput: + def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> List[OCROutput]: inputs = decode_cv2(inputs) results = self.model.predict(inputs) valid_texts_and_scores = [ - (text, score) + (text, score, box) for result in results - for text, score in zip(result['rec_texts'], result['rec_scores']) - if score > self.min_score + for text, score, box in zip(result['rec_texts'], result['rec_scores'], result['rec_boxes'].tolist()) + if score >= self.min_score ] if not valid_texts_and_scores: - return OCROutput(text="", confidence=0.0) - texts, scores = zip(*valid_texts_and_scores) - return OCROutput( - text="".join(texts), - confidence=sum(scores) / len(scores) - ) + return [] + + return [ + OCROutput(text=text, confidence=score, boundingBox={"x1": box[0], "y1": box[1], "x2": box[2], "y2": box[3]}) + for text, score, box in valid_texts_and_scores + ] diff --git a/machine-learning/immich_ml/schemas.py b/machine-learning/immich_ml/schemas.py index ee2f21fa80..d6622cb5f8 100644 --- a/machine-learning/immich_ml/schemas.py +++ b/machine-learning/immich_ml/schemas.py @@ -90,6 +90,7 @@ FacialRecognitionOutput = list[DetectedFace] class OCROutput(TypedDict): text: str confidence: float + boundingBox: BoundingBox class PipelineEntry(TypedDict): diff --git a/open-api/immich-openapi-specs.json b/open-api/immich-openapi-specs.json index a02427c563..97a7af6715 100644 --- a/open-api/immich-openapi-specs.json +++ b/open-api/immich-openapi-specs.json @@ -12926,12 +12926,20 @@ }, "modelName": { "type": "string" + }, + "orientationClassifyEnabled": { + "type": "boolean" + }, + "unwarpingEnabled": { + "type": "boolean" } }, "required": [ "enabled", "minScore", - "modelName" + "modelName", + "orientationClassifyEnabled", + "unwarpingEnabled" ], "type": "object" }, diff --git a/server/src/config.ts b/server/src/config.ts index b26735e09e..ce6eceb6fe 100644 --- a/server/src/config.ts +++ b/server/src/config.ts @@ -73,6 +73,8 @@ export interface SystemConfig { enabled: boolean; modelName: string; minScore: number; + unwarpingEnabled: boolean; + orientationClassifyEnabled: boolean; }; }; map: { @@ -250,8 +252,10 @@ export const defaults = Object.freeze({ }, ocr: { enabled: true, - modelName: 'paddle', + modelName: 'PP-OCRv5_server', minScore: 0.9, + unwarpingEnabled: false, + orientationClassifyEnabled: false, }, }, map: { diff --git a/server/src/dtos/model-config.dto.ts b/server/src/dtos/model-config.dto.ts index 3545699252..8f63ed8a54 100644 --- a/server/src/dtos/model-config.dto.ts +++ b/server/src/dtos/model-config.dto.ts @@ -54,4 +54,10 @@ export class OcrConfig extends ModelConfig { @Type(() => Number) @ApiProperty({ type: 'number', format: 'double' }) minScore!: number; + + @ValidateBoolean() + unwarpingEnabled!: boolean; + + @ValidateBoolean() + orientationClassifyEnabled!: boolean; } diff --git a/server/src/repositories/asset-job.repository.ts b/server/src/repositories/asset-job.repository.ts index 00a11b0529..d78ffb37bc 100644 --- a/server/src/repositories/asset-job.repository.ts +++ b/server/src/repositories/asset-job.repository.ts @@ -355,10 +355,8 @@ export class AssetJobRepository { .select(['assets.id']) .$if(!force, (qb) => qb - .leftJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id') - .where((eb) => - eb.or([eb('asset_job_status.ocrAt', 'is', null), eb('asset_job_status.assetId', 'is', null)]), - ) + .innerJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id') + .where('asset_job_status.ocrAt', 'is', null) .where('assets.visibility', '!=', AssetVisibility.HIDDEN), ) .where('assets.deletedAt', 'is', null) diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index 2ad32a3e1d..b5d77a3d9e 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -31,7 +31,7 @@ export type ModelPayload = { imagePath: string } | { text: string }; type ModelOptions = { modelName: string }; export type FaceDetectionOptions = ModelOptions & { minScore: number }; -export type OcrOptions = ModelOptions & { minScore: number }; +export type OcrOptions = ModelOptions & { minScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean }; type VisualResponse = { imageHeight: number; imageWidth: number }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse; @@ -40,12 +40,13 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo export type ClipTextualResponse = { [ModelTask.SEARCH]: string }; export type OCR = { + boundingBox: BoundingBox; text: string; confidence: number; }; export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minScore: number } } } }; -export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse; +export type OcrResponse = { [ModelTask.OCR]: OCR | OCR[] } & VisualResponse; export type FacialRecognitionRequest = { [ModelTask.FACIAL_RECOGNITION]: { @@ -203,8 +204,8 @@ export class MachineLearningRepository { return formData; } - async ocr(urls: string[], imagePath: string, { modelName, minScore }: OcrOptions) { - const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore } } } }; + async ocr(urls: string[], imagePath: string, { modelName, minScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) { + const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore, unwarpingEnabled, orientationClassifyEnabled } } } }; const response = await this.predict(urls, { imagePath }, request); return response[ModelTask.OCR]; } diff --git a/server/src/repositories/ocr.repository.ts b/server/src/repositories/ocr.repository.ts index 182a592e06..f21208c6d6 100644 --- a/server/src/repositories/ocr.repository.ts +++ b/server/src/repositories/ocr.repository.ts @@ -3,72 +3,41 @@ import { Kysely, sql } from 'kysely'; import { InjectKysely } from 'nestjs-kysely'; import { DB } from 'src/db'; import { DummyValue, GenerateSql } from 'src/decorators'; -import { OcrEntity } from 'src/entities/ocr.entity'; +export interface OcrInsertData { + assetId: string; + boundingBoxX1: number; + boundingBoxY1: number; + boundingBoxX2: number; + boundingBoxY2: number; + text: string; +} @Injectable() export class OcrRepository { constructor(@InjectKysely() private db: Kysely) {} @GenerateSql({ params: [DummyValue.UUID] }) - getOcrById(id: string): Promise { + async getById(id: string) { return this.db .selectFrom('asset_ocr') .selectAll('asset_ocr') .where('asset_ocr.assetId', '=', id) - .executeTakeFirst() as Promise; + .executeTakeFirst(); } - async insertOcrData(assetId: string, text: string): Promise { - await this.db - .insertInto('asset_ocr') - .values({ assetId, text }) - .execute(); - } - - async deleteAllOcr(): Promise { + async deleteAll(): Promise { await sql`truncate ${sql.table('asset_ocr')}`.execute(this.db); } - getAllOcr(options: Partial = {}): AsyncIterableIterator { - return this.db - .selectFrom('asset_ocr') - .selectAll('asset_ocr') - .$if(!!options.assetId, (qb) => qb.where('asset_ocr.assetId', '=', options.assetId!)) - .stream() as AsyncIterableIterator; - } - - - @GenerateSql() - async getLatestOcrDate(): Promise { - const result = (await this.db - .selectFrom('asset_job_status') - .select((eb) => sql`${eb.fn.max('asset_job_status.ocrAt')}::text`.as('latestDate')) - .executeTakeFirst()) as { latestDate: string } | undefined; - - return result?.latestDate; - } - - async updateOcrData(id: string, ocrData: string): Promise { + async insertMany(ocrDataList: OcrInsertData[]): Promise { + if (ocrDataList.length === 0) { + return; + } + await this.db - .updateTable('asset_ocr') - .set({ text: ocrData }) - .where('id', '=', id) - .execute(); - } - - getOcrWithoutText(): Promise { - return this.db - .selectFrom('asset_ocr') - .selectAll('asset_ocr') - .where('text', 'is', null) - .execute() as Promise; - } - - async delete(ocr: OcrEntity[]): Promise { - await this.db - .deleteFrom('asset_ocr') - .where('id', 'in', ocr.map((o) => o.id)) + .insertInto('asset_ocr') + .values(ocrDataList) .execute(); } } diff --git a/server/src/repositories/search.repository.ts b/server/src/repositories/search.repository.ts index aef9712a70..bacd953c92 100644 --- a/server/src/repositories/search.repository.ts +++ b/server/src/repositories/search.repository.ts @@ -321,19 +321,14 @@ export class SearchRepository { throw new Error(`Invalid value for 'size': ${pagination.size}`); } - const items = await this.db - .selectFrom('asset_ocr') - .selectAll() - .innerJoin('assets', 'assets.id', 'asset_ocr.assetId') - .where('assets.ownerId', '=', anyUuid(options.userIds)) + const items = await searchAssetBuilder(this.db, options) + .innerJoin('asset_ocr', 'assets.id', 'asset_ocr.assetId') .where('asset_ocr.text', 'ilike', `%${options.ocr}%`) .limit(pagination.size + 1) .offset((pagination.page - 1) * pagination.size) - .execute() as any; + .execute(); - const hasNextPage = items.length > pagination.size; - items.splice(pagination.size); - return { items, hasNextPage }; + return paginationHelper(items, pagination.size); } @GenerateSql({ diff --git a/server/src/services/ocr.service.ts b/server/src/services/ocr.service.ts index cba16bf942..e59b032288 100644 --- a/server/src/services/ocr.service.ts +++ b/server/src/services/ocr.service.ts @@ -13,12 +13,6 @@ import { isOcrEnabled } from 'src/utils/misc'; @Injectable() export class OcrService extends BaseService { - @OnJob({ name: JobName.OCR_CLEANUP, queue: QueueName.BACKGROUND_TASK }) - async handleOcrCleanup(): Promise { - const ocr = await this.ocrRepository.getOcrWithoutText(); - await this.ocrRepository.delete(ocr); - return JobStatus.SUCCESS; - } @OnJob({ name: JobName.QUEUE_OCR, queue: QueueName.OCR }) async handleQueueOcr({ force, nightly }: JobOf): Promise { @@ -28,7 +22,7 @@ export class OcrService extends BaseService { } if (force) { - await this.ocrRepository.deleteAllOcr(); + await this.ocrRepository.deleteAll(); } let jobs: JobItem[] = []; @@ -44,11 +38,6 @@ export class OcrService extends BaseService { } await this.jobRepository.queueAll(jobs); - - if (force === undefined) { - await this.jobRepository.queue({ name: JobName.OCR_CLEANUP }); - } - return JobStatus.SUCCESS; } @@ -77,8 +66,15 @@ export class OcrService extends BaseService { machineLearning.ocr ); - if (!ocrResults || ocrResults.text.length === 0) { - this.logger.warn(`No OCR results for document ${id}`); + const resultsArray = Array.isArray(ocrResults) ? ocrResults : [ocrResults]; + const validResults = resultsArray.filter(result => + result && + result.text && + result.text.trim().length > 0 + ); + + if (validResults.length === 0) { + this.logger.warn(`No valid OCR results for document ${id}`); await this.assetRepository.upsertJobStatus({ assetId: asset.id, ocrAt: new Date(), @@ -86,23 +82,29 @@ export class OcrService extends BaseService { return JobStatus.SUCCESS; } - this.logger.debug(`OCR ${id} has OCR results`); + try { + const ocrDataList = validResults.map(result => ({ + assetId: id, + boundingBoxX1: result.boundingBox.x1, + boundingBoxY1: result.boundingBox.y1, + boundingBoxX2: result.boundingBox.x2, + boundingBoxY2: result.boundingBox.y2, + text: result.text.trim(), + })); - const ocr = await this.ocrRepository.getOcrById(id); - if (ocr) { - this.logger.debug(`Updating OCR for ${id}`); - await this.ocrRepository.updateOcrData(id, ocrResults.text); - } else { - this.logger.debug(`Inserting OCR for ${id}`); - await this.ocrRepository.insertOcrData(id, ocrResults.text); + await this.ocrRepository.insertMany(ocrDataList); + + await this.assetRepository.upsertJobStatus({ + assetId: asset.id, + ocrAt: new Date(), + }); + + this.logger.debug(`Processed ${validResults.length} OCR result(s) for ${id}`); + return JobStatus.SUCCESS; + } catch (error) { + this.logger.error(`Failed to insert OCR results for ${id}:`, error); + return JobStatus.FAILED; } - - await this.assetRepository.upsertJobStatus({ - assetId: asset.id, - ocrAt: new Date(), - }); - this.logger.debug(`Processed OCR for ${id}`); - return JobStatus.SUCCESS; } } \ No newline at end of file diff --git a/server/src/services/search.service.ts b/server/src/services/search.service.ts index e9ba4bcb46..5dd8c3eb60 100644 --- a/server/src/services/search.service.ts +++ b/server/src/services/search.service.ts @@ -23,7 +23,7 @@ import { AssetOrder, AssetVisibility, Permission } from 'src/enum'; import { BaseService } from 'src/services/base.service'; import { requireElevatedPermission } from 'src/utils/access'; import { getMyPartnerIds } from 'src/utils/asset.util'; -import { isSmartSearchEnabled, isOcrEnabled } from 'src/utils/misc'; +import { isOcrEnabled, isSmartSearchEnabled } from 'src/utils/misc'; @Injectable() export class SearchService extends BaseService { diff --git a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte index 6265249f16..9ab8e88eb7 100644 --- a/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte +++ b/web/src/lib/components/admin-page/settings/machine-learning-settings/machine-learning-settings.svelte @@ -231,13 +231,30 @@ disabled={disabled || !config.machineLearning.enabled} /> + + + + +
+