mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
change dto
This commit is contained in:
parent
c59f932bf0
commit
412468989f
12 changed files with 93 additions and 98 deletions
|
|
@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType
|
||||||
|
|
||||||
from immich_ml.config import log, settings
|
from immich_ml.config import log, settings
|
||||||
from immich_ml.models.base import InferenceModel
|
from immich_ml.models.base import InferenceModel
|
||||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||||
|
from immich_ml.sessions.ort import OrtSession
|
||||||
|
|
||||||
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
|
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
|
||||||
|
|
||||||
|
|
@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel):
|
||||||
"text": [],
|
"text": [],
|
||||||
"textScore": [],
|
"textScore": [],
|
||||||
}
|
}
|
||||||
super().__init__(model_name, **model_kwargs)
|
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
|
||||||
|
|
||||||
def _download(self) -> None:
|
def _download(self) -> None:
|
||||||
model_info = InferSession.get_model_url(
|
model_info = InferSession.get_model_url(
|
||||||
|
|
@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel):
|
||||||
DownloadFile.run(download_params)
|
DownloadFile.run(download_params)
|
||||||
|
|
||||||
def _load(self) -> ModelSession:
|
def _load(self) -> ModelSession:
|
||||||
session = self._make_session(self.model_path)
|
# TODO: support other runtimes
|
||||||
|
session = OrtSession(self.model_path)
|
||||||
self.model = RapidTextRecognizer(
|
self.model = RapidTextRecognizer(
|
||||||
OcrOptions(
|
OcrOptions(
|
||||||
session=session.session,
|
session=session.session,
|
||||||
|
|
@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel):
|
||||||
valid_text_score_idx = text_scores > 0.5
|
valid_text_score_idx = text_scores > 0.5
|
||||||
valid_score_idx_list = valid_text_score_idx.tolist()
|
valid_score_idx_list = valid_text_score_idx.tolist()
|
||||||
return {
|
return {
|
||||||
"box": boxes.reshape(-1, 8)[valid_text_score_idx],
|
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
|
||||||
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
|
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
|
||||||
"boxScore": box_scores[valid_text_score_idx],
|
"boxScore": box_scores[valid_text_score_idx],
|
||||||
"textScore": text_scores[valid_text_score_idx],
|
"textScore": text_scores[valid_text_score_idx],
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing_extensions import TypedDict
|
||||||
class TextDetectionOutput(TypedDict):
|
class TextDetectionOutput(TypedDict):
|
||||||
resized: npt.NDArray[np.float32]
|
resized: npt.NDArray[np.float32]
|
||||||
boxes: npt.NDArray[np.float32]
|
boxes: npt.NDArray[np.float32]
|
||||||
scores: Iterable[float]
|
scores: npt.NDArray[np.float32]
|
||||||
|
|
||||||
|
|
||||||
class TextRecognitionOutput(TypedDict):
|
class TextRecognitionOutput(TypedDict):
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from ..config import log, settings
|
||||||
|
|
||||||
|
|
||||||
class OrtSession:
|
class OrtSession:
|
||||||
|
session: ort.InferenceSession
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: Path | str,
|
model_path: Path | str,
|
||||||
|
|
|
||||||
|
|
@ -72,11 +72,9 @@ export interface SystemConfig {
|
||||||
ocr: {
|
ocr: {
|
||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
modelName: string;
|
modelName: string;
|
||||||
minDetectionBoxScore: number;
|
|
||||||
minDetectionScore: number;
|
minDetectionScore: number;
|
||||||
minRecognitionScore: number;
|
minRecognitionScore: number;
|
||||||
unwarpingEnabled: boolean;
|
maxResolution: number;
|
||||||
orientationClassifyEnabled: boolean;
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
map: {
|
map: {
|
||||||
|
|
@ -255,11 +253,9 @@ export const defaults = Object.freeze<SystemConfig>({
|
||||||
ocr: {
|
ocr: {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
modelName: 'PP-OCRv5_server',
|
modelName: 'PP-OCRv5_server',
|
||||||
minDetectionBoxScore: 0.6,
|
|
||||||
minDetectionScore: 0.3,
|
minDetectionScore: 0.3,
|
||||||
minRecognitionScore: 0.0,
|
minRecognitionScore: 0.0,
|
||||||
unwarpingEnabled: false,
|
maxResolution: 1440,
|
||||||
orientationClassifyEnabled: false,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
map: {
|
map: {
|
||||||
|
|
|
||||||
|
|
@ -49,29 +49,22 @@ export class FacialRecognitionConfig extends ModelConfig {
|
||||||
|
|
||||||
export class OcrConfig extends ModelConfig {
|
export class OcrConfig extends ModelConfig {
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
@Min(0)
|
@Min(1)
|
||||||
@Max(1)
|
|
||||||
@Type(() => Number)
|
@Type(() => Number)
|
||||||
@ApiProperty({ type: 'number', format: 'double' })
|
@ApiProperty({ type: 'integer' })
|
||||||
minDetectionBoxScore!: number;
|
maxResolution!: number;
|
||||||
|
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
@Min(0)
|
@Min(0.1)
|
||||||
@Max(1)
|
@Max(1)
|
||||||
@Type(() => Number)
|
@Type(() => Number)
|
||||||
@ApiProperty({ type: 'number', format: 'double' })
|
@ApiProperty({ type: 'number', format: 'double' })
|
||||||
minDetectionScore!: number;
|
minDetectionScore!: number;
|
||||||
|
|
||||||
@IsNumber()
|
@IsNumber()
|
||||||
@Min(0)
|
@Min(0.1)
|
||||||
@Max(1)
|
@Max(1)
|
||||||
@Type(() => Number)
|
@Type(() => Number)
|
||||||
@ApiProperty({ type: 'number', format: 'double' })
|
@ApiProperty({ type: 'number', format: 'double' })
|
||||||
minRecognitionScore!: number;
|
minRecognitionScore!: number;
|
||||||
|
|
||||||
@ValidateBoolean()
|
|
||||||
unwarpingEnabled!: boolean;
|
|
||||||
|
|
||||||
@ValidateBoolean()
|
|
||||||
orientationClassifyEnabled!: boolean;
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,11 @@ export type ModelPayload = { imagePath: string } | { text: string };
|
||||||
type ModelOptions = { modelName: string };
|
type ModelOptions = { modelName: string };
|
||||||
|
|
||||||
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
export type FaceDetectionOptions = ModelOptions & { minScore: number };
|
||||||
export type OcrOptions = ModelOptions & { minDetectionBoxScore: number, minDetectionScore: number, minRecognitionScore: number, unwarpingEnabled: boolean, orientationClassifyEnabled: boolean };
|
export type OcrOptions = ModelOptions & {
|
||||||
|
minDetectionScore: number;
|
||||||
|
minRecognitionScore: number;
|
||||||
|
maxResolution: number;
|
||||||
|
};
|
||||||
type VisualResponse = { imageHeight: number; imageWidth: number };
|
type VisualResponse = { imageHeight: number; imageWidth: number };
|
||||||
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
export type ClipVisualRequest = { [ModelTask.SEARCH]: { [ModelType.VISUAL]: ModelOptions } };
|
||||||
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
|
export type ClipVisualResponse = { [ModelTask.SEARCH]: string } & VisualResponse;
|
||||||
|
|
@ -40,20 +44,19 @@ export type ClipTextualRequest = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: Mo
|
||||||
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
|
export type ClipTextualResponse = { [ModelTask.SEARCH]: string };
|
||||||
|
|
||||||
export type OCR = {
|
export type OCR = {
|
||||||
x1: number;
|
text: string[];
|
||||||
y1: number;
|
box: number[];
|
||||||
x2: number;
|
boxScore: number[];
|
||||||
y2: number;
|
textScore: number[];
|
||||||
x3: number;
|
|
||||||
y3: number;
|
|
||||||
x4: number;
|
|
||||||
y4: number;
|
|
||||||
text: string;
|
|
||||||
confidence: number;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type OcrRequest = { [ModelTask.OCR]: { [ModelType.OCR]: ModelOptions & { options: { minDetectionScore: number, minRecognitionScore: number } } } };
|
export type OcrRequest = {
|
||||||
export type OcrResponse = { [ModelTask.OCR]: OCR[] } & VisualResponse;
|
[ModelTask.OCR]: {
|
||||||
|
[ModelType.DETECTION]: ModelOptions & { options: { minScore: number; maxResolution: number } };
|
||||||
|
[ModelType.RECOGNITION]: ModelOptions & { options: { minScore: number } };
|
||||||
|
};
|
||||||
|
};
|
||||||
|
export type OcrResponse = { [ModelTask.OCR]: OCR } & VisualResponse;
|
||||||
|
|
||||||
export type FacialRecognitionRequest = {
|
export type FacialRecognitionRequest = {
|
||||||
[ModelTask.FACIAL_RECOGNITION]: {
|
[ModelTask.FACIAL_RECOGNITION]: {
|
||||||
|
|
@ -211,8 +214,17 @@ export class MachineLearningRepository {
|
||||||
return formData;
|
return formData;
|
||||||
}
|
}
|
||||||
|
|
||||||
async ocr(urls: string[], imagePath: string, { modelName, minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled }: OcrOptions) {
|
async ocr(
|
||||||
const request = { [ModelTask.OCR]: { [ModelType.OCR]: { modelName, options: { minDetectionBoxScore, minDetectionScore, minRecognitionScore, unwarpingEnabled, orientationClassifyEnabled } } } };
|
urls: string[],
|
||||||
|
imagePath: string,
|
||||||
|
{ modelName, minDetectionScore, minRecognitionScore, maxResolution }: OcrOptions,
|
||||||
|
) {
|
||||||
|
const request = {
|
||||||
|
[ModelTask.OCR]: {
|
||||||
|
[ModelType.DETECTION]: { modelName, options: { minScore: minDetectionScore, maxResolution } },
|
||||||
|
[ModelType.RECOGNITION]: { modelName, options: { minScore: minRecognitionScore } },
|
||||||
|
},
|
||||||
|
};
|
||||||
const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
|
const response = await this.predict<OcrResponse>(urls, { imagePath }, request);
|
||||||
return response[ModelTask.OCR];
|
return response[ModelTask.OCR];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,9 @@
|
||||||
import { Injectable } from '@nestjs/common';
|
import { Injectable } from '@nestjs/common';
|
||||||
import { Kysely, sql } from 'kysely';
|
import { Insertable, Kysely, sql } from 'kysely';
|
||||||
import { InjectKysely } from 'nestjs-kysely';
|
import { InjectKysely } from 'nestjs-kysely';
|
||||||
import { DB } from 'src/db';
|
import { AssetOcr, DB } from 'src/db';
|
||||||
import { DummyValue, GenerateSql } from 'src/decorators';
|
import { DummyValue, GenerateSql } from 'src/decorators';
|
||||||
|
|
||||||
export interface OcrInsertData {
|
|
||||||
assetId: string;
|
|
||||||
x1: number;
|
|
||||||
y1: number;
|
|
||||||
x2: number;
|
|
||||||
y2: number;
|
|
||||||
x3: number;
|
|
||||||
y3: number;
|
|
||||||
x4: number;
|
|
||||||
y4: number;
|
|
||||||
text: string;
|
|
||||||
confidence: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Injectable()
|
@Injectable()
|
||||||
export class OcrRepository {
|
export class OcrRepository {
|
||||||
constructor(@InjectKysely() private db: Kysely<DB>) {}
|
constructor(@InjectKysely() private db: Kysely<DB>) {}
|
||||||
|
|
@ -54,28 +40,26 @@ export class OcrRepository {
|
||||||
x4: DummyValue.NUMBER,
|
x4: DummyValue.NUMBER,
|
||||||
y4: DummyValue.NUMBER,
|
y4: DummyValue.NUMBER,
|
||||||
text: DummyValue.STRING,
|
text: DummyValue.STRING,
|
||||||
confidence: DummyValue.NUMBER,
|
boxScore: DummyValue.NUMBER,
|
||||||
|
textScore: DummyValue.NUMBER,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
upsert(assetId: string, ocrDataList: OcrInsertData[]) {
|
upsert(assetId: string, ocrDataList: Insertable<AssetOcr>[]) {
|
||||||
if (ocrDataList.length === 0) {
|
let query = this.db.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId));
|
||||||
return;
|
if (ocrDataList.length > 0) {
|
||||||
|
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
|
||||||
|
(query as any) = query
|
||||||
|
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
|
||||||
|
.with('inserted_search', (db) =>
|
||||||
|
db
|
||||||
|
.insertInto('ocr_search')
|
||||||
|
.values({ assetId, text: searchText })
|
||||||
|
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const searchText = ocrDataList.map((item) => item.text.trim()).join(' ');
|
return query.selectNoFrom(sql`1`.as('dummy')).execute();
|
||||||
|
|
||||||
return this.db
|
|
||||||
.with('deleted_ocr', (db) => db.deleteFrom('asset_ocr').where('assetId', '=', assetId))
|
|
||||||
.with('inserted_ocr', (db) => db.insertInto('asset_ocr').values(ocrDataList))
|
|
||||||
.with('inserted_search', (db) =>
|
|
||||||
db
|
|
||||||
.insertInto('ocr_search')
|
|
||||||
.values({ assetId, text: searchText })
|
|
||||||
.onConflict((oc) => oc.column('assetId').doUpdateSet((eb) => ({ text: eb.ref('excluded.text') }))),
|
|
||||||
)
|
|
||||||
.selectNoFrom(sql`1`.as('dummy'))
|
|
||||||
.execute();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import { Kysely, sql } from 'kysely';
|
import { Kysely, sql } from 'kysely';
|
||||||
|
|
||||||
export async function up(db: Kysely<any>): Promise<void> {
|
export async function up(db: Kysely<any>): Promise<void> {
|
||||||
await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT uuid_generate_v4(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "text" text NOT NULL, "confidence" real NOT NULL);`.execute(
|
await sql`CREATE TABLE "asset_ocr" ("id" uuid NOT NULL DEFAULT immich_uuid_v7(), "assetId" uuid NOT NULL, "x1" integer NOT NULL, "y1" integer NOT NULL, "x2" integer NOT NULL, "y2" integer NOT NULL, "x3" integer NOT NULL, "y3" integer NOT NULL, "x4" integer NOT NULL, "y4" integer NOT NULL, "boxScore" real NOT NULL, "textScore" real NOT NULL, "text" text NOT NULL);`.execute(
|
||||||
db,
|
db,
|
||||||
);
|
);
|
||||||
await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db);
|
await sql`ALTER TABLE "asset_ocr" ADD CONSTRAINT "PK_5c37b36ceef9ac1f688b6c6bf22" PRIMARY KEY ("id");`.execute(db);
|
||||||
|
|
@ -6,13 +6,12 @@ export class AssetOcrTable {
|
||||||
@PrimaryGeneratedColumn()
|
@PrimaryGeneratedColumn()
|
||||||
id!: string;
|
id!: string;
|
||||||
|
|
||||||
@ForeignKeyColumn(() => AssetTable, {
|
@ForeignKeyColumn(() => AssetTable, { onDelete: 'CASCADE', onUpdate: 'CASCADE' })
|
||||||
onDelete: 'CASCADE',
|
|
||||||
onUpdate: 'CASCADE',
|
|
||||||
index: true,
|
|
||||||
})
|
|
||||||
assetId!: string;
|
assetId!: string;
|
||||||
|
|
||||||
|
@Column({ type: 'text' })
|
||||||
|
text!: string;
|
||||||
|
|
||||||
@Column({ type: 'integer' })
|
@Column({ type: 'integer' })
|
||||||
x1!: number;
|
x1!: number;
|
||||||
|
|
||||||
|
|
@ -37,9 +36,9 @@ export class AssetOcrTable {
|
||||||
@Column({ type: 'integer' })
|
@Column({ type: 'integer' })
|
||||||
y4!: number;
|
y4!: number;
|
||||||
|
|
||||||
@Column({ type: 'text' })
|
@Column({ type: 'real' })
|
||||||
text!: string;
|
boxScore!: number;
|
||||||
|
|
||||||
@Column({ type: 'real' })
|
@Column({ type: 'real' })
|
||||||
confidence!: number;
|
textScore!: number;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import { Injectable } from '@nestjs/common';
|
||||||
import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants';
|
import { JOBS_ASSET_PAGINATION_SIZE } from 'src/constants';
|
||||||
import { OnJob } from 'src/decorators';
|
import { OnJob } from 'src/decorators';
|
||||||
import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
|
import { AssetVisibility, JobName, JobStatus, QueueName } from 'src/enum';
|
||||||
|
import { OCR } from 'src/repositories/machine-learning.repository';
|
||||||
import { BaseService } from 'src/services/base.service';
|
import { BaseService } from 'src/services/base.service';
|
||||||
import { JobItem, JobOf } from 'src/types';
|
import { JobItem, JobOf } from 'src/types';
|
||||||
import { isOcrEnabled } from 'src/utils/misc';
|
import { isOcrEnabled } from 'src/utils/misc';
|
||||||
|
|
@ -57,27 +58,34 @@ export class OcrService extends BaseService {
|
||||||
machineLearning.ocr,
|
machineLearning.ocr,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (ocrResults.length > 0) {
|
await this.ocrRepository.upsert(id, this.parseOcrResults(id, ocrResults));
|
||||||
const ocrDataList = ocrResults.map((result) => ({
|
|
||||||
assetId: id,
|
|
||||||
x1: result.x1,
|
|
||||||
y1: result.y1,
|
|
||||||
x2: result.x2,
|
|
||||||
y2: result.y2,
|
|
||||||
x3: result.x3,
|
|
||||||
y3: result.y3,
|
|
||||||
x4: result.x4,
|
|
||||||
y4: result.y4,
|
|
||||||
text: result.text,
|
|
||||||
confidence: result.confidence,
|
|
||||||
}));
|
|
||||||
|
|
||||||
await this.ocrRepository.upsert(id, ocrDataList);
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
|
await this.assetRepository.upsertJobStatus({ assetId: id, ocrAt: new Date() });
|
||||||
|
|
||||||
this.logger.debug(`Processed ${ocrResults.length} OCR result(s) for ${id}`);
|
this.logger.debug(`Processed ${ocrResults.text.length} OCR result(s) for ${id}`);
|
||||||
return JobStatus.SUCCESS;
|
return JobStatus.SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parseOcrResults(id: string, ocrResults: OCR) {
|
||||||
|
const ocrDataList = [];
|
||||||
|
for (let i = 0; i < ocrResults.text.length; i++) {
|
||||||
|
const boxOffset = i * 8;
|
||||||
|
const row = {
|
||||||
|
assetId: id,
|
||||||
|
text: ocrResults.text[i],
|
||||||
|
boxScore: ocrResults.boxScore[i],
|
||||||
|
textScore: ocrResults.textScore[i],
|
||||||
|
x1: ocrResults.box[boxOffset],
|
||||||
|
y1: ocrResults.box[boxOffset + 1],
|
||||||
|
x2: ocrResults.box[boxOffset + 2],
|
||||||
|
y2: ocrResults.box[boxOffset + 3],
|
||||||
|
x3: ocrResults.box[boxOffset + 4],
|
||||||
|
y3: ocrResults.box[boxOffset + 5],
|
||||||
|
x4: ocrResults.box[boxOffset + 6],
|
||||||
|
y4: ocrResults.box[boxOffset + 7],
|
||||||
|
};
|
||||||
|
ocrDataList.push(row);
|
||||||
|
}
|
||||||
|
return ocrDataList;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue