Delen via


TensorFlowEstimator Class

Definition

The TensorFlowTransformer is used in following two scenarios.

  1. Scoring with pretrained TensorFlow model: In this mode, the transform extracts hidden layers' values from a pre-trained Tensorflow model and uses outputs as features in ML.Net pipeline.
  2. Retraining of TensorFlow model: In this mode, the transform retrains a TensorFlow model using the user data passed through ML.Net pipeline. Once the model is trained, it's outputs can be used as features for scoring.
public sealed class TensorFlowEstimator : Microsoft.ML.IEstimator<Microsoft.ML.Transforms.TensorFlowTransformer>
type TensorFlowEstimator = class
    interface IEstimator<TensorFlowTransformer>
Public NotInheritable Class TensorFlowEstimator
Implements IEstimator(Of TensorFlowTransformer)
Inheritance
TensorFlowEstimator
Implements

Remarks

The TensorFlowTransform extracts specified outputs using a pre-trained Tensorflow model. Optionally, it can further retrain TensorFlow model on user data to adjust model parameters on the user data ( also know as "Transfer Learning").

For scoring, the transform takes as inputs the pre-trained Tensorflow model, the names of the input nodes, and names of the output nodes whose values we want to extract. For retraining, the transform also requires training related parameters such as the names of optimization operation in the TensorFlow graph, the name of the learning rate operation in the graph and its value, name of the operations in the graph to compute loss and performance metric etc.

This transform requires the Microsoft.ML.TensorFlow nuget to be installed. The TensorFlowTransform has the following assumptions regarding input, output, processing of data, and retraining.

  1. For the input model, currently the TensorFlowTransform supports both the Frozen model format and also the SavedModel format. However, retraining of the model is only possible for the SavedModel format. Checkpoint format is currently neither supported for scoring nor for retraining due lack of TensorFlow C-API support for loading it.
  2. The transform supports scoring only one example at a time. However, retraining can be performed in batches.
  3. Advanced transfer learning/fine tuning scenarios (e.g. adding more layers into the network, changing the shape of inputs, freezing the layers which do not need to be updated during retraining process etc.) are currently not possible due to lack of support for network/graph manipulation inside the model using TensorFlow C-API.
  4. The name of input column(s) should match the name of input(s) in TensorFlow model.
  5. The name of each output column should match one of the operations in the TensorFlow graph.
  6. Currently, double, float, long, int, short, sbyte, ulong, uint, ushort, byte and bool are the acceptable data types for input/output.
  7. Upon success, the transform will introduce a new column in IDataView corresponding to each output column specified.

The inputs and outputs of a TensorFlow model can be obtained using the GetModelSchema() or summarize_graph tools.

Methods

Fit(IDataView)

Trains and returns a TensorFlowTransformer.

GetOutputSchema(SchemaShape)

Returns the SchemaShape of the schema which will be produced by the transformer. Used for schema propagation and verification in a pipeline.

Extension Methods

AppendCacheCheckpoint<TTrans>(IEstimator<TTrans>, IHostEnvironment)

Append a 'caching checkpoint' to the estimator chain. This will ensure that the downstream estimators will be trained against cached data. It is helpful to have a caching checkpoint before trainers that take multiple data passes.

WithOnFitDelegate<TTransformer>(IEstimator<TTransformer>, Action<TTransformer>)

Given an estimator, return a wrapping object that will call a delegate once Fit(IDataView) is called. It is often important for an estimator to return information about what was fit, which is why the Fit(IDataView) method returns a specifically typed object, rather than just a general ITransformer. However, at the same time, IEstimator<TTransformer> are often formed into pipelines with many objects, so we may need to build a chain of estimators via EstimatorChain<TLastTransformer> where the estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this method attach a delegate that will be called once fit is called.

Applies to