1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| class VisCallback(tf.keras.callbacks.Callback): def __init__(self, inputs, ground_truth, display_freq=10, n_samples=10): self.inputs = inputs self.ground_truth = ground_truth self.images = [] self.display_freq = display_freq self.n_samples = n_samples
def on_epoch_end(self, epoch, logs=None): indexes = np.random.choice(len(self.inputs), size=self.n_samples) X_test, y_test = self.inputs[indexes], self.ground_truth[indexes] predictions = np.argmax(self.model.predict(X_test), axis=1)
display_digits(X_test, predictions, y_test, epoch, n=self.display_freq)
buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) image = Image.open(buf) self.images.append(np.array(image))
if epoch % self.display_freq == 0: plt.show()
def on_train_end(self, logs=None): imageio.mimsave(GIF_PATH, self.images, fps=1)
|