Use callbacks to trigger the end of training cycles

Keras dispose de la fonction EarlyStopping qui permet de mettre fin à l’apprentissage lorsqu’un paramètre monitoré cesse de s’améliorer.

Jason Brownlee a un tutoriel à ce sujet : Use Early Stopping to Halt the Training of Neural Networks At the Right Time.

Ci-dessous un exemple très simple d’Early Stopping :

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from keras.callbacks import EarlyStopping
def generate_dataset():
 x_batch = np.linspace(0, 2, 100)
 y_batch = 1.5 * x_batch + np.random.randn(*x_batch.shape) * 0.2 + 0.5
 return x_batch, y_batch

xs, ys = generate_dataset()
plt.figure()
plt.scatter(xs, ys)
plt.show()
es = tf.keras.callbacks.EarlyStopping(
    monitor='loss', patience=20, verbose=1, mode='min',
)

Le paramètre monitoré est la fonction de perte. On vérifie qu’un seuil minimum n’est pas atteint. S’il l’est on arrête l’apprentissage à condition que c’est vérifié au moins 20 fois (patience).

model = keras.Sequential([
    keras.layers.Dense(1, input_dim=1),
])

model.compile(loss="mean_squared_error", optimizer="sgd")
history = model.fit(xs, ys, epochs=500, callbacks=[es])
Epoch 1/500
4/4 [==============================] - 0s 2ms/step - loss: 1.6861
Epoch 2/500
4/4 [==============================] - 0s 2ms/step - loss: 1.0702
Epoch 3/500
4/4 [==============================] - 0s 3ms/step - loss: 0.7972
Epoch 4/500
4/4 [==============================] - 0s 3ms/step - loss: 0.5876
Epoch 5/500
4/4 [==============================] - 0s 3ms/step - loss: 0.4448
Epoch 6/500
4/4 [==============================] - 0s 2ms/step - loss: 0.2652
Epoch 7/500
4/4 [==============================] - 0s 3ms/step - loss: 0.2620
Epoch 8/500
4/4 [==============================] - 0s 3ms/step - loss: 0.1880
Epoch 9/500
4/4 [==============================] - 0s 2ms/step - loss: 0.1192
Epoch 10/500
4/4 [==============================] - 0s 4ms/step - loss: 0.1148
Epoch 11/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0819
Epoch 12/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0872
Epoch 13/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0594
Epoch 14/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0564
Epoch 15/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0516
Epoch 16/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0796
Epoch 17/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0438
Epoch 18/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0409
Epoch 19/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0374
Epoch 20/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0355
Epoch 21/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0460
Epoch 22/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0335
Epoch 23/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0360
Epoch 24/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0367
Epoch 25/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0329
Epoch 26/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0400
Epoch 27/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0330
Epoch 28/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0377
Epoch 29/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0412
Epoch 30/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0319
Epoch 31/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0365
Epoch 32/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0316
Epoch 33/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0356
Epoch 34/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0374
Epoch 35/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0396
Epoch 36/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0365
Epoch 37/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0323
Epoch 38/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0333
Epoch 39/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0331
Epoch 40/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0320
Epoch 41/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0513
Epoch 42/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0352
Epoch 43/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0317
Epoch 44/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0330
Epoch 45/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0501
Epoch 46/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0386
Epoch 47/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0432
Epoch 48/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0414
Epoch 49/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0497
Epoch 50/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0366
Epoch 51/500
4/4 [==============================] - 0s 2ms/step - loss: 0.0347
Epoch 52/500
4/4 [==============================] - 0s 3ms/step - loss: 0.0416
Epoch 00052: early stopping

Le code de ce tutoriel est ici.