Skip to content
Snippets Groups Projects
Commit 97335f37 authored by Liam Byrne's avatar Liam Byrne
Browse files

refactoring

parent 44c05a8b
No related branches found
No related tags found
No related merge requests found
Showing
with 13929 additions and 13922 deletions
......@@ -10,3 +10,4 @@ runs/
embeddings/tag_embedding_models/*
*.pt
embeddings/wandb/*
*.pkl
{
"python.pythonPath": "\\\\filestore.soton.ac.uk\\users\\lhb1g20\\.conda\\envs\\gpuservice\\python.exe",
"terminal.integrated.windowsEnableConpty": false
}
\ No newline at end of file
No preview for this file type
File added
File added
......@@ -16,6 +16,7 @@ class UserGraphDatasetInMemory(InMemoryDataset):
self._question_ids = question_ids
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
# Remove gradient requirements
self.data = self.data.apply(lambda x: x.detach())
@property
......
import torch
import wandb
import os
from custom_logger import setup_custom_logger
import logging
log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
def create_graph_vis(graph):
g = to_networkx(graph.to_homogeneous())
pos = nx.spring_layout(g)
vis = GraphVisualization(
g, pos, node_text_position='top left', node_size=20,
)
fig = vis.create_figure()
return fig
"""
Weights & Biases dashboard
"""
def init_wandb(project_name: str, dataset, data_details):
wandb.init(project=project_name, name="setup")
# Log all the details about the data to W&B.
wandb.log(data_details)
# Log exploratory visualizations for each data point to W&B
table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
for graph in dataset:
fig = create_graph_vis(graph)
n_nodes = graph.num_nodes
n_edges = graph.num_edges
label = graph.label.item()
# graph_vis = plotly.io.to_html(fig, full_html=False)
table.add_data(wandb.Plotly(fig), n_nodes, n_edges, label)
wandb.log({"data": table})
# Log the dataset to W&B as an artifact.
dataset_artifact = wandb.Artifact(name="static-graphs", type="dataset", metadata=data_details)
dataset_artifact.add_dir("../data/")
wandb.log_artifact(dataset_artifact)
# End the W&B run
wandb.finish()
def start_wandb_for_training(wandb_project_name: str, wandb_run_name: str):
wandb.init(project=wandb_project_name, name=wandb_run_name)
# wandb.use_artifact("static-graphs:latest")
def add_cm_to_wandb(test_info):
wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
def log_results_to_wandb(results_map, results_name: str):
wandb.log({
f"{results_name}/loss": results_map["loss"],
f"{results_name}/accuracy": results_map["accuracy"],
f"{results_name}/f1": results_map["f1-score"],
f"{results_name}/table": results_map["table"]
})
"""
PyTorch helpers
"""
def save_model(model, model_name: str):
torch.save(model.state_dict(), os.path.join("..", "models", model_name))
def split_test_train_pytorch(dataset, train_split):
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - (train_size)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
return train_dataset, test_dataset
def calculate_class_weights(dataset):
# Class weights
log.info(f"Calculating class weights")
train_labels = [x.label for x in dataset]
counts = [train_labels.count(x) for x in [0, 1]]
log.info(counts)
class_weights = [1 - (x / sum(counts)) for x in counts]
log.info(class_weights)
sampler = torch.utils.data.WeightedRandomSampler([class_weights[x] for x in train_labels], len(train_labels))
return sampler
......@@ -10,16 +10,17 @@ import torch
from sklearn.metrics import f1_score, accuracy_score
from torch_geometric.loader import DataLoader
from torch_geometric.nn import HeteroConv, GATConv, Linear, global_mean_pool
from embeddings.helper_functions import calculate_class_weights, split_test_train_pytorch
from helper_functions import calculate_class_weights, split_test_train_pytorch
import wandb
from torch_geometric.utils import to_networkx
from sklearn.model_selection import KFold
from custom_logger import setup_custom_logger
from dataset import UserGraphDataset
from dataset_in_memory import UserGraphDatasetInMemory
from Visualize import GraphVisualization
import helper_functions
from hetero_GAT_constants import TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS
from hetero_GAT_constants import TRAIN_BATCH_SIZE, TEST_BATCH_SIZE, IN_MEMORY_DATASET, INCLUDE_ANSWER, USE_WANDB, WANDB_PROJECT_NAME, NUM_WORKERS, EPOCHS, NUM_LAYERS, HIDDEN_CHANNELS, FINAL_MODEL_OUT_PATH, SAVE_CHECKPOINTS, WANDB_RUN_NAME
log = setup_custom_logger("heterogenous_GAT_model", logging.INFO)
torch.multiprocessing.set_sharing_strategy('file_system')
......@@ -28,7 +29,6 @@ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
class HeteroGNN(torch.nn.Module):
"""
Heterogenous Graph Attentional Network (GAT)
......@@ -134,15 +134,22 @@ def test(loader):
loss = criterion(out, torch.squeeze(data.label, -1)) # Compute the loss.
cumulative_loss += loss.item()
pred = out.argmax(dim=1) # Use the class with highest probability.
# Use the class with highest probability.
pred = out.argmax(dim=1)
# Cache the predictions for calculating metrics
predictions += list([x.item() for x in pred])
true_labels += list([x.item() for x in data.label])
# Log table of predictions to WandB
if USE_WANDB:
#graph_html = wandb.Html(plotly.io.to_html(create_graph_vis(data)))
for pred, label in zip(pred, torch.squeeze(data.label, -1)):
table.add_data(label, pred)
# Collate results into a single dictionary
test_results = {
"accuracy": accuracy_score(true_labels, predictions),
"f1-score": f1_score(true_labels, predictions),
......@@ -163,6 +170,10 @@ if __name__ == '__main__':
if IN_MEMORY_DATASET:
train_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='train-4175-qs.pt')
test_dataset = UserGraphDatasetInMemory(root="../data", file_name_out='test-1790-qs.pt')
## TEST
split1, split2 = helper_functions.split_test_train_pytorch(train_dataset, 0.7)
else:
dataset = UserGraphDataset(root="../data", skip_processing=True)
train_dataset, test_dataset = split_test_train_pytorch(dataset)
......@@ -179,8 +190,9 @@ if __name__ == '__main__':
log.info(f"Data Details:\n{data_details}")
if USE_WANDB:
wandb_run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
helper_functions.start_wandb_for_training(WANDB_PROJECT_NAME, wandb_run_name)
if WANDB_RUN_NAME is None:
run_name = f"run@{time.strftime('%Y%m%d-%H%M%S')}"
helper_functions.start_wandb_for_training(WANDB_PROJECT_NAME, run_name)
# Class weights
......@@ -219,23 +231,15 @@ if __name__ == '__main__':
checkpoint_file_name = f"../models/model-{epoch}.pt"
torch.save(model.state_dict(), checkpoint_file_name)
# log evaluation results to wandb
if USE_WANDB:
wandb.log({
"train/loss": train_info["loss"],
"train/accuracy": train_info["accuracy"],
"train/f1": train_info["f1-score"],
"train/table": train_info["table"],
"test/loss": test_info["loss"],
"test/accuracy": test_info["accuracy"],
"test/f1": test_info["f1-score"],
"test/table": test_info["table"]
})
helper_functions.log_results_to_wandb(train_info, "train")
helper_functions.log_results_to_wandb(test_info, "test")
log.info(f'Test F1: {train_info["f1-score"]:.4f}')
helper_functions.save_model(model, FINAL_MODEL_OUT_PATH)
# Plot confusion matrix
if USE_WANDB:
wandb.log({"test/cm": wandb.plot.confusion_matrix(probs=None, y_true=test_info["trues"], preds=test_info["preds"], class_names=["neutral", "upvoted"])})
helper_functions.add_cm_to_wandb(test_info)
wandb.finish()
# Batch sizes
TRAIN_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512
# Data config
IN_MEMORY_DATASET = True
INCLUDE_ANSWER = True
# W&B dashboard logging
USE_WANDB = True
WANDB_PROJECT_NAME = "heterogeneous-GAT-model"
WANDB_RUN_NAME = None # None for timestamp
NUM_WORKERS = 14
# Training parameters
EPOCHS = 2
# Model architecture
NUM_LAYERS = 3
HIDDEN_CHANNELS = 64
FINAL_MODEL_OUT_PATH = "model.pt"
SAVE_CHECKPOINTS = False
\ No newline at end of file
absl-py==1.3.0
anyio==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
asttokens==2.1.0
attrs==22.1.0
backcall==0.2.0
beautifulsoup4==4.11.1
bleach==5.0.1
blis==0.7.9
bs4==0.0.1
cachetools==5.2.0
catalogue==2.0.8
certifi==2022.9.24
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
colorama==0.4.6
confection==0.0.3
contourpy==1.0.6
cycler==0.11.0
cymem==2.0.7
debugpy==1.6.3
decorator==5.1.1
defusedxml==0.7.1
dgl==0.9.1
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.4.1/en_core_web_sm-3.4.1-py3-none-any.whl
entrypoints==0.4
executing==1.2.0
fastjsonschema==2.16.2
filelock==3.8.0
fonttools==4.38.0
google-auth==2.14.0
google-auth-oauthlib==0.4.6
graph4nlp==0.5.5
grpcio==1.50.0
huggingface-hub==0.10.1
certifi==2022.12.7
charset-normalizer==2.0.12
dataclasses==0.8
decorator==4.4.2
googledrivedownloader==0.4
idna==3.4
importlib-metadata==5.0.0
importlib-resources==5.10.0
ipykernel==6.16.2
ipython==8.6.0
ipython-genutils==0.2.0
ipywidgets==8.0.2
jedi==0.18.1
Jinja2==3.1.2
joblib==1.2.0
jsonschema==4.16.0
jupyter==1.0.0
jupyter-console==6.4.4
jupyter-server==1.21.0
jupyter_client==7.4.4
jupyter_core==4.11.2
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.3
kiwisolver==1.4.4
langcodes==3.3.0
littleutils==0.2.2
lxml==4.9.1
Markdown==3.4.1
MarkupSafe==2.1.1
matplotlib==3.6.2
matplotlib-inline==0.1.6
mistune==2.0.4
murmurhash==1.0.9
nbclassic==0.4.7
nbclient==0.7.0
nbconvert==7.2.3
nbformat==5.7.0
nest-asyncio==1.5.6
networkx==2.8.8
nltk==3.7
notebook==6.5.2
notebook_shim==0.2.0
numpy==1.23.5
oauthlib==3.2.2
ogb==1.3.5
outdated==0.2.2
packaging==21.3
pandas==1.5.1
pandocfilters==1.5.0
parso==0.8.3
pathy==0.6.2
pickleshare==0.7.5
Pillow==9.3.0
pkgutil_resolve_name==1.3.10
preshed==3.0.8
prometheus-client==0.15.0
prompt-toolkit==3.0.31
protobuf==3.19.6
psutil==5.9.3
pure-eval==0.2.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pydantic==1.10.2
Pygments==2.13.0
importlib-resources==5.4.0
isodate==0.6.1
Jinja2==3.0.3
joblib==1.1.1
MarkupSafe==2.0.1
networkx==2.5.1
numpy==1.19.5
pandas==1.1.5
Pillow==8.4.0
pyparsing==3.0.9
pyrsistent==0.18.1
python-dateutil==2.8.2
pythonds==1.2.1
pytz==2022.5
pywin32==304
pywinpty==2.0.9
pytz==2022.7.1
PyYAML==6.0
pyzmq==24.0.1
qtconsole==5.3.2
QtPy==2.2.1
regex==2022.10.31
requests==2.28.1
requests-oauthlib==1.3.1
rsa==4.9
scikit-learn==1.1.3
scipy==1.9.3
Send2Trash==1.8.0
rdflib==5.0.0
requests==2.27.1
scikit-learn==0.24.2
scipy==1.5.4
six==1.16.0
smart-open==5.2.1
sniffio==1.3.0
soupsieve==2.3.2.post1
spacy==3.4.2
spacy-legacy==3.0.10
spacy-loggers==1.0.3
srsly==2.4.5
stack-data==0.6.0
stanfordcorenlp==3.9.1.1
tensorboard==2.10.1
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
terminado==0.17.0
thinc==8.1.5
threadpoolctl==3.1.0
tinycss2==1.2.1
tokenizers==0.13.2
torch==1.13.0
torchtext==0.14.0
tornado==6.2
tqdm==4.64.1
traitlets==5.5.0
transformers==4.24.0
typer==0.4.2
typing_extensions==4.4.0
urllib3==1.26.12
wasabi==0.10.1
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.4.1
Werkzeug==2.2.2
widgetsnbextension==4.0.3
zipp==3.10.0
typing_extensions==4.1.1
urllib3==1.26.14
yacs==0.1.8
zipp==3.6.0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment