This commit is contained in:
mertalev 2025-10-15 14:35:52 -04:00
parent 5c044fb853
commit 44f90d42f8
No known key found for this signature in database
GPG key ID: DF6ABC77AAD98C95
3 changed files with 9 additions and 8 deletions

View file

@ -2,6 +2,7 @@ from typing import Any
import cv2 import cv2
import numpy as np import numpy as np
from numpy.typing import NDArray
from PIL.Image import Image from PIL.Image import Image
from rapidocr.ch_ppocr_rec import TextRecInput from rapidocr.ch_ppocr_rec import TextRecInput
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
@ -84,7 +85,7 @@ class TextRecognizer(InferenceModel):
"textScore": text_scores[valid_text_score_idx], "textScore": text_scores[valid_text_score_idx],
} }
def get_crop_img_list(self, img: np.ndarray, boxes: np.ndarray) -> list[np.ndarray]: def get_crop_img_list(self, img: NDArray[np.float32], boxes: NDArray[np.float32]) -> list[NDArray[np.float32]]:
img_crop_width = np.maximum( img_crop_width = np.maximum(
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1) np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
).astype(np.int32) ).astype(np.int32)
@ -96,16 +97,16 @@ class TextRecognizer(InferenceModel):
pts_std[:, 2:4, 1] = img_crop_height[:, None] pts_std[:, 2:4, 1] = img_crop_height[:, None]
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1).tolist() img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1).tolist()
imgs = [] imgs: list[NDArray[np.float32]] = []
for box, pts_std, dst_size in zip(list(boxes), list(pts_std), img_crop_sizes): for box, pts_std, dst_size in zip(list(boxes), list(pts_std), img_crop_sizes):
M = cv2.getPerspectiveTransform(box, pts_std) M = cv2.getPerspectiveTransform(box, pts_std)
dst_img = cv2.warpPerspective( dst_img: NDArray[np.float32] = cv2.warpPerspective(
img, img,
M, M,
dst_size, dst_size,
borderMode=cv2.BORDER_REPLICATE, borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC, flags=cv2.INTER_CUBIC,
) ) # type: ignore
dst_height, dst_width = dst_img.shape[0:2] dst_height, dst_width = dst_img.shape[0:2]
if dst_height * 1.0 / dst_width >= 1.5: if dst_height * 1.0 / dst_width >= 1.5:
dst_img = np.rot90(dst_img) dst_img = np.rot90(dst_img)

View file

@ -1,4 +1,4 @@
from typing import Iterable from typing import Any, Iterable
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -20,8 +20,8 @@ class TextRecognitionOutput(TypedDict):
# RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes # RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes
class OcrOptions(dict): class OcrOptions(dict[str, Any]):
def __init__(self, **options): def __init__(self, **options: Any) -> None:
super().__init__(**options) super().__init__(**options)
self.engine_type = EngineType.ONNXRUNTIME self.engine_type = EngineType.ONNXRUNTIME
self.lang_type = LangRec.CH self.lang_type = LangRec.CH

View file

@ -1082,7 +1082,7 @@ wheels = [
[[package]] [[package]]
name = "immich-ml" name = "immich-ml"
version = "1.129.0" version = "2.0.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "aiocache" }, { name = "aiocache" },