Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions forest/forest.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func NewClassifier(options ...func(forestConfiger)) *ForestClassifier {

// Fit constructs a forest from fitting n trees from the provided features X, and
// labels Y.
func (f *ForestClassifier) Fit(X [][]float64, Y []string) {
func (f *ForestClassifier) Fit(X [][]float32, Y []string) {
// labels as integer ids, ensure all trees know about all classes
var yIDs []int
uniq := make(map[string]int)
Expand Down Expand Up @@ -187,7 +187,7 @@ func (f *ForestClassifier) Fit(X [][]float64, Y []string) {
}

// Predict returns the most probable label for each example.
func (f *ForestClassifier) Predict(X [][]float64) []string {
func (f *ForestClassifier) Predict(X [][]float32) []string {
p := f.PredictProb(X)
maxC := make([]string, len(X))

Expand All @@ -212,7 +212,7 @@ func (f *ForestClassifier) Predict(X [][]float64) []string {

// PredictProb returns the class probability for each example. The indices of the
// return value correspond to Classifier.Classes.
func (f *ForestClassifier) PredictProb(X [][]float64) [][]float64 {
func (f *ForestClassifier) PredictProb(X [][]float32) [][]float64 {
//TODO: weighted voting...
probs := make([][]float64, len(X))
// initialize the other dim
Expand Down
302 changes: 151 additions & 151 deletions forest/iris_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,157 +69,157 @@ func BenchmarkIrisPredict(b *testing.B) {
}
}

var X = [][]float64{
[]float64{3.5, 1.4, 5.1, 0.2},
[]float64{3.0, 1.4, 4.9, 0.2},
[]float64{3.2, 1.3, 4.7, 0.2},
[]float64{3.1, 1.5, 4.6, 0.2},
[]float64{3.6, 1.4, 5.0, 0.2},
[]float64{3.9, 1.7, 5.4, 0.4},
[]float64{3.4, 1.4, 4.6, 0.3},
[]float64{3.4, 1.5, 5.0, 0.2},
[]float64{2.9, 1.4, 4.4, 0.2},
[]float64{3.1, 1.5, 4.9, 0.1},
[]float64{3.7, 1.5, 5.4, 0.2},
[]float64{3.4, 1.6, 4.8, 0.2},
[]float64{3.0, 1.4, 4.8, 0.1},
[]float64{3.0, 1.1, 4.3, 0.1},
[]float64{4.0, 1.2, 5.8, 0.2},
[]float64{4.4, 1.5, 5.7, 0.4},
[]float64{3.9, 1.3, 5.4, 0.4},
[]float64{3.5, 1.4, 5.1, 0.3},
[]float64{3.8, 1.7, 5.7, 0.3},
[]float64{3.8, 1.5, 5.1, 0.3},
[]float64{3.4, 1.7, 5.4, 0.2},
[]float64{3.7, 1.5, 5.1, 0.4},
[]float64{3.6, 1.0, 4.6, 0.2},
[]float64{3.3, 1.7, 5.1, 0.5},
[]float64{3.4, 1.9, 4.8, 0.2},
[]float64{3.0, 1.6, 5.0, 0.2},
[]float64{3.4, 1.6, 5.0, 0.4},
[]float64{3.5, 1.5, 5.2, 0.2},
[]float64{3.4, 1.4, 5.2, 0.2},
[]float64{3.2, 1.6, 4.7, 0.2},
[]float64{3.1, 1.6, 4.8, 0.2},
[]float64{3.4, 1.5, 5.4, 0.4},
[]float64{4.1, 1.5, 5.2, 0.1},
[]float64{4.2, 1.4, 5.5, 0.2},
[]float64{3.1, 1.5, 4.9, 0.2},
[]float64{3.2, 1.2, 5.0, 0.2},
[]float64{3.5, 1.3, 5.5, 0.2},
[]float64{3.6, 1.4, 4.9, 0.1},
[]float64{3.0, 1.3, 4.4, 0.2},
[]float64{3.4, 1.5, 5.1, 0.2},
[]float64{3.5, 1.3, 5.0, 0.3},
[]float64{2.3, 1.3, 4.5, 0.3},
[]float64{3.2, 1.3, 4.4, 0.2},
[]float64{3.5, 1.6, 5.0, 0.6},
[]float64{3.8, 1.9, 5.1, 0.4},
[]float64{3.0, 1.4, 4.8, 0.3},
[]float64{3.8, 1.6, 5.1, 0.2},
[]float64{3.2, 1.4, 4.6, 0.2},
[]float64{3.7, 1.5, 5.3, 0.2},
[]float64{3.3, 1.4, 5.0, 0.2},
[]float64{3.2, 4.7, 7.0, 1.4},
[]float64{3.2, 4.5, 6.4, 1.5},
[]float64{3.1, 4.9, 6.9, 1.5},
[]float64{2.3, 4.0, 5.5, 1.3},
[]float64{2.8, 4.6, 6.5, 1.5},
[]float64{2.8, 4.5, 5.7, 1.3},
[]float64{3.3, 4.7, 6.3, 1.6},
[]float64{2.4, 3.3, 4.9, 1.0},
[]float64{2.9, 4.6, 6.6, 1.3},
[]float64{2.7, 3.9, 5.2, 1.4},
[]float64{2.0, 3.5, 5.0, 1.0},
[]float64{3.0, 4.2, 5.9, 1.5},
[]float64{2.2, 4.0, 6.0, 1.0},
[]float64{2.9, 4.7, 6.1, 1.4},
[]float64{2.9, 3.6, 5.6, 1.3},
[]float64{3.1, 4.4, 6.7, 1.4},
[]float64{3.0, 4.5, 5.6, 1.5},
[]float64{2.7, 4.1, 5.8, 1.0},
[]float64{2.2, 4.5, 6.2, 1.5},
[]float64{2.5, 3.9, 5.6, 1.1},
[]float64{3.2, 4.8, 5.9, 1.8},
[]float64{2.8, 4.0, 6.1, 1.3},
[]float64{2.5, 4.9, 6.3, 1.5},
[]float64{2.8, 4.7, 6.1, 1.2},
[]float64{2.9, 4.3, 6.4, 1.3},
[]float64{3.0, 4.4, 6.6, 1.4},
[]float64{2.8, 4.8, 6.8, 1.4},
[]float64{3.0, 5.0, 6.7, 1.7},
[]float64{2.9, 4.5, 6.0, 1.5},
[]float64{2.6, 3.5, 5.7, 1.0},
[]float64{2.4, 3.8, 5.5, 1.1},
[]float64{2.4, 3.7, 5.5, 1.0},
[]float64{2.7, 3.9, 5.8, 1.2},
[]float64{2.7, 5.1, 6.0, 1.6},
[]float64{3.0, 4.5, 5.4, 1.5},
[]float64{3.4, 4.5, 6.0, 1.6},
[]float64{3.1, 4.7, 6.7, 1.5},
[]float64{2.3, 4.4, 6.3, 1.3},
[]float64{3.0, 4.1, 5.6, 1.3},
[]float64{2.5, 4.0, 5.5, 1.3},
[]float64{2.6, 4.4, 5.5, 1.2},
[]float64{3.0, 4.6, 6.1, 1.4},
[]float64{2.6, 4.0, 5.8, 1.2},
[]float64{2.3, 3.3, 5.0, 1.0},
[]float64{2.7, 4.2, 5.6, 1.3},
[]float64{3.0, 4.2, 5.7, 1.2},
[]float64{2.9, 4.2, 5.7, 1.3},
[]float64{2.9, 4.3, 6.2, 1.3},
[]float64{2.5, 3.0, 5.1, 1.1},
[]float64{2.8, 4.1, 5.7, 1.3},
[]float64{3.3, 6.0, 6.3, 2.5},
[]float64{2.7, 5.1, 5.8, 1.9},
[]float64{3.0, 5.9, 7.1, 2.1},
[]float64{2.9, 5.6, 6.3, 1.8},
[]float64{3.0, 5.8, 6.5, 2.2},
[]float64{3.0, 6.6, 7.6, 2.1},
[]float64{2.5, 4.5, 4.9, 1.7},
[]float64{2.9, 6.3, 7.3, 1.8},
[]float64{2.5, 5.8, 6.7, 1.8},
[]float64{3.6, 6.1, 7.2, 2.5},
[]float64{3.2, 5.1, 6.5, 2.0},
[]float64{2.7, 5.3, 6.4, 1.9},
[]float64{3.0, 5.5, 6.8, 2.1},
[]float64{2.5, 5.0, 5.7, 2.0},
[]float64{2.8, 5.1, 5.8, 2.4},
[]float64{3.2, 5.3, 6.4, 2.3},
[]float64{3.0, 5.5, 6.5, 1.8},
[]float64{3.8, 6.7, 7.7, 2.2},
[]float64{2.6, 6.9, 7.7, 2.3},
[]float64{2.2, 5.0, 6.0, 1.5},
[]float64{3.2, 5.7, 6.9, 2.3},
[]float64{2.8, 4.9, 5.6, 2.0},
[]float64{2.8, 6.7, 7.7, 2.0},
[]float64{2.7, 4.9, 6.3, 1.8},
[]float64{3.3, 5.7, 6.7, 2.1},
[]float64{3.2, 6.0, 7.2, 1.8},
[]float64{2.8, 4.8, 6.2, 1.8},
[]float64{3.0, 4.9, 6.1, 1.8},
[]float64{2.8, 5.6, 6.4, 2.1},
[]float64{3.0, 5.8, 7.2, 1.6},
[]float64{2.8, 6.1, 7.4, 1.9},
[]float64{3.8, 6.4, 7.9, 2.0},
[]float64{2.8, 5.6, 6.4, 2.2},
[]float64{2.8, 5.1, 6.3, 1.5},
[]float64{2.6, 5.6, 6.1, 1.4},
[]float64{3.0, 6.1, 7.7, 2.3},
[]float64{3.4, 5.6, 6.3, 2.4},
[]float64{3.1, 5.5, 6.4, 1.8},
[]float64{3.0, 4.8, 6.0, 1.8},
[]float64{3.1, 5.4, 6.9, 2.1},
[]float64{3.1, 5.6, 6.7, 2.4},
[]float64{3.1, 5.1, 6.9, 2.3},
[]float64{2.7, 5.1, 5.8, 1.9},
[]float64{3.2, 5.9, 6.8, 2.3},
[]float64{3.3, 5.7, 6.7, 2.5},
[]float64{3.0, 5.2, 6.7, 2.3},
[]float64{2.5, 5.0, 6.3, 1.9},
[]float64{3.0, 5.2, 6.5, 2.0},
[]float64{3.4, 5.4, 6.2, 2.3},
[]float64{3.0, 5.1, 5.9, 1.8},
var X = [][]float32{
[]float32{3.5, 1.4, 5.1, 0.2},
[]float32{3.0, 1.4, 4.9, 0.2},
[]float32{3.2, 1.3, 4.7, 0.2},
[]float32{3.1, 1.5, 4.6, 0.2},
[]float32{3.6, 1.4, 5.0, 0.2},
[]float32{3.9, 1.7, 5.4, 0.4},
[]float32{3.4, 1.4, 4.6, 0.3},
[]float32{3.4, 1.5, 5.0, 0.2},
[]float32{2.9, 1.4, 4.4, 0.2},
[]float32{3.1, 1.5, 4.9, 0.1},
[]float32{3.7, 1.5, 5.4, 0.2},
[]float32{3.4, 1.6, 4.8, 0.2},
[]float32{3.0, 1.4, 4.8, 0.1},
[]float32{3.0, 1.1, 4.3, 0.1},
[]float32{4.0, 1.2, 5.8, 0.2},
[]float32{4.4, 1.5, 5.7, 0.4},
[]float32{3.9, 1.3, 5.4, 0.4},
[]float32{3.5, 1.4, 5.1, 0.3},
[]float32{3.8, 1.7, 5.7, 0.3},
[]float32{3.8, 1.5, 5.1, 0.3},
[]float32{3.4, 1.7, 5.4, 0.2},
[]float32{3.7, 1.5, 5.1, 0.4},
[]float32{3.6, 1.0, 4.6, 0.2},
[]float32{3.3, 1.7, 5.1, 0.5},
[]float32{3.4, 1.9, 4.8, 0.2},
[]float32{3.0, 1.6, 5.0, 0.2},
[]float32{3.4, 1.6, 5.0, 0.4},
[]float32{3.5, 1.5, 5.2, 0.2},
[]float32{3.4, 1.4, 5.2, 0.2},
[]float32{3.2, 1.6, 4.7, 0.2},
[]float32{3.1, 1.6, 4.8, 0.2},
[]float32{3.4, 1.5, 5.4, 0.4},
[]float32{4.1, 1.5, 5.2, 0.1},
[]float32{4.2, 1.4, 5.5, 0.2},
[]float32{3.1, 1.5, 4.9, 0.2},
[]float32{3.2, 1.2, 5.0, 0.2},
[]float32{3.5, 1.3, 5.5, 0.2},
[]float32{3.6, 1.4, 4.9, 0.1},
[]float32{3.0, 1.3, 4.4, 0.2},
[]float32{3.4, 1.5, 5.1, 0.2},
[]float32{3.5, 1.3, 5.0, 0.3},
[]float32{2.3, 1.3, 4.5, 0.3},
[]float32{3.2, 1.3, 4.4, 0.2},
[]float32{3.5, 1.6, 5.0, 0.6},
[]float32{3.8, 1.9, 5.1, 0.4},
[]float32{3.0, 1.4, 4.8, 0.3},
[]float32{3.8, 1.6, 5.1, 0.2},
[]float32{3.2, 1.4, 4.6, 0.2},
[]float32{3.7, 1.5, 5.3, 0.2},
[]float32{3.3, 1.4, 5.0, 0.2},
[]float32{3.2, 4.7, 7.0, 1.4},
[]float32{3.2, 4.5, 6.4, 1.5},
[]float32{3.1, 4.9, 6.9, 1.5},
[]float32{2.3, 4.0, 5.5, 1.3},
[]float32{2.8, 4.6, 6.5, 1.5},
[]float32{2.8, 4.5, 5.7, 1.3},
[]float32{3.3, 4.7, 6.3, 1.6},
[]float32{2.4, 3.3, 4.9, 1.0},
[]float32{2.9, 4.6, 6.6, 1.3},
[]float32{2.7, 3.9, 5.2, 1.4},
[]float32{2.0, 3.5, 5.0, 1.0},
[]float32{3.0, 4.2, 5.9, 1.5},
[]float32{2.2, 4.0, 6.0, 1.0},
[]float32{2.9, 4.7, 6.1, 1.4},
[]float32{2.9, 3.6, 5.6, 1.3},
[]float32{3.1, 4.4, 6.7, 1.4},
[]float32{3.0, 4.5, 5.6, 1.5},
[]float32{2.7, 4.1, 5.8, 1.0},
[]float32{2.2, 4.5, 6.2, 1.5},
[]float32{2.5, 3.9, 5.6, 1.1},
[]float32{3.2, 4.8, 5.9, 1.8},
[]float32{2.8, 4.0, 6.1, 1.3},
[]float32{2.5, 4.9, 6.3, 1.5},
[]float32{2.8, 4.7, 6.1, 1.2},
[]float32{2.9, 4.3, 6.4, 1.3},
[]float32{3.0, 4.4, 6.6, 1.4},
[]float32{2.8, 4.8, 6.8, 1.4},
[]float32{3.0, 5.0, 6.7, 1.7},
[]float32{2.9, 4.5, 6.0, 1.5},
[]float32{2.6, 3.5, 5.7, 1.0},
[]float32{2.4, 3.8, 5.5, 1.1},
[]float32{2.4, 3.7, 5.5, 1.0},
[]float32{2.7, 3.9, 5.8, 1.2},
[]float32{2.7, 5.1, 6.0, 1.6},
[]float32{3.0, 4.5, 5.4, 1.5},
[]float32{3.4, 4.5, 6.0, 1.6},
[]float32{3.1, 4.7, 6.7, 1.5},
[]float32{2.3, 4.4, 6.3, 1.3},
[]float32{3.0, 4.1, 5.6, 1.3},
[]float32{2.5, 4.0, 5.5, 1.3},
[]float32{2.6, 4.4, 5.5, 1.2},
[]float32{3.0, 4.6, 6.1, 1.4},
[]float32{2.6, 4.0, 5.8, 1.2},
[]float32{2.3, 3.3, 5.0, 1.0},
[]float32{2.7, 4.2, 5.6, 1.3},
[]float32{3.0, 4.2, 5.7, 1.2},
[]float32{2.9, 4.2, 5.7, 1.3},
[]float32{2.9, 4.3, 6.2, 1.3},
[]float32{2.5, 3.0, 5.1, 1.1},
[]float32{2.8, 4.1, 5.7, 1.3},
[]float32{3.3, 6.0, 6.3, 2.5},
[]float32{2.7, 5.1, 5.8, 1.9},
[]float32{3.0, 5.9, 7.1, 2.1},
[]float32{2.9, 5.6, 6.3, 1.8},
[]float32{3.0, 5.8, 6.5, 2.2},
[]float32{3.0, 6.6, 7.6, 2.1},
[]float32{2.5, 4.5, 4.9, 1.7},
[]float32{2.9, 6.3, 7.3, 1.8},
[]float32{2.5, 5.8, 6.7, 1.8},
[]float32{3.6, 6.1, 7.2, 2.5},
[]float32{3.2, 5.1, 6.5, 2.0},
[]float32{2.7, 5.3, 6.4, 1.9},
[]float32{3.0, 5.5, 6.8, 2.1},
[]float32{2.5, 5.0, 5.7, 2.0},
[]float32{2.8, 5.1, 5.8, 2.4},
[]float32{3.2, 5.3, 6.4, 2.3},
[]float32{3.0, 5.5, 6.5, 1.8},
[]float32{3.8, 6.7, 7.7, 2.2},
[]float32{2.6, 6.9, 7.7, 2.3},
[]float32{2.2, 5.0, 6.0, 1.5},
[]float32{3.2, 5.7, 6.9, 2.3},
[]float32{2.8, 4.9, 5.6, 2.0},
[]float32{2.8, 6.7, 7.7, 2.0},
[]float32{2.7, 4.9, 6.3, 1.8},
[]float32{3.3, 5.7, 6.7, 2.1},
[]float32{3.2, 6.0, 7.2, 1.8},
[]float32{2.8, 4.8, 6.2, 1.8},
[]float32{3.0, 4.9, 6.1, 1.8},
[]float32{2.8, 5.6, 6.4, 2.1},
[]float32{3.0, 5.8, 7.2, 1.6},
[]float32{2.8, 6.1, 7.4, 1.9},
[]float32{3.8, 6.4, 7.9, 2.0},
[]float32{2.8, 5.6, 6.4, 2.2},
[]float32{2.8, 5.1, 6.3, 1.5},
[]float32{2.6, 5.6, 6.1, 1.4},
[]float32{3.0, 6.1, 7.7, 2.3},
[]float32{3.4, 5.6, 6.3, 2.4},
[]float32{3.1, 5.5, 6.4, 1.8},
[]float32{3.0, 4.8, 6.0, 1.8},
[]float32{3.1, 5.4, 6.9, 2.1},
[]float32{3.1, 5.6, 6.7, 2.4},
[]float32{3.1, 5.1, 6.9, 2.3},
[]float32{2.7, 5.1, 5.8, 1.9},
[]float32{3.2, 5.9, 6.8, 2.3},
[]float32{3.3, 5.7, 6.7, 2.5},
[]float32{3.0, 5.2, 6.7, 2.3},
[]float32{2.5, 5.0, 6.3, 1.9},
[]float32{3.0, 5.2, 6.5, 2.0},
[]float32{3.4, 5.4, 6.2, 2.3},
[]float32{3.0, 5.1, 5.9, 1.8},
}

var XNames = []string{"Sepal.Width", "Petal.Length", "Sepal.Length", "Petal.Width"}
Expand Down
10 changes: 5 additions & 5 deletions rf.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ func writePred(w io.Writer, prediction []string) error {
return wtr.Flush()
}

func parseCSV(r io.Reader) ([][]float64, []string, error) {
func parseCSV(r io.Reader) ([][]float32, []string, error) {
reader := csv.NewReader(r)

var (
X [][]float64
X [][]float32
Y []string
)

Expand All @@ -183,13 +183,13 @@ func parseCSV(r io.Reader) ([][]float64, []string, error) {

Y = append(Y, row[0])

var rowVal []float64
var rowVal []float32
for _, val := range row[1:] { // data starts in 2nd column
fv, err := strconv.ParseFloat(val, 64)
fv, err := strconv.ParseFloat(val, 32)
if err != nil {
return X, Y, err
}
rowVal = append(rowVal, fv)
rowVal = append(rowVal, float32(fv))
}
X = append(X, rowVal)
}
Expand Down
Loading