Ce tutoriel de classification d’iris, à l’aide de TensorFlow 2.x est le suivant du #3.
On termine en évaluant le modèle
Le problème
Classer des iris, en 3 catégories (Iris setosa, Iris virginica et Iris versicolor), à partir des dimensions (largeur et longueur) des sépales et pétales est un problème classique du Machine Learning. Puisqu’on peut le traiter facilement avec les outils de la statistique, pourquoi ne pas essayer de le faire, de façon plus compliquée avec un réseau de neurones ! Tout cela, bien sûr, dans un but pédagogique.
Références
- Custom training: walkthrough
- Premade Estimators
- TensorFlow – tutoriel #1
- Tutorial Classification
- Scatter Matrices using pandas
- How to use Pandas Scatter Matrix
- Dataset 5 fleurs
- How does the Softmax activation function work?
- TensorFlow 2 Tutorial: Get Started in Deep Learning With tf.keras
- Get started with TensorBoard
Code
Le notebook Jupyter, en Python, dans l’environnement GCP est disponible ici.
Gist
Evaluation du modèle
On travaille sur le test set et on répète ce que nous avons fait précédemment.
test_url = "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv"
test_fp = tf.keras.utils.get_file(fname=os.path.basename(test_url),
origin=test_url)
test_dataset = tf.data.experimental.make_csv_dataset(
test_fp,
batch_size,
column_names=column_names,
label_name='species',
num_epochs=1,
shuffle=False)
test_dataset = test_dataset.map(pack_features_vector)
test_accuracy = tf.keras.metrics.Accuracy()
for (x, y) in test_dataset:
# training=False is needed only if there are layers with different
# behavior during training versus inference (e.g. Dropout).
logits = model(x, training=False)
prediction = tf.argmax(logits, axis=1, output_type=tf.int32)
test_accuracy(prediction, y)
print("Test set accuracy: {:.3%}".format(test_accuracy.result()))
Test set accuracy: 96.667%
On peut regarder les prédictions sur le dernier batch. Il n’y a qu’une seule erreur.
tf.stack([y,prediction],axis=1)
<tf.Tensor: shape=(30, 2), dtype=int32, numpy=
array([[1, 1],
[2, 2],
[0, 0],
[1, 1],
[1, 1],
[1, 1],
[0, 0],
[2, 2],
[1, 1],
[2, 2],
[2, 2],
[0, 0],
[2, 2],
[1, 1],
[1, 1],
[0, 0],
[1, 1],
[0, 0],
[0, 0],
[2, 2],
[0, 0],
[1, 1],
[2, 2],
[1, 2],
[1, 1],
[1, 1],
[0, 0],
[1, 1],
[2, 2],
[1, 1]], dtype=int32)>