import importlib
import logging
import torch
from torch._dynamo import register_backend
from .common import fake_tensor_unsupported
log = logging.getLogger(__name__)
@register_backend
@fake_tensor_unsupported
def ipex(model, inputs):
try:
import intel_extension_for_pytorch # type: ignore[import] # noqa: F401
except ImportError:
log.exception(
"Unable to import Intel Extension for PyTorch (IPEX). "
"Please install the right version of IPEX that matches the PyTorch version being used. "
"Refer to https://github.com/intel/intel-extension-for-pytorch for details."
)
raise
try:
with torch.no_grad():
traced_model = torch.jit.trace(model.eval(), inputs)
traced_model = torch.jit.freeze(traced_model)
return traced_model
except Exception:
log.warning("JIT trace failed during the 'ipex' optimize process.")
return model
def has_ipex():
try:
importlib.import_module("intel_extension_for_pytorch")
return True
except ImportError:
return False