[go: nahoru, domu]

Skip to content

Commit

Permalink
torch_utils.float64
Browse files Browse the repository at this point in the history
return torch.float64 if device is not mps or xpu, else return torch.float32
  • Loading branch information
w-e-w committed May 16, 2024
1 parent ddb28b3 commit 9c8075b
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions modules/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import torch.nn
import torch


def get_param(model) -> torch.nn.Parameter:
Expand All @@ -15,3 +16,11 @@ def get_param(model) -> torch.nn.Parameter:
return param

raise ValueError(f"No parameters found in model {model!r}")


def float64(t: torch.Tensor):
"""return torch.float64 if device is not mps or xpu, else return torch.float32"""
match t.device.type:
case 'mps', 'xpu':
return torch.float32
return torch.float64

0 comments on commit 9c8075b

Please sign in to comment.