refactor(ml): model sessions (#10559)

This commit is contained in:
Mert 2024-06-25 12:00:24 -04:00 committed by GitHub
parent 6538ad8de7
commit 6356c28f64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 529 additions and 375 deletions

View file

@ -54,6 +54,14 @@ class ModelSource(StrEnum):
ModelIdentity = tuple[ModelType, ModelTask]
class SessionNode(Protocol):
@property
def name(self) -> str | None: ...
@property
def shape(self) -> tuple[int, ...]: ...
class ModelSession(Protocol):
def run(
self,
@ -62,6 +70,10 @@ class ModelSession(Protocol):
run_options: Any = None,
) -> list[npt.NDArray[np.float32]]: ...
def get_inputs(self) -> list[SessionNode]: ...
def get_outputs(self) -> list[SessionNode]: ...
class HasProfiling(Protocol):
profiling: dict[str, float]