Aller à la barre d’outils

Use models to predict results

Prédire. C’est la raison principale, voire la seule, pour laquelle on utilise des réseaux de neurones.

Régression ou Classification. Régression lorsqu’il s’agit de prédire un nombre, classification lorsqu’il s’agit de prédire une classe.

Reprenons notre exemple précédent avec MNIST.

import tensorflow as tf
print (tf.__version__)

mnist = tf.keras.datasets.mnist
1
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

Avec ce modèle non entraîné on peut déjà faire des prédictions, qui seront bien sûr erronées.

img = x_train[1]
predictions = model(img).numpy()
array([[-0.477541  ,  0.21469425,  0.13032357,  0.3327995 , -0.02744773,
         0.06294087,  0.49568832, -0.04907865,  0.04224914,  0.17364587]],
      dtype=float32)

puis avec un softmax

print(tf.nn.softmax(predictions).numpy())
array([[0.0550979 , 0.11009536, 0.1011876 , 0.12389719, 0.08641876,
        0.09459395, 0.14581533, 0.08456952, 0.09265675, 0.10566762]],
      dtype=float32)

Si on entraîne le modèle.

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

et on recalcule la prédiction (avec un softmax)

predictions = model(x_train[:1]).numpy()
print(tf.nn.softmax(predictions).numpy())
array([[1.1987252e-11, 1.6072766e-09, 6.1684972e-08, 2.5967252e-03,
        1.3634578e-18, 9.9740309e-01, 1.5508568e-12, 6.4643721e-11,
        6.9358935e-11, 3.5405140e-08]], dtype=float32)

Pour évaluer la performance du modèle, sur un jeu de données de test, on utilisera la méthode evaluate.

model.evaluate(x_test,  y_test, verbose=2)

Cette méthode retourne le coût et le résultat pour la métrique choisie (accuracy).

10000/1 - 1s - loss: 0.0450 - accuracy: 0.9792
[0.07106049620504491, 0.9792]