diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index c1521c8b78..dba8847ca0 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -148,6 +148,38 @@ jobs: -word_vec_size 16 -report_every 5 \ -rnn_size 16 -train_steps 10 \ -copy_attn + - name: Test LM training with label smoothing + run: | + python train.py \ + -config data/lm_data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.src \ + -model_task lm \ + -encoder_type transformer_lm \ + -decoder_type transformer_lm \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -label_smoothing 0.1 \ + -dec_layers 2 -batch_size 10 \ + -heads 4 -transformer_ff 64 \ + -word_vec_size 16 -report_every 5 \ + -rnn_size 16 -train_steps 10 + - name: Test LM training with unlikelihood loss + run: | + python train.py \ + -config data/lm_data.yaml \ + -src_vocab /tmp/onmt.vocab.src \ + -tgt_vocab /tmp/onmt.vocab.src \ + -model_task lm \ + -encoder_type transformer_lm \ + -decoder_type transformer_lm \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -unlikelihood_coeff 1.0 \ + -dec_layers 2 -batch_size 10 \ + -heads 4 -transformer_ff 64 \ + -word_vec_size 16 -report_every 5 \ + -rnn_size 16 -train_steps 10 - name: Test Graph neural network training run: | python train.py \ diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 76d8201086..1a54749673 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -2,7 +2,7 @@ import torch.nn as nn from onmt.utils.misc import aeq -from onmt.utils.loss import CommonLossCompute +from onmt.utils.loss import LossComputeBase def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None, @@ -177,7 +177,7 @@ def forward(self, scores, align, target): return loss -class CommonCopyGeneratorLossCompute(CommonLossCompute): +class CommonCopyGeneratorLossCompute(LossComputeBase): """Common Copy Generator Loss Computation.""" def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, lambda_coverage=0.0, tgt_shift_index=1): @@ -231,7 +231,8 @@ def _compute_loss(self, batch, output, target, copy_attn, align, target_data[correct_mask] += offset_align # Compute sum of perplexities for stats - stats = self._stats(loss.sum().clone(), scores_data, target_data) + stats = self._stats(loss.sum().clone(), loss.sum().clone(), + scores_data, target_data) # this part looks like it belongs in CopyGeneratorLoss if self.normalize_by_length: diff --git a/onmt/modules/sparse_losses.py b/onmt/modules/sparse_losses.py index 08a24e1d2e..b794021d2c 100644 --- a/onmt/modules/sparse_losses.py +++ b/onmt/modules/sparse_losses.py @@ -77,3 +77,9 @@ def forward(self, input, target): elif self.reduction == 'elementwise_mean': loss = loss.sum() / size return loss + + +class ExpandedSparsemaxLoss(SparsemaxLoss): + def forward(self, input, target): + gtruth = target.view(-1) + return super(ExpandedSparsemaxLoss, self).forward(input, gtruth) diff --git a/onmt/opts.py b/onmt/opts.py index 6597fbeb00..caa1d4aa13 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -574,13 +574,23 @@ def _add_train_general_opts(parser): 'suggested a value of 0.98 for beta2, this parameter may ' 'not work well for normal models / default ' 'baselines.') - group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0, - help="Label smoothing value epsilon. " - "Probabilities of all non-true labels " - "will be smoothed by epsilon / (vocab_size - 1). " - "Set to zero to turn off label smoothing. " - "For more detailed information, see: " - "https://arxiv.org/abs/1512.00567") + subgroup = group.add_mutually_exclusive_group() + subgroup.add('--label_smoothing', '-label_smoothing', type=float, + default=0.0, + help="Label smoothing value epsilon. " + "Probabilities of all non-true labels " + "will be smoothed by epsilon / (vocab_size - 1). " + "Set to zero to turn off label smoothing. " + "For more detailed information, see: " + "https://arxiv.org/abs/1512.00567") + subgroup.add('--unlikelihood_coeff', '-unlikelihood_coeff', type=float, + default=0.0, + help="Loss coefficient for token unlikelihood loss. " + "Usually set to 1. max_generator_batches option will " + "limit the neighbourhood size of the unlikelihood loss." + " For more detailed information, see: " + "https://arxiv.org/abs/1908.04319 and " + "https://openreview.net/forum?id=SJeYe0NtvH") group.add('--average_decay', '-average_decay', type=float, default=0, help="Moving average decay. " "Set to other than 0 (e.g. 1e-4) to activate. " diff --git a/onmt/tests/pull_request_chk.sh b/onmt/tests/pull_request_chk.sh index 18b86dfe29..d588020f7c 100755 --- a/onmt/tests/pull_request_chk.sh +++ b/onmt/tests/pull_request_chk.sh @@ -179,6 +179,42 @@ ${PYTHON} onmt/bin/train.py \ -rnn_size 16 -train_steps 10 \ -copy_attn >> ${LOG_FILE} 2>&1 [ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + +echo -n " [+] Testing LM training with label smoothing..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/lm_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -model_task lm \ + -encoder_type transformer_lm \ + -decoder_type transformer_lm \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -label_smoothing 0.1 \ + -dec_layers 2 -batch_size 10 \ + -heads 4 -transformer_ff 64 \ + -word_vec_size 16 -report_every 5 \ + -rnn_size 16 -train_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit +echo "Succeeded" | tee -a ${LOG_FILE} + +echo -n " [+] Testing LM training with unlikelihood loss..." +${PYTHON} onmt/bin/train.py \ + -config ${DATA_DIR}/lm_data.yaml \ + -src_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -tgt_vocab $TMP_OUT_DIR/onmt.vocab.src \ + -model_task lm \ + -encoder_type transformer_lm \ + -decoder_type transformer_lm \ + -src_vocab_size 1000 \ + -tgt_vocab_size 1000 \ + -unlikelihood_coeff 1 \ + -dec_layers 2 -batch_size 10 \ + -heads 4 -transformer_ff 64 \ + -word_vec_size 16 -report_every 5 \ + -rnn_size 16 -train_steps 10 >> ${LOG_FILE} 2>&1 +[ "$?" -eq 0 ] || error_exit echo "Succeeded" | tee -a ${LOG_FILE}* rm $TMP_OUT_DIR/onmt.vocab* diff --git a/onmt/tests/test_unlikelihood_loss_criterion.py b/onmt/tests/test_unlikelihood_loss_criterion.py new file mode 100644 index 0000000000..20d43c6714 --- /dev/null +++ b/onmt/tests/test_unlikelihood_loss_criterion.py @@ -0,0 +1,78 @@ +import unittest +from onmt.utils.loss import UnlikelihoodTokenLoss +import torch +import math + + +class TestUnlikelihoodLossCriterion(unittest.TestCase): + def test_compute_previous_context_tokens(self): + criterion = UnlikelihoodTokenLoss(1, 7) + target = torch.tensor([[2, 3, 4, 3, 5], [1, 1, 5, 6, 7]]).permute(1, 0) + previous_context_tokens = criterion.compute_previous_context_tokens( + target + ) + + self.assertEqual( + previous_context_tokens.permute(1, 0, 2).tolist(), + torch.tensor( + [ + [ + [7, 7, 7, 7, 7], + [2, 7, 7, 7, 7], + [2, 3, 7, 7, 7], + [2, 7, 4, 7, 7], + [2, 3, 4, 3, 7], + ], + [ + [7, 7, 7, 7, 7], + [7, 7, 7, 7, 7], + [1, 1, 7, 7, 7], + [1, 1, 5, 7, 7], + [7, 7, 7, 7, 7], + ], + ] + ).tolist(), + ) + + def test_loss_perfect_pred_should_be_zero(self): + criterion = UnlikelihoodTokenLoss(1, 7) + n_prob = -10e6 + target = torch.tensor([[2, 3, 4, 3, 5], [1, 1, 5, 6, 7]]).permute(1, 0) + perfect_probs = [ + [[n_prob if i != t else 1 for i in range(8)] for t in ex_target] + for ex_target in target + ] + + # check padded seq is removed + perfect_probs[-1][-1][-1] = n_prob + perfect_probs[-1][-1][1] = 0.1 + + output = torch.tensor(perfect_probs).view(-1, 8) + + unlikelihood_loss = criterion.compute_unlikelihood_loss(output, target) + + self.assertEqual(unlikelihood_loss.sum().item(), 0) + + def test_loss_value(self): + criterion = UnlikelihoodTokenLoss(1, 7) + n_prob = -10e6 + target = torch.tensor([[2, 3, 4, 3, 5], [1, 1, 5, 6, 7]]).permute(1, 0) + perfect_probs = [ + [[n_prob if i != t else 1 for i in range(8)] for t in ex_target] + for ex_target in target + ] + + # check padded seq is removed + perfect_probs[-1][-1][-1] = n_prob + perfect_probs[-1][-1][1] = 0.1 + + # set prob at 0.5 on 1 after softmax + perfect_probs[2][-1][1] = 1 + + output = torch.tensor(perfect_probs).view(-1, 8) + + unlikelihood_loss = criterion.compute_unlikelihood_loss(output, target) + + self.assertAlmostEqual( + unlikelihood_loss.view(5, 2, 8)[2, -1, 1].item(), -math.log(0.5) + ) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c772b3369b..92d0ade422 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -8,6 +8,7 @@ import onmt from onmt.modules.sparse_losses import SparsemaxLoss +from onmt.modules.sparse_losses import ExpandedSparsemaxLoss from onmt.modules.sparse_activations import LogSparsemax from onmt.constants import ModelTask @@ -35,14 +36,21 @@ def build_loss_compute(model, tgt_field, opt, train=True): len(tgt_field.vocab), opt.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx ) - elif opt.label_smoothing > 0 and train: + elif opt.label_smoothing > 0: criterion = LabelSmoothingLoss( - opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx + opt.label_smoothing, len(tgt_field.vocab), + ignore_index=padding_idx + ) + elif opt.unlikelihood_coeff > 0: + criterion = UnlikelihoodTokenLoss( + opt.unlikelihood_coeff, ignore_index=padding_idx ) elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + criterion = ExpandedSparsemaxLoss( + ignore_index=padding_idx, reduction="sum" + ) else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + criterion = ExpandedNLLLoss(ignore_index=padding_idx, reduction="sum") # if the loss function operates on vectors of raw logits instead of # probabilities, only the first part of the generator needs to be @@ -110,13 +118,18 @@ class LossComputeBase(nn.Module): distribution over the target vocabulary. tgt_vocab (:obj:`Vocab`) : torchtext vocab object representing the target output - normalzation (str): normalize by "sents" or "tokens" + normalization (str): normalize by "sents" or "tokens" """ - def __init__(self, criterion, generator): + def __init__(self, criterion, generator, lambda_coverage=0.0, + lambda_align=0.0, tgt_shift_index=1): super(LossComputeBase, self).__init__() self.criterion = criterion self.generator = generator + self.lambda_coverage = lambda_coverage + self.lambda_align = lambda_align + self.tgt_shift_index = tgt_shift_index + self.ppl_criterion = None @property def padding_idx(self): @@ -134,9 +147,73 @@ def _make_shard_state(self, batch, output, range_, attns=None): batch or a trunc of it? attns: the attns dictionary returned from the model. """ - return NotImplementedError + range_start = range_[0] + self.tgt_shift_index + range_end = range_[1] + shard_state = { + "output": output, + "target": batch.tgt[range_start:range_end, :, 0], + } + if self.lambda_coverage != 0.0: + self._add_coverage_shard_state(shard_state, attns) + if self.lambda_align != 0.0: + self._add_align_shard_state( + shard_state, batch, range_start, range_end, attns + ) + return shard_state - def _compute_loss(self, batch, output, target, **kwargs): + def _add_coverage_shard_state(self, shard_state, attns): + coverage = attns.get("coverage", None) + std = attns.get("std", None) + assert attns is not None + assert coverage is not None, ( + "lambda_coverage != 0.0 requires coverage attention" + " that could not be found in the model." + " Transformer decoders do not implement coverage" + ) + assert std is not None, ( + "lambda_coverage != 0.0 requires attention mechanism" + " that could not be found in the model." + ) + shard_state.update({"std_attn": attns.get("std"), + "coverage_attn": coverage}) + + def _add_align_shard_state(self, shard_state, batch, range_start, + range_end, attns): + # attn_align should be in (batch_size, pad_tgt_size, pad_src_size) + attn_align = attns.get("align", None) + # align_idx should be a Tensor in size([N, 3]), N is total number + # of align src-tgt pair in current batch, each as + # ['sent_N°_in_batch', 'tgt_id+1', 'src_id'] (check AlignField) + align_idx = batch.align + assert attns is not None + assert attn_align is not None, ( + "lambda_align != 0.0 requires " "alignement attention head" + ) + assert align_idx is not None, ( + "lambda_align != 0.0 requires " "provide guided alignement" + ) + pad_tgt_size, batch_size, _ = batch.tgt.size() + pad_src_size = batch.src[0].size(0) + align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] + ref_align = onmt.utils.make_batch_align_matrix( + align_idx, align_matrix_size, normalize=True + ) + # NOTE: tgt-src ref alignement that in range_ of shard + # (coherent with batch.tgt) + shard_state.update( + { + "align_head": attn_align, + "ref_align": ref_align[:, range_start:range_end, :], + } + ) + + def _compute_coverage_loss(self, std_attn, coverage_attn): + covloss = torch.min(std_attn, coverage_attn).sum() + covloss *= self.lambda_coverage + return covloss + + def _compute_loss(self, batch, output, target, std_attn=None, + coverage_attn=None, align_head=None, ref_align=None): """ Compute the loss. Subclass must define this method. @@ -147,7 +224,44 @@ def _compute_loss(self, batch, output, target, **kwargs): target: the validate target to compare output with. **kwargs(optional): additional info for computing loss. """ - return NotImplementedError + bottled_output = self._bottle(output) + + scores = self.generator(bottled_output) + gtruth = target.view(-1) + + loss = self.criterion(scores, target) + if self.lambda_coverage != 0.0: + coverage_loss = self._compute_coverage_loss( + std_attn=std_attn, coverage_attn=coverage_attn) + loss += coverage_loss + if self.lambda_align != 0.0: + if align_head.dtype != loss.dtype: # Fix FP16 + align_head = align_head.to(loss.dtype) + if ref_align.dtype != loss.dtype: + ref_align = ref_align.to(loss.dtype) + align_loss = self._compute_alignement_loss( + align_head=align_head, ref_align=ref_align) + loss += align_loss + + log_ppl = self._compute_log_ppl(scores, gtruth) + stats = self._stats(loss.clone(), log_ppl, scores, gtruth) + + return loss, stats + + def _compute_log_ppl(self, scores, gtruth): + with torch.no_grad(): + log_ppl = self.ppl_criterion(scores, gtruth) + return log_ppl + + def _compute_alignement_loss(self, align_head, ref_align): + """Compute loss between 2 partial alignment matrix.""" + # align_head contains value in [0, 1) presenting attn prob, + # 0 was resulted by the context attention src_pad_mask + # So, the correspand position in ref_align should also be 0 + # Therefore, clip align_head to > 1e-18 should be bias free. + align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum() + align_loss *= self.lambda_align + return align_loss def __call__(self, batch, @@ -198,7 +312,7 @@ def __call__(self, batch_stats.update(stats) return None, batch_stats - def _stats(self, loss, scores, target): + def _stats(self, loss, log_ppl, scores, target): """ Args: loss (:obj:`FloatTensor`): the loss computed by the loss criterion. @@ -212,7 +326,8 @@ def _stats(self, loss, scores, target): non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() - return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) + return onmt.utils.Statistics(loss.item(), num_non_padding, + num_correct, log_ppl.item()) def _bottle(self, _v): return _v.view(-1, _v.size(2)) @@ -241,155 +356,141 @@ def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): def forward(self, output, target): """ - output (FloatTensor): batch_size x n_classes - target (LongTensor): batch_size + output (FloatTensor): (batch_size x batch_length) x n_classes + target (LongTensor): batch_size x batch_length """ - model_prob = self.one_hot.repeat(target.size(0), 1) - model_prob.scatter_(1, target.unsqueeze(1), self.confidence) - model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) + gtruth = target.view(-1) + model_prob = self.one_hot.repeat(gtruth.size(0), 1) + model_prob.scatter_(1, gtruth.unsqueeze(1), self.confidence) + model_prob.masked_fill_((gtruth == self.ignore_index).unsqueeze(1), 0) return F.kl_div(output, model_prob, reduction='sum') -class CommonLossCompute(LossComputeBase): - """ - Loss Computation parent for NMTLossCompute and LMLossCompute +class ExpandedNLLLoss(nn.NLLLoss): + def forward(self, input, target): + gtruth = target.view(-1) + return super(ExpandedNLLLoss, self).forward(input, gtruth) + - Implement loss compatible with coverage and alignement shards +class UnlikelihoodTokenLoss(nn.Module): + """ + Unlikelihood token level loss """ - def __init__(self, criterion, generator, normalization="sents", - lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1): - super(CommonLossCompute, self).__init__(criterion, generator) - self.lambda_coverage = lambda_coverage - self.lambda_align = lambda_align - self.tgt_shift_index = tgt_shift_index - def _add_coverage_shard_state(self, shard_state, attns): - coverage = attns.get("coverage", None) - std = attns.get("std", None) - assert attns is not None - assert coverage is not None, ( - "lambda_coverage != 0.0 requires coverage attention" - " that could not be found in the model." - " Transformer decoders do not implement coverage" + def __init__(self, unlikelihood_coeff, ignore_index=-100): + assert 0.0 < unlikelihood_coeff, ( + f"unlikelihood_coeff {unlikelihood_coeff} must be non negative." ) - assert std is not None, ( - "lambda_coverage != 0.0 requires attention mechanism" - " that could not be found in the model." + self.ignore_index = ignore_index + super(UnlikelihoodTokenLoss, self).__init__() + + self.likelihood_criterion = nn.NLLLoss( + ignore_index=ignore_index, reduction="none" ) - shard_state.update({"std_attn": attns.get("std"), - "coverage_attn": coverage}) + self.unlikelihood_coeff = unlikelihood_coeff - def _compute_loss(self, batch, output, target, std_attn=None, - coverage_attn=None, align_head=None, ref_align=None): + def compute_previous_context_tokens(self, target): + expanded_target = ( + target.unsqueeze(-1) + .expand(target.size(0), target.size(1), target.size(0)) + .permute(1, 2, 0) + ) - bottled_output = self._bottle(output) + previous_context_tokens = ( + expanded_target.tril(-1) + + torch.full_like(expanded_target, self.ignore_index).triu() + ) - scores = self.generator(bottled_output) - gtruth = target.view(-1) + # remove padded examples + previous_context_tokens = previous_context_tokens.masked_fill( + expanded_target.permute(0, 2, 1) == self.ignore_index, + self.ignore_index, + ) - loss = self.criterion(scores, gtruth) - if self.lambda_coverage != 0.0: - coverage_loss = self._compute_coverage_loss( - std_attn=std_attn, coverage_attn=coverage_attn) - loss += coverage_loss - if self.lambda_align != 0.0: - if align_head.dtype != loss.dtype: # Fix FP16 - align_head = align_head.to(loss.dtype) - if ref_align.dtype != loss.dtype: - ref_align = ref_align.to(loss.dtype) - align_loss = self._compute_alignement_loss( - align_head=align_head, ref_align=ref_align) - loss += align_loss - stats = self._stats(loss.clone(), scores, gtruth) + # remove current word for ctx + previous_context_tokens = previous_context_tokens.masked_fill( + previous_context_tokens == expanded_target.permute(0, 2, 1), + self.ignore_index, + ) - return loss, stats + return previous_context_tokens.permute(1, 0, 2) - def _compute_coverage_loss(self, std_attn, coverage_attn): - covloss = torch.min(std_attn, coverage_attn).sum() - covloss *= self.lambda_coverage - return covloss + def compute_unlikelihood_loss(self, output, target): + with torch.no_grad(): + previous_context_tokens = self.compute_previous_context_tokens( + target + ) - def _add_align_shard_state(self, shard_state, batch, range_start, - range_end, attns): - # attn_align should be in (batch_size, pad_tgt_size, pad_src_size) - attn_align = attns.get("align", None) - # align_idx should be a Tensor in size([N, 3]), N is total number - # of align src-tgt pair in current batch, each as - # ['sent_N°_in_batch', 'tgt_id+1', 'src_id'] (check AlignField) - align_idx = batch.align - assert attns is not None - assert attn_align is not None, ( - "lambda_align != 0.0 requires " "alignement attention head" - ) - assert align_idx is not None, ( - "lambda_align != 0.0 requires " "provide guided alignement" + probs = F.softmax(output, dim=-1) + probs = probs.view(-1, previous_context_tokens.size(1), probs.size(1)) + + log_probs = -torch.clamp((1.0 - probs), min=1e-5).log() + + one_hot_previous_context_tokens = torch.zeros_like(log_probs).scatter_( + -1, previous_context_tokens, 1 ) - pad_tgt_size, batch_size, _ = batch.tgt.size() - pad_src_size = batch.src[0].size(0) - align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] - ref_align = onmt.utils.make_batch_align_matrix( - align_idx, align_matrix_size, normalize=True - ) - # NOTE: tgt-src ref alignement that in range_ of shard - # (coherent with batch.tgt) - shard_state.update( - { - "align_head": attn_align, - "ref_align": ref_align[:, range_start:range_end, :], - } + one_hot_previous_context_tokens[:, :, self.ignore_index] = 0 + unlikelihood_loss = log_probs * one_hot_previous_context_tokens + unlikelihood_loss = unlikelihood_loss.view( + -1, unlikelihood_loss.size(2) ) - def _compute_alignement_loss(self, align_head, ref_align): - """Compute loss between 2 partial alignment matrix.""" - # align_head contains value in [0, 1) presenting attn prob, - # 0 was resulted by the context attention src_pad_mask - # So, the correspand position in ref_align should also be 0 - # Therefore, clip align_head to > 1e-18 should be bias free. - align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum() - align_loss *= self.lambda_align - return align_loss + return unlikelihood_loss - def _make_shard_state(self, batch, output, range_, attns=None): - range_start = range_[0] + self.tgt_shift_index - range_end = range_[1] - shard_state = { - "output": output, - "target": batch.tgt[range_start:range_end, :, 0], - } - if self.lambda_coverage != 0.0: - self._add_coverage_shard_state(shard_state, attns) - if self.lambda_align != 0.0: - self._add_align_shard_state( - shard_state, batch, range_start, range_end, attns - ) - return shard_state + def forward(self, output, target): + """ + output (FloatTensor): (seq_length x batch_size) x n_classes + target (LongTensor): seq_length x batch_size + """ + gtruth = target.view(-1) + + loss = self.likelihood_criterion(output, gtruth) + + unlikelihood_loss = self.compute_unlikelihood_loss(output, target) + + loss += self.unlikelihood_coeff * unlikelihood_loss.sum(-1) + return loss.sum() -class NMTLossCompute(CommonLossCompute): +class NMTLossCompute(LossComputeBase): """ Standard NMT Loss Computation. """ - def __init__(self, criterion, generator, normalization="sents", - lambda_coverage=0.0, lambda_align=0.0): + def __init__(self, criterion, generator, lambda_coverage=0.0, + lambda_align=0.0): super(NMTLossCompute, self).__init__(criterion, generator, - normalization=normalization, lambda_coverage=lambda_coverage, lambda_align=lambda_align, tgt_shift_index=1) + if isinstance(self.generator[-1], LogSparsemax): + self.ppl_criterion = SparsemaxLoss( + ignore_index=self.criterion.ignore_index, reduction="sum" + ) + else: + self.ppl_criterion = nn.NLLLoss( + ignore_index=self.criterion.ignore_index, reduction="sum" + ) -class LMLossCompute(CommonLossCompute): +class LMLossCompute(LossComputeBase): """ Standard LM Loss Computation. """ - def __init__(self, criterion, generator, normalization="sents", - lambda_coverage=0.0, lambda_align=0.0): + def __init__(self, criterion, generator, lambda_coverage=0.0, + lambda_align=0.0): super(LMLossCompute, self).__init__(criterion, generator, - normalization=normalization, lambda_coverage=lambda_coverage, lambda_align=lambda_align, tgt_shift_index=0) + if isinstance(self.generator[-1], LogSparsemax): + self.ppl_criterion = SparsemaxLoss( + ignore_index=self.criterion.ignore_index, reduction="sum" + ) + else: + self.ppl_criterion = nn.NLLLoss( + ignore_index=self.criterion.ignore_index, reduction="sum" + ) def filter_shard_state(state, shard_size=None): diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index be4dd981ca..ff8947348c 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -1,7 +1,7 @@ """ Statistics calculation utility """ import time -import math import sys +import math from onmt.utils.logging import logger @@ -16,11 +16,12 @@ class Statistics(object): * elapsed time """ - def __init__(self, loss=0, n_words=0, n_correct=0): + def __init__(self, loss=0, n_words=0, n_correct=0, log_ppl=0): self.loss = loss self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 + self.log_ppl = log_ppl self.start_time = time.time() @staticmethod @@ -80,6 +81,7 @@ def update(self, stat, update_n_src_words=False): self.loss += stat.loss self.n_words += stat.n_words self.n_correct += stat.n_correct + self.log_ppl += stat.log_ppl if update_n_src_words: self.n_src_words += stat.n_src_words @@ -94,7 +96,7 @@ def xent(self): def ppl(self): """ compute perplexity """ - return math.exp(min(self.loss / self.n_words, 100)) + return math.exp(min(self.log_ppl / self.n_words, 100)) def elapsed_time(self): """ compute elapsed time """