chore(ml): added testing and github workflow (#2969)

* added testing

* github action for python, made mypy happy

* formatted with black

* minor fixes and styling

* test model cache

* cache test dependencies

* narrowed model cache tests

* moved endpoint tests to their own class

* cleaned up fixtures

* formatting

* removed unused dep
This commit is contained in:
Mert 2023-06-27 19:21:33 -04:00 committed by GitHub
parent 5e3bdc76b2
commit df1e8679d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 622 additions and 95 deletions

View file

@ -16,8 +16,8 @@ class ImageClassifier(InferenceModel):
self,
model_name: str,
min_score: float = settings.min_tag_score,
cache_dir: Path | None = None,
**model_kwargs,
cache_dir: Path | str | None = None,
**model_kwargs: Any,
) -> None:
self.min_score = min_score
super().__init__(model_name, cache_dir, **model_kwargs)
@ -30,13 +30,7 @@ class ImageClassifier(InferenceModel):
)
def predict(self, image: Image) -> list[str]:
predictions = self.model(image)
tags = list(
{
tag
for pred in predictions
for tag in pred["label"].split(", ")
if pred["score"] >= self.min_score
}
)
predictions: list[dict[str, Any]] = self.model(image) # type: ignore
tags = [tag for pred in predictions for tag in pred["label"].split(", ") if pred["score"] >= self.min_score]
return tags