pytorch lightning save checkpoint every epoch
class CheckpointEveryEpoch(pl.Callback):
def __init__(self, start_epoc, save_path,):
self.start_epoc = start_epoc
self.file_path = save_path
def on_epoch_end(self, trainer: pl.Trainer, _):
""" Check if we should save a checkpoint after every train epoch """
epoch = trainer.current_epoch
if epoch >= self.start_epoc:
ckpt_path = f"{self.save_path}_e{epoch}.ckpt"
trainer.save_checkpoint(ckpt_path)
trainer = Trainer(callbacks=[CheckpointEveryEpoch(2, args.save_path)]
) # after 2 epoch start to saving ckpts