StandardTrainersCatalog.OnlineGradientDescent 方法
定義
重要
部分資訊涉及發行前產品,在發行之前可能會有大幅修改。 Microsoft 對此處提供的資訊,不做任何明確或隱含的瑕疵擔保。
多載
OnlineGradientDescent(RegressionCatalog+RegressionTrainers, String, String, IRegressionLoss, Single, Boolean, Single, Int32)
建立 OnlineGradientDescentTrainer ,其會使用線性回歸模型來預測目標。
public static Microsoft.ML.Trainers.OnlineGradientDescentTrainer OnlineGradientDescent (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", Microsoft.ML.Trainers.IRegressionLoss lossFunction = default, float learningRate = 0.1, bool decreaseLearningRate = true, float l2Regularization = 0, int numberOfIterations = 1);
static member OnlineGradientDescent : Microsoft.ML.RegressionCatalog.RegressionTrainers * string * string * Microsoft.ML.Trainers.IRegressionLoss * single * bool * single * int -> Microsoft.ML.Trainers.OnlineGradientDescentTrainer
<Extension()>
Public Function OnlineGradientDescent (catalog As RegressionCatalog.RegressionTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional lossFunction As IRegressionLoss = Nothing, Optional learningRate As Single = 0.1, Optional decreaseLearningRate As Boolean = true, Optional l2Regularization As Single = 0, Optional numberOfIterations As Integer = 1) As OnlineGradientDescentTrainer
參數
回歸目錄定型器物件。
- lossFunction
- IRegressionLoss
定型程式中最小化的 遺失 函式。 例如, SquaredLoss 使用 會導致最小平方定型器。
- learningRate
- Single
SGD 所使用的初始學習率。
- decreaseLearningRate
- Boolean
在反復專案進度時降低學習速率。
- numberOfIterations
- Int32
通過訓練資料集的傳遞數目。
傳回
範例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.Regression
{
public static class OnlineGradientDescent
{
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline = mlContext.Regression.Trainers.OnlineGradientDescent(
labelColumnName: nameof(DataPoint.Label),
featureColumnName: nameof(DataPoint.Features));
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// This trainer is not numerically stable.
// Please see issue #2425.
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}
適用於
OnlineGradientDescent(RegressionCatalog+RegressionTrainers, OnlineGradientDescentTrainer+Options)
OnlineGradientDescentTrainer使用進階選項建立,以使用線性回歸模型預測目標。
public static Microsoft.ML.Trainers.OnlineGradientDescentTrainer OnlineGradientDescent (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, Microsoft.ML.Trainers.OnlineGradientDescentTrainer.Options options);
static member OnlineGradientDescent : Microsoft.ML.RegressionCatalog.RegressionTrainers * Microsoft.ML.Trainers.OnlineGradientDescentTrainer.Options -> Microsoft.ML.Trainers.OnlineGradientDescentTrainer
<Extension()>
Public Function OnlineGradientDescent (catalog As RegressionCatalog.RegressionTrainers, options As OnlineGradientDescentTrainer.Options) As OnlineGradientDescentTrainer
參數
回歸目錄定型器物件。
定型器選項。
傳回
範例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
namespace Samples.Dynamic.Trainers.Regression
{
public static class OnlineGradientDescentWithOptions
{
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define trainer options.
var options = new OnlineGradientDescentTrainer.Options
{
LabelColumnName = nameof(DataPoint.Label),
FeatureColumnName = nameof(DataPoint.Features),
// Change the loss function.
LossFunction = new TweedieLoss(),
// Give an extra gain to more recent updates.
RecencyGain = 0.1f,
// Turn off lazy updates.
LazyUpdate = false,
// Specify scale for initial weights.
InitialWeightsDiameter = 0.2f
};
// Define the trainer.
var pipeline =
mlContext.Regression.Trainers.OnlineGradientDescent(options);
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// This trainer is not numerically stable.
// Please see issue #2425.
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
// This trainer is not numerically stable. Please see
// issue #2425.
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}