mirror of
https://github.com/syssec-utd/pylingual.git
synced 2026-05-10 18:39:03 -07:00
mps support
This commit is contained in:
@@ -116,7 +116,11 @@ def load_models(config_file: Path = Path("pylingual/decompiler_config.yaml"), ve
|
||||
mask_token="[MASK]",
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
logger.info("Using CUDA GPU for models")
|
||||
device = torch.device("cuda:0")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
logger.info("Using MPS (Metal Performance Shaders) GPU for models")
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
logger.warning("Using CPU for models")
|
||||
device = torch.device("cpu")
|
||||
|
||||
Reference in New Issue
Block a user