mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
feat: preloading of machine learning models (#7540)
This commit is contained in:
parent
762c4684f8
commit
e8b001f62f
6 changed files with 75 additions and 49 deletions
|
|
@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
|
|||
|
||||
from app.models.base import InferenceModel
|
||||
|
||||
from .config import log, settings
|
||||
from .config import PreloadModelData, log, settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
MessageResponse,
|
||||
|
|
@ -27,7 +27,7 @@ from .schemas import (
|
|||
|
||||
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
||||
|
||||
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
|
||||
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
|
||||
thread_pool: ThreadPoolExecutor | None = None
|
||||
lock = threading.Lock()
|
||||
active_requests = 0
|
||||
|
|
@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
|||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
||||
asyncio.ensure_future(idle_shutdown_task())
|
||||
if settings.preload is not None:
|
||||
await preload_models(settings.preload)
|
||||
yield
|
||||
finally:
|
||||
log.handlers.clear()
|
||||
|
|
@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
|||
gc.collect()
|
||||
|
||||
|
||||
async def preload_models(preload_models: PreloadModelData) -> None:
|
||||
log.info(f"Preloading models: {preload_models}")
|
||||
if preload_models.clip is not None:
|
||||
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
|
||||
if preload_models.facial_recognition is not None:
|
||||
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
|
||||
|
||||
|
||||
def update_state() -> Iterator[None]:
|
||||
global active_requests, last_called
|
||||
active_requests += 1
|
||||
|
|
@ -103,7 +113,7 @@ async def predict(
|
|||
except orjson.JSONDecodeError:
|
||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||
|
||||
model = await load(await model_cache.get(model_name, model_type, **kwargs))
|
||||
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model.predict, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue