Restart model training from a checkpoint.
Source code in src/xvr/cli/commands/restart.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 | @click.command(cls=CategorizedCommand)
@categorized_option(
"-c",
"--ckptpath",
required=True,
type=click.Path(exists=True),
help="Checkpoint of a pretrained pose regressor",
)
@categorized_option(
"--id",
default=None,
type=str,
help="WandB run ID",
)
@categorized_option(
"--project",
type=str,
default=None,
help="WandB project name",
)
def restart(
ckptpath: str,
id: str,
project: str,
):
"""
Restart model training from a checkpoint.
"""
import os
from pathlib import Path
import torch
import wandb
from ...model import Trainer
# If ckptpath is a directory, get the last saved model
ckptpath = Path(ckptpath)
if ckptpath.is_dir():
ckptpath = sorted(ckptpath.glob("*.pth"))[-1]
ckptpath = str(ckptpath)
# Load the config from the previous model checkpoint
config = torch.load(ckptpath, weights_only=False)["config"]
config["ckptpath"] = ckptpath
config["reuse_optimizer"] = True
# Set up logging
wandb.login(key=os.environ["WANDB_API_KEY"])
project = config["project"] if project is None else project
run = wandb.init(project=project, id=id, config=config, resume="must")
# Train the model
trainer = Trainer(**config)
trainer.train(run)
|