evoaug.evoaug
Model (implemented in Pytorch Lightning) demonstrating how to use augmentations during training.
Module Contents
Classes
PyTorch Lightning module to specify how augmentation should be applied to a model. |
Functions
|
Load PyTorch lightning model from checkpoint. |
|
Determine whether insertions are applied to determine the insert_max, |
- class evoaug.evoaug.RobustModel(model, criterion, optimizer, augment_list=[], max_augs_per_seq=0, hard_aug=True, finetune=False, inference_aug=False)
Bases:
pytorch_lightning.LightningModule
PyTorch Lightning module to specify how augmentation should be applied to a model.
- Parameters:
model (torch.nn.Module) – PyTorch model.
criterion (callable) – PyTorch loss function
optimizer (torch.optim.Optimizer or dict) – PyTorch optimizer as a class or dictionary
augment_list (list) – List of data augmentations, each a callable class from augment.py. Default is empty list – no augmentations.
max_augs_per_seq (int) – Maximum number of augmentations to apply to each sequence. Value is superceded by the number of augmentations in augment_list.
hard_aug (bool) – Flag to set a hard number of augmentations, otherwise the number of augmentations is set randomly up to max_augs_per_seq, default is True.
finetune (bool) – Flag to turn off augmentations during training, default is False.
inference_aug (bool) – Flag to turn on augmentations during inference, default is False.
- forward(x)
Standard forward pass.
- configure_optimizers()
Standard optimizer configuration.
- training_step(batch, batch_idx)
Training step with augmentations.
- validation_step(batch, batch_idx)
Validation step without (or with) augmentations.
- test_step(batch, batch_idx)
Test step without (or with) augmentations.
- predict_step(batch, batch_idx)
Prediction step without (or with) augmentations.
- _sample_aug_combos(batch_size)
Set the number of augmentations and randomly select augmentations to apply to each sequence.
- _apply_augment(x)
Apply augmentations to each sequence in batch, x.
- _pad_end(x)
Add random DNA padding of length insert_max to the end of each sequence in batch.
- finetune_mode(optimizer=None)
Turn on finetune flag – no augmentations during training.
- evoaug.evoaug.load_model_from_checkpoint(model, checkpoint_path)
Load PyTorch lightning model from checkpoint.
- Parameters:
model (RobustModel) – RobustModel instance.
checkpoint_path (str) – path to checkpoint of model weights
- Returns:
Object with weights and config loaded from checkpoint.
- Return type: