Skip to content
Snippets Groups Projects
Select Git revision
  • 9d3f9faeae729a4f6e2d9fcab6f764502a9d8d89
  • dev default
  • main protected
  • 61-feature-add-optional-backwards-mapping-for-consistency-with-older-version-2
  • 11-test-fix-tests-to-handle-licensed-data-resources-from-trud-snd-omop
  • general
  • pypi
  • old-main
  • v0.0.3
9 results

index.md

Blame
  • NextTagEmbedding.py 6.38 KiB
    import itertools
    import logging
    import random
    import sqlite3
    from typing import *
    
    import pandas as pd
    import torch
    from torch import nn, optim
    from torch.functional import F
    from torch.utils.tensorboard import SummaryWriter
    from tqdm import tqdm
    
    logging.basicConfig(level=logging.INFO)
    
    
    class NextTagEmbeddingTrainer:
    
        def __init__(self, context_length: int, emb_size: int, excluded_tags=None, database_path: str = None):
            logger = logging.getLogger(self.__class__.__name__)
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            logger.info(f"Proceeding with {self.device} . .")
            if database_path is not None:
                self.db = sqlite3.connect(database_path)
                logger.info(f"Connected to {database_path}")
            self.tag_vocab: List[str]
            self.post_tags: List[Tuple]
            self.context_length = context_length
            self.emb_size = emb_size
            self.excluded_tags = excluded_tags
    
    
        def build_cbow(self, tags: List[str], context_len: int) -> List[Tuple]:
    
            filtered_tags = [t for t in tags if t not in self.excluded_tags]
    
            if len(filtered_tags) <= 1:
                return []
    
            pairs = []
    
            for target in filtered_tags:
                context = [t for t in filtered_tags if t != target]
                # Pad or cut depending on the context length
                while len(context) < context_len:
                    context.append('PAD')
                while len(context) > context_len:
                    context = context[:-1]
                pairs.append((context, target))
    
            return pairs
    
        def from_files(self, post_tags_path: str, tag_vocab: str):
            tag_df = pd.read_csv(tag_vocab, keep_default_na=False)
            self.tag_vocab = list(set(tag_df["TagName"])) + ["PAD"]
    
            post_tags = pd.read_csv(post_tags_path)
            tag_list_df = post_tags['Tags'].apply(lambda row: self.build_cbow(self.parse_tag_list(row), self.context_length))
            context_and_target = tag_list_df[tag_list_df.astype(str) != '[]']
            # Now concatenate all the lists together
            tag_pairs = []
            for i in context_and_target:
                tag_pairs += i
            self.post_tags = tag_pairs
    
        def from_db(self):
            post_tags = pd.read_sql_query(f"SELECT Tags FROM Post WHERE PostTypeId=1 AND Tags LIKE '%python%' LIMIT 100000", self.db)
            tag_list_df = post_tags['Tags'].map(self.parse_tag_list)
    
            self.tag_vocab = list(set(tag_list_df.sum() + ["PAD"]))
    
            context_and_target = tag_list_df.apply(lambda row: self.build_cbow(row, self.context_length))
            context_and_target = context_and_target[context_and_target.astype(str) != '[]']
            # Now concatenate all the lists together
            tag_pairs = []
            for i in context_and_target:
                tag_pairs += i
            self.post_tags = tag_pairs
    
        def parse_tag_list(self, tag_list: str) -> List[str]:
            return tag_list[1:-1].split("><")
    
        def sample_n(self, df, train_size: int):
            return random.sample(df, train_size)
    
        def train(self, train_size: int, epochs: int):
            # Loss
            loss_function = nn.NLLLoss()
            losses = []
            # Model
            self.model = NextTagEmbedding(vocab_size=len(self.tag_vocab), embedding_dim=self.emb_size, context_size=self.context_length).to(self.device)
            # Optimizer
            optimizer = optim.SGD(self.model.parameters(), lr=0.001)
            # Enumerate the vocabulary, reflects the index of where the 1 is in the one-hot
            self.tag_to_ix = {tag: i for i, tag in enumerate(self.tag_vocab)}
            # Reduce size of training set
            samples = self.sample_n(self.post_tags, train_size)
    
            for epoch in range(epochs):
                total_loss = 0
                for context, target in tqdm(samples):
                    context_tensor = torch.tensor([self.tag_to_ix[t] for t in context], dtype=torch.long).to(self.device)
                    self.model.zero_grad()
                    print(context_tensor)
                    log_probs = self.model(context_tensor)
                    loss = loss_function(log_probs.flatten(), torch.tensor(self.tag_to_ix[target], dtype=torch.long).to(self.device))
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()
                losses.append(total_loss)
    
        def get_tag_embedding(self, tag: str):
            return self.model.embedding.weight[self.tag_to_ix[tag]]
    
        def to_tensorboard(self, run_name: str):
            """
            Write embedding to Tensorboard projector
            tensorboard --logdir="runs/run@20221102-173048"
            """
            writer = SummaryWriter(f'runs/{run_name}')
            writer.add_embedding(self.model.embedding.weight,
                                 metadata=self.tag_vocab,
                                 tag=f'Next-Tag embedding')
            writer.close()
    
        def load_model(self, model_path: str, vocab_size: int, embedding_dim: int):
            self.model = NextTagEmbedding(vocab_size, embedding_dim, self.context_length)
            self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    
        def save_model(self, model_path: str):
            torch.save(self.model.state_dict(), model_path)
    
    
    class NextTagEmbedding(nn.Module):
    
        def __init__(self, vocab_size, embedding_dim, context_size):
            super(NextTagEmbedding, self).__init__()
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
            self.linear1 = nn.Linear(context_size * embedding_dim, 128)
            self.linear2 = nn.Linear(128, vocab_size)
    
        def forward(self, inputs):
            embeds = self.embedding(inputs).view((1, -1))
            out = F.relu(self.linear1(embeds))
            out = self.linear2(out)
            log_probs = F.log_softmax(out, dim=1)
            return log_probs
    
    
    if __name__ == '__main__':
        tet = NextTagEmbeddingTrainer(context_length=2, emb_size=30, excluded_tags=['python'], database_path="../stackoverflow.db")
        tet.from_db()
        print(len(tet.post_tags))
        print(len(tet.tag_vocab))
    
    
        #tet = NextTagEmbeddingTrainer(context_length=3, emb_size=50)
    
        #tet.from_files("../data/raw/all_tags.csv", "../data/raw/tag_vocab.csv")
        # assert len(tet.post_tags) == 84187510, "Incorrect number of post tags!"
        # assert len(tet.tag_vocab) == 63653, "Incorrect vocab size!"
    
        tet.train(1000, 1)
        # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}")
    
        # tet.save_model("25mil.pt")
        # tet.load_model("10mil_500d_embd.pt", 63653, 500)
        # tet.to_tensorboard(f"run@{time.strftime('%Y%m%d-%H%M%S')}")