chore(ml): installable package (#17153)

* app -> immich_ml

* fix test ci

* omit file name

* add new line

* add new line
This commit is contained in:
Mert 2025-03-27 15:49:09 -04:00 committed by GitHub
parent f7d730eb05
commit 84c35e35d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 347 additions and 316 deletions

View file

@ -1,58 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
from numpy.typing import NDArray
from ann.ann import Ann
from app.schemas import SessionNode
from ..config import log, settings
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
fp16=settings.ann_fp16_turbo,
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[SessionNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[SessionNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]