mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
feat(ml): ARMNN acceleration (#5667)
* feat(ml): ARMNN acceleration for CLIP * wrap ANN as ONNX-Session * strict typing * normalize ARMNN CLIP embedding * mutex to handle concurrent execution * make inputs contiguous * fine-grained locking; concurrent network execution --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
parent
29747437f6
commit
753292956e
17 changed files with 956 additions and 44 deletions
2
machine-learning/ann/export/.gitignore
vendored
Normal file
2
machine-learning/ann/export/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
armnn*
|
||||
output/
|
||||
4
machine-learning/ann/export/build-converter.sh
Executable file
4
machine-learning/ann/export/build-converter.sh
Executable file
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/sh
|
||||
|
||||
cd armnn-23.11/
|
||||
g++ -o ../armnnconverter -O1 -DARMNN_ONNX_PARSER -DARMNN_SERIALIZER -DARMNN_TF_LITE_PARSER -fuse-ld=gold -std=c++17 -Iinclude -Isrc/armnnUtils -Ithird-party -larmnn -larmnnDeserializer -larmnnTfLiteParser -larmnnOnnxParser -larmnnSerializer -L../armnn src/armnnConverter/ArmnnConverter.cpp
|
||||
8
machine-learning/ann/export/download-armnn.sh
Executable file
8
machine-learning/ann/export/download-armnn.sh
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
#!/bin/sh
|
||||
|
||||
# binaries
|
||||
mkdir armnn
|
||||
curl -SL "https://github.com/ARM-software/armnn/releases/download/v23.11/ArmNN-linux-x86_64.tar.gz" | tar -zx -C armnn
|
||||
|
||||
# source to build ArmnnConverter
|
||||
curl -SL "https://github.com/ARM-software/armnn/archive/refs/tags/v23.11.tar.gz" | tar -zx
|
||||
201
machine-learning/ann/export/env.yaml
Normal file
201
machine-learning/ann/export/env.yaml
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
name: annexport
|
||||
channels:
|
||||
- pytorch
|
||||
- nvidia
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=conda_forge
|
||||
- _openmp_mutex=4.5=2_kmp_llvm
|
||||
- aiohttp=3.9.1=py310h2372a71_0
|
||||
- aiosignal=1.3.1=pyhd8ed1ab_0
|
||||
- arpack=3.8.0=nompi_h0baa96a_101
|
||||
- async-timeout=4.0.3=pyhd8ed1ab_0
|
||||
- attrs=23.1.0=pyh71513ae_1
|
||||
- aws-c-auth=0.7.3=h28f7589_1
|
||||
- aws-c-cal=0.6.1=hc309b26_1
|
||||
- aws-c-common=0.9.0=hd590300_0
|
||||
- aws-c-compression=0.2.17=h4d4d85c_2
|
||||
- aws-c-event-stream=0.3.1=h2e3709c_4
|
||||
- aws-c-http=0.7.11=h00aa349_4
|
||||
- aws-c-io=0.13.32=he9a53bd_1
|
||||
- aws-c-mqtt=0.9.3=hb447be9_1
|
||||
- aws-c-s3=0.3.14=hf3aad02_1
|
||||
- aws-c-sdkutils=0.1.12=h4d4d85c_1
|
||||
- aws-checksums=0.1.17=h4d4d85c_1
|
||||
- aws-crt-cpp=0.21.0=hb942446_5
|
||||
- aws-sdk-cpp=1.10.57=h85b1a90_19
|
||||
- blas=2.120=openblas
|
||||
- blas-devel=3.9.0=20_linux64_openblas
|
||||
- brotli-python=1.0.9=py310hd8f1fbe_9
|
||||
- bzip2=1.0.8=hd590300_5
|
||||
- c-ares=1.23.0=hd590300_0
|
||||
- ca-certificates=2023.11.17=hbcca054_0
|
||||
- certifi=2023.11.17=pyhd8ed1ab_0
|
||||
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
||||
- click=8.1.7=unix_pyh707e725_0
|
||||
- colorama=0.4.6=pyhd8ed1ab_0
|
||||
- coloredlogs=15.0.1=pyhd8ed1ab_3
|
||||
- cuda-cudart=11.7.99=0
|
||||
- cuda-cupti=11.7.101=0
|
||||
- cuda-libraries=11.7.1=0
|
||||
- cuda-nvrtc=11.7.99=0
|
||||
- cuda-nvtx=11.7.91=0
|
||||
- cuda-runtime=11.7.1=0
|
||||
- dataclasses=0.8=pyhc8e2a94_3
|
||||
- datasets=2.14.7=pyhd8ed1ab_0
|
||||
- dill=0.3.7=pyhd8ed1ab_0
|
||||
- filelock=3.13.1=pyhd8ed1ab_0
|
||||
- flatbuffers=23.5.26=h59595ed_1
|
||||
- freetype=2.12.1=h267a509_2
|
||||
- frozenlist=1.4.0=py310h2372a71_1
|
||||
- fsspec=2023.10.0=pyhca7485f_0
|
||||
- ftfy=6.1.3=pyhd8ed1ab_0
|
||||
- gflags=2.2.2=he1b5a44_1004
|
||||
- glog=0.6.0=h6f12383_0
|
||||
- glpk=5.0=h445213a_0
|
||||
- gmp=6.3.0=h59595ed_0
|
||||
- gmpy2=2.1.2=py310h3ec546c_1
|
||||
- huggingface_hub=0.17.3=pyhd8ed1ab_0
|
||||
- humanfriendly=10.0=pyhd8ed1ab_6
|
||||
- icu=73.2=h59595ed_0
|
||||
- idna=3.6=pyhd8ed1ab_0
|
||||
- importlib-metadata=7.0.0=pyha770c72_0
|
||||
- importlib_metadata=7.0.0=hd8ed1ab_0
|
||||
- joblib=1.3.2=pyhd8ed1ab_0
|
||||
- keyutils=1.6.1=h166bdaf_0
|
||||
- krb5=1.21.2=h659d440_0
|
||||
- lcms2=2.15=h7f713cb_2
|
||||
- ld_impl_linux-64=2.40=h41732ed_0
|
||||
- lerc=4.0.0=h27087fc_0
|
||||
- libabseil=20230125.3=cxx17_h59595ed_0
|
||||
- libarrow=12.0.1=hb87d912_8_cpu
|
||||
- libblas=3.9.0=20_linux64_openblas
|
||||
- libbrotlicommon=1.0.9=h166bdaf_9
|
||||
- libbrotlidec=1.0.9=h166bdaf_9
|
||||
- libbrotlienc=1.0.9=h166bdaf_9
|
||||
- libcblas=3.9.0=20_linux64_openblas
|
||||
- libcrc32c=1.1.2=h9c3ff4c_0
|
||||
- libcublas=11.10.3.66=0
|
||||
- libcufft=10.7.2.124=h4fbf590_0
|
||||
- libcufile=1.8.1.2=0
|
||||
- libcurand=10.3.4.101=0
|
||||
- libcurl=8.5.0=hca28451_0
|
||||
- libcusolver=11.4.0.1=0
|
||||
- libcusparse=11.7.4.91=0
|
||||
- libdeflate=1.19=hd590300_0
|
||||
- libedit=3.1.20191231=he28a2e2_2
|
||||
- libev=4.33=hd590300_2
|
||||
- libevent=2.1.12=hf998b51_1
|
||||
- libffi=3.4.2=h7f98852_5
|
||||
- libgcc-ng=13.2.0=h807b86a_3
|
||||
- libgfortran-ng=13.2.0=h69a702a_3
|
||||
- libgfortran5=13.2.0=ha4646dd_3
|
||||
- libgoogle-cloud=2.12.0=hac9eb74_1
|
||||
- libgrpc=1.54.3=hb20ce57_0
|
||||
- libhwloc=2.9.3=default_h554bfaf_1009
|
||||
- libiconv=1.17=hd590300_1
|
||||
- libjpeg-turbo=2.1.5.1=hd590300_1
|
||||
- liblapack=3.9.0=20_linux64_openblas
|
||||
- liblapacke=3.9.0=20_linux64_openblas
|
||||
- libnghttp2=1.58.0=h47da74e_1
|
||||
- libnpp=11.7.4.75=0
|
||||
- libnsl=2.0.1=hd590300_0
|
||||
- libnuma=2.0.16=h0b41bf4_1
|
||||
- libnvjpeg=11.8.0.2=0
|
||||
- libopenblas=0.3.25=pthreads_h413a1c8_0
|
||||
- libpng=1.6.39=h753d276_0
|
||||
- libprotobuf=3.21.12=hfc55251_2
|
||||
- libsentencepiece=0.1.99=h180e1df_0
|
||||
- libsqlite=3.44.2=h2797004_0
|
||||
- libssh2=1.11.0=h0841786_0
|
||||
- libstdcxx-ng=13.2.0=h7e041cc_3
|
||||
- libthrift=0.18.1=h8fd135c_2
|
||||
- libtiff=4.6.0=h29866fb_1
|
||||
- libutf8proc=2.8.0=h166bdaf_0
|
||||
- libuuid=2.38.1=h0b41bf4_0
|
||||
- libwebp-base=1.3.2=hd590300_0
|
||||
- libxcb=1.15=h0b41bf4_0
|
||||
- libxml2=2.11.6=h232c23b_0
|
||||
- libzlib=1.2.13=hd590300_5
|
||||
- llvm-openmp=17.0.6=h4dfa4b3_0
|
||||
- lz4-c=1.9.4=hcb278e6_0
|
||||
- mkl=2022.2.1=h84fe81f_16997
|
||||
- mkl-devel=2022.2.1=ha770c72_16998
|
||||
- mkl-include=2022.2.1=h84fe81f_16997
|
||||
- mpc=1.3.1=hfe3b2da_0
|
||||
- mpfr=4.2.1=h9458935_0
|
||||
- mpmath=1.3.0=pyhd8ed1ab_0
|
||||
- multidict=6.0.4=py310h2372a71_1
|
||||
- multiprocess=0.70.15=py310h2372a71_1
|
||||
- ncurses=6.4=h59595ed_2
|
||||
- numpy=1.26.2=py310hb13e2d6_0
|
||||
- onnx=1.14.0=py310ha3deec4_1
|
||||
- onnx2torch=1.5.13=pyhd8ed1ab_0
|
||||
- onnxruntime=1.16.3=py310hd4b7fbc_1_cpu
|
||||
- open-clip-torch=2.23.0=pyhd8ed1ab_1
|
||||
- openblas=0.3.25=pthreads_h7a3da1a_0
|
||||
- openjpeg=2.5.0=h488ebb8_3
|
||||
- openssl=3.2.0=hd590300_1
|
||||
- orc=1.9.0=h2f23424_1
|
||||
- packaging=23.2=pyhd8ed1ab_0
|
||||
- pandas=2.1.4=py310hcc13569_0
|
||||
- pillow=10.0.1=py310h29da1c1_1
|
||||
- pip=23.3.1=pyhd8ed1ab_0
|
||||
- protobuf=4.21.12=py310heca2aa9_0
|
||||
- pthread-stubs=0.4=h36c2ea0_1001
|
||||
- pyarrow=12.0.1=py310h0576679_8_cpu
|
||||
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
||||
- pysocks=1.7.1=pyha2e5f31_6
|
||||
- python=3.10.13=hd12c33a_0_cpython
|
||||
- python-dateutil=2.8.2=pyhd8ed1ab_0
|
||||
- python-flatbuffers=23.5.26=pyhd8ed1ab_0
|
||||
- python-tzdata=2023.3=pyhd8ed1ab_0
|
||||
- python-xxhash=3.4.1=py310h2372a71_0
|
||||
- python_abi=3.10=4_cp310
|
||||
- pytorch=1.13.1=cpu_py310hd11e9c7_1
|
||||
- pytorch-cuda=11.7=h778d358_5
|
||||
- pytorch-mutex=1.0=cuda
|
||||
- pytz=2023.3.post1=pyhd8ed1ab_0
|
||||
- pyyaml=6.0.1=py310h2372a71_1
|
||||
- rdma-core=28.9=h59595ed_1
|
||||
- re2=2023.03.02=h8c504da_0
|
||||
- readline=8.2=h8228510_1
|
||||
- regex=2023.10.3=py310h2372a71_0
|
||||
- requests=2.31.0=pyhd8ed1ab_0
|
||||
- s2n=1.3.49=h06160fa_0
|
||||
- sacremoses=0.0.53=pyhd8ed1ab_0
|
||||
- safetensors=0.3.3=py310hcb5633a_1
|
||||
- sentencepiece=0.1.99=hff52083_0
|
||||
- sentencepiece-python=0.1.99=py310hebdb9f0_0
|
||||
- sentencepiece-spm=0.1.99=h180e1df_0
|
||||
- setuptools=68.2.2=pyhd8ed1ab_0
|
||||
- six=1.16.0=pyh6c4a22f_0
|
||||
- sleef=3.5.1=h9b69904_2
|
||||
- snappy=1.1.10=h9fff704_0
|
||||
- sympy=1.12=pypyh9d50eac_103
|
||||
- tbb=2021.11.0=h00ab1b0_0
|
||||
- texttable=1.7.0=pyhd8ed1ab_0
|
||||
- timm=0.9.12=pyhd8ed1ab_0
|
||||
- tk=8.6.13=noxft_h4845f30_101
|
||||
- tokenizers=0.14.1=py310h320607d_2
|
||||
- torchvision=0.14.1=cpu_py310hd3d2ac3_1
|
||||
- tqdm=4.66.1=pyhd8ed1ab_0
|
||||
- transformers=4.35.2=pyhd8ed1ab_0
|
||||
- typing-extensions=4.9.0=hd8ed1ab_0
|
||||
- typing_extensions=4.9.0=pyha770c72_0
|
||||
- tzdata=2023c=h71feb2d_0
|
||||
- ucx=1.14.1=h64cca9d_5
|
||||
- urllib3=2.1.0=pyhd8ed1ab_0
|
||||
- wcwidth=0.2.12=pyhd8ed1ab_0
|
||||
- wheel=0.42.0=pyhd8ed1ab_0
|
||||
- xorg-libxau=1.0.11=hd590300_0
|
||||
- xorg-libxdmcp=1.1.3=h7f98852_0
|
||||
- xxhash=0.8.2=hd590300_0
|
||||
- xz=5.2.6=h166bdaf_0
|
||||
- yaml=0.2.5=h7f98852_2
|
||||
- yarl=1.9.3=py310h2372a71_0
|
||||
- zipp=3.17.0=pyhd8ed1ab_0
|
||||
- zlib=1.2.13=hd590300_5
|
||||
- zstd=1.5.5=hfc55251_0
|
||||
- pip:
|
||||
- git+https://github.com/fyfrey/TinyNeuralNetwork.git
|
||||
157
machine-learning/ann/export/run.py
Normal file
157
machine-learning/ann/export/run.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
from abc import abstractmethod
|
||||
|
||||
import onnx
|
||||
import open_clip
|
||||
import torch
|
||||
from onnx2torch import convert
|
||||
from onnxruntime.tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
|
||||
from tinynn.converter import TFLiteConverter
|
||||
|
||||
|
||||
class ExportBase(torch.nn.Module):
|
||||
input_shape: tuple[int, ...]
|
||||
|
||||
def __init__(self, device: torch.device, name: str):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.name = name
|
||||
self.optimize = 5
|
||||
self.nchw_transpose = False
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
|
||||
pass
|
||||
|
||||
def dummy_input(self) -> torch.FloatTensor:
|
||||
return torch.rand((1, 3, 224, 224), device=self.device)
|
||||
|
||||
|
||||
class ArcFace(ExportBase):
|
||||
input_shape = (1, 3, 112, 112)
|
||||
|
||||
def __init__(self, onnx_model_path: str, device: torch.device):
|
||||
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
|
||||
super().__init__(device, name)
|
||||
onnx_model = onnx.load_model(onnx_model_path)
|
||||
make_input_shape_fixed(onnx_model.graph, onnx_model.graph.input[0].name, self.input_shape)
|
||||
fix_output_shapes(onnx_model)
|
||||
self.model = convert(onnx_model).to(device)
|
||||
if self.device.type == "cuda":
|
||||
self.model = self.model.half()
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
|
||||
embedding: torch.FloatTensor = self.model(
|
||||
input_tensor.half() if self.device.type == "cuda" else input_tensor
|
||||
).float()
|
||||
assert isinstance(embedding, torch.FloatTensor)
|
||||
return embedding
|
||||
|
||||
def dummy_input(self) -> torch.FloatTensor:
|
||||
return torch.rand(self.input_shape, device=self.device)
|
||||
|
||||
|
||||
class RetinaFace(ExportBase):
|
||||
input_shape = (1, 3, 640, 640)
|
||||
|
||||
def __init__(self, onnx_model_path: str, device: torch.device):
|
||||
name, _ = os.path.splitext(os.path.basename(onnx_model_path))
|
||||
super().__init__(device, name)
|
||||
self.optimize = 3
|
||||
self.model = convert(onnx_model_path).eval().to(device)
|
||||
if self.device.type == "cuda":
|
||||
self.model = self.model.half()
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> tuple[torch.FloatTensor]:
|
||||
out: torch.Tensor = self.model(input_tensor.half() if self.device.type == "cuda" else input_tensor)
|
||||
return tuple(o.float() for o in out)
|
||||
|
||||
def dummy_input(self) -> torch.FloatTensor:
|
||||
return torch.rand(self.input_shape, device=self.device)
|
||||
|
||||
|
||||
class ClipVision(ExportBase):
|
||||
input_shape = (1, 3, 224, 224)
|
||||
|
||||
def __init__(self, model_name: str, weights: str, device: torch.device):
|
||||
super().__init__(device, model_name + "__" + weights)
|
||||
self.model = open_clip.create_model(
|
||||
model_name,
|
||||
weights,
|
||||
precision="fp16" if device.type == "cuda" else "fp32",
|
||||
jit=False,
|
||||
require_pretrained=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, input_tensor: torch.Tensor) -> torch.FloatTensor:
|
||||
embedding: torch.Tensor = self.model.encode_image(
|
||||
input_tensor.half() if self.device.type == "cuda" else input_tensor,
|
||||
normalize=True,
|
||||
).float()
|
||||
return embedding
|
||||
|
||||
|
||||
def export(model: ExportBase) -> None:
|
||||
model.eval()
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
dummy_input = model.dummy_input()
|
||||
model(dummy_input)
|
||||
jit = torch.jit.trace(model, dummy_input) # type: ignore[no-untyped-call,attr-defined]
|
||||
tflite_model_path = f"output/{model.name}.tflite"
|
||||
os.makedirs("output", exist_ok=True)
|
||||
|
||||
converter = TFLiteConverter(
|
||||
jit,
|
||||
dummy_input,
|
||||
tflite_model_path,
|
||||
optimize=model.optimize,
|
||||
nchw_transpose=model.nchw_transpose,
|
||||
)
|
||||
# segfaults on ARM, must run on x86_64 / AMD64
|
||||
converter.convert()
|
||||
|
||||
armnn_model_path = f"output/{model.name}.armnn"
|
||||
os.environ["LD_LIBRARY_PATH"] = "armnn"
|
||||
subprocess.run(
|
||||
[
|
||||
"./armnnconverter",
|
||||
"-f",
|
||||
"tflite-binary",
|
||||
"-m",
|
||||
tflite_model_path,
|
||||
"-i",
|
||||
"input_tensor",
|
||||
"-o",
|
||||
"output_tensor",
|
||||
"-p",
|
||||
armnn_model_path,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if platform.machine() not in ("x86_64", "AMD64"):
|
||||
raise RuntimeError(f"Can only run on x86_64 / AMD64, not {platform.machine()}")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if device.type != "cuda":
|
||||
logging.warning(
|
||||
"No CUDA available, cannot create fp16 model! proceeding to create a fp32 model (use only for testing)"
|
||||
)
|
||||
models = [
|
||||
ClipVision("ViT-B-32", "openai", device),
|
||||
ArcFace("buffalo_l_rec.onnx", device),
|
||||
RetinaFace("buffalo_l_det.onnx", device),
|
||||
]
|
||||
for model in models:
|
||||
export(model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.no_grad():
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue