diff --git a/pylingual/models.py b/pylingual/models.py index ada1e21..e6cd1d6 100644 --- a/pylingual/models.py +++ b/pylingual/models.py @@ -116,7 +116,11 @@ def load_models(config_file: Path = Path("pylingual/decompiler_config.yaml"), ve mask_token="[MASK]", ) if torch.cuda.is_available(): + logger.info("Using CUDA GPU for models") device = torch.device("cuda:0") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + logger.info("Using MPS (Metal Performance Shaders) GPU for models") + device = torch.device("mps") else: logger.warning("Using CPU for models") device = torch.device("cpu")