Source code for wsinfer.modellib.models

from __future__ import annotations

import dataclasses
import warnings
from typing import Callable

import torch
import wsinfer_zoo
from wsinfer_zoo.client import HFModelTorchScript
from wsinfer_zoo.client import Model


@dataclasses.dataclass
[docs] class LocalModelTorchScript(Model): ...
[docs] def get_registered_model(name: str) -> HFModelTorchScript: registry = wsinfer_zoo.client.load_registry() model = registry.get_model_by_name(name=name) return model.load_model_torchscript()
[docs] def get_pretrained_torch_module( model: HFModelTorchScript | LocalModelTorchScript, ) -> torch.nn.Module: """Get a PyTorch Module with weights loaded.""" mod: torch.nn.Module = torch.jit.load(model.model_path, map_location="cpu") if not isinstance(mod, torch.nn.Module): raise TypeError( "expected the loaded object to be a subclass of torch.nn.Module but got" f" {type(mod)}." ) return mod
[docs] def jit_compile( model: torch.nn.Module, ) -> torch.jit.ScriptModule | torch.nn.Module | Callable: """JIT-compile a model for inference. A torchscript model may be JIT compiled here as well. """ noncompiled = model device = next(model.parameters()).device # Attempt to script. If it fails, return the original. test_input = torch.ones(1, 3, 224, 224).to(device) w = "Warning: could not JIT compile the model. Using non-compiled model instead." # PyTorch 2.x has torch.compile but it does not work when applied # to TorchScript models. if hasattr(torch, "compile") and not isinstance(model, torch.jit.ScriptModule): # Try to get the most optimized model. try: return torch.compile(model, fullgraph=True, mode="max-autotune") except Exception: pass try: return torch.compile(model, mode="max-autotune") except Exception: pass try: return torch.compile(model) except Exception: warnings.warn(w, stacklevel=1) return noncompiled # For pytorch 1.x, use torch.jit.script. else: try: mjit = torch.jit.script(model) with torch.no_grad(): mjit(test_input) except Exception: warnings.warn(w, stacklevel=1) return noncompiled # Now that we have scripted the model, try to optimize it further. If that # fails, return the scripted model. try: mjit_frozen = torch.jit.freeze(mjit) mjit_opt = torch.jit.optimize_for_inference(mjit_frozen) with torch.no_grad(): mjit_opt(test_input) return mjit_opt except Exception: return mjit # type: ignore