Share via


Tutorial: Ausführen eines TensorFlow-Modells in Python

In diesem Tutorial erfahren Sie, wie Sie ein exportiertes TensorFlow-Modell lokal zum Klassifizieren von Bildern verwenden.

Hinweis

Dieses Tutorial betrifft nur Modelle, die aus Bildklassifizierungsprojekten mit der Domäne „Allgemein (kompakt)“ exportiert wurden. Wenn Sie andere Modelle exportiert haben, finden Sie weitere Informationen im Beispielcoderepository.

Voraussetzungen

  • Installieren Sie entweder Python 2.7+ oder Python 3.6+.
  • Installieren Sie pip.

Anschließend müssen folgende Pakete installiert werden:

pip install tensorflow
pip install pillow
pip install numpy
pip install opencv-python

Laden Ihrer Modelle und Tags

Die heruntergeladene ZIP-Datei vom Exportschritt enthält eine Datei mit dem Namen model.pb und eine Datei mit dem Namen labels.txt. Diese Dateien stellen das trainierte Modell und die Klassifizierungsbezeichnungen dar. Laden Sie als ersten Schritt das Modell in Ihr Projekt. Fügen Sie den folgenden Code einem neuen Python-Skript hinzu:

import tensorflow as tf
import os

graph_def = tf.compat.v1.GraphDef()
labels = []

# These are set to the default names from exported models, update as needed.
filename = "model.pb"
labels_filename = "labels.txt"

# Import the TF graph
with tf.io.gfile.GFile(filename, 'rb') as f:
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

# Create a list of labels.
with open(labels_filename, 'rt') as lf:
    for l in lf:
        labels.append(l.strip())

Vorbereiten eines Bilds auf die Vorhersage

Es müssen ein paar Schritte ausgeführt werden, um ein Bild für die Vorhersage vorzubereiten. In diesen Schritten wird die Bildmanipulation imitiert, die während des Trainings ausgeführt wird.

  1. Öffnen der Datei und Erstellen eines Bilds im BGR-Farbraum

    from PIL import Image
    import numpy as np
    import cv2
    
    # Load from a file
    imageFile = "<path to your image file>"
    image = Image.open(imageFile)
    
    # Update orientation based on EXIF tags, if the file has orientation info.
    image = update_orientation(image)
    
    # Convert to OpenCV format
    image = convert_to_opencv(image)
    
  2. Wenn das Bild eine Abmessung über 1.600 Pixel hat, rufen Sie diese Methode auf (weiter unten definiert).

    image = resize_down_to_1600_max_dim(image)
    
  3. Zuschneiden des größten Quadrats in der Mitte

    h, w = image.shape[:2]
    min_dim = min(w,h)
    max_square_image = crop_center(image, min_dim, min_dim)
    
  4. Ändern Sie die Größe dieses Quadrats auf 256 × 256.

    augmented_image = resize_to_256_square(max_square_image)
    
  5. Zuschneiden des mittleren Bereichs auf die genaue Eingabegröße für das Modell

    # Get the input size of the model
    with tf.compat.v1.Session() as sess:
        input_tensor_shape = sess.graph.get_tensor_by_name('Placeholder:0').shape.as_list()
    network_input_size = input_tensor_shape[1]
    
    # Crop the center for the specified network_input_Size
    augmented_image = crop_center(augmented_image, network_input_size, network_input_size)
    
    
  6. Definieren Sie Hilfsfunktionen. In den obenstehenden Schritten werden die folgenden Hilfsfunktionen verwendet:

    def convert_to_opencv(image):
        # RGB -> BGR conversion is performed as well.
        image = image.convert('RGB')
        r,g,b = np.array(image).T
        opencv_image = np.array([b,g,r]).transpose()
        return opencv_image
    
    def crop_center(img,cropx,cropy):
        h, w = img.shape[:2]
        startx = w//2-(cropx//2)
        starty = h//2-(cropy//2)
        return img[starty:starty+cropy, startx:startx+cropx]
    
    def resize_down_to_1600_max_dim(image):
        h, w = image.shape[:2]
        if (h < 1600 and w < 1600):
            return image
    
        new_size = (1600 * w // h, 1600) if (h > w) else (1600, 1600 * h // w)
        return cv2.resize(image, new_size, interpolation = cv2.INTER_LINEAR)
    
    def resize_to_256_square(image):
        h, w = image.shape[:2]
        return cv2.resize(image, (256, 256), interpolation = cv2.INTER_LINEAR)
    
    def update_orientation(image):
        exif_orientation_tag = 0x0112
        if hasattr(image, '_getexif'):
            exif = image._getexif()
            if (exif != None and exif_orientation_tag in exif):
                orientation = exif.get(exif_orientation_tag, 1)
                # orientation is 1 based, shift to zero based and flip/transpose based on 0-based values
                orientation -= 1
                if orientation >= 4:
                    image = image.transpose(Image.TRANSPOSE)
                if orientation == 2 or orientation == 3 or orientation == 6 or orientation == 7:
                    image = image.transpose(Image.FLIP_TOP_BOTTOM)
                if orientation == 1 or orientation == 2 or orientation == 5 or orientation == 6:
                    image = image.transpose(Image.FLIP_LEFT_RIGHT)
        return image
    

Klassifizieren eines Bilds

Sobald das Bild als Tensor vorbereitet ist, kann es über das Modell für eine Vorhersage gesendet werden.

# These names are part of the model and cannot be changed.
output_layer = 'loss:0'
input_node = 'Placeholder:0'

with tf.compat.v1.Session() as sess:
    try:
        prob_tensor = sess.graph.get_tensor_by_name(output_layer)
        predictions = sess.run(prob_tensor, {input_node: [augmented_image] })
    except KeyError:
        print ("Couldn't find classification output layer: " + output_layer + ".")
        print ("Verify this a model exported from an Object Detection project.")
        exit(-1)

Anzeigen der Ergebnisse

Die Ergebnisse der Ausführung des Bildtensors über das Modell muss dann wieder den Bezeichnungen zugeordnet werden.

    # Print the highest probability label
    highest_probability_index = np.argmax(predictions)
    print('Classified as: ' + labels[highest_probability_index])
    print()

    # Or you can print out all of the results mapping labels to probabilities.
    label_index = 0
    for p in predictions:
        truncated_probablity = np.float64(np.round(p,8))
        print (labels[label_index], truncated_probablity)
        label_index += 1

Nächste Schritte

Als Nächstes erfahren Sie, wie Sie Ihr Modell in einer mobilen Anwendung umschließen: