diff --git a/xautodl/datasets/synthetic_env.py b/xautodl/datasets/synthetic_env.py index aaa1b98..077e826 100644 --- a/xautodl/datasets/synthetic_env.py +++ b/xautodl/datasets/synthetic_env.py @@ -127,6 +127,10 @@ class SyntheticDEnv(data.Dataset): targets = torch.from_numpy(targets) else: targets = torch.Tensor(targets) + if dataset.dtype == torch.float64: + dataset = dataset.float() + if targets.dtype == torch.float64: + targets = targets.float() return torch.Tensor([timestamp]), (dataset, targets) def __len__(self):