mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
Fix Smart Search when using OpenVINO (#7389)
* Fix external_path loading in OpenVINO EP * Fix ruff lint * Wrap block in try finally * remove static input shape code * add unit test * remove unused imports * remove repeat line * linting * formatting --------- Co-authored-by: mertalev <101130780+mertalev@users.noreply.github.com>
This commit is contained in:
parent
912d723281
commit
2a75f884d9
6 changed files with 41 additions and 64 deletions
|
|
@ -1,18 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any
|
||||
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
|
||||
import ann.ann
|
||||
from app.models.constants import STATIC_INPUT_PROVIDERS, SUPPORTED_PROVIDERS
|
||||
from app.models.constants import SUPPORTED_PROVIDERS
|
||||
|
||||
from ..config import get_cache_dir, get_hf_model_name, log, settings
|
||||
from ..schemas import ModelRuntime, ModelType
|
||||
|
|
@ -113,63 +111,25 @@ class InferenceModel(ABC):
|
|||
)
|
||||
model_path = onnx_path
|
||||
|
||||
if any(provider in STATIC_INPUT_PROVIDERS for provider in self.providers):
|
||||
static_path = model_path.parent / "static_1" / "model.onnx"
|
||||
static_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if not static_path.is_file():
|
||||
self._convert_to_static(model_path, static_path)
|
||||
model_path = static_path
|
||||
|
||||
match model_path.suffix:
|
||||
case ".armnn":
|
||||
session = AnnSession(model_path)
|
||||
case ".onnx":
|
||||
session = ort.InferenceSession(
|
||||
model_path.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(model_path.parent)
|
||||
session = ort.InferenceSession(
|
||||
model_path.as_posix(),
|
||||
sess_options=self.sess_options,
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
)
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
|
||||
def _convert_to_static(self, source_path: Path, target_path: Path) -> None:
|
||||
inferred = infer_shapes(onnx.load(source_path))
|
||||
inputs = self._get_static_dims(inferred.graph.input)
|
||||
outputs = self._get_static_dims(inferred.graph.output)
|
||||
|
||||
# check_model gets called in update_inputs_outputs_dims and doesn't work for large models
|
||||
check_model = onnx.checker.check_model
|
||||
try:
|
||||
|
||||
def check_model_stub(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
onnx.checker.check_model = check_model_stub
|
||||
updated_model = update_inputs_outputs_dims(inferred, inputs, outputs)
|
||||
finally:
|
||||
onnx.checker.check_model = check_model
|
||||
|
||||
onnx.save(
|
||||
updated_model,
|
||||
target_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=False,
|
||||
size_threshold=1048576,
|
||||
)
|
||||
|
||||
def _get_static_dims(self, graph_io: Any, dim_size: int = 1) -> dict[str, list[int]]:
|
||||
return {
|
||||
field.name: [
|
||||
d.dim_value if d.HasField("dim_value") else dim_size
|
||||
for shape in field.type.ListFields()
|
||||
if (dim := shape[1].shape.dim)
|
||||
for d in dim
|
||||
]
|
||||
for field in graph_io
|
||||
}
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelType:
|
||||
return self._model_type
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue