feat(ocr): enhance OCR model configuration with orientation classification and unwarping options, update PaddleOCR integration, and improve response structure

This commit is contained in:
CoderKang 2025-06-02 20:40:32 +08:00 committed by mertalev
parent 3949bf2cfa
commit 0e7ad8b2ba
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
15 changed files with 135 additions and 123 deletions

View file

@ -183,7 +183,10 @@ async def run_inference(payload: Image | str, entries: InferenceEntries) -> Infe
response: InferenceResponse = {}
async def _run_inference(entry: InferenceEntry) -> None:
model = await model_cache.get(entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl)
model = await model_cache.get(
entry["name"], entry["type"], entry["task"],
ttl=settings.model_ttl, **entry["options"]
)
inputs = [payload]
for dep in model.depends:
try:

View file

@ -38,7 +38,13 @@ class ModelCache:
async def get(
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
) -> InferenceModel:
key = f"{model_name}{model_type}{model_task}"
config_key = ""
if model_type == ModelType.OCR and model_task == ModelTask.OCR:
orientation = model_kwargs.get("orientationClassifyEnabled", True)
unwarping = model_kwargs.get("unwarpingEnabled", True)
config_key = f"_o{orientation}_u{unwarping}"
key = f"{model_name}{model_type}{model_task}{config_key}"
async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)

View file

@ -76,7 +76,8 @@ _INSIGHTFACE_MODELS = {
_PADDLE_MODELS = {
"paddle",
"PP-OCRv5_server",
"PP-OCRv5_mobile",
}
SUPPORTED_PROVIDERS = [

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List
import numpy as np
from numpy.typing import NDArray
@ -14,34 +14,33 @@ class PaddleOCRecognizer(InferenceModel):
def __init__(self, model_name: str, min_score: float = 0.9, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
self.orientation_classify_enabled = model_kwargs.pop("orientationClassifyEnabled", True)
self.unwarping_enabled = model_kwargs.pop("unwarpingEnabled", True)
super().__init__(model_name, **model_kwargs)
self._load()
self.loaded = True
def _load(self) -> None:
try:
def _load(self) -> PaddleOCR:
self.model = PaddleOCR(
use_doc_orientation_classify=False,
use_doc_unwarping=False,
use_textline_orientation=False
text_detection_model_name=f"{self.model_name}_det",
text_recognition_model_name=f"{self.model_name}_rec",
use_doc_orientation_classify=self.orientation_classify_enabled,
use_doc_unwarping=self.unwarping_enabled,
)
except Exception as e:
print(f"Error loading PaddleOCR model: {e}")
raise e
def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> OCROutput:
def _predict(self, inputs: NDArray[np.uint8] | bytes | Image.Image, **kwargs: Any) -> List[OCROutput]:
inputs = decode_cv2(inputs)
results = self.model.predict(inputs)
valid_texts_and_scores = [
(text, score)
(text, score, box)
for result in results
for text, score in zip(result['rec_texts'], result['rec_scores'])
if score > self.min_score
for text, score, box in zip(result['rec_texts'], result['rec_scores'], result['rec_boxes'].tolist())
if score >= self.min_score
]
if not valid_texts_and_scores:
return OCROutput(text="", confidence=0.0)
texts, scores = zip(*valid_texts_and_scores)
return OCROutput(
text="".join(texts),
confidence=sum(scores) / len(scores)
)
return []
return [
OCROutput(text=text, confidence=score, boundingBox={"x1": box[0], "y1": box[1], "x2": box[2], "y2": box[3]})
for text, score, box in valid_texts_and_scores
]

View file

@ -90,6 +90,7 @@ FacialRecognitionOutput = list[DetectedFace]
class OCROutput(TypedDict):
text: str
confidence: float
boundingBox: BoundingBox
class PipelineEntry(TypedDict):

View file

@ -12926,12 +12926,20 @@
},
"modelName": {
"type": "string"
},
"orientationClassifyEnabled": {
"type": "boolean"
},
"unwarpingEnabled": {
"type": "boolean"
}
},
"required": [
"enabled",
"minScore",
"modelName"
"modelName",
"orientationClassifyEnabled",
"unwarpingEnabled"
],
"type": "object"
},

View file

@ -73,6 +73,8 @@ export interface SystemConfig {
enabled: boolean;
modelName: string;
minScore: number;
unwarpingEnabled: boolean;
orientationClassifyEnabled: boolean;
};
};
map: {
@ -250,8 +252,10 @@ export const defaults = Object.freeze<SystemConfig>({
},
ocr: {
enabled: true,
modelName: 'paddle',
modelName: 'PP-OCRv5_server',
minScore: 0.9,
unwarpingEnabled: false,
orientationClassifyEnabled: false,
},
},
map: {

View file

@ -54,4 +54,10 @@ export class OcrConfig extends ModelConfig {
@Type(() => Number)
@ApiProperty({ type: 'number', format: 'double' })
minScore!: number;
@ValidateBoolean()
unwarpingEnabled!: boolean;
@ValidateBoolean()
orientationClassifyEnabled!: boolean;
}

View file

@ -355,10 +355,8 @@ export class AssetJobRepository {
.select(['assets.id'])
.$if(!force, (qb) =>
qb
.leftJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id')
.where((eb) =>
eb.or([eb('asset_job_status.ocrAt', 'is', null), eb('asset_job_status.assetId', 'is', null)]),
)
.innerJoin('asset_job_status', 'asset_job_status.assetId', 'assets.id')
.where('asset_job_status.ocrAt', 'is', null)
.where('assets.visibility', '!=', AssetVisibility.HIDDEN),
)
.where('assets.deletedAt', 'is', null)

View file

@ -31,7 +31,7 @@ export type ModelPayload = { imagePath: string } | { text: string };
type ModelOptions = { modelName: string };
export type FaceDetectionOptions = ModelOptions & { minScore: number };
export type OcrOptions = ModelOptions & { minScore: number };
export type OcrOptions = ModelOptions & { minScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean };
type VisualResponse = { imageHeight: number; imageWidth: number };
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
@ -40,12 +40,13 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
export type OCR = {
boundingBox: BoundingBox;
text: string;
confidence: number;
};
export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minScore: number } } } };
export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse;
export type OcrResponse = { [ModelTask.OCR]: OCR | OCR[] } & VisualResponse;
export type FacialRecognitionRequest = {
[ModelTask.FACIAL_RECOGNITION]: {
@ -203,8 +204,8 @@ export class MachineLearningRepository {
return formData;
}
async ocr(urls: string[], imagePath: string, { modelName, minScore }: OcrOptions) {
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore } } } };
async ocr(urls: string[], imagePath: string, { modelName, minScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) {
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minScore, unwarpingEnabled, orientationClassifyEnabled } } } };
const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
return response[ModelTask.OCR];
}

View file

@ -3,72 +3,41 @@ import { Kysely, sql } from 'kysely';
import { InjectKysely } from 'nestjs-kysely';
import { DB } from 'src/db';
import { DummyValue, GenerateSql } from 'src/decorators';
import { OcrEntity } from 'src/entities/ocr.entity';
export interface OcrInsertData {
assetId: string;
boundingBoxX1: number;
boundingBoxY1: number;
boundingBoxX2: number;
boundingBoxY2: number;
text: string;
}
@Injectable()
export class OcrRepository {
constructor(@InjectKysely() private db: Kysely<DB>) {}
@GenerateSql({ params: [DummyValue.UUID] })
getOcrById(id: string): Promise<OcrEntity | null> {
async getById(id: string) {
return this.db
.selectFrom('asset_ocr')
.selectAll('asset_ocr')
.where('asset_ocr.assetId', '=', id)
.executeTakeFirst() as Promise<OcrEntity | null>;
.executeTakeFirst();
}
async insertOcrData(assetId: string, text: string): Promise<void> {
await this.db
.insertInto('asset_ocr')
.values({ assetId, text })
.execute();
}
async deleteAllOcr(): Promise<void> {
async deleteAll(): Promise<void> {
await sql`truncate ${sql.table('asset_ocr')}`.execute(this.db);
}
getAllOcr(options: Partial<OcrEntity> = {}): AsyncIterableIterator<OcrEntity> {
return this.db
.selectFrom('asset_ocr')
.selectAll('asset_ocr')
.$if(!!options.assetId, (qb) => qb.where('asset_ocr.assetId', '=', options.assetId!))
.stream() as AsyncIterableIterator<OcrEntity>;
async insertMany(ocrDataList: OcrInsertData[]): Promise<void> {
if (ocrDataList.length === 0) {
return;
}
@GenerateSql()
async getLatestOcrDate(): Promise<string | undefined> {
const result = (await this.db
.selectFrom('asset_job_status')
.select((eb) => sql`${eb.fn.max('asset_job_status.ocrAt')}::text`.as('latestDate'))
.executeTakeFirst()) as { latestDate: string } | undefined;
return result?.latestDate;
}
async updateOcrData(id: string, ocrData: string): Promise<void> {
await this.db
.updateTable('asset_ocr')
.set({ text: ocrData })
.where('id', '=', id)
.execute();
}
getOcrWithoutText(): Promise<OcrEntity[]> {
return this.db
.selectFrom('asset_ocr')
.selectAll('asset_ocr')
.where('text', 'is', null)
.execute() as Promise<OcrEntity[]>;
}
async delete(ocr: OcrEntity[]): Promise<void> {
await this.db
.deleteFrom('asset_ocr')
.where('id', 'in', ocr.map((o) => o.id))
.insertInto('asset_ocr')
.values(ocrDataList)
.execute();
}
}

View file

@ -321,19 +321,14 @@ export class SearchRepository {
throw new Error(`Invalid value for 'size': ${pagination.size}`);
}
const items = await this.db
.selectFrom('asset_ocr')
.selectAll()
.innerJoin('assets', 'assets.id', 'asset_ocr.assetId')
.where('assets.ownerId', '=', anyUuid(options.userIds))
const items = await searchAssetBuilder(this.db, options)
.innerJoin('asset_ocr', 'assets.id', 'asset_ocr.assetId')
.where('asset_ocr.text', 'ilike', `%${options.ocr}%`)
.limit(pagination.size + 1)
.offset((pagination.page - 1) * pagination.size)
.execute() as any;
.execute();
const hasNextPage = items.length > pagination.size;
items.splice(pagination.size);
return { items, hasNextPage };
return paginationHelper(items, pagination.size);
}
@GenerateSql({

View file

@ -13,12 +13,6 @@ import { isOcrEnabled } from 'src/utils/misc';
@Injectable()
export class OcrService extends BaseService {
@OnJob({ name: JobName.OCR_CLEANUP, queue: QueueName.BACKGROUND_TASK })
async handleOcrCleanup(): Promise<JobStatus> {
const ocr = await this.ocrRepository.getOcrWithoutText();
await this.ocrRepository.delete(ocr);
return JobStatus.SUCCESS;
}
@OnJob({ name: JobName.QUEUE_OCR, queue: QueueName.OCR })
async handleQueueOcr({ force, nightly }: JobOf<JobName.QUEUE_OCR>): Promise<JobStatus> {
@ -28,7 +22,7 @@ export class OcrService extends BaseService {
}
if (force) {
await this.ocrRepository.deleteAllOcr();
await this.ocrRepository.deleteAll();
}
let jobs: JobItem[] = [];
@ -44,11 +38,6 @@ export class OcrService extends BaseService {
}
await this.jobRepository.queueAll(jobs);
if (force === undefined) {
await this.jobRepository.queue({ name: JobName.OCR_CLEANUP });
}
return JobStatus.SUCCESS;
}
@ -77,8 +66,15 @@ export class OcrService extends BaseService {
machineLearning.ocr
);
if (!ocrResults || ocrResults.text.length === 0) {
this.logger.warn(`No OCR results for document ${id}`);
const resultsArray = Array.isArray(ocrResults) ? ocrResults : [ocrResults];
const validResults = resultsArray.filter(result =>
result &&
result.text &&
result.text.trim().length > 0
);
if (validResults.length === 0) {
this.logger.warn(`No valid OCR results for document ${id}`);
await this.assetRepository.upsertJobStatus({
assetId: asset.id,
ocrAt: new Date(),
@ -86,23 +82,29 @@ export class OcrService extends BaseService {
return JobStatus.SUCCESS;
}
this.logger.debug(`OCR ${id} has OCR results`);
try {
const ocrDataList = validResults.map(result => ({
assetId: id,
boundingBoxX1: result.boundingBox.x1,
boundingBoxY1: result.boundingBox.y1,
boundingBoxX2: result.boundingBox.x2,
boundingBoxY2: result.boundingBox.y2,
text: result.text.trim(),
}));
const ocr = await this.ocrRepository.getOcrById(id);
if (ocr) {
this.logger.debug(`Updating OCR for ${id}`);
await this.ocrRepository.updateOcrData(id, ocrResults.text);
} else {
this.logger.debug(`Inserting OCR for ${id}`);
await this.ocrRepository.insertOcrData(id, ocrResults.text);
}
await this.ocrRepository.insertMany(ocrDataList);
await this.assetRepository.upsertJobStatus({
assetId: asset.id,
ocrAt: new Date(),
});
this.logger.debug(`Processed OCR for ${id}`);
this.logger.debug(`Processed ${validResults.length} OCR result(s) for ${id}`);
return JobStatus.SUCCESS;
} catch (error) {
this.logger.error(`Failed to insert OCR results for ${id}:`, error);
return JobStatus.FAILED;
}
}
}

View file

@ -23,7 +23,7 @@ import { AssetOrder, AssetVisibility, Permission } from 'src/enum';
import { BaseService } from 'src/services/base.service';
import { requireElevatedPermission } from 'src/utils/access';
import { getMyPartnerIds } from 'src/utils/asset.util';
import { isSmartSearchEnabled, isOcrEnabled } from 'src/utils/misc';
import { isOcrEnabled, isSmartSearchEnabled } from 'src/utils/misc';
@Injectable()
export class SearchService extends BaseService {

View file

@ -231,13 +231,30 @@
disabled={disabled || !config.machineLearning.enabled}
/>
<SettingSwitch
title={$t('admin.machine_learning_ocr_unwarping_enabled')}
subtitle={$t('admin.machine_learning_ocr_unwarping_enabled_description')}
bind:checked={config.machineLearning.ocr.unwarpingEnabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
/>
<SettingSwitch
title={$t('admin.machine_learning_ocr_orientation_classify_enabled')}
subtitle={$t('admin.machine_learning_ocr_orientation_classify_enabled_description')}
bind:checked={config.machineLearning.ocr.orientationClassifyEnabled}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
/>
<hr />
<SettingSelect
label={$t('admin.machine_learning_ocr_model')}
desc={$t('admin.machine_learning_ocr_model_description')}
name="ocr-model"
bind:value={config.machineLearning.ocr.modelName}
options={[
{ value: 'paddle', text: 'paddle' },
{ value: 'PP-OCRv5_server', text: 'PP-OCRv5_server' },
{ value: 'PP-OCRv5_mobile', text: 'PP-OCRv5_mobile' },
]}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
isEdited={config.machineLearning.ocr.modelName !== savedConfig.machineLearning.ocr.modelName}
@ -246,11 +263,13 @@
<SettingInputField
inputType={SettingInputFieldType.NUMBER}
label={$t('admin.machine_learning_ocr_min_score')}
description={$t('admin.machine_learning_ocr_min_score_description')}
bind:value={config.machineLearning.ocr.minScore}
step="0.1"
min={0.1}
max={1}
disabled={disabled || !config.machineLearning.enabled || !config.machineLearning.ocr.enabled}
isEdited={config.machineLearning.ocr.minScore !== savedConfig.machineLearning.ocr.minScore}
/>
</div>
</SettingAccordion>