@@ -1005,29 +1005,27 @@ def load_state_dict(
1005
1005
`load_state_dict` to crash. More precisely, `criterion.weight`
1006
1006
is holding the per-class weights for classification losses.
1007
1007
"""
1008
- # Special treatment for MultiLoss
1009
- if self .multi_stage_loss :
1010
- class_weight_bckp = self .criterion .weight
1011
- self .criterion .weight = None
1012
-
1013
- # Recover the class weights from any 'criterion.weight' or
1014
- # 'criterion.*.weight' key and remove those keys from the
1015
- # state_dict
1016
- keys = []
1017
- for key in state_dict .keys ():
1018
- if key .startswith ('criterion.' ) and key .endswith ('.weight' ):
1019
- keys .append (key )
1020
- class_weight = state_dict [keys [0 ]] if len (keys ) > 0 else None
1021
- for key in keys :
1022
- state_dict .pop (key )
1008
+ # Special treatment `criterion.weight`
1009
+ class_weight_bckp = self .criterion .weight
1010
+ self .criterion .weight = None
1011
+
1012
+ # Recover the class weights from any `criterion.weight' or
1013
+ # 'criterion.*.weight' key and remove those keys from the
1014
+ # state_dict
1015
+ keys = []
1016
+ for key in state_dict .keys ():
1017
+ if key .startswith ('criterion.' ) and key .endswith ('.weight' ):
1018
+ keys .append (key )
1019
+ class_weight = state_dict [keys [0 ]] if len (keys ) > 0 else None
1020
+ for key in keys :
1021
+ state_dict .pop (key )
1023
1022
1024
1023
# Load the state_dict
1025
1024
super ().load_state_dict (state_dict , strict = strict )
1026
1025
1027
1026
# If need be, assign the class weights to the criterion
1028
- if self .multi_stage_loss :
1029
- self .criterion .weight = class_weight if class_weight is not None \
1030
- else class_weight_bckp
1027
+ self .criterion .weight = class_weight if class_weight is not None \
1028
+ else class_weight_bckp
1031
1029
1032
1030
def _load_from_checkpoint (
1033
1031
self ,
0 commit comments