Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,38 @@ jobs:
-word_vec_size 16 -report_every 5 \
-rnn_size 16 -train_steps 10 \
-copy_attn
- name: Test LM training with label smoothing
run: |
python train.py \
-config data/lm_data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.src \
-model_task lm \
-encoder_type transformer_lm \
-decoder_type transformer_lm \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-label_smoothing 0.1 \
-dec_layers 2 -batch_size 10 \
-heads 4 -transformer_ff 64 \
-word_vec_size 16 -report_every 5 \
-rnn_size 16 -train_steps 10
- name: Test LM training with unlieklihood loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo 'unlikelihood'

run: |
python train.py \
-config data/lm_data.yaml \
-src_vocab /tmp/onmt.vocab.src \
-tgt_vocab /tmp/onmt.vocab.src \
-model_task lm \
-encoder_type transformer_lm \
-decoder_type transformer_lm \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-unlikelihood_coeff 1.0 \
-dec_layers 2 -batch_size 10 \
-heads 4 -transformer_ff 64 \
-word_vec_size 16 -report_every 5 \
-rnn_size 16 -train_steps 10
- name: Test Graph neural network training
run: |
python train.py \
Expand Down
7 changes: 4 additions & 3 deletions onmt/modules/copy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn

from onmt.utils.misc import aeq
from onmt.utils.loss import CommonLossCompute
from onmt.utils.loss import LossComputeBase


def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None,
Expand Down Expand Up @@ -177,7 +177,7 @@ def forward(self, scores, align, target):
return loss


class CommonCopyGeneratorLossCompute(CommonLossCompute):
class CommonCopyGeneratorLossCompute(LossComputeBase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to grasp the whole rationale behind the CommonLossCompute/LossComputeBase refactoring. Is the last big remaining difference only the log_ppl computation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Underlying question is: do we really need both CommonLossCompute and LossComputeBase anymore?)

Copy link
Collaborator Author

@funboarder13920 funboarder13920 Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _compute_loss, _make_shard_state and the way to use the generator are different between CopyGeneratorLoss and the other classes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it in one class, the code is already not very clear, it's not going to be worse. If we do that CopyGenerator will override _compute_loss, _compute_log_ppl and _compute_alignement_loss will only be used in the compute_loss of the main class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it in one class, the code is already not very clear, it's not going to be worse. If we do that CopyGenerator will override _compute_loss, _compute_log_ppl and _compute_alignement_loss will only be used in the compute_loss of the main class

Yes I think this might be a bit better to explicitly override this method instead of having a full class that we don't really know what it's for unless we look at this specific CopyGeneratorLoss.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged it, the ppl part is not nice. Also there is a normalization args that was not used anywhere, I will investigate to see if the normalization process disappeared by mistake

Copy link
Collaborator Author

@funboarder13920 funboarder13920 Feb 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalization was already not used a year ago

def __init__(self, criterion, generator, normalization="sents",

"""Common Copy Generator Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0, tgt_shift_index=1):
Expand Down Expand Up @@ -231,7 +231,8 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
target_data[correct_mask] += offset_align

# Compute sum of perplexities for stats
stats = self._stats(loss.sum().clone(), scores_data, target_data)
stats = self._stats(loss.sum().clone(), loss.sum().clone(),
scores_data, target_data)

# this part looks like it belongs in CopyGeneratorLoss
if self.normalize_by_length:
Expand Down
6 changes: 6 additions & 0 deletions onmt/modules/sparse_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def forward(self, input, target):
elif self.reduction == 'elementwise_mean':
loss = loss.sum() / size
return loss


class ExpandedSparsemaxLoss(SparsemaxLoss):
def forward(self, input, target):
gtruth = target.view(-1)
return super(ExpandedSparsemaxLoss, self).forward(input, gtruth)
8 changes: 8 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,14 @@ def _add_train_general_opts(parser):
"Set to zero to turn off label smoothing. "
"For more detailed information, see: "
"https://arxiv.org/abs/1512.00567")
group.add('--unlikelihood_coeff', '-unlikelihood_coeff', type=float,
default=0.0,
help="Loss coefficient for token unlikelihood loss. "
"Usually set to 1. max_generator_batches option will "
"limit the neighbourhood size of the unlikelihood loss. "
"For more detailed information, see: "
"https://arxiv.org/abs/1908.04319 and "
"https://openreview.net/forum?id=SJeYe0NtvH")
group.add('--average_decay', '-average_decay', type=float, default=0,
help="Moving average decay. "
"Set to other than 0 (e.g. 1e-4) to activate. "
Expand Down
36 changes: 36 additions & 0 deletions onmt/tests/pull_request_chk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,42 @@ ${PYTHON} onmt/bin/train.py \
-rnn_size 16 -train_steps 10 \
-copy_attn >> ${LOG_FILE} 2>&1
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}

echo -n " [+] Testing LM training with label smoothing..."
${PYTHON} onmt/bin/train.py \
-config ${DATA_DIR}/lm_data.yaml \
-src_vocab $TMP_OUT_DIR/onmt.vocab.src \
-tgt_vocab $TMP_OUT_DIR/onmt.vocab.src \
-model_task lm \
-encoder_type transformer_lm \
-decoder_type transformer_lm \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-label_smoothing 0.1 \
-dec_layers 2 -batch_size 10 \
-heads 4 -transformer_ff 64 \
-word_vec_size 16 -report_every 5 \
-rnn_size 16 -train_steps 10 >> ${LOG_FILE} 2>&1
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}

echo -n " [+] Testing LM training with unlikelihood loss..."
${PYTHON} onmt/bin/train.py \
-config ${DATA_DIR}/lm_data.yaml \
-src_vocab $TMP_OUT_DIR/onmt.vocab.src \
-tgt_vocab $TMP_OUT_DIR/onmt.vocab.src \
-model_task lm \
-encoder_type transformer_lm \
-decoder_type transformer_lm \
-src_vocab_size 1000 \
-tgt_vocab_size 1000 \
-unlikelihood_coeff 1 \
-dec_layers 2 -batch_size 10 \
-heads 4 -transformer_ff 64 \
-word_vec_size 16 -report_every 5 \
-rnn_size 16 -train_steps 10 >> ${LOG_FILE} 2>&1
[ "$?" -eq 0 ] || error_exit
echo "Succeeded" | tee -a ${LOG_FILE}*
rm $TMP_OUT_DIR/onmt.vocab*

Expand Down
Loading