docs: model benchmarks (#17036)

* model benchmarks

* minor fixes

* formatting

* docs build

* maybe fix reference

* clarify optimal

* use emojis

* wording

* wording

* clarify optimal wording

* bolding

* more detailed instructions

* clarify edge case fix

* early exit in dim loop
This commit is contained in:
Mert 2025-03-24 12:02:33 -04:00 committed by GitHub
parent ad151130f9
commit 4bfef2460a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 2209 additions and 255 deletions

View file

@ -0,0 +1,165 @@
import json
import resource
from pathlib import Path
import typer
from tenacity import retry, stop_after_attempt, wait_fixed
from typing_extensions import Annotated
from .exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource, ModelTask
from .exporters.onnx import export as onnx_export
from .exporters.rknn import export as rknn_export
app = typer.Typer(pretty_exceptions_show_locals=False)
def generate_readme(model_name: str, model_source: ModelSource) -> str:
(name, link, type) = SOURCE_TO_METADATA[model_source]
match model_source:
case ModelSource.MCLIP:
tags = ["immich", "clip", "multilingual"]
case ModelSource.OPENCLIP:
tags = ["immich", "clip"]
lowered = model_name.lower()
if "xlm" in lowered or "nllb" in lowered:
tags.append("multilingual")
case ModelSource.INSIGHTFACE:
tags = ["immich", "facial-recognition"]
case _:
raise ValueError(f"Unsupported model source {model_source}")
return f"""---
tags:
{" - " + "\n - ".join(tags)}
---
# Model Description
This repo contains ONNX exports for the associated {type} model by {name}. See the [{name}]({link}) repo for more info.
This repo is specifically intended for use with [Immich](https://immich.app/), a self-hosted photo library.
"""
def clean_name(model_name: str) -> str:
hf_model_name = model_name.split("/")[-1]
hf_model_name = hf_model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
hf_model_name = hf_model_name.replace("xlm-roberta-base", "XLM-Roberta-Base")
return hf_model_name
@app.command()
def export(model_name: str, model_source: ModelSource, output_dir: Path = Path("models"), cache: bool = True) -> None:
hf_model_name = clean_name(model_name)
output_dir = output_dir / hf_model_name
match model_source:
case ModelSource.MCLIP | ModelSource.OPENCLIP:
output_dir.mkdir(parents=True, exist_ok=True)
onnx_export(model_name, model_source, output_dir, cache=cache)
case ModelSource.INSIGHTFACE:
from huggingface_hub import snapshot_download
# TODO: start from insightface dump instead of downloading from HF
snapshot_download(f"immich-app/{hf_model_name}", local_dir=output_dir)
case _:
raise ValueError(f"Unsupported model source {model_source}")
try:
rknn_export(output_dir, cache=cache)
except Exception as e:
print(f"Failed to export model {model_name} to rknn: {e}")
(output_dir / "rknpu").unlink(missing_ok=True)
readme_path = output_dir / "README.md"
if not (cache or readme_path.exists()):
with open(readme_path, "w") as f:
f.write(generate_readme(model_name, model_source))
@app.command()
def profile(model_dir: Path, model_task: ModelTask, output_path: Path) -> None:
from timeit import timeit
import numpy as np
import onnxruntime as ort
np.random.seed(0)
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False
providers = ["CPUExecutionProvider"]
provider_options = [{"arena_extend_strategy": "kSameAsRequested"}]
match model_task:
case ModelTask.SEARCH:
textual = ort.InferenceSession(
model_dir / "textual" / "model.onnx",
sess_options=sess_options,
providers=providers,
provider_options=provider_options,
)
tokens = {node.name: np.random.rand(*node.shape).astype(np.int32) for node in textual.get_inputs()}
visual = ort.InferenceSession(
model_dir / "visual" / "model.onnx",
sess_options=sess_options,
providers=providers,
provider_options=provider_options,
)
image = {node.name: np.random.rand(*node.shape).astype(np.float32) for node in visual.get_inputs()}
def predict() -> None:
textual.run(None, tokens)
visual.run(None, image)
case ModelTask.FACIAL_RECOGNITION:
detection = ort.InferenceSession(
model_dir / "detection" / "model.onnx",
sess_options=sess_options,
providers=providers,
provider_options=provider_options,
)
image = {node.name: np.random.rand(1, 3, 640, 640).astype(np.float32) for node in detection.get_inputs()}
recognition = ort.InferenceSession(
model_dir / "recognition" / "model.onnx",
sess_options=sess_options,
providers=providers,
provider_options=provider_options,
)
face = {node.name: np.random.rand(1, 3, 112, 112).astype(np.float32) for node in recognition.get_inputs()}
def predict() -> None:
detection.run(None, image)
recognition.run(None, face)
case _:
raise ValueError(f"Unsupported model task {model_task}")
predict()
ms = timeit(predict, number=100)
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
json.dump({"pretrained_model": model_dir.name, "peak_rss": rss, "exec_time_ms": ms}, output_path.open("w"))
print(f"Model {model_dir.name} took {ms:.2f}ms per iteration using {rss / 1024:.2f}MiB of memory")
@app.command()
def upload(
model_dir: Path,
hf_organization: str = "immich-app",
hf_auth_token: Annotated[str | None, typer.Option(envvar="HF_AUTH_TOKEN")] = None,
) -> None:
from huggingface_hub import create_repo, upload_folder
repo_id = f"{hf_organization}/{model_dir.name}"
@retry(stop=stop_after_attempt(5), wait=wait_fixed(5))
def upload_model() -> None:
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
upload_folder(
repo_id=repo_id,
folder_path=model_dir,
# remote repo files to be deleted before uploading
# deletion is in the same commit as the upload, so it's atomic
delete_patterns=DELETE_PATTERNS,
token=hf_auth_token,
)
upload_model()

View file

@ -0,0 +1,3 @@
from immich_model_exporter import app
app()

View file

@ -1,98 +0,0 @@
from pathlib import Path
import typer
from tenacity import retry, stop_after_attempt, wait_fixed
from typing_extensions import Annotated
from .exporters.constants import DELETE_PATTERNS, SOURCE_TO_METADATA, ModelSource
from .exporters.onnx import export as onnx_export
from .exporters.rknn import export as rknn_export
app = typer.Typer(pretty_exceptions_show_locals=False)
def generate_readme(model_name: str, model_source: ModelSource) -> str:
(name, link, type) = SOURCE_TO_METADATA[model_source]
match model_source:
case ModelSource.MCLIP:
tags = ["immich", "clip", "multilingual"]
case ModelSource.OPENCLIP:
tags = ["immich", "clip"]
lowered = model_name.lower()
if "xlm" in lowered or "nllb" in lowered:
tags.append("multilingual")
case ModelSource.INSIGHTFACE:
tags = ["immich", "facial-recognition"]
case _:
raise ValueError(f"Unsupported model source {model_source}")
return f"""---
tags:
{" - " + "\n - ".join(tags)}
---
# Model Description
This repo contains ONNX exports for the associated {type} model by {name}. See the [{name}]({link}) repo for more info.
This repo is specifically intended for use with [Immich](https://immich.app/), a self-hosted photo library.
"""
@app.command()
def main(
model_name: str,
model_source: ModelSource,
output_dir: Path = Path("./models"),
no_cache: bool = False,
hf_organization: str = "immich-app",
hf_auth_token: Annotated[str | None, typer.Option(envvar="HF_AUTH_TOKEN")] = None,
) -> None:
hf_model_name = model_name.split("/")[-1]
hf_model_name = hf_model_name.replace("xlm-roberta-large", "XLM-Roberta-Large")
hf_model_name = hf_model_name.replace("xlm-roberta-base", "XLM-Roberta-Base")
output_dir = output_dir / hf_model_name
match model_source:
case ModelSource.MCLIP | ModelSource.OPENCLIP:
output_dir.mkdir(parents=True, exist_ok=True)
onnx_export(model_name, model_source, output_dir, no_cache=no_cache)
case ModelSource.INSIGHTFACE:
from huggingface_hub import snapshot_download
# TODO: start from insightface dump instead of downloading from HF
snapshot_download(f"immich-app/{hf_model_name}", local_dir=output_dir)
case _:
raise ValueError(f"Unsupported model source {model_source}")
try:
rknn_export(output_dir, no_cache=no_cache)
except Exception as e:
print(f"Failed to export model {model_name} to rknn: {e}")
(output_dir / "rknpu").unlink(missing_ok=True)
readme_path = output_dir / "README.md"
if no_cache or not readme_path.exists():
with open(readme_path, "w") as f:
f.write(generate_readme(model_name, model_source))
if hf_auth_token is not None:
from huggingface_hub import create_repo, upload_folder
repo_id = f"{hf_organization}/{hf_model_name}"
@retry(stop=stop_after_attempt(5), wait=wait_fixed(5))
def upload_model() -> None:
create_repo(repo_id, exist_ok=True, token=hf_auth_token)
upload_folder(
repo_id=repo_id,
folder_path=output_dir,
# remote repo files to be deleted before uploading
# deletion is in the same commit as the upload, so it's atomic
delete_patterns=DELETE_PATTERNS,
token=hf_auth_token,
)
upload_model()
if __name__ == "__main__":
typer.run(main)

View file

@ -8,6 +8,11 @@ class ModelSource(StrEnum):
OPENCLIP = "openclip"
class ModelTask(StrEnum):
FACIAL_RECOGNITION = "facial-recognition"
SEARCH = "clip"
class SourceMetadata(NamedTuple):
name: str
link: str
@ -22,6 +27,13 @@ SOURCE_TO_METADATA = {
),
}
SOURCE_TO_TASK = {
ModelSource.MCLIP: ModelTask.SEARCH,
ModelSource.OPENCLIP: ModelTask.SEARCH,
ModelSource.INSIGHTFACE: ModelTask.FACIAL_RECOGNITION,
}
RKNN_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]

View file

@ -5,16 +5,16 @@ from .models import mclip, openclip
def export(
model_name: str, model_source: ModelSource, output_dir: Path, opset_version: int = 19, no_cache: bool = False
model_name: str, model_source: ModelSource, output_dir: Path, opset_version: int = 19, cache: bool = True
) -> None:
visual_dir = output_dir / "visual"
textual_dir = output_dir / "textual"
match model_source:
case ModelSource.MCLIP:
mclip.to_onnx(model_name, opset_version, visual_dir, textual_dir, no_cache=no_cache)
mclip.to_onnx(model_name, opset_version, visual_dir, textual_dir, cache=cache)
case ModelSource.OPENCLIP:
name, _, pretrained = model_name.partition("__")
config = openclip.OpenCLIPModelConfig(name, pretrained)
openclip.to_onnx(config, opset_version, visual_dir, textual_dir, no_cache=no_cache)
openclip.to_onnx(config, opset_version, visual_dir, textual_dir, cache=cache)
case _:
raise ValueError(f"Unsupported model source {model_source}")

View file

@ -19,10 +19,10 @@ def to_onnx(
opset_version: int,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
no_cache: bool = False,
cache: bool = True,
) -> tuple[Path, Path]:
textual_path = get_model_path(output_dir_textual)
if no_cache or not textual_path.exists():
if not cache or not textual_path.exists():
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
from transformers import AutoTokenizer
@ -39,9 +39,7 @@ def to_onnx(
_export_text_encoder(model, textual_path, opset_version)
else:
print(f"Model {textual_path} already exists, skipping")
visual_path, _ = openclip_to_onnx(
_MCLIP_TO_OPENCLIP[model_name], opset_version, output_dir_visual, no_cache=no_cache
)
visual_path, _ = openclip_to_onnx(_MCLIP_TO_OPENCLIP[model_name], opset_version, output_dir_visual, cache=cache)
assert visual_path is not None, "Visual model export failed"
return visual_path, textual_path

View file

@ -37,7 +37,7 @@ def to_onnx(
opset_version: int,
output_dir_visual: Path | str | None = None,
output_dir_textual: Path | str | None = None,
no_cache: bool = False,
cache: bool = True,
) -> tuple[Path | None, Path | None]:
visual_path = None
textual_path = None
@ -49,9 +49,7 @@ def to_onnx(
output_dir_textual = Path(output_dir_textual)
textual_path = get_model_path(output_dir_textual)
if not no_cache and (
(textual_path is None or textual_path.exists()) and (visual_path is None or visual_path.exists())
):
if cache and ((textual_path is None or textual_path.exists()) and (visual_path is None or visual_path.exists())):
print(f"Models {textual_path} and {visual_path} already exist, skipping")
return visual_path, textual_path
@ -75,7 +73,7 @@ def to_onnx(
param.requires_grad_(False)
if visual_path is not None and output_dir_visual is not None:
if no_cache or not visual_path.exists():
if not cache or not visual_path.exists():
save_config(
open_clip.get_model_preprocess_cfg(model),
output_dir_visual / "preprocess_cfg.json",
@ -86,7 +84,7 @@ def to_onnx(
print(f"Model {visual_path} already exists, skipping")
if textual_path is not None and output_dir_textual is not None:
if no_cache or not textual_path.exists():
if not cache or not textual_path.exists():
tokenizer_name = text_vision_cfg["text_cfg"].get("hf_tokenizer_name", "openai/clip-vit-base-patch32")
AutoTokenizer.from_pretrained(tokenizer_name).save_pretrained(output_dir_textual)
_export_text_encoder(model, model_cfg, textual_path, opset_version)

View file

@ -9,13 +9,13 @@ def _export_platform(
inputs: list[str] | None = None,
input_size_list: list[list[int]] | None = None,
fuse_matmul_softmax_matmul_to_sdpa: bool = True,
no_cache: bool = False,
cache: bool = True,
) -> None:
from rknn.api import RKNN
input_path = model_dir / "model.onnx"
output_path = model_dir / "rknpu" / target_platform / "model.rknn"
if not no_cache and output_path.exists():
if cache and output_path.exists():
print(f"Model {input_path} already exists at {output_path}, skipping")
return
@ -49,7 +49,7 @@ def _export_platforms(
model_dir: Path,
inputs: list[str] | None = None,
input_size_list: list[list[int]] | None = None,
no_cache: bool = False,
cache: bool = True,
) -> None:
fuse_matmul_softmax_matmul_to_sdpa = True
for soc in RKNN_SOCS:
@ -60,7 +60,7 @@ def _export_platforms(
inputs=inputs,
input_size_list=input_size_list,
fuse_matmul_softmax_matmul_to_sdpa=fuse_matmul_softmax_matmul_to_sdpa,
no_cache=no_cache,
cache=cache,
)
except Exception as e:
print(f"Failed to export model for {soc}: {e}")
@ -73,24 +73,24 @@ def _export_platforms(
inputs=inputs,
input_size_list=input_size_list,
fuse_matmul_softmax_matmul_to_sdpa=fuse_matmul_softmax_matmul_to_sdpa,
no_cache=no_cache,
cache=cache,
)
def export(model_dir: Path, no_cache: bool = False) -> None:
def export(model_dir: Path, cache: bool = True) -> None:
textual = model_dir / "textual"
visual = model_dir / "visual"
detection = model_dir / "detection"
recognition = model_dir / "recognition"
if textual.is_dir():
_export_platforms(textual, no_cache=no_cache)
_export_platforms(textual, cache=cache)
if visual.is_dir():
_export_platforms(visual, no_cache=no_cache)
_export_platforms(visual, cache=cache)
if detection.is_dir():
_export_platforms(detection, inputs=["input.1"], input_size_list=[[1, 3, 640, 640]], no_cache=no_cache)
_export_platforms(detection, inputs=["input.1"], input_size_list=[[1, 3, 640, 640]], cache=cache)
if recognition.is_dir():
_export_platforms(recognition, inputs=["input.1"], input_size_list=[[1, 3, 112, 112]], no_cache=no_cache)
_export_platforms(recognition, inputs=["input.1"], input_size_list=[[1, 3, 112, 112]], cache=cache)

View file

@ -0,0 +1,22 @@
import json
from pathlib import Path
models_dir = Path("models")
model_to_embed_dim = {}
for model_dir in models_dir.iterdir():
if not model_dir.is_dir():
continue
config_path = model_dir / "config.json"
if not config_path.exists():
print(f"Skipping {model_dir.name} as it does not have a config.json")
continue
with open(config_path) as f:
config = json.load(f)
embed_dim = config.get("embed_dim")
if embed_dim is None:
print(f"Skipping {model_dir.name} as it does not have an embed_dim")
continue
print(f"{model_dir.name}: {embed_dim}")
model_to_embed_dim[model_dir.name] = {"dimSize": embed_dim}
print(json.dumps(model_to_embed_dim))

View file

@ -0,0 +1,121 @@
import polars as pl
def collapsed_table(language: str, df: pl.DataFrame) -> str:
with pl.Config(
tbl_formatting="ASCII_MARKDOWN",
tbl_hide_column_data_types=True,
tbl_hide_dataframe_shape=True,
fmt_str_lengths=100,
tbl_rows=1000,
tbl_width_chars=1000,
):
return f"<details>\n<summary>{language}</summary>\n{str(df)}\n</details>"
languages = {
"en": "English",
"ar": "Arabic",
"bn": "Bengali",
"zh": "Chinese (Simplified)",
"hr": "Croatian",
"quz": "Cusco Quechua",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"fil": "Filipino",
"fi": "Finnish",
"fr": "French",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"ko": "Korean",
"mi": "Maori",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"te": "Telugu",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"vi": "Vietnamese",
}
profile_df = pl.scan_ndjson("profiling/*.json").select("pretrained_model", "peak_rss", "exec_time_ms")
eval_df = pl.scan_ndjson("results/*.json").select("model", "pretrained", "language", "metrics")
eval_df = eval_df.with_columns(
model=pl.col("model")
.str.replace("xlm-roberta-base", "XLM-Roberta-Base")
.str.replace("xlm-roberta-large", "XLM-Roberta-Large")
)
eval_df = eval_df.with_columns(pretrained_model=pl.concat_str(pl.col("model"), pl.col("pretrained"), separator="__"))
eval_df = eval_df.drop("model", "pretrained")
eval_df = eval_df.join(profile_df, on="pretrained_model")
eval_df = eval_df.with_columns(
recall=(
pl.col("metrics").struct.field("image_retrieval_recall@1")
+ pl.col("metrics").struct.field("image_retrieval_recall@5")
+ pl.col("metrics").struct.field("image_retrieval_recall@10")
)
* (100 / 3)
)
pareto_front = eval_df.join_where(
eval_df.select("language", "peak_rss", "exec_time_ms", "recall").rename(
{
"language": "language_other",
"peak_rss": "peak_rss_other",
"exec_time_ms": "exec_time_ms_other",
"recall": "recall_other",
}
),
(pl.col("language") == pl.col("language_other"))
& (pl.col("peak_rss_other") <= pl.col("peak_rss"))
& (pl.col("exec_time_ms_other") <= pl.col("exec_time_ms"))
& (pl.col("recall_other") >= pl.col("recall"))
& (
(pl.col("peak_rss_other") < pl.col("peak_rss"))
| (pl.col("exec_time_ms_other") < pl.col("exec_time_ms"))
| (pl.col("recall_other") > pl.col("recall"))
),
)
eval_df = eval_df.join(pareto_front, on=["pretrained_model", "language"], how="left")
eval_df = eval_df.with_columns(is_pareto=pl.col("recall_other").is_null())
eval_df = (
eval_df.drop("peak_rss_other", "exec_time_ms_other", "recall_other", "language_other")
.unique(subset=["pretrained_model", "language"])
.collect()
)
eval_df.write_parquet("model_info.parquet")
eval_df = eval_df.drop("metrics")
eval_df = eval_df.filter(pl.col("recall") >= 20)
eval_df = eval_df.sort("recall", descending=True)
eval_df = eval_df.select(
pl.col("pretrained_model").alias("Model"),
(pl.col("peak_rss") / 1024).round().cast(pl.UInt32).alias("Memory (MiB)"),
pl.col("exec_time_ms").round(2).alias("Execution Time (ms)"),
pl.col("language").alias("Language"),
pl.col("recall").round(2).alias("Recall (%)"),
pl.when(pl.col("is_pareto")).then(pl.lit("")).otherwise(pl.lit("")).alias("Pareto Optimal"),
)
for language in languages:
lang_df = eval_df.filter(pl.col("Language") == language).drop("Language")
if lang_df.shape[0] == 0:
continue
print(collapsed_table(languages[language], lang_df))

View file

@ -1,7 +1,11 @@
import subprocess
from pathlib import Path
from exporters.constants import ModelSource
from immich_model_exporter import clean_name
from immich_model_exporter.exporters.constants import SOURCE_TO_TASK
mclip = [
"M-CLIP/LABSE-Vit-L-14",
"M-CLIP/XLM-Roberta-Large-Vit-B-16Plus",
@ -74,10 +78,28 @@ insightface = [
def export_models(models: list[str], source: ModelSource) -> None:
profiling_dir = Path("profiling")
profiling_dir.mkdir(exist_ok=True)
for model in models:
try:
print(f"Exporting model {model}")
subprocess.check_call(["python", "-m", "immich_model_exporter.export", model, source])
model_dir = f"models/{clean_name(model)}"
task = SOURCE_TO_TASK[source]
print(f"Processing model {model}")
subprocess.check_call(["python", "-m", "immich_model_exporter", "export", model, source])
subprocess.check_call(
[
"python",
"-m",
"immich_model_exporter",
"profile",
model_dir,
task,
"--output_path",
profiling_dir / f"{model}.json",
]
)
subprocess.check_call(["python", "-m", "immich_model_exporter", "upload", model_dir])
except Exception as e:
print(f"Failed to export model {model}: {e}")
@ -86,3 +108,64 @@ if __name__ == "__main__":
export_models(mclip, ModelSource.MCLIP)
export_models(openclip, ModelSource.OPENCLIP)
export_models(insightface, ModelSource.INSIGHTFACE)
Path("results").mkdir(exist_ok=True)
subprocess.check_call(
[
"python",
"clip_benchmark",
"eval",
"--pretrained_model",
*[name.replace("__", ",") for name in openclip],
"--task",
"zeroshot_retrieval",
"--dataset",
"crossmodal3600",
"--batch_size",
"64",
"--language",
"ar",
"bn",
"cs",
"da",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fil",
"fr",
"he",
"hi",
"hr",
"hu",
"id",
"it",
"ja",
"ko",
"mi",
"nl",
"no",
"pl",
"pt",
"quz",
"ro",
"ru",
"sv",
"sw",
"te",
"th",
"tr",
"uk",
"vi",
"zh",
"--recall_k",
"1",
"5",
"10",
"--no_amp",
"--output",
"results/{dataset}_{language}_{model}_{pretrained}.json",
]
)