TreeExtensions.FastTree メソッド
定義
重要
一部の情報は、リリース前に大きく変更される可能性があるプレリリースされた製品に関するものです。 Microsoft は、ここに記載されている情報について、明示または黙示を問わず、一切保証しません。
オーバーロード
FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, FastTreeBinaryTrainer+Options)
デシジョン ツリーの二項分類モデルを使用してターゲットを予測する高度なオプションを使用して作成 FastTreeBinaryTrainer します。
public static Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer FastTree (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options options);
static member FastTree : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer
<Extension()>
Public Function FastTree (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, options As FastTreeBinaryTrainer.Options) As FastTreeBinaryTrainer
パラメーター
- options
- FastTreeBinaryTrainer.Options
トレーナー オプション。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
namespace Samples.Dynamic.Trainers.BinaryClassification
{
public static class FastTreeWithOptions
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define trainer options.
var options = new FastTreeBinaryTrainer.Options
{
// Use L2Norm for early stopping.
EarlyStoppingMetric = EarlyStoppingMetric.L2Norm,
// Create a simpler model by penalizing usage of new features.
FeatureFirstUsePenalty = 0.1,
// Reduce the number of trees to 50.
NumberOfTrees = 50
};
// Define the trainer.
var pipeline = mlContext.BinaryClassification.Trainers
.FastTree(options);
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data
.CreateEnumerable<Prediction>(transformedTestData,
reuseRowObject: false).ToList();
// Print 5 predictions.
foreach (var p in predictions.Take(5))
Console.WriteLine($"Label: {p.Label}, "
+ $"Prediction: {p.PredictedLabel}");
// Expected output:
// Label: True, Prediction: True
// Label: False, Prediction: False
// Label: True, Prediction: True
// Label: True, Prediction: True
// Label: False, Prediction: False
// Evaluate the overall metrics.
var metrics = mlContext.BinaryClassification
.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Accuracy: 0.78
// AUC: 0.88
// F1 Score: 0.79
// Negative Precision: 0.83
// Negative Recall: 0.74
// Positive Precision: 0.74
// Positive Recall: 0.84
// Log Loss: 0.62
// Log Loss Reduction: 37.77
// Entropy: 1.00
//
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
// Confusion table
// ||======================
// PREDICTED || positive | negative | Recall
// TRUTH ||======================
// positive || 185 | 53 | 0.7773
// negative || 83 | 179 | 0.6832
// ||======================
// Precision || 0.6903 | 0.7716 |
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = randomFloat() > 0.5f;
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
// For data points with false label, the feature values are
// slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 50)
.Select(x => x ? randomFloat() : randomFloat() +
0.03f).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public bool Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public bool Label { get; set; }
// Predicted label from the trainer.
public bool PredictedLabel { get; set; }
}
// Pretty-print BinaryClassificationMetrics objects.
private static void PrintMetrics(BinaryClassificationMetrics metrics)
{
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
Console.WriteLine($"Negative Precision: " +
$"{metrics.NegativePrecision:F2}");
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
Console.WriteLine($"Positive Precision: " +
$"{metrics.PositivePrecision:F2}");
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
}
}
}
適用対象
FastTree(RankingCatalog+RankingTrainers, FastTreeRankingTrainer+Options)
FastTreeRankingTrainerデシジョン ツリーのランク付けモデルを使用して、関連性に基づいて一連の入力をランク付けする高度なオプションを作成します。
public static Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer FastTree (this Microsoft.ML.RankingCatalog.RankingTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options options);
static member FastTree : Microsoft.ML.RankingCatalog.RankingTrainers * Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer
<Extension()>
Public Function FastTree (catalog As RankingCatalog.RankingTrainers, options As FastTreeRankingTrainer.Options) As FastTreeRankingTrainer
パラメーター
- options
- FastTreeRankingTrainer.Options
トレーナー オプション。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
namespace Samples.Dynamic.Trainers.Ranking
{
public static class FastTreeWithOptions
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define trainer options.
var options = new FastTreeRankingTrainer.Options
{
// Use NdcgAt3 for early stopping.
EarlyStoppingMetric = EarlyStoppingRankingMetric.NdcgAt3,
// Create a simpler model by penalizing usage of new features.
FeatureFirstUsePenalty = 0.1,
// Reduce the number of trees to 50.
NumberOfTrees = 50,
// Specify the row group column name.
RowGroupColumnName = "GroupId"
};
// Define the trainer.
var pipeline = mlContext.Ranking.Trainers.FastTree(options);
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(500, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Take the top 5 rows.
var topTransformedTestData = mlContext.Data.TakeRows(
transformedTestData, 5);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
topTransformedTestData, reuseRowObject: false).ToList();
// Print 5 predictions.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");
// Expected output:
// Label: 5, Score: 8.807633
// Label: 1, Score: -10.71331
// Label: 3, Score: -8.134147
// Label: 3, Score: -6.545538
// Label: 1, Score: -10.27982
// Evaluate the overall metrics.
var metrics = mlContext.Ranking.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// DCG: @1:40.57, @2:61.21, @3:74.11
// NDCG: @1:0.96, @2:0.95, @3:0.97
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0, int groupSize = 10)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = random.Next(0, 5);
yield return new DataPoint
{
Label = (uint)label,
GroupId = (uint)(i / groupSize),
// Create random features that are correlated with the label.
// For data points with larger labels, the feature values are
// slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 50).Select(
x => randomFloat() + x * 0.1f).ToArray()
};
}
}
// Example with label, groupId, and 50 feature values. A data set is a
// collection of such examples.
private class DataPoint
{
[KeyType(5)]
public uint Label { get; set; }
[KeyType(100)]
public uint GroupId { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public uint Label { get; set; }
// Score produced from the trainer.
public float Score { get; set; }
}
// Pretty-print RankerMetrics objects.
public static void PrintMetrics(RankingMetrics metrics)
{
Console.WriteLine("DCG: " + string.Join(", ",
metrics.DiscountedCumulativeGains.Select(
(d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
Console.WriteLine("NDCG: " + string.Join(", ",
metrics.NormalizedDiscountedCumulativeGains.Select(
(d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
}
}
}
適用対象
FastTree(RegressionCatalog+RegressionTrainers, FastTreeRegressionTrainer+Options)
デシジョン ツリー回帰モデルを使用してターゲットを予測する高度なオプションを使用して作成 FastTreeRegressionTrainer します。
public static Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer FastTree (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.Options options);
static member FastTree : Microsoft.ML.RegressionCatalog.RegressionTrainers * Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer.Options -> Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer
<Extension()>
Public Function FastTree (catalog As RegressionCatalog.RegressionTrainers, options As FastTreeRegressionTrainer.Options) As FastTreeRegressionTrainer
パラメーター
トレーナー オプション。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers.FastTree;
namespace Samples.Dynamic.Trainers.Regression
{
public static class FastTreeWithOptionsRegression
{
// This example requires installation of additional NuGet
// package for Microsoft.ML.FastTree found at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define trainer options.
var options = new FastTreeRegressionTrainer.Options
{
LabelColumnName = nameof(DataPoint.Label),
FeatureColumnName = nameof(DataPoint.Features),
// Use L2-norm for early stopping. If the gradient's L2-norm is
// smaller than an auto-computed value, training process will stop.
EarlyStoppingMetric =
Microsoft.ML.Trainers.FastTree.EarlyStoppingMetric.L2Norm,
// Create a simpler model by penalizing usage of new features.
FeatureFirstUsePenalty = 0.1,
// Reduce the number of trees to 50.
NumberOfTrees = 50
};
// Define the trainer.
var pipeline =
mlContext.Regression.Trainers.FastTree(options);
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// Expected output:
// Label: 0.985, Prediction: 0.950
// Label: 0.155, Prediction: 0.111
// Label: 0.515, Prediction: 0.475
// Label: 0.566, Prediction: 0.575
// Label: 0.096, Prediction: 0.093
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Mean Absolute Error: 0.03
// Mean Squared Error: 0.00
// Root Mean Squared Error: 0.03
// RSquared: 0.99 (closer to 1 is better. The worst case is 0)
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}
適用対象
FastTree(BinaryClassificationCatalog+BinaryClassificationTrainers, String, String, String, Int32, Int32, Int32, Double)
デシジョン ツリーの二項分類モデルを使用してターゲットを予測する作成 FastTreeBinaryTrainer。
public static Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer FastTree (this Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.BinaryClassificationCatalog.BinaryClassificationTrainers * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeBinaryTrainer
<Extension()>
Public Function FastTree (catalog As BinaryClassificationCatalog.BinaryClassificationTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeBinaryTrainer
パラメーター
- exampleWeightColumnName
- String
重み列の例の名前 (省略可能)。
- numberOfLeaves
- Int32
デシジョン ツリーあたりのリーフの最大数。
- numberOfTrees
- Int32
アンサンブルで作成するデシジョン ツリーの合計数。
- minimumExampleCountPerLeaf
- Int32
新しいツリー リーフを形成するために必要なデータ ポイントの最小数。
- learningRate
- Double
学習率。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.BinaryClassification
{
public static class FastTree
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline = mlContext.BinaryClassification.Trainers
.FastTree();
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data
.LoadFromEnumerable(GenerateRandomDataPoints(500, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data
.CreateEnumerable<Prediction>(transformedTestData,
reuseRowObject: false).ToList();
// Print 5 predictions.
foreach (var p in predictions.Take(5))
Console.WriteLine($"Label: {p.Label}, "
+ $"Prediction: {p.PredictedLabel}");
// Expected output:
// Label: True, Prediction: True
// Label: False, Prediction: False
// Label: True, Prediction: True
// Label: True, Prediction: True
// Label: False, Prediction: False
// Evaluate the overall metrics.
var metrics = mlContext.BinaryClassification
.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Accuracy: 0.81
// AUC: 0.91
// F1 Score: 0.80
// Negative Precision: 0.82
// Negative Recall: 0.80
// Positive Precision: 0.79
// Positive Recall: 0.81
// Log Loss: 0.59
// Log Loss Reduction: 41.04
// Entropy: 1.00
//
// TEST POSITIVE RATIO: 0.4760 (238.0/(238.0+262.0))
// Confusion table
// ||======================
// PREDICTED || positive | negative | Recall
// TRUTH ||======================
// positive || 185 | 53 | 0.7773
// negative || 83 | 179 | 0.6832
// ||======================
// Precision || 0.6903 | 0.7716 |
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = randomFloat() > 0.5f;
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
// For data points with false label, the feature values are
// slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 50)
.Select(x => x ? randomFloat() : randomFloat() +
0.03f).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public bool Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public bool Label { get; set; }
// Predicted label from the trainer.
public bool PredictedLabel { get; set; }
}
// Pretty-print BinaryClassificationMetrics objects.
private static void PrintMetrics(BinaryClassificationMetrics metrics)
{
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
Console.WriteLine($"Negative Precision: " +
$"{metrics.NegativePrecision:F2}");
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
Console.WriteLine($"Positive Precision: " +
$"{metrics.PositivePrecision:F2}");
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
}
}
}
適用対象
FastTree(RegressionCatalog+RegressionTrainers, String, String, String, Int32, Int32, Int32, Double)
デシジョン ツリー回帰モデルを使用してターゲットを予測する作成 FastTreeRegressionTrainer。
public static Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer FastTree (this Microsoft.ML.RegressionCatalog.RegressionTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.RegressionCatalog.RegressionTrainers * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeRegressionTrainer
<Extension()>
Public Function FastTree (catalog As RegressionCatalog.RegressionTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeRegressionTrainer
パラメーター
- exampleWeightColumnName
- String
重み列の例の名前 (省略可能)。
- numberOfLeaves
- Int32
デシジョン ツリーあたりのリーフの最大数。
- numberOfTrees
- Int32
アンサンブルで作成するデシジョン ツリーの合計数。
- minimumExampleCountPerLeaf
- Int32
新しいツリー リーフを形成するために必要なデータ ポイントの最小数。
- learningRate
- Double
学習率。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.Regression
{
public static class FastTreeRegression
{
// This example requires installation of additional NuGet
// package for Microsoft.ML.FastTree found at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline = mlContext.Regression.Trainers.FastTree(
labelColumnName: nameof(DataPoint.Label),
featureColumnName: nameof(DataPoint.Features));
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(5, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
transformedTestData, reuseRowObject: false).ToList();
// Look at 5 predictions for the Label, side by side with the actual
// Label for comparison.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}");
// Expected output:
// Label: 0.985, Prediction: 0.938
// Label: 0.155, Prediction: 0.131
// Label: 0.515, Prediction: 0.517
// Label: 0.566, Prediction: 0.519
// Label: 0.096, Prediction: 0.089
// Evaluate the overall metrics
var metrics = mlContext.Regression.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// Mean Absolute Error: 0.03
// Mean Squared Error: 0.00
// Root Mean Squared Error: 0.03
// RSquared: 0.99 (closer to 1 is better. The worst case is 0)
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0)
{
var random = new Random(seed);
for (int i = 0; i < count; i++)
{
float label = (float)random.NextDouble();
yield return new DataPoint
{
Label = label,
// Create random features that are correlated with the label.
Features = Enumerable.Repeat(label, 50).Select(
x => x + (float)random.NextDouble()).ToArray()
};
}
}
// Example with label and 50 feature values. A data set is a collection of
// such examples.
private class DataPoint
{
public float Label { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public float Label { get; set; }
// Predicted score from the trainer.
public float Score { get; set; }
}
// Print some evaluation metrics to regression problems.
private static void PrintMetrics(RegressionMetrics metrics)
{
Console.WriteLine("Mean Absolute Error: " + metrics.MeanAbsoluteError);
Console.WriteLine("Mean Squared Error: " + metrics.MeanSquaredError);
Console.WriteLine(
"Root Mean Squared Error: " + metrics.RootMeanSquaredError);
Console.WriteLine("RSquared: " + metrics.RSquared);
}
}
}
適用対象
FastTree(RankingCatalog+RankingTrainers, String, String, String, String, Int32, Int32, Int32, Double)
FastTreeRankingTrainerデシジョン ツリーのランク付けモデルを使用して、関連性に基づいて一連の入力をランク付けする 、作成します。
public static Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer FastTree (this Microsoft.ML.RankingCatalog.RankingTrainers catalog, string labelColumnName = "Label", string featureColumnName = "Features", string rowGroupColumnName = "GroupId", string exampleWeightColumnName = default, int numberOfLeaves = 20, int numberOfTrees = 100, int minimumExampleCountPerLeaf = 10, double learningRate = 0.2);
static member FastTree : Microsoft.ML.RankingCatalog.RankingTrainers * string * string * string * string * int * int * int * double -> Microsoft.ML.Trainers.FastTree.FastTreeRankingTrainer
<Extension()>
Public Function FastTree (catalog As RankingCatalog.RankingTrainers, Optional labelColumnName As String = "Label", Optional featureColumnName As String = "Features", Optional rowGroupColumnName As String = "GroupId", Optional exampleWeightColumnName As String = Nothing, Optional numberOfLeaves As Integer = 20, Optional numberOfTrees As Integer = 100, Optional minimumExampleCountPerLeaf As Integer = 10, Optional learningRate As Double = 0.2) As FastTreeRankingTrainer
パラメーター
- labelColumnName
- String
ラベル列の名前。 列のデータは次の値であるSingleKeyDataViewType必要があります。
- rowGroupColumnName
- String
グループ列の名前。
- exampleWeightColumnName
- String
重み列の例の名前 (省略可能)。
- numberOfLeaves
- Int32
デシジョン ツリーあたりのリーフの最大数。
- numberOfTrees
- Int32
アンサンブルで作成するデシジョン ツリーの合計数。
- minimumExampleCountPerLeaf
- Int32
新しいツリー リーフを形成するために必要なデータ ポイントの最小数。
- learningRate
- Double
学習率。
戻り値
例
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.Ranking
{
public static class FastTree
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness. Setting the seed to a fixed number
// in this example to make outputs deterministic.
var mlContext = new MLContext(seed: 0);
// Create a list of training data points.
var dataPoints = GenerateRandomDataPoints(1000);
// Convert the list of data points to an IDataView object, which is
// consumable by ML.NET API.
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
// Define the trainer.
var pipeline = mlContext.Ranking.Trainers.FastTree();
// Train the model.
var model = pipeline.Fit(trainingData);
// Create testing data. Use different random seed to make it different
// from training data.
var testData = mlContext.Data.LoadFromEnumerable(
GenerateRandomDataPoints(500, seed: 123));
// Run the model on test data set.
var transformedTestData = model.Transform(testData);
// Take the top 5 rows.
var topTransformedTestData = mlContext.Data.TakeRows(
transformedTestData, 5);
// Convert IDataView object to a list.
var predictions = mlContext.Data.CreateEnumerable<Prediction>(
topTransformedTestData, reuseRowObject: false).ToList();
// Print 5 predictions.
foreach (var p in predictions)
Console.WriteLine($"Label: {p.Label}, Score: {p.Score}");
// Expected output:
// Label: 5, Score: 13.0154
// Label: 1, Score: -19.27798
// Label: 3, Score: -12.43686
// Label: 3, Score: -8.178633
// Label: 1, Score: -17.09313
// Evaluate the overall metrics.
var metrics = mlContext.Ranking.Evaluate(transformedTestData);
PrintMetrics(metrics);
// Expected output:
// DCG: @1:41.95, @2:63.33, @3:75.65
// NDCG: @1:0.99, @2:0.98, @3:0.99
}
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
int seed = 0, int groupSize = 10)
{
var random = new Random(seed);
float randomFloat() => (float)random.NextDouble();
for (int i = 0; i < count; i++)
{
var label = random.Next(0, 5);
yield return new DataPoint
{
Label = (uint)label,
GroupId = (uint)(i / groupSize),
// Create random features that are correlated with the label.
// For data points with larger labels, the feature values are
// slightly increased by adding a constant.
Features = Enumerable.Repeat(label, 50).Select(
x => randomFloat() + x * 0.1f).ToArray()
};
}
}
// Example with label, groupId, and 50 feature values. A data set is a
// collection of such examples.
private class DataPoint
{
[KeyType(5)]
public uint Label { get; set; }
[KeyType(100)]
public uint GroupId { get; set; }
[VectorType(50)]
public float[] Features { get; set; }
}
// Class used to capture predictions.
private class Prediction
{
// Original label.
public uint Label { get; set; }
// Score produced from the trainer.
public float Score { get; set; }
}
// Pretty-print RankerMetrics objects.
public static void PrintMetrics(RankingMetrics metrics)
{
Console.WriteLine("DCG: " + string.Join(", ",
metrics.DiscountedCumulativeGains.Select(
(d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
Console.WriteLine("NDCG: " + string.Join(", ",
metrics.NormalizedDiscountedCumulativeGains.Select(
(d, i) => (i + 1) + ":" + d + ":F2").ToArray()));
}
}
}