Modelos entrenados previamente y aprendizaje por transferencia
- 10 minutos
El entrenamiento de redes neuronales convolucionales (CNN) puede llevar una cantidad considerable de tiempo y se requiere una gran cantidad de datos para esa tarea. Gran parte del tiempo se dedica a experimentar para encontrar los mejores filtros de bajo nivel que una red necesita para extraer patrones de las imágenes. Surge una pregunta natural: ¿podemos usar una red neuronal entrenada en un conjunto de datos y adaptarla para clasificar diferentes imágenes sin un proceso de entrenamiento completo?
Este enfoque se denomina aprendizaje de transferencia, ya que transferimos algunos conocimientos de un modelo de red neuronal a otro. En el aprendizaje de transferencia, normalmente empezamos con un modelo entrenado previamente, que se ha entrenado en algún conjunto de datos de imagen grande, como ImageNet. Esos modelos ya realizan un buen trabajo extrayendo diferentes características de imágenes genéricas y, en muchos casos, simplemente crear un clasificador sobre esas características extraídas puede producir un buen resultado.
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
from PIL import Image
Conjunto de datos de gatos frente a perros
En esta unidad, resolveremos un problema de la vida real de clasificar imágenes de gatos y perros. Por este motivo, usaremos Kaggle Cats vs. Dogs Dataset, que también se puede descargar de Microsoft.
Vamos a descargar este conjunto de datos y extraerlo en el data directorio :
import urllib.request
import zipfile
dataset_url = 'https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip'
data_dir = 'data'
os.makedirs(data_dir, exist_ok=True)
zip_path = os.path.join(data_dir, 'kagglecatsanddogs_5340.zip')
if not os.path.exists(zip_path):
urllib.request.urlretrieve(dataset_url, zip_path)
if not os.path.exists(os.path.join(data_dir, 'PetImages')):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
El conjunto de datos puede contener algunos archivos de imagen dañados. Vamos a definir una función auxiliar para comprobarlas y quitarlas antes de cargarlas:
def check_image(fn):
try:
im = Image.open(fn)
im.verify()
return True
except (IOError, SyntaxError):
return False
def check_image_dir(dir_path):
for fn in glob.glob(dir_path):
if not check_image(fn):
print(f"Corrupt image: {fn}")
os.remove(fn)
# Remove any corrupt images from the dataset
check_image_dir('data/PetImages/Cat/*.jpg')
check_image_dir('data/PetImages/Dog/*.jpg')
Carga del conjunto de datos
En los ejemplos anteriores, se cargaban conjuntos de datos integrados en Keras. Ahora usaremos nuestro propio conjunto de datos, que necesitamos cargar desde un directorio de imágenes. Keras incluye una función image_dataset_from_directory auxiliar que puede crear a tf.data.Dataset partir de un directorio de imágenes organizadas en subdirectorios por clase.
data_dir = 'data/PetImages'
batch_size = 32
ds_train = keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='training',
seed=13,
image_size=(224, 224),
batch_size=batch_size
)
ds_test = keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset='validation',
seed=13,
image_size=(224, 224),
batch_size=batch_size
)
Nota:
Usamos el mismo seed valor al crear las divisiones de entrenamiento y validación para garantizar que no se superpongan entre los dos subconjuntos.
Podemos comprobar los nombres de clase que se deducen automáticamente de la estructura de directorios:
# Expected output: ['Cat', 'Dog']
ds_train.class_names
Vamos a definir un asistente para visualizar ejemplos de nuestro conjunto de datos (se trata de una nueva versión adaptada display_dataset para los datos por lotes):
def display_dataset(images, labels, classes=None, cols=8):
n = len(images)
rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
axes = axes.flatten() if n > 1 else [axes]
for i, ax in enumerate(axes):
if i < n:
ax.imshow(images[i])
label = int(labels[i][0]) if labels[i].ndim > 0 else int(labels[i])
title = classes[label] if classes else str(label)
ax.set_title(title, fontsize=8)
ax.axis('off')
plt.tight_layout()
plt.show()
El conjunto de datos produce lotes de imágenes y etiquetas. Cada lote contiene 32 imágenes de tamaño 224×224 con 3 canales de color y etiquetas correspondientes:
for x, y in ds_train:
print(f"Training batch shape: features={x.shape}, labels={y.shape}")
x_sample, y_sample = x, y
break
# Expected output: Training batch shape: features=(32, 224, 224, 3), labels=(32,)
display_dataset(x_sample.numpy().astype(np.uint8), np.expand_dims(y_sample, 1), classes=ds_train.class_names)
Nota:
Los valores de píxeles de imagen están en el intervalo de 0 a 255. Algunos modelos requieren que la entrada se escale a 0-1 o se preprocese mediante una función específica del modelo. VGG-16 tiene su propia preprocess_input función que usamos más adelante.
Modelos entrenados previamente
Hay muchas redes neuronales previamente entrenadas para la clasificación de imágenes que se han entrenado en el conjunto de datos ImageNet, que contiene más de 14 millones de imágenes en 1000 categorías. Una de las arquitecturas más conocidas es VGG-16, que logra una buena precisión mientras es fácil de entender. Vamos a cargar un modelo VGG-16 con pesos entrenados previamente:
vgg = keras.applications.VGG16()
Vamos a intentar usar esta red entrenada previamente para clasificar una de nuestras imágenes. La red VGG-16 se entrenó en ImageNet, que incluye categorías para diversas razas de perros y gatos:
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
# tf.argmax returns the index of the highest-probability class
print(f"Most probable class = {tf.argmax(res, 1)}")
# decode_predictions maps class indices to human-readable labels
keras.applications.vgg16.decode_predictions(res.numpy())
La preprocess_input función escala los valores de píxeles adecuadamente para el modelo VGG-16. La función decode_predictions devuelve las 5 clases de ImageNet más probables junto con sus puntuaciones de confianza.
Veamos la arquitectura de VGG-16:
# Shows all layers including convolutional blocks and final Dense classifier
vgg.summary()
Cálculos de GPU
Las redes neuronales profundas requieren una potencia computacional bastante sustancial para el entrenamiento. El uso de una GPU puede acelerar significativamente el proceso de entrenamiento. Vamos a comprobar si hay una GPU disponible:
# Lists available GPU devices; an empty list means CPU-only
tf.config.list_physical_devices('GPU')
Extracción de características de VGG
Si queremos usar VGG-16 para extraer características de nuestras imágenes, necesitamos el modelo sin las capas de clasificación finales. Para ello, se puede especificar include_top=False:
vgg = keras.applications.VGG16(include_top=False)
inp = keras.applications.vgg16.preprocess_input(x_sample[:1])
res = vgg(inp)
# The output is a 7x7 grid of 512 feature maps
print(f"Shape after applying VGG-16: {res[0].shape}")
plt.figure(figsize=(15, 3))
plt.imshow(res[0].numpy().reshape(-1, 512))
El vector de característica resultante tiene valores de forma 7×7×512 = 25088. Esto representa las características de alto nivel que VGG-16 ha aprendido a extraer de la imagen. Podemos calcular previamente de forma manual estas características para todo el conjunto de datos y luego entrenar un clasificador encima de ellas.
Advertencia
.take(25) Usamos y .take(10) a continuación para limitar el tamaño del conjunto de datos para un entrenamiento más rápido en este ejemplo. Cada lote contiene 32 imágenes, por lo que solo usamos 800 imágenes de entrenamiento y 320 imágenes de prueba. La precisión de ~90% notificada aquí refleja este pequeño subconjunto y puede que no se generalice en el conjunto de datos completo. Para su uso en producción, entrene en el conjunto de datos completo.
def preprocess(x, y):
return keras.applications.vgg16.preprocess_input(x), y
ds_features_train = ds_train.take(25).map(preprocess).map(lambda x, y: (vgg(x), y)).cache()
ds_features_test = ds_test.take(10).map(preprocess).map(lambda x, y: (vgg(x), y)).cache()
for x, y in ds_features_train:
# Expected output: (32, 7, 7, 512) (32,)
print(x.shape, y.shape)
break
Nota:
Llamamos a .cache() después de extraer características para que el modelo VGG-16 solo se ejecute una vez por lote en lugar de cada época.
Ahora podemos crear un clasificador simple en las características extraídas. Dado que las características de VGG ya son muy informativas, incluso una sola capa densa puede lograr buenos resultados:
model = keras.Sequential([
keras.layers.Input(shape=(7, 7, 512)),
keras.layers.Flatten(),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
hist = model.fit(ds_features_train, validation_data=ds_features_test)
# Expected: validation accuracy around 90%
Con alrededor de 90% precisión, esto demuestra la eficacia de las características previamente entrenadas. Sin embargo, precalcular manualmente las características es complicado.
Transferencia de aprendizaje mediante una red VGG
Podemos evitar las características de precálculo manualmente combinando el extractor de características VGG-16 y nuestro clasificador en una sola red. La clave es congelar las capas previamente entrenadas para que sus pesos no se actualicen durante el entrenamiento.
Movemos el paso preprocess_input a la canalización de datos en lugar de integrarlo en el modelo como una capa Lambda. Esto mantiene el modelo serializable para que podamos guardarlo y cargarlo más adelante:
def preprocess(x, y):
return keras.applications.vgg16.preprocess_input(x), y
ds_train_preprocessed = ds_train.map(preprocess)
ds_test_preprocessed = ds_test.map(preprocess)
Nota:
Dado que el preprocesamiento forma parte de la canalización de datos en lugar del modelo, también debe aplicarse preprocess_input a los datos de entrada en el momento de la inferencia.
Ahora construimos el modelo con la base VGG-16 congelada.
vgg_base = keras.applications.VGG16(include_top=False, input_shape=(224, 224, 3))
vgg_base.trainable = False
model = keras.Sequential([
keras.layers.Input(shape=(224, 224, 3)),
vgg_base,
keras.layers.Flatten(),
keras.layers.Dense(1, activation='sigmoid')
])
# Notice: ~15 million params are non-trainable (VGG-16), only ~25k are trainable
model.summary()
Al congelar las capas VGG-16, solo necesitamos entrenar la capa densa final, que tiene aproximadamente 25 000 parámetros en lugar de los 15 millones completos. Esto hace que el entrenamiento sea más rápido:
Advertencia
Al igual que con la sección anterior, usamos .take(50) y .take(10) para limitar el conjunto de datos para un entrenamiento más rápido. Esto significa que estamos entrenando en aproximadamente 1600 imágenes y validando en 320. Los resultados de precisión pueden diferir al entrenar en el conjunto de datos completo.
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
hist = model.fit(ds_train_preprocessed.take(50), validation_data=ds_test_preprocessed.take(10))
# Expected: validation accuracy around 90% or higher
Guardar y cargar el modelo
Una vez que tengamos un modelo entrenado, podemos guardarlo en el disco y volver a cargarlo más adelante sin volver a entrenar:
model.save('data/cats_dogs.keras')
Nota:
La .keras extensión usa el formato keras 3 nativo. Si usa una versión anterior de TensorFlow/Keras, use .h5 (formato HDF5) o el formato de directorio SavedModel en su lugar.
Para cargar el modelo guardado:
model = keras.models.load_model('data/cats_dogs.keras')
Otros modelos de Computer Vision
VGG-16 es una de las arquitecturas de CNN más sencillas de comprender, debido a su estructura uniforme de 3×3 convoluciones apiladas. Keras proporciona muchas más redes previamente entrenadas. Las más usadas entre ellas son las arquitecturas resNet , desarrolladas por Microsoft y Inception por Google.
Mejora de los resultados con aumento de datos
Al trabajar con datos de entrenamiento limitados, el aumento de datos puede mejorar significativamente la generalización. Al aplicar transformaciones aleatorias (como volteos horizontales, rotaciones y zooms) a imágenes de entrenamiento, aumentamos artificialmente la diversidad del conjunto de datos. Keras proporciona capas de aumento como keras.layers.RandomFlip, keras.layers.RandomRotationy keras.layers.RandomZoom que se pueden agregar directamente al modelo o a la canalización de datos.
Conclusión
Con el aprendizaje de transferencia hemos podido reunir rápidamente un clasificador para nuestra tarea de clasificación de objetos personalizada y lograr una alta precisión. Este ejemplo no era completamente justo porque la red VGG-16 original estaba previamente entrenada en ImageNet, que ya incluye categorías para varias razas de gatos y perros, y por lo tanto solo se reutilizaba la mayoría de los patrones que ya estaban presentes en la red. Puede esperar una menor precisión para otros objetos específicos del dominio, como los detalles de una línea de producción en una planta o hojas de árbol diferentes. Puede ver que las tareas más complejas requieren una mayor potencia computacional y, a menudo, se benefician de la aceleración de GPU para el entrenamiento.