Module snapshot_ensemble.utils
Expand source code
from snapshot_ensemble import *
import matplotlib.pyplot as plt
def VisualizeLR(epochs=500, cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-6, lr_multiplier=0.9):
"""
Helper function for visualizing the cosine annealed learning rate schedule.
Note: The implementation in `SnapshotEnsembleCallback` allows for smoother batch-level decay of learning rates, while
this function returns a simplified epoch-level decay for visualization purposes.
"""
lr_max = lr_init
prevRestartEpoch = 0
res = []
for t in range(epochs):
# Update states at start of new cycle
if t == (prevRestartEpoch + cycle_length + 1):
cycle_length = math.ceil(cycle_length * cycle_length_multiplier)
lr_max = lr_max * lr_multiplier
lr_min = lr_min * lr_multiplier
prevRestartEpoch = t
# Cosine annealed learning rate
epochs_since_restart = t - prevRestartEpoch
lr = float(
lr_min +
0.5 * (lr_max - lr_min) *
( 1 + np.cos(np.pi * (epochs_since_restart / cycle_length)) )
)
res.append( lr )
fig = plt.figure()
plt.title('Cosine Annealed Learning Rate Schedule')
plt.xlabel('Epochs')
plt.ylabel('Learning Rate')
_ = plt.plot(res)
plt.close()
return fig
def GenerateSnapshotCallbacks(cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-6, lr_multiplier=0.9,
ensemble=True, ensemble_options={}):
"""
Helper function for generating a list of Keras callbacks to be used during training.
Includes the `SnapshotEnsembleCallback` for cosine annealing and saving an ensemble of models,
as well as `ModelCheckpoint` for saving the best model.
"""
snapEns = SnapshotEnsembleCallback(
cycle_length=cycle_length,
cycle_length_multiplier=cycle_length_multiplier,
lr_init=lr_init,
lr_min=lr_min,
lr_multiplier=lr_multiplier,
ensemble=ensemble,
ensemble_options=ensemble_options
)
callbacks = [
snapEns,
tfk.callbacks.ModelCheckpoint(
os.path.join( snapEns.ensembleConfig.get('dirpath'), f"{snapEns.ensembleConfig.get('model_prefix')}-Best.h5" ),
monitor='val_loss',
mode='min',
save_best_only=True,
save_weights_only=True
),
]
return callbacks
Functions
def GenerateSnapshotCallbacks(cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-06, lr_multiplier=0.9, ensemble=True, ensemble_options={})
-
Helper function for generating a list of Keras callbacks to be used during training.
Includes the
SnapshotEnsembleCallback
for cosine annealing and saving an ensemble of models, as well asModelCheckpoint
for saving the best model.Expand source code
def GenerateSnapshotCallbacks(cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-6, lr_multiplier=0.9, ensemble=True, ensemble_options={}): """ Helper function for generating a list of Keras callbacks to be used during training. Includes the `SnapshotEnsembleCallback` for cosine annealing and saving an ensemble of models, as well as `ModelCheckpoint` for saving the best model. """ snapEns = SnapshotEnsembleCallback( cycle_length=cycle_length, cycle_length_multiplier=cycle_length_multiplier, lr_init=lr_init, lr_min=lr_min, lr_multiplier=lr_multiplier, ensemble=ensemble, ensemble_options=ensemble_options ) callbacks = [ snapEns, tfk.callbacks.ModelCheckpoint( os.path.join( snapEns.ensembleConfig.get('dirpath'), f"{snapEns.ensembleConfig.get('model_prefix')}-Best.h5" ), monitor='val_loss', mode='min', save_best_only=True, save_weights_only=True ), ] return callbacks
def VisualizeLR(epochs=500, cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-06, lr_multiplier=0.9)
-
Helper function for visualizing the cosine annealed learning rate schedule.
Note: The implementation in
SnapshotEnsembleCallback
allows for smoother batch-level decay of learning rates, while this function returns a simplified epoch-level decay for visualization purposes.Expand source code
def VisualizeLR(epochs=500, cycle_length=10, cycle_length_multiplier=1.5, lr_init=0.01, lr_min=1e-6, lr_multiplier=0.9): """ Helper function for visualizing the cosine annealed learning rate schedule. Note: The implementation in `SnapshotEnsembleCallback` allows for smoother batch-level decay of learning rates, while this function returns a simplified epoch-level decay for visualization purposes. """ lr_max = lr_init prevRestartEpoch = 0 res = [] for t in range(epochs): # Update states at start of new cycle if t == (prevRestartEpoch + cycle_length + 1): cycle_length = math.ceil(cycle_length * cycle_length_multiplier) lr_max = lr_max * lr_multiplier lr_min = lr_min * lr_multiplier prevRestartEpoch = t # Cosine annealed learning rate epochs_since_restart = t - prevRestartEpoch lr = float( lr_min + 0.5 * (lr_max - lr_min) * ( 1 + np.cos(np.pi * (epochs_since_restart / cycle_length)) ) ) res.append( lr ) fig = plt.figure() plt.title('Cosine Annealed Learning Rate Schedule') plt.xlabel('Epochs') plt.ylabel('Learning Rate') _ = plt.plot(res) plt.close() return fig