Skip to content
Snippets Groups Projects
Commit 585a92bd authored by Liam Byrne's avatar Liam Byrne
Browse files

starting explainer tools

parent 560b70cb
No related branches found
No related tags found
No related merge requests found
Showing
with 1687 additions and 350 deletions
......@@ -27,6 +27,8 @@ class GAT(torch.nn.Module):
# 2. Readout layer
x = self.pool(x, batch) # [batch_size, hidden_channels]
x = torch.cat([x, post_emb], dim=1)
# 3. Concatenate with post embedding
#x = torch.cat((x, post_emb))
# 4. Apply a final classifier.
......@@ -41,9 +43,9 @@ def train(model, train_loader):
for data in train_loader: # Iterate in batches over the training dataset.
print(data)
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, torch.concat([data.question_emb, data.answer_emb])) # Perform a single forward pass.
out = model(data.x_dict, data.edge_index_dict, data.batch_dict, torch.cat([data.question_emb, data.answer_emb], dim=1)) # Perform a single forward pass.
print(out, data.label)
loss = criterion(out, data.label) # Compute the loss.
loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
......
File added
No preview for this file type
No preview for this file type
No preview for this file type
File added
No preview for this file type
No preview for this file type
......@@ -306,7 +306,6 @@ if __name__ == '__main__':
print(confusion_matrix)
if SAVE_CHECKPOINTS:
print("HELLO")
torch.save(model.state_dict(), f"../models/baseline-model-{epoch}.pt")
# log evaluation results to wandb
......
......@@ -41,7 +41,7 @@ class UserGraphDataset(Dataset):
@property
def processed_file_names(self):
if self._skip_processing:
return os.listdir("../data/processed")
return os.listdir(os.path.join(self.root, "processed"))
return []
def download(self):
......@@ -54,7 +54,7 @@ class UserGraphDataset(Dataset):
processed = []
max_idx = -1
for f in os.listdir("../data/processed"):
for f in os.listdir(os.path.join(self.root, "processed")):
question_id_search = re.search(r"id_(\d+)", f)
if question_id_search:
processed.append(int(question_id_search.group(1)))
......@@ -72,81 +72,53 @@ class UserGraphDataset(Dataset):
def process(self):
"""
"""
log.info("Processing data...")
'''TIME START'''
t1 = time.time()
# Fetch the unprocessed questions and the next index to use.
unprocessed, idx = self.get_unprocessed_ids()
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.get_unprocessed_ids.__name__, t2 - t1))
'''TIME START'''
t1 = time.time()
# Fetch questions from database.
valid_questions = self.fetch_questions_by_post_ids(unprocessed)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.fetch_questions_by_post_ids.__name__, t2 - t1))
for row in tqdm(valid_questions.itertuples(), total=len(valid_questions)):
'''TIME START'''
t1 = time.time()
for _, question in tqdm(valid_questions.iterrows(), total=len(valid_questions)):
# Build Question embedding
question_word_embs, question_code_embs, _ = self._post_embedding_builder(
[row.question_body],
[question["Body"]],
use_bert=True,
title_batch=[row.question_title]
title_batch=[question["Title"]]
)
question_emb = torch.concat((question_word_embs[0], question_code_embs[0]))
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % ("Post embedding builder (question)", t2 - t1))
'''TIME START'''
t1 = time.time()
# Fetch answers to question
answers_to_question = self.fetch_answers_for_question(row.post_id)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.fetch_answers_for_question.__name__, t2 - t1))
answers_to_question = self.fetch_answers_for_question(question["PostId"])
# Build Answer embeddings
for _, answer_body, answer_user_id, score in answers_to_question.itertuples():
label = torch.tensor([1 if score > 0 else 0], dtype=torch.long)
for _, answer in answers_to_question.iterrows():
label = torch.tensor([1 if answer["Score"] > 0 else 0], dtype=torch.long)
answer_word_embs, answer_code_embs, _ = self._post_embedding_builder(
[answer_body], use_bert=True, title_batch=[None]
[answer["Body"]], use_bert=True, title_batch=[None]
)
answer_emb = torch.concat((answer_word_embs[0], answer_code_embs[0]))
'''TIME START'''
t1 = time.time()
# Build graph
graph: HeteroData = self.construct_graph(answer_user_id)
'''TIME END'''
t2 = time.time()
log.debug("Function=%s, Time=%s" % (self.construct_graph.__name__, t2 - t1))
graph: HeteroData = self.construct_graph(answer["OwnerUserId"])
# pytorch geometric data object
graph.__setattr__('question_emb', question_emb)
graph.__setattr__('answer_emb', answer_emb)
graph.__setattr__('score', answer["Score"])
graph.__setattr__('question_id', question["PostId"])
graph.__setattr__('answer_id', answer["PostId"])
graph.__setattr__('label', label)
torch.save(graph, os.path.join(self.processed_dir, f'data_{idx}_question_id_{row.post_id}'))
torch.save(graph, os.path.join(self.processed_dir, f'data_{idx}_question_id_{question["PostId"]}_answer_id_{answer["PostId"]}.pt'))
idx += 1
def len(self):
return len(self.processed_file_names)-2
def get(self, idx):
file_name = [filename for filename in os.listdir('../data/processed/') if filename.startswith(f"data_{idx}")]
file_name = [filename for filename in os.listdir(os.path.join(self.root, 'processed')) if filename.startswith(f"data_{idx}")]
if len(file_name):
data = torch.load(os.path.join(self.processed_dir, file_name[0]))
......@@ -160,10 +132,9 @@ class UserGraphDataset(Dataset):
def fetch_questions_by_post_ids(self, post_ids: List[int]):
questions_df = pd.read_sql_query(f"""
SELECT PostId, Body, Title, OwnerUserId FROM Post
SELECT * FROM Post
WHERE PostId IN ({','.join([str(x) for x in post_ids])})
""", self._db)
questions_df.columns = ['post_id', 'question_body', 'question_title', 'question_user_id']
return questions_df
def fetch_questions_by_user(self, user_id: int):
......@@ -187,14 +158,28 @@ class UserGraphDataset(Dataset):
return answers_df
def fetch_answers_for_question(self, question_post_id: int):
"""
Fetch answers for a question for P@1 evaluation
"""
answers_df = pd.read_sql_query(f"""
SELECT Body, OwnerUserId, Score
SELECT *
FROM Post
WHERE ParentId = {question_post_id}
""", self._db)
answers_df = answers_df.dropna()
answers_df = answers_df.dropna(subset=['PostId', 'Body', 'Score', 'OwnerUserId'])
return answers_df
def fetch_questions_by_post_ids_eval(self, post_ids: List[int]):
"""
Fetch questions for P@1 evaluation
"""
questions_df = pd.read_sql_query(f"""
SELECT * FROM Post
WHERE PostId IN ({','.join([str(x) for x in post_ids])})
""", self._db)
questions_df.columns = ['post_id', 'question_body', 'question_title', 'question_user_id']
return questions_df
def fetch_comments_by_user(self, user_id: int):
comments_on_questions_df = pd.read_sql_query(f"""
SELECT A.Tags, B.*
......@@ -215,6 +200,16 @@ class UserGraphDataset(Dataset):
return pd.concat([comments_on_questions_df, comments_on_answers_df])
def fetch_tags_for_question(self, question_post_id: int):
tags_df = pd.read_sql_query(f"""
SELECT Tags
FROM Post
WHERE PostId = {question_post_id}
""", self._db)
if len(tags_df) == 0:
return []
return tags_df.iloc[0]['Tags'][1:-1].split("><")
def construct_graph(self, user_id: int):
graph_constructor = StaticGraphConstruction()
qs = self.fetch_questions_by_user(user_id)
......@@ -232,7 +227,7 @@ if __name__ == '__main__':
'''
ds = UserGraphDataset('../data/', db_address='../stackoverflow.db', skip_processing=False)
ds = UserGraphDataset('../datav2/', db_address='../stackoverflow.db', skip_processing=False)
data = ds.get(1078)
print("Question ndim:", data.x_dict['question'].shape)
print("Answer ndim:", data.x_dict['answer'].shape)
......
......@@ -30,7 +30,7 @@ class UserGraphDatasetInMemory(InMemoryDataset):
@property
def raw_file_names(self):
return [x for x in os.listdir("../data/processed") if x not in ['pre_filter.pt', 'pre_transform.pt']]
return [x for x in os.listdir(os.path.join(self.root, "processed")) if x not in ['pre_filter.pt', 'pre_transform.pt']]
@property
def processed_file_names(self):
......@@ -59,10 +59,10 @@ class UserGraphDatasetInMemory(InMemoryDataset):
"""
"""
def fetch_question_ids() -> List[int]:
def fetch_question_ids(root) -> List[int]:
question_ids = set()
# Split by question ids
for f in os.listdir("../data/processed"):
for f in os.listdir(os.path.join(root, "processed")):
question_id_search = re.search(r"id_(\d+)", f)
if question_id_search:
question_ids.add(int(question_id_search.group(1)))
......@@ -71,8 +71,8 @@ def fetch_question_ids() -> List[int]:
def split(a, n):
k, m = divmod(len(a), n)
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
def create_datasets_for_kfolds(folds):
question_ids = fetch_question_ids()
def create_datasets_for_kfolds(folds, root):
question_ids = fetch_question_ids(root)
random.shuffle(question_ids)
folds = list(split(question_ids, folds))
......@@ -81,7 +81,8 @@ def create_datasets_for_kfolds(folds):
def create_train_test_datasets():
question_ids = fetch_question_ids()
question_ids = fetch_question_ids(ROOT)
# question_ids = list(question_ids)[:len(question_ids)* 0.6]
train_ids = list(question_ids)[:int(len(question_ids) * 0.7)]
......@@ -90,15 +91,16 @@ def create_train_test_datasets():
log.info(f"Training question count {len(train_ids)}")
log.info(f"Testing question count {len(test_ids)}")
train_dataset = UserGraphDatasetInMemory('../data/', f'train-{len(train_ids)}-qs.pt', train_ids)
test_dataset = UserGraphDatasetInMemory('../data/', f'test-{len(test_ids)}-qs.pt', test_ids)
train_dataset = UserGraphDatasetInMemory(ROOT, f'train-{len(train_ids)}-qs.pt', train_ids)
test_dataset = UserGraphDatasetInMemory(ROOT, f'test-{len(test_ids)}-qs.pt', test_ids)
return train_dataset, test_dataset
if __name__ == '__main__':
ROOT = "../datav2/"
choice = input("1. Create train/test datasets\n2. Create k-fold datasets\n>>>")
if choice == '1':
create_train_test_datasets()
elif choice == '2':
n = int(input("Enter number of folds: "))
folds = list(create_datasets_for_kfolds(n))
folds = list(create_datasets_for_kfolds(n, ROOT))
This diff is collapsed.
Source diff could not be displayed: it is too large. Options to address this: view the blob.
This diff is collapsed.
......@@ -29,17 +29,17 @@ class BatchedHeteroData(HeteroData):
return super().__cat_dim__(key, value)
class StaticGraphConstruction:
def __init__(self):
# PostEmbedding is costly to instantiate in each StaticGraphConstruction instance.
post_embedding_builder = PostEmbedding()
tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3)
module_embedding_model = ModuleEmbeddingTrainer.load_model("../models/module-emb-1milx5-30d-49911.pt", embedding_dim=30, vocab_size=49911)
post_embedding_builder = PostEmbedding()
def __init__(self):
# PostEmbedding is costly to instantiate in each StaticGraphConstruction instance.
self._known_tags = {} # tag_name -> index
self._known_modules = {} # module_name -> index
self._data = BatchedHeteroData()
self._first_n_tags = 3
self._first_n_tags = 8
self._tag_to_question_edges = []
self._tag_to_answer_edges = []
......@@ -49,7 +49,7 @@ class StaticGraphConstruction:
self._module_to_answer_edges = []
self._module_to_comment_edges = []
self._use_bert = True
self._post_count_limit = 10
self._post_count_limit = 20
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment