ImageClassificationTrainer Class
Definition
Important
Some information relates to prerelease product that may be substantially modified before it’s released. Microsoft makes no warranties, express or implied, with respect to the information provided here.
The IEstimator<TTransformer> for training a Deep Neural Network(DNN) to classify images.
public sealed class ImageClassificationTrainer : Microsoft.ML.Trainers.TrainerEstimatorBase<Microsoft.ML.Data.MulticlassPredictionTransformer<Microsoft.ML.Vision.ImageClassificationModelParameters>,Microsoft.ML.Vision.ImageClassificationModelParameters>
type ImageClassificationTrainer = class
inherit TrainerEstimatorBase<MulticlassPredictionTransformer<ImageClassificationModelParameters>, ImageClassificationModelParameters>
Public NotInheritable Class ImageClassificationTrainer
Inherits TrainerEstimatorBase(Of MulticlassPredictionTransformer(Of ImageClassificationModelParameters), ImageClassificationModelParameters)
- Inheritance
-
TrainerEstimatorBase<MulticlassPredictionTransformer<ImageClassificationModelParameters>,ImageClassificationModelParameters>ImageClassificationTrainer
Remarks
To create this trainer, use ImageClassification.
Input and Output Columns
The input label column data must be key type and the feature column must be a variable-sized vector of Byte.
This trainer outputs the following columns:
Output Column Name | Column Type | Description |
---|---|---|
Score |
Vector ofSingle | The scores of all classes.Higher value means higher probability to fall into the associated class. If the i-th element has the largest value, the predicted label index would be i.Note that i is zero-based index. |
PredictedLabel |
key type | The predicted label's index. If its value is i, the actual label would be the i-th category in the key-valued input label type. |
Trainer Characteristics
Machine learning task | Multiclass classification |
Is normalization required? | No |
Is caching required? | No |
Required NuGet in addition to Microsoft.ML | Microsoft.ML.Vision and SciSharp.TensorFlow.Redist / SciSharp.TensorFlow.Redist-Windows-GPU / SciSharp.TensorFlow.Redist-Linux-GPU |
Exportable to ONNX | No |
Using TensorFlow based APIs
In order to run any TensorFlow based ML.Net APIs you must first add a NuGet dependency on the TensorFlow redist library. There are currently two versions you can use. One which is compiled for GPU support, and one which has CPU support only.
CPU only
CPU based TensorFlow is currently supported on:
- Linux
- MacOS
- Windows
To get TensorFlow working on the CPU only all that is to take a NuGet dependency on SciSharp.TensorFlow.Redist v1.14.0
GPU support
GPU based TensorFlow is currently supported on:
- Windows
- Linux As of now TensorFlow does not support running on GPUs for MacOS, so we cannot support this currently.
Prerequisites
You must have at least one CUDA compatible GPU, for a list of compatible GPUs see Nvidia's Guide.
Install CUDA v10.1 and CUDNN v7.6.4.
Make sure you install CUDA v10.1, not any other newer version. After downloading CUDNN v7.6.4 .zip file and unpacking it, you need to do the following steps:
copy <CUDNN_zip_files_path>\cuda\bin\cudnn64_7.dll to <YOUR_DRIVE>\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\bin
For C/C++ development:
Copy <CUDNN_zip_files_path>\cuda\ include\cudnn.h to <YOUR_DRIVE>\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\include
Copy <CUDNN_zip_files_path>\cuda\lib\x64\cudnn.lib to <YOUR_DRIVE>\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.1\lib\x64
For further details in cuDNN you can follow the cuDNN Installation guide.
Usage
To use TensorFlow with GPU support take a NuGet dependency on the following package depending on your OS:
- Windows -> SciSharp.TensorFlow.Redist-Windows-GPU
- Linux -> SciSharp.TensorFlow.Redist-Linux-GPU
No code modification should be necessary to leverage the GPU for TensorFlow operations.
Troubleshooting
If you are not able to use your GPU after adding the GPU based TensorFlow NuGet, make sure that there is only a dependency on the GPU based version. If you have a dependency on both NuGets, the CPU based TensorFlow will run instead.
Training Algorithm Details
Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose of classifying images. The technique was inspired from TensorFlow's retrain image classification tutorial
Fields
FeatureColumn |
The feature column that the trainer expects. (Inherited from TrainerEstimatorBase<TTransformer,TModel>) |
LabelColumn |
The label column that the trainer expects. Can be |
WeightColumn |
The weight column that the trainer expects. Can be |
Properties
Info |
Auxiliary information about the trainer in terms of its capabilities and requirements. |
Methods
Finalize() | |
Fit(IDataView, IDataView) |
Trains a ImageClassificationTrainer using both training and validation data, returns a ImageClassificationModelParameters. |
Fit(IDataView) |
Trains and returns a ITransformer. (Inherited from TrainerEstimatorBase<TTransformer,TModel>) |
GetOutputSchema(SchemaShape) | (Inherited from TrainerEstimatorBase<TTransformer,TModel>) |
Extension Methods
AppendCacheCheckpoint<TTrans>(IEstimator<TTrans>, IHostEnvironment) |
Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes. |
WithOnFitDelegate<TTransformer>(IEstimator<TTransformer>, Action<TTransformer>) |
Given an estimator, return a wrapping object that will call a delegate once Fit(IDataView) is called. It is often important for an estimator to return information about what was fit, which is why the Fit(IDataView) method returns a specifically typed object, rather than just a general ITransformer. However, at the same time, IEstimator<TTransformer> are often formed into pipelines with many objects, so we may need to build a chain of estimators via EstimatorChain<TLastTransformer> where the estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this method attach a delegate that will be called once fit is called. |