From ab8b1099b725278bb53ed1176b4e88d50989a238 Mon Sep 17 00:00:00 2001 From: zengxianyu Date: Thu, 28 Jun 2018 10:55:45 +0800 Subject: [PATCH] error when batchsize=1 --- pytorch/skipthoughts/skipthoughts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch/skipthoughts/skipthoughts.py b/pytorch/skipthoughts/skipthoughts.py index 5166193..66569dc 100644 --- a/pytorch/skipthoughts/skipthoughts.py +++ b/pytorch/skipthoughts/skipthoughts.py @@ -132,7 +132,7 @@ def _select_last_old(self, input, lengths): def _process_lengths(self, input): max_length = input.size(1) - lengths = list(max_length - input.data.eq(0).sum(1).squeeze()) + lengths = list(max_length - input.data.eq(0).sum(1, keepdim=False)) return lengths def _load_rnn(self):