Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
C
chenpf1025_SLN_clone
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
lm9g17
chenpf1025_SLN_clone
Commits
9d298f8f
Commit
9d298f8f
authored
4 years ago
by
lm9g17
Browse files
Options
Downloads
Patches
Plain Diff
Updated noise_cifar_train.py to swap between models
parent
67cb659d
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
noise_cifar_train.py
+127
-111
127 additions, 111 deletions
noise_cifar_train.py
with
127 additions
and
111 deletions
noise_cifar_train.py
+
127
−
111
View file @
9d298f8f
...
...
@@ -11,6 +11,8 @@ import torch.backends.cudnn as cudnn
from
utils
import
train_noise
,
test
,
get_output
,
WeightEMA
from
dataset
import
get_cifar_dataset
from
networks.wideresnet
import
Wide_ResNet
from
networks.cnn
import
CNN
#Import custom CNN
from
networks.convnet
import
ConvNet
def
log
(
path
,
str
):
print
(
str
)
...
...
@@ -31,11 +33,18 @@ parser.add_argument('--sigma', type=float, default=0.5, help='STD of Gaussian no
parser
.
add_argument
(
'
--correction
'
,
type
=
int
,
default
=
250
,
help
=
'
correction start epoch
'
)
parser
.
add_argument
(
'
--gpu_id
'
,
type
=
int
,
default
=
0
,
help
=
'
index of gpu to use
'
)
parser
.
add_argument
(
'
--seed
'
,
type
=
int
,
default
=
0
,
help
=
'
random seed (default: 0)
'
)
parser
.
add_argument
(
'
--model
'
,
type
=
str
,
default
=
'
wide_res
'
,
help
=
'
model to use
'
)
#Added model argument
args
=
parser
.
parse_args
()
exp_name
=
'
sigma{:.1f}_{}_{}{:.1f}_seed{}
'
.
format
(
args
.
sigma
,
args
.
dataset
,
args
.
noise_mode
,
args
.
noise_rate
,
args
.
seed
)
#Add cnn to output name when using cnn model
if
args
.
model
==
'
cnn
'
:
exp_name
=
'
cnn_
'
+
exp_name
if
args
.
model
==
'
convnet
'
:
exp_name
=
'
convnet_
'
+
exp_name
if
0
<
args
.
correction
<
args
.
epochs
:
exp_name
=
'
correction_
'
+
exp_name
...
...
@@ -68,6 +77,13 @@ noisy_targets = np.eye(args.num_class)[noisy_targets] # to one-hot
# model
if
args
.
model
==
'
cnn
'
:
net
=
CNN
().
cuda
()
ema_net
=
CNN
().
cuda
()
elif
args
.
model
==
'
convnet
'
:
net
=
ConvNet
(
n_outputs
=
args
.
num_class
).
cuda
()
ema_net
=
ConvNet
(
n_outputs
=
args
.
num_class
).
cuda
()
else
:
net
=
Wide_ResNet
(
num_classes
=
args
.
num_class
).
cuda
()
ema_net
=
Wide_ResNet
(
num_classes
=
args
.
num_class
).
cuda
()
for
param
in
ema_net
.
parameters
():
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment