mirror of
https://github.com/immich-app/immich
synced 2025-11-14 17:36:12 +00:00
chore(ml): installable package (#17153)
* app -> immich_ml * fix test ci * omit file name * add new line * add new line
This commit is contained in:
parent
f7d730eb05
commit
84c35e35d6
31 changed files with 347 additions and 316 deletions
40
machine-learning/immich_ml/models/__init__.py
Normal file
40
machine-learning/immich_ml/models/__init__.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Any
|
||||
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
|
||||
from immich_ml.models.clip.visual import OpenClipVisualEncoder
|
||||
from immich_ml.schemas import ModelSource, ModelTask, ModelType
|
||||
|
||||
from .constants import get_model_source
|
||||
from .facial_recognition.detection import FaceDetector
|
||||
from .facial_recognition.recognition import FaceRecognizer
|
||||
|
||||
|
||||
def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
|
||||
source = get_model_source(model_name)
|
||||
match source, model_type, model_task:
|
||||
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
|
||||
return OpenClipVisualEncoder
|
||||
|
||||
case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||
return OpenClipTextualEncoder
|
||||
|
||||
case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||
return MClipTextualEncoder
|
||||
|
||||
case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
|
||||
return FaceDetector
|
||||
|
||||
case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
|
||||
return FaceRecognizer
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")
|
||||
|
||||
|
||||
def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
|
||||
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)
|
||||
|
||||
|
||||
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
|
||||
return get_model_class(model_name, model_type, model_task).depends
|
||||
177
machine-learning/immich_ml/models/base.py
Normal file
177
machine-learning/immich_ml/models/base.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import immich_ml.sessions.ann.loader
|
||||
import immich_ml.sessions.rknn as rknn
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from ..config import clean_name, log, settings
|
||||
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
|
||||
from ..sessions.ann import AnnSession
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
depends: ClassVar[list[ModelIdentity]]
|
||||
identity: ClassVar[ModelIdentity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | str | None = None,
|
||||
model_format: ModelFormat | None = None,
|
||||
session: ModelSession | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.loaded = session is not None
|
||||
self.load_attempts = 0
|
||||
self.model_name = clean_name(model_name)
|
||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
||||
self.model_format = model_format if model_format is not None else self._model_format_default
|
||||
if session is not None:
|
||||
self.session = session
|
||||
|
||||
def download(self) -> None:
|
||||
if not self.cached:
|
||||
log.info(
|
||||
f"Downloading {self.model_type.replace('-', ' ')} model '{self.model_name}'. This may take a while."
|
||||
)
|
||||
self._download()
|
||||
|
||||
def load(self) -> None:
|
||||
if self.loaded:
|
||||
return
|
||||
self.load_attempts += 1
|
||||
|
||||
self.download()
|
||||
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
|
||||
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||
self.session = self._load()
|
||||
self.loaded = True
|
||||
|
||||
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(*inputs, **model_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def _download(self) -> None:
|
||||
ignored_patterns: dict[ModelFormat, list[str]] = {
|
||||
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
|
||||
ModelFormat.ARMNN: ["*.rknn"],
|
||||
ModelFormat.RKNN: ["*.armnn"],
|
||||
}
|
||||
|
||||
snapshot_download(
|
||||
f"immich-app/{clean_name(self.model_name)}",
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
ignore_patterns=ignored_patterns.get(self.model_format, []),
|
||||
)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
return self._make_session(self.model_path)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
if not self.cache_dir.exists():
|
||||
log.warning(
|
||||
f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
|
||||
)
|
||||
return
|
||||
if not rmtree.avoids_symlink_attacks:
|
||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")
|
||||
|
||||
if self.cache_dir.is_dir():
|
||||
log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
||||
rmtree(self.cache_dir)
|
||||
else:
|
||||
log.warning(
|
||||
(
|
||||
f"Encountered file instead of directory at cache path "
|
||||
f"for '{self.model_name}'. Removing file and replacing with a directory."
|
||||
),
|
||||
)
|
||||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_session(self, model_path: Path) -> ModelSession:
|
||||
if not model_path.is_file():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
match model_path.suffix:
|
||||
case ".armnn":
|
||||
session: ModelSession = AnnSession(model_path)
|
||||
case ".onnx":
|
||||
session = OrtSession(model_path)
|
||||
case ".rknn":
|
||||
session = rknn.RknnSession(model_path)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
|
||||
def model_path_for_format(self, model_format: ModelFormat) -> Path:
|
||||
model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
|
||||
if model_path_prefix:
|
||||
return self.model_dir / model_path_prefix / f"model.{model_format}"
|
||||
return self.model_dir / f"model.{model_format}"
|
||||
|
||||
@property
|
||||
def model_dir(self) -> Path:
|
||||
return self.cache_dir / self.model_type.value
|
||||
|
||||
@property
|
||||
def model_path(self) -> Path:
|
||||
return self.model_path_for_format(self.model_format)
|
||||
|
||||
@property
|
||||
def model_task(self) -> ModelTask:
|
||||
return self.identity[1]
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelType:
|
||||
return self.identity[0]
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
return self._cache_dir
|
||||
|
||||
@cache_dir.setter
|
||||
def cache_dir(self, cache_dir: Path) -> None:
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
@property
|
||||
def _cache_dir_default(self) -> Path:
|
||||
return settings.cache_folder / self.model_task.value / self.model_name
|
||||
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.model_path.is_file()
|
||||
|
||||
@property
|
||||
def model_format(self) -> ModelFormat:
|
||||
return self._model_format
|
||||
|
||||
@model_format.setter
|
||||
def model_format(self, model_format: ModelFormat) -> None:
|
||||
log.debug(f"Setting model format to {model_format}")
|
||||
self._model_format = model_format
|
||||
|
||||
@property
|
||||
def _model_format_default(self) -> ModelFormat:
|
||||
if rknn.is_available:
|
||||
return ModelFormat.RKNN
|
||||
elif immich_ml.sessions.ann.loader.is_available and settings.ann:
|
||||
return ModelFormat.ARMNN
|
||||
else:
|
||||
return ModelFormat.ONNX
|
||||
60
machine-learning/immich_ml/models/cache.py
Normal file
60
machine-learning/immich_ml/models/cache.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
from aiocache.plugins import TimingPlugin
|
||||
|
||||
from immich_ml.models import from_model_type
|
||||
from immich_ml.models.base import InferenceModel
|
||||
|
||||
from ..schemas import ModelTask, ModelType, has_profiling
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
revalidate: bool = False,
|
||||
timeout: int | None = None,
|
||||
profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||
"""
|
||||
|
||||
plugins = []
|
||||
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.should_revalidate = revalidate
|
||||
|
||||
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||
|
||||
async def get(
|
||||
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
|
||||
) -> InferenceModel:
|
||||
key = f"{model_name}{model_type}{model_task}"
|
||||
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model: InferenceModel | None = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = from_model_type(model_name, model_type, model_task, **model_kwargs)
|
||||
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||
elif self.should_revalidate:
|
||||
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
if not has_profiling(self.cache):
|
||||
return None
|
||||
|
||||
return self.cache.profiling
|
||||
|
||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||
if ttl is not None and key in self.cache._handlers:
|
||||
await self.cache.expire(key, ttl)
|
||||
108
machine-learning/immich_ml/models/clip/textual.py
Normal file
108
machine-learning/immich_ml/models/clip/textual.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from tokenizers import Encoding, Tokenizer
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import clean_text, serialize_np_array
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class BaseCLIPTextualEncoder(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: str, **kwargs: Any) -> str:
|
||||
res: NDArray[np.float32] = self.session.run(None, self.tokenize(inputs))[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = super()._load()
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def model_cfg_path(self) -> Path:
|
||||
return self.cache_dir / "config.json"
|
||||
|
||||
@property
|
||||
def tokenizer_file_path(self) -> Path:
|
||||
return self.model_dir / "tokenizer.json"
|
||||
|
||||
@property
|
||||
def tokenizer_cfg_path(self) -> Path:
|
||||
return self.model_dir / "tokenizer_config.json"
|
||||
|
||||
@cached_property
|
||||
def model_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@property
|
||||
def text_cfg(self) -> dict[str, Any]:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
return text_cfg
|
||||
|
||||
@cached_property
|
||||
def tokenizer_file(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
|
||||
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
|
||||
return tokenizer_file
|
||||
|
||||
@cached_property
|
||||
def tokenizer_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
|
||||
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
|
||||
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
|
||||
return tokenizer_cfg
|
||||
|
||||
|
||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
context_length: int = self.text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
||||
pad_id: int = tokenizer.token_to_id(pad_token)
|
||||
tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||
tokenizer.enable_truncation(max_length=context_length)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
|
||||
}
|
||||
77
machine-learning/immich_ml/models/clip/visual.py
Normal file
77
machine-learning/immich_ml/models/clip/visual.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import (
|
||||
crop_pil,
|
||||
decode_pil,
|
||||
get_pil_resampling,
|
||||
normalize,
|
||||
resize_pil,
|
||||
serialize_np_array,
|
||||
to_numpy,
|
||||
)
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class BaseCLIPVisualEncoder(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: Image.Image | bytes, **kwargs: Any) -> str:
|
||||
image = decode_pil(inputs)
|
||||
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def model_cfg_path(self) -> Path:
|
||||
return self.cache_dir / "config.json"
|
||||
|
||||
@property
|
||||
def preprocess_cfg_path(self) -> Path:
|
||||
return self.model_dir / "preprocess_cfg.json"
|
||||
|
||||
@cached_property
|
||||
def model_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@cached_property
|
||||
def preprocess_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
||||
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||||
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
||||
return preprocess_cfg
|
||||
|
||||
|
||||
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
||||
def _load(self) -> ModelSession:
|
||||
size: list[int] | int = self.preprocess_cfg["size"]
|
||||
self.size = size[0] if isinstance(size, list) else size
|
||||
|
||||
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||
|
||||
return super()._load()
|
||||
|
||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||
image = resize_pil(image, self.size)
|
||||
image = crop_pil(image, self.size)
|
||||
image_np = to_numpy(image)
|
||||
image_np = normalize(image_np, self.mean, self.std)
|
||||
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|
||||
101
machine-learning/immich_ml/models/constants.py
Normal file
101
machine-learning/immich_ml/models/constants.py
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
from immich_ml.config import clean_name
|
||||
from immich_ml.schemas import ModelSource
|
||||
|
||||
_OPENCLIP_MODELS = {
|
||||
"RN101__openai",
|
||||
"RN101__yfcc15m",
|
||||
"RN50__cc12m",
|
||||
"RN50__openai",
|
||||
"RN50__yfcc15m",
|
||||
"RN50x16__openai",
|
||||
"RN50x4__openai",
|
||||
"RN50x64__openai",
|
||||
"ViT-B-16-SigLIP-256__webli",
|
||||
"ViT-B-16-SigLIP-384__webli",
|
||||
"ViT-B-16-SigLIP-512__webli",
|
||||
"ViT-B-16-SigLIP-i18n-256__webli",
|
||||
"ViT-B-16-SigLIP__webli",
|
||||
"ViT-B-16-plus-240__laion400m_e31",
|
||||
"ViT-B-16-plus-240__laion400m_e32",
|
||||
"ViT-B-16__laion400m_e31",
|
||||
"ViT-B-16__laion400m_e32",
|
||||
"ViT-B-16__openai",
|
||||
"ViT-B-32__laion2b-s34b-b79k",
|
||||
"ViT-B-32__laion2b_e16",
|
||||
"ViT-B-32__laion400m_e31",
|
||||
"ViT-B-32__laion400m_e32",
|
||||
"ViT-B-32__openai",
|
||||
"ViT-H-14-378-quickgelu__dfn5b",
|
||||
"ViT-H-14-quickgelu__dfn5b",
|
||||
"ViT-H-14__laion2b-s32b-b79k",
|
||||
"ViT-L-14-336__openai",
|
||||
"ViT-L-14-quickgelu__dfn2b",
|
||||
"ViT-L-14__laion2b-s32b-b82k",
|
||||
"ViT-L-14__laion400m_e31",
|
||||
"ViT-L-14__laion400m_e32",
|
||||
"ViT-L-14__openai",
|
||||
"ViT-L-16-SigLIP-256__webli",
|
||||
"ViT-L-16-SigLIP-384__webli",
|
||||
"ViT-SO400M-14-SigLIP-384__webli",
|
||||
"ViT-g-14__laion2b-s12b-b42k",
|
||||
"XLM-Roberta-Base-ViT-B-32__laion5b_s13b_b90k",
|
||||
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
|
||||
"nllb-clip-base-siglip__mrl",
|
||||
"nllb-clip-base-siglip__v1",
|
||||
"nllb-clip-large-siglip__mrl",
|
||||
"nllb-clip-large-siglip__v1",
|
||||
"ViT-B-16-SigLIP2__webli",
|
||||
"ViT-B-32-SigLIP2-256__webli",
|
||||
"ViT-L-16-SigLIP2-256__webli",
|
||||
"ViT-L-16-SigLIP2-384__webli",
|
||||
"ViT-L-16-SigLIP2-512__webli",
|
||||
"ViT-SO400M-14-SigLIP2-378__webli",
|
||||
"ViT-SO400M-14-SigLIP2__webli",
|
||||
"ViT-SO400M-16-SigLIP2-256__webli",
|
||||
"ViT-SO400M-16-SigLIP2-384__webli",
|
||||
"ViT-SO400M-16-SigLIP2-512__webli",
|
||||
"ViT-gopt-16-SigLIP2-256__webli",
|
||||
"ViT-gopt-16-SigLIP2-384__webli",
|
||||
}
|
||||
|
||||
|
||||
_MCLIP_MODELS = {
|
||||
"LABSE-Vit-L-14",
|
||||
"XLM-Roberta-Large-Vit-B-16Plus",
|
||||
"XLM-Roberta-Large-Vit-B-32",
|
||||
"XLM-Roberta-Large-Vit-L-14",
|
||||
}
|
||||
|
||||
|
||||
_INSIGHTFACE_MODELS = {
|
||||
"antelopev2",
|
||||
"buffalo_s",
|
||||
"buffalo_m",
|
||||
"buffalo_l",
|
||||
}
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS = [
|
||||
"CUDAExecutionProvider",
|
||||
"ROCMExecutionProvider",
|
||||
"OpenVINOExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
|
||||
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
|
||||
|
||||
|
||||
def get_model_source(model_name: str) -> ModelSource | None:
|
||||
cleaned_name = clean_name(model_name)
|
||||
|
||||
if cleaned_name in _INSIGHTFACE_MODELS:
|
||||
return ModelSource.INSIGHTFACE
|
||||
|
||||
if cleaned_name in _MCLIP_MODELS:
|
||||
return ModelSource.MCLIP
|
||||
|
||||
if cleaned_name in _OPENCLIP_MODELS:
|
||||
return ModelSource.OPENCLIP
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from insightface.model_zoo import RetinaFace
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_cv2
|
||||
from immich_ml.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class FaceDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||
|
||||
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
self.model = RetinaFace(session=session)
|
||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||
|
||||
return session
|
||||
|
||||
def _predict(self, inputs: NDArray[np.uint8] | bytes, **kwargs: Any) -> FaceDetectionOutput:
|
||||
inputs = decode_cv2(inputs)
|
||||
|
||||
bboxes, landmarks = self._detect(inputs)
|
||||
return {
|
||||
"boxes": bboxes[:, :4].round(),
|
||||
"scores": bboxes[:, 4],
|
||||
"landmarks": landmarks,
|
||||
}
|
||||
|
||||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||
return self.model.detect(inputs) # type: ignore
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from insightface.model_zoo import ArcFaceONNX
|
||||
from insightface.utils.face_align import norm_crop
|
||||
from numpy.typing import NDArray
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
from PIL import Image
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_cv2, serialize_np_array
|
||||
from immich_ml.schemas import (
|
||||
FaceDetectionOutput,
|
||||
FacialRecognitionOutput,
|
||||
ModelFormat,
|
||||
ModelSession,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
|
||||
class FaceRecognizer(InferenceModel):
|
||||
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
|
||||
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
|
||||
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
if (not self.batch_size or self.batch_size > 1) and str(session.get_inputs()[0].shape[0]) != "batch":
|
||||
self._add_batch_axis(self.model_path)
|
||||
session = self._make_session(self.model_path)
|
||||
self.model = ArcFaceONNX(
|
||||
self.model_path_for_format(ModelFormat.ONNX).as_posix(),
|
||||
session=session,
|
||||
)
|
||||
return session
|
||||
|
||||
def _predict(
|
||||
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput, **kwargs: Any
|
||||
) -> FacialRecognitionOutput:
|
||||
if faces["boxes"].shape[0] == 0:
|
||||
return []
|
||||
inputs = decode_cv2(inputs)
|
||||
cropped_faces = self._crop(inputs, faces)
|
||||
embeddings = self._predict_batch(cropped_faces)
|
||||
return self.postprocess(faces, embeddings)
|
||||
|
||||
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
|
||||
if not self.batch_size or len(cropped_faces) <= self.batch_size:
|
||||
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
|
||||
return embeddings
|
||||
|
||||
batch_embeddings: list[NDArray[np.float32]] = []
|
||||
for i in range(0, len(cropped_faces), self.batch_size):
|
||||
batch_embeddings.append(self.model.get_feat(cropped_faces[i : i + self.batch_size]))
|
||||
return np.concatenate(batch_embeddings, axis=0)
|
||||
|
||||
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
|
||||
return [
|
||||
{
|
||||
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
|
||||
"embedding": serialize_np_array(embedding),
|
||||
"score": score,
|
||||
}
|
||||
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
|
||||
]
|
||||
|
||||
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
|
||||
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
|
||||
|
||||
def _add_batch_axis(self, model_path: Path) -> None:
|
||||
log.debug(f"Adding batch axis to model {model_path}")
|
||||
proto = onnx.load(model_path)
|
||||
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
|
||||
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
|
||||
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
|
||||
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
|
||||
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
|
||||
onnx.save(updated_proto, model_path)
|
||||
|
||||
@property
|
||||
def _batch_size_default(self) -> int | None:
|
||||
providers = ort.get_available_providers()
|
||||
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1
|
||||
78
machine-learning/immich_ml/models/transforms.py
Normal file
78
machine-learning/immich_ml/models/transforms.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import string
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import orjson
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
|
||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
|
||||
|
||||
|
||||
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
if img.width < img.height:
|
||||
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
|
||||
else:
|
||||
return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC)
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/60883103
|
||||
def crop_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
left = int((img.size[0] / 2) - (size / 2))
|
||||
upper = int((img.size[1] / 2) - (size / 2))
|
||||
right = left + size
|
||||
lower = upper + size
|
||||
|
||||
return img.crop((left, upper, right, lower))
|
||||
|
||||
|
||||
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
|
||||
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
|
||||
|
||||
def normalize(
|
||||
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
|
||||
) -> NDArray[np.float32]:
|
||||
return np.divide(img - mean, std, dtype=np.float32)
|
||||
|
||||
|
||||
def get_pil_resampling(resample: str) -> Image.Resampling:
|
||||
return _PIL_RESAMPLING_METHODS[resample.lower()]
|
||||
|
||||
|
||||
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
|
||||
|
||||
|
||||
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
|
||||
if isinstance(image_bytes, Image.Image):
|
||||
return image_bytes
|
||||
image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
|
||||
image.load()
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
|
||||
if isinstance(image_bytes, bytes):
|
||||
image_bytes = decode_pil(image_bytes) # pillow is much faster than cv2
|
||||
if isinstance(image_bytes, Image.Image):
|
||||
return pil_to_cv2(image_bytes)
|
||||
return image_bytes
|
||||
|
||||
|
||||
def clean_text(text: str, canonicalize: bool = False) -> str:
|
||||
text = " ".join(text.split())
|
||||
if canonicalize:
|
||||
text = text.translate(_PUNCTUATION_TRANS).lower()
|
||||
return text
|
||||
|
||||
|
||||
# this allows the client to use the array as a string without deserializing only to serialize back to a string
|
||||
# TODO: use this in a less invasive way
|
||||
def serialize_np_array(arr: NDArray[np.float32]) -> str:
|
||||
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()
|
||||
Loading…
Add table
Add a link
Reference in a new issue