change dto

This commit is contained in:
mertalev 2025-06-13 00:39:39 -04:00
parent c59f932bf0
commit 412468989f
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
12 changed files with 93 additions and 98 deletions

View file

@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log, settings from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel from immich_ml.models.base import InferenceModel
from immich_ml.schemas import ModelSession, ModelTask, ModelType from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel):
"text": [], "text": [],
"textScore": [], "textScore": [],
} }
super().__init__(model_name, **model_kwargs) super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
def _download(self) -> None: def _download(self) -> None:
model_info = InferSession.get_model_url( model_info = InferSession.get_model_url(
@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel):
DownloadFile.run(download_params) DownloadFile.run(download_params)
def _load(self) -> ModelSession: def _load(self) -> ModelSession:
session = self._make_session(self.model_path) # TODO: support other runtimes
session = OrtSession(self.model_path)
self.model = RapidTextRecognizer( self.model = RapidTextRecognizer(
OcrOptions( OcrOptions(
session=session.session, session=session.session,
@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel):
valid_text_score_idx = text_scores > 0.5 valid_text_score_idx = text_scores > 0.5
valid_score_idx_list = valid_text_score_idx.tolist() valid_score_idx_list = valid_text_score_idx.tolist()
return { return {
"box": boxes.reshape(-1, 8)[valid_text_score_idx], "box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]], "text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
"boxScore": box_scores[valid_text_score_idx], "boxScore": box_scores[valid_text_score_idx],
"textScore": text_scores[valid_text_score_idx], "textScore": text_scores[valid_text_score_idx],

View file

@ -9,7 +9,7 @@ from typing_extensions import TypedDict
class TextDetectionOutput(TypedDict): class TextDetectionOutput(TypedDict):
resized: npt.NDArray[np.float32] resized: npt.NDArray[np.float32]
boxes: npt.NDArray[np.float32] boxes: npt.NDArray[np.float32]
scores: Iterable[float] scores: npt.NDArray[np.float32]
class TextRecognitionOutput(TypedDict): class TextRecognitionOutput(TypedDict):

View file

@ -14,6 +14,7 @@ from ..config import log, settings
class OrtSession: class OrtSession:
session: ort.InferenceSession
def __init__( def __init__(
self, self,
model_path: Path | str, model_path: Path | str,

View file

@ -72,11 +72,9 @@ export interface SystemConfig {
ocr: { ocr: {
enabled: boolean; enabled: boolean;
modelName: string; modelName: string;
minDetectionBoxScore: number;
minDetectionScore: number; minDetectionScore: number;
minRecognitionScore: number; minRecognitionScore: number;
unwarpingEnabled: boolean; maxResolution: number;
orientationClassifyEnabled: boolean;
}; };
}; };
map: { map: {
@ -255,11 +253,9 @@ export const defaults = Object.freeze<SystemConfig>({
ocr: { ocr: {
enabled: true, enabled: true,
modelName: 'PP-OCRv5_server', modelName: 'PP-OCRv5_server',
minDetectionBoxScore: 0.6,
minDetectionScore: 0.3, minDetectionScore: 0.3,
minRecognitionScore: 0.0, minRecognitionScore: 0.0,
unwarpingEnabled: false, maxResolution: 1440,
orientationClassifyEnabled: false,
}, },
}, },
map: { map: {

View file

@ -49,29 +49,22 @@ export class FacialRecognitionConfig extends ModelConfig {
export class OcrConfig extends ModelConfig { export class OcrConfig extends ModelConfig {
@IsNumber() @IsNumber()
@Min(0) @Min(1)
@Max(1)
@Type(() => Number) @Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' }) @ApiProperty({ type: 'integer' })
minDetectionBoxScore!: number; maxResolution!: number;
@IsNumber() @IsNumber()
@Min(0) @Min(0.1)
@Max(1) @Max(1)
@Type(() => Number) @Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' }) @ApiProperty({ type: 'number', format: 'double' })
minDetectionScore!: number; minDetectionScore!: number;
@IsNumber() @IsNumber()
@Min(0) @Min(0.1)
@Max(1) @Max(1)
@Type(() => Number) @Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' }) @ApiProperty({ type: 'number', format: 'double' })
minRecognitionScore!: number; minRecognitionScore!: number;
@ValidateBoolean()
unwarpingEnabled!: boolean;
@ValidateBoolean()
orientationClassifyEnabled!: boolean;
} }

View file

@ -31,7 +31,11 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string }; type ModelOptions = { modelName: string };
export type FaceDetectionOptions = ModelOptions & { minScore: number }; export type FaceDetectionOptions = ModelOptions & { minScore: number };
export type OcrOptions = ModelOptions & { minDetectionBoxScore: number, minDetectionScore: number, minRecognitionScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean }; export type OcrOptions = ModelOptions & {
minDetectionScore: number;
minRecognitionScore: number;
maxResolution: number;
};
type VisualResponse = { imageHeight: number; imageWidth: number }; type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse; export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
@ -40,20 +44,19 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo
export type ClipTextualResponse = { [ModelTask.SEARCH]: string }; export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
export type OCR = { export type OCR = {
x1: number; text: string[];
y1: number; box: number[];
x2: number; boxScore: number[];
y2: number; textScore: number[];
x3: number;
y3: number;
x4: number;
y4: number;
text: string;
confidence: number;
}; };
export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minDetectionScore: number, minRecognitionScore: number } } } }; export type OcrRequest = {
export type OcrResponse = { [ModelTask.OCR]: OCR[] } & VisualResponse; [ModelTask.OCR]: {
[ModelType.DETECTION]: ModelOptions & { options: { minScore: number; maxResolution: number } };
[ModelType.RECOGNITION]: ModelOptions & { options: { minScore: number } };
};
};
export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse;
export type FacialRecognitionRequest = { export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: { [ModelTask.FACIAL_RECOGNITION]: {
@ -211,8 +214,17 @@ export class MachineLearningRepository {
return formData; return formData;
} }
async ocr(urls: string[], imagePath: string, { modelName, minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) { async ocr(
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled } } } }; urls: string[],
imagePath: string,
{ modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions,
) {
const request = {
[ModelTask.OCR]: {
[ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } },
[ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } },
},
};
const response = await this.predict<OcrResponse>(urls, { imagePath }, request); const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
return response[ModelTask.OCR]; return response[ModelTask.OCR];
} }

View file

@ -1,23 +1,9 @@
import { Injectable } from '@nestjs/common'; import { Injectable } from '@nestjs/common';
import { Kysely, sql } from 'kysely'; import { Insertable, Kysely, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely'; import { InjectKysely } from 'nestjs-kysely';
import { DB } from 'src/db'; import { AssetOcr, DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators'; import { DummyValue, GenerateSql } from 'src/decorators';
export interface OcrInsertData {
assetId: string;
x1: number;
y1: number;
x2: number;
y2: number;
x3: number;
y3: number;
x4: number;
y4: number;
text: string;
confidence: number;
}
@Injectable() @Injectable()
export class OcrRepository { export class OcrRepository {
constructor(@InjectKysely() private db: Kysely<DB>) {} constructor(@InjectKysely() private db: Kysely<DB>) {}
@ -54,28 +40,26 @@ export class OcrRepository {
x4: DummyValue.NUMBER, x4: DummyValue.NUMBER,
y4: DummyValue.NUMBER, y4: DummyValue.NUMBER,
text: DummyValue.STRING, text: DummyValue.STRING,
confidence: DummyValue.NUMBER, boxScore: DummyValue.NUMBER,
textScore: DummyValue.NUMBER,
}, },
], ],
], ],
}) })
upsert(assetId: string, ocrDataList: OcrInsertData[]) { upsert(assetId: string, ocrDataList: Insertable<AssetOcr>[]) {
if (ocrDataList.length === 0) { let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
return; if (ocrDataList.length > 0) {
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
(query as any) = query
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
.with('inserted_search', (db) =>
db
.insertInto('ocr_search')
.values({ assetId, text: searchText })
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))),
);
} }
const searchText = ocrDataList.map((item) => item.text.trim()).join(' '); return query.selectNoFrom(sql`1`.as('dummy')).execute();
return this.db
.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId))
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
.with('inserted_search', (db) =>
db
.insertInto('ocr_search')
.values({ assetId, text: searchText })
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))),
)
.selectNoFrom(sql`1`.as('dummy'))
.execute();
} }
} }

View file

@ -1,7 +1,7 @@
import { Kysely, sql } from 'kysely'; import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> { export async function up(db: Kysely<any>): Promise<void> {
await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "text" text NOT NULL, "confidence" real NOT NULL);`.execute( await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT immich_uuid_v7(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "boxScore" real NOT NULL, "textScore" real NOT NULL, "text" text NOT NULL);`.execute(
db, db,
); );
await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db); await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db);

View file

@ -6,13 +6,12 @@ export class AssetOcrTable {
@PrimaryGeneratedColumn() @PrimaryGeneratedColumn()
id!: string; id!: string;
@ForeignKeyColumn(() => AssetTable, { @ForeignKeyColumn(() => AssetTable, { onDelete: 'CASCADE', onUpdate: 'CASCADE' })
onDelete: 'CASCADE',
onUpdate: 'CASCADE',
index: true,
})
assetId!: string; assetId!: string;
@Column({ type: 'text' })
text!: string;
@Column({ type: 'integer' }) @Column({ type: 'integer' })
x1!: number; x1!: number;
@ -37,9 +36,9 @@ export class AssetOcrTable {
@Column({ type: 'integer' }) @Column({ type: 'integer' })
y4!: number; y4!: number;
@Column({ type: 'text' }) @Column({ type: 'real' })
text!: string; boxScore!: number;
@Column({ type: 'real' }) @Column({ type: 'real' })
confidence!: number; textScore!: number;
} }

View file

@ -2,6 +2,7 @@ import { Injectable } from '@nestjs/common';
import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants'; import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants';
import { OnJob } from 'src/decorators'; import { OnJob } from 'src/decorators';
import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum'; import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
import { OCR } from 'src/repositories/machine-learning.repository';
import { BaseService } from 'src/services/base.service'; import { BaseService } from 'src/services/base.service';
import { JobItem, JobOf } from 'src/types'; import { JobItem, JobOf } from 'src/types';
import { isOcrEnabled } from 'src/utils/misc'; import { isOcrEnabled } from 'src/utils/misc';
@ -57,27 +58,34 @@ export class OcrService extends BaseService {
machineLearning.ocr, machineLearning.ocr,
); );
if (ocrResults.length > 0) { await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
const ocrDataList = ocrResults.map((result) => ({
assetId: id,
x1: result.x1,
y1: result.y1,
x2: result.x2,
y2: result.y2,
x3: result.x3,
y3: result.y3,
x4: result.x4,
y4: result.y4,
text: result.text,
confidence: result.confidence,
}));
await this.ocrRepository.upsert(id, ocrDataList);
}
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() }); await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
this.logger.debug(`Processed ${ocrResults.length} OCR result(s) for ${id}`); this.logger.debug(`Processed ${ocrResults.text.length} OCR result(s) for ${id}`);
return JobStatus.SUCCESS; return JobStatus.SUCCESS;
} }
parseOcrResults(id: string, ocrResults: OCR) {
const ocrDataList = [];
for (let i = 0; i < ocrResults.text.length; i++) {
const boxOffset = i * 8;
const row = {
assetId: id,
text: ocrResults.text[i],
boxScore: ocrResults.boxScore[i],
textScore: ocrResults.textScore[i],
x1: ocrResults.box[boxOffset],
y1: ocrResults.box[boxOffset + 1],
x2: ocrResults.box[boxOffset + 2],
y2: ocrResults.box[boxOffset + 3],
x3: ocrResults.box[boxOffset + 4],
y3: ocrResults.box[boxOffset + 5],
x4: ocrResults.box[boxOffset + 6],
y4: ocrResults.box[boxOffset + 7],
};
ocrDataList.push(row);
}
return ocrDataList;
}
} }