Skip to content

Commit 8bae509

Browse files
committed
fix load_state_dict() error when not using multi-stage loss
1 parent 662cd86 commit 8bae509

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

src/models/semantic.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -1005,29 +1005,27 @@ def load_state_dict(
10051005
`load_state_dict` to crash. More precisely, `criterion.weight`
10061006
is holding the per-class weights for classification losses.
10071007
"""
1008-
# Special treatment for MultiLoss
1009-
if self.multi_stage_loss:
1010-
class_weight_bckp = self.criterion.weight
1011-
self.criterion.weight = None
1012-
1013-
# Recover the class weights from any 'criterion.weight' or
1014-
# 'criterion.*.weight' key and remove those keys from the
1015-
# state_dict
1016-
keys = []
1017-
for key in state_dict.keys():
1018-
if key.startswith('criterion.') and key.endswith('.weight'):
1019-
keys.append(key)
1020-
class_weight = state_dict[keys[0]] if len(keys) > 0 else None
1021-
for key in keys:
1022-
state_dict.pop(key)
1008+
# Special treatment `criterion.weight`
1009+
class_weight_bckp = self.criterion.weight
1010+
self.criterion.weight = None
1011+
1012+
# Recover the class weights from any `criterion.weight' or
1013+
# 'criterion.*.weight' key and remove those keys from the
1014+
# state_dict
1015+
keys = []
1016+
for key in state_dict.keys():
1017+
if key.startswith('criterion.') and key.endswith('.weight'):
1018+
keys.append(key)
1019+
class_weight = state_dict[keys[0]] if len(keys) > 0 else None
1020+
for key in keys:
1021+
state_dict.pop(key)
10231022

10241023
# Load the state_dict
10251024
super().load_state_dict(state_dict, strict=strict)
10261025

10271026
# If need be, assign the class weights to the criterion
1028-
if self.multi_stage_loss:
1029-
self.criterion.weight = class_weight if class_weight is not None \
1030-
else class_weight_bckp
1027+
self.criterion.weight = class_weight if class_weight is not None \
1028+
else class_weight_bckp
10311029

10321030
def _load_from_checkpoint(
10331031
self,

0 commit comments

Comments
 (0)