-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathmain.py
More file actions
84 lines (62 loc) · 2.7 KB
/
main.py
File metadata and controls
84 lines (62 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import logging
import pprint
import os
import sys
import shutil
from datetime import datetime
from torch import distributed as dist
from matdeeplearn.common.config.build_config import build_config
from matdeeplearn.common.config.flags import flags
from matdeeplearn.common.trainer_context import new_trainer_context
from matdeeplearn.preprocessor.processor import process_data
# import submitit
# from matdeeplearn.common.utils import setup_logging
class Runner: # submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
def __call__(self, config):
with new_trainer_context(args=args, config=config) as ctx:
self.config = ctx.config
self.task = ctx.task
self.trainer = ctx.trainer
self.task.setup(self.trainer)
# Print settings for job
logging.debug("Settings: ")
logging.debug(pprint.pformat(self.config))
self.task.run()
shutil.move('log_'+config["task"]["log_id"]+'.txt', os.path.join(self.trainer.save_dir, "results", self.trainer.timestamp_id, "log.txt"))
def checkpoint(self, *args, **kwargs):
# new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
# return submitit.helpers.DelayedSubmission(new_runner, self.config)
if __name__ == "__main__":
# setup_logging()
local_rank = os.environ.get('LOCAL_RANK', None)
if local_rank == None or int(local_rank) == 0:
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
timestamp = datetime.now().timestamp()
timestamp_id = datetime.fromtimestamp(timestamp).strftime(
"%Y-%m-%d-%H-%M-%S-%f"
)[:-3]
fh = logging.FileHandler('log_'+timestamp_id+'.txt', 'w+')
fh.setLevel(logging.DEBUG)
root_logger.addHandler(fh)
sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.DEBUG)
root_logger.addHandler(sh)
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
config["task"]["log_id"] = timestamp_id
if not config["dataset"]["processed"]:
process_data(config["dataset"])
if args.submit: # Run on cluster
# TODO: add setup to submit to cluster
pass
else: # Run locally
Runner()(config)