mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
fix(ml): armnn not being used (#10929)
* fix armnn not being used, move fallback handling to main, add tests * formatting
This commit is contained in:
parent
59aa347912
commit
f43721ec92
7 changed files with 111 additions and 44 deletions
|
|
@ -23,7 +23,7 @@ class InferenceModel(ABC):
|
|||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | str | None = None,
|
||||
preferred_format: ModelFormat | None = None,
|
||||
model_format: ModelFormat | None = None,
|
||||
session: ModelSession | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
|
|
@ -31,7 +31,7 @@ class InferenceModel(ABC):
|
|||
self.load_attempts = 0
|
||||
self.model_name = clean_name(model_name)
|
||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
||||
self.model_format = preferred_format if preferred_format is not None else self._model_format_default
|
||||
self.model_format = model_format if model_format is not None else self._model_format_default
|
||||
if session is not None:
|
||||
self.session = session
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ class InferenceModel(ABC):
|
|||
self.load_attempts += 1
|
||||
|
||||
self.download()
|
||||
attempt = f"Attempt #{self.load_attempts + 1} to load" if self.load_attempts else "Loading"
|
||||
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
|
||||
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||
self.session = self._load()
|
||||
self.loaded = True
|
||||
|
|
@ -101,6 +101,9 @@ class InferenceModel(ABC):
|
|||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_session(self, model_path: Path) -> ModelSession:
|
||||
if not model_path.is_file():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
match model_path.suffix:
|
||||
case ".armnn":
|
||||
session: ModelSession = AnnSession(model_path)
|
||||
|
|
@ -144,17 +147,13 @@ class InferenceModel(ABC):
|
|||
|
||||
@property
|
||||
def model_format(self) -> ModelFormat:
|
||||
return self._preferred_format
|
||||
return self._model_format
|
||||
|
||||
@model_format.setter
|
||||
def model_format(self, preferred_format: ModelFormat) -> None:
|
||||
log.debug(f"Setting preferred format to {preferred_format}")
|
||||
self._preferred_format = preferred_format
|
||||
def model_format(self, model_format: ModelFormat) -> None:
|
||||
log.debug(f"Setting model format to {model_format}")
|
||||
self._model_format = model_format
|
||||
|
||||
@property
|
||||
def _model_format_default(self) -> ModelFormat:
|
||||
prefer_ann = ann.ann.is_available and settings.ann
|
||||
ann_exists = (self.model_dir / "model.armnn").is_file()
|
||||
if prefer_ann and not ann_exists:
|
||||
log.warning(f"ARM NN is available, but '{self.model_name}' does not support ARM NN. Falling back to ONNX.")
|
||||
return ModelFormat.ARMNN if prefer_ann and ann_exists else ModelFormat.ONNX
|
||||
return ModelFormat.ARMNN if ann.ann.is_available and settings.ann else ModelFormat.ONNX
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue