feat: preloading of machine learning models (#7540)

This commit is contained in:
DawidPietrykowski 2024-03-04 01:48:56 +01:00 committed by GitHub
parent 762c4684f8
commit e8b001f62f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 75 additions and 49 deletions

View file

@ -17,7 +17,7 @@ from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
from .config import log, settings
from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
MessageResponse,
@ -27,7 +27,7 @@ from .schemas import (
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
model_cache = ModelCache(ttl=settings.model_ttl, revalidate=settings.model_ttl > 0)
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
@ -51,6 +51,8 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(settings.preload)
yield
finally:
log.handlers.clear()
@ -61,6 +63,14 @@ async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
gc.collect()
async def preload_models(preload_models: PreloadModelData) -> None:
log.info(f"Preloading models: {preload_models}")
if preload_models.clip is not None:
await load(await model_cache.get(preload_models.clip, ModelType.CLIP))
if preload_models.facial_recognition is not None:
await load(await model_cache.get(preload_models.facial_recognition, ModelType.FACIAL_RECOGNITION))
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
@ -103,7 +113,7 @@ async def predict(
except orjson.JSONDecodeError:
raise HTTPException(400, f"Invalid options JSON: {options}")
model = await load(await model_cache.get(model_name, model_type, **kwargs))
model = await load(await model_cache.get(model_name, model_type, ttl=settings.model_ttl, **kwargs))
model.configure(**kwargs)
outputs = await run(model.predict, inputs)
return ORJSONResponse(outputs)