feat(ml): configurable batch size for facial recognition (#13689)

* configurable batch size, default openvino to 1

* update docs

* don't add a new dependency for two lines

* fix typing
This commit is contained in:
Mert 2024-10-23 08:50:28 -04:00 committed by GitHub
parent a76c39812f
commit 1ec9a60e41
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 70 additions and 31 deletions

View file

@ -3,13 +3,14 @@ from typing import Any
import numpy as np
import onnx
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX
from insightface.utils.face_align import norm_crop
from numpy.typing import NDArray
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from PIL import Image
from app.config import log
from app.config import log, settings
from app.models.base import InferenceModel
from app.models.transforms import decode_cv2
from app.schemas import FaceDetectionOutput, FacialRecognitionOutput, ModelFormat, ModelSession, ModelTask, ModelType
@ -22,11 +23,12 @@ class FaceRecognizer(InferenceModel):
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs)
self.min_score = model_kwargs.pop("minScore", min_score)
self.batch = self.model_format == ModelFormat.ONNX
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if self.batch and str(session.get_inputs()[0].shape[0]) != "batch":
if (not self.batch_size or self.batch_size > 1) and str(session.get_inputs()[0].shape[0]) != "batch":
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
@ -42,18 +44,18 @@ class FaceRecognizer(InferenceModel):
return []
inputs = decode_cv2(inputs)
cropped_faces = self._crop(inputs, faces)
embeddings = self._predict_batch(cropped_faces) if self.batch else self._predict_single(cropped_faces)
embeddings = self._predict_batch(cropped_faces)
return self.postprocess(faces, embeddings)
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
return embeddings
if not self.batch_size or len(cropped_faces) <= self.batch_size:
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
return embeddings
def _predict_single(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
embeddings: list[NDArray[np.float32]] = []
for face in cropped_faces:
embeddings.append(self.model.get_feat(face))
return np.concatenate(embeddings, axis=0)
batch_embeddings: list[NDArray[np.float32]] = []
for i in range(0, len(cropped_faces), self.batch_size):
batch_embeddings.append(self.model.get_feat(cropped_faces[i : i + self.batch_size]))
return np.concatenate(batch_embeddings, axis=0)
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
return [
@ -77,3 +79,8 @@ class FaceRecognizer(InferenceModel):
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
onnx.save(updated_proto, model_path)
@property
def _batch_size_default(self) -> int | None:
providers = ort.get_available_providers()
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1