feat(ocr): enhance OCR model configuration with orientation classification and unwarping options, update PaddleOCR integration, and improve response structure

This commit is contained in:
CoderKang 2025-06-02 20:40:32 +08:00 committed by mertalev
parent 3949bf2cfa
commit 0e7ad8b2ba
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
15 changed files with 135 additions and 123 deletions

View file

@ -183,7 +183,10 @@ async def run_inference(payload: Image | str, entries: InferenceEntries) -> Infe
response: InferenceResponse = {} response: InferenceResponse = {}
async def _run_inference(entry: InferenceEntry) -> None: async def _run_inference(entry: InferenceEntry) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl) model = await model_cache.get(
entry["name"], entry["type"], entry["task"],
ttl=settings.model_ttl, **entry["options"]
)
inputs = [payload] inputs = [payload]
for dep in model.depends: for dep in model.depends:
try: try:

View file

@ -38,7 +38,13 @@ class ModelCache:
async def get( async def get(
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
) -> InferenceModel: ) -> InferenceModel:
key = f"{model_name}{model_type}{model_task}" config_key = ""
if model_type == ModelType.OCR and model_task == ModelTask.OCR:
orientation = model_kwargs.get("orientationClassifyEnabled", True)
unwarping = model_kwargs.get("unwarpingEnabled", True)
config_key = f"_o{orientation}_u{unwarping}"
key = f"{model_name}{model_type}{model_task}{config_key}"
async with OptimisticLock(self.cache, key) as lock: async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key) model: InferenceModel | None = await self.cache.get(key)

View file

@ -76,7 +76,8 @@ _INSIGHTFACE_MODELS = {
_PADDLE_MODELS = { _PADDLE_MODELS = {
"paddle", "PP-OCRv5_server",
"PP-OCRv5_mobile",
} }
SUPPORTED_PROVIDERS = [ SUPPORTED_PROVIDERS = [

View file

@ -1,4 +1,4 @@
from typing import Any from typing import Any, List
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
@ -14,34 +14,33 @@ class PaddleOCRecognizer(InferenceModel):
def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None: def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", min_score) self.min_score = model_kwargs.pop("minScore", min_score)
self.orientation_classify_enabled = model_kwargs.pop("orientationClassifyEnabled", True)
self.unwarping_enabled = model_kwargs.pop("unwarpingEnabled", True)
super().__init__(model_name, **model_kwargs) super().__init__(model_name, **model_kwargs)
self._load() self._load()
self.loaded = True self.loaded = True
def _load(self) -> None: def _load(self) -> PaddleOCR:
try: self.model = PaddleOCR(
self.model = PaddleOCR( text_detection_model_name=f"{self.model_name}_det",
use_doc_orientation_classify=False, text_recognition_model_name=f"{self.model_name}_rec",
use_doc_unwarping=False, use_doc_orientation_classify=self.orientation_classify_enabled,
use_textline_orientation=False use_doc_unwarping=self.unwarping_enabled,
) )
except Exception as e:
print(f"Error loading PaddleOCR model: {e}")
raise e
def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> OCROutput: def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> List[OCROutput]:
inputs = decode_cv2(inputs) inputs = decode_cv2(inputs)
results = self.model.predict(inputs) results = self.model.predict(inputs)
valid_texts_and_scores = [ valid_texts_and_scores = [
(text, score) (text, score, box)
for result in results for result in results
for text, score in zip(result['rec_texts'], result['rec_scores']) for text, score, box in zip(result['rec_texts'], result['rec_scores'], result['rec_boxes'].tolist())
if score > self.min_score if score >= self.min_score
] ]
if not valid_texts_and_scores: if not valid_texts_and_scores:
return OCROutput(text="", confidence=0.0) return []
texts, scores = zip(*valid_texts_and_scores)
return OCROutput( return [
text="".join(texts), OCROutput(text=text, confidence=score, boundingBox={"x1": box[0], "y1": box[1], "x2": box[2], "y2": box[3]})
confidence=sum(scores) / len(scores) for text, score, box in valid_texts_and_scores
) ]

View file

@ -90,6 +90,7 @@ FacialRecognitionOutput = list[DetectedFace]
class OCROutput(TypedDict): class OCROutput(TypedDict):
text: str text: str
confidence: float confidence: float
boundingBox: BoundingBox
class PipelineEntry(TypedDict): class PipelineEntry(TypedDict):

View file

@ -12926,12 +12926,20 @@
}, },
"modelName": { "modelName": {
"type": "string" "type": "string"
},
"orientationClassifyEnabled": {
"type": "boolean"
},
"unwarpingEnabled": {
"type": "boolean"
} }
}, },
"required": [ "required": [
"enabled", "enabled",
"minScore", "minScore",
"modelName" "modelName",
"orientationClassifyEnabled",
"unwarpingEnabled"
], ],
"type": "object" "type": "object"
}, },

View file

@ -73,6 +73,8 @@ export interface SystemConfig {
enabled: boolean; enabled: boolean;
modelName: string; modelName: string;
minScore: number; minScore: number;
unwarpingEnabled: boolean;
orientationClassifyEnabled: boolean;
}; };
}; };
map: { map: {
@ -250,8 +252,10 @@ export const defaults = Object.freeze<SystemConfig>({
}, },
ocr: { ocr: {
enabled: true, enabled: true,
modelName: 'paddle', modelName: 'PP-OCRv5_server',
minScore: 0.9, minScore: 0.9,
unwarpingEnabled: false,
orientationClassifyEnabled: false,
}, },
}, },
map: { map: {

View file

@ -54,4 +54,10 @@ export class OcrConfig extends ModelConfig {
@Type(() => Number) @Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' }) @ApiProperty({ type: 'number', format: 'double' })
minScore!: number; minScore!: number;
@ValidateBoolean()
unwarpingEnabled!: boolean;
@ValidateBoolean()
orientationClassifyEnabled!: boolean;
} }

View file

@ -355,10 +355,8 @@ export class AssetJobRepository {
.select(['assets.id']) .select(['assets.id'])
.$if(!force, (qb) => .$if(!force, (qb) =>
qb qb
.leftJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id') .innerJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id')
.where((eb) => .where('asset_job_status.ocrAt', 'is', null)
eb.or([eb('asset_job_status.ocrAt', 'is', null), eb('asset_job_status.assetId', 'is', null)]),
)
.where('assets.visibility', '!=', AssetVisibility.HIDDEN), .where('assets.visibility', '!=', AssetVisibility.HIDDEN),
) )
.where('assets.deletedAt', 'is', null) .where('assets.deletedAt', 'is', null)

View file

@ -31,7 +31,7 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string }; type ModelOptions = { modelName: string };
export type FaceDetectionOptions = ModelOptions & { minScore: number }; export type FaceDetectionOptions = ModelOptions & { minScore: number };
export type OcrOptions = ModelOptions & { minScore: number }; export type OcrOptions = ModelOptions & { minScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean };
type VisualResponse = { imageHeight: number; imageWidth: number }; type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } }; export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse; export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
@ -40,12 +40,13 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo
export type ClipTextualResponse = { [ModelTask.SEARCH]: string }; export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
export type OCR = { export type OCR = {
boundingBox: BoundingBox;
text: string; text: string;
confidence: number; confidence: number;
}; };
export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minScore: number } } } }; export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minScore: number } } } };
export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse; export type OcrResponse = { [ModelTask.OCR]: OCR | OCR[] } & VisualResponse;
export type FacialRecognitionRequest = { export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: { [ModelTask.FACIAL_RECOGNITION]: {
@ -203,8 +204,8 @@ export class MachineLearningRepository {
return formData; return formData;
} }
async ocr(urls: string[], imagePath: string, { modelName, minScore }: OcrOptions) { async ocr(urls: string[], imagePath: string, { modelName, minScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) {
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore } } } }; const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore, unwarpingEnabled, orientationClassifyEnabled } } } };
const response = await this.predict<OcrResponse>(urls, { imagePath }, request); const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
return response[ModelTask.OCR]; return response[ModelTask.OCR];
} }

View file

@ -3,72 +3,41 @@ import { Kysely, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely'; import { InjectKysely } from 'nestjs-kysely';
import { DB } from 'src/db'; import { DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators'; import { DummyValue, GenerateSql } from 'src/decorators';
import { OcrEntity } from 'src/entities/ocr.entity';
export interface OcrInsertData {
assetId: string;
boundingBoxX1: number;
boundingBoxY1: number;
boundingBoxX2: number;
boundingBoxY2: number;
text: string;
}
@Injectable() @Injectable()
export class OcrRepository { export class OcrRepository {
constructor(@InjectKysely() private db: Kysely<DB>) {} constructor(@InjectKysely() private db: Kysely<DB>) {}
@GenerateSql({ params: [DummyValue.UUID] }) @GenerateSql({ params: [DummyValue.UUID] })
getOcrById(id: string): Promise<OcrEntity | null> { async getById(id: string) {
return this.db return this.db
.selectFrom('asset_ocr') .selectFrom('asset_ocr')
.selectAll('asset_ocr') .selectAll('asset_ocr')
.where('asset_ocr.assetId', '=', id) .where('asset_ocr.assetId', '=', id)
.executeTakeFirst() as Promise<OcrEntity | null>; .executeTakeFirst();
} }
async insertOcrData(assetId: string, text: string): Promise<void> { async deleteAll(): Promise<void> {
await this.db
.insertInto('asset_ocr')
.values({ assetId, text })
.execute();
}
async deleteAllOcr(): Promise<void> {
await sql`truncate ${sql.table('asset_ocr')}`.execute(this.db); await sql`truncate ${sql.table('asset_ocr')}`.execute(this.db);
} }
getAllOcr(options: Partial<OcrEntity> = {}): AsyncIterableIterator<OcrEntity> { async insertMany(ocrDataList: OcrInsertData[]): Promise<void> {
return this.db if (ocrDataList.length === 0) {
.selectFrom('asset_ocr') return;
.selectAll('asset_ocr') }
.$if(!!options.assetId, (qb) => qb.where('asset_ocr.assetId', '=', options.assetId!))
.stream() as AsyncIterableIterator<OcrEntity>;
}
@GenerateSql()
async getLatestOcrDate(): Promise<string | undefined> {
const result = (await this.db
.selectFrom('asset_job_status')
.select((eb) => sql`${eb.fn.max('asset_job_status.ocrAt')}::text`.as('latestDate'))
.executeTakeFirst()) as { latestDate: string } | undefined;
return result?.latestDate;
}
async updateOcrData(id: string, ocrData: string): Promise<void> {
await this.db await this.db
.updateTable('asset_ocr') .insertInto('asset_ocr')
.set({ text: ocrData }) .values(ocrDataList)
.where('id', '=', id)
.execute();
}
getOcrWithoutText(): Promise<OcrEntity[]> {
return this.db
.selectFrom('asset_ocr')
.selectAll('asset_ocr')
.where('text', 'is', null)
.execute() as Promise<OcrEntity[]>;
}
async delete(ocr: OcrEntity[]): Promise<void> {
await this.db
.deleteFrom('asset_ocr')
.where('id', 'in', ocr.map((o) => o.id))
.execute(); .execute();
} }
} }

View file

@ -321,19 +321,14 @@ export class SearchRepository {
throw new Error(`Invalid value for 'size': ${pagination.size}`); throw new Error(`Invalid value for 'size': ${pagination.size}`);
} }
const items = await this.db const items = await searchAssetBuilder(this.db, options)
.selectFrom('asset_ocr') .innerJoin('asset_ocr', 'assets.id', 'asset_ocr.assetId')
.selectAll()
.innerJoin('assets', 'assets.id', 'asset_ocr.assetId')
.where('assets.ownerId', '=', anyUuid(options.userIds))
.where('asset_ocr.text', 'ilike', `%${options.ocr}%`) .where('asset_ocr.text', 'ilike', `%${options.ocr}%`)
.limit(pagination.size + 1) .limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size) .offset((pagination.page - 1) * pagination.size)
.execute() as any; .execute();
const hasNextPage = items.length > pagination.size; return paginationHelper(items, pagination.size);
items.splice(pagination.size);
return { items, hasNextPage };
} }
@GenerateSql({ @GenerateSql({

View file

@ -13,12 +13,6 @@ import { isOcrEnabled } from 'src/utils/misc';
@Injectable() @Injectable()
export class OcrService extends BaseService { export class OcrService extends BaseService {
@OnJob({ name: JobName.OCR_CLEANUP, queue: QueueName.BACKGROUND_TASK })
async handleOcrCleanup(): Promise<JobStatus> {
const ocr = await this.ocrRepository.getOcrWithoutText();
await this.ocrRepository.delete(ocr);
return JobStatus.SUCCESS;
}
@OnJob({ name: JobName.QUEUE_OCR, queue: QueueName.OCR }) @OnJob({ name: JobName.QUEUE_OCR, queue: QueueName.OCR })
async handleQueueOcr({ force, nightly }: JobOf<JobName.QUEUE_OCR>): Promise<JobStatus> { async handleQueueOcr({ force, nightly }: JobOf<JobName.QUEUE_OCR>): Promise<JobStatus> {
@ -28,7 +22,7 @@ export class OcrService extends BaseService {
} }
if (force) { if (force) {
await this.ocrRepository.deleteAllOcr(); await this.ocrRepository.deleteAll();
} }
let jobs: JobItem[] = []; let jobs: JobItem[] = [];
@ -44,11 +38,6 @@ export class OcrService extends BaseService {
} }
await this.jobRepository.queueAll(jobs); await this.jobRepository.queueAll(jobs);
if (force === undefined) {
await this.jobRepository.queue({ name: JobName.OCR_CLEANUP });
}
return JobStatus.SUCCESS; return JobStatus.SUCCESS;
} }
@ -77,8 +66,15 @@ export class OcrService extends BaseService {
machineLearning.ocr machineLearning.ocr
); );
if (!ocrResults || ocrResults.text.length === 0) { const resultsArray = Array.isArray(ocrResults) ? ocrResults : [ocrResults];
this.logger.warn(`No OCR results for document ${id}`); const validResults = resultsArray.filter(result =>
result &&
result.text &&
result.text.trim().length > 0
);
if (validResults.length === 0) {
this.logger.warn(`No valid OCR results for document ${id}`);
await this.assetRepository.upsertJobStatus({ await this.assetRepository.upsertJobStatus({
assetId: asset.id, assetId: asset.id,
ocrAt: new Date(), ocrAt: new Date(),
@ -86,23 +82,29 @@ export class OcrService extends BaseService {
return JobStatus.SUCCESS; return JobStatus.SUCCESS;
} }
this.logger.debug(`OCR ${id} has OCR results`); try {
const ocrDataList = validResults.map(result => ({
assetId: id,
boundingBoxX1: result.boundingBox.x1,
boundingBoxY1: result.boundingBox.y1,
boundingBoxX2: result.boundingBox.x2,
boundingBoxY2: result.boundingBox.y2,
text: result.text.trim(),
}));
const ocr = await this.ocrRepository.getOcrById(id); await this.ocrRepository.insertMany(ocrDataList);
if (ocr) {
this.logger.debug(`Updating OCR for ${id}`); await this.assetRepository.upsertJobStatus({
await this.ocrRepository.updateOcrData(id, ocrResults.text); assetId: asset.id,
} else { ocrAt: new Date(),
this.logger.debug(`Inserting OCR for ${id}`); });
await this.ocrRepository.insertOcrData(id, ocrResults.text);
this.logger.debug(`Processed ${validResults.length} OCR result(s) for ${id}`);
return JobStatus.SUCCESS;
} catch (error) {
this.logger.error(`Failed to insert OCR results for ${id}:`, error);
return JobStatus.FAILED;
} }
await this.assetRepository.upsertJobStatus({
assetId: asset.id,
ocrAt: new Date(),
});
this.logger.debug(`Processed OCR for ${id}`);
return JobStatus.SUCCESS;
} }
} }

View file

@ -23,7 +23,7 @@ import { AssetOrder, AssetVisibility, Permission } from 'src/enum';
import { BaseService } from 'src/services/base.service'; import { BaseService } from 'src/services/base.service';
import { requireElevatedPermission } from 'src/utils/access'; import { requireElevatedPermission } from 'src/utils/access';
import { getMyPartnerIds } from 'src/utils/asset.util'; import { getMyPartnerIds } from 'src/utils/asset.util';
import { isSmartSearchEnabled, isOcrEnabled } from 'src/utils/misc'; import { isOcrEnabled, isSmartSearchEnabled } from 'src/utils/misc';
@Injectable() @Injectable()
export class SearchService extends BaseService { export class SearchService extends BaseService {

View file

@ -231,13 +231,30 @@
disabled={disabled || !config.machineLearning.enabled} disabled={disabled || !config.machineLearning.enabled}
/> />
<SettingSwitch
title={$t('admin.machine_learning_ocr_unwarping_enabled')}
subtitle={$t('admin.machine_learning_ocr_unwarping_enabled_description')}
bind:checked={config.machineLearning.ocr.unwarpingEnabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
/>
<SettingSwitch
title={$t('admin.machine_learning_ocr_orientation_classify_enabled')}
subtitle={$t('admin.machine_learning_ocr_orientation_classify_enabled_description')}
bind:checked={config.machineLearning.ocr.orientationClassifyEnabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
/>
<hr />
<SettingSelect <SettingSelect
label={$t('admin.machine_learning_ocr_model')} label={$t('admin.machine_learning_ocr_model')}
desc={$t('admin.machine_learning_ocr_model_description')} desc={$t('admin.machine_learning_ocr_model_description')}
name="ocr-model" name="ocr-model"
bind:value={config.machineLearning.ocr.modelName} bind:value={config.machineLearning.ocr.modelName}
options={[ options={[
{ value: 'paddle', text: 'paddle' }, { value: 'PP-OCRv5_server', text: 'PP-OCRv5_server' },
{ value: 'PP-OCRv5_mobile', text: 'PP-OCRv5_mobile' },
]} ]}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled} disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
isEdited={config.machineLearning.ocr.modelName !== savedConfig.machineLearning.ocr.modelName} isEdited={config.machineLearning.ocr.modelName !== savedConfig.machineLearning.ocr.modelName}
@ -246,11 +263,13 @@
<SettingInputField <SettingInputField
inputType={SettingInputFieldType.NUMBER} inputType={SettingInputFieldType.NUMBER}
label={$t('admin.machine_learning_ocr_min_score')} label={$t('admin.machine_learning_ocr_min_score')}
description={$t('admin.machine_learning_ocr_min_score_description')}
bind:value={config.machineLearning.ocr.minScore} bind:value={config.machineLearning.ocr.minScore}
step="0.1" step="0.1"
min={0.1} min={0.1}
max={1} max={1}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled} disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
isEdited={config.machineLearning.ocr.minScore !== savedConfig.machineLearning.ocr.minScore}
/> />
</div> </div>
</SettingAccordion> </SettingAccordion>