Skip to content
Snippets Groups Projects
Commit c2a72dd2 authored by Liam Byrne's avatar Liam Byrne
Browse files

working tag embeddings

parent c87745f8
No related branches found
No related tags found
No related merge requests found
...@@ -110,9 +110,6 @@ class NextTagEmbeddingTrainer: ...@@ -110,9 +110,6 @@ class NextTagEmbeddingTrainer:
total_loss += loss.item() total_loss += loss.item()
losses.append(total_loss) losses.append(total_loss)
def get_tag_embedding(self, tag: str):
return self.model.embedding.weight[self.tag_to_ix[tag]]
def to_tensorboard(self, run_name: str): def to_tensorboard(self, run_name: str):
""" """
Write embedding to Tensorboard projector Write embedding to Tensorboard projector
...@@ -130,7 +127,7 @@ class NextTagEmbeddingTrainer: ...@@ -130,7 +127,7 @@ class NextTagEmbeddingTrainer:
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
# unpickle the tag_to_ix # unpickle the tag_to_ix
with open('tag_to_ix_' + model_path, 'rb') as f: with open(model_path.replace('tag-emb', f'tag_to_ix_tag-emb'), 'rb') as f:
model.tag_to_ix = pickle.load(f) model.tag_to_ix = pickle.load(f)
return model return model
...@@ -157,6 +154,11 @@ class NextTagEmbedding(nn.Module): ...@@ -157,6 +154,11 @@ class NextTagEmbedding(nn.Module):
log_probs = F.log_softmax(out, dim=1) log_probs = F.log_softmax(out, dim=1)
return log_probs return log_probs
def get_tag_embedding(self, tag: str):
assert tag in self.tag_to_ix, "Tag not in vocabulary!"
assert self.tag_to_ix is not None, "Tag to index mapping not set!"
return self.embedding.weight[self.tag_to_ix[tag]]
if __name__ == '__main__': if __name__ == '__main__':
tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db") tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db")
......
No preview for this file type
No preview for this file type
...@@ -19,7 +19,7 @@ class StaticGraphConstruction: ...@@ -19,7 +19,7 @@ class StaticGraphConstruction:
# PostEmbedding is costly to instantiate in each StaticGraphConstruction instance. # PostEmbedding is costly to instantiate in each StaticGraphConstruction instance.
post_embedding_builder = PostEmbedding() post_embedding_builder = PostEmbedding()
tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-1mil.pt", embedding_dim=30, vocab_size=63654, context_length=3) tag_embedding_model = NextTagEmbeddingTrainer.load_model("../models/tag-emb-7_5mil-50d-63653-3.pt", embedding_dim=50, vocab_size=63654, context_length=3)
def __init__(self): def __init__(self):
self._known_tags = {} # tag_name -> index self._known_tags = {} # tag_name -> index
...@@ -116,7 +116,7 @@ class StaticGraphConstruction: ...@@ -116,7 +116,7 @@ class StaticGraphConstruction:
if not len(self._known_tags): if not len(self._known_tags):
return None return None
for tag in self._known_tags: for tag in self._known_tags:
yield StaticGraphConstruction.tag_embedding_model.get_tag_embedding(tag) yield self.tag_embedding_model.get_tag_embedding(tag)
def process_modules(self): def process_modules(self):
if not len(self._known_modules): if not len(self._known_modules):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment