From 032da9b1941f580734646d010c3ed1d86358332b Mon Sep 17 00:00:00 2001 From: Jinfeng Rao Date: Sat, 19 Jan 2019 20:42:06 -0800 Subject: [PATCH 1/2] Batch support for TreeLSTM --- config.py | 1 + main.py | 35 ++++--- main_test.py | 228 +++++++++++++++++++++++++++++++++++++++++++ test.py | 49 ++++++++++ treelstm/__init__.py | 4 +- treelstm/dataset.py | 14 ++- treelstm/model.py | 179 +++++++++++++++++++++++++++++++++ treelstm/trainer.py | 113 +++++++++++++++++---- treelstm/tree.py | 14 ++- 9 files changed, 600 insertions(+), 37 deletions(-) create mode 100644 main_test.py create mode 100644 test.py diff --git a/config.py b/config.py index fa4463e..3b3559c 100644 --- a/config.py +++ b/config.py @@ -44,6 +44,7 @@ def parse_args(): cuda_parser = parser.add_mutually_exclusive_group(required=False) cuda_parser.add_argument('--cuda', dest='cuda', action='store_true') cuda_parser.add_argument('--no-cuda', dest='cuda', action='store_false') + cuda_parser.add_argument('--use_batch', dest='use_batch', action='store_true') parser.set_defaults(cuda=True) args = parser.parse_args() diff --git a/main.py b/main.py index db28a86..729532c 100644 --- a/main.py +++ b/main.py @@ -25,8 +25,21 @@ from treelstm import Trainer # CONFIG PARSER from config import parse_args +from main_test import get_avg_grad +def set_optimizer(model, lr, wd): + if args.optim == 'adam': + optimizer = optim.Adam(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + elif args.optim == 'adagrad': + optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + elif args.optim == 'sgd': + optimizer = optim.SGD(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + return optimizer + # MAIN BLOCK def main(): global args @@ -111,7 +124,7 @@ def main(): args.num_classes, args.sparse, args.freeze_embed) - criterion = nn.KLDivLoss() + criterion = nn.KLDivLoss(reduction='none') # for words common to dataset vocab and GLOVE, use GLOVE vectors # for other words in dataset vocab, use random normal vectors @@ -137,21 +150,14 @@ def main(): model.emb.weight.data.copy_(emb) model.to(device), criterion.to(device) - if args.optim == 'adam': - optimizer = optim.Adam(filter(lambda p: p.requires_grad, - model.parameters()), lr=args.lr, weight_decay=args.wd) - elif args.optim == 'adagrad': - optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, - model.parameters()), lr=args.lr, weight_decay=args.wd) - elif args.optim == 'sgd': - optimizer = optim.SGD(filter(lambda p: p.requires_grad, - model.parameters()), lr=args.lr, weight_decay=args.wd) + optimizer = set_optimizer(model, args.lr, args.wd) metrics = Metrics(args.num_classes) # create trainer object for training and testing trainer = Trainer(args, model, criterion, optimizer, device) - best = -float('inf') + best, last_dev_loss = -float('inf'), float('inf') + curr_lr = args.lr for epoch in range(args.epochs): train_loss = trainer.train(train_dataset) train_loss, train_pred = trainer.test(train_dataset) @@ -171,6 +177,13 @@ def main(): logger.info('==> Epoch {}, Test \tLoss: {}\tPearson: {}\tMSE: {}'.format( epoch, test_loss, test_pearson, test_mse)) + if dev_loss > last_dev_loss: + curr_lr = curr_lr / 5 + trainer.optimizer = set_optimizer(model, curr_lr, args.wd) + print('reset lr to {}'.format(curr_lr)) + + last_dev_loss = dev_loss + if best < test_pearson: best = test_pearson checkpoint = { diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..6c996c1 --- /dev/null +++ b/main_test.py @@ -0,0 +1,228 @@ +from __future__ import division +from __future__ import print_function + +import os +import random +import logging + +import torch +import torch.nn as nn +import torch.optim as optim + +# IMPORT CONSTANTS +from treelstm import Constants +# NEURAL NETWORK MODULES/LAYERS +from treelstm import SimilarityTreeLSTM +# DATA HANDLING CLASSES +from treelstm import Vocab +# DATASET CLASS FOR SICK DATASET +from treelstm import SICKDataset +# METRICS CLASS FOR EVALUATION +from treelstm import Metrics +# UTILITY FUNCTIONS +from treelstm import utils +# TRAIN AND TEST HELPER FUNCTIONS +from treelstm import Trainer +# CONFIG PARSER +from config import parse_args + + +def set_optimizer(model, lr, wd): + if args.optim == 'adam': + optimizer = optim.Adam(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + elif args.optim == 'adagrad': + optimizer = optim.Adagrad(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + elif args.optim == 'sgd': + optimizer = optim.SGD(filter(lambda p: p.requires_grad, + model.parameters()), lr=lr, weight_decay=wd) + return optimizer + + +def get_avg_grad(named_parameters): + layers, avg_data, avg_grads = [], [], [] + for name, param in named_parameters: + if (param.requires_grad) and ("bias" not in name): + layers.append(name) + avg_data.append(param.data.abs().mean()) + if param.grad is not None: + avg_grads.append(param.grad.abs().mean()) + return layers, avg_data, avg_grads + +# MAIN BLOCK +def main(): + global args + args = parse_args() + # global logger + logger = logging.getLogger(__name__) + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s") + # file logger + fh = logging.FileHandler(os.path.join(args.save, args.expname)+'.log', mode='w') + fh.setLevel(logging.INFO) + fh.setFormatter(formatter) + logger.addHandler(fh) + # console logger + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + ch.setFormatter(formatter) + logger.addHandler(ch) + # argument validation + args.cuda = args.cuda and torch.cuda.is_available() + device = torch.device("cuda:0" if args.cuda else "cpu") + if args.sparse and args.wd != 0: + logger.error('Sparsity and weight decay are incompatible, pick one!') + exit() + logger.debug(args) + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + torch.backends.cudnn.benchmark = True + if not os.path.exists(args.save): + os.makedirs(args.save) + + train_dir = os.path.join(args.data, 'train/') + dev_dir = os.path.join(args.data, 'dev/') + test_dir = os.path.join(args.data, 'test/') + + # write unique words from all token files + sick_vocab_file = os.path.join(args.data, 'sick.vocab') + if not os.path.isfile(sick_vocab_file): + token_files_b = [os.path.join(split, 'b.toks') for split in [train_dir, dev_dir, test_dir]] + token_files_a = [os.path.join(split, 'a.toks') for split in [train_dir, dev_dir, test_dir]] + token_files = token_files_a + token_files_b + sick_vocab_file = os.path.join(args.data, 'sick.vocab') + utils.build_vocab(token_files, sick_vocab_file) + + # get vocab object from vocab file previously written + vocab = Vocab(filename=sick_vocab_file, + data=[Constants.PAD_WORD, Constants.UNK_WORD, + Constants.BOS_WORD, Constants.EOS_WORD]) + logger.debug('==> SICK vocabulary size : %d ' % vocab.size()) + + # load SICK dataset splits + train_file = os.path.join(args.data, 'sick_train.pth') + if os.path.isfile(train_file): + train_dataset = torch.load(train_file) + else: + train_dataset = SICKDataset(train_dir, vocab, args.num_classes) + torch.save(train_dataset, train_file) + logger.debug('==> Size of train data : %d ' % len(train_dataset)) + dev_file = os.path.join(args.data, 'sick_dev.pth') + if os.path.isfile(dev_file): + dev_dataset = torch.load(dev_file) + else: + dev_dataset = SICKDataset(dev_dir, vocab, args.num_classes) + torch.save(dev_dataset, dev_file) + logger.debug('==> Size of dev data : %d ' % len(dev_dataset)) + test_file = os.path.join(args.data, 'sick_test.pth') + if os.path.isfile(test_file): + test_dataset = torch.load(test_file) + else: + test_dataset = SICKDataset(test_dir, vocab, args.num_classes) + torch.save(test_dataset, test_file) + logger.debug('==> Size of test data : %d ' % len(test_dataset)) + + # initialize model, criterion/loss_function, optimizer + model = SimilarityTreeLSTM( + vocab.size(), + args.input_dim, + args.mem_dim, + args.hidden_dim, + args.num_classes, + args.sparse, + args.freeze_embed) + criterion = nn.KLDivLoss(reduce=False) + + # for words common to dataset vocab and GLOVE, use GLOVE vectors + # for other words in dataset vocab, use random normal vectors + emb_file = os.path.join(args.data, 'sick_embed.pth') + if os.path.isfile(emb_file): + emb = torch.load(emb_file) + else: + # load glove embeddings and vocab + glove_vocab, glove_emb = utils.load_word_vectors( + os.path.join(args.glove, 'glove.840B.300d')) + logger.debug('==> GLOVE vocabulary size: %d ' % glove_vocab.size()) + emb = torch.zeros(vocab.size(), glove_emb.size(1), dtype=torch.float, device=device) + emb.normal_(0, 0.05) + # zero out the embeddings for padding and other special words if they are absent in vocab + for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, + Constants.BOS_WORD, Constants.EOS_WORD]): + emb[idx].zero_() + for word in vocab.labelToIdx.keys(): + if glove_vocab.getIndex(word): + emb[vocab.getIndex(word)] = glove_emb[glove_vocab.getIndex(word)] + torch.save(emb, emb_file) + # plug these into embedding matrix inside model + model.emb.weight.data.copy_(emb) + + model.to(device), criterion.to(device) + optimizer = set_optimizer(model, args.lr, args.wd) + metrics = Metrics(args.num_classes) + + # create trainer object for training and testing + trainer = Trainer(args, model, criterion, optimizer, device) + + init_layers, init_avg_data, init_avg_grad = get_avg_grad(model.named_parameters()) + best, last_dev_loss = -float('inf'), float('inf') + dataset = train_dataset + + for epoch in range(args.epochs): + model.train() + optimizer.zero_grad() + total_loss = 0.0 + outputs_nobatch, losses_nobatch = [], [] + lstates_nobatch, rstates_nobatch = [], [] + for idx in range(args.batchsize): + ltree, linput, rtree, rinput, label = dataset[idx] + lroot, ltree = ltree[0], ltree[1] + rroot, rtree = rtree[0], rtree[1] + target = utils.map_label_to_target(label, dataset.num_classes) + linput, rinput = linput.to(device), rinput.to(device) + target = target.to(device) + linputs = model.emb(linput) + rinputs = model.emb(rinput) + lstate, lhidden = model.childsumtreelstm(lroot, linputs) + rstate, rhidden = model.childsumtreelstm(rroot, rinputs) + output = model.similarity(lstate, rstate) + #output = model(lroot, linput, rroot, rinput) + outputs_nobatch.append(output) + lstates_nobatch.append(lstate) + rstates_nobatch.append(rstate) + loss = criterion(output, target) + losses_nobatch.append(loss) + total_loss += loss.sum() + loss.sum().backward() + print(total_loss / args.batchsize) + layers1, avg_data1, avg_grad1 = get_avg_grad(model.named_parameters()) + + model.train() + optimizer.zero_grad() + total_loss = 0.0 + ltrees, linputs, rtrees, rinputs, labels = dataset.get_next_batch(0, args.batchsize) + targets = [] + for i in range(len(linputs)): + linputs[i] = linputs[i].to(device) + rinputs[i] = rinputs[i].to(device) + target = utils.map_label_to_target(labels[i], dataset.num_classes) + targets.append(target.to(device)) + targets = torch.cat(targets, dim=0) + linputs_tensor, rinputs_tensor = [], [] + for i in range(len(linputs)): + linputs_tensor.append(model.emb(linputs[i])) + rinputs_tensor.append(model.emb(rinputs[i])) + lstates, lhidden = model.childsumtreelstm(ltrees, linputs_tensor) + rstates, rhidden = model.childsumtreelstm(rtrees, rinputs_tensor) + outputs = model.similarity(lstates, rstates) + losses = criterion(outputs, targets) + total_loss += losses.sum() + losses.sum().backward() + layers2, avg_data2, avg_grad2 = get_avg_grad(model.named_parameters()) + import pdb; pdb.set_trace() + + +if __name__ == "__main__": + main() diff --git a/test.py b/test.py new file mode 100644 index 0000000..9f1b37f --- /dev/null +++ b/test.py @@ -0,0 +1,49 @@ +import torch +from .tree import Tree +from .model import ChildSumTreeLSTM + +t1_n1 = Tree() +t1_n1.idx = 0 +t1_n2 = Tree() +t1_n2.idx = 1 +t1_n1.add_child(t1_n2) +t1_n2.parent = t1_n1 +tree1 = {0: t1_n1, 1: t1_n2} + +t2_n1 = Tree() +t2_n1.idx = 0 +t2_n2 = Tree() +t2_n2.idx = 1 +t2_n3 = Tree() +t2_n3.idx = 2 +t2_n4 = Tree() +t2_n4.idx = 3 +t2_n3.add_child(t2_n1) +t2_n3.add_child(t2_n2) +t2_n1.parent = t2_n3 +t2_n2.parent = t2_n3 +t2_n2.add_child(t2_n4) +t2_n4.parent=t2_n2 +tree2 = {0: t2_n1, 1: t2_n2, 2: t2_n3, 3: t2_n4} +trees = [tree1, tree2] + +tensor1 = torch.Tensor(2, 10) +tensor2 = torch.Tensor(4, 10) +tensors = [tensor1, tensor2] + +tree_lstm = ChildSumTreeLSTM(10, 4) + +state1, hidden1 = tree_lstm(t1_n1, tensor1) +state2, hidden2 = tree_lstm(t2_n3, tensor2) + +print("state1", state1) +print("hidden1", hidden1) +print("state2", state2) +print("hidden2", hidden2) + +for tree in trees: + for idx, node in tree.items(): + node.state = None +batch_state, batch_hidden = tree_lstm(trees, tensors) +print(batch_state) +print(batch_hidden) \ No newline at end of file diff --git a/treelstm/__init__.py b/treelstm/__init__.py index 7621897..69cc93f 100644 --- a/treelstm/__init__.py +++ b/treelstm/__init__.py @@ -1,10 +1,10 @@ from . import Constants from .dataset import SICKDataset from .metrics import Metrics -from .model import SimilarityTreeLSTM +from .model import ChildSumTreeLSTM, SimilarityTreeLSTM from .trainer import Trainer from .tree import Tree from . import utils from .vocab import Vocab -__all__ = [Constants, SICKDataset, Metrics, SimilarityTreeLSTM, Trainer, Tree, Vocab, utils] +__all__ = [Constants, ChildSumTreeLSTM, SICKDataset, Metrics, SimilarityTreeLSTM, Trainer, Tree, Vocab, utils] diff --git a/treelstm/dataset.py b/treelstm/dataset.py index fa1caff..4c84359 100644 --- a/treelstm/dataset.py +++ b/treelstm/dataset.py @@ -37,6 +37,18 @@ def __getitem__(self, index): label = deepcopy(self.labels[index]) return (ltree, lsent, rtree, rsent, label) + def get_next_batch(self, index, batch_size): + ltrees, rtrees, lsents, rsents, labels = [], [], [], [], [] + for i in range(index, index+batch_size): + if i < self.size: + ltree, lsent, rtree, rsent, label = self.__getitem__(i) + ltrees.append(ltree[1]) + rtrees.append(rtree[1]) + lsents.append(lsent) + rsents.append(rsent) + labels.append(label) + return (ltrees, lsents, rtrees, rsents, labels) + def read_sentences(self, filename): with open(filename, 'r') as f: sentences = [self.read_sentence(line) for line in tqdm(f.readlines())] @@ -77,7 +89,7 @@ def read_tree(self, line): else: prev = tree idx = parent - return root + return (root, trees) def read_labels(self, filename): with open(filename, 'r') as f: diff --git a/treelstm/model.py b/treelstm/model.py index 1e51c23..e7e7a7d 100644 --- a/treelstm/model.py +++ b/treelstm/model.py @@ -1,14 +1,19 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.autograd import Variable +from torch import cuda from . import Constants +from .tree import Tree # module for childsumtreelstm class ChildSumTreeLSTM(nn.Module): def __init__(self, in_dim, mem_dim): super(ChildSumTreeLSTM, self).__init__() + self.bsz = 128 + self.max_num_children = 10 self.in_dim = in_dim self.mem_dim = mem_dim self.ioux = nn.Linear(self.in_dim, 3 * self.mem_dim) @@ -17,6 +22,9 @@ def __init__(self, in_dim, mem_dim): self.fh = nn.Linear(self.mem_dim, self.mem_dim) def node_forward(self, inputs, child_c, child_h): + # inputs: in_dim + # child_c: num_children * mem_dim + # child_h: num_children * mem_dim child_h_sum = torch.sum(child_h, dim=0, keepdim=True) iou = self.ioux(inputs) + self.iouh(child_h_sum) @@ -33,7 +41,32 @@ def node_forward(self, inputs, child_c, child_h): h = torch.mul(o, F.tanh(c)) return c, h + def batch_node_forward(self, inputs, child_c, child_h, num_children): + # inputs: bsz * in_dim + # child_c: bsz * max_num_child * in_dim + # child_h: bsz * max_num_child * in_dim + # num_children: bsz * num_children + bsz, max_num_children, _ = child_c.size() + child_h_sum = torch.sum(child_h, dim=1, keepdim=False) + + iou = self.ioux(inputs) + self.iouh(child_h_sum) + i, o, u = torch.split(iou, iou.size(1) // 3, dim=1) + i, o, u = F.sigmoid(i), F.sigmoid(o), F.tanh(u) + + fh, fx = self.fh(child_h), self.fx(inputs) + for idx in range(bsz): + fh[idx, :num_children[idx]] += fx[idx].repeat(num_children[idx], 1) + f = F.sigmoid(fh) + + fc = torch.mul(f, child_c) + + c = torch.mul(i, u) + torch.sum(fc, dim=1, keepdim=False) + h = torch.mul(o, F.tanh(c)) + return c, h + def forward(self, tree, inputs): + if isinstance(tree, list): + return self.batch_forward(tree, inputs) for idx in range(tree.num_children): self.forward(tree.children[idx], inputs) @@ -48,6 +81,140 @@ def forward(self, tree, inputs): return tree.state + def update_leaf_states(self, trees, inputs): + queue = [] + for tree_idx, tree in enumerate(trees): + for node_idx, node in tree.items(): + if node.num_children == 0: + node.visited = True + child_c = Variable(torch.zeros(1, self.mem_dim), requires_grad=True) + child_h = Variable(torch.zeros(1, self.mem_dim), requires_grad=True) + if cuda.is_available(): + child_c = child_c.cuda() + child_h = child_h.cuda() + input = inputs[tree_idx][node_idx] + queue.append((tree_idx, node_idx, input, child_c, child_h)) + + head = 0 + while head < len(queue): + idxes, encoder_inputs, children_c, children_h = [], [], [], [] + while head < len(queue) and len(encoder_inputs) < self.bsz: + tree_idx, node_idx, input, child_c, child_h = queue[head] + encoder_inputs.append(input.unsqueeze(0)) + children_c.append(child_c.unsqueeze(0)) + children_h.append(child_h.unsqueeze(0)) + head += 1 + encoder_inputs = torch.cat(encoder_inputs, dim=0) + children_c = torch.cat(children_c, dim=0) + children_h = torch.cat(children_h, dim=0) + num_children = [1] * len(encoder_inputs) + #print('leaf', len(encoder_inputs)) + batch_c, batch_h = self.batch_node_forward(encoder_inputs, children_c, children_h, num_children) + for i in range(batch_c.shape[0]): + idx = head - batch_c.shape[0] + i + tree_idx, node_idx, _, _, _ = queue[idx] + trees[tree_idx][node_idx].state = (batch_c[i], batch_h[i]) + for tree_idx, tree in enumerate(trees): + for node_idx, node in tree.items(): + if node.num_children == 0: + assert node.state is not None + return queue + + + def update_internal_node_states(self, trees, inputs, queue): + head = 0 + num_internal_nodes, depth = 0, 1 + while head < len(queue): + # find updatable parent nodes, and push to the end of queue + prev_num_nodes = len(queue) + while head < prev_num_nodes: + tree_idx, node_idx, _, _, _ = queue[head] + parent_node = trees[tree_idx][node_idx].parent + if parent_node is not None and not parent_node.visited: + can_visit = True + children_c, children_h = [], [] + for child_node in parent_node.children: + if child_node.state is None: + can_visit = False + break + else: + c, h = child_node.state + children_c.append(c.unsqueeze(0)) + children_h.append(h.unsqueeze(0)) + if can_visit: + parent_node.visited = True + children_c_var = Variable(torch.zeros(self.max_num_children, self.mem_dim), + requires_grad=True) + children_h_var = Variable(torch.zeros(self.max_num_children, self.mem_dim), + requires_grad=True) + if cuda.is_available(): + children_c_var = children_c_var.cuda() + children_h_var = children_h_var.cuda() + children_c_var[:len(children_c)] = torch.cat(children_c, dim=0) + children_h_var[:len(children_h)] = torch.cat(children_h, dim=0) + queue.append((tree_idx, parent_node.idx, inputs[tree_idx][parent_node.idx], + children_c_var, children_h_var)) + head += 1 + + depth += 1 + # update parent states + newhead = prev_num_nodes + while newhead < len(queue): + encoder_inputs, children_c, children_h, num_children = [], [], [], [] + while newhead < len(queue) and len(encoder_inputs) < self.bsz: + tree_idx, node_idx, input, child_c, child_h = queue[newhead] + curr_node = trees[tree_idx][node_idx] + encoder_inputs.append(input.unsqueeze(0)) + children_c.append(child_c.unsqueeze(0)) + children_h.append(child_h.unsqueeze(0)) + num_children.append(curr_node.num_children) + newhead += 1 + if len(encoder_inputs) > 0: + encoder_inputs = torch.cat(encoder_inputs, dim=0) + children_c = torch.cat(children_c, dim=0) + children_h = torch.cat(children_h, dim=0) + num_internal_nodes += len(encoder_inputs) + #print('internal', len(encoder_inputs)) + batch_c, batch_h = self.batch_node_forward( + encoder_inputs, children_c, children_h, num_children + ) + for i in range(batch_c.shape[0]): + idx = newhead - batch_c.shape[0] + i + tree_idx, node_idx, _, _, _ = queue[idx] + trees[tree_idx][node_idx].state = (batch_c[i], batch_h[i]) + for index in range(prev_num_nodes, len(queue)): + tree_idx, node_idx, _, _, _ = queue[idx] + assert trees[tree_idx][node_idx].state is not None + #print("num of internal nodes", num_internal_nodes) + + + def batch_forward(self, trees, inputs): + # trees: list[list[tree]] + # inputs: list[torch.Tensor(seqlen, emb_size)] + num_nodes = {} + for tree_idx, tree in enumerate(trees): + for node_idx, node in tree.items(): + assert node.state is None + assert not node.visited + depth = node.depth() + if depth not in num_nodes: + num_nodes[depth] = 0 + num_nodes[depth] += 1 + # print(num_nodes) + queue = self.update_leaf_states(trees, inputs) + self.update_internal_node_states(trees, inputs, queue) + root_c, root_h = [], [] + for tree in trees: + root = Tree.get_root(tree[0]) + root_c.append(root.state[0].unsqueeze(0)) + root_h.append(root.state[1].unsqueeze(0)) + for tree_idx, tree in enumerate(trees): + for node_idx, node in tree.items(): + assert node.state is not None + assert node.visited + return torch.cat(root_c, dim=0), torch.cat(root_h, dim=0) + + # module for distance-angle similarity class Similarity(nn.Module): def __init__(self, mem_dim, hidden_dim, num_classes): @@ -79,9 +246,21 @@ def __init__(self, vocab_size, in_dim, mem_dim, hidden_dim, num_classes, sparsit self.similarity = Similarity(mem_dim, hidden_dim, num_classes) def forward(self, ltree, linputs, rtree, rinputs): + if isinstance(ltree, list): + return self.batch_forward(ltree, linputs, rtree, rinputs) linputs = self.emb(linputs) rinputs = self.emb(rinputs) lstate, lhidden = self.childsumtreelstm(ltree, linputs) rstate, rhidden = self.childsumtreelstm(rtree, rinputs) output = self.similarity(lstate, rstate) return output + + def batch_forward(self, ltrees, linputs, rtrees, rinputs): + linputs_tensor, rinputs_tensor = [], [] + for i in range(len(linputs)): + linputs_tensor.append(self.emb(linputs[i])) + rinputs_tensor.append(self.emb(rinputs[i])) + lstates, lhidden = self.childsumtreelstm(ltrees, linputs_tensor) + rstates, rhidden = self.childsumtreelstm(rtrees, rinputs_tensor) + output = self.similarity(lstates, rstates) + return output diff --git a/treelstm/trainer.py b/treelstm/trainer.py index 8ad2e85..53322fd 100644 --- a/treelstm/trainer.py +++ b/treelstm/trainer.py @@ -5,6 +5,17 @@ from . import utils +def get_avg_grad(named_parameters): + layers, avg_data, avg_grads = [], [], [] + for name, param in named_parameters: + if (param.requires_grad) and ("bias" not in name): + layers.append(name) + avg_data.append(param.data.abs().mean()) + if param.grad is not None: + avg_grads.append(param.grad.abs().mean()) + return layers, avg_data, avg_grads + + class Trainer(object): def __init__(self, args, model, criterion, optimizer, device): super(Trainer, self).__init__() @@ -15,24 +26,61 @@ def __init__(self, args, model, criterion, optimizer, device): self.device = device self.epoch = 0 + def clear_states(self, trees): + for tree_idx, tree in enumerate(trees): + for node_idx, node in tree.items(): + assert node.state is not None + assert node.visited + node.state = None + node.visited = False + # helper function for training def train(self, dataset): self.model.train() self.optimizer.zero_grad() total_loss = 0.0 indices = torch.randperm(len(dataset), dtype=torch.long, device='cpu') - for idx in tqdm(range(len(dataset)), desc='Training epoch ' + str(self.epoch + 1) + ''): - ltree, linput, rtree, rinput, label = dataset[indices[idx]] - target = utils.map_label_to_target(label, dataset.num_classes) - linput, rinput = linput.to(self.device), rinput.to(self.device) - target = target.to(self.device) - output = self.model(ltree, linput, rtree, rinput) - loss = self.criterion(output, target) - total_loss += loss.item() - loss.backward() - if idx % self.args.batchsize == 0 and idx > 0: + if not self.args.use_batch: + for idx in range(len(dataset)): + ltree, linput, rtree, rinput, label = dataset[indices[idx]] + lroot, ltree = ltree[0], ltree[1] + rroot, rtree = rtree[0], rtree[1] + target = utils.map_label_to_target(label, dataset.num_classes) + linput, rinput = linput.to(self.device), rinput.to(self.device) + target = target.to(self.device) + output = self.model(lroot, linput, rroot, rinput) + loss = self.criterion(output, target) + total_loss += loss.sum() + loss.sum().backward() + if (idx + 1) % self.args.batchsize == 0 and idx > 0: + self.optimizer.step() + self.optimizer.zero_grad() + else: + for idx in range(0, len(dataset), self.args.batchsize): + ltrees, rtrees, linputs, rinputs, labels = [], [], [], [], [] + for new_index in indices[idx: min(len(dataset), idx+self.args.batchsize)]: + ltree, lsent, rtree, rsent, label = dataset[new_index] + ltrees.append(ltree[1]) + rtrees.append(rtree[1]) + linputs.append(lsent) + rinputs.append(rsent) + labels.append(label) + + targets = [] + for i in range(len(linputs)): + linputs[i] = linputs[i].to(self.device) + rinputs[i] = rinputs[i].to(self.device) + target = utils.map_label_to_target(labels[i], dataset.num_classes) + targets.append(target.to(self.device)) + targets = torch.cat(targets, dim=0) + outputs = self.model(ltrees, linputs, rtrees, rinputs) + loss = self.criterion(outputs, targets) + total_loss += loss.sum() + loss.sum().backward() self.optimizer.step() self.optimizer.zero_grad() + self.clear_states(ltrees) + self.clear_states(rtrees) self.epoch += 1 return total_loss / len(dataset) @@ -42,15 +90,38 @@ def test(self, dataset): with torch.no_grad(): total_loss = 0.0 predictions = torch.zeros(len(dataset), dtype=torch.float, device='cpu') - indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu') - for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''): - ltree, linput, rtree, rinput, label = dataset[idx] - target = utils.map_label_to_target(label, dataset.num_classes) - linput, rinput = linput.to(self.device), rinput.to(self.device) - target = target.to(self.device) - output = self.model(ltree, linput, rtree, rinput) - loss = self.criterion(output, target) - total_loss += loss.item() - output = output.squeeze().to('cpu') - predictions[idx] = torch.dot(indices, torch.exp(output)) + if not self.args.use_batch: + indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu') + for idx in range(len(dataset)): + ltree, linput, rtree, rinput, label = dataset[idx] + lroot, ltree = ltree[0], ltree[1] + rroot, rtree = rtree[0], rtree[1] + target = utils.map_label_to_target(label, dataset.num_classes) + linput, rinput = linput.to(self.device), rinput.to(self.device) + target = target.to(self.device) + output = self.model(lroot, linput, rroot, rinput) + loss = self.criterion(output, target) + total_loss += loss.sum() + output = output.squeeze().to('cpu') + predictions[idx] = torch.dot(indices, torch.exp(output)) + else: + indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu') + for idx in range(0, len(dataset), self.args.batchsize): + ltrees, linputs, rtrees, rinputs, labels = dataset.get_next_batch(idx, self.args.batchsize) + targets = [] + for i in range(len(linputs)): + linputs[i] = linputs[i].to(self.device) + rinputs[i] = rinputs[i].to(self.device) + target = utils.map_label_to_target(labels[i], dataset.num_classes) + targets.append(target.to(self.device)) + targets = torch.cat(targets, dim=0) + outputs = self.model(ltrees, linputs, rtrees, rinputs) + losses = self.criterion(outputs, targets) + total_loss += losses.sum() + outputs = outputs.to('cpu') + batch_indices = indices.repeat(len(ltrees), 1) + predictions[idx: idx+len(ltrees)] = \ + (batch_indices * torch.exp(outputs)).sum(dim=1, keepdim=False) + self.clear_states(ltrees) + self.clear_states(rtrees) return total_loss / len(dataset), predictions diff --git a/treelstm/tree.py b/treelstm/tree.py index ea85ff8..c47c267 100644 --- a/treelstm/tree.py +++ b/treelstm/tree.py @@ -2,6 +2,9 @@ class Tree(object): def __init__(self): self.parent = None + self.state = None + self.idx = -1 + self.visited = False self.num_children = 0 self.children = list() @@ -10,6 +13,13 @@ def add_child(self, child): self.num_children += 1 self.children.append(child) + @staticmethod + def get_root(node): + if node.parent is None: + return node + else: + return Tree.get_root(node.parent) + def size(self): if getattr(self, '_size'): return self._size @@ -20,8 +30,8 @@ def size(self): return self._size def depth(self): - if getattr(self, '_depth'): - return self._depth + #if getattr(self, '_depth'): + # return self._depth count = 0 if self.num_children > 0: for i in range(self.num_children): From e6e5cefc538be5f829737ff7e44626da139138a1 Mon Sep 17 00:00:00 2001 From: Jinfeng Rao Date: Sat, 19 Jan 2019 20:46:04 -0800 Subject: [PATCH 2/2] Update README.md --- README.md | 65 ++++--------------------------------------------------- 1 file changed, 4 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 68ee91a..eacb005 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,9 @@ - # Tree-Structured Long Short-Term Memory Networks -This is a [PyTorch](http://pytorch.org/) implementation of Tree-LSTM as described in the paper [Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks](http://arxiv.org/abs/1503.00075) by Kai Sheng Tai, Richard Socher, and Christopher Manning. On the semantic similarity task using the SICK dataset, this implementation reaches: - - Pearson's coefficient: `0.8492` and MSE: `0.2842` using hyperparameters `--lr 0.010 --wd 0.0001 --optim adagrad --batchsize 25` - - Pearson's coefficient: `0.8674` and MSE: `0.2536` using hyperparameters `--lr 0.025 --wd 0.0001 --optim adagrad --batchsize 25 --freeze_embed` - - Pearson's coefficient: `0.8676` and MSE: `0.2532` are the numbers reported in the original paper. - - Known differences include the way the gradients are accumulated (normalized by batchsize or not). - -### Requirements -- Python (tested on **3.6.5**, should work on **>=2.7**) -- Java >= 8 (for Stanford CoreNLP utilities) -- Other dependencies are in `requirements.txt` -Note: Currently works with PyTorch 0.4.0. Switch to the `pytorch-v0.3.1` branch if you want to use PyTorch 0.3.1. -### Usage -Before delving into how to run the code, here is a quick overview of the contents: - - Use the script `fetch_and_preprocess.sh` to download the [SICK dataset](http://alt.qcri.org/semeval2014/task1/index.php?id=data-and-tools), [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml) and [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml), and [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840) -- **Warning:** this is a 2GB download!), and additionally preprocees the data, i.e. generate dependency parses using [Stanford Neural Network Dependency Parser](http://nlp.stanford.edu/software/nndep.shtml). - - `main.py`does the actual heavy lifting of training the model and testing it on the SICK dataset. For a list of all command-line arguments, have a look at `config.py`. - - The first run caches GLOVE embeddings for words in the SICK vocabulary. In later runs, only the cache is read in during later runs. - - Logs and model checkpoints are saved to the `checkpoints/` directory with the name specified by the command line argument `--expname`. +The [original implementation](https://github.com/dasguptar/treelstm.pytorch) for paper [Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks](http://arxiv.org/abs/1503.00075) doesn't support batch calculation of TreeLSTM. -Next, these are the different ways to run the code here to train a TreeLSTM model. -#### Local Python Environment -If you have a working Python3 environment, simply run the following sequence of steps: -``` -- bash fetch_and_preprocess.sh -- pip install -r requirements.txt -- python main.py -``` -#### Pure Docker Environment -If you want to use a Docker container, simply follow these steps: -``` -- docker build -t treelstm . -- docker run -it treelstm bash -- bash fetch_and_preprocess.sh -- python main.py -``` -#### Local Filesystem + Docker Environment -If you want to use a Docker container, but want to persist data and checkpoints in your local filesystem, simply follow these steps: +To run the model with batch TreeLSTM: ``` -- bash fetch_and_preprocess.sh -- docker build -t treelstm . -- docker run -it --mount type=bind,source="$(pwd)",target="/root/treelstm.pytorch" treelstm bash -- python main.py +- python main.py --use_batch --batchsize 25 ``` -**NOTE**: Setting the environment variable OMP_NUM_THREADS=1 usually gives a speedup on the CPU. Use it like `OMP_NUM_THREADS=1 python main.py`. To run on a GPU, set the CUDA_VISIBLE_DEVICES instead. Usually, CUDA does not give much speedup here, since we are operating at a batchsize of `1`. - -### Notes - - (**Apr 02, 2018**) Added Dockerfile - - (**Apr 02, 2018**) Now works on **PyTorch 0.3.1** and **Python 3.6**, removed dependency on **Python 2.7** - - (**Nov 28, 2017**) Added **frozen embeddings**, closed gap to paper. - - (**Nov 08, 2017**) Refactored model to get **1.5x - 2x speedup**. - - (**Oct 23, 2017**) Now works with **PyTorch 0.2.0**. - - (**May 04, 2017**) Added support for **sparse tensors**. Using the `--sparse` argument will enable sparse gradient updates for `nn.Embedding`, potentially reducing memory usage. - - There are a couple of caveats, however, viz. weight decay will not work in conjunction with sparsity, and results from the original paper might not be reproduced using sparse embeddings. - -### Acknowledgements -Shout-out to [Kai Sheng Tai](https://github.com/kaishengtai/) for the [original LuaTorch implementation](https://github.com/stanfordnlp/treelstm), and to the [Pytorch team](https://github.com/pytorch/pytorch#the-team) for the fun library. - -### Contact -[Riddhiman Dasgupta](https://researchweb.iiit.ac.in/~riddhiman.dasgupta/) - -*This is my first PyTorch based implementation, and might contain bugs. Please let me know if you find any!* - -### License -MIT +which should give you exact results as without batch, but much faster in training and inference.