Skip to content
Snippets Groups Projects
Commit 9d298f8f authored by lm9g17's avatar lm9g17
Browse files

Updated noise_cifar_train.py to swap between models

parent 67cb659d
Branches
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ import torch.backends.cudnn as cudnn
from utils import train_noise, test, get_output, WeightEMA
from dataset import get_cifar_dataset
from networks.wideresnet import Wide_ResNet
from networks.cnn import CNN #Import custom CNN
from networks.convnet import ConvNet
def log(path, str):
print(str)
......@@ -31,11 +33,18 @@ parser.add_argument('--sigma', type=float, default=0.5, help='STD of Gaussian no
parser.add_argument('--correction', type=int, default=250, help='correction start epoch')
parser.add_argument('--gpu_id', type=int, default=0, help='index of gpu to use')
parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
parser.add_argument('--model', type=str, default='wide_res', help='model to use') #Added model argument
args = parser.parse_args()
exp_name = 'sigma{:.1f}_{}_{}{:.1f}_seed{}'.format(args.sigma, args.dataset, args.noise_mode, args.noise_rate, args.seed)
#Add cnn to output name when using cnn model
if args.model == 'cnn':
exp_name = 'cnn_' + exp_name
if args.model == 'convnet':
exp_name = 'convnet_' + exp_name
if 0<args.correction<args.epochs:
exp_name = 'correction_' + exp_name
......@@ -68,6 +77,13 @@ noisy_targets = np.eye(args.num_class)[noisy_targets] # to one-hot
# model
if args.model == 'cnn':
net = CNN().cuda()
ema_net = CNN().cuda()
elif args.model == 'convnet':
net = ConvNet(n_outputs=args.num_class).cuda()
ema_net = ConvNet(n_outputs=args.num_class).cuda()
else:
net = Wide_ResNet(num_classes=args.num_class).cuda()
ema_net = Wide_ResNet(num_classes=args.num_class).cuda()
for param in ema_net.parameters():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment