mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
chore(ml): installable package (#17153)
* app -> immich_ml * fix test ci * omit file name * add new line * add new line
This commit is contained in:
parent
f7d730eb05
commit
84c35e35d6
31 changed files with 347 additions and 316 deletions
0
machine-learning/immich_ml/sessions/__init__.py
Normal file
0
machine-learning/immich_ml/sessions/__init__.py
Normal file
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from .loader import Ann
|
||||
|
||||
|
||||
class AnnSession:
|
||||
"""
|
||||
Wrapper for ANN to be drop-in replacement for ONNX session.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
|
||||
self.model_path = model_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
|
||||
|
||||
log.info("Loading ANN model %s ...", model_path)
|
||||
self.model = self.ann.load(
|
||||
model_path.as_posix(),
|
||||
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
|
||||
fp16=settings.ann_fp16_turbo,
|
||||
)
|
||||
log.info("Loaded ANN model with ID %d", self.model)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.ann.unload(self.model)
|
||||
log.info("Unloaded ANN model %d", self.model)
|
||||
self.ann.destroy()
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.input_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.output_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
return self.ann.execute(self.model, inputs)
|
||||
|
||||
|
||||
class AnnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
||||
169
machine-learning/immich_ml/sessions/ann/loader.py
Normal file
169
machine-learning/immich_ml/sessions/ann/loader.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ctypes import CDLL, Array, c_bool, c_char_p, c_int, c_ulong, c_void_p
|
||||
from os.path import exists
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log
|
||||
|
||||
try:
|
||||
CDLL("libmali.so") # fail if libmali.so is not mounted into container
|
||||
libann = CDLL("libann.so")
|
||||
libann.init.argtypes = c_int, c_int, c_char_p
|
||||
libann.init.restype = c_void_p
|
||||
libann.load.argtypes = c_void_p, c_char_p, c_bool, c_bool, c_bool, c_char_p
|
||||
libann.load.restype = c_int
|
||||
libann.execute.argtypes = c_void_p, c_int, Array[c_void_p], Array[c_void_p]
|
||||
libann.unload.argtypes = c_void_p, c_int
|
||||
libann.destroy.argtypes = (c_void_p,)
|
||||
libann.shape.argtypes = c_void_p, c_int, c_bool, c_int
|
||||
libann.shape.restype = c_ulong
|
||||
libann.tensors.argtypes = c_void_p, c_int, c_bool
|
||||
libann.tensors.restype = c_int
|
||||
is_available = True
|
||||
except OSError as e:
|
||||
log.debug("Could not load ANN shared libraries, using ONNX: %s", e)
|
||||
is_available = False
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Newable(Protocol[T]):
|
||||
def new(self) -> None: ...
|
||||
|
||||
|
||||
class _Singleton(type, Newable[T]):
|
||||
_instances: dict[_Singleton[T], Newable[T]] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Newable[T]:
|
||||
if cls not in cls._instances:
|
||||
obj: Newable[T] = super(_Singleton, cls).__call__(*args, **kwargs)
|
||||
cls._instances[cls] = obj
|
||||
else:
|
||||
obj = cls._instances[cls]
|
||||
obj.new()
|
||||
return obj
|
||||
|
||||
|
||||
class Ann(metaclass=_Singleton):
|
||||
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
|
||||
if not is_available:
|
||||
raise RuntimeError("libann is not available!")
|
||||
if tuning_level == 0 and tuning_file is None:
|
||||
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
|
||||
if tuning_level < 0 or tuning_level > 3:
|
||||
raise ValueError("tuning_level must be 0 (load from tuning_file), 1, 2 or 3.")
|
||||
if log_level < 0 or log_level > 5:
|
||||
raise ValueError("log_level must be 0 (trace), 1 (debug), 2 (info), 3 (warning), 4 (error) or 5 (fatal)")
|
||||
self.log_level = log_level
|
||||
self.tuning_level = tuning_level
|
||||
self.tuning_file = tuning_file
|
||||
self.output_shapes: dict[int, tuple[tuple[int], ...]] = {}
|
||||
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
|
||||
self.ann: int | None = None
|
||||
self.new()
|
||||
|
||||
if self.tuning_file is not None:
|
||||
# make sure tuning file exists (without clearing contents)
|
||||
# once filled, the tuning file reduces the cost/time of the first
|
||||
# inference after model load by 10s of seconds
|
||||
open(self.tuning_file, "a").close()
|
||||
|
||||
def new(self) -> None:
|
||||
if self.ann is None:
|
||||
self.ann = libann.init(
|
||||
self.log_level,
|
||||
self.tuning_level,
|
||||
self.tuning_file.encode() if self.tuning_file is not None else None,
|
||||
)
|
||||
self.ref_count = 0
|
||||
|
||||
self.ref_count += 1
|
||||
|
||||
def destroy(self) -> None:
|
||||
self.ref_count -= 1
|
||||
if self.ref_count <= 0 and self.ann is not None:
|
||||
libann.destroy(self.ann)
|
||||
self.ann = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.ann is not None:
|
||||
libann.destroy(self.ann)
|
||||
self.ann = None
|
||||
|
||||
def load(
|
||||
self,
|
||||
model_path: str,
|
||||
fast_math: bool = True,
|
||||
fp16: bool = False,
|
||||
cached_network_path: str | None = None,
|
||||
) -> int:
|
||||
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
|
||||
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
|
||||
if not exists(model_path):
|
||||
raise ValueError("model_path must point to an existing file!")
|
||||
|
||||
save_cached_network = False
|
||||
if cached_network_path is not None and not exists(cached_network_path):
|
||||
save_cached_network = True
|
||||
# create empty model cache file
|
||||
open(cached_network_path, "a").close()
|
||||
|
||||
net_id: int = libann.load(
|
||||
self.ann,
|
||||
model_path.encode(),
|
||||
fast_math,
|
||||
fp16,
|
||||
save_cached_network,
|
||||
cached_network_path.encode() if cached_network_path is not None else None,
|
||||
)
|
||||
if net_id < 0:
|
||||
raise ValueError("Cannot load model!")
|
||||
|
||||
self.input_shapes[net_id] = tuple(
|
||||
self.shape(net_id, input=True, index=i) for i in range(self.tensors(net_id, input=True))
|
||||
)
|
||||
self.output_shapes[net_id] = tuple(
|
||||
self.shape(net_id, input=False, index=i) for i in range(self.tensors(net_id, input=False))
|
||||
)
|
||||
return net_id
|
||||
|
||||
def unload(self, network_id: int) -> None:
|
||||
libann.unload(self.ann, network_id)
|
||||
del self.output_shapes[network_id]
|
||||
|
||||
def execute(self, network_id: int, input_tensors: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
if not isinstance(input_tensors, list):
|
||||
raise ValueError("input_tensors needs to be a list!")
|
||||
net_input_shapes = self.input_shapes[network_id]
|
||||
if len(input_tensors) != len(net_input_shapes):
|
||||
raise ValueError(f"input_tensors lengths {len(input_tensors)} != network inputs {len(net_input_shapes)}")
|
||||
for net_input_shape, input_tensor in zip(net_input_shapes, input_tensors):
|
||||
if net_input_shape != input_tensor.shape:
|
||||
raise ValueError(f"input_tensor shape {input_tensor.shape} != network input shape {net_input_shape}")
|
||||
if not input_tensor.flags.c_contiguous:
|
||||
raise ValueError("input_tensors must be c_contiguous numpy ndarrays")
|
||||
output_tensors: list[NDArray[np.float32]] = [
|
||||
np.ndarray(s, dtype=np.float32) for s in self.output_shapes[network_id]
|
||||
]
|
||||
input_type = c_void_p * len(input_tensors)
|
||||
inputs = input_type(*[t.ctypes.data_as(c_void_p) for t in input_tensors])
|
||||
output_type = c_void_p * len(output_tensors)
|
||||
outputs = output_type(*[t.ctypes.data_as(c_void_p) for t in output_tensors])
|
||||
libann.execute(self.ann, network_id, inputs, outputs)
|
||||
return output_tensors
|
||||
|
||||
def shape(self, network_id: int, input: bool = False, index: int = 0) -> tuple[int]:
|
||||
s = libann.shape(self.ann, network_id, input, index)
|
||||
a = []
|
||||
while s != 0:
|
||||
a.append(s & 0xFFFF)
|
||||
s >>= 16
|
||||
return tuple(a)
|
||||
|
||||
def tensors(self, network_id: int, input: bool = False) -> int:
|
||||
tensors: int = libann.tensors(self.ann, network_id, input)
|
||||
return tensors
|
||||
135
machine-learning/immich_ml/sessions/ort.py
Normal file
135
machine-learning/immich_ml/sessions/ort.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.models.constants import SUPPORTED_PROVIDERS
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from ..config import log, settings
|
||||
|
||||
|
||||
class OrtSession:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Path | str,
|
||||
providers: list[str] | None = None,
|
||||
provider_options: list[dict[str, Any]] | None = None,
|
||||
sess_options: ort.SessionOptions | None = None,
|
||||
):
|
||||
self.model_path = Path(model_path)
|
||||
self.providers = providers if providers is not None else self._providers_default
|
||||
self.provider_options = provider_options if provider_options is not None else self._provider_options_default
|
||||
self.sess_options = sess_options if sess_options is not None else self._sess_options_default
|
||||
self.session = ort.InferenceSession(
|
||||
self.model_path.as_posix(),
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
sess_options=self.sess_options,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
inputs: list[SessionNode] = self.session.get_inputs()
|
||||
return inputs
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
outputs: list[SessionNode] = self.session.get_outputs()
|
||||
return outputs
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def providers(self) -> list[str]:
|
||||
return self._providers
|
||||
|
||||
@providers.setter
|
||||
def providers(self, providers: list[str]) -> None:
|
||||
log.info(f"Setting execution providers to {providers}, in descending order of preference")
|
||||
self._providers = providers
|
||||
|
||||
@property
|
||||
def _providers_default(self) -> list[str]:
|
||||
available_providers = set(ort.get_available_providers())
|
||||
log.debug(f"Available ORT providers: {available_providers}")
|
||||
if (openvino := "OpenVINOExecutionProvider") in available_providers:
|
||||
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
|
||||
log.debug(f"Available OpenVINO devices: {device_ids}")
|
||||
|
||||
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
|
||||
if not gpu_devices:
|
||||
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
|
||||
available_providers.remove(openvino)
|
||||
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
|
||||
|
||||
@property
|
||||
def provider_options(self) -> list[dict[str, Any]]:
|
||||
return self._provider_options
|
||||
|
||||
@provider_options.setter
|
||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||
log.debug(f"Setting execution provider options to {provider_options}")
|
||||
self._provider_options = provider_options
|
||||
|
||||
@property
|
||||
def _provider_options_default(self) -> list[dict[str, Any]]:
|
||||
provider_options = []
|
||||
for provider in self.providers:
|
||||
match provider:
|
||||
case "CPUExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested"}
|
||||
case "CUDAExecutionProvider" | "ROCMExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id}
|
||||
case "OpenVINOExecutionProvider":
|
||||
options = {
|
||||
"device_type": f"GPU.{settings.device_id}",
|
||||
"precision": "FP32",
|
||||
"cache_dir": (self.model_path.parent / "openvino").as_posix(),
|
||||
}
|
||||
case _:
|
||||
options = {}
|
||||
provider_options.append(options)
|
||||
return provider_options
|
||||
|
||||
@property
|
||||
def sess_options(self) -> ort.SessionOptions:
|
||||
return self._sess_options
|
||||
|
||||
@sess_options.setter
|
||||
def sess_options(self, sess_options: ort.SessionOptions) -> None:
|
||||
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
|
||||
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
|
||||
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
|
||||
self._sess_options = sess_options
|
||||
|
||||
@property
|
||||
def _sess_options_default(self) -> ort.SessionOptions:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_cpu_mem_arena = False
|
||||
|
||||
# avoid thread contention between models
|
||||
if settings.model_inter_op_threads > 0:
|
||||
sess_options.inter_op_num_threads = settings.model_inter_op_threads
|
||||
# these defaults work well for CPU, but bottleneck GPU
|
||||
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
||||
sess_options.inter_op_num_threads = 1
|
||||
|
||||
if settings.model_intra_op_threads > 0:
|
||||
sess_options.intra_op_num_threads = settings.model_intra_op_threads
|
||||
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
||||
sess_options.intra_op_num_threads = 2
|
||||
|
||||
if sess_options.inter_op_num_threads > 1:
|
||||
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
||||
|
||||
return sess_options
|
||||
76
machine-learning/immich_ml/sessions/rknn/__init__.py
Normal file
76
machine-learning/immich_ml/sessions/rknn/__init__.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
||||
|
||||
is_available = is_available and settings.rknn
|
||||
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
|
||||
|
||||
|
||||
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")
|
||||
return outputs
|
||||
|
||||
|
||||
input_output_mapping: dict[str, dict[str, Any]] = {
|
||||
"detection": {
|
||||
"input": {"norm_tensor:0": (1, 3, 640, 640)},
|
||||
"output": {
|
||||
"norm_tensor:1": (12800, 1),
|
||||
"norm_tensor:2": (3200, 1),
|
||||
"norm_tensor:3": (800, 1),
|
||||
"norm_tensor:4": (12800, 4),
|
||||
"norm_tensor:5": (3200, 4),
|
||||
"norm_tensor:6": (800, 4),
|
||||
"norm_tensor:7": (12800, 10),
|
||||
"norm_tensor:8": (3200, 10),
|
||||
"norm_tensor:9": (800, 10),
|
||||
},
|
||||
},
|
||||
"recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}},
|
||||
}
|
||||
|
||||
|
||||
class RknnSession:
|
||||
def __init__(self, model_path: Path) -> None:
|
||||
self.model_type = "detection" if "detection" in model_path.parts else "recognition"
|
||||
self.tpe = settings.rknn_threads
|
||||
|
||||
log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.")
|
||||
self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference)
|
||||
log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.")
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()]
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()]
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
self.rknnpool.put(input_data)
|
||||
res = self.rknnpool.get()
|
||||
if res is None:
|
||||
raise RuntimeError("RKNN inference failed!")
|
||||
return res
|
||||
|
||||
|
||||
class RknnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]
|
||||
91
machine-learning/immich_ml/sessions/rknn/rknnpool.py
Normal file
91
machine-learning/immich_ml/sessions/rknn/rknnpool.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# This code is from leafqycc/rknn-multi-threaded
|
||||
# Following Apache License 2.0
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.constants import RKNN_COREMASK_SUPPORTED_SOCS, RKNN_SUPPORTED_SOCS
|
||||
|
||||
|
||||
def get_soc(device_tree_path: Path | str) -> str | None:
|
||||
try:
|
||||
with Path(device_tree_path).open() as f:
|
||||
device_compatible_str = f.read()
|
||||
for soc in RKNN_SUPPORTED_SOCS:
|
||||
if soc in device_compatible_str:
|
||||
return soc
|
||||
log.warning("Device is not supported for RKNN")
|
||||
except OSError as e:
|
||||
log.warning(f"Could not read {device_tree_path}. Reason: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
soc_name = None
|
||||
is_available = False
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
soc_name = get_soc("/proc/device-tree/compatible")
|
||||
is_available = soc_name is not None
|
||||
except ImportError:
|
||||
log.debug("RKNN is not available")
|
||||
|
||||
|
||||
def init_rknn(model_path: str) -> "RKNNLite":
|
||||
if not is_available:
|
||||
raise RuntimeError("rknn is not available!")
|
||||
rknn_lite = RKNNLite()
|
||||
rknn_lite.rknn_log.logger.setLevel(logging.ERROR)
|
||||
ret = rknn_lite.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError("Failed to load RKNN model")
|
||||
|
||||
if soc_name in RKNN_COREMASK_SUPPORTED_SOCS:
|
||||
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
|
||||
else:
|
||||
ret = rknn_lite.init_runtime() # Please do not set this parameter on other platforms.
|
||||
|
||||
if ret != 0:
|
||||
raise RuntimeError("Failed to inititalize RKNN runtime environment")
|
||||
|
||||
return rknn_lite
|
||||
|
||||
|
||||
class RknnPoolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tpes: int,
|
||||
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
|
||||
) -> None:
|
||||
self.tpes = tpes
|
||||
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
|
||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||
self.func = func
|
||||
self.num = 0
|
||||
|
||||
def put(self, inputs: list[NDArray[np.float32]]) -> None:
|
||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||
self.num += 1
|
||||
|
||||
def get(self) -> list[NDArray[np.float32]] | None:
|
||||
if self.queue.empty():
|
||||
return None
|
||||
fut = self.queue.get()
|
||||
return fut.result()
|
||||
|
||||
def release(self) -> None:
|
||||
self.pool.shutdown()
|
||||
for rknn_lite in self.rknn_pool:
|
||||
rknn_lite.release()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.release()
|
||||
Loading…
Add table
Add a link
Reference in a new issue