diff --git a/machine-learning/immich_ml/models/ocr/recognition.py b/machine-learning/immich_ml/models/ocr/recognition.py index 74b3862dca..90ba691804 100644 --- a/machine-learning/immich_ml/models/ocr/recognition.py +++ b/machine-learning/immich_ml/models/ocr/recognition.py @@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType from immich_ml.config import log, settings from immich_ml.models.base import InferenceModel -from immich_ml.schemas import ModelSession, ModelTask, ModelType +from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType +from immich_ml.sessions.ort import OrtSession from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput @@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel): "text": [], "textScore": [], } - super().__init__(model_name, **model_kwargs) + super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX) def _download(self) -> None: model_info = InferSession.get_model_url( @@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel): DownloadFile.run(download_params) def _load(self) -> ModelSession: - session = self._make_session(self.model_path) + # TODO: support other runtimes + session = OrtSession(self.model_path) self.model = RapidTextRecognizer( OcrOptions( session=session.session, @@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel): valid_text_score_idx = text_scores > 0.5 valid_score_idx_list = valid_text_score_idx.tolist() return { - "box": boxes.reshape(-1, 8)[valid_text_score_idx], + "box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1), "text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]], "boxScore": box_scores[valid_text_score_idx], "textScore": text_scores[valid_text_score_idx], diff --git a/machine-learning/immich_ml/models/ocr/schemas.py b/machine-learning/immich_ml/models/ocr/schemas.py index bacc60bb2a..8db2fdc65c 100644 --- a/machine-learning/immich_ml/models/ocr/schemas.py +++ b/machine-learning/immich_ml/models/ocr/schemas.py @@ -9,7 +9,7 @@ from typing_extensions import TypedDict class TextDetectionOutput(TypedDict): resized: npt.NDArray[np.float32] boxes: npt.NDArray[np.float32] - scores: Iterable[float] + scores: npt.NDArray[np.float32] class TextRecognitionOutput(TypedDict): diff --git a/machine-learning/immich_ml/sessions/ort.py b/machine-learning/immich_ml/sessions/ort.py index e7d8635876..538da6b11c 100644 --- a/machine-learning/immich_ml/sessions/ort.py +++ b/machine-learning/immich_ml/sessions/ort.py @@ -14,6 +14,7 @@ from ..config import log, settings class OrtSession: + session: ort.InferenceSession def __init__( self, model_path: Path | str, diff --git a/server/src/config.ts b/server/src/config.ts index 88f1325203..64cd968db1 100644 --- a/server/src/config.ts +++ b/server/src/config.ts @@ -72,11 +72,9 @@ export interface SystemConfig { ocr: { enabled: boolean; modelName: string; - minDetectionBoxScore: number; minDetectionScore: number; minRecognitionScore: number; - unwarpingEnabled: boolean; - orientationClassifyEnabled: boolean; + maxResolution: number; }; }; map: { @@ -255,11 +253,9 @@ export const defaults = Object.freeze({ ocr: { enabled: true, modelName: 'PP-OCRv5_server', - minDetectionBoxScore: 0.6, minDetectionScore: 0.3, minRecognitionScore: 0.0, - unwarpingEnabled: false, - orientationClassifyEnabled: false, + maxResolution: 1440, }, }, map: { diff --git a/server/src/dtos/model-config.dto.ts b/server/src/dtos/model-config.dto.ts index 5baaf5bb04..527317346a 100644 --- a/server/src/dtos/model-config.dto.ts +++ b/server/src/dtos/model-config.dto.ts @@ -49,29 +49,22 @@ export class FacialRecognitionConfig extends ModelConfig { export class OcrConfig extends ModelConfig { @IsNumber() - @Min(0) - @Max(1) + @Min(1) @Type(() => Number) - @ApiProperty({ type: 'number', format: 'double' }) - minDetectionBoxScore!: number; + @ApiProperty({ type: 'integer' }) + maxResolution!: number; @IsNumber() - @Min(0) + @Min(0.1) @Max(1) @Type(() => Number) @ApiProperty({ type: 'number', format: 'double' }) minDetectionScore!: number; @IsNumber() - @Min(0) + @Min(0.1) @Max(1) @Type(() => Number) @ApiProperty({ type: 'number', format: 'double' }) minRecognitionScore!: number; - - @ValidateBoolean() - unwarpingEnabled!: boolean; - - @ValidateBoolean() - orientationClassifyEnabled!: boolean; } diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index 633d495296..839d9f33cb 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -31,7 +31,11 @@ export type ModelPayload = { imagePath: string } | { text: string }; type ModelOptions = { modelName: string }; export type FaceDetectionOptions = ModelOptions & { minScore: number }; -export type OcrOptions = ModelOptions & { minDetectionBoxScore: number, minDetectionScore: number, minRecognitionScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean }; +export type OcrOptions = ModelOptions & { + minDetectionScore: number; + minRecognitionScore: number; + maxResolution: number; +}; type VisualResponse = { imageHeight: number; imageWidth: number }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse; @@ -40,20 +44,19 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo export type ClipTextualResponse = { [ModelTask.SEARCH]: string }; export type OCR = { - x1: number; - y1: number; - x2: number; - y2: number; - x3: number; - y3: number; - x4: number; - y4: number; - text: string; - confidence: number; + text: string[]; + box: number[]; + boxScore: number[]; + textScore: number[]; }; -export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minDetectionScore: number, minRecognitionScore: number } } } }; -export type OcrResponse = { [ModelTask.OCR]: OCR[] } & VisualResponse; +export type OcrRequest = { + [ModelTask.OCR]: { + [ModelType.DETECTION]: ModelOptions & { options: { minScore: number; maxResolution: number } }; + [ModelType.RECOGNITION]: ModelOptions & { options: { minScore: number } }; + }; +}; +export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse; export type FacialRecognitionRequest = { [ModelTask.FACIAL_RECOGNITION]: { @@ -211,8 +214,17 @@ export class MachineLearningRepository { return formData; } - async ocr(urls: string[], imagePath: string, { modelName, minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) { - const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled } } } }; + async ocr( + urls: string[], + imagePath: string, + { modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions, + ) { + const request = { + [ModelTask.OCR]: { + [ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } }, + [ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } }, + }, + }; const response = await this.predict(urls, { imagePath }, request); return response[ModelTask.OCR]; } diff --git a/server/src/repositories/ocr.repository.ts b/server/src/repositories/ocr.repository.ts index 8a9548193b..01463a7fb3 100644 --- a/server/src/repositories/ocr.repository.ts +++ b/server/src/repositories/ocr.repository.ts @@ -1,23 +1,9 @@ import { Injectable } from '@nestjs/common'; -import { Kysely, sql } from 'kysely'; +import { Insertable, Kysely, sql } from 'kysely'; import { InjectKysely } from 'nestjs-kysely'; -import { DB } from 'src/db'; +import { AssetOcr, DB } from 'src/db'; import { DummyValue, GenerateSql } from 'src/decorators'; -export interface OcrInsertData { - assetId: string; - x1: number; - y1: number; - x2: number; - y2: number; - x3: number; - y3: number; - x4: number; - y4: number; - text: string; - confidence: number; -} - @Injectable() export class OcrRepository { constructor(@InjectKysely() private db: Kysely) {} @@ -54,28 +40,26 @@ export class OcrRepository { x4: DummyValue.NUMBER, y4: DummyValue.NUMBER, text: DummyValue.STRING, - confidence: DummyValue.NUMBER, + boxScore: DummyValue.NUMBER, + textScore: DummyValue.NUMBER, }, ], ], }) - upsert(assetId: string, ocrDataList: OcrInsertData[]) { - if (ocrDataList.length === 0) { - return; + upsert(assetId: string, ocrDataList: Insertable[]) { + let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId)); + if (ocrDataList.length > 0) { + const searchText = ocrDataList.map((item) => item.text.trim()).join(' '); + (query as any) = query + .with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList)) + .with('inserted_search', (db) => + db + .insertInto('ocr_search') + .values({ assetId, text: searchText }) + .onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))), + ); } - const searchText = ocrDataList.map((item) => item.text.trim()).join(' '); - - return this.db - .with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId)) - .with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList)) - .with('inserted_search', (db) => - db - .insertInto('ocr_search') - .values({ assetId, text: searchText }) - .onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))), - ) - .selectNoFrom(sql`1`.as('dummy')) - .execute(); + return query.selectNoFrom(sql`1`.as('dummy')).execute(); } } diff --git a/server/src/schema/migrations/1748926208942-CreateAssetOCRTable.ts b/server/src/schema/migrations/1750091086107-CreateAssetOCRTable.ts similarity index 75% rename from server/src/schema/migrations/1748926208942-CreateAssetOCRTable.ts rename to server/src/schema/migrations/1750091086107-CreateAssetOCRTable.ts index f968f17dc1..b3ff4aa3f0 100644 --- a/server/src/schema/migrations/1748926208942-CreateAssetOCRTable.ts +++ b/server/src/schema/migrations/1750091086107-CreateAssetOCRTable.ts @@ -1,7 +1,7 @@ import { Kysely, sql } from 'kysely'; export async function up(db: Kysely): Promise { - await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "text" text NOT NULL, "confidence" real NOT NULL);`.execute( + await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT immich_uuid_v7(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "boxScore" real NOT NULL, "textScore" real NOT NULL, "text" text NOT NULL);`.execute( db, ); await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db); diff --git a/server/src/schema/migrations/1748929348618-CreateOCRSearchTable.ts b/server/src/schema/migrations/1750091089793-CreateOCRSearchTable.ts similarity index 100% rename from server/src/schema/migrations/1748929348618-CreateOCRSearchTable.ts rename to server/src/schema/migrations/1750091089793-CreateOCRSearchTable.ts diff --git a/server/src/schema/migrations/1748858302889-UpsertOcrAssetJobStatus.ts b/server/src/schema/migrations/1750091146366-UpsertOcrAssetJobStatus.ts similarity index 100% rename from server/src/schema/migrations/1748858302889-UpsertOcrAssetJobStatus.ts rename to server/src/schema/migrations/1750091146366-UpsertOcrAssetJobStatus.ts diff --git a/server/src/schema/tables/asset-ocr.table.ts b/server/src/schema/tables/asset-ocr.table.ts index 532f593230..b29136ccd7 100644 --- a/server/src/schema/tables/asset-ocr.table.ts +++ b/server/src/schema/tables/asset-ocr.table.ts @@ -6,13 +6,12 @@ export class AssetOcrTable { @PrimaryGeneratedColumn() id!: string; - @ForeignKeyColumn(() => AssetTable, { - onDelete: 'CASCADE', - onUpdate: 'CASCADE', - index: true, - }) + @ForeignKeyColumn(() => AssetTable, { onDelete: 'CASCADE', onUpdate: 'CASCADE' }) assetId!: string; + @Column({ type: 'text' }) + text!: string; + @Column({ type: 'integer' }) x1!: number; @@ -37,9 +36,9 @@ export class AssetOcrTable { @Column({ type: 'integer' }) y4!: number; - @Column({ type: 'text' }) - text!: string; + @Column({ type: 'real' }) + boxScore!: number; @Column({ type: 'real' }) - confidence!: number; + textScore!: number; } diff --git a/server/src/services/ocr.service.ts b/server/src/services/ocr.service.ts index ce41b9500f..969becbe97 100644 --- a/server/src/services/ocr.service.ts +++ b/server/src/services/ocr.service.ts @@ -2,6 +2,7 @@ import { Injectable } from '@nestjs/common'; import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants'; import { OnJob } from 'src/decorators'; import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum'; +import { OCR } from 'src/repositories/machine-learning.repository'; import { BaseService } from 'src/services/base.service'; import { JobItem, JobOf } from 'src/types'; import { isOcrEnabled } from 'src/utils/misc'; @@ -57,27 +58,34 @@ export class OcrService extends BaseService { machineLearning.ocr, ); - if (ocrResults.length > 0) { - const ocrDataList = ocrResults.map((result) => ({ - assetId: id, - x1: result.x1, - y1: result.y1, - x2: result.x2, - y2: result.y2, - x3: result.x3, - y3: result.y3, - x4: result.x4, - y4: result.y4, - text: result.text, - confidence: result.confidence, - })); - - await this.ocrRepository.upsert(id, ocrDataList); - } + await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults)); await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() }); - this.logger.debug(`Processed ${ocrResults.length} OCR result(s) for ${id}`); + this.logger.debug(`Processed ${ocrResults.text.length} OCR result(s) for ${id}`); return JobStatus.SUCCESS; } + + parseOcrResults(id: string, ocrResults: OCR) { + const ocrDataList = []; + for (let i = 0; i < ocrResults.text.length; i++) { + const boxOffset = i * 8; + const row = { + assetId: id, + text: ocrResults.text[i], + boxScore: ocrResults.boxScore[i], + textScore: ocrResults.textScore[i], + x1: ocrResults.box[boxOffset], + y1: ocrResults.box[boxOffset + 1], + x2: ocrResults.box[boxOffset + 2], + y2: ocrResults.box[boxOffset + 3], + x3: ocrResults.box[boxOffset + 4], + y3: ocrResults.box[boxOffset + 5], + x4: ocrResults.box[boxOffset + 6], + y4: ocrResults.box[boxOffset + 7], + }; + ocrDataList.push(row); + } + return ocrDataList; + } }