From df36a09cd3fabaaba4d7e4dbbb1f85b8881c816d Mon Sep 17 00:00:00 2001 From: CoderKang Date: Mon, 2 Jun 2025 22:26:49 +0800 Subject: [PATCH] refactor(ocr): update OCR schema and response structure to use individual coordinates instead of bounding box, and adjust related service and repository files --- .../immich_ml/models/ocr/paddle.py | 8 +++-- machine-learning/immich_ml/schemas.py | 9 +++++- .../machine-learning.repository.ts | 12 ++++++-- server/src/repositories/ocr.repository.ts | 12 +++++--- ...ble.ts => 1748871815291-CreateOCRTable.ts} | 4 ++- server/src/schema/tables/asset-ocr.table.ts | 30 +++++++++++++------ server/src/services/ocr.service.ts | 25 +++++++--------- 7 files changed, 66 insertions(+), 34 deletions(-) rename server/src/schema/migrations/{1748864166925-CreateOCRTable.ts => 1748871815291-CreateOCRTable.ts} (64%) diff --git a/machine-learning/immich_ml/models/ocr/paddle.py b/machine-learning/immich_ml/models/ocr/paddle.py index 950a9af78e..880a38179f 100644 --- a/machine-learning/immich_ml/models/ocr/paddle.py +++ b/machine-learning/immich_ml/models/ocr/paddle.py @@ -34,13 +34,17 @@ class PaddleOCRecognizer(InferenceModel): valid_texts_and_scores = [ (text, score, box) for result in results - for text, score, box in zip(result['rec_texts'], result['rec_scores'], result['rec_boxes'].tolist()) + for text, score, box in zip(result['rec_texts'], result['rec_scores'], result['rec_polys']) if score >= self.min_score ] if not valid_texts_and_scores: return [] return [ - OCROutput(text=text, confidence=score, boundingBox={"x1": box[0], "y1": box[1], "x2": box[2], "y2": box[3]}) + OCROutput( + text=text, confidence=score, + x1=box[0][0], y1=box[0][1], x2=box[1][0], y2=box[1][1], + x3=box[2][0], y3=box[2][1], x4=box[3][0], y4=box[3][1] + ) for text, score, box in valid_texts_and_scores ] diff --git a/machine-learning/immich_ml/schemas.py b/machine-learning/immich_ml/schemas.py index d6622cb5f8..e95b51b11d 100644 --- a/machine-learning/immich_ml/schemas.py +++ b/machine-learning/immich_ml/schemas.py @@ -90,7 +90,14 @@ FacialRecognitionOutput = list[DetectedFace] class OCROutput(TypedDict): text: str confidence: float - boundingBox: BoundingBox + x1: int + y1: int + x2: int + y2: int + x3: int + y3: int + x4: int + y4: int class PipelineEntry(TypedDict): diff --git a/server/src/repositories/machine-learning.repository.ts b/server/src/repositories/machine-learning.repository.ts index b5d77a3d9e..6bae7eeaaf 100644 --- a/server/src/repositories/machine-learning.repository.ts +++ b/server/src/repositories/machine-learning.repository.ts @@ -40,13 +40,19 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo export type ClipTextualResponse = { [ModelTask.SEARCH]: string }; export type OCR = { - boundingBox: BoundingBox; + x1: number; + y1: number; + x2: number; + y2: number; + x3: number; + y3: number; + x4: number; + y4: number; text: string; - confidence: number; }; export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minScore: number } } } }; -export type OcrResponse = { [ModelTask.OCR]: OCR | OCR[] } & VisualResponse; +export type OcrResponse = { [ModelTask.OCR]: OCR[] } & VisualResponse; export type FacialRecognitionRequest = { [ModelTask.FACIAL_RECOGNITION]: { diff --git a/server/src/repositories/ocr.repository.ts b/server/src/repositories/ocr.repository.ts index f21208c6d6..38c1d3e52f 100644 --- a/server/src/repositories/ocr.repository.ts +++ b/server/src/repositories/ocr.repository.ts @@ -6,10 +6,14 @@ import { DummyValue, GenerateSql } from 'src/decorators'; export interface OcrInsertData { assetId: string; - boundingBoxX1: number; - boundingBoxY1: number; - boundingBoxX2: number; - boundingBoxY2: number; + x1: number; + y1: number; + x2: number; + y2: number; + x3: number; + y3: number; + x4: number; + y4: number; text: string; } diff --git a/server/src/schema/migrations/1748864166925-CreateOCRTable.ts b/server/src/schema/migrations/1748871815291-CreateOCRTable.ts similarity index 64% rename from server/src/schema/migrations/1748864166925-CreateOCRTable.ts rename to server/src/schema/migrations/1748871815291-CreateOCRTable.ts index d3634d0516..2c1b210424 100644 --- a/server/src/schema/migrations/1748864166925-CreateOCRTable.ts +++ b/server/src/schema/migrations/1748871815291-CreateOCRTable.ts @@ -1,12 +1,14 @@ import { Kysely, sql } from 'kysely'; export async function up(db: Kysely): Promise { - await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "boundingBoxX1" integer NOT NULL DEFAULT 0, "boundingBoxY1" integer NOT NULL DEFAULT 0, "boundingBoxX2" integer NOT NULL DEFAULT 0, "boundingBoxY2" integer NOT NULL DEFAULT 0, "text" text NOT NULL);`.execute(db); + await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "text" text NOT NULL);`.execute(db); await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db); await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "FK_dc592ec504976f5636e28bb84c6" FOREIGN KEY ("assetId") REFERENCES "assets" ("id") ON UPDATE CASCADE ON DELETE CASCADE;`.execute(db); + await sql`CREATE INDEX "IDX_dc592ec504976f5636e28bb84c" ON "asset_ocr" ("assetId")`.execute(db); } export async function down(db: Kysely): Promise { + await sql`DROP INDEX "IDX_dc592ec504976f5636e28bb84c";`.execute(db); await sql`ALTER TABLE "asset_ocr" DROP CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22";`.execute(db); await sql`ALTER TABLE "asset_ocr" DROP CONSTRAINT "FK_dc592ec504976f5636e28bb84c6";`.execute(db); await sql`DROP TABLE "asset_ocr";`.execute(db); diff --git a/server/src/schema/tables/asset-ocr.table.ts b/server/src/schema/tables/asset-ocr.table.ts index 8d50cc4157..90c2b23913 100644 --- a/server/src/schema/tables/asset-ocr.table.ts +++ b/server/src/schema/tables/asset-ocr.table.ts @@ -9,21 +9,33 @@ export class AssetOcrTable { @ForeignKeyColumn(() => AssetTable, { onDelete: 'CASCADE', onUpdate: 'CASCADE', - index: false, + index: true, }) assetId!: string; - @Column({ default: 0, type: 'integer' }) - boundingBoxX1!: number; + @Column({ type: 'integer' }) + x1!: number; - @Column({ default: 0, type: 'integer' }) - boundingBoxY1!: number; + @Column({ type: 'integer' }) + y1!: number; - @Column({ default: 0, type: 'integer' }) - boundingBoxX2!: number; + @Column({ type: 'integer' }) + x2!: number; - @Column({ default: 0, type: 'integer' }) - boundingBoxY2!: number; + @Column({ type: 'integer' }) + y2!: number; + + @Column({ type: 'integer' }) + x3!: number; + + @Column({ type: 'integer' }) + y3!: number; + + @Column({ type: 'integer' }) + x4!: number; + + @Column({ type: 'integer' }) + y4!: number; @Column({ type: 'text' }) text!: string; diff --git a/server/src/services/ocr.service.ts b/server/src/services/ocr.service.ts index e59b032288..15ef589524 100644 --- a/server/src/services/ocr.service.ts +++ b/server/src/services/ocr.service.ts @@ -66,14 +66,7 @@ export class OcrService extends BaseService { machineLearning.ocr ); - const resultsArray = Array.isArray(ocrResults) ? ocrResults : [ocrResults]; - const validResults = resultsArray.filter(result => - result && - result.text && - result.text.trim().length > 0 - ); - - if (validResults.length === 0) { + if (ocrResults.length === 0) { this.logger.warn(`No valid OCR results for document ${id}`); await this.assetRepository.upsertJobStatus({ assetId: asset.id, @@ -83,12 +76,16 @@ export class OcrService extends BaseService { } try { - const ocrDataList = validResults.map(result => ({ + const ocrDataList = ocrResults.map(result => ({ assetId: id, - boundingBoxX1: result.boundingBox.x1, - boundingBoxY1: result.boundingBox.y1, - boundingBoxX2: result.boundingBox.x2, - boundingBoxY2: result.boundingBox.y2, + x1: result.x1, + y1: result.y1, + x2: result.x2, + y2: result.y2, + x3: result.x3, + y3: result.y3, + x4: result.x4, + y4: result.y4, text: result.text.trim(), })); @@ -99,7 +96,7 @@ export class OcrService extends BaseService { ocrAt: new Date(), }); - this.logger.debug(`Processed ${validResults.length} OCR result(s) for ${id}`); + this.logger.debug(`Processed ${ocrResults.length} OCR result(s) for ${id}`); return JobStatus.SUCCESS; } catch (error) { this.logger.error(`Failed to insert OCR results for ${id}:`, error);