fix(ml): tokenization for webli models (#11881)

This commit is contained in:
Mert 2024-08-18 11:05:10 -04:00 committed by GitHub
parent 5ab92f346a
commit 036676d501
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 48 additions and 3 deletions

View file

@ -1,3 +1,4 @@
import string
from io import BytesIO
from typing import IO
@ -7,6 +8,7 @@ from numpy.typing import NDArray
from PIL import Image
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
def resize_pil(img: Image.Image, size: int) -> Image.Image:
@ -60,3 +62,10 @@ def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[
if isinstance(image_bytes, Image.Image):
return pil_to_cv2(image_bytes)
return image_bytes
def clean_text(text: str, canonicalize: bool = False) -> str:
text = " ".join(text.split())
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text