mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
feat: preloading of machine learning models (#7540)
This commit is contained in:
parent
762c4684f8
commit
e8b001f62f
6 changed files with 75 additions and 49 deletions
|
|
@ -2,7 +2,7 @@ from typing import Any
|
|||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
from aiocache.plugins import BasePlugin, TimingPlugin
|
||||
from aiocache.plugins import TimingPlugin
|
||||
|
||||
from app.models import from_model_type
|
||||
|
||||
|
|
@ -15,28 +15,25 @@ class ModelCache:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
ttl: float | None = None,
|
||||
revalidate: bool = False,
|
||||
timeout: int | None = None,
|
||||
profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
|
||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||
"""
|
||||
|
||||
self.ttl = ttl
|
||||
plugins = []
|
||||
|
||||
if revalidate:
|
||||
plugins.append(RevalidationPlugin())
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
|
||||
self.revalidate_enable = revalidate
|
||||
|
||||
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||
|
||||
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
|
||||
"""
|
||||
|
|
@ -49,11 +46,14 @@ class ModelCache:
|
|||
"""
|
||||
|
||||
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
|
||||
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model: InferenceModel | None = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = from_model_type(model_type, model_name, **model_kwargs)
|
||||
await lock.cas(model, ttl=self.ttl)
|
||||
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||
elif self.revalidate_enable:
|
||||
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
|
|
@ -62,21 +62,6 @@ class ModelCache:
|
|||
|
||||
return self.cache.profiling
|
||||
|
||||
|
||||
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
|
||||
"""Revalidates cache item's TTL after cache hit."""
|
||||
|
||||
async def post_get(
|
||||
self,
|
||||
client: SimpleMemoryCache,
|
||||
key: str,
|
||||
ret: Any | None = None,
|
||||
namespace: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if ret is None:
|
||||
return
|
||||
if namespace is not None:
|
||||
key = client.build_key(key, namespace)
|
||||
if key in client._handlers:
|
||||
await client.expire(key, client.ttl)
|
||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||
if ttl is not None and key in self.cache._handlers:
|
||||
await self.cache.expire(key, ttl)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue