mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
refactor(ml): model sessions (#10559)
This commit is contained in:
parent
6538ad8de7
commit
6356c28f64
11 changed files with 529 additions and 375 deletions
|
|
@ -52,8 +52,6 @@ class Ann(metaclass=_Singleton):
|
|||
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
|
||||
if not is_available:
|
||||
raise RuntimeError("libann is not available!")
|
||||
if tuning_file and not exists(tuning_file):
|
||||
raise ValueError("tuning_file must point to an existing (possibly empty) file!")
|
||||
if tuning_level == 0 and tuning_file is None:
|
||||
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
|
||||
if tuning_level < 0 or tuning_level > 3:
|
||||
|
|
@ -67,6 +65,12 @@ class Ann(metaclass=_Singleton):
|
|||
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
|
||||
self.ann: int | None = None
|
||||
self.new()
|
||||
|
||||
if self.tuning_file is not None:
|
||||
# make sure tuning file exists (without clearing contents)
|
||||
# once filled, the tuning file reduces the cost/time of the first
|
||||
# inference after model load by 10s of seconds
|
||||
open(self.tuning_file, "a").close()
|
||||
|
||||
def new(self) -> None:
|
||||
if self.ann is None:
|
||||
|
|
@ -95,17 +99,19 @@ class Ann(metaclass=_Singleton):
|
|||
model_path: str,
|
||||
fast_math: bool = True,
|
||||
fp16: bool = False,
|
||||
save_cached_network: bool = False,
|
||||
cached_network_path: str | None = None,
|
||||
) -> int:
|
||||
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
|
||||
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
|
||||
if not exists(model_path):
|
||||
raise ValueError("model_path must point to an existing file!")
|
||||
|
||||
save_cached_network = False
|
||||
if cached_network_path is not None and not exists(cached_network_path):
|
||||
raise ValueError("cached_network_path must point to an existing (possibly empty) file!")
|
||||
if save_cached_network and cached_network_path is None:
|
||||
raise ValueError("save_cached_network is True, cached_network_path must be specified!")
|
||||
save_cached_network = True
|
||||
# create empty model cache file
|
||||
open(cached_network_path, "a").close()
|
||||
|
||||
net_id: int = libann.load(
|
||||
self.ann,
|
||||
model_path.encode(),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue