change dto

This commit is contained in:
mertalev 2025-06-13 00:39:39 -04:00
parent c59f932bf0
commit 412468989f
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
12 changed files with 93 additions and 98 deletions

View file

@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.schemas import ModelSession, ModelTask, ModelType
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel):
"text": [],
"textScore": [],
}
super().__init__(model_name, **model_kwargs)
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
def _download(self) -> None:
model_info = InferSession.get_model_url(
@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel):
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
# TODO: support other runtimes
session = OrtSession(self.model_path)
self.model = RapidTextRecognizer(
OcrOptions(
session=session.session,
@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel):
valid_text_score_idx = text_scores > 0.5
valid_score_idx_list = valid_text_score_idx.tolist()
return {
"box": boxes.reshape(-1, 8)[valid_text_score_idx],
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
"boxScore": box_scores[valid_text_score_idx],
"textScore": text_scores[valid_text_score_idx],