mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
change dto
This commit is contained in:
parent
c59f932bf0
commit
412468989f
12 changed files with 93 additions and 98 deletions
|
|
@ -12,7 +12,8 @@ from rapidocr.utils.typings import ModelType as RapidModelType
|
|||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
|
||||
|
||||
|
|
@ -29,7 +30,7 @@ class TextRecognizer(InferenceModel):
|
|||
"text": [],
|
||||
"textScore": [],
|
||||
}
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
|
||||
def _download(self) -> None:
|
||||
model_info = InferSession.get_model_url(
|
||||
|
|
@ -50,7 +51,8 @@ class TextRecognizer(InferenceModel):
|
|||
DownloadFile.run(download_params)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
# TODO: support other runtimes
|
||||
session = OrtSession(self.model_path)
|
||||
self.model = RapidTextRecognizer(
|
||||
OcrOptions(
|
||||
session=session.session,
|
||||
|
|
@ -80,7 +82,7 @@ class TextRecognizer(InferenceModel):
|
|||
valid_text_score_idx = text_scores > 0.5
|
||||
valid_score_idx_list = valid_text_score_idx.tolist()
|
||||
return {
|
||||
"box": boxes.reshape(-1, 8)[valid_text_score_idx],
|
||||
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
|
||||
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
|
||||
"boxScore": box_scores[valid_text_score_idx],
|
||||
"textScore": text_scores[valid_text_score_idx],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue