Esercitazione: Eseguire un modello TensorFlow in Python

Questa esercitazione illustra come usare un modello TensorFlow esportato in locale per classificare le immagini.

Nota

Questa esercitazione si applica solo ai modelli esportati da progetti di classificazione delle immagini "Generale (compatta)". Se sono stati esportati altri modelli, visitare il repository di codice di esempio.

Prerequisiti

  • Installare Python 2.7+ o Python 3.6+.
  • Installare pip.

Sarà quindi necessario installare anche i pacchetti seguenti:

pip install tensorflow
pip install pillow
pip install numpy
pip install opencv-python

Caricare il modello e i tag

Il file .zip scaricato dal passaggio di esportazione contiene un file model.pb e un file labels.txt. Questi file rappresentano il modello sottoposto a training e le etichette di classificazione. Il primo passaggio consiste nel caricare il modello nel progetto. Aggiungere il codice seguente in un nuovo script Python.

import tensorflow as tf
import os

graph_def = tf.compat.v1.GraphDef()
labels = []

# These are set to the default names from exported models, update as needed.
filename = "model.pb"
labels_filename = "labels.txt"

# Import the TF graph
with tf.io.gfile.GFile(filename, 'rb') as f:
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

# Create a list of labels.
with open(labels_filename, 'rt') as lf:
    for l in lf:
        labels.append(l.strip())

Preparare un'immagine per la stima

Per preparare un'immagine per la stima, è necessario eseguire alcuni passaggi. Questi passaggi simulano la manipolazione dell'immagine eseguita durante il training.

  1. Aprire il file e creare un'immagine nello spazio colore BGR

    from PIL import Image
    import numpy as np
    import cv2
    
    # Load from a file
    imageFile = "<path to your image file>"
    image = Image.open(imageFile)
    
    # Update orientation based on EXIF tags, if the file has orientation info.
    image = update_orientation(image)
    
    # Convert to OpenCV format
    image = convert_to_opencv(image)
    
  2. Se l'immagine ha una dimensione maggiore di 1600 pixel, chiamare questo metodo (definito in un secondo momento).

    image = resize_down_to_1600_max_dim(image)
    
  3. Ritagliare il riquadro centrale più grande

    h, w = image.shape[:2]
    min_dim = min(w,h)
    max_square_image = crop_center(image, min_dim, min_dim)
    
  4. Ridimensionare il quadrato fino a 256x256

    augmented_image = resize_to_256_square(max_square_image)
    
  5. Ritagliare il centro per le dimensioni di input specifiche per il modello

    # Get the input size of the model
    with tf.compat.v1.Session() as sess:
        input_tensor_shape = sess.graph.get_tensor_by_name('Placeholder:0').shape.as_list()
    network_input_size = input_tensor_shape[1]
    
    # Crop the center for the specified network_input_Size
    augmented_image = crop_center(augmented_image, network_input_size, network_input_size)
    
    
  6. Definire le funzioni helper. I passaggi precedenti usano le funzioni helper seguenti:

    def convert_to_opencv(image):
        # RGB -> BGR conversion is performed as well.
        image = image.convert('RGB')
        r,g,b = np.array(image).T
        opencv_image = np.array([b,g,r]).transpose()
        return opencv_image
    
    def crop_center(img,cropx,cropy):
        h, w = img.shape[:2]
        startx = w//2-(cropx//2)
        starty = h//2-(cropy//2)
        return img[starty:starty+cropy, startx:startx+cropx]
    
    def resize_down_to_1600_max_dim(image):
        h, w = image.shape[:2]
        if (h < 1600 and w < 1600):
            return image
    
        new_size = (1600 * w // h, 1600) if (h > w) else (1600, 1600 * h // w)
        return cv2.resize(image, new_size, interpolation = cv2.INTER_LINEAR)
    
    def resize_to_256_square(image):
        h, w = image.shape[:2]
        return cv2.resize(image, (256, 256), interpolation = cv2.INTER_LINEAR)
    
    def update_orientation(image):
        exif_orientation_tag = 0x0112
        if hasattr(image, '_getexif'):
            exif = image._getexif()
            if (exif != None and exif_orientation_tag in exif):
                orientation = exif.get(exif_orientation_tag, 1)
                # orientation is 1 based, shift to zero based and flip/transpose based on 0-based values
                orientation -= 1
                if orientation >= 4:
                    image = image.transpose(Image.TRANSPOSE)
                if orientation == 2 or orientation == 3 or orientation == 6 or orientation == 7:
                    image = image.transpose(Image.FLIP_TOP_BOTTOM)
                if orientation == 1 or orientation == 2 or orientation == 5 or orientation == 6:
                    image = image.transpose(Image.FLIP_LEFT_RIGHT)
        return image
    

Classificare un'immagine

Dopo aver preparato l'immagine come tensore, è possibile inviarla tramite il modello per una stima.

# These names are part of the model and cannot be changed.
output_layer = 'loss:0'
input_node = 'Placeholder:0'

with tf.compat.v1.Session() as sess:
    try:
        prob_tensor = sess.graph.get_tensor_by_name(output_layer)
        predictions = sess.run(prob_tensor, {input_node: [augmented_image] })
    except KeyError:
        print ("Couldn't find classification output layer: " + output_layer + ".")
        print ("Verify this a model exported from an Object Detection project.")
        exit(-1)

Visualizzare i risultati

Sarà quindi necessario mappare nuovamente alle etichette i risultati dell'esecuzione del tensore dell'immagine nel modello.

    # Print the highest probability label
    highest_probability_index = np.argmax(predictions)
    print('Classified as: ' + labels[highest_probability_index])
    print()

    # Or you can print out all of the results mapping labels to probabilities.
    label_index = 0
    for p in predictions:
        truncated_probablity = np.float64(np.round(p,8))
        print (labels[label_index], truncated_probablity)
        label_index += 1

Passaggi successivi

Informazioni su come inserire il modello in un'applicazione per dispositivi mobili: