TreeExtensions.Gam Method
Definition
Important
Some information relates to prerelease product that may be substantially modified before it’s released. Microsoft makes no warranties, express or implied, with respect to the information provided here.
Overloads
Gam(BinaryClassificationCatalog+BinaryClassificationTrainers, String, String, String, Int32, Int32, Double) |
Create GamBinaryTrainer, which predicts a target using generalized additive models (GAM). |
Gam(BinaryClassificationCatalog+BinaryClassificationTrainers, GamBinaryTrainer+Options) |
Create GamBinaryTrainer using advanced options, which predicts a target using generalized additive models (GAM). |
Gam(RegressionCatalog+RegressionTrainers, GamRegressionTrainer+Options) |
Create GamRegressionTrainer using advanced options, which predicts a target using generalized additive models (GAM). |
Gam(RegressionCatalog+RegressionTrainers, String, String, String, Int32, Int32, Double) |
Create GamRegressionTrainer, which predicts a target using generalized additive models (GAM). |
Gam(BinaryClassificationCatalog+BinaryClassificationTrainers, String, String, String, Int32, Int32, Double)
Create GamBinaryTrainer, which predicts a target using generalized additive models (GAM).
public static Microsoft.ML.Trainers.FastTree.GamBinaryTrainer Gam (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfIterations = 9500, int maximumBinCountPerFeature = 255, double learningRate = 0.002);
static member Gam : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * string * string * string * int * int * double -> Microsoft.ML.Trainers.FastTree.GamBinaryTrainer
<Extension()>
Public Function Gam (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfIterations As Integer = 9500, Optional maximumBinCountPerFeature As Integer = 255, Optional learningRate As Double = 0.002) As GamBinaryTrainer
Parameters
- featureColumnName
- String
The name of the feature column. The column data must be a known-sized vector of Single.
- exampleWeightColumnName
- String
The name of the example weight column (optional).
- numberOfIterations
- Int32
The number of iterations to use in learning the features.
- maximumBinCountPerFeature
- Int32
The maximum number of bins to use to approximate features.
- learningRate
- Double
The learning rate. GAMs work best with a small learning rate.
Returns
Examples
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.BinaryClassification
{
public static class Gam
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness.
var mlContext = new MLContext();
// Create the dataset.
var samples = GenerateData();
// Convert the dataset to an IDataView.
var data = mlContext.Data.LoadFromEnumerable(samples);
// Create training and validation sets.
var dataSets = mlContext.Data.TrainTestSplit(data);
var trainSet = dataSets.TrainSet;
var validSet = dataSets.TestSet;
// Create a GAM trainer.
// Use a small number of bins for this example. The setting below means
// for each feature, we divide its range into 16 discrete regions for
// the training process. Note that these regions are not evenly spaced,
// and that the final model may contain fewer bins, as neighboring bins
// with identical values will be combined. In general, we recommend
// using at least the default number of bins, as a small number of bins
// limits the capacity of the model.
var trainer = mlContext.BinaryClassification.Trainers
.Gam(maximumBinCountPerFeature: 16);
// Fit the model using both of training and validation sets. GAM can use
// a technique called pruning to tune the model to the validation set
// after training to improve generalization.
var model = trainer.Fit(trainSet, validSet);
// Extract the model parameters.
var gam = model.Model.SubModel;
// Now we can inspect the parameters of the Generalized Additive Model
// to understand the fit and potentially learn about our dataset. First,
// we will look at the bias; the bias represents the average prediction
// for the training data.
Console.WriteLine($"Average prediction: {gam.Bias:0.00}");
// Now look at the shape functions that the model has learned. Similar
// to a linear model, we have one response per feature, and they are
// independent. Unlike a linear model, this response is a generic
// function instead of a line. Because we have included a bias term,
// each feature response represents the deviation from the average
// prediction as a function of the feature value.
for (int i = 0; i < gam.NumberOfShapeFunctions; i++)
{
// Break a line.
Console.WriteLine();
// Get the bin upper bounds for the feature.
var binUpperBounds = gam.GetBinUpperBounds(i);
// Get the bin effects; these are the function values for each bin.
var binEffects = gam.GetBinEffects(i);
// Now, write the function to the console. The function is a set of
// bins, and the corresponding function values. You can think of
// GAMs as building a bar-chart or lookup table for each feature.
Console.WriteLine($"Feature{i}");
for (int j = 0; j < binUpperBounds.Count; j++)
Console.WriteLine(
$"x < {binUpperBounds[j]:0.00} => {binEffects[j]:0.000}");
}
// Expected output:
// Average prediction: 0.82
//
// Feature0
// x < -0.44 => 0.286
// x < -0.38 => 0.225
// x < -0.32 => 0.048
// x < -0.26 => -0.110
// x < -0.20 => -0.116
// x < 0.18 => -0.143
// x < 0.25 => -0.115
// x < 0.31 => -0.005
// x < 0.37 => 0.097
// x < 0.44 => 0.263
// x < ∞ => 0.284
//
// Feature1
// x < 0.00 => -0.350
// x < 0.24 => 0.875
// x < 0.31 => -0.138
// x < ∞ => -0.188
// Let's consider this output. To score a given example, we look up the
// first bin where the inequality is satisfied for the feature value.
// We can look at the whole function to get a sense for how the model
// responds to the variable on a global level.The model can be seen to
// reconstruct the parabolic and step-wise function, shifted with
// respect to the average expected output over the training set.
// Very few bins are used to model the second feature because the GAM
// model discards unchanged bins to create smaller models. One last
// thing to notice is that these feature functions can be noisy. While
// we know that Feature1 should be symmetric, this is not captured in
// the model. This is due to noise in the data. Common practice is to
// use resampling methods to estimate a confidence interval at each bin.
// This will help to determine if the effect is real or just sampling
// noise. See for example: Tan, Caruana, Hooker, and Lou.
// "Distill-and-Compare: Auditing Black-Box Models Using Transparent
// Model Distillation."
// <a href='https://arxiv.org/abs/1710.06169'>arXiv:1710.06169</a>."
}
private class Data
{
public bool Label { get; set; }
[VectorType(2)]
public float[] Features { get; set; }
}
/// <summary>
/// Creates a dataset, an IEnumerable of Data objects, for a GAM sample.
/// Feature1 is a parabola centered around 0, while Feature2 is a simple
/// piecewise function.
/// </summary>
/// <param name="numExamples">The number of examples to generate.</param>
/// <param name="seed">The seed for the random number generator used to
/// produce data.</param>
/// <returns></returns>
private static IEnumerable<Data> GenerateData(int numExamples = 25000,
int seed = 1)
{
var rng = new Random(seed);
float centeredFloat() => (float)(rng.NextDouble() - 0.5);
for (int i = 0; i < numExamples; i++)
{
// Generate random, uncoupled features.
var data = new Data
{
Features = new float[2] { centeredFloat(), centeredFloat() }
};
// Compute the label from the shape functions and add noise.
data.Label = Sigmoid(Parabola(data.Features[0])
+ SimplePiecewise(data.Features[1]) + centeredFloat()) > 0.5;
yield return data;
}
}
private static float Parabola(float x) => x * x;
private static float SimplePiecewise(float x)
{
if (x < 0)
return 0;
else if (x < 0.25)
return 1;
else
return 0;
}
private static double Sigmoid(double x) => 1.0 / (1.0 + Math.Exp(-1 * x));
}
}
Applies to
Gam(BinaryClassificationCatalog+BinaryClassificationTrainers, GamBinaryTrainer+Options)
Create GamBinaryTrainer using advanced options, which predicts a target using generalized additive models (GAM).
public static Microsoft.ML.Trainers.FastTree.GamBinaryTrainer Gam (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Microsoft.ML.Trainers.FastTree.GamBinaryTrainer.Options options);
static member Gam : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * Microsoft.ML.Trainers.FastTree.GamBinaryTrainer.Options -> Microsoft.ML.Trainers.FastTree.GamBinaryTrainer
<Extension()>
Public Function Gam (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, options As GamBinaryTrainer.Options) As GamBinaryTrainer
Parameters
- options
- GamBinaryTrainer.Options
Trainer options.
Returns
Examples
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
namespace Samples.Dynamic.Trainers.BinaryClassification
{
public static class GamWithOptions
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness.
var mlContext = new MLContext();
// Create the dataset.
var samples = GenerateData();
// Convert the dataset to an IDataView.
var data = mlContext.Data.LoadFromEnumerable(samples);
// Create training and validation datasets.
var dataSets = mlContext.Data.TrainTestSplit(data);
var trainSet = dataSets.TrainSet;
var validSet = dataSets.TestSet;
// Create a GAM trainer.
// Use a small number of bins for this example. The setting below means
// for each feature, we divide its range into 16 discrete regions for
// the training process. Note that these regions are not evenly spaced,
// and that the final model may contain fewer bins, as neighboring bins
// with identical values will be combined. In general, we recommend
// using at least the default number of bins, as a small number of bins
// limits the capacity of the model. Also, set the learning rate to half
// the default to slow down the gradient descent, and double the number
// of iterations to compensate.
var trainer = mlContext.BinaryClassification.Trainers.Gam(
new GamBinaryTrainer.Options
{
NumberOfIterations = 19000,
MaximumBinCountPerFeature = 16,
LearningRate = 0.001
});
// Fit the model using both of training and validation sets. GAM can use
// a technique called pruning to tune the model to the validation set
// after training to improve generalization.
var model = trainer.Fit(trainSet, validSet);
// Extract the model parameters.
var gam = model.Model.SubModel;
// Now we can inspect the parameters of the Generalized Additive Model
// to understand the fit and potentially learn about our dataset. First,
// we will look at the bias; the bias represents the average prediction
// for the training data.
Console.WriteLine($"Average prediction: {gam.Bias:0.00}");
// Now look at the shape functions that the model has learned. Similar
// to a linear model, we have one response per feature, and they are
// independent. Unlike a linear model, this response is a generic
// function instead of a line. Because we have included a bias term,
// each feature response represents the deviation from the average
// prediction as a function of the feature value.
for (int i = 0; i < gam.NumberOfShapeFunctions; i++)
{
// Break a line.
Console.WriteLine();
// Get the bin upper bounds for the feature.
var binUpperBounds = gam.GetBinUpperBounds(i);
// Get the bin effects; these are the function values for each bin.
var binEffects = gam.GetBinEffects(i);
// Now, write the function to the console. The function is a set of
// bins, and the corresponding function values. You can think of
// GAMs as building a bar-chart or lookup table for each feature.
Console.WriteLine($"Feature{i}");
for (int j = 0; j < binUpperBounds.Count; j++)
Console.WriteLine(
$"x < {binUpperBounds[j]:0.00} => {binEffects[j]:0.000}");
}
// Expected output:
// Average prediction: 0.82
//
// Feature0
// x < -0.44 => 0.286
// x < -0.38 => 0.225
// x < -0.32 => 0.048
// x < -0.26 => -0.110
// x < -0.20 => -0.116
// x < 0.18 => -0.143
// x < 0.25 => -0.115
// x < 0.31 => -0.005
// x < 0.37 => 0.097
// x < 0.44 => 0.263
// x < ∞ => 0.284
//
// Feature1
// x < 0.00 => -0.350
// x < 0.24 => 0.875
// x < 0.31 => -0.138
// x < ∞ => -0.188
// Let's consider this output. To score a given example, we look up the
// first bin where the inequality is satisfied for the feature value.
// We can look at the whole function to get a sense for how the model
// responds to the variable on a global level. The model can be seen to
// reconstruct the parabolic and step-wise function, shifted with
// respect to the average expected output over the training set. Very
// few bins are used to model the second feature because the GAM model
// discards unchanged bins to create smaller models.One last thing to
// notice is that these feature functions can be noisy. While we know
// that Feature1 should be symmetric, this is not captured in the model.
// This is due to noise in the data. Common practice is to use
// resampling methods to estimate a confidence interval at each bin.
// This will help to determine if the effect is real or just sampling
// noise. See for example: Tan, Caruana, Hooker, and Lou.
// "Distill-and-Compare: Auditing Black-Box Models Using Transparent
// Model Distillation."
// <a href='https://arxiv.org/abs/1710.06169'>arXiv:1710.06169</a>."
}
private class Data
{
public bool Label { get; set; }
[VectorType(2)]
public float[] Features { get; set; }
}
/// <summary>
/// Creates a dataset, an IEnumerable of Data objects, for a GAM sample.
/// Feature1 is a parabola centered around 0, while Feature2 is a simple
/// piecewise function.
/// </summary>
/// <param name="numExamples">The number of examples to generate.</param>
/// <param name="seed">The seed for the random number generator used to
/// produce data.</param>
/// <returns></returns>
private static IEnumerable<Data> GenerateData(int numExamples = 25000,
int seed = 1)
{
var rng = new Random(seed);
float centeredFloat() => (float)(rng.NextDouble() - 0.5);
for (int i = 0; i < numExamples; i++)
{
// Generate random, uncoupled features.
var data = new Data
{
Features = new float[2] { centeredFloat(), centeredFloat() }
};
// Compute the label from the shape functions and add noise.
data.Label = Sigmoid(Parabola(data.Features[0]) +
SimplePiecewise(data.Features[1]) + centeredFloat()) > 0.5;
yield return data;
}
}
private static float Parabola(float x) => x * x;
private static float SimplePiecewise(float x)
{
if (x < 0)
return 0;
else if (x < 0.25)
return 1;
else
return 0;
}
private static double Sigmoid(double x) => 1.0 / (1.0 + Math.Exp(-1 * x));
}
}
Applies to
Gam(RegressionCatalog+RegressionTrainers, GamRegressionTrainer+Options)
Create GamRegressionTrainer using advanced options, which predicts a target using generalized additive models (GAM).
public static Microsoft.ML.Trainers.FastTree.GamRegressionTrainer Gam (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, Microsoft.ML.Trainers.FastTree.GamRegressionTrainer.Options options);
static member Gam : Microsoft.ML.RegressionCatalog.RegressionTrainers * Microsoft.ML.Trainers.FastTree.GamRegressionTrainer.Options -> Microsoft.ML.Trainers.FastTree.GamRegressionTrainer
<Extension()>
Public Function Gam (catalog As RegressionCatalog.RegressionTrainers, options As GamRegressionTrainer.Options) As GamRegressionTrainer
Parameters
The RegressionCatalog.
- options
- GamRegressionTrainer.Options
Trainer options.
Returns
Examples
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
namespace Samples.Dynamic.Trainers.Regression
{
public static class GamWithOptions
{
// This example requires installation of additional NuGet
// package for Microsoft.ML.FastTree found at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define trainer options.
var options = new GamRegressionTrainer.Options
{
LabelColumnName = nameof(DataPoint.Label),
FeatureColumnName = nameof(DataPoint.Features),
// The entropy (regularization) coefficient.
EntropyCoefficient = 0.3,
// Reduce the number of iterations to 50.
NumberOfIterations = 50
};
// Define the trainer.
var pipeline =
mlContext.Regression.Trainers.Gam(options);
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// Expected output:
// Label: 0.985, Prediction: 0.841
// Label: 0.155, Prediction: 0.187
// Label: 0.515, Prediction: 0.496
// Label: 0.566, Prediction: 0.467
// Label: 0.096, Prediction: 0.144
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Mean Absolute Error: 0.04
// Mean Squared Error: 0.01
// Root Mean Squared Error: 0.05
// RSquared: 0.98 (closer to 1 is better. The worst case is 0)
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}
Applies to
Gam(RegressionCatalog+RegressionTrainers, String, String, String, Int32, Int32, Double)
Create GamRegressionTrainer, which predicts a target using generalized additive models (GAM).
public static Microsoft.ML.Trainers.FastTree.GamRegressionTrainer Gam (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfIterations = 9500, int maximumBinCountPerFeature = 255, double learningRate = 0.002);
static member Gam : Microsoft.ML.RegressionCatalog.RegressionTrainers * string * string * string * int * int * double -> Microsoft.ML.Trainers.FastTree.GamRegressionTrainer
<Extension()>
Public Function Gam (catalog As RegressionCatalog.RegressionTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfIterations As Integer = 9500, Optional maximumBinCountPerFeature As Integer = 255, Optional learningRate As Double = 0.002) As GamRegressionTrainer
Parameters
The RegressionCatalog.
- featureColumnName
- String
The name of the feature column. The column data must be a known-sized vector of Single.
- exampleWeightColumnName
- String
The name of the example weight column (optional).
- numberOfIterations
- Int32
The number of iterations to use in learning the features.
- maximumBinCountPerFeature
- Int32
The maximum number of bins to use to approximate features.
- learningRate
- Double
The learning rate. GAMs work best with a small learning rate.
Returns
Examples
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.Regression
{
public static class Gam
{
// This example requires installation of additional NuGet
// package for Microsoft.ML.FastTree found at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline = mlContext.Regression.Trainers.Gam(
labelColumnName: nameof(DataPoint.Label),
featureColumnName: nameof(DataPoint.Features));
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// Expected output:
// Label: 0.985, Prediction: 0.948
// Label: 0.155, Prediction: 0.089
// Label: 0.515, Prediction: 0.463
// Label: 0.566, Prediction: 0.509
// Label: 0.096, Prediction: 0.106
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Mean Absolute Error: 0.03
// Mean Squared Error: 0.00
// Root Mean Squared Error: 0.03
// RSquared: 0.99 (closer to 1 is better. The worst case is 0)
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}