feat: preloading of machine learning models (#7540)

This commit is contained in:
DawidPietrykowski 2024-03-04 01:48:56 +01:00 committed by GitHub
parent 762c4684f8
commit e8b001f62f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 75 additions and 49 deletions

View file

@ -2,7 +2,7 @@ from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import BasePlugin, TimingPlugin
from aiocache.plugins import TimingPlugin
from app.models import from_model_type
@ -15,28 +15,25 @@ class ModelCache:
def __init__(
self,
ttl: float | None = None,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
ttl: Unloads model after this duration. Disabled if None. Defaults to None.
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""
self.ttl = ttl
plugins = []
if revalidate:
plugins.append(RevalidationPlugin())
if profiling:
plugins.append(TimingPlugin())
self.cache = SimpleMemoryCache(ttl=ttl, timeout=timeout, plugins=plugins, namespace=None)
self.revalidate_enable = revalidate
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
async def get(self, model_name: str, model_type: ModelType, **model_kwargs: Any) -> InferenceModel:
"""
@ -49,11 +46,14 @@ class ModelCache:
"""
key = f"{model_name}{model_type.value}{model_kwargs.get('mode', '')}"
async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_type, model_name, **model_kwargs)
await lock.cas(model, ttl=self.ttl)
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.revalidate_enable:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model
async def get_profiling(self) -> dict[str, float] | None:
@ -62,21 +62,6 @@ class ModelCache:
return self.cache.profiling
class RevalidationPlugin(BasePlugin): # type: ignore[misc]
"""Revalidates cache item's TTL after cache hit."""
async def post_get(
self,
client: SimpleMemoryCache,
key: str,
ret: Any | None = None,
namespace: str | None = None,
**kwargs: Any,
) -> None:
if ret is None:
return
if namespace is not None:
key = client.build_key(key, namespace)
if key in client._handlers:
await client.expire(key, client.ttl)
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)