Source code for pydelfi.train
import tensorflow as tf
import numpy as np
import numpy.random as rng
import os
from tqdm.auto import tqdm
[docs]class ConditionalTrainer():
def __init__(self, model, optimizer=tf.train.AdamOptimizer, optimizer_arguments={}):
"""
Constructor that defines the training operation.
:param model: made/maf instance to be trained.
:param optimizer: tensorflow optimizer class to be used during training.
:param optimizer_arguments: dictionary of arguments for optimizer intialization.
"""
self.model = model
self.train_optimizer = optimizer(**optimizer_arguments).minimize(self.model.trn_loss)
self.train_reg_optimizer = optimizer(**optimizer_arguments).minimize(self.model.reg_loss)
"""
Training class for the conditional MADEs/MAFs classes using a tensorflow optimizer.
"""
[docs] def train(self, sess, train_data, validation_split = 0.1, epochs=1000, batch_size=100,
patience=20, saver_name='tmp_model', progress_bar=True, mode='samples'):
"""
Training function to be called with desired parameters within a tensorflow session.
:param sess: tensorflow session where the graph is run.
:param train_data: a tuple/list of (X,Y) with training data where Y is conditioned on X.
:param validation_split: percentage of training data randomly selected to be used for validation
:param epochs: maximum number of epochs for training.
:param batch_size: batch size of each batch within an epoch.
:param early_stopping: number of epochs for early stopping criteria.
:param check_every_N: check every N iterations if model has improved and saves if so.
:param saver_name: string of name (with or without folder) where model is saved. If none is given,
a temporal model is used to save and restore best model, and removed afterwards.
"""
# Training data
if mode == 'samples':
train_data_X, train_data_Y = train_data
elif mode == 'regression':
train_data_X, train_data_Y, train_data_PDF = train_data
train_idx = np.arange(train_data_X.shape[0])
# validation data using p_val percent of the data
rng.shuffle(train_idx)
N = train_data_X.shape[0]
val_data_X = train_data_X[train_idx[-int(validation_split*N):]]
train_data_X = train_data_X[train_idx[:-int(validation_split*N)]]
val_data_Y = train_data_Y[train_idx[-int(validation_split*N):]]
train_data_Y = train_data_Y[train_idx[:-int(validation_split*N)]]
if mode == 'regression':
val_data_PDF = train_data_PDF[train_idx[-int(validation_split*N):]]
train_data_PDF = train_data_PDF[train_idx[:-int(validation_split*N)]]
train_idx = np.arange(train_data_X.shape[0])
# Early stopping variables
bst_loss = np.infty
early_stopping_count = 0
saver = tf.train.Saver()
# Validation and training losses
validation_losses = []
training_losses = []
# Main training loop
if progress_bar:
pbar = tqdm(total = epochs, desc = "Training")
pbar.set_postfix(ordered_dict={"train loss":0, "val loss":0}, refresh=True)
for epoch in range(epochs):
# Shuffel training indices
rng.shuffle(train_idx)
for batch in range(len(train_idx)//batch_size):
# Last batch will have maximum number of elements possible
batch_idx = train_idx[batch*batch_size:np.min([(batch+1)*batch_size,len(train_idx)])]
if mode == 'samples':
sess.run(self.train_optimizer,feed_dict={self.model.parameters:train_data_X[batch_idx],
self.model.data:train_data_Y[batch_idx]})
elif mode == 'regression':
sess.run(self.train_reg_optimizer,feed_dict={self.model.parameters:train_data_X[batch_idx],
self.model.data:train_data_Y[batch_idx],
self.model.logpdf:train_data_PDF[batch_idx]})
# Early stopping check
if mode == 'samples':
val_loss = sess.run(self.model.trn_loss,feed_dict={self.model.parameters:val_data_X,
self.model.data:val_data_Y})
train_loss = sess.run(self.model.trn_loss,feed_dict={self.model.parameters:train_data_X,
self.model.data:train_data_Y})
elif mode == 'regression':
val_loss = sess.run(self.model.reg_loss,feed_dict={self.model.parameters:val_data_X,
self.model.data:val_data_Y,
self.model.logpdf:val_data_PDF})
train_loss = sess.run(self.model.reg_loss,feed_dict={self.model.parameters:train_data_X,
self.model.data:train_data_Y,
self.model.logpdf:train_data_PDF})
if progress_bar:
pbar.update()
pbar.set_postfix(ordered_dict={"train loss":train_loss, "val loss":val_loss}, refresh=True)
validation_losses.append(val_loss)
training_losses.append(train_loss)
if val_loss < bst_loss:
bst_loss = val_loss
if saver_name is not None:
saver.save(sess,"./"+saver_name)
early_stopping_count = 0
else:
early_stopping_count += 1
if early_stopping_count >= patience:
#pbar.set_postfix(str="Early stopping: terminated", refresh=True)
break
# Restore best model
if saver_name is not None:
saver.restore(sess, saver_name)
return np.array(validation_losses), np.array(training_losses)