mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
feat(ml): composable ml (#9973)
* modularize model classes * various fixes * expose port * change response * round coordinates * simplify preload * update server * simplify interface simplify * update tests * composable endpoint * cleanup fixes remove unnecessary interface support text input, cleanup * ew camelcase * update server server fixes fix typing * ml fixes update locustfile fixes * cleaner response * better repo response * update tests formatting and typing rename * undo compose change * linting fix type actually fix typing * stricter typing fix detection-only response no need for defaultdict * update spec file update api linting * update e2e * unnecessary dimension * remove commented code * remove duplicate code * remove unused imports * add batch dim
This commit is contained in:
parent
7a46f80ddc
commit
2b1b43a7e4
39 changed files with 982 additions and 999 deletions
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
|
@ -11,13 +11,14 @@ from huggingface_hub import snapshot_download
|
|||
import ann.ann
|
||||
from app.models.constants import SUPPORTED_PROVIDERS
|
||||
|
||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||
from ..schemas import ModelRuntime, ModelType
|
||||
from ..config import clean_name, log, settings
|
||||
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
|
||||
from .ann import AnnSession
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
_model_type: ModelType
|
||||
depends: ClassVar[list[ModelIdentity]]
|
||||
identity: ClassVar[ModelIdentity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -26,16 +27,16 @@ class InferenceModel(ABC):
|
|||
providers: list[str] | None = None,
|
||||
provider_options: list[dict[str, Any]] | None = None,
|
||||
sess_options: ort.SessionOptions | None = None,
|
||||
preferred_runtime: ModelRuntime | None = None,
|
||||
preferred_format: ModelFormat | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.loaded = False
|
||||
self.model_name = model_name
|
||||
self.model_name = clean_name(model_name)
|
||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
|
||||
self.providers = providers if providers is not None else self.providers_default
|
||||
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
|
||||
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
|
||||
self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default
|
||||
self.preferred_format = preferred_format if preferred_format is not None else self.preferred_format_default
|
||||
|
||||
def download(self) -> None:
|
||||
if not self.cached:
|
||||
|
|
@ -47,35 +48,36 @@ class InferenceModel(ABC):
|
|||
def load(self) -> None:
|
||||
if self.loaded:
|
||||
return
|
||||
|
||||
self.download()
|
||||
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||
self._load()
|
||||
self.session = self._load()
|
||||
self.loaded = True
|
||||
|
||||
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
|
||||
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(inputs)
|
||||
return self._predict(*inputs, **model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, inputs: Any) -> Any: ...
|
||||
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
|
||||
|
||||
def configure(self, **model_kwargs: Any) -> None:
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def _download(self) -> None:
|
||||
ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"]
|
||||
ignore_patterns = [] if self.preferred_format == ModelFormat.ARMNN else ["*.armnn"]
|
||||
snapshot_download(
|
||||
get_hf_model_name(self.model_name),
|
||||
f"immich-app/{clean_name(self.model_name)}",
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
local_dir_use_symlinks=False,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _load(self) -> None: ...
|
||||
def _load(self) -> ModelSession:
|
||||
return self._make_session(self.model_path)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
if not self.cache_dir.exists():
|
||||
|
|
@ -99,7 +101,7 @@ class InferenceModel(ABC):
|
|||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
|
||||
def _make_session(self, model_path: Path) -> ModelSession:
|
||||
if not model_path.is_file():
|
||||
onnx_path = model_path.with_suffix(".onnx")
|
||||
if not onnx_path.is_file():
|
||||
|
|
@ -124,9 +126,21 @@ class InferenceModel(ABC):
|
|||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
|
||||
@property
|
||||
def model_dir(self) -> Path:
|
||||
return self.cache_dir / self.model_type.value
|
||||
|
||||
@property
|
||||
def model_path(self) -> Path:
|
||||
return self.model_dir / f"model.{self.preferred_format}"
|
||||
|
||||
@property
|
||||
def model_task(self) -> ModelTask:
|
||||
return self.identity[1]
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelType:
|
||||
return self._model_type
|
||||
return self.identity[0]
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
|
|
@ -138,11 +152,11 @@ class InferenceModel(ABC):
|
|||
|
||||
@property
|
||||
def cache_dir_default(self) -> Path:
|
||||
return get_cache_dir(self.model_name, self.model_type)
|
||||
return settings.cache_folder / self.model_task.value / self.model_name
|
||||
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.cache_dir.is_dir() and any(self.cache_dir.iterdir())
|
||||
return self.model_path.is_file()
|
||||
|
||||
@property
|
||||
def providers(self) -> list[str]:
|
||||
|
|
@ -226,14 +240,14 @@ class InferenceModel(ABC):
|
|||
return sess_options
|
||||
|
||||
@property
|
||||
def preferred_runtime(self) -> ModelRuntime:
|
||||
return self._preferred_runtime
|
||||
def preferred_format(self) -> ModelFormat:
|
||||
return self._preferred_format
|
||||
|
||||
@preferred_runtime.setter
|
||||
def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None:
|
||||
log.debug(f"Setting preferred runtime to {preferred_runtime}")
|
||||
self._preferred_runtime = preferred_runtime
|
||||
@preferred_format.setter
|
||||
def preferred_format(self, preferred_format: ModelFormat) -> None:
|
||||
log.debug(f"Setting preferred format to {preferred_format}")
|
||||
self._preferred_format = preferred_format
|
||||
|
||||
@property
|
||||
def preferred_runtime_default(self) -> ModelRuntime:
|
||||
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
|
||||
def preferred_format_default(self) -> ModelFormat:
|
||||
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue