refactor(ml): modularization and styling (#2835)

* basic refactor and styling

* removed batching

* module entrypoint

* removed unused imports

* model superclass,  model cache now in app state

* fixed cache dir and enforced abstract method

---------

Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
This commit is contained in:
Mert 2023-06-24 23:18:09 -04:00 committed by GitHub
parent 837ad24f58
commit a2f5674bbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 281 additions and 182 deletions

View file

@ -1,52 +1,58 @@
import os
import io
from io import BytesIO
from typing import Any
from cache import ModelCache
from schemas import (
import cv2
import numpy as np
import uvicorn
from fastapi import Body, Depends, FastAPI
from PIL import Image
from .config import settings
from .models.base import InferenceModel
from .models.cache import ModelCache
from .schemas import (
EmbeddingResponse,
FaceResponse,
TagResponse,
MessageResponse,
ModelType,
TagResponse,
TextModelRequest,
TextResponse,
)
import uvicorn
from PIL import Image
from fastapi import FastAPI, HTTPException, Depends, Body
from models import get_model, run_classification, run_facial_recognition
from config import settings
_model_cache = None
app = FastAPI()
@app.on_event("startup")
async def startup_event() -> None:
global _model_cache
_model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
same_clip = settings.clip_image_model == settings.clip_text_model
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
models = [
(settings.classification_model, "image-classification"),
(settings.clip_image_model, "clip"),
(settings.clip_text_model, "clip"),
(settings.facial_recognition_model, "facial-recognition"),
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
(settings.clip_image_model, app.state.clip_vision_type),
(settings.clip_text_model, app.state.clip_text_type),
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
]
# Get all models
for model_name, model_type in models:
if settings.eager_startup:
await _model_cache.get_cached_model(model_name, model_type)
await app.state.model_cache.get(model_name, model_type)
else:
get_model(model_name, model_type)
InferenceModel.from_model_type(model_type, model_name)
def dep_model_cache():
if _model_cache is None:
raise HTTPException(status_code=500, detail="Unable to load model.")
def dep_pil_image(byte_image: bytes = Body(...)) -> Image.Image:
return Image.open(BytesIO(byte_image))
def dep_cv_image(byte_image: bytes = Body(...)) -> cv2.Mat:
byte_image_np = np.frombuffer(byte_image, np.uint8)
return cv2.imdecode(byte_image_np, cv2.IMREAD_COLOR)
def dep_input_image(image: bytes = Body(...)) -> Image:
return Image.open(io.BytesIO(image))
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
@ -62,33 +68,29 @@ def ping() -> str:
"/image-classifier/tag-image",
response_model=TagResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def image_classification(
image: Image = Depends(dep_input_image)
image: Image.Image = Depends(dep_pil_image),
) -> list[str]:
try:
model = await _model_cache.get_cached_model(
settings.classification_model, "image-classification"
)
labels = run_classification(model, image, settings.min_tag_score)
except Exception as ex:
raise HTTPException(status_code=500, detail=str(ex))
else:
return labels
model = await app.state.model_cache.get(
settings.classification_model, ModelType.IMAGE_CLASSIFICATION
)
labels = model.predict(image)
return labels
@app.post(
"/sentence-transformer/encode-image",
response_model=EmbeddingResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def clip_encode_image(
image: Image = Depends(dep_input_image)
image: Image.Image = Depends(dep_pil_image),
) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_image_model, "clip")
embedding = model.encode(image).tolist()
model = await app.state.model_cache.get(
settings.clip_image_model, app.state.clip_vision_type
)
embedding = model.predict(image)
return embedding
@ -96,13 +98,12 @@ async def clip_encode_image(
"/sentence-transformer/encode-text",
response_model=EmbeddingResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def clip_encode_text(
payload: TextModelRequest
) -> list[float]:
model = await _model_cache.get_cached_model(settings.clip_text_model, "clip")
embedding = model.encode(payload.text).tolist()
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
model = await app.state.model_cache.get(
settings.clip_text_model, app.state.clip_text_type
)
embedding = model.predict(payload.text)
return embedding
@ -110,22 +111,21 @@ async def clip_encode_text(
"/facial-recognition/detect-faces",
response_model=FaceResponse,
status_code=200,
dependencies=[Depends(dep_model_cache)],
)
async def facial_recognition(
image: bytes = Body(...),
image: cv2.Mat = Depends(dep_cv_image),
) -> list[dict[str, Any]]:
model = await _model_cache.get_cached_model(
settings.facial_recognition_model, "facial-recognition"
model = await app.state.model_cache.get(
settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION
)
faces = run_facial_recognition(model, image)
faces = model.predict(image)
return faces
if __name__ == "__main__":
is_dev = os.getenv("NODE_ENV") == "development"
uvicorn.run(
"main:app",
"app.main:app",
host=settings.host,
port=settings.port,
reload=is_dev,