-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathRegressionModel.cs
More file actions
115 lines (103 loc) · 5.54 KB
/
RegressionModel.cs
File metadata and controls
115 lines (103 loc) · 5.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using Mvvm;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Windows.Storage;
namespace XamlBrewer.Uwp.MachineLearningSample.Models
{
internal class RegressionModel : ViewModelBase
{
private MLContext _mlContext = new MLContext(seed: null);
private IDataView trainingData;
private PredictionEngine<RegressionData, RegressionPrediction> predictionEngine;
public ITransformer Model { get; private set; }
public IEnumerable<RegressionData> Load(string trainingDataPath)
{
var readerOptions = new TextLoader.Options()
{
Separators = new[] { ',' },
HasHeader = true,
Columns = new[]
{
new TextLoader.Column("Label", DataKind.Single, 1),
new TextLoader.Column("NBA_DraftNumber", DataKind.Single, 3),
new TextLoader.Column("Age", DataKind.Single, 4),
new TextLoader.Column("Ws", DataKind.Single, 22),
new TextLoader.Column("Bmp", DataKind.Single, 26)
}
};
trainingData = _mlContext.Data.LoadFromTextFile(trainingDataPath, readerOptions);
return _mlContext.Data.CreateEnumerable<RegressionData>(trainingData, reuseRowObject: false);
}
public void BuildAndTrain(string regressionTrainer)
{
var prepipeline = _mlContext.Transforms.ReplaceMissingValues("Age", "Age", MissingValueReplacingEstimator.ReplacementMode.Mean)
.Append(_mlContext.Transforms.ReplaceMissingValues("Ws", "Ws", MissingValueReplacingEstimator.ReplacementMode.Mean))
.Append(_mlContext.Transforms.ReplaceMissingValues("Bmp", "Bmp", MissingValueReplacingEstimator.ReplacementMode.Mean))
.Append(_mlContext.Transforms.ReplaceMissingValues("NBA_DraftNumber", "NBA_DraftNumber", MissingValueReplacingEstimator.ReplacementMode.Mean))
.Append(_mlContext.Transforms.NormalizeBinning("NBA_DraftNumber", "NBA_DraftNumber"))
.Append(_mlContext.Transforms.NormalizeMinMax("Age", "Age"))
.Append(_mlContext.Transforms.NormalizeMeanVariance("Ws", "Ws"))
.Append(_mlContext.Transforms.NormalizeMeanVariance("Bmp", "Bmp"))
.Append(_mlContext.Transforms.Concatenate(
"Features",
new[] { "NBA_DraftNumber", "Age", "Ws", "Bmp" }));
// .Append(_mlContext.Regression.Trainers.FastTree()); // PlatformNotSupportedException
// .Append(_mlContext.Regression.Trainers.OnlineGradientDescent(new OnlineGradientDescentTrainer.Options { })); // InvalidOperationException if you don't normalize.
// .Append(_mlContext.Regression.Trainers.StochasticDualCoordinateAscent());
// .Append(_mlContext.Regression.Trainers.PoissonRegression());
//.Append(_mlContext.Regression.Trainers.Gam());
switch (regressionTrainer)
{
//case "FastTree": // PlatformNotSupportedException
// var pipelineFastTree = prepipeline.Append(_mlContext.Regression.Trainers.FastTree());
// Model = pipelineFastTree.Fit(trainingData);
// break;
//case "FastTreeTweedie": // PlatformNotSupportedException
// var pipelineFastTreeTweedie = prepipeline.Append(_mlContext.Regression.Trainers.FastTreeTweedie());
// Model = pipelineFastTreeTweedie.Fit(trainingData);
// break;
case "Gam":
var pipelineGam = prepipeline.Append(_mlContext.Regression.Trainers.Gam());
Model = pipelineGam.Fit(trainingData);
break;
case "LightGbm":
var pipelineLightGbm = prepipeline.Append(_mlContext.Regression.Trainers.LightGbm());
Model = pipelineLightGbm.Fit(trainingData);
break;
case "Ols":
var pipelineOls = prepipeline.Append(_mlContext.Regression.Trainers.Ols());
Model = pipelineOls.Fit(trainingData);
break;
case "Sdca":
var pipelineSdca = prepipeline.Append(_mlContext.Regression.Trainers.Sdca());
Model = pipelineSdca.Fit(trainingData);
break;
default:
var pipeline = prepipeline.Append(_mlContext.Regression.Trainers.Gam());
Model = pipeline.Fit(trainingData);
break;
}
predictionEngine = _mlContext.Model.CreatePredictionEngine<RegressionData, RegressionPrediction>(Model);
}
public void Save(string modelName)
{
var storageFolder = ApplicationData.Current.LocalFolder;
string modelPath = Path.Combine(storageFolder.Path, modelName);
_mlContext.Model.Save(Model, inputSchema: null, filePath: modelPath);
}
public IEnumerable<RegressionPrediction> PredictTrainingData()
{
var res = Model.Transform(trainingData);
return _mlContext.Data.CreateEnumerable<RegressionPrediction>(res, reuseRowObject: false);
}
public RegressionPrediction Predict(RegressionData data)
{
return predictionEngine.Predict(data);
}
}
}