diff --git a/docs/code/utils.rst b/docs/code/utils.rst index 5c14af6b0..21aeb0cbf 100644 --- a/docs/code/utils.rst +++ b/docs/code/utils.rst @@ -12,6 +12,10 @@ Frequent Use .. autoclass:: texar.torch.utils.AverageRecorder :members: +:hidden:`collect_trainable_variables` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.collect_trainable_variables + :hidden:`compat_as_text` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.torch.utils.compat_as_text @@ -20,6 +24,17 @@ Frequent Use ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: texar.torch.utils.write_paired_text +Variables +========= + +:hidden:`collect_trainable_variables` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.collect_trainable_variables + +:hidden:`add_variable` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: texar.torch.utils.add_variable + IO === diff --git a/docs/examples.md b/docs/examples.md index 429d3b02b..b10bcd3db 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -22,6 +22,10 @@ More examples are continuously added... * [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation * [xlnet](https://github.com/asyml/texar-pytorch/tree/master/examples/xlnet): Pre-trained XLNet model for text representation +### GANs / Discriminator-supervision ### + +* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation + --- ## Examples by Tasks @@ -35,6 +39,11 @@ More examples are continuously added... * [seq2seq_attn](https://github.com/asyml/texar-pytorch/tree/master/examples/seq2seq_attn): Attentional seq2seq * [transformer](https://github.com/asyml/texar-pytorch/tree/master/examples/transformer): Transformer for machine translation +### Text Style Transfer ### + +* [text_style_transfer](https://github.com/asyml/texar-pytorch/tree/master/examples/text_style_transfer): Discriminator supervision for controlled text generation + + ### Classification ### * [bert](https://github.com/asyml/texar-pytorch/tree/master/examples/bert): Pre-trained BERT model for text representation diff --git a/examples/README.md b/examples/README.md index 72fe20c39..790b232d0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -22,6 +22,10 @@ More examples are continuously added... * [vae_text](./vae_text): VAE language model +### GANs / Discriminiator-supervision ### + +* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation + ### Classifier / Sequence Prediction ### * [bert](./bert): Pre-trained BERT model for text representation @@ -43,6 +47,10 @@ More examples are continuously added... * [seq2seq_attn](./seq2seq_attn): Attentional seq2seq * [transformer](./transformer): Transformer for machine translation +### Text Style Transfer ### + +* [text_style_transfer](./text_style_transfer): Discriminator supervision for controlled text generation + ### Classification ### * [bert](./bert): Pre-trained BERT model for text representation diff --git a/examples/text_style_transfer/README.md b/examples/text_style_transfer/README.md new file mode 100644 index 000000000..94c500adb --- /dev/null +++ b/examples/text_style_transfer/README.md @@ -0,0 +1,106 @@ +# Text Style Transfer # + +This example implements a simplified variant of the `ctrl-gen` model from + +[Toward Controlled Generation of Text](https://arxiv.org/pdf/1703.00955.pdf) +*Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing; ICML 2017* + +The model roughly has an architecture of `Encoder--Decoder--Classifier`. Compared to the paper, following simplications are made: + + * Replaces the base Variational Autoencoder (VAE) model with an attentional Autoencoder (AE) -- VAE is not necessary in the text style transfer setting since we do not need to interpolate the latent space as in the paper. + * Attribute classifier (i.e., discriminator) is trained with real data only. Samples generated by the decoder are not used. + * Independency constraint is omitted. + +## Usage ## + +### Dataset ### +Download the yelp sentiment dataset with the following command: +``` +python prepare_data.py +``` + +### Train the model ### + +Train the model on the above data to do sentiment transfer. +``` +python main.py --config config +``` + +[config.py](./config.py) contains the data and mode configurations. + +* The model will first be pre-trained for a few epochs (specified in `config.py`). During pre-training, the `Encoder-Decoder` part is trained as an autoencoder, while the `Classifier` part is trained with the classification labels. +* Full-training is then performed for another few epochs. During full-training, the `Classifier` part is fixed, and the `Encoder-Decoder` part is trained to fit the classifier, along with continuing to minimize the autoencoding loss. + +(**Note:** When using your own dataset, make sure to set `max_decoding_length_train` and `max_decoding_length_infer` in [config.py](https://github.com/asyml/texar/blob/master/examples/text_style_transfer/config.py#L85-L86).) + +Training log is printed as below: +``` +gamma: 1.0, lambda_g: 0.0 +step: 1, loss_d: 0.6934 accu_d: 0.4844 +step: 1, loss_g_ae: 9.1392 +step: 500, loss_d: 0.1488 accu_d: 0.9484 +step: 500, loss_g_ae: 4.2884 +step: 1000, loss_d: 0.1215 accu_d: 0.9625 +step: 1000, loss_g_ae: 2.6201 +... +epoch: 1, loss_d: 0.0750 accu_d: 0.9688 +epoch: 1, loss_g_ae: 0.8832 +val: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2949 loss_d: 0.0702 accu_d: 0.9744 accu_g: 0.3022 accu_g_gdy: 0.2732 bleu: 60.8234 +test: loss_g: 0.0000 loss_g_ae: 0.0000 loss_g_class: 3.2359 loss_d: 0.0746 accu_d: 0.9733 accu_g: 0.3076 accu_g_gdy: 0.2791 bleu: 60.1810993 accu_g_gdy: 0.5993 bleu: 63.6671 +... + +``` +where: +- `loss_d` and `accu_d` are the classification loss/accuracy of the `Classifier` part. +- `loss_g_class` is the classification loss of the generated sentences. +- `loss_g_ae` is the autoencoding loss. +- `loss_g` is the joint loss `= loss_g_ae + lambda_g * loss_g_class`. +- `accu_g` is the classification accuracy of the generated sentences with soft represetations (i.e., Gumbel-softmax). +- `accu_g_gdy` is the classification accuracy of the generated sentences with greedy decoding. +- `bleu` is the BLEU score between the generated and input sentences. + +## Results ## + +Text style transfer has two primary goals: +1. The generated sentence should have desired attribute (e.g., positive/negative sentiment) +2. The generated sentence should keep the content of the original one + +We use automatic metrics to evaluate both: +* For (1), we can use a pre-trained classifier to classify the generated sentences and evaluate the accuracy (the higher the better). In this code we have not implemented a stand-alone classifier for evaluation, which could be very easy though. The `Classifier` part in the model gives a reasonably good estimation (i.e., `accu_g_gdy` in the above) of the accuracy. +* For (2), we evaluate the BLEU score between the generated sentences and the original sentences, i.e., `bleu` in the above (the higher the better) (See [Yang et al., 2018](https://arxiv.org/pdf/1805.11749.pdf) for more details.) + +The implementation here gives the following performance after 10 epochs of pre-training and 2 epochs of full-training: + +| Accuracy (by the `Classifier` part) | BLEU (with the original sentence) | +| -------------------------------------| ----------------------------------| +| 0.96 | 52.0 | + +Also refer to the following papers that used this code and compared to other text style transfer approaches: + +* [Unsupervised Text Style Transfer using Language Models as Discriminators](https://papers.nips.cc/paper/7959-unsupervised-text-style-transfer-using-language-models-as-discriminators.pdf). Zichao Yang, Zhiting Hu, Chris Dyer, Eric Xing, Taylor Berg-Kirkpatrick. NeurIPS 2018 +* [Structured Content Preservation for Unsupervised Text Style Transfer](https://arxiv.org/pdf/1810.06526.pdf). Youzhi Tian, Zhiting Hu, Zhou Yu. 2018 + +### Samples ### +Here are some randomly-picked samples. In each pair, the first sentence is the original sentence and the second is the generated. +``` +love , love love . +poor , poor poor . + +good atmosphere . +disgusted atmosphere . + +the donuts are good sized and very well priced . +the donuts are disgusted sized and very _num_ priced . + +it is always clean and the staff is super friendly . +it is nasty overpriced and the staff is super cold . + +super sweet place . +super plain place . + +highly recommended . +horrible horrible . + +very good ingredients . +very disgusted ingredients . +``` diff --git a/examples/text_style_transfer/config.py b/examples/text_style_transfer/config.py new file mode 100644 index 000000000..bb9f1dc13 --- /dev/null +++ b/examples/text_style_transfer/config.py @@ -0,0 +1,108 @@ +"""Config +""" +# pylint: disable=invalid-name + +import copy +from typing import Dict, Any + +# Total number of training epochs (including pre-train and full-train) +max_nepochs = 12 +pretrain_nepochs = 10 # Number of pre-train epochs (training as autoencoder) +display = 500 # Display the training results every N training steps. +# Display the dev results every N training steps (set to a +# very large value to disable it). +display_eval = 1e10 + +sample_path = './samples' +checkpoint_path = './checkpoints' +restore = '' # Model snapshot to restore from + +lambda_g = 0.1 # Weight of the classification loss +gamma_decay = 0.5 # Gumbel-softmax temperature anneal rate + +max_seq_length = 16 # Maximum sequence length in dataset w/o BOS token + +train_data: Dict[str, Any] = { + 'batch_size': 64, + # 'seed': 123, + 'datasets': [ + { + 'files': './data/yelp/sentiment.train.text', + 'vocab_file': './data/yelp/vocab', + 'data_name': '' + }, + { + 'files': './data/yelp/sentiment.train.labels', + 'data_type': 'int', + 'data_name': 'labels' + } + ], + 'name': 'train' +} + +val_data = copy.deepcopy(train_data) +val_data['datasets'][0]['files'] = './data/yelp/sentiment.dev.text' +val_data['datasets'][1]['files'] = './data/yelp/sentiment.dev.labels' + +test_data = copy.deepcopy(train_data) +test_data['datasets'][0]['files'] = './data/yelp/sentiment.test.text' +test_data['datasets'][1]['files'] = './data/yelp/sentiment.test.labels' + +model = { + 'dim_c': 200, + 'dim_z': 500, + 'embedder': { + 'dim': 100, + }, + 'max_seq_length': max_seq_length, + 'encoder': { + 'rnn_cell': { + 'type': 'GRUCell', + 'kwargs': { + 'num_units': 700 + }, + 'dropout': { + 'input_keep_prob': 0.5 + } + } + }, + 'decoder': { + 'rnn_cell': { + 'type': 'GRUCell', + 'kwargs': { + 'num_units': 700, + }, + 'dropout': { + 'input_keep_prob': 0.5, + 'output_keep_prob': 0.5 + }, + }, + 'attention': { + 'type': 'BahdanauAttention', + 'kwargs': { + 'num_units': 700, + }, + 'attention_layer_size': 700, + }, + 'max_decoding_length_train': 21, + 'max_decoding_length_infer': 20, + }, + 'classifier': { + 'kernel_size': [3, 4, 5], + 'out_channels': 128, + 'data_format': 'channels_last', + 'other_conv_kwargs': [[{'padding': 1}, {'padding': 2}, {'padding': 2}]], + 'dropout_conv': [1], + 'dropout_rate': 0.5, + 'num_dense_layers': 0, + 'num_classes': 1 + }, + 'opt': { + 'optimizer': { + 'type': 'Adam', + 'kwargs': { + 'lr': 3e-4, + }, + }, + }, +} diff --git a/examples/text_style_transfer/ctrl_gen_model.py b/examples/text_style_transfer/ctrl_gen_model.py new file mode 100644 index 000000000..aee2eaebf --- /dev/null +++ b/examples/text_style_transfer/ctrl_gen_model.py @@ -0,0 +1,227 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text style transfer +""" + +# pylint: disable=invalid-name, too-many-locals + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +import texar.torch as tx +from texar.torch.modules import WordEmbedder, UnidirectionalRNNEncoder, \ + MLPTransformConnector, AttentionRNNDecoder, \ + GumbelSoftmaxEmbeddingHelper, Conv1DClassifier +from texar.torch.utils import get_batch_size, collect_trainable_variables + + +class CtrlGenModel(nn.Module): + """Control + """ + def __init__(self, vocab: tx.data.Vocab, hparams=None): + super().__init__() + self.vocab = vocab + + self._hparams = tx.HParams(hparams, None) + + self.embedder = WordEmbedder(vocab_size=self.vocab.size, + hparams=self._hparams.embedder) + + self.encoder = UnidirectionalRNNEncoder( + input_size=self.embedder.dim, + hparams=self._hparams.encoder) # type: UnidirectionalRNNEncoder + + # Encodes label + self.label_connector = MLPTransformConnector( + output_size=self._hparams.dim_c, + linear_layer_dim=1) + + # Teacher-force decoding and the auto-encoding loss for G + self.decoder = AttentionRNNDecoder( + input_size=self.embedder.dim, + encoder_output_size=self.encoder.cell.hidden_size, + vocab_size=self.vocab.size, + token_embedder=self.embedder, + hparams=self._hparams.decoder) + + self.connector = MLPTransformConnector( + output_size=self.decoder.output_size, + linear_layer_dim=(self._hparams.dim_c + self._hparams.dim_z)) + + self.classifier = Conv1DClassifier( + in_channels=self.embedder.dim, + in_features=self._hparams.max_seq_length, + hparams=self._hparams.classifier) + + self.class_embedder = WordEmbedder(vocab_size=self.vocab.size, + hparams=self._hparams.embedder) + + # Creates optimizers + self.g_vars = collect_trainable_variables( + [self.decoder, self.connector, self.label_connector, + self.encoder, self.embedder]) + + self.d_vars = collect_trainable_variables( + [self.class_embedder, self.classifier]) + + def forward_D(self, inputs, f_labels): + + # Classification loss for the classifier + # Get inputs in correct format, [batch_size, channels, seq_length] + class_inputs = self.class_embedder(ids=inputs['text_ids'][:, 1:]) + class_logits, class_preds = self.classifier( + input=class_inputs, + sequence_length=inputs['length'] - 1) + + loss_d = F.binary_cross_entropy_with_logits(class_logits, f_labels) + accu_d = tx.evals.accuracy(labels=f_labels, + preds=class_preds) + return { + "loss_d": loss_d, + "accu_d": accu_d + } + + def forward_G(self, inputs, f_labels, gamma, lambda_g, mode): + + # text_ids for encoder, with BOS token removed + enc_text_ids = inputs['text_ids'][:, 1:].long() + enc_inputs = self.embedder(enc_text_ids) + enc_outputs, final_state = self.encoder( + enc_inputs, + sequence_length=inputs['length'] - 1) + z = final_state[:, self._hparams.dim_c:] + + labels = inputs['labels'].view(-1, 1).float() + + c = self.label_connector(labels) + c_ = self.label_connector(1 - labels) + h = torch.cat([c, z], dim=1) + h_ = torch.cat([c_, z], dim=1) + + # Gumbel-softmax decoding, used in training + start_tokens = torch.ones_like(inputs['labels'].long()) * \ + self.vocab.bos_token_id + end_token = self.vocab.eos_token_id + + if mode == 'train': + g_outputs, _, _ = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + initial_state=self.connector(h), + inputs=inputs['text_ids'], + sequence_length=inputs['length'] - 1 + ) + + loss_g_ae = tx.losses.sequence_sparse_softmax_cross_entropy( + labels=inputs['text_ids'][:, 1:], + logits=g_outputs.logits, + sequence_length=inputs['length'] - 1, + average_across_timesteps=True, + sum_over_timesteps=False + ) + if lambda_g == 0: + ret = { + "loss_g_ae": loss_g_ae, + } + return ret + + else: + # for eval, there is no loss + loss_g_ae = 0 + + gumbel_helper = GumbelSoftmaxEmbeddingHelper( + start_tokens=start_tokens, + end_token=end_token, + tau=gamma) + + soft_outputs_, _, soft_length_, = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + helper=gumbel_helper, + initial_state=self.connector(h_)) + + # Greedy decoding, used in eval + outputs_, _, length_ = self.decoder( + memory=enc_outputs, + memory_sequence_length=inputs['length'] - 1, + decoding_strategy='infer_greedy', + initial_state=self.connector(h_), + start_tokens=start_tokens, + end_token=end_token) + + # Get inputs in correct format, [batch_size, channels, seq_length] + soft_inputs = self.class_embedder(soft_ids=soft_outputs_.sample_id) + soft_logits, soft_preds = self.classifier( + input=soft_inputs, + sequence_length=soft_length_) + + loss_g_class = F.binary_cross_entropy_with_logits(soft_logits, + (1 - f_labels)) + + # Accuracy on greedy-decoded samples, for training progress monitoring + greedy_inputs = self.class_embedder(ids=outputs_.sample_id) + _, gdy_preds = self.classifier( + input=greedy_inputs, + sequence_length=length_) + + accu_g_gdy = tx.evals.accuracy( + labels=1 - f_labels, preds=gdy_preds) + + # Accuracy on soft samples, for training progress monitoring + accu_g = tx.evals.accuracy(labels=1 - f_labels, + preds=soft_preds) + loss_g = loss_g_ae + lambda_g * loss_g_class + ret = { + "loss_g": loss_g, + "loss_g_ae": loss_g_ae, + "loss_g_class": loss_g_class, + "accu_g": accu_g, + "accu_g_gdy": accu_g_gdy, + } + if mode == 'eval': + ret.update({'outputs': outputs_}) + return ret + + def forward(self, inputs, gamma, lambda_g, mode, component=None): + + f_labels = inputs['labels'].float() + if mode == 'train': + if component == 'D': + ret_d = self.forward_D(inputs, f_labels) + return ret_d + + elif component == 'G': + ret_g = self.forward_G(inputs, f_labels, gamma, lambda_g, mode) + return ret_g + + else: + ret_d = self.forward_D(inputs, f_labels) + ret_g = self.forward_G(inputs, f_labels, gamma, lambda_g, mode) + rets = { + "batch_size": get_batch_size(inputs['text_ids']), + "loss_g": ret_g['loss_g'], + "loss_g_ae": ret_g['loss_g_ae'], + "loss_g_class": ret_g['loss_g_class'], + "loss_d": ret_d['loss_d'], + "accu_d": ret_d['accu_d'], + "accu_g": ret_g['accu_g'], + "accu_g_gdy": ret_g['accu_g_gdy'] + } + samples = { + "original": inputs['text_ids'][:, 1:], + "transferred": ret_g['outputs'].sample_id + } + return rets, samples diff --git a/examples/text_style_transfer/main.py b/examples/text_style_transfer/main.py new file mode 100644 index 000000000..16a5eff54 --- /dev/null +++ b/examples/text_style_transfer/main.py @@ -0,0 +1,209 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Text style transfer + +This is a simplified implementation of: + +Toward Controlled Generation of Text, ICML2017 +Zhiting Hu, Zichao Yang, Xiaodan Liang, Ruslan Salakhutdinov, Eric Xing + +Download the data with the cmd: + +$ python prepare_data.py + +Train the model with the cmd: + +$ python main.py --config config +""" + +# pylint: disable=invalid-name, too-many-locals, too-many-arguments, no-member + +import os +import importlib +import argparse +import numpy as np +import torch + +import texar.torch as tx + +from ctrl_gen_model import CtrlGenModel + +parser = argparse.ArgumentParser() + +parser.add_argument('--config', default='config', help="The config to use.") + +args = parser.parse_args() + +config = importlib.import_module(args.config) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def main(): + # Data + train_data = tx.data.MultiAlignedData(hparams=config.train_data, + device=device) + val_data = tx.data.MultiAlignedData(hparams=config.val_data, + device=device) + test_data = tx.data.MultiAlignedData(hparams=config.test_data, + device=device) + vocab = train_data.vocab(0) + + # Each training batch is used twice: once for updating the generator and + # once for updating the discriminator. Feedable data iterator is used for + # such case. + iterator = tx.data.DataIterator( + {'train': train_data, + 'val': val_data, 'test': test_data}) + + # Model + gamma_ = 1. + lambda_g_ = 0. + + # Model + model = CtrlGenModel(vocab, hparams=config.model) + model.to(device) + + # create optimizers + train_op_d = tx.core.get_optimizer( + params=model.d_vars, + hparams=config.model['opt'] + ) + + train_op_g = tx.core.get_optimizer( + params=model.g_vars, + hparams=config.model['opt'] + ) + + train_op_g_ae = tx.core.get_optimizer( + params=model.g_vars, + hparams=config.model['opt'] + ) + + def _train_epoch(gamma_, lambda_g_, epoch, verbose=True): + model.train() + avg_meters_d = tx.utils.AverageRecorder(size=10) + avg_meters_g = tx.utils.AverageRecorder(size=10) + iterator.switch_to_dataset("train") + step = 0 + for batch in iterator: + train_op_d.zero_grad() + train_op_g_ae.zero_grad() + train_op_g.zero_grad() + step += 1 + + vals_d = model(batch, gamma_, lambda_g_, mode="train", + component="D") + loss_d = vals_d['loss_d'] + loss_d.backward() + train_op_d.step() + recorder_d = {key: value.detach().cpu().data + for (key, value) in vals_d.items()} + avg_meters_d.add(recorder_d) + + vals_g = model(batch, gamma_, lambda_g_, mode="train", + component="G") + + if epoch <= config.pretrain_nepochs: + loss_g_ae = vals_g['loss_g_ae'] + loss_g_ae.backward() + train_op_g_ae.step() + else: + loss_g = vals_g['loss_g'] + loss_g.backward() + train_op_g.step() + + recorder_g = {key: value.detach().cpu().data + for (key, value) in vals_g.items()} + avg_meters_g.add(recorder_g) + + if verbose and (step == 1 or step % config.display == 0): + print('step: {}, {}'.format(step, avg_meters_d.to_str(4))) + print('step: {}, {}'.format(step, avg_meters_g.to_str(4))) + + if verbose and step % config.display_eval == 0: + _eval_epoch(gamma_, lambda_g_, epoch) + + print('epoch: {}, {}'.format(epoch, avg_meters_d.to_str(4))) + print('epoch: {}, {}'.format(epoch, avg_meters_g.to_str(4))) + + @torch.no_grad() + def _eval_epoch(gamma_, lambda_g_, epoch, val_or_test='val'): + model.eval() + avg_meters = tx.utils.AverageRecorder() + iterator.switch_to_dataset(val_or_test) + for batch in iterator: + vals, samples = model(batch, gamma_, lambda_g_, mode='eval') + + batch_size = vals.pop('batch_size') + + # Computes BLEU + hyps = tx.data.map_ids_to_strs(samples['transferred'].cpu(), vocab) + + refs = tx.data.map_ids_to_strs(samples['original'].cpu(), vocab) + refs = np.expand_dims(refs, axis=1) + + bleu = tx.evals.corpus_bleu_moses(refs, hyps) + vals['bleu'] = bleu + + avg_meters.add(vals, weight=batch_size) + + # Writes samples + tx.utils.write_paired_text( + refs.squeeze(), hyps, + os.path.join(config.sample_path, 'val.%d' % epoch), + append=True, mode='v') + + print('{}: {}'.format( + val_or_test, avg_meters.to_str(precision=4))) + + return avg_meters.avg() + + os.makedirs(config.sample_path, exist_ok=True) + os.makedirs(config.checkpoint_path, exist_ok=True) + + # Runs the logics + if config.restore: + print('Restore from: {}'.format(config.restore)) + ckpt = torch.load(args.restore) + model.load_state_dict(ckpt['model']) + train_op_d.load_state_dict(ckpt['optimizer_d']) + train_op_g.load_state_dict(ckpt['optimizer_g']) + + for epoch in range(1, config.max_nepochs + 1): + if epoch > config.pretrain_nepochs: + # Anneals the gumbel-softmax temperature + gamma_ = max(0.001, gamma_ * config.gamma_decay) + lambda_g_ = config.lambda_g + print('gamma: {}, lambda_g: {}'.format(gamma_, lambda_g_)) + + # Train + _train_epoch(gamma_, lambda_g_, epoch) + + # Val + _eval_epoch(gamma_, lambda_g_, epoch, 'val') + + states = { + 'model': model.state_dict(), + 'optimizer_d': train_op_d.state_dict(), + 'optimizer_g': train_op_g.state_dict() + } + torch.save(states, os.path.join(config.checkpoint_path, 'ckpt')) + + # Test + _eval_epoch(gamma_, lambda_g_, epoch, 'test') + + +if __name__ == '__main__': + main() diff --git a/examples/text_style_transfer/prepare_data.py b/examples/text_style_transfer/prepare_data.py new file mode 100644 index 000000000..c01da312e --- /dev/null +++ b/examples/text_style_transfer/prepare_data.py @@ -0,0 +1,37 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Downloads data. +""" +import texar.torch as tx + + +def prepare_data(): + """Downloads data. + """ + tx.data.maybe_download( + urls='https://drive.google.com/file/d/' + '1HaUKEYDBEk6GlJGmXwqYteB-4rS9q8Lg/view?usp=sharing', + path='./', + filenames='yelp.zip', + extract=True) + + +def main(): + """Entrypoint. + """ + prepare_data() + + +if __name__ == '__main__': + main() diff --git a/texar/torch/utils/__init__.py b/texar/torch/utils/__init__.py index d33081eb5..d67a31aca 100644 --- a/texar/torch/utils/__init__.py +++ b/texar/torch/utils/__init__.py @@ -21,3 +21,4 @@ from texar.torch.utils.shapes import * from texar.torch.utils.utils import * from texar.torch.utils.utils_io import * +from texar.torch.utils.variables import * diff --git a/texar/torch/utils/variables.py b/texar/torch/utils/variables.py new file mode 100644 index 000000000..687ed1c40 --- /dev/null +++ b/texar/torch/utils/variables.py @@ -0,0 +1,68 @@ +# Copyright 2019 The Texar Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utility functions related to variables. +""" + +from typing import List, Tuple, Union, Set +import torch.nn as nn + +from texar.torch.module_base import ModuleBase + +__all__ = [ + "add_variable", + "collect_trainable_variables" +] + + +def add_variable( + variable: Union[List[nn.Parameter], Tuple[nn.Parameter], nn.Parameter], + var_list: Set[nn.Parameter]): + r"""Adds variable to a given list. + + Args: + variable: A (list of) variable(s). + var_list (set): The set where the trainable parameters are added to. + """ + if isinstance(variable, (list, tuple)): + for var in variable: + add_variable(var, var_list) + else: + if variable not in var_list: + var_list.add(variable) + + +def collect_trainable_variables( + modules: Union[ModuleBase, List[ModuleBase]] +): + r"""Collects all trainable variables of modules. + + Trainable variables included in multiple modules occur only once in the + returned list. + + Args: + modules: A (list of) instance of the subclasses of + :class:`~texar.torch.modules.ModuleBase`. + + Returns: + A list of trainable variables in the modules. + """ + if not isinstance(modules, (list, tuple)): + modules = [modules] + + var_list: Set[nn.Parameter] = set() + for mod in modules: + add_variable(mod.trainable_variables, var_list) + + return list(var_list)