feat(ml): ARMNN acceleration (#5667)

* feat(ml): ARMNN acceleration for CLIP

* wrap ANN as ONNX-Session

* strict typing

* normalize ARMNN CLIP embedding

* mutex to handle concurrent execution

* make inputs contiguous

* fine-grained locking; concurrent network execution

---------

Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
Fynn Petersen-Frey 2024-01-11 18:26:46 +01:00 committed by GitHub
parent 29747437f6
commit 753292956e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 956 additions and 44 deletions

View file

@ -26,6 +26,7 @@ class Settings(BaseSettings):
request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 1
model_intra_op_threads: int = 2
ann: bool = True
class Config:
env_prefix = "MACHINE_LEARNING_"

View file

@ -0,0 +1,68 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
from numpy import ascontiguousarray
from ann.ann import Ann
from app.schemas import ndarray_f32, ndarray_i32
from ..config import log, settings
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path):
tuning_file = Path(settings.cache_folder) / "gpu-tuning.ann"
with tuning_file.open(mode="a"):
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
pass
self.ann = Ann(tuning_level=3, tuning_file=tuning_file.as_posix())
log.info("Loading ANN model %s ...", model_path)
cache_file = model_path.with_suffix(".anncache")
save = False
if not cache_file.is_file():
save = True
with cache_file.open(mode="a"):
# create empty model cache file
pass
self.model = self.ann.load(
model_path.as_posix(),
save_cached_network=save,
cached_network_path=cache_file.as_posix(),
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[AnnNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[AnnNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, ndarray_f32] | dict[str, ndarray_i32],
run_options: Any = None,
) -> list[ndarray_f32]:
inputs: list[ndarray_f32] = [ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]

View file

@ -10,8 +10,11 @@ import onnxruntime as ort
from huggingface_hub import snapshot_download
from typing_extensions import Buffer
import ann.ann
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelType
from .ann import AnnSession
class InferenceModel(ABC):
@ -138,6 +141,21 @@ class InferenceModel(ABC):
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
armnn_path = model_path.with_suffix(".armnn")
if settings.ann and ann.ann.is_available and armnn_path.is_file():
session = AnnSession(armnn_path)
elif model_path.is_file():
session = ort.InferenceSession(
model_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
else:
raise ValueError(f"the file model_path='{model_path}' does not exist")
return session
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]

View file

@ -6,7 +6,6 @@ from pathlib import Path
from typing import Any, Literal
import numpy as np
import onnxruntime as ort
from PIL import Image
from tokenizers import Encoding, Tokenizer
@ -33,24 +32,12 @@ class BaseCLIPEncoder(InferenceModel):
def _load(self) -> None:
if self.mode == "text" or self.mode is None:
log.debug(f"Loading clip text model '{self.model_name}'")
self.text_model = ort.InferenceSession(
self.textual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.text_model = self._make_session(self.textual_path)
log.debug(f"Loaded clip text model '{self.model_name}'")
if self.mode == "vision" or self.mode is None:
log.debug(f"Loading clip vision model '{self.model_name}'")
self.vision_model = ort.InferenceSession(
self.visual_path.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
)
self.vision_model = self._make_session(self.visual_path)
log.debug(f"Loaded clip vision model '{self.model_name}'")
def _predict(self, image_or_text: Image.Image | str) -> ndarray_f32:
@ -61,12 +48,10 @@ class BaseCLIPEncoder(InferenceModel):
case Image.Image():
if self.mode == "text":
raise TypeError("Cannot encode image as text-only model")
outputs: ndarray_f32 = self.vision_model.run(None, self.transform(image_or_text))[0][0]
case str():
if self.mode == "vision":
raise TypeError("Cannot encode text as vision-only model")
outputs = self.text_model.run(None, self.tokenize(image_or_text))[0][0]
case _:
raise TypeError(f"Expected Image or str, but got: {type(image_or_text)}")

View file

@ -3,7 +3,6 @@ from typing import Any
import cv2
import numpy as np
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX, RetinaFace
from insightface.utils.face_align import norm_crop
@ -27,23 +26,8 @@ class FaceRecognizer(InferenceModel):
super().__init__(clean_name(model_name), cache_dir, **model_kwargs)
def _load(self) -> None:
self.det_model = RetinaFace(
session=ort.InferenceSession(
self.det_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.rec_model = ArcFaceONNX(
self.rec_file.as_posix(),
session=ort.InferenceSession(
self.rec_file.as_posix(),
sess_options=self.sess_options,
providers=self.providers,
provider_options=self.provider_options,
),
)
self.det_model = RetinaFace(session=self._make_session(self.det_file))
self.rec_model = ArcFaceONNX(self.rec_file.as_posix(), session=self._make_session(self.rec_file))
self.det_model.prepare(
ctx_id=0,

View file

@ -13,7 +13,7 @@ from PIL import Image
from pytest_mock import MockerFixture
from .config import settings
from .models.base import PicklableSessionOptions
from .models.base import InferenceModel, PicklableSessionOptions
from .models.cache import ModelCache
from .models.clip import OpenCLIPEncoder
from .models.facial_recognition import FaceRecognizer
@ -36,9 +36,10 @@ class TestCLIP:
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mocked.run.return_value = [[self.embedding]]
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="vision")
embedding = clip_encoder.predict(pil_image)
@ -47,7 +48,7 @@ class TestCLIP:
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
clip_encoder.vision_model.run.assert_called_once()
mocked.run.assert_called_once()
def test_basic_text(
self,
@ -60,9 +61,10 @@ class TestCLIP:
mocker.patch.object(OpenCLIPEncoder, "model_cfg", clip_model_cfg)
mocker.patch.object(OpenCLIPEncoder, "preprocess_cfg", clip_preprocess_cfg)
mocker.patch.object(OpenCLIPEncoder, "tokenizer_cfg", clip_tokenizer_cfg)
mocked = mocker.patch.object(InferenceModel, "_make_session", autospec=True).return_value
mocked.run.return_value = [[self.embedding]]
mocker.patch("app.models.clip.Tokenizer.from_file", autospec=True)
mocked = mocker.patch("app.models.clip.ort.InferenceSession", autospec=True)
mocked.return_value.run.return_value = [[self.embedding]]
clip_encoder = OpenCLIPEncoder("ViT-B-32::openai", cache_dir="test_cache", mode="text")
embedding = clip_encoder.predict("test search query")
@ -71,7 +73,7 @@ class TestCLIP:
assert isinstance(embedding, np.ndarray)
assert embedding.shape[0] == clip_model_cfg["embed_dim"]
assert embedding.dtype == np.float32
clip_encoder.text_model.run.assert_called_once()
mocked.run.assert_called_once()
class TestFaceRecognition: