Skip to content
Snippets Groups Projects
Unverified Commit cb711401 authored by Timothée Darcet's avatar Timothée Darcet Committed by GitHub
Browse files

Fix XCiT hub.load issue (#140)


* hub.load xcit from main branch

* Fix other files using hub.load

* Also change torch.hub.list

Co-authored-by: default avatarTimothee Darcet <tdarcet@inrialpes.fr>
parent 499d9e2b
Branches
No related tags found
No related merge requests found
...@@ -132,7 +132,7 @@ if __name__ == '__main__': ...@@ -132,7 +132,7 @@ if __name__ == '__main__':
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
elif "xcit" in args.arch: elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
elif args.arch in torchvision_models.__dict__.keys(): elif args.arch in torchvision_models.__dict__.keys():
model = torchvision_models.__dict__[args.arch](num_classes=0) model = torchvision_models.__dict__[args.arch](num_classes=0)
else: else:
......
...@@ -60,7 +60,7 @@ def extract_feature_pipeline(args): ...@@ -60,7 +60,7 @@ def extract_feature_pipeline(args):
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
elif "xcit" in args.arch: elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
elif args.arch in torchvision_models.__dict__.keys(): elif args.arch in torchvision_models.__dict__.keys():
model = torchvision_models.__dict__[args.arch](num_classes=0) model = torchvision_models.__dict__[args.arch](num_classes=0)
model.fc = nn.Identity() model.fc = nn.Identity()
......
...@@ -41,7 +41,7 @@ def eval_linear(args): ...@@ -41,7 +41,7 @@ def eval_linear(args):
embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens)) embed_dim = model.embed_dim * (args.n_last_blocks + int(args.avgpool_patchtokens))
# if the network is a XCiT # if the network is a XCiT
elif "xcit" in args.arch: elif "xcit" in args.arch:
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0) model = torch.hub.load('facebookresearch/xcit:main', args.arch, num_classes=0)
embed_dim = model.embed_dim embed_dim = model.embed_dim
# otherwise, we check if the architecture is in torchvision models # otherwise, we check if the architecture is in torchvision models
elif args.arch in torchvision_models.__dict__.keys(): elif args.arch in torchvision_models.__dict__.keys():
......
...@@ -99,7 +99,7 @@ def dino_xcit_small_12_p16(pretrained=True, **kwargs): ...@@ -99,7 +99,7 @@ def dino_xcit_small_12_p16(pretrained=True, **kwargs):
""" """
XCiT-Small-12/16 pre-trained with DINO. XCiT-Small-12/16 pre-trained with DINO.
""" """
model = torch.hub.load('facebookresearch/xcit', "xcit_small_12_p16", num_classes=0, **kwargs) model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p16", num_classes=0, **kwargs)
if pretrained: if pretrained:
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth", url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth",
...@@ -113,7 +113,7 @@ def dino_xcit_small_12_p8(pretrained=True, **kwargs): ...@@ -113,7 +113,7 @@ def dino_xcit_small_12_p8(pretrained=True, **kwargs):
""" """
XCiT-Small-12/8 pre-trained with DINO. XCiT-Small-12/8 pre-trained with DINO.
""" """
model = torch.hub.load('facebookresearch/xcit', "xcit_small_12_p8", num_classes=0, **kwargs) model = torch.hub.load('facebookresearch/xcit:main', "xcit_small_12_p8", num_classes=0, **kwargs)
if pretrained: if pretrained:
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth", url="https://dl.fbaipublicfiles.com/dino/dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth",
...@@ -127,7 +127,7 @@ def dino_xcit_medium_24_p16(pretrained=True, **kwargs): ...@@ -127,7 +127,7 @@ def dino_xcit_medium_24_p16(pretrained=True, **kwargs):
""" """
XCiT-Medium-24/16 pre-trained with DINO. XCiT-Medium-24/16 pre-trained with DINO.
""" """
model = torch.hub.load('facebookresearch/xcit', "xcit_medium_24_p16", num_classes=0, **kwargs) model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p16", num_classes=0, **kwargs)
if pretrained: if pretrained:
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth", url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
...@@ -141,7 +141,7 @@ def dino_xcit_medium_24_p8(pretrained=True, **kwargs): ...@@ -141,7 +141,7 @@ def dino_xcit_medium_24_p8(pretrained=True, **kwargs):
""" """
XCiT-Medium-24/8 pre-trained with DINO. XCiT-Medium-24/8 pre-trained with DINO.
""" """
model = torch.hub.load('facebookresearch/xcit', "xcit_medium_24_p8", num_classes=0, **kwargs) model = torch.hub.load('facebookresearch/xcit:main', "xcit_medium_24_p8", num_classes=0, **kwargs)
if pretrained: if pretrained:
state_dict = torch.hub.load_state_dict_from_url( state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth", url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth",
......
...@@ -44,7 +44,7 @@ def get_args_parser(): ...@@ -44,7 +44,7 @@ def get_args_parser():
# Model parameters # Model parameters
parser.add_argument('--arch', default='vit_small', type=str, parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \ choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \
+ torchvision_archs + torch.hub.list("facebookresearch/xcit"), + torchvision_archs + torch.hub.list("facebookresearch/xcit:main"),
help="""Name of architecture to train. For quick experiments with ViTs, help="""Name of architecture to train. For quick experiments with ViTs,
we recommend using vit_tiny or vit_small.""") we recommend using vit_tiny or vit_small.""")
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
...@@ -166,10 +166,10 @@ def train_dino(args): ...@@ -166,10 +166,10 @@ def train_dino(args):
teacher = vits.__dict__[args.arch](patch_size=args.patch_size) teacher = vits.__dict__[args.arch](patch_size=args.patch_size)
embed_dim = student.embed_dim embed_dim = student.embed_dim
# if the network is a XCiT # if the network is a XCiT
elif args.arch in torch.hub.list("facebookresearch/xcit"): elif args.arch in torch.hub.list("facebookresearch/xcit:main"):
student = torch.hub.load('facebookresearch/xcit', args.arch, student = torch.hub.load('facebookresearch/xcit:main', args.arch,
pretrained=False, drop_path_rate=args.drop_path_rate) pretrained=False, drop_path_rate=args.drop_path_rate)
teacher = torch.hub.load('facebookresearch/xcit', args.arch, pretrained=False) teacher = torch.hub.load('facebookresearch/xcit:main', args.arch, pretrained=False)
embed_dim = student.embed_dim embed_dim = student.embed_dim
# otherwise, we check if the architecture is in torchvision models # otherwise, we check if the architecture is in torchvision models
elif args.arch in torchvision_models.__dict__.keys(): elif args.arch in torchvision_models.__dict__.keys():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment