mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
fix(ml): load models in separate threads (#4034)
* load models in thread * set clip mode logs to debug level * updated tests * made fixtures slightly less ugly * moved responses to json file * formatting
This commit is contained in:
parent
f1db257628
commit
258b98c262
9 changed files with 1683 additions and 114 deletions
|
|
@ -1,10 +1,13 @@
|
|||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import orjson
|
||||
from fastapi import FastAPI, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile # type: ignore
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from app.models.base import InferenceModel
|
||||
|
|
@ -31,6 +34,7 @@ def init_state() -> None:
|
|||
)
|
||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
||||
app.state.locks = {model_type: threading.Lock() for model_type in ModelType}
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
|
||||
|
||||
|
|
@ -63,14 +67,49 @@ async def predict(
|
|||
inputs = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
try:
|
||||
kwargs = orjson.loads(options)
|
||||
except orjson.JSONDecodeError:
|
||||
raise HTTPException(400, f"Invalid options JSON: {options}")
|
||||
|
||||
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
|
||||
model = await load(await app.state.model_cache.get(model_name, model_type, **kwargs))
|
||||
model.configure(**kwargs)
|
||||
outputs = await run(model, inputs)
|
||||
return ORJSONResponse(outputs)
|
||||
|
||||
|
||||
async def run(model: InferenceModel, inputs: Any) -> Any:
|
||||
if app.state.thread_pool is not None:
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
else:
|
||||
if app.state.thread_pool is None:
|
||||
return model.predict(inputs)
|
||||
|
||||
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)
|
||||
|
||||
|
||||
async def load(model: InferenceModel) -> InferenceModel:
|
||||
if model.loaded:
|
||||
return model
|
||||
|
||||
def _load() -> None:
|
||||
with app.state.locks[model.model_type]:
|
||||
model.load()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
if app.state.thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||
return model
|
||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||
log.warn(
|
||||
(
|
||||
f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'."
|
||||
"Clearing cache and retrying."
|
||||
)
|
||||
)
|
||||
model.clear_cache()
|
||||
if app.state.thread_pool is None:
|
||||
model.load()
|
||||
else:
|
||||
await loop.run_in_executor(app.state.thread_pool, _load)
|
||||
return model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue