diff --git a/imagenet/main.py b/imagenet/main.py index 2a1540ba13..1888081f05 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -14,6 +14,7 @@ import torchvision.models as models + model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) @@ -52,15 +53,77 @@ best_prec1 = 0 +class FineTuneModel(nn.Module): + """ Finetunes just the last layer + + This freezes the weights of all layers except the last one. + You should also look into finetuning previous layers, but slowly + Ideally, do this first, then unfreeze all layers and tune further with small lr + + Arguments: + original_model: Model to finetune + arch: Name of model architecture + num_classes: Number of classes to tune for + + """ + def __init__(self, original_model, arch, num_classes): + super(FineTuneModel, self).__init__() + + if arch.startswith('alexnet') or arch.startswith('vgg'): + self.features = original_model.features + self.fc = nn.Sequential(*list(original_model.classifier.children())[:-1]) + self.classifier = nn.Sequential( + nn.Linear(4096, num_classes) + ) + elif arch.startswith('resnet') : + # Everything except the last linear layer + self.features = nn.Sequential(*list(original_model.children())[:-1]) + self.classifier = nn.Sequential( + nn.Linear(512, num_classes) + ) + elif arch.startswith('inception') : + # Everything except the last linear layer + self.features = nn.Sequential(*list(original_model.children())[:-1]) + self.classifier = nn.Sequential( + nn.Linear(2048, num_classes) + ) + else : + raise("Finetuning not supported on this architecture yet. Feel free to add") + + self.unfreeze(False) # Freeze weights except last layer + + def unfreeze(self, unfreeze): + # Freeze those weights + for p in self.features.parameters(): + p.requires_grad = unfreeze + if hasattr(self, 'fc'): + for p in self.fc.parameters(): + p.requires_grad = unfreeze + + def forward(self, x): + f = self.features(x) + if hasattr(self, 'fc'): + f = f.view(f.size(0), -1) + f = self.fc(f) + f = f.view(f.size(0), -1) + y = self.classifier(f) + return y + def main(): global args, best_prec1 args = parser.parse_args() + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + # Get number of classes from train directory + num_classes = len([name for name in os.listdir(traindir)]) + # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) - model = models.__dict__[args.arch](pretrained=True) + original_model = models.__dict__[args.arch](pretrained=True) + model = FineTuneModel(original_model, args.arch, num_classes) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() @@ -74,11 +137,13 @@ def main(): # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() - optimizer = torch.optim.SGD(model.parameters(), args.lr, + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params + args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - # optionally resume from a checkpoint + + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) @@ -95,8 +160,6 @@ def main(): cudnn.benchmark = True # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -106,7 +169,7 @@ def main(): transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, - ])), + ])), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)