mirror of
https://github.com/immich-app/immich
synced 2025-10-17 18:19:27 +00:00
chore(ml): update pydantic (#13230)
* update pydantic * fix typing * remove unused import * remove unused schema
This commit is contained in:
parent
f29fb1655a
commit
e7397f35c9
6 changed files with 186 additions and 82 deletions
|
|
@ -6,7 +6,8 @@ from pathlib import Path
|
|||
from socket import socket
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from pydantic import BaseModel, BaseSettings
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from uvicorn import Server
|
||||
|
|
@ -14,11 +15,18 @@ from uvicorn.workers import UvicornWorker
|
|||
|
||||
|
||||
class PreloadModelData(BaseModel):
|
||||
clip: str | None
|
||||
facial_recognition: str | None
|
||||
clip: str | None = None
|
||||
facial_recognition: str | None = None
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="MACHINE_LEARNING_",
|
||||
case_sensitive=False,
|
||||
env_nested_delimiter="__",
|
||||
protected_namespaces=("settings_",),
|
||||
)
|
||||
|
||||
cache_folder: Path = Path("/cache")
|
||||
model_ttl: int = 300
|
||||
model_ttl_poll_s: int = 10
|
||||
|
|
@ -34,23 +42,17 @@ class Settings(BaseSettings):
|
|||
ann_tuning_level: int = 2
|
||||
preload: PreloadModelData | None = None
|
||||
|
||||
class Config:
|
||||
env_prefix = "MACHINE_LEARNING_"
|
||||
case_sensitive = False
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@property
|
||||
def device_id(self) -> str:
|
||||
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
|
||||
|
||||
|
||||
class LogSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(case_sensitive=False)
|
||||
|
||||
immich_log_level: str = "info"
|
||||
no_color: bool = False
|
||||
|
||||
class Config:
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from zipfile import BadZipFile
|
|||
|
||||
import orjson
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||
from PIL.Image import Image
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -28,14 +28,12 @@ from .schemas import (
|
|||
InferenceEntries,
|
||||
InferenceEntry,
|
||||
InferenceResponse,
|
||||
MessageResponse,
|
||||
ModelFormat,
|
||||
ModelIdentity,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
PipelineRequest,
|
||||
T,
|
||||
TextResponse,
|
||||
)
|
||||
|
||||
MultiPartParser.max_file_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
||||
|
|
@ -127,14 +125,14 @@ def get_entries(entries: str = Form()) -> InferenceEntries:
|
|||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/", response_model=MessageResponse)
|
||||
async def root() -> dict[str, str]:
|
||||
return {"message": "Immich ML"}
|
||||
@app.get("/")
|
||||
async def root() -> ORJSONResponse:
|
||||
return ORJSONResponse({"message": "Immich ML"})
|
||||
|
||||
|
||||
@app.get("/ping", response_model=TextResponse)
|
||||
def ping() -> str:
|
||||
return "pong"
|
||||
@app.get("/ping")
|
||||
def ping() -> PlainTextResponse:
|
||||
return PlainTextResponse("pong")
|
||||
|
||||
|
||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol, TypedDict, TypeGuard, TypeVar
|
||||
from typing import Any, Literal, Protocol, TypeGuard, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
|
|
@ -13,14 +13,6 @@ class StrEnum(str, Enum):
|
|||
return self.value
|
||||
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
__root__: str
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class BoundingBox(TypedDict):
|
||||
x1: int
|
||||
y1: int
|
||||
|
|
|
|||
|
|
@ -810,11 +810,26 @@ class TestLoad:
|
|||
mock_model.model_format = ModelFormat.ONNX
|
||||
|
||||
|
||||
def test_root_endpoint(deployed_app: TestClient) -> None:
|
||||
response = deployed_app.get("http://localhost:3003")
|
||||
|
||||
body = response.json()
|
||||
assert response.status_code == 200
|
||||
assert body == {"message": "Immich ML"}
|
||||
|
||||
|
||||
def test_ping_endpoint(deployed_app: TestClient) -> None:
|
||||
response = deployed_app.get("http://localhost:3003/ping")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.text == "pong"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not settings.test_full,
|
||||
reason="More time-consuming since it deploys the app and loads models.",
|
||||
)
|
||||
class TestEndpoints:
|
||||
class TestPredictionEndpoints:
|
||||
def test_clip_image_endpoint(
|
||||
self, pil_image: Image.Image, responses: dict[str, Any], deployed_app: TestClient
|
||||
) -> None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue