diff --git a/__pycache__/custom_logger.cpython-38.pyc b/__pycache__/custom_logger.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd1b7e74e21f2765fb90fb38781bc9be52b35809 Binary files /dev/null and b/__pycache__/custom_logger.cpython-38.pyc differ diff --git a/embeddings/GAT.py b/archived/GAT.py similarity index 100% rename from embeddings/GAT.py rename to archived/GAT.py diff --git a/embeddings/db_handler.py b/archived/db_handler.py similarity index 100% rename from embeddings/db_handler.py rename to archived/db_handler.py diff --git a/embeddings/embeddings.ipynb b/archived/embeddings.ipynb similarity index 100% rename from embeddings/embeddings.ipynb rename to archived/embeddings.ipynb diff --git a/embeddings/pyg_construction_demo.ipynb b/archived/pyg_construction_demo.ipynb similarity index 100% rename from embeddings/pyg_construction_demo.ipynb rename to archived/pyg_construction_demo.ipynb diff --git a/embeddings/tag_embedding_trainer.py b/archived/tag_embedding_trainer.py similarity index 100% rename from embeddings/tag_embedding_trainer.py rename to archived/tag_embedding_trainer.py diff --git a/custom_logger.py b/custom_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..428f5a43902b960d46ee84d8064d47173258d58c --- /dev/null +++ b/custom_logger.py @@ -0,0 +1,16 @@ +import logging +import sys + + +def setup_custom_logger(name, level): + formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s', + datefmt='%Y-%m-%d %H:%M:%S') + screen_handler = logging.StreamHandler(stream=sys.stdout) + screen_handler.setFormatter(formatter) + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(level) + + logger.handlers.clear() + logger.addHandler(screen_handler) + return logger diff --git a/embeddings/__pycache__/dataset.cpython-38.pyc b/embeddings/__pycache__/dataset.cpython-38.pyc index 05d3eee3ff36256b85b90bfb9541b5349d0b6872..8f00b97e8ca224ab5e5ca43a353fac98b07d13c1 100644 Binary files a/embeddings/__pycache__/dataset.cpython-38.pyc and b/embeddings/__pycache__/dataset.cpython-38.pyc differ diff --git a/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc b/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f733627988e7838d5f142171ab9486d0d03313e1 Binary files /dev/null and b/embeddings/__pycache__/dataset_in_memory.cpython-38.pyc differ diff --git a/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc b/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc index a5e9dda385534af7818ba061bb26aabf70615bf9..9d69ec7710c3ec43bf3d6926c232566a68e672bc 100644 Binary files a/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc and b/embeddings/__pycache__/post_embedding_builder.cpython-38.pyc differ diff --git a/embeddings/__pycache__/static_graph_construction.cpython-38.pyc b/embeddings/__pycache__/static_graph_construction.cpython-38.pyc index 25ddf94535803d6606156bd35f630605d2bf1fbe..5eae8edcf42817130d409e7c8ba9f90b9149efdb 100644 Binary files a/embeddings/__pycache__/static_graph_construction.cpython-38.pyc and b/embeddings/__pycache__/static_graph_construction.cpython-38.pyc differ diff --git a/embeddings/dataset.py b/embeddings/dataset.py index 319af3ea3f12a3cae30ee59ca39c13bcb09c910b..165ccbfd85e18e9168b0459904b992ac91a597f7 100644 --- a/embeddings/dataset.py +++ b/embeddings/dataset.py @@ -13,15 +13,15 @@ from torch_geometric.data import Dataset, download_url, Data, HeteroData from torch_geometric.data.hetero_data import NodeOrEdgeStorage from tqdm import tqdm import warnings + +from custom_logger import setup_custom_logger + warnings.filterwarnings('ignore', category=MarkupResemblesLocatorWarning) from post_embedding_builder import PostEmbedding from static_graph_construction import StaticGraphConstruction -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) -log = logging.getLogger("dataset") - +log = setup_custom_logger('dataset', logging.INFO) class UserGraphDataset(Dataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, db_address:str=None, skip_processing=False): diff --git a/embeddings/dataset_in_memory.py b/embeddings/dataset_in_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..153055e362ba7a2164820943821e540113ece47b --- /dev/null +++ b/embeddings/dataset_in_memory.py @@ -0,0 +1,51 @@ +import logging +import os + +import torch +from torch_geometric.data import InMemoryDataset + +from custom_logger import setup_custom_logger + +log = setup_custom_logger('in-memory-dataset', logging.INFO) + +class UserGraphDatasetInMemory(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): + super().__init__(root, transform, pre_transform, pre_filter) + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + return os.path.join(self.root, 'processed_in_memory') + + @property + def raw_dir(self): + return os.path.join(self.root, 'processed') + + @property + def raw_file_names(self): + return [x for x in os.listdir("../data/processed") if x not in ['pre_filter.pt', 'pre_transform.pt']] + + @property + def processed_file_names(self): + return ['in-memory-dataset.pt'] + + def download(self): + pass + + def process(self): + # Read data into huge `Data` list. + data_list = [] + + for f in self.raw_file_names: + data = torch.load(os.path.join(self.raw_dir, f)) + data_list.append(data) + + data, slices = self.collate(data_list) + torch.save((data, slices), os.path.join(self.processed_dir, self.processed_file_names[0])) + + + +if __name__ == '__main__': + dataset = UserGraphDatasetInMemory('../data/') + + print(dataset.get(3)) \ No newline at end of file diff --git a/embeddings/hetero_GAT.py b/embeddings/hetero_GAT.py index 98ca211618753b93a0ea461364f516558adc3251..f4c9d33bf4a99e9a85668c2152832dd52414b4ab 100644 --- a/embeddings/hetero_GAT.py +++ b/embeddings/hetero_GAT.py @@ -12,12 +12,13 @@ from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool import wandb from torch_geometric.utils import to_networkx +from custom_logger import setup_custom_logger from dataset import UserGraphDataset +from dataset_in_memory import UserGraphDatasetInMemory from Visualize import GraphVisualization -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) -log = logging.getLogger('heterogeneous-GAT-model') +log = setup_custom_logger("heterogenous_GAT_model", logging.INFO) + class HeteroGNN(torch.nn.Module): def __init__(self, hidden_channels, out_channels, num_layers): @@ -54,7 +55,6 @@ class HeteroGNN(torch.nn.Module): if len(x): outs.append(global_mean_pool(x, batch=batch, size=len(post_emb))) else: - #print("EMPTY") outs.append(torch.zeros(1, x.size(-1))) #print([x.shape for x in outs]) @@ -75,7 +75,6 @@ def train(model, train_loader): model.train() for i, data in enumerate(train_loader): # Iterate in batches over the training dataset. - #print(data) data = data.to(device) optimizer.zero_grad() # Clear gradients. @@ -121,7 +120,6 @@ def test(loader): for pred, label in zip(pred, torch.squeeze(data.label, -1)): table.add_data(graph_html, label, pred) - #print("PRED", predictions, true_labels) return accuracy_score(true_labels, predictions), f1_score(true_labels, predictions), loss_ / len(loader), table def create_graph_vis(graph): @@ -146,9 +144,9 @@ def init_wandb(project_name: str, dataset): n_edges = graph.num_edges label = graph.label.item() - graph_vis = plotly.io.to_html(fig, full_html=False) + #graph_vis = plotly.io.to_html(fig, full_html=False) - table.add_data(graph_vis, n_nodes, n_edges, label) + table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label) wandb.log({"data": table}) # Log the dataset to W&B as an artifact. @@ -159,17 +157,30 @@ def init_wandb(project_name: str, dataset): # End the W&B run wandb.finish() -def start_wandb_for_training(wandb_project_name: str): - wandb.init(project=wandb_project_name) - wandb.use_artifact("static-graphs:latest") +def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str): + wandb.init(project=wandb_project_name, name=wandb_run_name) + #wandb.use_artifact("static-graphs:latest") if __name__ == '__main__': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(f"Proceeding with {device} . .") + in_memory_dataset = False # Datasets - dataset = UserGraphDataset(root="../data", skip_processing=True) - train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.6, 0.1, 0.3]) + if in_memory_dataset: + dataset = UserGraphDatasetInMemory(root="../data") + else: + dataset = UserGraphDataset(root="../data", skip_processing=True) + + + train_size = int(0.7 * len(dataset)) + val_size = int(0.1 * len(dataset)) + test_size = len(dataset) - (train_size + val_size) + + log.info(f"Train Dataset Size: {train_size}") + log.info(f"Validation Dataset Size: {val_size}") + log.info(f"Test Dataset Size: {test_size}") + train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size]) # Weights&Biases dashboard data_details = { @@ -182,30 +193,35 @@ if __name__ == '__main__': wandb_project_name = "heterogeneous-GAT-model" if setup_wandb: init_wandb(wandb_project_name, dataset) - use_wandb = True + use_wandb = False if use_wandb: wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}" start_wandb_for_training(wandb_project_name, wandb_run_name) - # Class weights - train_labels = [x.label for x in train_dataset] - counts = [train_labels.count(x) for x in [0,1]] - class_weights = [1 - (x / sum(counts)) for x in counts] - sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) + + calculate_class_weights = False + #Class weights + sampler = None + if calculate_class_weights: + log.info(f"Calculating class weights") + train_labels = [x.label for x in train_dataset] + counts = [train_labels.count(x) for x in [0,1]] + class_weights = [1 - (x / sum(counts)) for x in counts] + sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels)) # Dataloaders - train_loader = DataLoader(train_dataset, sampler=None, batch_size=16) - val_loader = DataLoader(val_dataset, batch_size=1) - test_loader = DataLoader(test_dataset, batch_size=1) + train_loader = DataLoader(train_dataset, sampler=sampler, batch_size=512) + val_loader = DataLoader(val_dataset, batch_size=16) + test_loader = DataLoader(test_dataset, batch_size=16) # Model - model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3) - model.to(device) + model = HeteroGNN(hidden_channels=64, out_channels=2, num_layers=3).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = torch.nn.CrossEntropyLoss() for epoch in range(1, 10): + log.info(f"Epoch: {epoch:03d} > > >") train(model, train_loader) train_acc, train_f1, train_loss, train_table = test(train_loader) val_acc, val_f1, val_loss, val_table = test(val_loader) diff --git a/embeddings/post_embedding_builder.py b/embeddings/post_embedding_builder.py index 2a3674ca5c28ed0f5c8e5aea8e32c670639a84ba..29fa4ff950dc933323b5d8a2e3ecd6877da7e3cc 100644 --- a/embeddings/post_embedding_builder.py +++ b/embeddings/post_embedding_builder.py @@ -9,17 +9,18 @@ from typing import List from NextTagEmbedding import NextTagEmbedding, NextTagEmbeddingTrainer -logging.basicConfig() -logging.getLogger().setLevel(logging.DEBUG) -log = logging.getLogger("PostEmbedding") - from bs4 import BeautifulSoup import spacy import torch import torch.nn as nn from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel, AutoConfig + +from custom_logger import setup_custom_logger from unixcoder import UniXcoder +log = setup_custom_logger('post_embedding_builder', logging.INFO) + + Import = namedtuple("Import", ["module", "name", "alias"]) Function = namedtuple("Function", ["function_name", "parameter_names"]) diff --git a/embeddings/static_graph_construction.py b/embeddings/static_graph_construction.py index e8e9f3baf6e6b3ef13b6154a7956e6b9fcbfd958..eefa12f9eee0de73b1c811c3e709be5ce2d02a30 100644 --- a/embeddings/static_graph_construction.py +++ b/embeddings/static_graph_construction.py @@ -30,12 +30,11 @@ class BatchedHeteroData(HeteroData): class StaticGraphConstruction: - # PostEmbedding is costly to instantiate in each StaticGraphConstruction instance. - post_embedding_builder = PostEmbedding() - tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3) - module_embedding_model = ModuleEmbeddingTrainer.load_model("../models/module-emb-1milx5-30d-49911.pt", embedding_dim=30, vocab_size=49911) - def __init__(self): + # PostEmbedding is costly to instantiate in each StaticGraphConstruction instance. + post_embedding_builder = PostEmbedding() + tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3) + module_embedding_model = ModuleEmbeddingTrainer.load_model("../models/module-emb-1milx5-30d-49911.pt", embedding_dim=30, vocab_size=49911) self._known_tags = {} # tag_name -> index self._known_modules = {} # module_name -> index self._data = BatchedHeteroData() diff --git a/embeddings/tag_embedding_models/10mil_500d_embd.pt b/embeddings/tag_embedding_models/10mil_500d_embd.pt deleted file mode 100644 index 91881da44787cc40eec3e91f2139414a0b7258ba..0000000000000000000000000000000000000000 Binary files a/embeddings/tag_embedding_models/10mil_500d_embd.pt and /dev/null differ diff --git a/embeddings/tag_embedding_models/10mil_embd.pt b/embeddings/tag_embedding_models/10mil_embd.pt deleted file mode 100644 index 46c0789a42dc92fcdf73ae2286c1bb88172de0c0..0000000000000000000000000000000000000000 Binary files a/embeddings/tag_embedding_models/10mil_embd.pt and /dev/null differ diff --git a/embeddings/tag_embedding_models/1mil_750d_embd.pt b/embeddings/tag_embedding_models/1mil_750d_embd.pt deleted file mode 100644 index b266910b1bd67b214f235a5df5aba06707098871..0000000000000000000000000000000000000000 Binary files a/embeddings/tag_embedding_models/1mil_750d_embd.pt and /dev/null differ diff --git a/embeddings/user-graph.svg b/figures/user-graph.svg similarity index 100% rename from embeddings/user-graph.svg rename to figures/user-graph.svg