From 03883af2d50e94b317a426d029d6c92f63e222e0 Mon Sep 17 00:00:00 2001 From: Bill Lattner Date: Mon, 17 Nov 2014 10:48:30 -0600 Subject: [PATCH] use float32 for example data --- forest/forest.go | 6 +- forest/iris_test.go | 302 ++++++++++++++++++++++---------------------- rf.go | 10 +- tree/iris_test.go | 302 ++++++++++++++++++++++---------------------- tree/sort.go | 18 +-- tree/split_test.go | 8 +- tree/tree.go | 25 ++-- 7 files changed, 336 insertions(+), 335 deletions(-) diff --git a/forest/forest.go b/forest/forest.go index 3f056e7..b8f83ed 100644 --- a/forest/forest.go +++ b/forest/forest.go @@ -130,7 +130,7 @@ func NewClassifier(options ...func(forestConfiger)) *ForestClassifier { // Fit constructs a forest from fitting n trees from the provided features X, and // labels Y. -func (f *ForestClassifier) Fit(X [][]float64, Y []string) { +func (f *ForestClassifier) Fit(X [][]float32, Y []string) { f.Trees = make([]*tree.Classifier, f.NTrees) if f.MaxFeatures < 0 { @@ -174,7 +174,7 @@ func (f *ForestClassifier) Fit(X [][]float64, Y []string) { } // Predict returns the most probable label for each example. -func (f *ForestClassifier) Predict(X [][]float64) []string { +func (f *ForestClassifier) Predict(X [][]float32) []string { p := f.PredictProb(X) maxC := make([]string, len(X)) @@ -199,7 +199,7 @@ func (f *ForestClassifier) Predict(X [][]float64) []string { // PredictProb returns the class probability for each example. The indices of the // return value correspond to Classifier.Classes. -func (f *ForestClassifier) PredictProb(X [][]float64) [][]float64 { +func (f *ForestClassifier) PredictProb(X [][]float32) [][]float64 { //TODO: weighted voting... probs := make([][]float64, len(X)) // initialize the other dim diff --git a/forest/iris_test.go b/forest/iris_test.go index 824cda1..207d2b3 100644 --- a/forest/iris_test.go +++ b/forest/iris_test.go @@ -69,157 +69,157 @@ func BenchmarkIrisPredict(b *testing.B) { } } -var X = [][]float64{ - []float64{3.5, 1.4, 5.1, 0.2}, - []float64{3.0, 1.4, 4.9, 0.2}, - []float64{3.2, 1.3, 4.7, 0.2}, - []float64{3.1, 1.5, 4.6, 0.2}, - []float64{3.6, 1.4, 5.0, 0.2}, - []float64{3.9, 1.7, 5.4, 0.4}, - []float64{3.4, 1.4, 4.6, 0.3}, - []float64{3.4, 1.5, 5.0, 0.2}, - []float64{2.9, 1.4, 4.4, 0.2}, - []float64{3.1, 1.5, 4.9, 0.1}, - []float64{3.7, 1.5, 5.4, 0.2}, - []float64{3.4, 1.6, 4.8, 0.2}, - []float64{3.0, 1.4, 4.8, 0.1}, - []float64{3.0, 1.1, 4.3, 0.1}, - []float64{4.0, 1.2, 5.8, 0.2}, - []float64{4.4, 1.5, 5.7, 0.4}, - []float64{3.9, 1.3, 5.4, 0.4}, - []float64{3.5, 1.4, 5.1, 0.3}, - []float64{3.8, 1.7, 5.7, 0.3}, - []float64{3.8, 1.5, 5.1, 0.3}, - []float64{3.4, 1.7, 5.4, 0.2}, - []float64{3.7, 1.5, 5.1, 0.4}, - []float64{3.6, 1.0, 4.6, 0.2}, - []float64{3.3, 1.7, 5.1, 0.5}, - []float64{3.4, 1.9, 4.8, 0.2}, - []float64{3.0, 1.6, 5.0, 0.2}, - []float64{3.4, 1.6, 5.0, 0.4}, - []float64{3.5, 1.5, 5.2, 0.2}, - []float64{3.4, 1.4, 5.2, 0.2}, - []float64{3.2, 1.6, 4.7, 0.2}, - []float64{3.1, 1.6, 4.8, 0.2}, - []float64{3.4, 1.5, 5.4, 0.4}, - []float64{4.1, 1.5, 5.2, 0.1}, - []float64{4.2, 1.4, 5.5, 0.2}, - []float64{3.1, 1.5, 4.9, 0.2}, - []float64{3.2, 1.2, 5.0, 0.2}, - []float64{3.5, 1.3, 5.5, 0.2}, - []float64{3.6, 1.4, 4.9, 0.1}, - []float64{3.0, 1.3, 4.4, 0.2}, - []float64{3.4, 1.5, 5.1, 0.2}, - []float64{3.5, 1.3, 5.0, 0.3}, - []float64{2.3, 1.3, 4.5, 0.3}, - []float64{3.2, 1.3, 4.4, 0.2}, - []float64{3.5, 1.6, 5.0, 0.6}, - []float64{3.8, 1.9, 5.1, 0.4}, - []float64{3.0, 1.4, 4.8, 0.3}, - []float64{3.8, 1.6, 5.1, 0.2}, - []float64{3.2, 1.4, 4.6, 0.2}, - []float64{3.7, 1.5, 5.3, 0.2}, - []float64{3.3, 1.4, 5.0, 0.2}, - []float64{3.2, 4.7, 7.0, 1.4}, - []float64{3.2, 4.5, 6.4, 1.5}, - []float64{3.1, 4.9, 6.9, 1.5}, - []float64{2.3, 4.0, 5.5, 1.3}, - []float64{2.8, 4.6, 6.5, 1.5}, - []float64{2.8, 4.5, 5.7, 1.3}, - []float64{3.3, 4.7, 6.3, 1.6}, - []float64{2.4, 3.3, 4.9, 1.0}, - []float64{2.9, 4.6, 6.6, 1.3}, - []float64{2.7, 3.9, 5.2, 1.4}, - []float64{2.0, 3.5, 5.0, 1.0}, - []float64{3.0, 4.2, 5.9, 1.5}, - []float64{2.2, 4.0, 6.0, 1.0}, - []float64{2.9, 4.7, 6.1, 1.4}, - []float64{2.9, 3.6, 5.6, 1.3}, - []float64{3.1, 4.4, 6.7, 1.4}, - []float64{3.0, 4.5, 5.6, 1.5}, - []float64{2.7, 4.1, 5.8, 1.0}, - []float64{2.2, 4.5, 6.2, 1.5}, - []float64{2.5, 3.9, 5.6, 1.1}, - []float64{3.2, 4.8, 5.9, 1.8}, - []float64{2.8, 4.0, 6.1, 1.3}, - []float64{2.5, 4.9, 6.3, 1.5}, - []float64{2.8, 4.7, 6.1, 1.2}, - []float64{2.9, 4.3, 6.4, 1.3}, - []float64{3.0, 4.4, 6.6, 1.4}, - []float64{2.8, 4.8, 6.8, 1.4}, - []float64{3.0, 5.0, 6.7, 1.7}, - []float64{2.9, 4.5, 6.0, 1.5}, - []float64{2.6, 3.5, 5.7, 1.0}, - []float64{2.4, 3.8, 5.5, 1.1}, - []float64{2.4, 3.7, 5.5, 1.0}, - []float64{2.7, 3.9, 5.8, 1.2}, - []float64{2.7, 5.1, 6.0, 1.6}, - []float64{3.0, 4.5, 5.4, 1.5}, - []float64{3.4, 4.5, 6.0, 1.6}, - []float64{3.1, 4.7, 6.7, 1.5}, - []float64{2.3, 4.4, 6.3, 1.3}, - []float64{3.0, 4.1, 5.6, 1.3}, - []float64{2.5, 4.0, 5.5, 1.3}, - []float64{2.6, 4.4, 5.5, 1.2}, - []float64{3.0, 4.6, 6.1, 1.4}, - []float64{2.6, 4.0, 5.8, 1.2}, - []float64{2.3, 3.3, 5.0, 1.0}, - []float64{2.7, 4.2, 5.6, 1.3}, - []float64{3.0, 4.2, 5.7, 1.2}, - []float64{2.9, 4.2, 5.7, 1.3}, - []float64{2.9, 4.3, 6.2, 1.3}, - []float64{2.5, 3.0, 5.1, 1.1}, - []float64{2.8, 4.1, 5.7, 1.3}, - []float64{3.3, 6.0, 6.3, 2.5}, - []float64{2.7, 5.1, 5.8, 1.9}, - []float64{3.0, 5.9, 7.1, 2.1}, - []float64{2.9, 5.6, 6.3, 1.8}, - []float64{3.0, 5.8, 6.5, 2.2}, - []float64{3.0, 6.6, 7.6, 2.1}, - []float64{2.5, 4.5, 4.9, 1.7}, - []float64{2.9, 6.3, 7.3, 1.8}, - []float64{2.5, 5.8, 6.7, 1.8}, - []float64{3.6, 6.1, 7.2, 2.5}, - []float64{3.2, 5.1, 6.5, 2.0}, - []float64{2.7, 5.3, 6.4, 1.9}, - []float64{3.0, 5.5, 6.8, 2.1}, - []float64{2.5, 5.0, 5.7, 2.0}, - []float64{2.8, 5.1, 5.8, 2.4}, - []float64{3.2, 5.3, 6.4, 2.3}, - []float64{3.0, 5.5, 6.5, 1.8}, - []float64{3.8, 6.7, 7.7, 2.2}, - []float64{2.6, 6.9, 7.7, 2.3}, - []float64{2.2, 5.0, 6.0, 1.5}, - []float64{3.2, 5.7, 6.9, 2.3}, - []float64{2.8, 4.9, 5.6, 2.0}, - []float64{2.8, 6.7, 7.7, 2.0}, - []float64{2.7, 4.9, 6.3, 1.8}, - []float64{3.3, 5.7, 6.7, 2.1}, - []float64{3.2, 6.0, 7.2, 1.8}, - []float64{2.8, 4.8, 6.2, 1.8}, - []float64{3.0, 4.9, 6.1, 1.8}, - []float64{2.8, 5.6, 6.4, 2.1}, - []float64{3.0, 5.8, 7.2, 1.6}, - []float64{2.8, 6.1, 7.4, 1.9}, - []float64{3.8, 6.4, 7.9, 2.0}, - []float64{2.8, 5.6, 6.4, 2.2}, - []float64{2.8, 5.1, 6.3, 1.5}, - []float64{2.6, 5.6, 6.1, 1.4}, - []float64{3.0, 6.1, 7.7, 2.3}, - []float64{3.4, 5.6, 6.3, 2.4}, - []float64{3.1, 5.5, 6.4, 1.8}, - []float64{3.0, 4.8, 6.0, 1.8}, - []float64{3.1, 5.4, 6.9, 2.1}, - []float64{3.1, 5.6, 6.7, 2.4}, - []float64{3.1, 5.1, 6.9, 2.3}, - []float64{2.7, 5.1, 5.8, 1.9}, - []float64{3.2, 5.9, 6.8, 2.3}, - []float64{3.3, 5.7, 6.7, 2.5}, - []float64{3.0, 5.2, 6.7, 2.3}, - []float64{2.5, 5.0, 6.3, 1.9}, - []float64{3.0, 5.2, 6.5, 2.0}, - []float64{3.4, 5.4, 6.2, 2.3}, - []float64{3.0, 5.1, 5.9, 1.8}, +var X = [][]float32{ + []float32{3.5, 1.4, 5.1, 0.2}, + []float32{3.0, 1.4, 4.9, 0.2}, + []float32{3.2, 1.3, 4.7, 0.2}, + []float32{3.1, 1.5, 4.6, 0.2}, + []float32{3.6, 1.4, 5.0, 0.2}, + []float32{3.9, 1.7, 5.4, 0.4}, + []float32{3.4, 1.4, 4.6, 0.3}, + []float32{3.4, 1.5, 5.0, 0.2}, + []float32{2.9, 1.4, 4.4, 0.2}, + []float32{3.1, 1.5, 4.9, 0.1}, + []float32{3.7, 1.5, 5.4, 0.2}, + []float32{3.4, 1.6, 4.8, 0.2}, + []float32{3.0, 1.4, 4.8, 0.1}, + []float32{3.0, 1.1, 4.3, 0.1}, + []float32{4.0, 1.2, 5.8, 0.2}, + []float32{4.4, 1.5, 5.7, 0.4}, + []float32{3.9, 1.3, 5.4, 0.4}, + []float32{3.5, 1.4, 5.1, 0.3}, + []float32{3.8, 1.7, 5.7, 0.3}, + []float32{3.8, 1.5, 5.1, 0.3}, + []float32{3.4, 1.7, 5.4, 0.2}, + []float32{3.7, 1.5, 5.1, 0.4}, + []float32{3.6, 1.0, 4.6, 0.2}, + []float32{3.3, 1.7, 5.1, 0.5}, + []float32{3.4, 1.9, 4.8, 0.2}, + []float32{3.0, 1.6, 5.0, 0.2}, + []float32{3.4, 1.6, 5.0, 0.4}, + []float32{3.5, 1.5, 5.2, 0.2}, + []float32{3.4, 1.4, 5.2, 0.2}, + []float32{3.2, 1.6, 4.7, 0.2}, + []float32{3.1, 1.6, 4.8, 0.2}, + []float32{3.4, 1.5, 5.4, 0.4}, + []float32{4.1, 1.5, 5.2, 0.1}, + []float32{4.2, 1.4, 5.5, 0.2}, + []float32{3.1, 1.5, 4.9, 0.2}, + []float32{3.2, 1.2, 5.0, 0.2}, + []float32{3.5, 1.3, 5.5, 0.2}, + []float32{3.6, 1.4, 4.9, 0.1}, + []float32{3.0, 1.3, 4.4, 0.2}, + []float32{3.4, 1.5, 5.1, 0.2}, + []float32{3.5, 1.3, 5.0, 0.3}, + []float32{2.3, 1.3, 4.5, 0.3}, + []float32{3.2, 1.3, 4.4, 0.2}, + []float32{3.5, 1.6, 5.0, 0.6}, + []float32{3.8, 1.9, 5.1, 0.4}, + []float32{3.0, 1.4, 4.8, 0.3}, + []float32{3.8, 1.6, 5.1, 0.2}, + []float32{3.2, 1.4, 4.6, 0.2}, + []float32{3.7, 1.5, 5.3, 0.2}, + []float32{3.3, 1.4, 5.0, 0.2}, + []float32{3.2, 4.7, 7.0, 1.4}, + []float32{3.2, 4.5, 6.4, 1.5}, + []float32{3.1, 4.9, 6.9, 1.5}, + []float32{2.3, 4.0, 5.5, 1.3}, + []float32{2.8, 4.6, 6.5, 1.5}, + []float32{2.8, 4.5, 5.7, 1.3}, + []float32{3.3, 4.7, 6.3, 1.6}, + []float32{2.4, 3.3, 4.9, 1.0}, + []float32{2.9, 4.6, 6.6, 1.3}, + []float32{2.7, 3.9, 5.2, 1.4}, + []float32{2.0, 3.5, 5.0, 1.0}, + []float32{3.0, 4.2, 5.9, 1.5}, + []float32{2.2, 4.0, 6.0, 1.0}, + []float32{2.9, 4.7, 6.1, 1.4}, + []float32{2.9, 3.6, 5.6, 1.3}, + []float32{3.1, 4.4, 6.7, 1.4}, + []float32{3.0, 4.5, 5.6, 1.5}, + []float32{2.7, 4.1, 5.8, 1.0}, + []float32{2.2, 4.5, 6.2, 1.5}, + []float32{2.5, 3.9, 5.6, 1.1}, + []float32{3.2, 4.8, 5.9, 1.8}, + []float32{2.8, 4.0, 6.1, 1.3}, + []float32{2.5, 4.9, 6.3, 1.5}, + []float32{2.8, 4.7, 6.1, 1.2}, + []float32{2.9, 4.3, 6.4, 1.3}, + []float32{3.0, 4.4, 6.6, 1.4}, + []float32{2.8, 4.8, 6.8, 1.4}, + []float32{3.0, 5.0, 6.7, 1.7}, + []float32{2.9, 4.5, 6.0, 1.5}, + []float32{2.6, 3.5, 5.7, 1.0}, + []float32{2.4, 3.8, 5.5, 1.1}, + []float32{2.4, 3.7, 5.5, 1.0}, + []float32{2.7, 3.9, 5.8, 1.2}, + []float32{2.7, 5.1, 6.0, 1.6}, + []float32{3.0, 4.5, 5.4, 1.5}, + []float32{3.4, 4.5, 6.0, 1.6}, + []float32{3.1, 4.7, 6.7, 1.5}, + []float32{2.3, 4.4, 6.3, 1.3}, + []float32{3.0, 4.1, 5.6, 1.3}, + []float32{2.5, 4.0, 5.5, 1.3}, + []float32{2.6, 4.4, 5.5, 1.2}, + []float32{3.0, 4.6, 6.1, 1.4}, + []float32{2.6, 4.0, 5.8, 1.2}, + []float32{2.3, 3.3, 5.0, 1.0}, + []float32{2.7, 4.2, 5.6, 1.3}, + []float32{3.0, 4.2, 5.7, 1.2}, + []float32{2.9, 4.2, 5.7, 1.3}, + []float32{2.9, 4.3, 6.2, 1.3}, + []float32{2.5, 3.0, 5.1, 1.1}, + []float32{2.8, 4.1, 5.7, 1.3}, + []float32{3.3, 6.0, 6.3, 2.5}, + []float32{2.7, 5.1, 5.8, 1.9}, + []float32{3.0, 5.9, 7.1, 2.1}, + []float32{2.9, 5.6, 6.3, 1.8}, + []float32{3.0, 5.8, 6.5, 2.2}, + []float32{3.0, 6.6, 7.6, 2.1}, + []float32{2.5, 4.5, 4.9, 1.7}, + []float32{2.9, 6.3, 7.3, 1.8}, + []float32{2.5, 5.8, 6.7, 1.8}, + []float32{3.6, 6.1, 7.2, 2.5}, + []float32{3.2, 5.1, 6.5, 2.0}, + []float32{2.7, 5.3, 6.4, 1.9}, + []float32{3.0, 5.5, 6.8, 2.1}, + []float32{2.5, 5.0, 5.7, 2.0}, + []float32{2.8, 5.1, 5.8, 2.4}, + []float32{3.2, 5.3, 6.4, 2.3}, + []float32{3.0, 5.5, 6.5, 1.8}, + []float32{3.8, 6.7, 7.7, 2.2}, + []float32{2.6, 6.9, 7.7, 2.3}, + []float32{2.2, 5.0, 6.0, 1.5}, + []float32{3.2, 5.7, 6.9, 2.3}, + []float32{2.8, 4.9, 5.6, 2.0}, + []float32{2.8, 6.7, 7.7, 2.0}, + []float32{2.7, 4.9, 6.3, 1.8}, + []float32{3.3, 5.7, 6.7, 2.1}, + []float32{3.2, 6.0, 7.2, 1.8}, + []float32{2.8, 4.8, 6.2, 1.8}, + []float32{3.0, 4.9, 6.1, 1.8}, + []float32{2.8, 5.6, 6.4, 2.1}, + []float32{3.0, 5.8, 7.2, 1.6}, + []float32{2.8, 6.1, 7.4, 1.9}, + []float32{3.8, 6.4, 7.9, 2.0}, + []float32{2.8, 5.6, 6.4, 2.2}, + []float32{2.8, 5.1, 6.3, 1.5}, + []float32{2.6, 5.6, 6.1, 1.4}, + []float32{3.0, 6.1, 7.7, 2.3}, + []float32{3.4, 5.6, 6.3, 2.4}, + []float32{3.1, 5.5, 6.4, 1.8}, + []float32{3.0, 4.8, 6.0, 1.8}, + []float32{3.1, 5.4, 6.9, 2.1}, + []float32{3.1, 5.6, 6.7, 2.4}, + []float32{3.1, 5.1, 6.9, 2.3}, + []float32{2.7, 5.1, 5.8, 1.9}, + []float32{3.2, 5.9, 6.8, 2.3}, + []float32{3.3, 5.7, 6.7, 2.5}, + []float32{3.0, 5.2, 6.7, 2.3}, + []float32{2.5, 5.0, 6.3, 1.9}, + []float32{3.0, 5.2, 6.5, 2.0}, + []float32{3.4, 5.4, 6.2, 2.3}, + []float32{3.0, 5.1, 5.9, 1.8}, } var XNames = []string{"Sepal.Width", "Petal.Length", "Sepal.Length", "Petal.Width"} diff --git a/rf.go b/rf.go index f30eef4..5d70814 100644 --- a/rf.go +++ b/rf.go @@ -164,11 +164,11 @@ func writePred(w io.Writer, prediction []string) error { return wtr.Flush() } -func parseCSV(r io.Reader) ([][]float64, []string, error) { +func parseCSV(r io.Reader) ([][]float32, []string, error) { reader := csv.NewReader(r) var ( - X [][]float64 + X [][]float32 Y []string ) @@ -183,13 +183,13 @@ func parseCSV(r io.Reader) ([][]float64, []string, error) { Y = append(Y, row[0]) - var rowVal []float64 + var rowVal []float32 for _, val := range row[1:] { // data starts in 2nd column - fv, err := strconv.ParseFloat(val, 64) + fv, err := strconv.ParseFloat(val, 32) if err != nil { return X, Y, err } - rowVal = append(rowVal, fv) + rowVal = append(rowVal, float32(fv)) } X = append(X, rowVal) } diff --git a/tree/iris_test.go b/tree/iris_test.go index 4730d44..adc2bc0 100644 --- a/tree/iris_test.go +++ b/tree/iris_test.go @@ -93,157 +93,157 @@ func (n *Node) String() string { return fmt.Sprintf("Inpurity: %f, classes: %v, Split Variable: %s, Split Val: %f, n: %d", n.Impurity, n.ClassCounts, XNames[n.SplitVar], n.SplitVal, n.Samples) } -var X = [][]float64{ - []float64{3.5, 1.4, 5.1, 0.2}, - []float64{3.0, 1.4, 4.9, 0.2}, - []float64{3.2, 1.3, 4.7, 0.2}, - []float64{3.1, 1.5, 4.6, 0.2}, - []float64{3.6, 1.4, 5.0, 0.2}, - []float64{3.9, 1.7, 5.4, 0.4}, - []float64{3.4, 1.4, 4.6, 0.3}, - []float64{3.4, 1.5, 5.0, 0.2}, - []float64{2.9, 1.4, 4.4, 0.2}, - []float64{3.1, 1.5, 4.9, 0.1}, - []float64{3.7, 1.5, 5.4, 0.2}, - []float64{3.4, 1.6, 4.8, 0.2}, - []float64{3.0, 1.4, 4.8, 0.1}, - []float64{3.0, 1.1, 4.3, 0.1}, - []float64{4.0, 1.2, 5.8, 0.2}, - []float64{4.4, 1.5, 5.7, 0.4}, - []float64{3.9, 1.3, 5.4, 0.4}, - []float64{3.5, 1.4, 5.1, 0.3}, - []float64{3.8, 1.7, 5.7, 0.3}, - []float64{3.8, 1.5, 5.1, 0.3}, - []float64{3.4, 1.7, 5.4, 0.2}, - []float64{3.7, 1.5, 5.1, 0.4}, - []float64{3.6, 1.0, 4.6, 0.2}, - []float64{3.3, 1.7, 5.1, 0.5}, - []float64{3.4, 1.9, 4.8, 0.2}, - []float64{3.0, 1.6, 5.0, 0.2}, - []float64{3.4, 1.6, 5.0, 0.4}, - []float64{3.5, 1.5, 5.2, 0.2}, - []float64{3.4, 1.4, 5.2, 0.2}, - []float64{3.2, 1.6, 4.7, 0.2}, - []float64{3.1, 1.6, 4.8, 0.2}, - []float64{3.4, 1.5, 5.4, 0.4}, - []float64{4.1, 1.5, 5.2, 0.1}, - []float64{4.2, 1.4, 5.5, 0.2}, - []float64{3.1, 1.5, 4.9, 0.2}, - []float64{3.2, 1.2, 5.0, 0.2}, - []float64{3.5, 1.3, 5.5, 0.2}, - []float64{3.6, 1.4, 4.9, 0.1}, - []float64{3.0, 1.3, 4.4, 0.2}, - []float64{3.4, 1.5, 5.1, 0.2}, - []float64{3.5, 1.3, 5.0, 0.3}, - []float64{2.3, 1.3, 4.5, 0.3}, - []float64{3.2, 1.3, 4.4, 0.2}, - []float64{3.5, 1.6, 5.0, 0.6}, - []float64{3.8, 1.9, 5.1, 0.4}, - []float64{3.0, 1.4, 4.8, 0.3}, - []float64{3.8, 1.6, 5.1, 0.2}, - []float64{3.2, 1.4, 4.6, 0.2}, - []float64{3.7, 1.5, 5.3, 0.2}, - []float64{3.3, 1.4, 5.0, 0.2}, - []float64{3.2, 4.7, 7.0, 1.4}, - []float64{3.2, 4.5, 6.4, 1.5}, - []float64{3.1, 4.9, 6.9, 1.5}, - []float64{2.3, 4.0, 5.5, 1.3}, - []float64{2.8, 4.6, 6.5, 1.5}, - []float64{2.8, 4.5, 5.7, 1.3}, - []float64{3.3, 4.7, 6.3, 1.6}, - []float64{2.4, 3.3, 4.9, 1.0}, - []float64{2.9, 4.6, 6.6, 1.3}, - []float64{2.7, 3.9, 5.2, 1.4}, - []float64{2.0, 3.5, 5.0, 1.0}, - []float64{3.0, 4.2, 5.9, 1.5}, - []float64{2.2, 4.0, 6.0, 1.0}, - []float64{2.9, 4.7, 6.1, 1.4}, - []float64{2.9, 3.6, 5.6, 1.3}, - []float64{3.1, 4.4, 6.7, 1.4}, - []float64{3.0, 4.5, 5.6, 1.5}, - []float64{2.7, 4.1, 5.8, 1.0}, - []float64{2.2, 4.5, 6.2, 1.5}, - []float64{2.5, 3.9, 5.6, 1.1}, - []float64{3.2, 4.8, 5.9, 1.8}, - []float64{2.8, 4.0, 6.1, 1.3}, - []float64{2.5, 4.9, 6.3, 1.5}, - []float64{2.8, 4.7, 6.1, 1.2}, - []float64{2.9, 4.3, 6.4, 1.3}, - []float64{3.0, 4.4, 6.6, 1.4}, - []float64{2.8, 4.8, 6.8, 1.4}, - []float64{3.0, 5.0, 6.7, 1.7}, - []float64{2.9, 4.5, 6.0, 1.5}, - []float64{2.6, 3.5, 5.7, 1.0}, - []float64{2.4, 3.8, 5.5, 1.1}, - []float64{2.4, 3.7, 5.5, 1.0}, - []float64{2.7, 3.9, 5.8, 1.2}, - []float64{2.7, 5.1, 6.0, 1.6}, - []float64{3.0, 4.5, 5.4, 1.5}, - []float64{3.4, 4.5, 6.0, 1.6}, - []float64{3.1, 4.7, 6.7, 1.5}, - []float64{2.3, 4.4, 6.3, 1.3}, - []float64{3.0, 4.1, 5.6, 1.3}, - []float64{2.5, 4.0, 5.5, 1.3}, - []float64{2.6, 4.4, 5.5, 1.2}, - []float64{3.0, 4.6, 6.1, 1.4}, - []float64{2.6, 4.0, 5.8, 1.2}, - []float64{2.3, 3.3, 5.0, 1.0}, - []float64{2.7, 4.2, 5.6, 1.3}, - []float64{3.0, 4.2, 5.7, 1.2}, - []float64{2.9, 4.2, 5.7, 1.3}, - []float64{2.9, 4.3, 6.2, 1.3}, - []float64{2.5, 3.0, 5.1, 1.1}, - []float64{2.8, 4.1, 5.7, 1.3}, - []float64{3.3, 6.0, 6.3, 2.5}, - []float64{2.7, 5.1, 5.8, 1.9}, - []float64{3.0, 5.9, 7.1, 2.1}, - []float64{2.9, 5.6, 6.3, 1.8}, - []float64{3.0, 5.8, 6.5, 2.2}, - []float64{3.0, 6.6, 7.6, 2.1}, - []float64{2.5, 4.5, 4.9, 1.7}, - []float64{2.9, 6.3, 7.3, 1.8}, - []float64{2.5, 5.8, 6.7, 1.8}, - []float64{3.6, 6.1, 7.2, 2.5}, - []float64{3.2, 5.1, 6.5, 2.0}, - []float64{2.7, 5.3, 6.4, 1.9}, - []float64{3.0, 5.5, 6.8, 2.1}, - []float64{2.5, 5.0, 5.7, 2.0}, - []float64{2.8, 5.1, 5.8, 2.4}, - []float64{3.2, 5.3, 6.4, 2.3}, - []float64{3.0, 5.5, 6.5, 1.8}, - []float64{3.8, 6.7, 7.7, 2.2}, - []float64{2.6, 6.9, 7.7, 2.3}, - []float64{2.2, 5.0, 6.0, 1.5}, - []float64{3.2, 5.7, 6.9, 2.3}, - []float64{2.8, 4.9, 5.6, 2.0}, - []float64{2.8, 6.7, 7.7, 2.0}, - []float64{2.7, 4.9, 6.3, 1.8}, - []float64{3.3, 5.7, 6.7, 2.1}, - []float64{3.2, 6.0, 7.2, 1.8}, - []float64{2.8, 4.8, 6.2, 1.8}, - []float64{3.0, 4.9, 6.1, 1.8}, - []float64{2.8, 5.6, 6.4, 2.1}, - []float64{3.0, 5.8, 7.2, 1.6}, - []float64{2.8, 6.1, 7.4, 1.9}, - []float64{3.8, 6.4, 7.9, 2.0}, - []float64{2.8, 5.6, 6.4, 2.2}, - []float64{2.8, 5.1, 6.3, 1.5}, - []float64{2.6, 5.6, 6.1, 1.4}, - []float64{3.0, 6.1, 7.7, 2.3}, - []float64{3.4, 5.6, 6.3, 2.4}, - []float64{3.1, 5.5, 6.4, 1.8}, - []float64{3.0, 4.8, 6.0, 1.8}, - []float64{3.1, 5.4, 6.9, 2.1}, - []float64{3.1, 5.6, 6.7, 2.4}, - []float64{3.1, 5.1, 6.9, 2.3}, - []float64{2.7, 5.1, 5.8, 1.9}, - []float64{3.2, 5.9, 6.8, 2.3}, - []float64{3.3, 5.7, 6.7, 2.5}, - []float64{3.0, 5.2, 6.7, 2.3}, - []float64{2.5, 5.0, 6.3, 1.9}, - []float64{3.0, 5.2, 6.5, 2.0}, - []float64{3.4, 5.4, 6.2, 2.3}, - []float64{3.0, 5.1, 5.9, 1.8}, +var X = [][]float32{ + []float32{3.5, 1.4, 5.1, 0.2}, + []float32{3.0, 1.4, 4.9, 0.2}, + []float32{3.2, 1.3, 4.7, 0.2}, + []float32{3.1, 1.5, 4.6, 0.2}, + []float32{3.6, 1.4, 5.0, 0.2}, + []float32{3.9, 1.7, 5.4, 0.4}, + []float32{3.4, 1.4, 4.6, 0.3}, + []float32{3.4, 1.5, 5.0, 0.2}, + []float32{2.9, 1.4, 4.4, 0.2}, + []float32{3.1, 1.5, 4.9, 0.1}, + []float32{3.7, 1.5, 5.4, 0.2}, + []float32{3.4, 1.6, 4.8, 0.2}, + []float32{3.0, 1.4, 4.8, 0.1}, + []float32{3.0, 1.1, 4.3, 0.1}, + []float32{4.0, 1.2, 5.8, 0.2}, + []float32{4.4, 1.5, 5.7, 0.4}, + []float32{3.9, 1.3, 5.4, 0.4}, + []float32{3.5, 1.4, 5.1, 0.3}, + []float32{3.8, 1.7, 5.7, 0.3}, + []float32{3.8, 1.5, 5.1, 0.3}, + []float32{3.4, 1.7, 5.4, 0.2}, + []float32{3.7, 1.5, 5.1, 0.4}, + []float32{3.6, 1.0, 4.6, 0.2}, + []float32{3.3, 1.7, 5.1, 0.5}, + []float32{3.4, 1.9, 4.8, 0.2}, + []float32{3.0, 1.6, 5.0, 0.2}, + []float32{3.4, 1.6, 5.0, 0.4}, + []float32{3.5, 1.5, 5.2, 0.2}, + []float32{3.4, 1.4, 5.2, 0.2}, + []float32{3.2, 1.6, 4.7, 0.2}, + []float32{3.1, 1.6, 4.8, 0.2}, + []float32{3.4, 1.5, 5.4, 0.4}, + []float32{4.1, 1.5, 5.2, 0.1}, + []float32{4.2, 1.4, 5.5, 0.2}, + []float32{3.1, 1.5, 4.9, 0.2}, + []float32{3.2, 1.2, 5.0, 0.2}, + []float32{3.5, 1.3, 5.5, 0.2}, + []float32{3.6, 1.4, 4.9, 0.1}, + []float32{3.0, 1.3, 4.4, 0.2}, + []float32{3.4, 1.5, 5.1, 0.2}, + []float32{3.5, 1.3, 5.0, 0.3}, + []float32{2.3, 1.3, 4.5, 0.3}, + []float32{3.2, 1.3, 4.4, 0.2}, + []float32{3.5, 1.6, 5.0, 0.6}, + []float32{3.8, 1.9, 5.1, 0.4}, + []float32{3.0, 1.4, 4.8, 0.3}, + []float32{3.8, 1.6, 5.1, 0.2}, + []float32{3.2, 1.4, 4.6, 0.2}, + []float32{3.7, 1.5, 5.3, 0.2}, + []float32{3.3, 1.4, 5.0, 0.2}, + []float32{3.2, 4.7, 7.0, 1.4}, + []float32{3.2, 4.5, 6.4, 1.5}, + []float32{3.1, 4.9, 6.9, 1.5}, + []float32{2.3, 4.0, 5.5, 1.3}, + []float32{2.8, 4.6, 6.5, 1.5}, + []float32{2.8, 4.5, 5.7, 1.3}, + []float32{3.3, 4.7, 6.3, 1.6}, + []float32{2.4, 3.3, 4.9, 1.0}, + []float32{2.9, 4.6, 6.6, 1.3}, + []float32{2.7, 3.9, 5.2, 1.4}, + []float32{2.0, 3.5, 5.0, 1.0}, + []float32{3.0, 4.2, 5.9, 1.5}, + []float32{2.2, 4.0, 6.0, 1.0}, + []float32{2.9, 4.7, 6.1, 1.4}, + []float32{2.9, 3.6, 5.6, 1.3}, + []float32{3.1, 4.4, 6.7, 1.4}, + []float32{3.0, 4.5, 5.6, 1.5}, + []float32{2.7, 4.1, 5.8, 1.0}, + []float32{2.2, 4.5, 6.2, 1.5}, + []float32{2.5, 3.9, 5.6, 1.1}, + []float32{3.2, 4.8, 5.9, 1.8}, + []float32{2.8, 4.0, 6.1, 1.3}, + []float32{2.5, 4.9, 6.3, 1.5}, + []float32{2.8, 4.7, 6.1, 1.2}, + []float32{2.9, 4.3, 6.4, 1.3}, + []float32{3.0, 4.4, 6.6, 1.4}, + []float32{2.8, 4.8, 6.8, 1.4}, + []float32{3.0, 5.0, 6.7, 1.7}, + []float32{2.9, 4.5, 6.0, 1.5}, + []float32{2.6, 3.5, 5.7, 1.0}, + []float32{2.4, 3.8, 5.5, 1.1}, + []float32{2.4, 3.7, 5.5, 1.0}, + []float32{2.7, 3.9, 5.8, 1.2}, + []float32{2.7, 5.1, 6.0, 1.6}, + []float32{3.0, 4.5, 5.4, 1.5}, + []float32{3.4, 4.5, 6.0, 1.6}, + []float32{3.1, 4.7, 6.7, 1.5}, + []float32{2.3, 4.4, 6.3, 1.3}, + []float32{3.0, 4.1, 5.6, 1.3}, + []float32{2.5, 4.0, 5.5, 1.3}, + []float32{2.6, 4.4, 5.5, 1.2}, + []float32{3.0, 4.6, 6.1, 1.4}, + []float32{2.6, 4.0, 5.8, 1.2}, + []float32{2.3, 3.3, 5.0, 1.0}, + []float32{2.7, 4.2, 5.6, 1.3}, + []float32{3.0, 4.2, 5.7, 1.2}, + []float32{2.9, 4.2, 5.7, 1.3}, + []float32{2.9, 4.3, 6.2, 1.3}, + []float32{2.5, 3.0, 5.1, 1.1}, + []float32{2.8, 4.1, 5.7, 1.3}, + []float32{3.3, 6.0, 6.3, 2.5}, + []float32{2.7, 5.1, 5.8, 1.9}, + []float32{3.0, 5.9, 7.1, 2.1}, + []float32{2.9, 5.6, 6.3, 1.8}, + []float32{3.0, 5.8, 6.5, 2.2}, + []float32{3.0, 6.6, 7.6, 2.1}, + []float32{2.5, 4.5, 4.9, 1.7}, + []float32{2.9, 6.3, 7.3, 1.8}, + []float32{2.5, 5.8, 6.7, 1.8}, + []float32{3.6, 6.1, 7.2, 2.5}, + []float32{3.2, 5.1, 6.5, 2.0}, + []float32{2.7, 5.3, 6.4, 1.9}, + []float32{3.0, 5.5, 6.8, 2.1}, + []float32{2.5, 5.0, 5.7, 2.0}, + []float32{2.8, 5.1, 5.8, 2.4}, + []float32{3.2, 5.3, 6.4, 2.3}, + []float32{3.0, 5.5, 6.5, 1.8}, + []float32{3.8, 6.7, 7.7, 2.2}, + []float32{2.6, 6.9, 7.7, 2.3}, + []float32{2.2, 5.0, 6.0, 1.5}, + []float32{3.2, 5.7, 6.9, 2.3}, + []float32{2.8, 4.9, 5.6, 2.0}, + []float32{2.8, 6.7, 7.7, 2.0}, + []float32{2.7, 4.9, 6.3, 1.8}, + []float32{3.3, 5.7, 6.7, 2.1}, + []float32{3.2, 6.0, 7.2, 1.8}, + []float32{2.8, 4.8, 6.2, 1.8}, + []float32{3.0, 4.9, 6.1, 1.8}, + []float32{2.8, 5.6, 6.4, 2.1}, + []float32{3.0, 5.8, 7.2, 1.6}, + []float32{2.8, 6.1, 7.4, 1.9}, + []float32{3.8, 6.4, 7.9, 2.0}, + []float32{2.8, 5.6, 6.4, 2.2}, + []float32{2.8, 5.1, 6.3, 1.5}, + []float32{2.6, 5.6, 6.1, 1.4}, + []float32{3.0, 6.1, 7.7, 2.3}, + []float32{3.4, 5.6, 6.3, 2.4}, + []float32{3.1, 5.5, 6.4, 1.8}, + []float32{3.0, 4.8, 6.0, 1.8}, + []float32{3.1, 5.4, 6.9, 2.1}, + []float32{3.1, 5.6, 6.7, 2.4}, + []float32{3.1, 5.1, 6.9, 2.3}, + []float32{2.7, 5.1, 5.8, 1.9}, + []float32{3.2, 5.9, 6.8, 2.3}, + []float32{3.3, 5.7, 6.7, 2.5}, + []float32{3.0, 5.2, 6.7, 2.3}, + []float32{2.5, 5.0, 6.3, 1.9}, + []float32{3.0, 5.2, 6.5, 2.0}, + []float32{3.4, 5.4, 6.2, 2.3}, + []float32{3.0, 5.1, 5.9, 1.8}, } var XNames = []string{"Sepal.Width", "Petal.Length", "Sepal.Length", "Petal.Width"} diff --git a/tree/sort.go b/tree/sort.go index a0f91cb..15dd4ee 100644 --- a/tree/sort.go +++ b/tree/sort.go @@ -14,13 +14,13 @@ func min(a, b int) int { return b } -func swap(x []float64, inx []int, i, j int) { +func swap(x []float32, inx []int, i, j int) { x[i], x[j] = x[j], x[i] inx[i], inx[j] = inx[j], inx[i] } // Insertion sort -func insertionSort(x []float64, inx []int, a, b int) { +func insertionSort(x []float32, inx []int, a, b int) { for i := a + 1; i < b; i++ { for j := i; j > a && x[j] < x[j-1]; j-- { swap(x, inx, j, j-1) @@ -30,7 +30,7 @@ func insertionSort(x []float64, inx []int, a, b int) { // siftDown implements the heap property on data[lo, hi). // first is an offset into the array where the root of the heap lies. -func siftDown(x []float64, inx []int, lo, hi, first int) { +func siftDown(x []float32, inx []int, lo, hi, first int) { root := lo for { child := 2*root + 1 @@ -48,7 +48,7 @@ func siftDown(x []float64, inx []int, lo, hi, first int) { } } -func heapSort(x []float64, inx []int, a, b int) { +func heapSort(x []float32, inx []int, a, b int) { first := a lo := 0 hi := b - a @@ -69,7 +69,7 @@ func heapSort(x []float64, inx []int, a, b int) { // ``Engineering a Sort Function,'' SP&E November 1993. // medianOfThree moves the median of the three values data[a], data[b], data[c] into data[a]. -func medianOfThree(x []float64, inx []int, a, b, c int) { +func medianOfThree(x []float32, inx []int, a, b, c int) { m0 := b m1 := a m2 := c @@ -86,13 +86,13 @@ func medianOfThree(x []float64, inx []int, a, b, c int) { // now data[m0] <= data[m1] <= data[m2] } -func swapRange(x []float64, inx []int, a, b, n int) { +func swapRange(x []float32, inx []int, a, b, n int) { for i := 0; i < n; i++ { swap(x, inx, a+i, b+i) } } -func doPivot(x []float64, inx []int, lo, hi int) (midlo, midhi int) { +func doPivot(x []float32, inx []int, lo, hi int) (midlo, midhi int) { m := lo + (hi-lo)/2 // Written like this to avoid integer overflow. if hi-lo > 40 { // Tukey's ``Ninther,'' median of three medians of three. @@ -156,7 +156,7 @@ func doPivot(x []float64, inx []int, lo, hi int) (midlo, midhi int) { return lo + b - a, hi - (d - c) } -func quickSort(x []float64, inx []int, a, b, maxDepth int) { +func quickSort(x []float32, inx []int, a, b, maxDepth int) { for b-a > 7 { if maxDepth == 0 { heapSort(x, inx, a, b) @@ -182,7 +182,7 @@ func quickSort(x []float64, inx []int, a, b, maxDepth int) { // Sort sorts data. // It makes one call to data.Len to determine n, and O(n*log(n)) calls to // data.Less and data.Swap. The sort is not guaranteed to be stable. -func bSort(x []float64, inx []int) { +func bSort(x []float32, inx []int) { // Switch to heapsort if depth of 2*ceil(lg(n+1)) is reached. n := len(inx) maxDepth := 0 diff --git a/tree/split_test.go b/tree/split_test.go index ac67a78..62003ba 100644 --- a/tree/split_test.go +++ b/tree/split_test.go @@ -8,7 +8,7 @@ import ( func TestBestSplit(t *testing.T) { clf := NewClassifier() - xi := []float64{0.08918780255911574, 0.097704546453666, 0.15739526725378827, 0.1772808696619108, 0.47001967423520297, 0.5621969807319502, 0.6055333992245421, 0.6462220030737842, 0.8020611535912714, 0.9244669313190392} + xi := []float32{0.08918780255911574, 0.097704546453666, 0.15739526725378827, 0.1772808696619108, 0.47001967423520297, 0.5621969807319502, 0.6055333992245421, 0.6462220030737842, 0.8020611535912714, 0.9244669313190392} y := []int{0, 0, 0, 0, 0, 1, 1, 1, 1, 0} classCount := []int{6, 4} inx := make([]int, len(y)) @@ -34,7 +34,7 @@ func TestBestSplit(t *testing.T) { func TestBestSplitConstant(t *testing.T) { clf := NewClassifier() - xi := []float64{1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1} + xi := []float32{1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1} y := []int{0, 0, 0, 0, 0, 1, 1, 1, 1, 0} classCount := []int{6, 4} inx := make([]int, len(y)) @@ -46,7 +46,7 @@ func TestBestSplitConstant(t *testing.T) { classCtR := make([]int, 2) copy(classCtR, classCount) sp, gain, _ := clf.bestSplit(xi, y, inx, 0.48, classCtL, classCtR) - spActual := 0.0 // feature is constant, should be no split + spActual := float32(0.0) // feature is constant, should be no split if sp != spActual { t.Error("expected split to be:", spActual, " got:", sp) } @@ -59,7 +59,7 @@ func TestBestSplitConstant(t *testing.T) { func TestBestSplitSomeConstant(t *testing.T) { clf := NewClassifier() - xi := []float64{0.08918780255911574, 0.09, 0.09, 0.09, 0.47001967423520297, 0.5621969807319502, 0.6055333992245421, 0.6462220030737842, 0.8020611535912714, 0.9244669313190392} + xi := []float32{0.08918780255911574, 0.09, 0.09, 0.09, 0.47001967423520297, 0.5621969807319502, 0.6055333992245421, 0.6462220030737842, 0.8020611535912714, 0.9244669313190392} y := []int{0, 0, 0, 0, 0, 1, 1, 1, 1, 0} classCount := []int{6, 4} inx := make([]int, len(y)) diff --git a/tree/tree.go b/tree/tree.go index d8db9fc..b1bf8a0 100644 --- a/tree/tree.go +++ b/tree/tree.go @@ -132,7 +132,7 @@ func NewClassifier(options ...func(treeConfiger)) *Classifier { } // Fit constructs a tree from the provided features X, and labels Y. -func (t *Classifier) Fit(X [][]float64, Y []string) { +func (t *Classifier) Fit(X [][]float32, Y []string) { inx := make([]int, len(Y)) for i := 0; i < len(Y); i++ { inx[i] = i @@ -144,12 +144,12 @@ func (t *Classifier) Fit(X [][]float64, Y []string) { // FitInx constructs a tree as in Fit, but uses the inx slice to mask // the examples in X and Y. FitInx is intended to be used with meta algorithm // that rely on bootstrap sampling, such as RandomForest. -func (t *Classifier) FitInx(X [][]float64, Y []string, inx []int) { +func (t *Classifier) FitInx(X [][]float32, Y []string, inx []int) { //TODO: []int for Y instead; caller's responsibility to keep track t.fit(X, Y, inx) } -func (t *Classifier) fit(X [][]float64, Y []string, inx []int) { +func (t *Classifier) fit(X [][]float32, Y []string, inx []int) { // all examples are in root node t.Root = &Node{Samples: len(inx)} @@ -184,7 +184,7 @@ func (t *Classifier) fit(X [][]float64, Y []string, inx []int) { } // working copies of features and labels - xBuf := make([]float64, len(yIDs)) + xBuf := make([]float32, len(yIDs)) classCtL := make([]int, len(uniq)) classCtR := make([]int, len(uniq)) @@ -213,7 +213,7 @@ func (t *Classifier) fit(X [][]float64, Y []string, inx []int) { var ( dBest float64 - vBest float64 + vBest float32 xBest int iBest int ) @@ -302,7 +302,7 @@ func (t *Classifier) fit(X [][]float64, Y []string, inx []int) { } // Predict returns the most probable label for each example. -func (t *Classifier) Predict(X [][]float64) []string { +func (t *Classifier) Predict(X [][]float32) []string { p := make([]string, len(X)) for i := range p { @@ -331,7 +331,7 @@ func (t *Classifier) Predict(X [][]float64) []string { // PredictProb returns the class probability for each example. The indices // of the return value correspond to Classifier.Classes. -func (t *Classifier) PredictProb(X [][]float64) [][]float64 { +func (t *Classifier) PredictProb(X [][]float32) [][]float64 { p := make([][]float64, len(X)) for i := range p { @@ -368,12 +368,13 @@ func (t *Classifier) Load(r io.Reader) error { // this function takes a lot of args // classCtl and classCtR should be initialized by the caller, classCtL should // be all zeros, classCtR should be the counts for the current node -func (t *Classifier) bestSplit(xi []float64, y []int, inx []int, dInit float64, - classCtL []int, classCtR []int) (float64, float64, int) { +func (t *Classifier) bestSplit(xi []float32, y []int, inx []int, dInit float64, + classCtL []int, classCtR []int) (float32, float64, int) { var ( - dBest, vBest, v, d float64 - pos int + dBest, d float64 + vBest, v float32 + pos int ) n := len(xi) @@ -461,7 +462,7 @@ type Node struct { Left *Node Right *Node SplitVar int - SplitVal float64 + SplitVal float32 //TODO: do we need to store class counts at each node? ClassCounts []int Impurity float64