Skip to content

Commit 48cc2bb

Browse files
committed
[Update] Use char sequences to get the final hiiden states
1 parent 8a8a5ae commit 48cc2bb

File tree

6 files changed

+42
-24
lines changed

6 files changed

+42
-24
lines changed

layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def forward(self, x, x_mask):
5555
Output:
5656
x_encoded: batch * len * hdim_encoded
5757
"""
58-
if x_mask.data.sum() == 0:
58+
if x_mask.data.sum() == 0 or x_mask.data.eq(1).long().sum(1).min() == 0:
5959
# No padding necessary.
6060
output = self._forward_unpadded(x, x_mask)
6161
elif self.padding or not self.training:

m_reader.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, args, normalize=True):
4545
dropout_output=args.dropout_rnn_output,
4646
concat_layers=False,
4747
rnn_type=self.RNN_TYPES[args.rnn_type],
48-
padding=args.rnn_padding,
48+
padding=False,
4949
)
5050

5151
doc_input_size = args.embedding_dim + args.char_hidden_size * 2 + args.num_features
@@ -127,8 +127,14 @@ def forward(self, x1, x1_c, x1_f, x1_mask, x2, x2_c, x2_f, x2_mask):
127127
x2_c_emb = F.dropout(x2_c_emb, p=self.args.dropout_emb, training=self.training)
128128

129129
# Generate char features
130-
x1_c_features = self.char_rnn(x1_c_emb, x1_mask)[:,-1,:]
131-
x2_c_features = self.char_rnn(x2_c_emb, x2_mask)[:,-1,:]
130+
x1_c_features = self.char_rnn(
131+
x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2), x1_c_emb.size(3))),
132+
x1_mask.unsqueeze(2).repeat(1, 1, x1_c_emb.size(2)).reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2)))
133+
).reshape((x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2), -1))[:,:,-1,:]
134+
x2_c_features = self.char_rnn(
135+
x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2), x2_c_emb.size(3))),
136+
x2_mask.unsqueeze(2).repeat(1, 1, x2_c_emb.size(2)).reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2)))
137+
).reshape((x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2), -1))[:,:,-1,:]
132138

133139
# Combine input
134140
crnn_input = [x1_emb, x1_c_features]
@@ -156,4 +162,4 @@ def forward(self, x1, x1_c, x1_f, x1_mask, x2, x2_c, x2_f, x2_mask):
156162
# Predict
157163
start_scores, end_scores = self.mem_ans_ptr.forward(c_check, q, x1_mask, x2_mask)
158164

159-
return start_scores, end_scores
165+
return start_scores, end_scores

r_net.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self, args, normalize=True):
4545
dropout_output=args.dropout_rnn_output,
4646
concat_layers=False,
4747
rnn_type=self.RNN_TYPES[args.rnn_type],
48-
padding=args.rnn_padding,
48+
padding=False,
4949
)
5050

5151
doc_input_size = args.embedding_dim + args.char_hidden_size * 2
@@ -146,8 +146,14 @@ def forward(self, x1, x1_c, x1_f, x1_mask, x2, x2_c, x2_f, x2_mask):
146146
x2_c_emb = F.dropout(x2_c_emb, p=self.args.dropout_emb, training=self.training)
147147

148148
# Generate char features
149-
x1_c_features = self.char_rnn(x1_c_emb, x1_mask)[:,-1,:]
150-
x2_c_features = self.char_rnn(x2_c_emb, x2_mask)[:,-1,:]
149+
x1_c_features = self.char_rnn(
150+
x1_c_emb.reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2), x1_c_emb.size(3))),
151+
x1_mask.unsqueeze(2).repeat(1, 1, x1_c_emb.size(2)).reshape((x1_c_emb.size(0) * x1_c_emb.size(1), x1_c_emb.size(2)))
152+
).reshape((x1_c_emb.size(0), x1_c_emb.size(1), x1_c_emb.size(2), -1))[:,:,-1,:]
153+
x2_c_features = self.char_rnn(
154+
x2_c_emb.reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2), x2_c_emb.size(3))),
155+
x2_mask.unsqueeze(2).repeat(1, 1, x2_c_emb.size(2)).reshape((x2_c_emb.size(0) * x2_c_emb.size(1), x2_c_emb.size(2)))
156+
).reshape((x2_c_emb.size(0), x2_c_emb.size(1), x2_c_emb.size(2), -1))[:,:,-1,:]
151157

152158
# Combine input
153159
crnn_input = [x1_emb, x1_c_features]

spacy_tokenizer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def chars(self, uncased=False):
4848
uncased: lower cases characters
4949
"""
5050
if uncased:
51-
return [c.lower() for t in self.data for c in t[self.CHAR]]
51+
return [[c.lower() for c in t[self.CHAR]] for t in self.data]
5252
else:
53-
return [c for t in self.data for c in t[self.CHAR]]
53+
return [[c for c in t[self.CHAR]] for t in self.data]
5454

5555
def words(self, uncased=False):
5656
"""Returns a list of the text of each token
@@ -174,7 +174,6 @@ def tokenize(self, text):
174174

175175
data.append((
176176
tokens[i].text,
177-
# tokens[i].text[0] if len(tokens[i].text) > 0 else '',
178177
list(tokens[i].text),
179178
text[start_ws: end_ws],
180179
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),

utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,12 @@ def index_embedding_chars(char_embedding_file):
141141
def load_chars(args, examples):
142142
"""Iterate and index all the chars in examples (documents + questions)."""
143143
def _insert(iterable):
144-
for c in iterable:
145-
c = Dictionary.normalize(c)
146-
if valid_chars and c not in valid_chars:
147-
continue
148-
chars.add(c)
144+
for cs in iterable:
145+
for c in cs:
146+
c = Dictionary.normalize(c)
147+
if valid_chars and c not in valid_chars:
148+
continue
149+
chars.add(c)
149150

150151
if args.restrict_vocab and args.char_embedding_file:
151152
logger.info('Restricting to chars in %s' % args.char_embedding_file)

vector.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def vectorize(ex, model, single_answer=False):
1919

2020
# Index words
2121
document = torch.LongTensor([word_dict[w] for w in ex['document']])
22-
document_char = torch.LongTensor([char_dict[c] for c in ex['document_char']])
22+
document_char = [torch.LongTensor([char_dict[c] for c in cs]) for cs in ex['document_char']]
2323
question = torch.LongTensor([word_dict[w] for w in ex['question']])
24-
question_char = torch.LongTensor([char_dict[c] for c in ex['question_char']])
24+
question_char = [torch.LongTensor([char_dict[c] for c in cs]) for cs in ex['question_char']]
2525

2626
# Create extra features vector
2727
if len(feature_dict) > 0:
@@ -120,8 +120,10 @@ def batchify(batch):
120120

121121
# Batch documents and features
122122
max_length = max([d.size(0) for d in docs])
123+
# max_char_length = max([c.size(0) for cs in doc_chars for c in cs])
124+
max_char_length = 13
123125
x1 = torch.LongTensor(len(docs), max_length).zero_()
124-
x1_c = torch.LongTensor(len(docs), max_length).zero_()
126+
x1_c = torch.LongTensor(len(docs), max_length, max_char_length).zero_()
125127
x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1)
126128
if c_features[0] is None:
127129
x1_f = None
@@ -132,13 +134,15 @@ def batchify(batch):
132134
x1_mask[i, :d.size(0)].fill_(0)
133135
if x1_f is not None:
134136
x1_f[i, :d.size(0)].copy_(c_features[i])
135-
for i, c in enumerate(doc_chars):
136-
x1_c[i, :c.size(0)].copy_(c)
137+
for i, cs in enumerate(doc_chars):
138+
for j, c in enumerate(cs):
139+
c_ = c[:max_char_length]
140+
x1_c[i, j, :c_.size(0)].copy_(c_)
137141

138142
# Batch questions
139143
max_length = max([q.size(0) for q in questions])
140144
x2 = torch.LongTensor(len(questions), max_length).zero_()
141-
x2_c = torch.LongTensor(len(questions), max_length).zero_()
145+
x2_c = torch.LongTensor(len(questions), max_length, max_char_length).zero_()
142146
x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1)
143147
if q_features[0] is None:
144148
x2_f = None
@@ -149,8 +153,10 @@ def batchify(batch):
149153
x2_mask[i, :d.size(0)].fill_(0)
150154
if x2_f is not None:
151155
x2_f[i, :d.size(0)].copy_(q_features[i])
152-
for i, c in enumerate(question_chars):
153-
x2_c[i, :c.size(0)].copy_(c)
156+
for i, cs in enumerate(question_chars):
157+
for j, c in enumerate(cs):
158+
c_ = c[:max_char_length]
159+
x2_c[i, j, :c_.size(0)].copy_(c_)
154160

155161
# Maybe return without targets
156162
if len(batch[0]) == NUM_INPUTS + NUM_EXTRA:

0 commit comments

Comments
 (0)