-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
54 lines (45 loc) · 2.27 KB
/
utils.py
File metadata and controls
54 lines (45 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
Util functions
"""
import torch
import models_mae, models_vit
from huggingface_hub import hf_hub_download
def get_available_models():
available_models = [
'mae_say_none', 'mae_s_none', 'mae_kinetics_none', 'mae_kinetics-200h_none',
'vit_say_none', 'vit_s_none', 'vit_kinetics_none', 'vit_kinetics-200h_none',
'vit_say_ssv2-50shot', 'vit_s_ssv2-50shot', 'vit_kinetics_ssv2-50shot', 'vit_kinetics-200h_ssv2-50shot',
'vit_say_ssv2-10shot', 'vit_s_ssv2-10shot', 'vit_kinetics_ssv2-10shot', 'vit_kinetics-200h_ssv2-10shot',
'vit_say_kinetics-50shot', 'vit_s_kinetics-50shot', 'vit_kinetics_kinetics-50shot', 'vit_kinetics-200h_kinetics-50shot',
'vit_say_kinetics-10shot', 'vit_s_kinetics-10shot', 'vit_kinetics_kinetics-10shot', 'vit_kinetics-200h_kinetics-10shot',
]
return available_models
def load_model(model_name):
# make sure model is valid
assert model_name in get_available_models(), 'Unrecognized model!'
# parse identifier
model_type, pretrain_data, finetune_data = model_name.split('_')
# checks
assert model_type in ['mae', 'vit'], 'Unrecognized model type!'
assert pretrain_data in ['say', 's', 'kinetics', 'kinetics-200h'], 'Unrecognized pretraining data!'
assert finetune_data in ['none', 'ssv2-50shot', 'kinetics-50shot', 'ssv2-10shot', 'kinetics-10shot'], 'Unrecognized finetuning data!'
# download checkpoint from hf
ckpt_filename = pretrain_data + '_' + finetune_data + '.pth'
ckpt = hf_hub_download(repo_id='eminorhan/video-models', filename=ckpt_filename)
if model_type.startswith('mae'):
model = models_mae.mae_vit_huge_patch14()
ckpt = torch.load(ckpt, map_location='cpu')
msg = model.load_state_dict(ckpt['model'], strict=True)
print(f'Loaded with message: {msg}')
elif model_type.startswith('vit'):
if finetune_data.startswith('ssv2'):
num_classes = 174
elif finetune_data.startswith('kinetics'):
num_classes = 700
else:
num_classes = None
model = models_vit.vit_huge_patch14(num_classes=num_classes)
ckpt = torch.load(ckpt, map_location='cpu')['model']
msg = model.load_state_dict(ckpt, strict=False)
print(f'Loaded with message: {msg}')
return model