Skip to content
Snippets Groups Projects
Unverified Commit ad5a262b authored by qasfb's avatar qasfb Committed by GitHub
Browse files

Update param_groups.py (#283)

* Update param_groups.py

Update lr decay rates for reg tokens

* Update param_groups.py
parent e203621e
No related branches found
No related tags found
No related merge requests found
...@@ -22,10 +22,10 @@ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backb ...@@ -22,10 +22,10 @@ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backb
""" """
layer_id = num_layers + 1 layer_id = num_layers + 1
if name.startswith("backbone") or force_is_backbone: if name.startswith("backbone") or force_is_backbone:
if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name: if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name or ".register_tokens" in name:
layer_id = 0 layer_id = 0
elif force_is_backbone and ( elif force_is_backbone and (
"pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name or "register_tokens" in name
): ):
layer_id = 0 layer_id = 0
elif ".blocks." in name and ".residual." not in name: elif ".blocks." in name and ".residual." not in name:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment