Vortrainierte Modelle und Lerntransfer
- 10 Minuten
Die Schulung von CNNs kann eine erhebliche Zeit in Anspruch nehmen, und für diesen Vorgang ist eine große Menge an Daten erforderlich. Viel Zeit wird damit verbracht, die besten Filter auf niedriger Ebene zu finden, die ein Netzwerk zum Extrahieren von Mustern aus den Bildern benötigt. Eine natürliche Frage stellt sich - können wir ein neurales Netzwerk verwenden, das auf einem Dataset trainiert wurde, und es an die Klassifizierung verschiedener Bilder ohne vollständigen Trainingsprozess anpassen?
Dieser Ansatz wird als Transferlernen bezeichnet, da wir einige Kenntnisse von einem neuralen Netzwerkmodell in ein anderes übertragen. Bei transfer learning beginnen wir in der Regel mit einem vortrainierten Modell, das auf einigen großen Bilddatensätzen trainiert wurde, z. B. ImageNet. Diese Modelle machen bereits eine gute Arbeit, um verschiedene Features aus generischen Bildern zu extrahieren, und in vielen Fällen kann nur das Erstellen eines Klassifizierers über diese extrahierten Features zu einem guten Ergebnis führen.
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
from PIL import Image
Katzen- und Hundedatensatz
In dieser Einheit lösen wir ein reales Problem der Klassifizierung von Katzen und Hunden. Aus diesem Grund verwenden wir Kaggle Cats vs. Dogs Dataset, die auch von Microsoft heruntergeladen werden können.
Laden wir dieses Dataset herunter und extrahieren sie in das data Verzeichnis:
import urllib.request
import zipfile
dataset_url = 'https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip'
data_dir = 'data'
os.makedirs(data_dir, exist_ok=True)
zip_path = os.path.join(data_dir, 'kagglecatsanddogs_5340.zip')
if not os.path.exists(zip_path):
urllib.request.urlretrieve(dataset_url, zip_path)
if not os.path.exists(os.path.join(data_dir, 'PetImages')):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
Das Dataset kann einige beschädigte Bilddateien enthalten. Definieren wir eine Hilfsfunktion, um sie zu überprüfen und zu entfernen, bevor sie geladen werden:
def check_image(fn):
try:
im = Image.open(fn)
im.verify()
return True
except (IOError, SyntaxError):
return False
def check_image_dir(dir_path):
for fn in glob.glob(dir_path):
if not check_image(fn):
print(f"Corrupt image: {fn}")
os.remove(fn)
# Remove any corrupt images from the dataset
check_image_dir('data/PetImages/Cat/*.jpg')
check_image_dir('data/PetImages/Dog/*.jpg')
Laden des Datasets
In den vorherigen Beispielen wurden Datasets geladen, die in Keras integriert sind. Jetzt verwenden wir unser eigenes Dataset, das wir aus einem Verzeichnis von Bildern laden müssen. Keras enthält eine Hilfsfunktion image_dataset_from_directory , die ein tf.data.Dataset Aus einem Verzeichnis von Bildern erstellen kann, das nach Klasse in Unterverzeichnissen organisiert ist.
data_dir = 'data/PetImages'
batch_size = 32
ds_train = keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='training',
seed=13,
image_size=(224, 224),
batch_size=batch_size
)
ds_test = keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='validation',
seed=13,
image_size=(224, 224),
batch_size=batch_size
)
Hinweis
Wir verwenden denselben seed Wert beim Erstellen der Schulungs- und Validierungsteilungen, um sicherzustellen, dass keine Überlappung zwischen den beiden Teilmengen gewährleistet ist.
Wir können die Klassennamen überprüfen, die automatisch aus der Verzeichnisstruktur abgeleitet wurden:
# Expected output: ['Cat', 'Dog']
ds_train.class_names
Definieren wir eine Hilfsfunktion zum Visualisieren von Beispielen aus unserem Datensatz (dies ist eine neue Version von display_dataset, die für stapelweise Daten angepasst ist):
def display_dataset(images, labels, classes=None, cols=8):
n = len(images)
rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
axes = axes.flatten() if n > 1 else [axes]
for i, ax in enumerate(axes):
if i < n:
ax.imshow(images[i])
label = int(labels[i][0]) if labels[i].ndim > 0 else int(labels[i])
title = classes[label] if classes else str(label)
ax.set_title(title, fontsize=8)
ax.axis('off')
plt.tight_layout()
plt.show()
Das Dataset liefert Batches von Bildern und Beschriftungen. Jeder Batch enthält 32 Bilder der Größe 224×224 mit 3 Farbkanälen und entsprechenden Etiketten:
for x, y in ds_train:
print(f"Training batch shape: features={x.shape}, labels={y.shape}")
x_sample, y_sample = x, y
break
# Expected output: Training batch shape: features=(32, 224, 224, 3), labels=(32,)
display_dataset(x_sample.numpy().astype(np.uint8), np.expand_dims(y_sample, 1), classes=ds_train.class_names)
Hinweis
Die Bildpixelwerte befinden sich im Bereich von 0 bis 255. Bei einigen Modellen muss die Eingabe mithilfe einer modellspezifischen Funktion auf 0-1 skaliert oder vorverarbeitet werden. VGG-16 verfügt über eine eigene preprocess_input Funktion, die wir später verwenden.
Vortrainierte Modelle
Es gibt viele vortrainierte neurale Netzwerke für die Bildklassifizierung, die auf dem ImageNet-Dataset trainiert wurden, das mehr als 14 Millionen Bilder in 1.000 Kategorien enthält. Eine der bekanntesten Architekturen ist VGG-16, die eine gute Genauigkeit erzielt, während sie einfach zu verstehen ist. Laden wir ein VGG-16-Modell mit vortrainierten Gewichten:
vgg = keras.applications.VGG16()
Versuchen wir, dieses vortrainierte Netzwerk zu verwenden, um eines unserer Bilder zu klassifizieren. Das VGG-16-Netzwerk wurde auf ImageNet trainiert, das Kategorien für verschiedene Hunde- und Katzenrassen umfasst:
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
# tf.argmax returns the index of the highest-probability class
print(f"Most probable class = {tf.argmax(res, 1)}")
# decode_predictions maps class indices to human-readable labels
keras.applications.vgg16.decode_predictions(res.numpy())
Die preprocess_input Funktion skaliert Pixelwerte entsprechend für das VGG-16-Modell. Die decode_predictions Funktion gibt die 5 höchstwahrscheinlichsten ImageNet-Klassen zusammen mit ihren Konfidenzergebnissen zurück.
Sehen wir uns die Architektur von VGG-16 an:
# Shows all layers including convolutional blocks and final Dense classifier
vgg.summary()
GPU-Berechnungen
Tiefe neurale Netzwerke erfordern eine ziemlich substanzielle Rechenleistung für die Schulung. Die Verwendung einer GPU kann den Schulungsvorgang erheblich beschleunigen. Überprüfen wir, ob eine GPU verfügbar ist:
# Lists available GPU devices; an empty list means CPU-only
tf.config.list_physical_devices('GPU')
Extrahieren von VGG-Features
Wenn wir VGG-16 verwenden möchten, um Features aus unseren Bildern zu extrahieren, benötigen wir das Modell ohne die endgültigen Klassifizierungsebenen. Dazu können wir Folgendes angeben include_top=False:
vgg = keras.applications.VGG16(include_top=False)
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
# The output is a 7x7 grid of 512 feature maps
print(f"Shape after applying VGG-16: {res[0].shape}")
plt.figure(figsize=(15, 3))
plt.imshow(res[0].numpy().reshape(-1, 512))
Der resultierende Featurevektor hat die Form 7×7×512 und umfasst 25088 Werte. Dies stellt die hochrangigen Merkmale dar, die VGG-16 gelernt hat, aus dem Bild zu extrahieren. Wir können diese Features manuell für unser gesamtes Dataset vorkonfigurieren und dann einen Klassifizierer oben trainieren:
Warnung
Wir verwenden .take(25) und .take(10) unten, um die Datasetgröße für eine schnellere Schulung in diesem Beispiel zu begrenzen. Jeder Batch enthält 32 Bilder, daher verwenden wir nur 800 Schulungsbilder und 320 Testbilder. Die hier gemeldete Genauigkeit von ~90% spiegelt diese kleine Teilmenge wider und kann nicht auf das vollständige Dataset generalisieren. Für den Produktionseinsatz trainieren Sie mit dem gesamten Dataset.
def preprocess(x, y):
return keras.applications.vgg16.preprocess_input(x), y
ds_features_train = ds_train.take(25).map(preprocess).map(lambda x, y: (vgg(x), y)).cache()
ds_features_test = ds_test.take(10).map(preprocess).map(lambda x, y: (vgg(x), y)).cache()
for x, y in ds_features_train:
# Expected output: (32, 7, 7, 512) (32,)
print(x.shape, y.shape)
break
Hinweis
Wir rufen .cache() nach dem Extrahieren von Merkmalen auf, damit das VGG-16-Modell nur einmal pro Batch statt pro Epoche ausgeführt wird.
Jetzt können wir einen einfachen Klassifizierer für die extrahierten Features erstellen. Da die VGG-Features bereits sehr informativ sind, kann sogar eine einzelne Dichteschicht gute Ergebnisse erzielen:
model = keras.Sequential([
keras.layers.Input(shape=(7, 7, 512)),
keras.layers.Flatten(),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
hist = model.fit(ds_features_train, validation_data=ds_features_test)
# Expected: validation accuracy around 90%
Mit rund 90% Genauigkeit zeigt dies die Leistungsfähigkeit der vortrainierten Features! Das manuelle Komputieren von Features ist jedoch umständlich.
Transferlernen mithilfe eines VGG-Netzwerks
Durch die Kombination des VGG-16-Featureextraktionsmoduls und unseres Klassifizierers in einem einzigen Netzwerk können wir manuelle Vorkompilierungsfunktionen vermeiden. Der Schlüssel besteht darin, die vortrainierten Schichten einzufrieren, damit ihre Gewichte während des Trainings nicht aktualisiert werden.
Wir verschieben den preprocess_input Schritt in die Datenpipeline, anstatt ihn in das Modell als Lambda Ebene einzubetten. Dadurch bleibt das Modell serialisierbar, sodass wir es später speichern und laden können:
def preprocess(x, y):
return keras.applications.vgg16.preprocess_input(x), y
ds_train_preprocessed = ds_train.map(preprocess)
ds_test_preprocessed = ds_test.map(preprocess)
Hinweis
Da die Vorverarbeitung jetzt Teil der Datenpipeline und nicht des Modells ist, müssen Sie preprocess_input auch zur Inferenzzeit auf Eingabedaten anwenden.
Jetzt erstellen wir das Modell mit der eingefrorenen VGG-16-Basis.
vgg_base = keras.applications.VGG16(include_top=False, input_shape=(224, 224, 3))
vgg_base.trainable = False
model = keras.Sequential([
keras.layers.Input(shape=(224, 224, 3)),
vgg_base,
keras.layers.Flatten(),
keras.layers.Dense(1, activation='sigmoid')
])
# Notice: ~15 million params are non-trainable (VGG-16), only ~25k are trainable
model.summary()
Durch das Einfrieren der VGG-16-Schichten müssen wir nur die endgültige Dichteschicht trainieren, die ungefähr 25.000 Parameter anstelle der vollen 15 Millionen Parameter aufweist. Dadurch wird die Schulung schneller:
Warnung
Wie auch im vorherigen Abschnitt verwenden wir .take(50) und .take(10), um das Dataset für eine schnellere Schulung zu beschränken. Dies bedeutet, dass wir auf ca. 1.600 Bildern trainieren und auf 320 überprüfen. Genauigkeitsergebnisse können abweichen, wenn das Training auf dem gesamten Datensatz erfolgt.
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
hist = model.fit(ds_train_preprocessed.take(50), validation_data=ds_test_preprocessed.take(10))
# Expected: validation accuracy around 90% or higher
Speichern und Laden des Modells
Sobald wir ein trainiertes Modell haben, können wir es auf der Festplatte speichern und später erneut laden, ohne umzuschulungen:
model.save('data/cats_dogs.keras')
Hinweis
Die .keras Erweiterung verwendet das systemeigene Keras 3-Format. Wenn Sie eine ältere Version von TensorFlow/Keras verwenden, verwenden Sie .h5 stattdessen das HDF5-Format oder das SavedModel-Verzeichnisformat.
So laden Sie das gespeicherte Modell:
model = keras.models.load_model('data/cats_dogs.keras')
Andere Computer-Vision-Modelle
VGG-16 ist eine der einfachsten tiefen CNN-Architekturen zu verstehen, aufgrund ihrer einheitlichen Struktur gestapelter 3×3-Konvolutionen. Keras bietet viele weitere vortrainierte Netzwerke. Die am häufigsten verwendeten unter denen sind ResNet-Architekturen , die von Microsoft entwickelt wurden, und Inception von Google.
Verbessern von Ergebnissen mit Datenerweiterung
Wenn Sie mit eingeschränkten Schulungsdaten arbeiten, kann die Datenerweiterung die Verallgemeinerung erheblich verbessern. Durch die Anwendung zufälliger Transformationen (z. B. horizontale Flips, Drehungen und Zooms) zum Trainieren von Bildern erhöhen wir die Vielfalt des Datasets künstlich. Keras bietet Erweiterungsebenen wie keras.layers.RandomFlip, keras.layers.RandomRotationund keras.layers.RandomZoom die direkt zu Ihrem Modell oder Ihrer Datenpipeline hinzugefügt werden können.
Schlussfolgerung
Mit transfer learning konnten wir schnell einen Klassifizierer für unsere benutzerdefinierte Objektklassifizierungsaufgabe zusammenstellen und eine hohe Genauigkeit erzielen. Dieses Beispiel war nicht vollständig fair, da das ursprüngliche VGG-16-Netzwerk auf ImageNet vortrainiert wurde, das bereits Kategorien für verschiedene Katzen- und Hunderassen enthält, und daher wurden wir nur die meisten Muster wiederverwenden, die bereits im Netzwerk vorhanden waren. Sie können eine geringere Genauigkeit für andere domänenspezifische Objekte erwarten, z. B. Details zu einer Produktionslinie in einer Pflanze oder anderen Baumblättern. Sie können sehen, dass komplexere Aufgaben eine höhere Rechenleistung erfordern und häufig von der GPU-Beschleunigung für schulungen profitieren.