-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathCPR.py
More file actions
26 lines (20 loc) · 718 Bytes
/
CPR.py
File metadata and controls
26 lines (20 loc) · 718 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
# import wandb
from recq.recommenders.cpr import CPR
from recq.utils import Dataset
from recq.tools.parser import parse_args
from recq.tools.io import print_seperate_line
args = parse_args("cpr")
print_seperate_line()
for key, value in vars(args).items():
print(key + "=" + str(value))
print_seperate_line()
# wandb.init(project="cpr_" + args.dataset, config=args)
# wandb.define_metric("epoch")
# wandb.define_metric("val/*", step_metric="epoch", summary="max")
curr_dir = os.path.dirname(__file__)
data_dir = os.path.join(curr_dir, "data", args.dataset)
model_dir = os.path.join(curr_dir, "output", "model")
dataset = Dataset(args, data_dir)
model = CPR(args, dataset)
model.fit(args, model_dir)