Skip to content
Snippets Groups Projects
Commit 2312470a authored by Liam Byrne's avatar Liam Byrne
Browse files

Revised post embedding builder with unixcoder

parent ed74f9cd
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
File added
......@@ -4,6 +4,7 @@ import pickle
import re
import sqlite3
from typing import List
import time
import pandas as pd
import torch
......@@ -17,7 +18,7 @@ from post_embedding_builder import PostEmbedding
from static_graph_construction import StaticGraphConstruction
logging.basicConfig()
#logging.getLogger().setLevel(logging.ERROR)
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger("dataset")
......@@ -71,32 +72,68 @@ class UserGraphDataset(Dataset):
"""
"""
log.info("Processing data...")
'''TIME START'''
t1 = time.time()
# Fetch the unprocessed questions and the next index to use.
unprocessed, idx = self.get_unprocessed_ids()
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.get_unprocessed_ids.__name__, t2 - t1))
'''TIME START'''
t1 = time.time()
# Fetch questions from database.
valid_questions = self.fetch_questions_by_post_ids(unprocessed)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.fetch_questions_by_post_ids.__name__, t2 - t1))
for row in tqdm(valid_questions.itertuples(), total=len(valid_questions)):
'''TIME START'''
t1 = time.time()
# Build Question embedding
question_word_emb, question_code_emb, _ = self._post_embedding_builder(
row.question_body,
question_word_embs, question_code_embs, _ = self._post_embedding_builder(
[row.question_body],
use_bert=True,
title=row.question_title
title_batch=[row.question_title]
)
question_emb = torch.concat((question_word_emb, question_code_emb))
question_emb = torch.concat((question_word_embs[0], question_code_embs[0]))
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % ("Post embedding builder (question)", t2 - t1))
'''TIME START'''
t1 = time.time()
# Fetch answers to question
answers_to_question = self.fetch_answers_for_question(row.post_id)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.fetch_answers_for_question.__name__, t2 - t1))
# Build Answer embeddings
for _, answer_body, answer_user_id, score in answers_to_question.itertuples():
label = torch.tensor([1 if score > 0 else 0], dtype=torch.long)
answer_word_emb, answer_code_emb, _ = self._post_embedding_builder(
answer_body, use_bert=True
answer_word_embs, answer_code_embs, _ = self._post_embedding_builder(
[answer_body], use_bert=True, title_batch=[None]
)
answer_emb = torch.concat((answer_word_emb, answer_code_emb))
answer_emb = torch.concat((answer_word_embs[0], answer_code_embs[0]))
'''TIME START'''
t1 = time.time()
# Build graph
graph: HeteroData = self.construct_graph(answer_user_id)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.construct_graph.__name__, t2 - t1))
# pytorch geometric data object
graph.__setattr__('question_emb', question_emb)
graph.__setattr__('answer_emb', answer_emb)
......
......@@ -59,6 +59,7 @@ def train(model, train_loader):
model.train()
for data in train_loader: # Iterate in batches over the training dataset.
print(data)
data = data.to(device)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, torch.concat([data.question_emb, data.answer_emb])) # Perform a single forward pass.
loss = criterion(torch.unsqueeze(out,0), data.label) # Compute the loss.
......
......@@ -2,21 +2,21 @@ import ast
import io
import logging
import re
import time
import tokenize
from collections import namedtuple
from typing import List
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
log = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger("PostEmbedding")
from bs4 import BeautifulSoup
import spacy
import torch
import torch.nn as nn
from torchtext.vocab import GloVe
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel, AutoConfig
from unixcoder import UniXcoder
Import = namedtuple("Import", ["module", "name", "alias"])
Function = namedtuple("Function", ["function_name", "parameter_names"])
......@@ -27,18 +27,20 @@ class PostEmbedding(nn.Module):
Torch module for transforming Stackoverflow posts into a torch tensor.
"""
def __init__(self):
def __init__(self, batched=False):
super().__init__()
log.info("PostEmbedding instantiated!")
self._batched = batched
# self._global_vectors = GloVe(name='840B', dim=300)
self._en = spacy.load('en_core_web_sm')
self._stopwords = self._en.Defaults.stop_words
self._bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self._bert_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True)
self._code_bert_tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
self._code_bert_model = AutoModel.from_pretrained("microsoft/codebert-base")
self._unixcoder = UniXcoder("microsoft/unixcoder-base")
self._codebert_tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base')
self._codebert_model = AutoModel.from_pretrained('microsoft/codebert-base')
def forward(self, html: str, use_bert: bool, title: str=None, flatten=True) -> torch.tensor:
def forward(self, html_batch: List[str], use_bert: bool, title_batch: List[str]) -> torch.tensor:
"""
@param html: HTML string of the body of a StackOverflow post.
@param title: Title of a question post.
......@@ -46,18 +48,48 @@ class PostEmbedding(nn.Module):
@return: Post embedding represented as a torch tensor
"""
soup = BeautifulSoup(html, 'lxml')
ps = self.get_paragraphs(soup, title)
soups = [BeautifulSoup(html, 'lxml') for html in html_batch]
assert len(soups) == len(title_batch)
paragraph_batches = [self.get_paragraphs(soup, not use_bert, title) for soup, title in zip(soups, title_batch)]
log.debug(f"Processing {len(html_batch)} posts")
'''TIME START'''
t1 = time.time()
if use_bert:
para_emb = self.to_bert_embedding(" ".join(ps))
if self._batched:
para_embs = self.to_bert_embedding([" ".join(ps) for ps in paragraph_batches])
else:
para_emb = self.to_glove_paragraph_embedding(ps)
modules, funcs = self.get_code(soup, get_imports_with_regex=True)
para_embs = []
for ps in paragraph_batches:
para_emb = self.to_bert_embedding([" ".join(ps)])
para_embs.append(torch.squeeze(para_emb))
else:
# para_emb = self.to_glove_paragraph_embedding(ps)
raise NotImplementedError("GloVe paragraph embedding need to be refactored to work with batches")
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % ("Paragraph embedding", t2 - t1))
code_features = [self.get_code(soup, get_imports_with_regex=True) for soup in soups]
modules = [x[0] for x in code_features]
'''TIME START'''
t1 = time.time()
if self._batched:
code_embs = self.to_unixcode_embedding(["\n".join([x.get_text() for x in soup.find_all('code')]) for soup in soups])
else:
code_embs = []
for soup in soups:
code_emb = self.to_unixcode_embedding(["\n".join([x.get_text() for x in soup.find_all('code')])])
code_embs.append(torch.squeeze(code_emb))
code_bert = self.to_code_bert_embedding("\n".join([x.get_text() for x in soup.find_all('code')]))
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % ("CodeBERT embedding", t2 - t1))
return para_emb, code_bert, modules
return para_embs, code_embs, modules
def preprocess(self, text: str) -> List[str]:
"""
......@@ -68,17 +100,24 @@ class PostEmbedding(nn.Module):
tokens = [word.text for word in doc if not (word.is_stop or word.is_punct or word.like_num)]
return tokens
def get_paragraphs(self, soup: BeautifulSoup, title: str = None) -> List[str]:
def get_paragraphs(self, soup: BeautifulSoup, preprocess: bool, title: str = None) -> List[str]:
"""
@param soup: Post body HTML wrapped in a BeautifulSoup object.
@param title: If available, add title as a paragraph.
@return: List of tokens for each paragraph.
:param preprocess:
"""
if preprocess:
paras = [self.preprocess(x.get_text()) for x in soup.find_all('p')]
else:
paras = [[x.get_text()] for x in soup.find_all('p')]
# If title is available add it to the paragraphs
if title is not None:
if preprocess:
paras.append(self.preprocess(title))
else:
paras.append([title])
return [token for para in paras for token in para]
def get_code(self, soup: BeautifulSoup, get_imports_with_regex=False, get_functions_with_regex=False) -> (List[Import], List[Function]):
......@@ -112,16 +151,27 @@ class PostEmbedding(nn.Module):
word_embeddings = self._global_vectors.get_vecs_by_tokens(tokens)
return torch.sum(word_embeddings, dim=0) / len(tokens)
def to_bert_embedding(self, text: str) -> torch.tensor:
# if not len(text):
# return torch.zeros(768)
encodings = self._bert_tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=512)
def to_bert_embedding(self, texts: List[str]) -> torch.tensor:
encodings = self._bert_tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512)
with torch.no_grad():
outputs = self._bert_model(**encodings)
last_layer = outputs.last_hidden_state
cls = last_layer[:, 0, :]
return torch.squeeze(cls) # Converts from dim [1, 768] to [768]
return cls # Converts from dim [1, 768] to [768]
def to_unixcode_embedding(self, code_batches: List[str]) -> torch.tensor:
"""
Get comments
:param code:
:return:
"""
token_ids = self._unixcoder.tokenize(code_batches, max_length=512, mode="<encoder-only>")
longest_token_ids = max([len(x) for x in token_ids])
token_ids = [x + ([self._unixcoder.config.pad_token_id] * (longest_token_ids - len(x))) for x in token_ids]
source_ids = torch.tensor(token_ids)
tokens_embeddings, code_embeddings = self._unixcoder(source_ids)
normalized_code_emb = torch.nn.functional.normalize(code_embeddings, p=2, dim=1)
return normalized_code_emb
def to_code_bert_embedding(self, code):
"""
......@@ -150,9 +200,9 @@ class PostEmbedding(nn.Module):
except IndentationError:
continue
nl_tokens = self._code_bert_tokenizer.tokenize(" ".join(comments))
nl_tokens = self._codebert_tokenizer.tokenize(" ".join(comments))
code_tokens = self._code_bert_tokenizer.tokenize("".join(source))
code_tokens = self._codebert_tokenizer.tokenize(" ".join(source))
# CodeBERT has a max token length of 512
while len(nl_tokens) + len(code_tokens) > 509:
......@@ -161,12 +211,13 @@ class PostEmbedding(nn.Module):
else:
code_tokens = code_tokens[:-1]
tokens = [self._code_bert_tokenizer.cls_token] + nl_tokens + [self._code_bert_tokenizer.sep_token] + code_tokens + [self._code_bert_tokenizer.eos_token]
tokens_ids = self._code_bert_tokenizer.convert_tokens_to_ids(tokens)
log.debug(f"NL Tokens: {len(nl_tokens)} Code Tokens: {len(code_tokens)}")
emb = self._code_bert_model(torch.tensor(tokens_ids)[None,:])[0]
return emb.mean(dim=1).mean(dim=0)
tokens = [self._codebert_tokenizer.cls_token] + nl_tokens + [self._codebert_tokenizer.sep_token] + code_tokens + [self._codebert_tokenizer.eos_token]
tokens_ids = self._codebert_tokenizer.convert_tokens_to_ids(tokens)
emb = self._codebert_model(torch.tensor(tokens_ids)[None, :])[0]
return emb.mean(dim=1).mean(dim=0)
"""
Python RegEx methods
......@@ -180,7 +231,6 @@ class PostEmbedding(nn.Module):
for module in list(set(re.findall(PATTERN, code_snippet, flags=re.MULTILINE))):
yield Import(module, None, None)
"""
Python Abstract Syntax Tree methods
"""
......@@ -214,6 +264,12 @@ class PostEmbedding(nn.Module):
if __name__ == '__main__':
pe = PostEmbedding()
#print(pe.to_code_bert_embedding("\n".join(["for i in range(32):\n #return 6 or something\n"])).shape)
print(pe.to_bert_embedding("This is a test sentence.").shape)
'''TIME START'''
t1 = time.time()
for i in range(1):
a = (pe.to_unixcode_embedding(2 * ["\n".join(["for i in range(32):\n #return 6 or something\n"])]).shape)
b = (pe.to_bert_embedding(2 * ["This is a test sentence."]).shape)
# print([x.module for x in pe.get_imports_via_regex(BeautifulSoup("<code>import ast<\code>", 'lxml'))])
'''TIME END'''
t2 = time.time()
print("Function=%s, Time=%s" % ("embedding", t2 - t1))
This diff is collapsed.
......@@ -36,24 +36,42 @@ class StaticGraphConstruction:
def process_questions(self, questions: pd.DataFrame) -> torch.Tensor:
for i, body, title, tags in questions[['Body', 'Title', 'Tags']].itertuples():
word_embedding, code_embedding, modules = StaticGraphConstruction.post_embedding_builder(body, self._use_bert, title)
modules = self.process_module_names(modules)
if not len(questions):
return None
word_emb_batches, code_emb_batches, module_name_batches = StaticGraphConstruction.post_embedding_builder(
questions['Body'], self._use_bert, questions['Title']
)
row_counter = 0
for post_id, body, title, tags in questions[['Body', 'Title', 'Tags']].itertuples():
modules = self.process_module_names(module_name_batches[row_counter])
tag_list = self.parse_tag_list(tags)[:self._first_n_tags]
for tag in tag_list:
self._tag_to_question_edges.append((self._known_tags[tag], i))
self._tag_to_question_edges.append((self._known_tags[tag], post_id))
for module in modules:
self._module_to_question_edges.append((self._known_modules[module], i))
self._module_to_question_edges.append((self._known_modules[module], post_id))
yield torch.concat((word_embedding, code_embedding))
post_emb = torch.concat((word_emb_batches[row_counter], code_emb_batches[row_counter]))
row_counter += 1
yield post_emb
def process_answers(self, answers: pd.DataFrame) -> torch.Tensor:
if not len(answers):
return None
word_emb_batches, code_emb_batches, module_name_batches = StaticGraphConstruction.post_embedding_builder(
answers['Body'], self._use_bert, title_batch=answers['Title']
)
row_counter = 0
for i, body, title, tags in answers[['Body', 'Title', 'Tags']].itertuples():
word_embedding, code_embedding, modules = StaticGraphConstruction.post_embedding_builder(body, self._use_bert, title)
modules = self.process_module_names(modules)
modules = self.process_module_names(module_name_batches[row_counter])
tag_list = self.parse_tag_list(tags)[:self._first_n_tags]
for tag in tag_list:
......@@ -62,12 +80,22 @@ class StaticGraphConstruction:
for module in modules:
self._module_to_answer_edges.append((self._known_modules[module], i))
yield torch.concat((word_embedding, code_embedding))
post_emb = torch.concat((word_emb_batches[row_counter], code_emb_batches[row_counter]))
row_counter += 1
yield post_emb
def process_comments(self, comments: pd.DataFrame) -> torch.Tensor:
if not len(comments):
return None
word_emb_batches, code_emb_batches, module_name_batches = StaticGraphConstruction.post_embedding_builder(
comments['Body'], self._use_bert, title_batch=[None for _ in range(len(comments))]
)
row_counter = 0
for i, body, tags in comments[['Body', 'Tags']].itertuples():
word_embedding, code_embedding, modules = StaticGraphConstruction.post_embedding_builder(body, self._use_bert)
modules = self.process_module_names(modules)
modules = self.process_module_names(module_name_batches[row_counter])
tag_list = self.parse_tag_list(tags)[:self._first_n_tags]
for tag in tag_list:
......@@ -76,7 +104,9 @@ class StaticGraphConstruction:
for module in modules:
self._module_to_comment_edges.append((self._known_modules[module], i))
yield word_embedding
post_emb = word_emb_batches[row_counter]
row_counter += 1
yield post_emb
def process_tags(self):
if not len(self._known_tags):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
class UniXcoder(nn.Module):
def __init__(self, model_name):
"""
Build UniXcoder.
Parameters:
* `model_name`- huggingface model card name. e.g. microsoft/unixcoder-base
"""
super(UniXcoder, self).__init__()
self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
self.config = RobertaConfig.from_pretrained(model_name)
self.config.is_decoder = True
self.model = RobertaModel.from_pretrained(model_name, config=self.config)
self.register_buffer("bias", torch.tril(torch.ones((1024, 1024), dtype=torch.uint8)).view(1, 1024, 1024))
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.lm_head.weight = self.model.embeddings.word_embeddings.weight
self.lsm = nn.LogSoftmax(dim=-1)
self.tokenizer.add_tokens(["<mask0>"], special_tokens=True)
def tokenize(self, inputs, mode="<encoder-only>", max_length=512, padding=False):
"""
Convert string to token ids
Parameters:
* `inputs`- list of input strings.
* `max_length`- The maximum total source sequence length after tokenization.
* `padding`- whether to pad source sequence length to max_length.
* `mode`- which mode the sequence will use. i.e. <encoder-only>, <decoder-only>, <encoder-decoder>
"""
assert mode in ["<encoder-only>", "<decoder-only>", "<encoder-decoder>"]
assert max_length < 1024
tokenizer = self.tokenizer
tokens_ids = []
for x in inputs:
tokens = tokenizer.tokenize(x)
if mode == "<encoder-only>":
tokens = tokens[:max_length - 4]
tokens = [tokenizer.cls_token, mode, tokenizer.sep_token] + tokens + [tokenizer.sep_token]
elif mode == "<decoder-only>":
tokens = tokens[-(max_length - 3):]
tokens = [tokenizer.cls_token, mode, tokenizer.sep_token] + tokens
else:
tokens = tokens[:max_length - 5]
tokens = [tokenizer.cls_token, mode, tokenizer.sep_token] + tokens + [tokenizer.sep_token]
tokens_id = tokenizer.convert_tokens_to_ids(tokens)
if padding:
tokens_id = tokens_id + [self.config.pad_token_id] * (max_length - len(tokens_id))
tokens_ids.append(tokens_id)
return tokens_ids
def decode(self, source_ids):
""" Convert token ids to string """
predictions = []
for x in source_ids:
prediction = []
for y in x:
t = y.cpu().numpy()
t = list(t)
if 0 in t:
t = t[:t.index(0)]
text = self.tokenizer.decode(t, clean_up_tokenization_spaces=False)
prediction.append(text)
predictions.append(prediction)
return predictions
def forward(self, source_ids):
""" Obtain token embeddings and sentence embeddings """
mask = source_ids.ne(self.config.pad_token_id)
token_embeddings = self.model(source_ids, attention_mask=mask.unsqueeze(1) * mask.unsqueeze(2))[0]
sentence_embeddings = (token_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1)
return token_embeddings, sentence_embeddings
def generate(self, source_ids, decoder_only=True, eos_id=None, beam_size=5, max_length=64):
""" Generate sequence given context (source_ids) """
# Set encoder mask attention matrix: bidirectional for <encoder-decoder>, unirectional for <decoder-only>
if decoder_only:
mask = self.bias[:, :source_ids.size(-1), :source_ids.size(-1)]
else:
mask = source_ids.ne(self.config.pad_token_id)
mask = mask.unsqueeze(1) * mask.unsqueeze(2)
if eos_id is None:
eos_id = self.config.eos_token_id
device = source_ids.device
# Decoding using beam search
preds = []
zero = torch.LongTensor(1).fill_(0).to(device)
source_len = list(source_ids.ne(1).sum(-1).cpu().numpy())
length = source_ids.size(-1)
encoder_output = self.model(source_ids, attention_mask=mask)
for i in range(source_ids.shape[0]):
context = [[x[i:i + 1, :, :source_len[i]].repeat(beam_size, 1, 1, 1) for x in y]
for y in encoder_output.past_key_values]
beam = Beam(beam_size, eos_id, device)
input_ids = beam.getCurrentState().clone()
context_ids = source_ids[i:i + 1, :source_len[i]].repeat(beam_size, 1)
out = encoder_output.last_hidden_state[i:i + 1, :source_len[i]].repeat(beam_size, 1, 1)
for _ in range(max_length):
if beam.done():
break
if _ == 0:
hidden_states = out[:, -1, :]
out = self.lsm(self.lm_head(hidden_states)).data
beam.advance(out)
input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
input_ids = beam.getCurrentState().clone()
else:
length = context_ids.size(-1) + input_ids.size(-1)
out = self.model(input_ids, attention_mask=self.bias[:, context_ids.size(-1):length, :length],
past_key_values=context).last_hidden_state
hidden_states = out[:, -1, :]
out = self.lsm(self.lm_head(hidden_states)).data
beam.advance(out)
input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
input_ids = torch.cat((input_ids, beam.getCurrentState().clone()), -1)
hyp = beam.getHyp(beam.getFinal())
pred = beam.buildTargetTokens(hyp)[:beam_size]
pred = [torch.cat([x.view(-1) for x in p] + [zero] * (max_length - len(p))).view(1, -1) for p in pred]
preds.append(torch.cat(pred, 0).unsqueeze(0))
preds = torch.cat(preds, 0)
return preds
class Beam(object):
def __init__(self, size, eos, device):
self.size = size
self.device = device
# The score for each translation on the beam.
self.scores = torch.FloatTensor(size).zero_().to(device)
# The backpointers at each time-step.
self.prevKs = []
# The outputs at each time-step.
self.nextYs = [torch.LongTensor(size).fill_(0).to(device)]
# Has EOS topped the beam yet.
self._eos = eos
self.eosTop = False
# Time and k pair for finished.
self.finished = []
def getCurrentState(self):
"Get the outputs for the current timestep."
batch = self.nextYs[-1].view(-1, 1)
return batch
def getCurrentOrigin(self):
"Get the backpointers for the current timestep."
return self.prevKs[-1]
def advance(self, wordLk):
"""
Given prob over words for every last beam `wordLk` and attention
`attnOut`: Compute and update the beam search.
Parameters:
* `wordLk`- probs of advancing from the last step (K x words)
* `attnOut`- attention at the last step
Returns: True if beam search is complete.
"""
numWords = wordLk.size(1)
# Sum the previous scores.
if len(self.prevKs) > 0:
beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
# Don't let EOS have children.
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
beamLk[i] = -1e20
else:
beamLk = wordLk[0]
flatBeamLk = beamLk.view(-1)
bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
self.scores = bestScores
# bestScoresId is flattened beam x word array, so calculate which
# word and beam each score came from
prevK = torch.div(bestScoresId, numWords, rounding_mode="floor")
self.prevKs.append(prevK)
self.nextYs.append((bestScoresId - prevK * numWords))
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] == self._eos:
s = self.scores[i]
self.finished.append((s, len(self.nextYs) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.nextYs[-1][0] == self._eos:
self.eosTop = True
def done(self):
return self.eosTop and len(self.finished) >= self.size
def getFinal(self):
if len(self.finished) == 0:
self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
self.finished.sort(key=lambda a: -a[0])
if len(self.finished) != self.size:
unfinished = []
for i in range(self.nextYs[-1].size(0)):
if self.nextYs[-1][i] != self._eos:
s = self.scores[i]
unfinished.append((s, len(self.nextYs) - 1, i))
unfinished.sort(key=lambda a: -a[0])
self.finished += unfinished[:self.size - len(self.finished)]
return self.finished[:self.size]
def getHyp(self, beam_res):
"""
Walk back to construct the full hypothesis.
"""
hyps = []
for _, timestep, k in beam_res:
hyp = []
for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
hyp.append(self.nextYs[j + 1][k])
k = self.prevKs[j][k]
hyps.append(hyp[::-1])
return hyps
def buildTargetTokens(self, preds):
sentence = []
for pred in preds:
tokens = []
for tok in pred:
if tok == self._eos:
break
tokens.append(tok)
sentence.append(tokens)
return sentence
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment