Condividi tramite


StandardTrainersCatalog.OneVersusAll<TModel> Metodo

Definizione

Creare un OneVersusAllTraineroggetto , che stima una destinazione multiclasse usando una strategia one-versus-all con lo strumento di stima della classificazione binaria specificato da binaryEstimator.

public static Microsoft.ML.Trainers.OneVersusAllTrainer OneVersusAll<TModel> (this Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog, Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<TModel>,TModel> binaryEstimator, string labelColumnName = "Label", bool imputeMissingLabelsAsNegative = false, Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> calibrator = default, int maximumCalibrationExampleCount = 1000000000, bool useProbabilities = true) where TModel : class;
static member OneVersusAll : Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers * Microsoft.ML.Trainers.ITrainerEstimator<Microsoft.ML.Data.BinaryPredictionTransformer<'Model>, 'Model (requires 'Model : null)> * string * bool * Microsoft.ML.IEstimator<Microsoft.ML.ISingleFeaturePredictionTransformer<Microsoft.ML.Calibrators.ICalibrator>> * int * bool -> Microsoft.ML.Trainers.OneVersusAllTrainer (requires 'Model : null)
<Extension()>
Public Function OneVersusAll(Of TModel As Class) (catalog As MulticlassClassificationCatalog.MulticlassClassificationTrainers, binaryEstimator As ITrainerEstimator(Of BinaryPredictionTransformer(Of TModel), TModel), Optional labelColumnName As String = "Label", Optional imputeMissingLabelsAsNegative As Boolean = false, Optional calibrator As IEstimator(Of ISingleFeaturePredictionTransformer(Of ICalibrator)) = Nothing, Optional maximumCalibrationExampleCount As Integer = 1000000000, Optional useProbabilities As Boolean = true) As OneVersusAllTrainer

Parametri di tipo

TModel

Tipo del modello. Questo parametro di tipo in genere verrà dedotto automaticamente da binaryEstimator.

Parametri

catalog
MulticlassClassificationCatalog.MulticlassClassificationTrainers

Oggetto trainer del catalogo di classificazione multiclasse.

binaryEstimator
ITrainerEstimator<BinaryPredictionTransformer<TModel>,TModel>

Istanza di un file binario ITrainerEstimator<TTransformer,TModel> utilizzato come formatore di base.

labelColumnName
String

Nome della colonna dell'etichetta.

imputeMissingLabelsAsNegative
Boolean

Se considerare le etichette mancanti come con etichette negative, invece di mantenerle mancanti.

calibrator
IEstimator<ISingleFeaturePredictionTransformer<ICalibrator>>

Calibratore. Se non viene specificato in modo esplicito un calibratore, per impostazione predefinita Microsoft.ML.Calibrators.PlattCalibratorTrainer

maximumCalibrationExampleCount
Int32

Numero di istanze per il training del calibratore.

useProbabilities
Boolean

Usare le probabilità (e gli output non elaborati) per identificare la categoria di punteggio superiore.

Restituisce

Esempio

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Samples.Dynamic.Trainers.MulticlassClassification
{
    public static class OneVersusAll
    {
        public static void Example()
        {
            // Create a new context for ML.NET operations. It can be used for
            // exception tracking and logging, as a catalog of available operations
            // and as the source of randomness. Setting the seed to a fixed number
            // in this example to make outputs deterministic.
            var mlContext = new MLContext(seed: 0);

            // Create a list of training data points.
            var dataPoints = GenerateRandomDataPoints(1000);

            // Convert the list of data points to an IDataView object, which is
            // consumable by ML.NET API.
            var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);

            // Define the trainer.
            var pipeline =
                // Convert the string labels into key types.
                mlContext.Transforms.Conversion.MapValueToKey("Label")
                // Apply OneVersusAll multiclass meta trainer on top of
                // binary trainer.
                .Append(mlContext.MulticlassClassification.Trainers
                .OneVersusAll(
                mlContext.BinaryClassification.Trainers.SdcaLogisticRegression()));

            // Train the model.
            var model = pipeline.Fit(trainingData);

            // Create testing data. Use different random seed to make it different
            // from training data.
            var testData = mlContext.Data
                .LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));

            // Run the model on test data set.
            var transformedTestData = model.Transform(testData);

            // Convert IDataView object to a list.
            var predictions = mlContext.Data
                .CreateEnumerable<Prediction>(transformedTestData,
                reuseRowObject: false).ToList();

            // Look at 5 predictions
            foreach (var p in predictions.Take(5))
                Console.WriteLine($"Label: {p.Label}, " +
                    $"Prediction: {p.PredictedLabel}");

            // Expected output:
            //   Label: 1, Prediction: 1
            //   Label: 2, Prediction: 2
            //   Label: 3, Prediction: 2
            //   Label: 2, Prediction: 2
            //   Label: 3, Prediction: 2

            // Evaluate the overall metrics
            var metrics = mlContext.MulticlassClassification
                .Evaluate(transformedTestData);

            PrintMetrics(metrics);

            // Expected output:
            //   Micro Accuracy: 0.90
            //   Macro Accuracy: 0.90
            //   Log Loss: 0.36
            //   Log Loss Reduction: 0.68

            //   Confusion table
            //             ||========================
            //   PREDICTED ||     0 |     1 |     2 | Recall
            //   TRUTH     ||========================
            //           0 ||   152 |     0 |     8 | 0.9500
            //           1 ||     0 |   168 |     9 | 0.9492
            //           2 ||    17 |    15 |   131 | 0.8037
            //             ||========================
            //   Precision ||0.8994 |0.9180 |0.8851 |
        }

        // Generates random uniform doubles in [-0.5, 0.5)
        // range with labels 1, 2 or 3.
        private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
            int seed = 0)

        {
            var random = new Random(seed);
            float randomFloat() => (float)(random.NextDouble() - 0.5);
            for (int i = 0; i < count; i++)
            {
                // Generate Labels that are integers 1, 2 or 3
                var label = random.Next(1, 4);
                yield return new DataPoint
                {
                    Label = (uint)label,
                    // Create random features that are correlated with the label.
                    // The feature values are slightly increased by adding a
                    // constant multiple of label.
                    Features = Enumerable.Repeat(label, 20)
                        .Select(x => randomFloat() + label * 0.2f).ToArray()

                };
            }
        }

        // Example with label and 20 feature values. A data set is a collection of
        // such examples.
        private class DataPoint
        {
            public uint Label { get; set; }
            [VectorType(20)]
            public float[] Features { get; set; }
        }

        // Class used to capture predictions.
        private class Prediction
        {
            // Original label.
            public uint Label { get; set; }
            // Predicted label from the trainer.
            public uint PredictedLabel { get; set; }
        }

        // Pretty-print MulticlassClassificationMetrics objects.
        public static void PrintMetrics(MulticlassClassificationMetrics metrics)
        {
            Console.WriteLine($"Micro Accuracy: {metrics.MicroAccuracy:F2}");
            Console.WriteLine($"Macro Accuracy: {metrics.MacroAccuracy:F2}");
            Console.WriteLine($"Log Loss: {metrics.LogLoss:F2}");
            Console.WriteLine(
                $"Log Loss Reduction: {metrics.LogLossReduction:F2}\n");

            Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
        }
    }
}

Commenti

In una strategia uno-e-all, un algoritmo di classificazione binaria viene usato per eseguire il training di un classificatore per ogni classe, che distingue tale classe da tutte le altre classi. La stima viene quindi eseguita eseguendo questi classificatori binari e scegliendo la stima con il punteggio di attendibilità più alto.

Si applica a