evoaug.evoaug

Model (implemented in Pytorch Lightning) demonstrating how to use augmentations during training.

Module Contents

Classes

RobustModel

PyTorch Lightning module to specify how augmentation should be applied to a model.

Functions

load_model_from_checkpoint(model, checkpoint_path)

Load PyTorch lightning model from checkpoint.

augment_max_len(augment_list)

Determine whether insertions are applied to determine the insert_max,

class evoaug.evoaug.RobustModel(model, criterion, optimizer, augment_list=[], max_augs_per_seq=0, hard_aug=True, finetune=False, inference_aug=False)

Bases: pytorch_lightning.LightningModule

PyTorch Lightning module to specify how augmentation should be applied to a model.

Parameters:
  • model (torch.nn.Module) – PyTorch model.

  • criterion (callable) – PyTorch loss function

  • optimizer (torch.optim.Optimizer or dict) – PyTorch optimizer as a class or dictionary

  • augment_list (list) – List of data augmentations, each a callable class from augment.py. Default is empty list – no augmentations.

  • max_augs_per_seq (int) – Maximum number of augmentations to apply to each sequence. Value is superceded by the number of augmentations in augment_list.

  • hard_aug (bool) – Flag to set a hard number of augmentations, otherwise the number of augmentations is set randomly up to max_augs_per_seq, default is True.

  • finetune (bool) – Flag to turn off augmentations during training, default is False.

  • inference_aug (bool) – Flag to turn on augmentations during inference, default is False.

forward(x)

Standard forward pass.

configure_optimizers()

Standard optimizer configuration.

training_step(batch, batch_idx)

Training step with augmentations.

validation_step(batch, batch_idx)

Validation step without (or with) augmentations.

test_step(batch, batch_idx)

Test step without (or with) augmentations.

predict_step(batch, batch_idx)

Prediction step without (or with) augmentations.

_sample_aug_combos(batch_size)

Set the number of augmentations and randomly select augmentations to apply to each sequence.

_apply_augment(x)

Apply augmentations to each sequence in batch, x.

_pad_end(x)

Add random DNA padding of length insert_max to the end of each sequence in batch.

finetune_mode(optimizer=None)

Turn on finetune flag – no augmentations during training.

evoaug.evoaug.load_model_from_checkpoint(model, checkpoint_path)

Load PyTorch lightning model from checkpoint.

Parameters:
  • model (RobustModel) – RobustModel instance.

  • checkpoint_path (str) – path to checkpoint of model weights

Returns:

Object with weights and config loaded from checkpoint.

Return type:

RobustModel

evoaug.evoaug.augment_max_len(augment_list)

Determine whether insertions are applied to determine the insert_max, which will be applied to pad other sequences with random DNA.

Parameters:

augment_list (list) – List of augmentations.

Returns:

Value for insert max.

Return type:

int