From c189a72d6a8b463e2bcfb66f598ac861a6e68716 Mon Sep 17 00:00:00 2001 From: Jeff Lomax Date: Sun, 19 Aug 2018 16:41:15 -0500 Subject: [PATCH] add freeze_safe support, python lint issues, checkpoint_every_n option --- super_resolution/main.py | 94 ++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/super_resolution/main.py b/super_resolution/main.py index f9f5f8b190..2b670d04d6 100644 --- a/super_resolution/main.py +++ b/super_resolution/main.py @@ -9,41 +9,8 @@ from model import Net from data import get_training_set, get_test_set -# Training settings -parser = argparse.ArgumentParser(description='PyTorch Super Res Example') -parser.add_argument('--upscale_factor', type=int, required=True, help="super resolution upscale factor") -parser.add_argument('--batchSize', type=int, default=64, help='training batch size') -parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size') -parser.add_argument('--nEpochs', type=int, default=2, help='number of epochs to train for') -parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01') -parser.add_argument('--cuda', action='store_true', help='use cuda?') -parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') -parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') -opt = parser.parse_args() -print(opt) - -if opt.cuda and not torch.cuda.is_available(): - raise Exception("No GPU found, please run without --cuda") - -torch.manual_seed(opt.seed) - -device = torch.device("cuda" if opt.cuda else "cpu") - -print('===> Loading datasets') -train_set = get_training_set(opt.upscale_factor) -test_set = get_test_set(opt.upscale_factor) -training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) -testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) - -print('===> Building model') -model = Net(upscale_factor=opt.upscale_factor).to(device) -criterion = nn.MSELoss() - -optimizer = optim.Adam(model.parameters(), lr=opt.lr) - - -def train(epoch): +def train(epoch, device, model, optimizer, criterion, training_data_loader): epoch_loss = 0 for iteration, batch in enumerate(training_data_loader, 1): input, target = batch[0].to(device), batch[1].to(device) @@ -59,7 +26,7 @@ def train(epoch): print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader))) -def test(): +def test(device, model, criterion, testing_data_loader): avg_psnr = 0 with torch.no_grad(): for batch in testing_data_loader: @@ -72,12 +39,53 @@ def test(): print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(testing_data_loader))) -def checkpoint(epoch): - model_out_path = "model_epoch_{}.pth".format(epoch) - torch.save(model, model_out_path) - print("Checkpoint saved to {}".format(model_out_path)) +def checkpoint(epoch, model, every_n): + if epoch % every_n == 0: + model_out_path = "model_epoch_{}.pth".format(epoch) + torch.save(model, model_out_path) + print("Checkpoint saved to {}".format(model_out_path)) + + +def freeze_safe_main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch Super Res Example') + parser.add_argument('--upscale_factor', type=int, required=True, help="super resolution upscale factor") + parser.add_argument('--batchSize', type=int, default=64, help='training batch size') + parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size') + parser.add_argument('--nEpochs', type=int, default=2, help='number of epochs to train for') + parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01') + parser.add_argument('--cuda', action='store_true', help='use cuda?') + parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') + parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') + parser.add_argument('--checkpoint_every_n', type=int, default=1, help='write checkpoint every N epochs') + opt = parser.parse_args() + + print(opt) + + if opt.cuda and not torch.cuda.is_available(): + raise Exception("No GPU found, please run without --cuda") + + torch.manual_seed(opt.seed) + + device = torch.device("cuda" if opt.cuda else "cpu") + + print('===> Loading datasets') + train_set = get_training_set(opt.upscale_factor) + test_set = get_test_set(opt.upscale_factor) + training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) + testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) + + print('===> Building model') + model = Net(upscale_factor=opt.upscale_factor).to(device) + criterion = nn.MSELoss() + + optimizer = optim.Adam(model.parameters(), lr=opt.lr) + + for epoch in range(1, opt.nEpochs + 1): + train(epoch, device, model, optimizer, criterion, training_data_loader) + test(device, model, criterion, testing_data_loader) + checkpoint(epoch, model, opt.checkpoint_every_n) -for epoch in range(1, opt.nEpochs + 1): - train(epoch) - test() - checkpoint(epoch) +# standard freeze_support so multithreading works everywhere +if __name__ == '__main__': + freeze_safe_main()