diff --git a/.gitignore b/.gitignore
index 71601871e5bb3caa8b745f89d5e5bd099ee68a46..d2d51362ab6e1728ce5f02b4ff71332113ae3e5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,6 @@
 __pycache__/
 
 runs/
-data/experiments/GM_MI/
 
-dataset/
\ No newline at end of file
+experiments/debug
+dataset
\ No newline at end of file
diff --git a/GM_MI.py b/GM_MI.py
index b3b831e63972ca7c49aa8341e604ee52d22fab05..087056593d173fa2ba8c91f6dd08f066d272e393 100644
--- a/GM_MI.py
+++ b/GM_MI.py
@@ -302,7 +302,7 @@ def grow(model, writer, stage, device, args):
         partition_config=PARTITION_CONFIG_MIX["stage {}".format(stage)],
     )  # Load partitioned data for this stage's training, using the partition config for this stage
 
-    epochs = args.cluster_epochs if not args.debug else 1
+    epochs = args.cluster_epochs #if not args.debug else 1
 
     exemplar_means = model.calculate_exemplar_means(
         device, stage
@@ -792,9 +792,6 @@ def merge(model, writer, stage, device, args):
         dist_loss_record = AverageMeter()
         nce_loss_record = AverageMeter()
         model.train()
-        # model.fix_backbone()
-        # model.fix_backbone()
-        # model.fix_static(stage - 1)
 
         mix_loader = MixUpWrapper(train_loader, args)
         train_loader = mix_loader if args.batch_mixup else train_loader  # no mixup by default
@@ -806,36 +803,36 @@ def merge(model, writer, stage, device, args):
                 aug, stage
             )
 
-            feat_dists_to_exemplar_means = dist_func(
+            dists_to_exemplar_means = dist_func(
                 orig_feats.cpu().detach().numpy(), exemplar_means.cpu().numpy()
             )
             
-            feat_dists_to_exemplar_means = torch.from_numpy(feat_dists_to_exemplar_means).to(device)
-            min_f_2_p, min_indices = feat_dists_to_exemplar_means.min(dim=1)
-
-            refuse_indices = min_f_2_p >= thres2
+            dists_to_exemplar_means = torch.from_numpy(dists_to_exemplar_means).to(device)
+            min_dist_to_exemplar_means, min_indices = dists_to_exemplar_means.min(dim=1)
+            
+            refuse_idxs = min_dist_to_exemplar_means >= thres2
 
-            feature_accept = orig_feats[~refuse_indices]
-            feature_bar_accept = aug_feats[~refuse_indices]
-            pseudo_accept = pseudolabel[~refuse_indices]
+            accepted_orig_feats = orig_feats[~refuse_idxs]
+            accepted_aug_feats = aug_feats[~refuse_idxs]
+            accepted_labels = pseudolabel[~refuse_idxs]
 
             # ssl: x1 >-< cur_proto, and x1 <-> old_protos
-            logits = feature_accept.mm(exemplar_means.t())
+            logits = accepted_orig_feats.mm(exemplar_means.t())
             logits = logits / args.ssl_temperature
 
-            if args.no_pll:
+            if args.no_pll: # false by default
                 loss = torch.tensor(0.0).to(device)
             else:
-                loss = (-F.log_softmax(logits, dim=1) * pseudo_accept).sum(dim=1).mean()
+                loss = (-F.log_softmax(logits, dim=1) * accepted_labels).sum(dim=1).mean()
 
             if args.pull_exemplar_features:  # true in default params
-                if args.pll_exem:  # true in default params
+                if args.pll_exem:  # true in default params. Therefore skip other branches in this block
                     exemplars = orig
                 elif args.sync_backbone:  # true in default params
                     selector = torch.any(pseudolabel[:, :targets_min_idx] > 0.999, dim=1)
                     exemplars = orig[selector]
                 else:
-                    if not args.pef_all:
+                    if not args.pef_all: # args.pef_all True in default params
                         exemplars = model.exemplar_sets[:targets_min_idx]
                     else:
                         exemplars = model.exemplar_sets[:]
@@ -845,15 +842,15 @@ def merge(model, writer, stage, device, args):
                     exemplars = torch.cat(exemplars, dim=0)
 
                 with torch.no_grad():
-                    if args.ema_fast_slow:
+                    if args.ema_fast_slow: # false in default params
                         old_feature = model.slow_encoder(exemplars, 1)
-                    elif args.sync_backbone:
+                    elif args.sync_backbone: # true in default params
                         old_feature = model.old_encoder(exemplars, 0)
                     else:
                         old_feature = model.moco.encoder_q(exemplars, 0)
                 new_feature = model.moco.encoder_q(exemplars, stage)
 
-                if args.pef_type == "cos":
+                if args.pef_type == "cos": # "cos" in default params
                     old_feature = F.normalize(old_feature, dim=1, p=2)
                     new_feature = F.normalize(new_feature, dim=1, p=2)
 
@@ -872,16 +869,16 @@ def merge(model, writer, stage, device, args):
 
             if args.ssl_with_cluster:
 
-                topk_min_f_2_p_indices = torch.argsort(min_f_2_p, descending=True)[
+                topk_min_dist_to_exemplar_means_idxs = torch.argsort(min_dist_to_exemplar_means, descending=True)[
                     : int(orig.shape[0] * args.ood_thres)
                 ]
 
-                prob1 = F.softmax(orig_class_logits, dim=1)
-                prob2 = F.softmax(orig_cluster_logits, dim=1)
-                prob1_bar = F.softmax(aug_class_logits, dim=1)
-                prob2_bar = F.softmax(aug_cluster_logits, dim=1)
+                orig_class_prob = F.softmax(orig_class_logits, dim=1)
+                orig_clust_prob = F.softmax(orig_cluster_logits, dim=1)
+                aug_class_prob = F.softmax(aug_class_logits, dim=1)
+                aug_clust_prob = F.softmax(aug_cluster_logits, dim=1)
 
-                rank_feat = orig_feats[topk_min_f_2_p_indices].detach()
+                rank_feat = orig_feats[topk_min_dist_to_exemplar_means_idxs].detach()
 
                 rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
                 rank_idx1, rank_idx2 = PairEnum(rank_idx)
@@ -897,23 +894,23 @@ def merge(model, writer, stage, device, args):
                 target_ulb = torch.ones_like(rank_diff).float().to(device)
                 target_ulb[rank_diff > 0] = -1
 
-                prob1_ulb, _ = PairEnum(prob2[topk_min_f_2_p_indices])
-                _, prob2_ulb = PairEnum(prob2_bar[topk_min_f_2_p_indices])
+                prob1_ulb, _ = PairEnum(orig_clust_prob[topk_min_dist_to_exemplar_means_idxs])
+                _, prob2_ulb = PairEnum(aug_clust_prob[topk_min_dist_to_exemplar_means_idxs])
 
                 if args.bce:
                     bce_loss = bce(prob1_ulb, prob2_ulb, target_ulb)
                     bce_loss_record.update(bce_loss.item(), orig.size(0))
                     loss += bce(prob1_ulb, prob2_ulb, target_ulb)
                 if args.mse:
-                    mse_loss = F.mse_loss(prob2, prob2_bar)
-                    if not args.no_fc:
-                        mse_loss += F.mse_loss(prob1, prob1_bar)
+                    mse_loss = F.mse_loss(orig_clust_prob, aug_clust_prob)
+                    if not args.no_fc: # true in default params
+                        mse_loss += F.mse_loss(orig_class_prob, aug_class_prob)
                     mse_loss_record.update(mse_loss.item(), orig.size(0))
                     loss += mse_loss
 
-            if args.ssl_nce:
-                f = F.normalize(feature_accept, dim=1, p=2)
-                f_bar = F.normalize(feature_bar_accept, dim=1, p=2)
+            if args.ssl_nce: # False in default params
+                f = F.normalize(accepted_orig_feats, dim=1, p=2)
+                f_bar = F.normalize(accepted_aug_feats, dim=1, p=2)
                 logits_neg = f.mm(f.t())
                 logits_neg = logits_neg * (
                     1 - torch.eye(logits_neg.shape[0]).to(device).float()
@@ -927,7 +924,7 @@ def merge(model, writer, stage, device, args):
 
             if args.ssl_with_ce:  # false by default
                 loss_ce = (
-                    (-F.log_softmax(orig_class_logits, dim=1) * pseudo_accept.detach())
+                    (-F.log_softmax(orig_class_logits, dim=1) * accepted_labels.detach())
                     .sum(1)
                     .mean()
                 )
@@ -956,12 +953,13 @@ def merge(model, writer, stage, device, args):
 
             writer.add_scalar(
                 f"Stage {stage}/Merge/Train/Number of Refused Samples",
-                refuse_indices.sum(),
+                refuse_idxs.sum(),
                 epoch,
             )  # was labelled n refuse
 
         lr_sched.step()
 
+        # args.reselect_exemplars = False
         if args.reselect_exemplars and epoch % args.reselect_exemplars_interval == (
             args.reselect_exemplars_interval - 1
         ):
@@ -1399,7 +1397,7 @@ if __name__ == "__main__":
     device = torch.device("cuda" if args.cuda else "cpu")
     seed_torch(args.seed)
     runner_name = os.path.basename(__file__).split(".")[0]
-    model_dir = os.path.join(args.exp_root, args.model_name)
+    model_dir = os.path.join(args.exp_root, args.model_name if not args.debug else "debug")
     os.makedirs(model_dir, exist_ok=True)
     args.model_dir = model_dir
 
diff --git a/models/moco.py b/models/moco.py
index ce8c08510c4f70727cce53df424f010864ceae11..f6ea534e33752034995fbf39419768a68afba636 100644
--- a/models/moco.py
+++ b/models/moco.py
@@ -33,7 +33,7 @@ class MultiBranchResNet(nn.Module):
         replace_stride_with_dilation=None, norm_layer=None, branch_depth=3, same_branch_size=False, same_branch=False):
         super(MultiBranchResNet, self).__init__()
 
-        self.same_branch = same_branch
+        self.same_branch = same_branch # True by default
 
         if norm_layer is None:
             norm_layer = nn.BatchNorm2d
@@ -50,7 +50,7 @@ class MultiBranchResNet(nn.Module):
                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
         self.groups = groups
         self.base_width = width_per_group
-        self.branch_depth = branch_depth
+        self.branch_depth = branch_depth # 3 by default
         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
         self.bn1 = norm_layer(self.inplanes)
         self.relu = nn.ReLU(inplace=True)
@@ -63,8 +63,8 @@ class MultiBranchResNet(nn.Module):
         self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
 
         branches = []
-        if same_branch_size:
-            if branch_depth == 3:
+        if same_branch_size: # False by default
+            if branch_depth == 3: # 3 by default
                 for _ in range(3):
                     self.inplanes = 256
                     layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
@@ -87,10 +87,10 @@ class MultiBranchResNet(nn.Module):
                     )
                     branches.append(layer)
         else:
-            if branch_depth == 3:
-                for _ in range(3):
+            if branch_depth == 3: # this is the one we use
+                for _ in range(3): # creates a three basic block network
                     self.inplanes = 256
-                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+                    layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 
                     branches.append(layer)
             elif branch_depth == 2:
                 for _ in range(3):
@@ -170,9 +170,9 @@ class MultiBranchResNet(nn.Module):
         if branch == 0:
             x = self.layer4(x3)
         else:
-            b = branch if not self.same_branch else 1
-            if self.branch_depth == 3:
-                x = self.branches[b - 1](x3)
+            b = branch if not self.same_branch else 1 # same_branch is True by defauly, therefore b = 1
+            if self.branch_depth == 3: # 3 by default
+                x = self.branches[b - 1](x3) # run x3 through 0th branch
             elif self.branch_depth == 2:
                 x = self.branches[b - 1](x2)
             elif self.branch_depth == 1:
@@ -390,7 +390,7 @@ class MultiBranchModelBase(nn.Module):
 
         # use split batchnorm
         norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
-        if fully:
+        if fully: # false
             self.net = MultiBranchResNetFully(BasicBlock, [2, 2, 2, 2], num_classes=feature_dim, norm_layer=norm_layer, branch_depth=branch_depth, same_branch=same_branch)
             pretrained = resnet18(pretrained=True)
             self.net.load_state_dict(pretrained.state_dict(), strict=False)