mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
feat(ml): improve test coverage (#7041)
* update e2e * tokenizer tests * more tests, remove unnecessary code * fix e2e setting * add tests for loading model * update workflow * fixed test
This commit is contained in:
parent
6e853e2a9d
commit
0c4df216d7
8 changed files with 501 additions and 1636 deletions
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
|
@ -11,7 +10,6 @@ import onnxruntime as ort
|
|||
from huggingface_hub import snapshot_download
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import ann.ann
|
||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||
|
|
@ -200,7 +198,7 @@ class InferenceModel(ABC):
|
|||
|
||||
@providers.setter
|
||||
def providers(self, providers: list[str]) -> None:
|
||||
log.debug(
|
||||
log.info(
|
||||
(f"Setting '{self.model_name}' execution providers to {providers}, " "in descending order of preference"),
|
||||
)
|
||||
self._providers = providers
|
||||
|
|
@ -217,7 +215,7 @@ class InferenceModel(ABC):
|
|||
|
||||
@provider_options.setter
|
||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||
log.info(f"Setting execution provider options to {provider_options}")
|
||||
log.debug(f"Setting execution provider options to {provider_options}")
|
||||
self._provider_options = provider_options
|
||||
|
||||
@property
|
||||
|
|
@ -255,7 +253,7 @@ class InferenceModel(ABC):
|
|||
|
||||
@property
|
||||
def sess_options_default(self) -> ort.SessionOptions:
|
||||
sess_options = PicklableSessionOptions()
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
# avoid thread contention between models
|
||||
|
|
@ -287,15 +285,3 @@ class InferenceModel(ABC):
|
|||
@property
|
||||
def preferred_runtime_default(self) -> ModelRuntime:
|
||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||
|
||||
|
||||
# HF deep copies configs, so we need to make session options picklable
|
||||
class PicklableSessionOptions(ort.SessionOptions): # type: ignore[misc]
|
||||
def __getstate__(self) -> bytes:
|
||||
return pickle.dumps([(attr, getattr(self, attr)) for attr in dir(self) if not callable(getattr(self, attr))])
|
||||
|
||||
def __setstate__(self, state: Buffer) -> None:
|
||||
self.__init__() # type: ignore[misc]
|
||||
attrs: list[tuple[str, Any]] = pickle.loads(state)
|
||||
for attr, val in attrs:
|
||||
setattr(self, attr, val)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue