change dto

This commit is contained in:
mertalev 2025-06-13 00:39:39 -04:00
parent c59f932bf0
commit 412468989f
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
12 changed files with 93 additions and 98 deletions

View file

@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.schemas import ModelSession, ModelTask, ModelType
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel):
"text": [],
"textScore": [],
}
super().__init__(model_name, **model_kwargs)
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
def _download(self) -> None:
model_info = InferSession.get_model_url(
@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel):
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
# TODO: support other runtimes
session = OrtSession(self.model_path)
self.model = RapidTextRecognizer(
OcrOptions(
session=session.session,
@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel):
valid_text_score_idx = text_scores > 0.5
valid_score_idx_list = valid_text_score_idx.tolist()
return {
"box": boxes.reshape(-1, 8)[valid_text_score_idx],
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
"boxScore": box_scores[valid_text_score_idx],
"textScore": text_scores[valid_text_score_idx],

View file

@ -9,7 +9,7 @@ from typing_extensions import TypedDict
class TextDetectionOutput(TypedDict):
resized: npt.NDArray[np.float32]
boxes: npt.NDArray[np.float32]
scores: Iterable[float]
scores: npt.NDArray[np.float32]
class TextRecognitionOutput(TypedDict):

View file

@ -14,6 +14,7 @@ from ..config import log, settings
class OrtSession:
session: ort.InferenceSession
def __init__(
self,
model_path: Path | str,

View file

@ -72,11 +72,9 @@ export interface SystemConfig {
ocr: {
enabled: boolean;
modelName: string;
minDetectionBoxScore: number;
minDetectionScore: number;
minRecognitionScore: number;
unwarpingEnabled: boolean;
orientationClassifyEnabled: boolean;
maxResolution: number;
};
};
map: {
@ -255,11 +253,9 @@ export const defaults = Object.freeze<SystemConfig>({
ocr: {
enabled: true,
modelName: 'PP-OCRv5_server',
minDetectionBoxScore: 0.6,
minDetectionScore: 0.3,
minRecognitionScore: 0.0,
unwarpingEnabled: false,
orientationClassifyEnabled: false,
maxResolution: 1440,
},
},
map: {

View file

@ -49,29 +49,22 @@ export class FacialRecognitionConfig extends ModelConfig {
export class OcrConfig extends ModelConfig {
@IsNumber()
@Min(0)
@Max(1)
@Min(1)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' })
minDetectionBoxScore!: number;
@ApiProperty({ type: 'integer' })
maxResolution!: number;
@IsNumber()
@Min(0)
@Min(0.1)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' })
minDetectionScore!: number;
@IsNumber()
@Min(0)
@Min(0.1)
@Max(1)
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' })
minRecognitionScore!: number;
@ValidateBoolean()
unwarpingEnabled!: boolean;
@ValidateBoolean()
orientationClassifyEnabled!: boolean;
}

View file

@ -31,7 +31,11 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string };
export type FaceDetectionOptions = ModelOptions & { minScore: number };
export type OcrOptions = ModelOptions & { minDetectionBoxScore: number, minDetectionScore: number, minRecognitionScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean };
export type OcrOptions = ModelOptions & {
minDetectionScore: number;
minRecognitionScore: number;
maxResolution: number;
};
type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
@ -40,20 +44,19 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
export type OCR = {
x1: number;
y1: number;
x2: number;
y2: number;
x3: number;
y3: number;
x4: number;
y4: number;
text: string;
confidence: number;
text: string[];
box: number[];
boxScore: number[];
textScore: number[];
};
export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minDetectionScore: number, minRecognitionScore: number } } } };
export type OcrResponse = { [ModelTask.OCR]: OCR[] } & VisualResponse;
export type OcrRequest = {
[ModelTask.OCR]: {
[ModelType.DETECTION]: ModelOptions & { options: { minScore: number; maxResolution: number } };
[ModelType.RECOGNITION]: ModelOptions & { options: { minScore: number } };
};
};
export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse;
export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: {
@ -211,8 +214,17 @@ export class MachineLearningRepository {
return formData;
}
async ocr(urls: string[], imagePath: string, { modelName, minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) {
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled } } } };
async ocr(
urls: string[],
imagePath: string,
{ modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions,
) {
const request = {
[ModelTask.OCR]: {
[ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } },
[ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } },
},
};
const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
return response[ModelTask.OCR];
}

View file

@ -1,23 +1,9 @@
import { Injectable } from '@nestjs/common';
import { Kysely, sql } from 'kysely';
import { Insertable, Kysely, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely';
import { DB } from 'src/db';
import { AssetOcr, DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
export interface OcrInsertData {
assetId: string;
x1: number;
y1: number;
x2: number;
y2: number;
x3: number;
y3: number;
x4: number;
y4: number;
text: string;
confidence: number;
}
@Injectable()
export class OcrRepository {
constructor(@InjectKysely() private db: Kysely<DB>) {}
@ -54,28 +40,26 @@ export class OcrRepository {
x4: DummyValue.NUMBER,
y4: DummyValue.NUMBER,
text: DummyValue.STRING,
confidence: DummyValue.NUMBER,
boxScore: DummyValue.NUMBER,
textScore: DummyValue.NUMBER,
},
],
],
})
upsert(assetId: string, ocrDataList: OcrInsertData[]) {
if (ocrDataList.length === 0) {
return;
}
upsert(assetId: string, ocrDataList: Insertable<AssetOcr>[]) {
let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
if (ocrDataList.length > 0) {
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
return this.db
.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId))
(query as any) = query
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
.with('inserted_search', (db) =>
db
.insertInto('ocr_search')
.values({ assetId, text: searchText })
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))),
)
.selectNoFrom(sql`1`.as('dummy'))
.execute();
);
}
return query.selectNoFrom(sql`1`.as('dummy')).execute();
}
}

View file

@ -1,7 +1,7 @@
import { Kysely, sql } from 'kysely';
export async function up(db: Kysely<any>): Promise<void> {
await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "text" text NOT NULL, "confidence" real NOT NULL);`.execute(
await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT immich_uuid_v7(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "boxScore" real NOT NULL, "textScore" real NOT NULL, "text" text NOT NULL);`.execute(
db,
);
await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db);

View file

@ -6,13 +6,12 @@ export class AssetOcrTable {
@PrimaryGeneratedColumn()
id!: string;
@ForeignKeyColumn(() => AssetTable, {
onDelete: 'CASCADE',
onUpdate: 'CASCADE',
index: true,
})
@ForeignKeyColumn(() => AssetTable, { onDelete: 'CASCADE', onUpdate: 'CASCADE' })
assetId!: string;
@Column({ type: 'text' })
text!: string;
@Column({ type: 'integer' })
x1!: number;
@ -37,9 +36,9 @@ export class AssetOcrTable {
@Column({ type: 'integer' })
y4!: number;
@Column({ type: 'text' })
text!: string;
@Column({ type: 'real' })
boxScore!: number;
@Column({ type: 'real' })
confidence!: number;
textScore!: number;
}

View file

@ -2,6 +2,7 @@ import { Injectable } from '@nestjs/common';
import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants';
import { OnJob } from 'src/decorators';
import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
import { OCR } from 'src/repositories/machine-learning.repository';
import { BaseService } from 'src/services/base.service';
import { JobItem, JobOf } from 'src/types';
import { isOcrEnabled } from 'src/utils/misc';
@ -57,27 +58,34 @@ export class OcrService extends BaseService {
machineLearning.ocr,
);
if (ocrResults.length > 0) {
const ocrDataList = ocrResults.map((result) => ({
assetId: id,
x1: result.x1,
y1: result.y1,
x2: result.x2,
y2: result.y2,
x3: result.x3,
y3: result.y3,
x4: result.x4,
y4: result.y4,
text: result.text,
confidence: result.confidence,
}));
await this.ocrRepository.upsert(id, ocrDataList);
}
await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
this.logger.debug(`Processed ${ocrResults.length} OCR result(s) for ${id}`);
this.logger.debug(`Processed ${ocrResults.text.length} OCR result(s) for ${id}`);
return JobStatus.SUCCESS;
}
parseOcrResults(id: string, ocrResults: OCR) {
const ocrDataList = [];
for (let i = 0; i < ocrResults.text.length; i++) {
const boxOffset = i * 8;
const row = {
assetId: id,
text: ocrResults.text[i],
boxScore: ocrResults.boxScore[i],
textScore: ocrResults.textScore[i],
x1: ocrResults.box[boxOffset],
y1: ocrResults.box[boxOffset + 1],
x2: ocrResults.box[boxOffset + 2],
y2: ocrResults.box[boxOffset + 3],
x3: ocrResults.box[boxOffset + 4],
y3: ocrResults.box[boxOffset + 5],
x4: ocrResults.box[boxOffset + 6],
y4: ocrResults.box[boxOffset + 7],
};
ocrDataList.push(row);
}
return ocrDataList;
}
}