Skip to content

restart

xvr.cli.commands.restart

restart

restart(ckptpath: str, id: str, project: str)

Restart model training from a checkpoint.

Source code in src/xvr/cli/commands/restart.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@click.command(cls=CategorizedCommand)
@categorized_option(
    "-c",
    "--ckptpath",
    required=True,
    type=click.Path(exists=True),
    help="Checkpoint of a pretrained pose regressor",
)
@categorized_option(
    "--id",
    default=None,
    type=str,
    help="WandB run ID",
)
@categorized_option(
    "--project",
    type=str,
    default=None,
    help="WandB project name",
)
def restart(
    ckptpath: str,
    id: str,
    project: str,
):
    """
    Restart model training from a checkpoint.
    """
    import os
    from pathlib import Path

    import torch
    import wandb

    from ...model import Trainer

    # If ckptpath is a directory, get the last saved model
    ckptpath = Path(ckptpath)
    if ckptpath.is_dir():
        ckptpath = sorted(ckptpath.glob("*.pth"))[-1]
    ckptpath = str(ckptpath)

    # Load the config from the previous model checkpoint
    config = torch.load(ckptpath, weights_only=False)["config"]
    config["ckptpath"] = ckptpath
    config["reuse_optimizer"] = True

    # Set up logging
    wandb.login(key=os.environ["WANDB_API_KEY"])
    project = config["project"] if project is None else project
    run = wandb.init(project=project, id=id, config=config, resume="must")

    # Train the model
    trainer = Trainer(**config)
    trainer.train(run)