mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
fix(ml): load models in separate threads (#4034)
* load models in thread * set clip mode logs to debug level * updated tests * made fixtures slightly less ugly * moved responses to json file * formatting
This commit is contained in:
parent
f1db257628
commit
258b98c262
9 changed files with 1683 additions and 114 deletions
|
|
@ -5,10 +5,8 @@ from abc import ABC, abstractmethod
|
|||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
||||
|
||||
from ..config import get_cache_dir, log, settings
|
||||
from ..schemas import ModelType
|
||||
|
|
@ -21,16 +19,13 @@ class InferenceModel(ABC):
|
|||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | str | None = None,
|
||||
eager: bool = True,
|
||||
inter_op_num_threads: int = settings.model_inter_op_threads,
|
||||
intra_op_num_threads: int = settings.model_intra_op_threads,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self._loaded = False
|
||||
self.loaded = False
|
||||
self._cache_dir = Path(cache_dir) if cache_dir is not None else get_cache_dir(model_name, self.model_type)
|
||||
loader = self.load if eager else self.download
|
||||
|
||||
self.providers = model_kwargs.pop("providers", ["CPUExecutionProvider"])
|
||||
# don't pre-allocate more memory than needed
|
||||
self.provider_options = model_kwargs.pop(
|
||||
|
|
@ -55,34 +50,23 @@ class InferenceModel(ABC):
|
|||
self.sess_options.intra_op_num_threads = intra_op_num_threads
|
||||
self.sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
try:
|
||||
loader(**model_kwargs)
|
||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||
log.warn(
|
||||
(
|
||||
f"Failed to load {self.model_type.replace('_', ' ')} model '{self.model_name}'."
|
||||
"Clearing cache and retrying."
|
||||
)
|
||||
)
|
||||
self.clear_cache()
|
||||
loader(**model_kwargs)
|
||||
|
||||
def download(self, **model_kwargs: Any) -> None:
|
||||
def download(self) -> None:
|
||||
if not self.cached:
|
||||
log.info(
|
||||
(f"Downloading {self.model_type.replace('_', ' ')} model '{self.model_name}'." "This may take a while.")
|
||||
(f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'." "This may take a while.")
|
||||
)
|
||||
self._download(**model_kwargs)
|
||||
self._download()
|
||||
|
||||
def load(self, **model_kwargs: Any) -> None:
|
||||
self.download(**model_kwargs)
|
||||
self._load(**model_kwargs)
|
||||
self._loaded = True
|
||||
def load(self) -> None:
|
||||
if self.loaded:
|
||||
return
|
||||
self.download()
|
||||
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}'")
|
||||
self._load()
|
||||
self.loaded = True
|
||||
|
||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||
if not self._loaded:
|
||||
log.info(f"Loading {self.model_type.replace('_', ' ')} model '{self.model_name}'")
|
||||
self.load()
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(inputs)
|
||||
|
|
@ -95,11 +79,11 @@ class InferenceModel(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _download(self, **model_kwargs: Any) -> None:
|
||||
def _download(self) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _load(self, **model_kwargs: Any) -> None:
|
||||
def _load(self) -> None:
|
||||
...
|
||||
|
||||
@property
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue