From 7d85f2c1c947169e699a652d18244f89d365e17d Mon Sep 17 00:00:00 2001 From: James Bradbury Date: Sun, 5 Mar 2017 21:05:00 +0000 Subject: [PATCH 1/2] batched recursive NN (SPINN) for SNLI --- snli/model.py | 77 ++++++++++++++-------- snli/spinn.py | 175 ++++++++++++++++++++++++++++++++++++++++++++++++++ snli/train.py | 65 ++++++++++++++----- snli/util.py | 11 +++- 4 files changed, 285 insertions(+), 43 deletions(-) create mode 100644 snli/spinn.py diff --git a/snli/model.py b/snli/model.py index e483724e8b..437beb1ef6 100644 --- a/snli/model.py +++ b/snli/model.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch.autograd import Variable +from spinn import SPINN class Bottle(nn.Module): @@ -9,7 +10,7 @@ def forward(self, input): if len(input.size()) <= 2: return super(Bottle, self).forward(input) size = input.size()[:2] - out = super(Bottle, self).forward(input.view(size[0]*size[1], -1)) + out = super(Bottle, self).forward(input.view(size[0] * size[1], -1)) return out.view(*size, -1) @@ -17,6 +18,22 @@ class Linear(Bottle, nn.Linear): pass +class BatchNorm(Bottle, nn.BatchNorm1d): + pass + + +class Feature(nn.Module): + + def __init__(self, size, dropout): + super(Feature, self).__init__() + self.bn = nn.BatchNorm1d(size * 4) + self.dropout = nn.Dropout(p=dropout) + + def forward(self, prem, hypo): + return self.dropout(self.bn(torch.cat( + [prem, hypo, prem - hypo, prem * hypo], 1))) + + class Encoder(nn.Module): def __init__(self, config): @@ -24,10 +41,10 @@ def __init__(self, config): self.config = config input_size = config.d_proj if config.projection else config.d_embed self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden, - num_layers=config.n_layers, dropout=config.dp_ratio, - bidirectional=config.birnn) + num_layers=config.n_layers, dropout=config.rnn_dropout, + bidirectional=config.birnn) - def forward(self, inputs): + def forward(self, inputs, _): batch_size = inputs.size()[1] state_shape = self.config.n_cells, batch_size, self.config.d_hidden h0 = c0 = Variable(inputs.data.new(*state_shape).zero_()) @@ -42,35 +59,43 @@ def __init__(self, config): self.config = config self.embed = nn.Embedding(config.n_embed, config.d_embed) self.projection = Linear(config.d_embed, config.d_proj) - self.encoder = Encoder(config) - self.dropout = nn.Dropout(p=config.dp_ratio) + self.embed_bn = BatchNorm(config.d_proj) + self.embed_dropout = nn.Dropout(p=config.embed_dropout) + self.encoder = SPINN(config) if config.spinn else Encoder(config) + feat_in_size = config.d_hidden * ( + 2 if self.config.birnn and not self.config.spinn else 1) + self.feature = Feature(feat_in_size, config.mlp_dropout) + self.mlp_dropout = nn.Dropout(p=config.mlp_dropout) self.relu = nn.ReLU() - seq_in_size = 2*config.d_hidden - if self.config.birnn: - seq_in_size *= 2 - lin_config = [seq_in_size]*2 - self.out = nn.Sequential( - Linear(*lin_config), - self.relu, - self.dropout, - Linear(*lin_config), - self.relu, - self.dropout, - Linear(*lin_config), - self.relu, - self.dropout, - Linear(seq_in_size, config.d_out)) + mlp_in_size = 4 * feat_in_size + mlp = [nn.Linear(mlp_in_size, config.d_mlp), self.relu, + nn.BatchNorm1d(config.d_mlp), self.mlp_dropout] + for i in range(config.n_mlp_layers - 1): + mlp.extend([nn.Linear(config.d_mlp, config.d_mlp), self.relu, + nn.BatchNorm1d(config.d_mlp), self.mlp_dropout]) + mlp.append(nn.Linear(config.d_mlp, config.d_out)) + self.out = nn.Sequential(*mlp) def forward(self, batch): + # import pdb + # pdb.set_trace() prem_embed = self.embed(batch.premise) hypo_embed = self.embed(batch.hypothesis) if self.config.fix_emb: prem_embed = Variable(prem_embed.data) hypo_embed = Variable(hypo_embed.data) if self.config.projection: - prem_embed = self.relu(self.projection(prem_embed)) - hypo_embed = self.relu(self.projection(hypo_embed)) - premise = self.encoder(prem_embed) - hypothesis = self.encoder(hypo_embed) - scores = self.out(torch.cat([premise, hypothesis], 1)) + prem_embed = self.projection(prem_embed) # no relu + hypo_embed = self.projection(hypo_embed) + prem_embed = self.embed_dropout(self.embed_bn(prem_embed)) + hypo_embed = self.embed_dropout(self.embed_bn(hypo_embed)) + if hasattr(batch, 'premise_transitions'): + prem_trans = batch.premise_transitions + hypo_trans = batch.hypothesis_transitions + else: + prem_trans = hypo_trans = None + premise = self.encoder(prem_embed, prem_trans) + hypothesis = self.encoder(hypo_embed, hypo_trans) + scores = self.out(self.feature(premise, hypothesis)) + #print(premise[0][:5], hypothesis[0][:5]) return scores diff --git a/snli/spinn.py b/snli/spinn.py new file mode 100644 index 0000000000..685a08b34f --- /dev/null +++ b/snli/spinn.py @@ -0,0 +1,175 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Variable + +import itertools + + +def tree_lstm(c1, c2, lstm_in): + a, i, f1, f2, o = lstm_in.chunk(5, 1) + c = a.tanh() * i.sigmoid() + f1.sigmoid() * c1 + f2.sigmoid() * c2 + h = o.sigmoid() * c.tanh() + return h, c + + +def bundle(lstm_iter): + if lstm_iter is None: + return None + lstm_iter = tuple(lstm_iter) + if lstm_iter[0] is None: + return None + return torch.cat(lstm_iter, 0).chunk(2, 1) + + +def unbundle(state): + if state is None: + return itertools.repeat(None) + return torch.split(torch.cat(state, 1), 1, 0) + + +class Reduce(nn.Module): + """TreeLSTM composition module for SPINN. + + The TreeLSTM has two or three inputs: the first two are the left and right + children being composed; the third is the current state of the tracker + LSTM if one is present in the SPINN model. + + Args: + size: The size of the model state. + tracker_size: The size of the tracker LSTM hidden state, or None if no + tracker is present. + """ + + def __init__(self, size, tracker_size=None): + super(Reduce, self).__init__() + self.left = nn.Linear(size, 5 * size) + self.right = nn.Linear(size, 5 * size, bias=False) + if tracker_size is not None: + self.track = nn.Linear(tracker_size, 5 * size, bias=False) + + def __call__(self, left_in, right_in, tracking=None): + """Perform batched TreeLSTM composition. + + This implements the REDUCE operation of a SPINN in parallel for a + batch of nodes. The batch size is flexible; only provide this function + the nodes that actually need to be REDUCEd. + + The TreeLSTM has two or three inputs: the first two are the left and + right children being composed; the third is the current state of the + tracker LSTM if one is present in the SPINN model. All are provided as + iterables and batched internally into tensors. + + Additionally augments each new node with pointers to its children. + + Args: + left_in: Iterable of ``B`` ~autograd.Variable objects containing + ``c`` and ``h`` concatenated for the left child of each node + in the batch. + right_in: Iterable of ``B`` ~autograd.Variable objects containing + ``c`` and ``h`` concatenated for the right child of each node + in the batch. + tracking: Iterable of ``B`` ~autograd.Variable objects containing + ``c`` and ``h`` concatenated for the tracker LSTM state of + each node in the batch, or None. + + Returns: + out: Tuple of ``B`` ~autograd.Variable objects containing ``c`` and + ``h`` concatenated for the LSTM state of each new node. These + objects are also augmented with ``left`` and ``right`` + attributes. + """ + left, right = bundle(left_in), bundle(right_in) + tracking = bundle(tracking) + lstm_in = self.left(left[0]) + lstm_in += self.right(right[0]) + if hasattr(self, 'track'): + lstm_in += self.track(tracking[0]) + out = unbundle(tree_lstm(left[1], right[1], lstm_in)) + # for o, l, r in zip(out, left_in, right_in): + # o.left, o.right = l, r + return out + + +class Tracker(nn.Module): + + def __init__(self, size, tracker_size, predict): + super(Tracker, self).__init__() + self.rnn = nn.LSTMCell(3 * size, tracker_size) + if predict: + self.transition = nn.Linear(tracker_size, 4) + self.state_size = tracker_size + + def reset_state(self): + self.state = None + + def __call__(self, bufs, stacks): + buf = bundle(buf[-1] for buf in bufs)[0] + stack1 = bundle(stack[-1] for stack in stacks)[0] + stack2 = bundle(stack[-2] for stack in stacks)[0] + x = torch.cat((buf, stack1, stack2), 1) + if self.state is None: + self.state = 2 * [Variable( + x.data.new(x.size(0), self.state_size).zero_())] + self.state = self.rnn(x, self.state) + if hasattr(self, 'transition'): + return self.transition(self.state[0]) + + +class SPINN(nn.Module): + + def __init__(self, config): + super(SPINN, self).__init__() + self.config = config + assert config.d_hidden == config.d_proj / 2 + self.reduce = Reduce(config.d_hidden, config.d_tracker) + if config.d_tracker is not None: + self.tracker = Tracker(config.d_hidden, config.d_tracker, + predict=config.predict) + + def __call__(self, buffers, transitions): + buffers = [list(torch.split(b.squeeze(1), 1, 0)) + for b in torch.split(buffers, 1, 1)] + stacks = [[buf[0], buf[0]] for buf in buffers] + + if hasattr(self, 'tracker'): + self.tracker.reset_state() + else: + assert transitions is not None + + if transitions is not None: + num_transitions = transitions.size(0) + # trans_loss, trans_acc = 0, 0 + else: + num_transitions = len(buffers[0]) * 2 - 3 + + for i in range(num_transitions): + if transitions is not None: + trans = transitions[i] + if hasattr(self, 'tracker'): + trans_hyp = self.tracker(buffers, stacks) + if trans_hyp is not None: + trans = trans_hyp.max(1)[1] + # if transitions is not None: + # trans_loss += F.cross_entropy(trans_hyp, trans) + # trans_acc += (trans_preds.data == trans.data).mean() + # else: + # trans = trans_preds + lefts, rights, trackings = [], [], [] + batch = zip(trans.data, buffers, stacks, + unbundle(self.tracker.state) if hasattr(self, 'tracker') + else itertools.repeat(None)) + for transition, buf, stack, tracking in batch: + if transition == 3: # shift + stack.append(buf.pop()) + elif transition == 2: # reduce + rights.append(stack.pop()) + lefts.append(stack.pop()) + trackings.append(tracking) + if rights: + reduced = iter(self.reduce(lefts, rights, trackings)) + for transition, stack in zip(trans.data, stacks): + if transition == 2: + stack.append(next(reduced)) + # if trans_loss is not 0: + return bundle([stack.pop() for stack in stacks])[0] diff --git a/snli/train.py b/snli/train.py index d80c790bf7..532b3ea7cc 100644 --- a/snli/train.py +++ b/snli/train.py @@ -3,8 +3,8 @@ import glob import torch -import torch.optim as O -import torch.nn as nn +from torch import optim +from torch import nn from torchtext import data from torchtext import datasets @@ -14,12 +14,18 @@ args = get_args() -torch.cuda.set_device(args.gpu) +if args.gpu != -1: + torch.cuda.set_device(args.gpu) -inputs = data.Field(lower=args.lower) +if args.spinn: + inputs = datasets.snli.ParsedTextField(lower=args.lower) + transitions = datasets.snli.ShiftReduceField() +else: + inputs = data.Field(lower=args.lower) + transitions = None answers = data.Field(sequential=False) -train, dev, test = datasets.SNLI.splits(inputs, answers) +train, dev, test = datasets.SNLI.splits(inputs, answers, transitions) inputs.build_vocab(train, dev, test) if args.word_vectors: @@ -32,7 +38,7 @@ answers.build_vocab(train) train_iter, dev_iter, test_iter = data.BucketIterator.splits( - (train, dev, test), batch_size=args.batch_size, device=args.gpu) + (train, dev, test), batch_size=args.batch_size, device=args.gpu) config = args config.n_embed = len(inputs.vocab) @@ -41,16 +47,38 @@ if config.birnn: config.n_cells *= 2 +if config.spinn: + config.lr = 2e-3 # 3e-4 + config.lr_decay_by = 0.75 + config.lr_decay_every = 1 #0.6 + config.regularization = 0 #3e-6 + config.mlp_dropout = 0.07 + config.embed_dropout = 0.08 # 0.17 + config.n_mlp_layers = 2 + config.d_tracker = 64 + config.d_mlp = 1024 + config.d_hidden = 300 + config.d_embed = 300 + config.d_proj = 600 + torch.backends.cudnn.enabled = False +else: + config.regularization = 0 + if args.resume_snapshot: - model = torch.load(args.resume_snapshot, map_location=lambda storage, locatoin: storage.cuda(args.gpu)) + model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage if args.gpu == -1 else storage.cuda(args.gpu)) else: model = SNLIClassifier(config) + if config.spinn: + model.out[len(model.out._modules) - 1].weight.data.uniform_(-0.005, 0.005) if args.word_vectors: model.embed.weight.data = inputs.vocab.vectors + if args.gpu != -1: model.cuda() criterion = nn.CrossEntropyLoss() -opt = O.Adam(model.parameters(), lr=args.lr) +#opt = optim.Adam(model.parameters(), lr=args.lr) +opt = optim.RMSprop(model.parameters(), lr=config.lr, alpha=0.9, eps=1e-6, + weight_decay=config.regularization) iterations = 0 start = time.time() @@ -64,19 +92,23 @@ for epoch in range(args.epochs): train_iter.init_epoch() - n_correct, n_total = 0, 0 + n_correct, n_total, train_loss = 0, 0, 0 for batch_idx, batch in enumerate(train_iter): model.train(); opt.zero_grad() - iterations += 1 + for pg in opt.param_groups: + pg['lr'] = args.lr * (args.lr_decay_by ** ( + iterations / len(train_iter) / args.lr_decay_every)) + iterations += 1 answer = model(batch) + #print(nn.functional.softmax(answer[0]).data.tolist(), batch.label.data[0]) n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum() n_total += batch.batch_size train_acc = 100. * n_correct/n_total loss = criterion(answer, batch.label) - loss.backward(); opt.step() + loss.backward(); opt.step(); train_loss += loss.data[0] * batch.batch_size if iterations % args.save_every == 0: snapshot_prefix = os.path.join(args.save_path, 'snapshot') - snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations) + snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, train_loss / n_total, iterations) torch.save(model, snapshot_path) for f in glob.glob(snapshot_prefix + '*'): if f != snapshot_path: @@ -87,11 +119,12 @@ for dev_batch_idx, dev_batch in enumerate(dev_iter): answer = model(dev_batch) n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() - dev_loss = criterion(answer, dev_batch.label) + dev_loss += criterion(answer, dev_batch.label).data[0] * dev_batch.batch_size dev_acc = 100. * n_dev_correct / len(dev) print(dev_log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), - 100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc)) + 100. * (1+batch_idx) / len(train_iter), train_loss / n_total, dev_loss / len(dev), train_acc, dev_acc)) + train_loss = 0 if dev_acc > best_dev_acc: best_dev_acc = dev_acc snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') @@ -103,6 +136,6 @@ elif iterations % args.log_every == 0: print(log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), - 100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12)) - + 100. * (1+batch_idx) / len(train_iter), train_loss / n_total, ' '*8, n_correct / n_total*100, ' '*12)) + n_correct, n_total, train_loss = 0, 0, 0 diff --git a/snli/util.py b/snli/util.py index 1ef2133de6..d17b4750f1 100644 --- a/snli/util.py +++ b/snli/util.py @@ -8,16 +8,25 @@ def get_args(): parser.add_argument('--d_embed', type=int, default=300) parser.add_argument('--d_proj', type=int, default=300) parser.add_argument('--d_hidden', type=int, default=300) + parser.add_argument('--d_mlp', type=int, default=600) + parser.add_argument('--n_mlp_layers', type=int, default=3) + parser.add_argument('--d_tracker', type=int, default=None) parser.add_argument('--n_layers', type=int, default=1) parser.add_argument('--log_every', type=int, default=50) parser.add_argument('--lr', type=float, default=.001) + parser.add_argument('--lr_decay_by', type=float, default=1) + parser.add_argument('--lr_decay_every', type=float, default=1) parser.add_argument('--dev_every', type=int, default=1000) parser.add_argument('--save_every', type=int, default=1000) - parser.add_argument('--dp_ratio', type=int, default=0.2) + parser.add_argument('--embed_dropout', type=float, default=0.2) + parser.add_argument('--mlp_dropout', type=float, default=0.2) + parser.add_argument('--rnn_dropout', type=float, default=0.2) parser.add_argument('--no-bidirectional', action='store_false', dest='birnn') parser.add_argument('--preserve-case', action='store_false', dest='lower') parser.add_argument('--no-projection', action='store_false', dest='projection') parser.add_argument('--train_embed', action='store_false', dest='fix_emb') + parser.add_argument('--predict_transitions', action='store_true', dest='predict') + parser.add_argument('--spinn', action='store_true', dest='spinn') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--save_path', type=str, default='results') parser.add_argument('--data_cache', type=str, default=os.path.join(os.getcwd(), '.data_cache')) From 68e8329c7d9d578fbdc6d19ac375b775aa708b2c Mon Sep 17 00:00:00 2001 From: James Bradbury Date: Mon, 6 Mar 2017 09:26:52 +0000 Subject: [PATCH 2/2] fixes --- snli/spinn.py | 18 +++++++++--------- snli/train.py | 31 +++++++++++++++---------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/snli/spinn.py b/snli/spinn.py index 685a08b34f..96c7436149 100644 --- a/snli/spinn.py +++ b/snli/spinn.py @@ -48,7 +48,7 @@ def __init__(self, size, tracker_size=None): if tracker_size is not None: self.track = nn.Linear(tracker_size, 5 * size, bias=False) - def __call__(self, left_in, right_in, tracking=None): + def forward(self, left_in, right_in, tracking=None): """Perform batched TreeLSTM composition. This implements the REDUCE operation of a SPINN in parallel for a @@ -103,7 +103,7 @@ def __init__(self, size, tracker_size, predict): def reset_state(self): self.state = None - def __call__(self, bufs, stacks): + def forward(self, bufs, stacks): buf = bundle(buf[-1] for buf in bufs)[0] stack1 = bundle(stack[-1] for stack in stacks)[0] stack2 = bundle(stack[-2] for stack in stacks)[0] @@ -113,8 +113,8 @@ def __call__(self, bufs, stacks): x.data.new(x.size(0), self.state_size).zero_())] self.state = self.rnn(x, self.state) if hasattr(self, 'transition'): - return self.transition(self.state[0]) - + return unbundle(self.state), self.transition(self.state[0]) + return unbundle(self.state), None class SPINN(nn.Module): @@ -127,7 +127,7 @@ def __init__(self, config): self.tracker = Tracker(config.d_hidden, config.d_tracker, predict=config.predict) - def __call__(self, buffers, transitions): + def forward(self, buffers, transitions): buffers = [list(torch.split(b.squeeze(1), 1, 0)) for b in torch.split(buffers, 1, 1)] stacks = [[buf[0], buf[0]] for buf in buffers] @@ -147,7 +147,7 @@ def __call__(self, buffers, transitions): if transitions is not None: trans = transitions[i] if hasattr(self, 'tracker'): - trans_hyp = self.tracker(buffers, stacks) + tracker_states, trans_hyp = self.tracker(buffers, stacks) if trans_hyp is not None: trans = trans_hyp.max(1)[1] # if transitions is not None: @@ -155,10 +155,10 @@ def __call__(self, buffers, transitions): # trans_acc += (trans_preds.data == trans.data).mean() # else: # trans = trans_preds + else: + tracker_states = itertools.repeat(None) lefts, rights, trackings = [], [], [] - batch = zip(trans.data, buffers, stacks, - unbundle(self.tracker.state) if hasattr(self, 'tracker') - else itertools.repeat(None)) + batch = zip(trans.data, buffers, stacks, tracker_states) for transition, buf, stack, tracking in batch: if transition == 3: # shift stack.append(buf.pop()) diff --git a/snli/train.py b/snli/train.py index 532b3ea7cc..a685af37ad 100644 --- a/snli/train.py +++ b/snli/train.py @@ -64,16 +64,15 @@ else: config.regularization = 0 +model = SNLIClassifier(config) +if config.spinn: + model.out[len(model.out._modules) - 1].weight.data.uniform_(-0.005, 0.005) +if args.word_vectors: + model.embed.weight.data = inputs.vocab.vectors +if args.gpu != -1: + model.cuda() if args.resume_snapshot: - model = torch.load(args.resume_snapshot, map_location=lambda storage, location: storage if args.gpu == -1 else storage.cuda(args.gpu)) -else: - model = SNLIClassifier(config) - if config.spinn: - model.out[len(model.out._modules) - 1].weight.data.uniform_(-0.005, 0.005) - if args.word_vectors: - model.embed.weight.data = inputs.vocab.vectors - if args.gpu != -1: - model.cuda() + model.load_state_dict(torch.load(args.resume_snapshot)) criterion = nn.CrossEntropyLoss() #opt = optim.Adam(model.parameters(), lr=args.lr) @@ -92,7 +91,7 @@ for epoch in range(args.epochs): train_iter.init_epoch() - n_correct, n_total, train_loss = 0, 0, 0 + n_correct = n_total = train_loss = 0 for batch_idx, batch in enumerate(train_iter): model.train(); opt.zero_grad() for pg in opt.param_groups: @@ -109,13 +108,13 @@ if iterations % args.save_every == 0: snapshot_prefix = os.path.join(args.save_path, 'snapshot') snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, train_loss / n_total, iterations) - torch.save(model, snapshot_path) + torch.save(model.state_dict(), snapshot_path) for f in glob.glob(snapshot_prefix + '*'): if f != snapshot_path: os.remove(f) if iterations % args.dev_every == 0: model.eval(); dev_iter.init_epoch() - n_dev_correct, dev_loss = 0, 0 + n_dev_correct = dev_loss = 0 for dev_batch_idx, dev_batch in enumerate(dev_iter): answer = model(dev_batch) n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum() @@ -124,12 +123,12 @@ print(dev_log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), 100. * (1+batch_idx) / len(train_iter), train_loss / n_total, dev_loss / len(dev), train_acc, dev_acc)) - train_loss = 0 + n_correct = n_total = train_loss = 0 if dev_acc > best_dev_acc: best_dev_acc = dev_acc snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') - snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations) - torch.save(model, snapshot_path) + snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss / len(dev), iterations) + torch.save(model.state_dict(), snapshot_path) for f in glob.glob(snapshot_prefix + '*'): if f != snapshot_path: os.remove(f) @@ -137,5 +136,5 @@ print(log_template.format(time.time()-start, epoch, iterations, 1+batch_idx, len(train_iter), 100. * (1+batch_idx) / len(train_iter), train_loss / n_total, ' '*8, n_correct / n_total*100, ' '*12)) - n_correct, n_total, train_loss = 0, 0, 0 + n_correct = n_total = train_loss = 0