diff --git a/src/python/piper_train/__main__.py b/src/python/piper_train/__main__.py index 824c7cc6f..eb8f5872b 100644 --- a/src/python/piper_train/__main__.py +++ b/src/python/piper_train/__main__.py @@ -125,11 +125,13 @@ def main(): else: torch.manual_seed(args.seed) _LOGGER.debug("Using manual seed: %s", args.seed) - + # Function to check if the GPU supports Tensor Cores def supports_tensor_cores(): # Assuming that Tensor Cores are supported if the compute capability is 7.0 or higher # This is a simplification; you might need a more detailed check based on your specific requirements + if args.accelerator == "cpu": + return False return torch.cuda.get_device_capability(0)[0] >= 7 # Set the float32 matrix multiplication precision based on GPU support for Tensor Cores