Source code for wsinfer.modellib.transforms

"""PyTorch image classification transform."""

from __future__ import annotations

from torchvision import transforms
from wsinfer_zoo.client import TransformConfigurationItem

# The subset of transforms known to the wsinfer config spec.
# This can be expanded in the future as needs arise.
_name_to_tv_cls = {
    "Resize": transforms.Resize,
    "ToTensor": transforms.ToTensor,
    "Normalize": transforms.Normalize,
}


[docs] def make_compose_from_transform_config( list_of_transforms: list[TransformConfigurationItem], ) -> transforms.Compose: """Create a torchvision Compose instance from configuration of transforms.""" all_t: list = [] for t in list_of_transforms: cls = _name_to_tv_cls[t.name] kwargs = t.arguments or {} all_t.append(cls(**kwargs)) return transforms.Compose(all_t)