mirror of
https://github.com/immich-app/immich
synced 2025-11-07 17:27:20 +00:00
fix(ml): tokenization for webli models (#11881)
This commit is contained in:
parent
5ab92f346a
commit
036676d501
3 changed files with 48 additions and 3 deletions
|
|
@ -10,6 +10,7 @@ from tokenizers import Encoding, Tokenizer
|
|||
|
||||
from app.config import log
|
||||
from app.models.base import InferenceModel
|
||||
from app.models.transforms import clean_text
|
||||
from app.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
|
|
@ -25,6 +26,8 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
session = super()._load()
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
|
@ -56,6 +59,11 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@property
|
||||
def text_cfg(self) -> dict[str, Any]:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
return text_cfg
|
||||
|
||||
@cached_property
|
||||
def tokenizer_file(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||
|
|
@ -73,8 +81,7 @@ class BaseCLIPTextualEncoder(InferenceModel):
|
|||
|
||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
context_length: int = text_cfg.get("context_length", 77)
|
||||
context_length: int = self.text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
|
@ -86,12 +93,14 @@ class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
|||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue