TreeExtensions.FastTree 方法

定义

重载

FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, FastTreeBinaryTrainer+Options)

使用高级选项创建 FastTreeBinaryTrainer ,该选项使用决策树二元分类模型预测目标。

FastTree(RankingCatalog+RankingTrainers, FastTreeRankingTrainer+Options)

使用决策树排名模型创建一个 FastTreeRankingTrainer 具有高级选项的高级选项,根据输入的相关性对一系列输入进行排名。

FastTree(RegressionCatalog+RegressionTrainers, FastTreeRegressionTrainer+Options)

使用高级选项创建 FastTreeRegressionTrainer ,该选项使用决策树回归模型预测目标。

FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, String, String, String, Int32, Int32, Int32, Double)

创建 FastTreeBinaryTrainer,它使用决策树二元分类模型预测目标。

FastTree(RegressionCatalog+RegressionTrainers, String, String, String, Int32, Int32, Int32, Double)

创建 FastTreeRegressionTrainer,它使用决策树回归模型预测目标。

FastTree(RankingCatalog+RankingTrainers, String, String, String, String, Int32, Int32, Int32, Double)

使用决策树排名模型创建一 FastTreeRankingTrainer个基于其相关性对一系列输入进行排名。

FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, FastTreeBinaryTrainer+Options)

使用高级选项创建 FastTreeBinaryTrainer ,该选项使用决策树二元分类模型预测目标。

public static Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer FastTree (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options options);
static member FastTree : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer
<Extension()>
Public Function FastTree (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, options As FastTreeBinaryTrainer.Options) As FastTreeBinaryTrainer

参数

options
FastTreeBinaryTrainer.Options

训练器选项。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;

namespace Samples.Dynamic.Trainers.BinaryClassification
{
    public static class FastTreeWithOptions
    {
        // This example requires installation of additional NuGet package for 
        // Microsoft.ML.FastTree at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define trainer options.
            var options = new FastTreeBinaryTrainer.Options
            {
                // Use L2Norm for early stopping.
                EarlyStoppingMetric = EarlyStoppingMetric.L2Norm,
                // Create a simpler model by penalizing usage of new features.
                FeatureFirstUsePenalty = 0.1,
                // Reduce the number of trees to 50.
                NumberOfTrees = 50
            };

            // Define the trainer.
            var pipeline = mlContext.BinaryClassification.Trainers
                .FastTree(options);

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data
                .LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data
                .CreateEnumerable<Prediction>(transformedTestData,
                reuseRowObject: false).ToList();

            // Print 5 predictions.
            foreach (var p in predictions.Take(5))
                Console.WriteLine($"Label: {p.Label}, "
                    + $"Prediction: {p.PredictedLabel}");

            // Expected output:
            //   Label: True, Prediction: True
            //   Label: False, Prediction: False
            //   Label: True, Prediction: True
            //   Label: True, Prediction: True
            //   Label: False, Prediction: False

            // Evaluate the overall metrics.
            var metrics = mlContext.BinaryClassification
                .Evaluate(transformedTestData);

            PrintMetrics(metrics);

            // Expected output:
            //   Accuracy: 0.78
            //   AUC: 0.88
            //   F1 Score: 0.79
            //   Negative Precision: 0.83
            //   Negative Recall: 0.74
            //   Positive Precision: 0.74
            //   Positive Recall: 0.84
            //   Log Loss: 0.62
            //   Log Loss Reduction: 37.77
            //   Entropy: 1.00
            //
            //  TEST POSITIVE RATIO:    0.4760 (238.0/(238.0+262.0))
            //  Confusion table
            //            ||======================
            //  PREDICTED || positive | negative | Recall
            //  TRUTH     ||======================
            //   positive ||      185 |       53 | 0.7773
            //   negative ||       83 |      179 | 0.6832
            //            ||======================
            //  Precision ||   0.6903 |   0.7716 |
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            float randomFloat() => (float)random.NextDouble();
            for (int i = 0; i < count; i++)
            {
                var label = randomFloat() > 0.5f;
                yield return new DataPoint
                {
                    Label = label,
                    // Create random features that are correlated with the label.
                    // For data points with false label, the feature values are
                    // slightly increased by adding a constant.
                    Features = Enumerable.Repeat(label, 50)
                        .Select(x => x ? randomFloat() : randomFloat() +
                        0.03f).ToArray()

                };
            }
        }

        // Example with label and 50 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public bool Label { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public bool Label { get; set; }
            // Predicted label from the trainer.
            public bool PredictedLabel { get; set; }
        }

        // Pretty-print BinaryClassificationMetrics objects.
        private static void PrintMetrics(BinaryClassificationMetrics metrics)
        {
            Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
            Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
            Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
            Console.WriteLine($"Negative Precision: " +
                $"{metrics.NegativePrecision:F2}");

            Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
            Console.WriteLine($"Positive Precision: " +
                $"{metrics.PositivePrecision:F2}");

            Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
            Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
        }
    }
}


适用于

FastTree(RankingCatalog+RankingTrainers, FastTreeRankingTrainer+Options)

使用决策树排名模型创建一个 FastTreeRankingTrainer 具有高级选项的高级选项,根据输入的相关性对一系列输入进行排名。

public static Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer FastTree (this Microsoft.ML.RankingCatalog.RankingTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options options);
static member FastTree : Microsoft.ML.RankingCatalog.RankingTrainers * Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer
<Extension()>
Public Function FastTree (catalog As RankingCatalog.RankingTrainers, options As FastTreeRankingTrainer.Options) As FastTreeRankingTrainer

参数

options
FastTreeRankingTrainer.Options

训练器选项。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;

namespace Samples.Dynamic.Trainers.Ranking
{
    public static class FastTreeWithOptions
    {
        // This example requires installation of additional NuGet package for 
        // Microsoft.ML.FastTree at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define trainer options.
            var options = new FastTreeRankingTrainer.Options
            {
                // Use NdcgAt3 for early stopping.
                EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt3,
                // Create a simpler model by penalizing usage of new features.
                FeatureFirstUsePenalty = 0.1,
                // Reduce the number of trees to 50.
                NumberOfTrees = 50,
                // Specify the row group column name.
                RowGroupColumnName = "GroupId"
            };

            // Define the trainer.
            var pipeline = mlContext.Ranking.Trainers.FastTree(options);

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data.LoadFromEnumerable(
                GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Take the top 5 rows.
            var topTransformedTestData = mlContext.Data.TakeRows(
                transformedTestData, 5);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data.CreateEnumerable<Prediction>(
                topTransformedTestData, reuseRowObject: false).ToList();

            // Print 5 predictions.
            foreach (var p in predictions)
                Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");

            // Expected output:
            //   Label: 5, Score: 8.807633
            //   Label: 1, Score: -10.71331
            //   Label: 3, Score: -8.134147
            //   Label: 3, Score: -6.545538
            //   Label: 1, Score: -10.27982

            // Evaluate the overall metrics.
            var metrics = mlContext.Ranking.Evaluate(transformedTestData);
            PrintMetrics(metrics);

            // Expected output:
            //   DCG: @1:40.57, @2:61.21, @3:74.11
            //   NDCG: @1:0.96, @2:0.95, @3:0.97
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0, int groupSize = 10)
        {
            var random = new Random(seed);
            float randomFloat() => (float)random.NextDouble();
            for (int i = 0; i < count; i++)
            {
                var label = random.Next(0, 5);
                yield return new DataPoint
                {
                    Label = (uint)label,
                    GroupId = (uint)(i / groupSize),
                    // Create random features that are correlated with the label.
                    // For data points with larger labels, the feature values are
                    // slightly increased by adding a constant.
                    Features = Enumerable.Repeat(label, 50).Select(
                        x => randomFloat() + x * 0.1f).ToArray()
                };
            }
        }

        // Example with label, groupId, and 50 feature values. A data set is a
        // collection of such examples.
        private class DataPoint
        {
            [KeyType(5)]
            public uint Label { get; set; }
            [KeyType(100)]
            public uint GroupId { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public uint Label { get; set; }
            // Score produced from the trainer.
            public float Score { get; set; }
        }

        // Pretty-print RankerMetrics objects.
        public static void PrintMetrics(RankingMetrics metrics)
        {
            Console.WriteLine("DCG: " + string.Join(", ",
                metrics.DiscountedCumulativeGains.Select(
                    (d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
            Console.WriteLine("NDCG: " + string.Join(", ",
                metrics.NormalizedDiscountedCumulativeGains.Select(
                    (d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
        }
    }
}

适用于

FastTree(RegressionCatalog+RegressionTrainers, FastTreeRegressionTrainer+Options)

使用高级选项创建 FastTreeRegressionTrainer ,该选项使用决策树回归模型预测目标。

public static Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer FastTree (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.Options options);
static member FastTree : Microsoft.ML.RegressionCatalog.RegressionTrainers * Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer
<Extension()>
Public Function FastTree (catalog As RegressionCatalog.RegressionTrainers, options As FastTreeRegressionTrainer.Options) As FastTreeRegressionTrainer

参数

options
FastTreeRegressionTrainer.Options

训练器选项。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;

namespace Samples.Dynamic.Trainers.Regression
{
    public static class FastTreeWithOptionsRegression
    {
        // This example requires installation of additional NuGet
        // package for Microsoft.ML.FastTree found at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define trainer options.
            var options = new FastTreeRegressionTrainer.Options
            {
                LabelColumnName = nameof(DataPoint.Label),
                FeatureColumnName = nameof(DataPoint.Features),
                // Use L2-norm for early stopping. If the gradient's L2-norm is
                // smaller than an auto-computed value, training process will stop.
                EarlyStoppingMetric =
                    Microsoft.ML.Trainers.FastTree.EarlyStoppingMetric.L2Norm,

                // Create a simpler model by penalizing usage of new features.
                FeatureFirstUsePenalty = 0.1,
                // Reduce the number of trees to 50.
                NumberOfTrees = 50
            };

            // Define the trainer.
            var pipeline =
                mlContext.Regression.Trainers.FastTree(options);

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data.LoadFromEnumerable(
                GenerateRandomDataPoints(5, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data.CreateEnumerable<Prediction>(
                transformedTestData, reuseRowObject: false).ToList();

            // Look at 5 predictions for the Label, side by side with the actual
            // Label for comparison.
            foreach (var p in predictions)
                Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");

            // Expected output:
            //   Label: 0.985, Prediction: 0.950
            //   Label: 0.155, Prediction: 0.111
            //   Label: 0.515, Prediction: 0.475
            //   Label: 0.566, Prediction: 0.575
            //   Label: 0.096, Prediction: 0.093

            // Evaluate the overall metrics
            var metrics = mlContext.Regression.Evaluate(transformedTestData);
            PrintMetrics(metrics);

            // Expected output:
            //   Mean Absolute Error: 0.03
            //   Mean Squared Error: 0.00
            //   Root Mean Squared Error: 0.03
            //   RSquared: 0.99 (closer to 1 is better. The worst case is 0)
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)
        {
            var random = new Random(seed);
            for (int i = 0; i < count; i++)
            {
                float label = (float)random.NextDouble();
                yield return new DataPoint
                {
                    Label = label,
                    // Create random features that are correlated with the label.
                    Features = Enumerable.Repeat(label, 50).Select(
                        x => x + (float)random.NextDouble()).ToArray()
                };
            }
        }

        // Example with label and 50 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public float Label { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public float Label { get; set; }
            // Predicted score from the trainer.
            public float Score { get; set; }
        }

        // Print some evaluation metrics to regression problems.
        private static void PrintMetrics(RegressionMetrics metrics)
        {
            Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
            Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
            Console.WriteLine(
                "Root Mean Squared Error: " + metrics.RootMeanSquaredError);

            Console.WriteLine("RSquared: " + metrics.RSquared);
        }
    }
}

适用于

FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, String, String, String, Int32, Int32, Int32, Double)

创建 FastTreeBinaryTrainer,它使用决策树二元分类模型预测目标。

public static Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer FastTree (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer
<Extension()>
Public Function FastTree (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeBinaryTrainer

参数

labelColumnName
String

标签列的名称。 列数据必须是 Boolean

featureColumnName
String

功能列的名称。 列数据必须是已知大小的向量 Single

exampleWeightColumnName
String

示例权重列的名称 (可选) 。

numberOfLeaves
Int32

每个决策树的最大叶数。

numberOfTrees
Int32

在合奏中创建的决策树总数。

minimumExampleCountPerLeaf
Int32

形成新树叶所需的最小数据点数。

learningRate
Double

学习速率。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Samples.Dynamic.Trainers.BinaryClassification
{
    public static class FastTree
    {
        // This example requires installation of additional NuGet package for 
        // Microsoft.ML.FastTree at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define the trainer.
            var pipeline = mlContext.BinaryClassification.Trainers
                .FastTree();

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data
                .LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data
                .CreateEnumerable<Prediction>(transformedTestData,
                reuseRowObject: false).ToList();

            // Print 5 predictions.
            foreach (var p in predictions.Take(5))
                Console.WriteLine($"Label: {p.Label}, "
                    + $"Prediction: {p.PredictedLabel}");

            // Expected output:
            //   Label: True, Prediction: True
            //   Label: False, Prediction: False
            //   Label: True, Prediction: True
            //   Label: True, Prediction: True
            //   Label: False, Prediction: False

            // Evaluate the overall metrics.
            var metrics = mlContext.BinaryClassification
                .Evaluate(transformedTestData);

            PrintMetrics(metrics);

            // Expected output:
            //   Accuracy: 0.81
            //   AUC: 0.91
            //   F1 Score: 0.80
            //   Negative Precision: 0.82
            //   Negative Recall: 0.80
            //   Positive Precision: 0.79
            //   Positive Recall: 0.81
            //   Log Loss: 0.59
            //   Log Loss Reduction: 41.04
            //   Entropy: 1.00
            //
            //   TEST POSITIVE RATIO:    0.4760 (238.0/(238.0+262.0))
            //   Confusion table
            //             ||======================
            //   PREDICTED || positive | negative | Recall
            //   TRUTH     ||======================
            //    positive ||      185 |       53 | 0.7773
            //    negative ||       83 |      179 | 0.6832
            //             ||======================
            //   Precision ||   0.6903 |   0.7716 |
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            float randomFloat() => (float)random.NextDouble();
            for (int i = 0; i < count; i++)
            {
                var label = randomFloat() > 0.5f;
                yield return new DataPoint
                {
                    Label = label,
                    // Create random features that are correlated with the label.
                    // For data points with false label, the feature values are
                    // slightly increased by adding a constant.
                    Features = Enumerable.Repeat(label, 50)
                        .Select(x => x ? randomFloat() : randomFloat() +
                        0.03f).ToArray()

                };
            }
        }

        // Example with label and 50 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public bool Label { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public bool Label { get; set; }
            // Predicted label from the trainer.
            public bool PredictedLabel { get; set; }
        }

        // Pretty-print BinaryClassificationMetrics objects.
        private static void PrintMetrics(BinaryClassificationMetrics metrics)
        {
            Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
            Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
            Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
            Console.WriteLine($"Negative Precision: " +
                $"{metrics.NegativePrecision:F2}");

            Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
            Console.WriteLine($"Positive Precision: " +
                $"{metrics.PositivePrecision:F2}");

            Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
            Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
        }
    }
}


适用于

FastTree(RegressionCatalog+RegressionTrainers, String, String, String, Int32, Int32, Int32, Double)

创建 FastTreeRegressionTrainer,它使用决策树回归模型预测目标。

public static Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer FastTree (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.RegressionCatalog.RegressionTrainers * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer
<Extension()>
Public Function FastTree (catalog As RegressionCatalog.RegressionTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeRegressionTrainer

参数

labelColumnName
String

标签列的名称。 列数据必须是 Single

featureColumnName
String

功能列的名称。 列数据必须是已知大小的向量 Single

exampleWeightColumnName
String

示例权重列的名称 (可选) 。

numberOfLeaves
Int32

每个决策树的最大叶数。

numberOfTrees
Int32

在合奏中创建的决策树总数。

minimumExampleCountPerLeaf
Int32

形成新树叶所需的最小数据点数。

learningRate
Double

学习速率。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Samples.Dynamic.Trainers.Regression
{
    public static class FastTreeRegression
    {
        // This example requires installation of additional NuGet
        // package for Microsoft.ML.FastTree found at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define the trainer.
            var pipeline = mlContext.Regression.Trainers.FastTree(
                labelColumnName: nameof(DataPoint.Label),
                featureColumnName: nameof(DataPoint.Features));

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data.LoadFromEnumerable(
                GenerateRandomDataPoints(5, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data.CreateEnumerable<Prediction>(
                transformedTestData, reuseRowObject: false).ToList();

            // Look at 5 predictions for the Label, side by side with the actual
            // Label for comparison.
            foreach (var p in predictions)
                Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");

            // Expected output:
            //   Label: 0.985, Prediction: 0.938
            //   Label: 0.155, Prediction: 0.131
            //   Label: 0.515, Prediction: 0.517
            //   Label: 0.566, Prediction: 0.519
            //   Label: 0.096, Prediction: 0.089

            // Evaluate the overall metrics
            var metrics = mlContext.Regression.Evaluate(transformedTestData);
            PrintMetrics(metrics);

            // Expected output:
            //   Mean Absolute Error: 0.03
            //   Mean Squared Error: 0.00
            //   Root Mean Squared Error: 0.03
            //   RSquared: 0.99 (closer to 1 is better. The worst case is 0)
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)
        {
            var random = new Random(seed);
            for (int i = 0; i < count; i++)
            {
                float label = (float)random.NextDouble();
                yield return new DataPoint
                {
                    Label = label,
                    // Create random features that are correlated with the label.
                    Features = Enumerable.Repeat(label, 50).Select(
                        x => x + (float)random.NextDouble()).ToArray()
                };
            }
        }

        // Example with label and 50 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public float Label { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public float Label { get; set; }
            // Predicted score from the trainer.
            public float Score { get; set; }
        }

        // Print some evaluation metrics to regression problems.
        private static void PrintMetrics(RegressionMetrics metrics)
        {
            Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
            Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
            Console.WriteLine(
                "Root Mean Squared Error: " + metrics.RootMeanSquaredError);

            Console.WriteLine("RSquared: " + metrics.RSquared);
        }
    }
}

适用于

FastTree(RankingCatalog+RankingTrainers, String, String, String, String, Int32, Int32, Int32, Double)

使用决策树排名模型创建一 FastTreeRankingTrainer个基于其相关性对一系列输入进行排名。

public static Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer FastTree (this Microsoft.ML.RankingCatalog.RankingTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string rowGroupColumnName = "GroupId", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.RankingCatalog.RankingTrainers * string * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer
<Extension()>
Public Function FastTree (catalog As RankingCatalog.RankingTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional rowGroupColumnName As String = "GroupId", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeRankingTrainer

参数

labelColumnName
String

标签列的名称。 列数据必须为 SingleKeyDataViewType

featureColumnName
String

功能列的名称。 列数据必须是已知大小的向量 Single

rowGroupColumnName
String

组列的名称。

exampleWeightColumnName
String

示例权重列的名称 (可选) 。

numberOfLeaves
Int32

每个决策树的最大叶数。

numberOfTrees
Int32

在合奏中创建的决策树总数。

minimumExampleCountPerLeaf
Int32

形成新树叶所需的最小数据点数。

learningRate
Double

学习速率。

返回

示例

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Samples.Dynamic.Trainers.Ranking
{
    public static class FastTree
    {
        // This example requires installation of additional NuGet package for 
        // Microsoft.ML.FastTree at
        // https://www.nuget.org/packages/Microsoft.ML.FastTree/
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define the trainer.
            var pipeline = mlContext.Ranking.Trainers.FastTree();

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data.LoadFromEnumerable(
                GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Take the top 5 rows.
            var topTransformedTestData = mlContext.Data.TakeRows(
                transformedTestData, 5);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data.CreateEnumerable<Prediction>(
                topTransformedTestData, reuseRowObject: false).ToList();

            // Print 5 predictions.
            foreach (var p in predictions)
                Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");

            // Expected output:
            //   Label: 5, Score: 13.0154
            //   Label: 1, Score: -19.27798
            //   Label: 3, Score: -12.43686
            //   Label: 3, Score: -8.178633
            //   Label: 1, Score: -17.09313

            // Evaluate the overall metrics.
            var metrics = mlContext.Ranking.Evaluate(transformedTestData);
            PrintMetrics(metrics);

            // Expected output:
            //   DCG: @1:41.95, @2:63.33, @3:75.65
            //   NDCG: @1:0.99, @2:0.98, @3:0.99
        }

        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0, int groupSize = 10)
        {
            var random = new Random(seed);
            float randomFloat() => (float)random.NextDouble();
            for (int i = 0; i < count; i++)
            {
                var label = random.Next(0, 5);
                yield return new DataPoint
                {
                    Label = (uint)label,
                    GroupId = (uint)(i / groupSize),
                    // Create random features that are correlated with the label.
                    // For data points with larger labels, the feature values are
                    // slightly increased by adding a constant.
                    Features = Enumerable.Repeat(label, 50).Select(
                           x => randomFloat() + x * 0.1f).ToArray()
                };
            }
        }

        // Example with label, groupId, and 50 feature values. A data set is a
        // collection of such examples.
        private class DataPoint
        {
            [KeyType(5)]
            public uint Label { get; set; }
            [KeyType(100)]
            public uint GroupId { get; set; }
            [VectorType(50)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public uint Label { get; set; }
            // Score produced from the trainer.
            public float Score { get; set; }
        }

        // Pretty-print RankerMetrics objects.
        public static void PrintMetrics(RankingMetrics metrics)
        {
            Console.WriteLine("DCG: " + string.Join(", ",
                metrics.DiscountedCumulativeGains.Select(
                (d, i) => (i + 1) + ":" + d + ":F2").ToArray()));

            Console.WriteLine("NDCG: " + string.Join(", ",
                metrics.NormalizedDiscountedCumulativeGains.Select(
                (d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
        }
    }
}

适用于