feat(ml)!: customizable ML settings (#3891)

* consolidated endpoints, added live configuration

* added ml settings to server

* added settings dashboard

* updated deps, fixed typos

* simplified modelconfig

updated tests

* Added ml setting accordion for admin page

updated tests

* merge `clipText` and `clipVision`

* added face distance setting

clarified setting

* add clip mode in request, dropdown for face models

* polished ml settings

updated descriptions

* update clip field on error

* removed unused import

* add description for image classification threshold

* pin safetensors for arm wheel

updated poetry lock

* moved dto

* set model type only in ml repository

* revert form-data package install

use fetch instead of axios

* added slotted description with link

updated facial recognition description

clarified effect of disabling tasks

* validation before model load

* removed unnecessary getconfig call

* added migration

* updated api

updated api

updated api

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Mert 2023-08-29 09:58:00 -04:00 committed by GitHub
parent 22f5e05060
commit bcc36d14a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 2324 additions and 655 deletions

View file

@ -1,29 +1,26 @@
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
import cv2
import numpy as np
import orjson
import uvicorn
from fastapi import Body, Depends, FastAPI
from PIL import Image
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import ORJSONResponse
from starlette.formparsers import MultiPartParser
from app.models.base import InferenceModel
from .config import settings
from .models.cache import ModelCache
from .schemas import (
EmbeddingResponse,
FaceResponse,
MessageResponse,
ModelType,
TagResponse,
TextModelRequest,
TextResponse,
)
MultiPartParser.max_file_size = 2**24 # spools to disk if payload is 16 MiB or larger
app = FastAPI()
@ -33,37 +30,9 @@ def init_state() -> None:
app.state.thread_pool = ThreadPoolExecutor(settings.request_threads)
async def load_models() -> None:
models: list[tuple[str, ModelType, dict[str, Any]]] = [
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION, {}),
(settings.clip_image_model, ModelType.CLIP, {"mode": "vision"}),
(settings.clip_text_model, ModelType.CLIP, {"mode": "text"}),
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION, {}),
]
# Get all models
for model_name, model_type, model_kwargs in models:
await app.state.model_cache.get(model_name, model_type, eager=settings.eager_startup, **model_kwargs)
@app.on_event("startup")
async def startup_event() -> None:
init_state()
await load_models()
@app.on_event("shutdown")
async def shutdown_event() -> None:
app.state.thread_pool.shutdown()
def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
return Image.open(BytesIO(byte_image))
def dep_cv_image(byte_image: bytes = Body(...)) -> np.ndarray[int, np.dtype[Any]]:
byte_image_np = np.frombuffer(byte_image, np.uint8)
return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
@app.get("/", response_model=MessageResponse)
@ -76,57 +45,27 @@ def ping() -> str:
return "pong"
@app.post(
"/image-classifier/tag-image",
response_model=TagResponse,
status_code=200,
)
async def image_classification(
image: Image.Image = Depends(dep_pil_image),
) -> list[str]:
model = await app.state.model_cache.get(settings.classification_model, ModelType.IMAGE_CLASSIFICATION)
labels = await predict(model, image)
return labels
@app.post("/predict")
async def predict(
model_name: str = Form(alias="modelName"),
model_type: ModelType = Form(alias="modelType"),
options: str = Form(default="{}"),
text: str | None = Form(default=None),
image: UploadFile | None = None,
) -> Any:
if image is not None:
inputs: str | bytes = await image.read()
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
model: InferenceModel = await app.state.model_cache.get(model_name, model_type, **orjson.loads(options))
outputs = await run(model, inputs)
return ORJSONResponse(outputs)
@app.post(
"/sentence-transformer/encode-image",
response_model=EmbeddingResponse,
status_code=200,
)
async def clip_encode_image(
image: Image.Image = Depends(dep_pil_image),
) -> list[float]:
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP, mode="vision")
embedding = await predict(model, image)
return embedding
@app.post(
"/sentence-transformer/encode-text",
response_model=EmbeddingResponse,
status_code=200,
)
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP, mode="text")
embedding = await predict(model, payload.text)
return embedding
@app.post(
"/facial-recognition/detect-faces",
response_model=FaceResponse,
status_code=200,
)
async def facial_recognition(
image: cv2.Mat = Depends(dep_cv_image),
) -> list[dict[str, Any]]:
model = await app.state.model_cache.get(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION)
faces = await predict(model, image)
return faces
async def predict(model: InferenceModel, inputs: Any) -> Any:
async def run(model: InferenceModel, inputs: Any) -> Any:
return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs)