Skip to content

Add finetuning #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torchvision.models as models



model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
Expand Down Expand Up @@ -52,15 +53,77 @@

best_prec1 = 0

class FineTuneModel(nn.Module):
""" Finetunes just the last layer

This freezes the weights of all layers except the last one.
You should also look into finetuning previous layers, but slowly
Ideally, do this first, then unfreeze all layers and tune further with small lr

Arguments:
original_model: Model to finetune
arch: Name of model architecture
num_classes: Number of classes to tune for

"""
def __init__(self, original_model, arch, num_classes):
super(FineTuneModel, self).__init__()

if arch.startswith('alexnet') or arch.startswith('vgg'):
self.features = original_model.features
self.fc = nn.Sequential(*list(original_model.classifier.children())[:-1])
self.classifier = nn.Sequential(
nn.Linear(4096, num_classes)
)
elif arch.startswith('resnet') :
# Everything except the last linear layer
self.features = nn.Sequential(*list(original_model.children())[:-1])
self.classifier = nn.Sequential(
nn.Linear(512, num_classes)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

)
elif arch.startswith('inception') :
# Everything except the last linear layer
self.features = nn.Sequential(*list(original_model.children())[:-1])
self.classifier = nn.Sequential(
nn.Linear(2048, num_classes)
)
else :
raise("Finetuning not supported on this architecture yet. Feel free to add")

self.unfreeze(False) # Freeze weights except last layer

def unfreeze(self, unfreeze):
# Freeze those weights
for p in self.features.parameters():
p.requires_grad = unfreeze
if hasattr(self, 'fc'):
for p in self.fc.parameters():
p.requires_grad = unfreeze

def forward(self, x):
f = self.features(x)
if hasattr(self, 'fc'):
f = f.view(f.size(0), -1)
f = self.fc(f)
f = f.view(f.size(0), -1)
y = self.classifier(f)
return y


def main():
global args, best_prec1
args = parser.parse_args()

traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
# Get number of classes from train directory
num_classes = len([name for name in os.listdir(traindir)])

# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
original_model = models.__dict__[args.arch](pretrained=True)
model = FineTuneModel(original_model, args.arch, num_classes)
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()
Expand All @@ -74,11 +137,13 @@ def main():
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()

optimizer = torch.optim.SGD(model.parameters(), args.lr,
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), # Only finetunable params
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)

# optionally resume from a checkpoint

# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
Expand All @@ -95,8 +160,6 @@ def main():
cudnn.benchmark = True

# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

Expand All @@ -106,7 +169,7 @@ def main():
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)

Expand Down