fix(ml): armnn not being used (#10929)

* fix armnn not being used, move fallback handling to main, add tests

* formatting
This commit is contained in:
Mert 2024-07-10 10:20:43 -04:00 committed by GitHub
parent 59aa347912
commit f43721ec92
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 111 additions and 44 deletions

View file

@ -9,7 +9,7 @@ from numpy.typing import NDArray
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from PIL import Image
from app.config import clean_name, log
from app.config import log
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
@ -20,20 +20,14 @@ class FaceRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
def __init__(
self,
model_name: str,
min_score: float = 0.7,
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs)
self.min_score = model_kwargs.pop("minScore", min_score)
self.batch = self.model_format == ModelFormat.ONNX
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if self.model_format == ModelFormat.ONNX and not has_batch_axis(session):
if self.batch and not has_batch_axis(session):
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(