Skip to content

Loading a trained model

Trained models can be loaded from a checkpoint using load_model_checkpoint:

from frogbox.utils import load_model_checkpoint

model, config = load_model_checkpoint("checkpoint/cool-giraffe-123/checkpoint_1234.pt")
model = model.eval()

with torch.inference_mode():
    pred = model(torch.randn(1, 3, 128, 128))

frogbox.utils

load_model_checkpoint

load_model_checkpoint(path, config_path=None)

Load model from checkpoint.

Parameters:

  • path

    (path) –

    Path to checkpoint file.

  • config_path

    (path, default: None ) –

    Path to config file. If empty config will be read from "config.json" in the same folder as path.

Returns:

  • checkpoint ( (Module, Config) ) –

    Model checkpoint and config.