chore(ml): update pydantic (#13230)

* update pydantic

* fix typing

* remove unused import

* remove unused schema
This commit is contained in:
Mert 2024-10-13 18:00:21 -04:00 committed by GitHub
parent f29fb1655a
commit e7397f35c9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 186 additions and 82 deletions

View file

@ -6,7 +6,8 @@ from pathlib import Path
from socket import socket
from gunicorn.arbiter import Arbiter
from pydantic import BaseModel, BaseSettings
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
@ -14,11 +15,18 @@ from uvicorn.workers import UvicornWorker
class PreloadModelData(BaseModel):
clip: str | None
facial_recognition: str | None
clip: str | None = None
facial_recognition: str | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="MACHINE_LEARNING_",
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=("settings_",),
)
cache_folder: Path = Path("/cache")
model_ttl: int = 300
model_ttl_poll_s: int = 10
@ -34,23 +42,17 @@ class Settings(BaseSettings):
ann_tuning_level: int = 2
preload: PreloadModelData | None = None
class Config:
env_prefix = "MACHINE_LEARNING_"
case_sensitive = False
env_nested_delimiter = "__"
@property
def device_id(self) -> str:
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
class LogSettings(BaseSettings):
model_config = SettingsConfigDict(case_sensitive=False)
immich_log_level: str = "info"
no_color: bool = False
class Config:
case_sensitive = False
_clean_name = str.maketrans(":\\/", "___", ".")

View file

@ -12,7 +12,7 @@ from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse
from fastapi.responses import ORJSONResponse, PlainTextResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
@ -28,14 +28,12 @@ from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
MessageResponse,
ModelFormat,
ModelIdentity,
ModelTask,
ModelType,
PipelineRequest,
T,
TextResponse,
)
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
@ -127,14 +125,14 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
app = FastAPI(lifespan=lifespan)
@app.get("/", response_model=MessageResponse)
async def root() -> dict[str, str]:
return {"message": "Immich ML"}
@app.get("/")
async def root() -> ORJSONResponse:
return ORJSONResponse({"message": "Immich ML"})
@app.get("/ping", response_model=TextResponse)
def ping() -> str:
return "pong"
@app.get("/ping")
def ping() -> PlainTextResponse:
return PlainTextResponse("pong")
@app.post("/predict", dependencies=[Depends(update_state)])

View file

@ -1,9 +1,9 @@
from enum import Enum
from typing import Any, Literal, Protocol, TypedDict, TypeGuard, TypeVar
from typing import Any, Literal, Protocol, TypeGuard, TypeVar
import numpy as np
import numpy.typing as npt
from pydantic import BaseModel
from typing_extensions import TypedDict
class StrEnum(str, Enum):
@ -13,14 +13,6 @@ class StrEnum(str, Enum):
return self.value
class TextResponse(BaseModel):
__root__: str
class MessageResponse(BaseModel):
message: str
class BoundingBox(TypedDict):
x1: int
y1: int

View file

@ -810,11 +810,26 @@ class TestLoad:
mock_model.model_format = ModelFormat.ONNX
def test_root_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003")
body = response.json()
assert response.status_code == 200
assert body == {"message": "Immich ML"}
def test_ping_endpoint(deployed_app: TestClient) -> None:
response = deployed_app.get("http://localhost:3003/ping")
assert response.status_code == 200
assert response.text == "pong"
@pytest.mark.skipif(
not settings.test_full,
reason="More time-consuming since it deploys the app and loads models.",
)
class TestEndpoints:
class TestPredictionEndpoints:
def test_clip_image_endpoint(
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
) -> None: