diff --git a/lib/utils/detectron_weight_helper.py b/lib/utils/detectron_weight_helper.py index 0953ffb8..ee70737e 100644 --- a/lib/utils/detectron_weight_helper.py +++ b/lib/utils/detectron_weight_helper.py @@ -1,10 +1,16 @@ """Helper functions for loading pretrained weights from Detectron pickle files """ -import pickle import re +import pickle +import logging + import torch +from utils.logging import setup_logging + +logger = setup_logging(__name__) +logger.setLevel(logging.WARN) def load_detectron_weight(net, detectron_weight_file): name_mapping, orphan_in_detectron = net.detectron_weight_mapping @@ -18,7 +24,11 @@ def load_detectron_weight(net, detectron_weight_file): for p_name, p_tensor in params.items(): d_name = name_mapping[p_name] if isinstance(d_name, str): # maybe str, None or True - p_tensor.copy_(torch.Tensor(src_blobs[d_name])) + d_param = src_blobs.get(d_name) + if d_param is not None: + p_tensor.copy_(torch.Tensor(d_param)) + else: + logger.warn('{} not found'.format(d_name)) def resnet_weights_name_pattern():