Tensorflow Advanced : Part 3

Callbacks

Common methods for training/testing/predicting

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Callback(object):
...

## Called at the begin of fit/evaluate/predict
def on_(train|test|predict)_begin(self,logs = None):

## Called at the end of fit/evaluate/predict
def on_(train|test|predict)_end(self,logs = None):

## Called right before processing a batch during training/testing/predicting
def on_(train|test|predict)_batch_begin(self,logs = None):

## Called at the end of training/testing/predicting a batch
def on_(train|test|predict)_batch_end(self,logs = None):

Where can you use them ?

  • fit(... , callbacks = [])
  • fit_generator(... , callbacks = [])
  • evaluate(... , callbacks = [])
  • evaluate_generator(... , callbacks = [])
  • predict(... , callbacks = [])
  • predict_generator(... , callbacks = [])

Tensorboard

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, LearningRateScheduler, ModelCheckpoint, CSVLogger, ReduceLROnPlateau
%load_ext tensorboard

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir)

!rm -rf logs

model.fit(train_batches,
epochs=10,
validation_data=validation_batches,
callbacks=[tensorboard_callback])

%tensorboard --logdir logs

Checkpoints

1
2
3
4
5
6
7
# save model after each epoch
model.fit(train_batches,
epochs=1,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('mode.h5', verbose=1)
])
1
2
3
4
5
6
7
# save only the weights
model.fit(train_batches,
epochs=1,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('mode.h5',save_weights_only = True, verbose=1)
])
1
2
3
4
5
6
7
# save the weights only when monitor improves
model.fit(train_batches,
epochs=1,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('mode.h5', monitor = 'val_loss' , save_weights_only = True, verbose=1)
])
1
2
3
4
5
6
model.fit(train_batches, 
epochs=5,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.h5', verbose=1),
])
1
2
3
4
5
6
7
# Standard TF way ?
model.fit(train_batches,
epochs=1,
validation_data=validation_batches,
verbose=2,
callbacks=[ModelCheckpoint('saved_model', verbose=1)
])

Early Stopping

1
2
3
4
5
6
7
8
9
10
model.fit(train_batches, 
epochs=50,
validation_data=validation_batches,
verbose=2,
callbacks=[EarlyStopping(
patience=3,
monitor='val_loss',
restore_best_weights=True, # restore where it was best
verbose=1)
])
1
2
3
4
5
6
7
8
9
10
11
12
13
model.fit(train_batches, 
epochs=50,
validation_data=validation_batches,
verbose=2,
callbacks=[EarlyStopping(
patience=3,
min_delta=0.05,
baseline=0.8,
mode='min', # for loss we want to minimize
monitor='val_loss',
restore_best_weights=True,
verbose=1)
])

CSV Logger

1
2
3
4
5
6
7
8
9
10
11
12
13
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

csv_file = 'training.csv'

model.fit(train_batches,
epochs=5,
validation_data=validation_batches,
callbacks=[CSVLogger(csv_file)
])

Learning Rate Scheduler

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

def step_decay(epoch):
initial_lr = 0.01
drop = 0.5
epochs_drop = 1
lr = initial_lr * math.pow(drop, math.floor((1+epoch)/epochs_drop))
return lr

model.fit(train_batches,
epochs=5,
validation_data=validation_batches,
callbacks=[LearningRateScheduler(step_decay, verbose=1),
TensorBoard(log_dir='./log_dir')])

%tensorboard --logdir log_dir

Reduce LR on Plateau

1
2
3
4
5
6
7
8
9
10
11
12
13
model = build_model(dense_units=256)
model.compile(
optimizer='sgd',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(train_batches,
epochs=50,
validation_data=validation_batches,
callbacks=[ReduceLROnPlateau(monitor='val_loss',
factor=0.2, verbose=1,
patience=1, min_lr=0.001),
TensorBoard(log_dir='./log_dir')])

Custom Callbacks

1
2
3
my_custom_callback = MyCustomCallback()

model.fit(x_train,y_train,batch_size=64,epochs=1,verbose=0,callbacks=[my_custom_callback])

Detect Overfitting

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class DetectOverfittingCallback(tf.keras.callbacks.Callback):
def __init__(self, threshold=0.7):
super(DetectOverfittingCallback, self).__init__()
self.threshold = threshold

def on_epoch_end(self, epoch, logs=None):
ratio = logs["val_loss"] / logs["loss"]
print("Epoch: {}, Val/Train loss ratio: {:.2f}".format(epoch, ratio))

if ratio > self.threshold:
print("Stopping training...")
self.model.stop_training = True

model = get_model()
_ = model.fit(x_train, y_train,
validation_data=(x_test, y_test),
batch_size=64,
epochs=3,
verbose=0,
callbacks=[DetectOverfittingCallback()])

Visualization Callback

1
2
3
4
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Visualization utilities
plt.rc('font', size=20)
plt.rc('figure', figsize=(15, 3))

def display_digits(inputs, outputs, ground_truth, epoch, n=10):
plt.clf()

plt.yticks([])
plt.grid(None)
inputs = np.reshape(inputs, [n, 28, 28])
inputs = np.swapaxes(inputs, 0, 1)
inputs = np.reshape(inputs, [28, 28*n])
plt.imshow(inputs)
plt.xticks([28*x+14 for x in range(n)], outputs)
for i,t in enumerate(plt.gca().xaxis.get_ticklabels()):
if outputs[i] == ground_truth[i]:
t.set_color('green')
else:
t.set_color('red')
plt.grid(None)
1
GIF_PATH = './animation.gif'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class VisCallback(tf.keras.callbacks.Callback):
def __init__(self, inputs, ground_truth, display_freq=10, n_samples=10):
self.inputs = inputs
self.ground_truth = ground_truth
self.images = []
self.display_freq = display_freq
self.n_samples = n_samples

def on_epoch_end(self, epoch, logs=None):
# Randomly sample data
indexes = np.random.choice(len(self.inputs), size=self.n_samples)
X_test, y_test = self.inputs[indexes], self.ground_truth[indexes]
predictions = np.argmax(self.model.predict(X_test), axis=1)

# Plot the digits
display_digits(X_test, predictions, y_test, epoch, n=self.display_freq)

# Save the figure
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
image = Image.open(buf)
self.images.append(np.array(image))

# Display the digits every 'display_freq' number of epochs
if epoch % self.display_freq == 0:
plt.show()

def on_train_end(self, logs=None):
imageio.mimsave(GIF_PATH, self.images, fps=1)
1
2
3
4
5
6
7
8
9
10
11
12
13
def get_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(32, activation='linear', input_dim=784))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model

model = get_model()
model.fit(x_train, y_train,
batch_size=64,
epochs=20,
verbose=0,
callbacks=[VisCallback(x_test, y_test)])
1
2
3
4
SCALE = 60

# FYI, the format is set to PNG here to bypass checks for acceptable embeddings
IPyImage(GIF_PATH, format='png', width=15 * SCALE, height=3 * SCALE)