Aller à la barre d’outils

Classification d’iris – #4

Ce tutoriel de classification d’iris, à l’aide de TensorFlow 2.x est le suivant du #3.

On termine en évaluant le modèle

Le problèmeRéférencesCodeGist

Le problème

Classer des iris, en 3 catégories (Iris setosa, Iris virginica et Iris versicolor), à partir des dimensions (largeur et longueur) des sépales et pétales est un problème classique du Machine Learning. Puisqu’on peut le traiter facilement avec les outils de la statistique, pourquoi ne pas essayer de le faire, de façon plus compliquée avec un réseau de neurones ! Tout cela, bien sûr, dans un but pédagogique.

Références

Code

Le notebook Jupyter, en Python, dans l’environnement GCP est disponible ici.

Gist

Evaluation du modèle

On travaille sur le test set et on répète ce que nous avons fait précédemment.

test_url = "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv"
test_fp = tf.keras.utils.get_file(fname=os.path.basename(test_url),
                                  origin=test_url)
test_dataset = tf.data.experimental.make_csv_dataset(
    test_fp,
    batch_size,
    column_names=column_names,
    label_name='species',
    num_epochs=1,
    shuffle=False)

test_dataset = test_dataset.map(pack_features_vector)
test_accuracy = tf.keras.metrics.Accuracy()

for (x, y) in test_dataset:
  # training=False is needed only if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  logits = model(x, training=False)
  prediction = tf.argmax(logits, axis=1, output_type=tf.int32)
  test_accuracy(prediction, y)

print("Test set accuracy: {:.3%}".format(test_accuracy.result()))
Test set accuracy: 96.667%

On peut regarder les prédictions sur le dernier batch. Il n’y a qu’une seule erreur.

tf.stack([y,prediction],axis=1)
<tf.Tensor: shape=(30, 2), dtype=int32, numpy=
array([[1, 1],
       [2, 2],
       [0, 0],
       [1, 1],
       [1, 1],
       [1, 1],
       [0, 0],
       [2, 2],
       [1, 1],
       [2, 2],
       [2, 2],
       [0, 0],
       [2, 2],
       [1, 1],
       [1, 1],
       [0, 0],
       [1, 1],
       [0, 0],
       [0, 0],
       [2, 2],
       [0, 0],
       [1, 1],
       [2, 2],
       [1, 2],
       [1, 1],
       [1, 1],
       [0, 0],
       [1, 1],
       [2, 2],
       [1, 1]], dtype=int32)>