feat(ml): composable ml (#9973)

* modularize model classes

* various fixes

* expose port

* change response

* round coordinates

* simplify preload

* update server

* simplify interface

simplify

* update tests

* composable endpoint

* cleanup

fixes

remove unnecessary interface

support text input, cleanup

* ew camelcase

* update server

server fixes

fix typing

* ml fixes

update locustfile

fixes

* cleaner response

* better repo response

* update tests

formatting and typing

rename

* undo compose change

* linting

fix type

actually fix typing

* stricter typing

fix detection-only response

no need for defaultdict

* update spec file

update api

linting

* update e2e

* unnecessary dimension

* remove commented code

* remove duplicate code

* remove unused imports

* add batch dim
This commit is contained in:
Mert 2024-06-06 23:09:47 -04:00 committed by GitHub
parent 7a46f80ddc
commit 2b1b43a7e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 982 additions and 999 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any
from typing import Any, ClassVar
import onnxruntime as ort
from huggingface_hub import snapshot_download
@ -11,13 +11,14 @@ from huggingface_hub import snapshot_download
import ann.ann
from app.models.constants import SUPPORTED_PROVIDERS
from ..config import get_cache_dir, get_hf_model_name, log, settings
from ..schemas import ModelRuntime, ModelType
from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from .ann import AnnSession
class InferenceModel(ABC):
_model_type: ModelType
depends: ClassVar[list[ModelIdentity]]
identity: ClassVar[ModelIdentity]
def __init__(
self,
@ -26,16 +27,16 @@ class InferenceModel(ABC):
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
preferred_runtime: ModelRuntime | None = None,
preferred_format: ModelFormat | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = False
self.model_name = model_name
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self.cache_dir_default
self.providers = providers if providers is not None else self.providers_default
self.provider_options = provider_options if provider_options is not None else self.provider_options_default
self.sess_options = sess_options if sess_options is not None else self.sess_options_default
self.preferred_runtime = preferred_runtime if preferred_runtime is not None else self.preferred_runtime_default
self.preferred_format = preferred_format if preferred_format is not None else self.preferred_format_default
def download(self) -> None:
if not self.cached:
@ -47,35 +48,36 @@ class InferenceModel(ABC):
def load(self) -> None:
if self.loaded:
return
self.download()
log.info(f"Loading {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
self._load()
self.session = self._load()
self.loaded = True
def predict(self, inputs: Any, **model_kwargs: Any) -> Any:
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
self.load()
if model_kwargs:
self.configure(**model_kwargs)
return self._predict(inputs)
return self._predict(*inputs, **model_kwargs)
@abstractmethod
def _predict(self, inputs: Any) -> Any: ...
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
def configure(self, **model_kwargs: Any) -> None:
def configure(self, **kwargs: Any) -> None:
pass
def _download(self) -> None:
ignore_patterns = [] if self.preferred_runtime == ModelRuntime.ARMNN else ["*.armnn"]
ignore_patterns = [] if self.preferred_format == ModelFormat.ARMNN else ["*.armnn"]
snapshot_download(
get_hf_model_name(self.model_name),
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
local_dir=self.cache_dir,
local_dir_use_symlinks=False,
ignore_patterns=ignore_patterns,
)
@abstractmethod
def _load(self) -> None: ...
def _load(self) -> ModelSession:
return self._make_session(self.model_path)
def clear_cache(self) -> None:
if not self.cache_dir.exists():
@ -99,7 +101,7 @@ class InferenceModel(ABC):
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> AnnSession | ort.InferenceSession:
def _make_session(self, model_path: Path) -> ModelSession:
if not model_path.is_file():
onnx_path = model_path.with_suffix(".onnx")
if not onnx_path.is_file():
@ -124,9 +126,21 @@ class InferenceModel(ABC):
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
@property
def model_dir(self) -> Path:
return self.cache_dir / self.model_type.value
@property
def model_path(self) -> Path:
return self.model_dir / f"model.{self.preferred_format}"
@property
def model_task(self) -> ModelTask:
return self.identity[1]
@property
def model_type(self) -> ModelType:
return self._model_type
return self.identity[0]
@property
def cache_dir(self) -> Path:
@ -138,11 +152,11 @@ class InferenceModel(ABC):
@property
def cache_dir_default(self) -> Path:
return get_cache_dir(self.model_name, self.model_type)
return settings.cache_folder / self.model_task.value / self.model_name
@property
def cached(self) -> bool:
return self.cache_dir.is_dir() and any(self.cache_dir.iterdir())
return self.model_path.is_file()
@property
def providers(self) -> list[str]:
@ -226,14 +240,14 @@ class InferenceModel(ABC):
return sess_options
@property
def preferred_runtime(self) -> ModelRuntime:
return self._preferred_runtime
def preferred_format(self) -> ModelFormat:
return self._preferred_format
@preferred_runtime.setter
def preferred_runtime(self, preferred_runtime: ModelRuntime) -> None:
log.debug(f"Setting preferred runtime to {preferred_runtime}")
self._preferred_runtime = preferred_runtime
@preferred_format.setter
def preferred_format(self, preferred_format: ModelFormat) -> None:
log.debug(f"Setting preferred format to {preferred_format}")
self._preferred_format = preferred_format
@property
def preferred_runtime_default(self) -> ModelRuntime:
return ModelRuntime.ARMNN if ann.ann.is_available and settings.ann else ModelRuntime.ONNX
def preferred_format_default(self) -> ModelFormat:
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX