@@ -82,39 +82,35 @@ def main():
82
82
help = 'learning rate (default: 1.0)' )
83
83
parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
84
84
help = 'Learning rate step gamma (default: 0.7)' )
85
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
86
- help = 'disables CUDA training' )
87
- parser .add_argument ('--no-mps' , action = 'store_true' , default = False ,
88
- help = 'disables macOS GPU training' )
89
- parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
85
+ parser .add_argument ('--no-accel' , action = 'store_true' ,
86
+ help = 'disables accelerator' )
87
+ parser .add_argument ('--dry-run' , action = 'store_true' ,
90
88
help = 'quickly check a single pass' )
91
89
parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
92
90
help = 'random seed (default: 1)' )
93
91
parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
94
92
help = 'how many batches to wait before logging training status' )
95
- parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
93
+ parser .add_argument ('--save-model' , action = 'store_true' ,
96
94
help = 'For Saving the current Model' )
97
95
args = parser .parse_args ()
98
- use_cuda = not args . no_cuda and torch . cuda . is_available ()
99
- use_mps = not args .no_mps and torch .backends . mps .is_available ()
96
+
97
+ use_accel = not args .no_accel and torch .accelerator .is_available ()
100
98
101
99
torch .manual_seed (args .seed )
102
100
103
- if use_cuda :
104
- device = torch .device ("cuda" )
105
- elif use_mps :
106
- device = torch .device ("mps" )
101
+ if use_accel :
102
+ device = torch .accelerator .current_accelerator ()
107
103
else :
108
104
device = torch .device ("cpu" )
109
105
110
106
train_kwargs = {'batch_size' : args .batch_size }
111
107
test_kwargs = {'batch_size' : args .test_batch_size }
112
- if use_cuda :
113
- cuda_kwargs = {'num_workers' : 1 ,
108
+ if use_accel :
109
+ accel_kwargs = {'num_workers' : 1 ,
114
110
'pin_memory' : True ,
115
111
'shuffle' : True }
116
- train_kwargs .update (cuda_kwargs )
117
- test_kwargs .update (cuda_kwargs )
112
+ train_kwargs .update (accel_kwargs )
113
+ test_kwargs .update (accel_kwargs )
118
114
119
115
transform = transforms .Compose ([
120
116
transforms .ToTensor (),
0 commit comments