fixed setting different clip, removed unused stubs (#2987)

This commit is contained in:
Mert 2023-06-27 13:21:50 -04:00 committed by GitHub
parent b3e97a1a0c
commit 4d3ce0a65e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 6 additions and 25 deletions

View file

@ -27,13 +27,10 @@ app = FastAPI()
@app.on_event("startup")
async def startup_event() -> None:
app.state.model_cache = ModelCache(ttl=settings.model_ttl, revalidate=True)
same_clip = settings.clip_image_model == settings.clip_text_model
app.state.clip_vision_type = ModelType.CLIP if same_clip else ModelType.CLIP_VISION
app.state.clip_text_type = ModelType.CLIP if same_clip else ModelType.CLIP_TEXT
models = [
(settings.classification_model, ModelType.IMAGE_CLASSIFICATION),
(settings.clip_image_model, app.state.clip_vision_type),
(settings.clip_text_model, app.state.clip_text_type),
(settings.clip_image_model, ModelType.CLIP),
(settings.clip_text_model, ModelType.CLIP),
(settings.facial_recognition_model, ModelType.FACIAL_RECOGNITION),
]
@ -87,9 +84,7 @@ async def image_classification(
async def clip_encode_image(
image: Image.Image = Depends(dep_pil_image),
) -> list[float]:
model = await app.state.model_cache.get(
settings.clip_image_model, app.state.clip_vision_type
)
model = await app.state.model_cache.get(settings.clip_image_model, ModelType.CLIP)
embedding = model.predict(image)
return embedding
@ -100,9 +95,7 @@ async def clip_encode_image(
status_code=200,
)
async def clip_encode_text(payload: TextModelRequest) -> list[float]:
model = await app.state.model_cache.get(
settings.clip_text_model, app.state.clip_text_type
)
model = await app.state.model_cache.get(settings.clip_text_model, ModelType.CLIP)
embedding = model.predict(payload.text)
return embedding