diff --git a/.gitignore b/.gitignore index 71601871e5bb3caa8b745f89d5e5bd099ee68a46..d2d51362ab6e1728ce5f02b4ff71332113ae3e5a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,6 @@ __pycache__/ runs/ -data/experiments/GM_MI/ -dataset/ \ No newline at end of file +experiments/debug +dataset \ No newline at end of file diff --git a/GM_MI.py b/GM_MI.py index b3b831e63972ca7c49aa8341e604ee52d22fab05..087056593d173fa2ba8c91f6dd08f066d272e393 100644 --- a/GM_MI.py +++ b/GM_MI.py @@ -302,7 +302,7 @@ def grow(model, writer, stage, device, args): partition_config=PARTITION_CONFIG_MIX["stage {}".format(stage)], ) # Load partitioned data for this stage's training, using the partition config for this stage - epochs = args.cluster_epochs if not args.debug else 1 + epochs = args.cluster_epochs #if not args.debug else 1 exemplar_means = model.calculate_exemplar_means( device, stage @@ -792,9 +792,6 @@ def merge(model, writer, stage, device, args): dist_loss_record = AverageMeter() nce_loss_record = AverageMeter() model.train() - # model.fix_backbone() - # model.fix_backbone() - # model.fix_static(stage - 1) mix_loader = MixUpWrapper(train_loader, args) train_loader = mix_loader if args.batch_mixup else train_loader # no mixup by default @@ -806,36 +803,36 @@ def merge(model, writer, stage, device, args): aug, stage ) - feat_dists_to_exemplar_means = dist_func( + dists_to_exemplar_means = dist_func( orig_feats.cpu().detach().numpy(), exemplar_means.cpu().numpy() ) - feat_dists_to_exemplar_means = torch.from_numpy(feat_dists_to_exemplar_means).to(device) - min_f_2_p, min_indices = feat_dists_to_exemplar_means.min(dim=1) - - refuse_indices = min_f_2_p >= thres2 + dists_to_exemplar_means = torch.from_numpy(dists_to_exemplar_means).to(device) + min_dist_to_exemplar_means, min_indices = dists_to_exemplar_means.min(dim=1) + + refuse_idxs = min_dist_to_exemplar_means >= thres2 - feature_accept = orig_feats[~refuse_indices] - feature_bar_accept = aug_feats[~refuse_indices] - pseudo_accept = pseudolabel[~refuse_indices] + accepted_orig_feats = orig_feats[~refuse_idxs] + accepted_aug_feats = aug_feats[~refuse_idxs] + accepted_labels = pseudolabel[~refuse_idxs] # ssl: x1 >-< cur_proto, and x1 <-> old_protos - logits = feature_accept.mm(exemplar_means.t()) + logits = accepted_orig_feats.mm(exemplar_means.t()) logits = logits / args.ssl_temperature - if args.no_pll: + if args.no_pll: # false by default loss = torch.tensor(0.0).to(device) else: - loss = (-F.log_softmax(logits, dim=1) * pseudo_accept).sum(dim=1).mean() + loss = (-F.log_softmax(logits, dim=1) * accepted_labels).sum(dim=1).mean() if args.pull_exemplar_features: # true in default params - if args.pll_exem: # true in default params + if args.pll_exem: # true in default params. Therefore skip other branches in this block exemplars = orig elif args.sync_backbone: # true in default params selector = torch.any(pseudolabel[:, :targets_min_idx] > 0.999, dim=1) exemplars = orig[selector] else: - if not args.pef_all: + if not args.pef_all: # args.pef_all True in default params exemplars = model.exemplar_sets[:targets_min_idx] else: exemplars = model.exemplar_sets[:] @@ -845,15 +842,15 @@ def merge(model, writer, stage, device, args): exemplars = torch.cat(exemplars, dim=0) with torch.no_grad(): - if args.ema_fast_slow: + if args.ema_fast_slow: # false in default params old_feature = model.slow_encoder(exemplars, 1) - elif args.sync_backbone: + elif args.sync_backbone: # true in default params old_feature = model.old_encoder(exemplars, 0) else: old_feature = model.moco.encoder_q(exemplars, 0) new_feature = model.moco.encoder_q(exemplars, stage) - if args.pef_type == "cos": + if args.pef_type == "cos": # "cos" in default params old_feature = F.normalize(old_feature, dim=1, p=2) new_feature = F.normalize(new_feature, dim=1, p=2) @@ -872,16 +869,16 @@ def merge(model, writer, stage, device, args): if args.ssl_with_cluster: - topk_min_f_2_p_indices = torch.argsort(min_f_2_p, descending=True)[ + topk_min_dist_to_exemplar_means_idxs = torch.argsort(min_dist_to_exemplar_means, descending=True)[ : int(orig.shape[0] * args.ood_thres) ] - prob1 = F.softmax(orig_class_logits, dim=1) - prob2 = F.softmax(orig_cluster_logits, dim=1) - prob1_bar = F.softmax(aug_class_logits, dim=1) - prob2_bar = F.softmax(aug_cluster_logits, dim=1) + orig_class_prob = F.softmax(orig_class_logits, dim=1) + orig_clust_prob = F.softmax(orig_cluster_logits, dim=1) + aug_class_prob = F.softmax(aug_class_logits, dim=1) + aug_clust_prob = F.softmax(aug_cluster_logits, dim=1) - rank_feat = orig_feats[topk_min_f_2_p_indices].detach() + rank_feat = orig_feats[topk_min_dist_to_exemplar_means_idxs].detach() rank_idx = torch.argsort(rank_feat, dim=1, descending=True) rank_idx1, rank_idx2 = PairEnum(rank_idx) @@ -897,23 +894,23 @@ def merge(model, writer, stage, device, args): target_ulb = torch.ones_like(rank_diff).float().to(device) target_ulb[rank_diff > 0] = -1 - prob1_ulb, _ = PairEnum(prob2[topk_min_f_2_p_indices]) - _, prob2_ulb = PairEnum(prob2_bar[topk_min_f_2_p_indices]) + prob1_ulb, _ = PairEnum(orig_clust_prob[topk_min_dist_to_exemplar_means_idxs]) + _, prob2_ulb = PairEnum(aug_clust_prob[topk_min_dist_to_exemplar_means_idxs]) if args.bce: bce_loss = bce(prob1_ulb, prob2_ulb, target_ulb) bce_loss_record.update(bce_loss.item(), orig.size(0)) loss += bce(prob1_ulb, prob2_ulb, target_ulb) if args.mse: - mse_loss = F.mse_loss(prob2, prob2_bar) - if not args.no_fc: - mse_loss += F.mse_loss(prob1, prob1_bar) + mse_loss = F.mse_loss(orig_clust_prob, aug_clust_prob) + if not args.no_fc: # true in default params + mse_loss += F.mse_loss(orig_class_prob, aug_class_prob) mse_loss_record.update(mse_loss.item(), orig.size(0)) loss += mse_loss - if args.ssl_nce: - f = F.normalize(feature_accept, dim=1, p=2) - f_bar = F.normalize(feature_bar_accept, dim=1, p=2) + if args.ssl_nce: # False in default params + f = F.normalize(accepted_orig_feats, dim=1, p=2) + f_bar = F.normalize(accepted_aug_feats, dim=1, p=2) logits_neg = f.mm(f.t()) logits_neg = logits_neg * ( 1 - torch.eye(logits_neg.shape[0]).to(device).float() @@ -927,7 +924,7 @@ def merge(model, writer, stage, device, args): if args.ssl_with_ce: # false by default loss_ce = ( - (-F.log_softmax(orig_class_logits, dim=1) * pseudo_accept.detach()) + (-F.log_softmax(orig_class_logits, dim=1) * accepted_labels.detach()) .sum(1) .mean() ) @@ -956,12 +953,13 @@ def merge(model, writer, stage, device, args): writer.add_scalar( f"Stage {stage}/Merge/Train/Number of Refused Samples", - refuse_indices.sum(), + refuse_idxs.sum(), epoch, ) # was labelled n refuse lr_sched.step() + # args.reselect_exemplars = False if args.reselect_exemplars and epoch % args.reselect_exemplars_interval == ( args.reselect_exemplars_interval - 1 ): @@ -1399,7 +1397,7 @@ if __name__ == "__main__": device = torch.device("cuda" if args.cuda else "cpu") seed_torch(args.seed) runner_name = os.path.basename(__file__).split(".")[0] - model_dir = os.path.join(args.exp_root, args.model_name) + model_dir = os.path.join(args.exp_root, args.model_name if not args.debug else "debug") os.makedirs(model_dir, exist_ok=True) args.model_dir = model_dir diff --git a/models/moco.py b/models/moco.py index ce8c08510c4f70727cce53df424f010864ceae11..f6ea534e33752034995fbf39419768a68afba636 100644 --- a/models/moco.py +++ b/models/moco.py @@ -33,7 +33,7 @@ class MultiBranchResNet(nn.Module): replace_stride_with_dilation=None, norm_layer=None, branch_depth=3, same_branch_size=False, same_branch=False): super(MultiBranchResNet, self).__init__() - self.same_branch = same_branch + self.same_branch = same_branch # True by default if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -50,7 +50,7 @@ class MultiBranchResNet(nn.Module): "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group - self.branch_depth = branch_depth + self.branch_depth = branch_depth # 3 by default self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) @@ -63,8 +63,8 @@ class MultiBranchResNet(nn.Module): self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) branches = [] - if same_branch_size: - if branch_depth == 3: + if same_branch_size: # False by default + if branch_depth == 3: # 3 by default for _ in range(3): self.inplanes = 256 layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) @@ -87,10 +87,10 @@ class MultiBranchResNet(nn.Module): ) branches.append(layer) else: - if branch_depth == 3: - for _ in range(3): + if branch_depth == 3: # this is the one we use + for _ in range(3): # creates a three basic block network self.inplanes = 256 - layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + layer = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) branches.append(layer) elif branch_depth == 2: for _ in range(3): @@ -170,9 +170,9 @@ class MultiBranchResNet(nn.Module): if branch == 0: x = self.layer4(x3) else: - b = branch if not self.same_branch else 1 - if self.branch_depth == 3: - x = self.branches[b - 1](x3) + b = branch if not self.same_branch else 1 # same_branch is True by defauly, therefore b = 1 + if self.branch_depth == 3: # 3 by default + x = self.branches[b - 1](x3) # run x3 through 0th branch elif self.branch_depth == 2: x = self.branches[b - 1](x2) elif self.branch_depth == 1: @@ -390,7 +390,7 @@ class MultiBranchModelBase(nn.Module): # use split batchnorm norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d - if fully: + if fully: # false self.net = MultiBranchResNetFully(BasicBlock, [2, 2, 2, 2], num_classes=feature_dim, norm_layer=norm_layer, branch_depth=branch_depth, same_branch=same_branch) pretrained = resnet18(pretrained=True) self.net.load_state_dict(pretrained.state_dict(), strict=False)