mps support

This commit is contained in:
Joel Flores
2025-06-26 17:37:08 -05:00
parent 3fae0294fb
commit 3ab3abccc1
+4
View File
@@ -116,7 +116,11 @@ def load_models(config_file: Path = Path("pylingual/decompiler_config.yaml"), ve
mask_token="[MASK]",
)
if torch.cuda.is_available():
logger.info("Using CUDA GPU for models")
device = torch.device("cuda:0")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info("Using MPS (Metal Performance Shaders) GPU for models")
device = torch.device("mps")
else:
logger.warning("Using CPU for models")
device = torch.device("cpu")