feat(ml): improve test coverage (#7041)

* update e2e

* tokenizer tests

* more tests, remove unnecessary code

* fix e2e setting

* add tests for loading model

* update workflow

* fixed test
This commit is contained in:
Mert 2024-02-11 17:58:56 -05:00 committed by GitHub
parent 6e853e2a9d
commit 0c4df216d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 501 additions and 1636 deletions

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import pickle
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
@ -11,7 +10,6 @@ import onnxruntime as ort
from huggingface_hub import snapshot_download
from onnx.shape_inference import infer_shapes
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from typing_extensions import Buffer
import ann.ann
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
@ -200,7 +198,7 @@ class InferenceModel(ABC):
@providers.setter
def providers(self, providers: list[str]) -> None:
log.debug(
log.info(
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
)
self._providers = providers
@ -217,7 +215,7 @@ class InferenceModel(ABC):
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.info(f"Setting execution provider options to {provider_options}")
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
@ -255,7 +253,7 @@ class InferenceModel(ABC):
@property
def sess_options_default(self) -> ort.SessionOptions:
sess_options = PicklableSessionOptions()
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
# avoid thread contention between models
@ -287,15 +285,3 @@ class InferenceModel(ABC):
@property
def preferred_runtime_default(self) -> ModelRuntime:
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
# HF deep copies configs, so we need to make session options picklable
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
def __getstate__(self) -> bytes:
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
def __setstate__(self, state: Buffer) -> None:
self.__init__() # type: ignore[misc]
attrs: list[tuple[str, Any]] = pickle.loads(state)
for attr, val in attrs:
setattr(self, attr, val)