最小二乘回归树Golang多线程实现

来源:互联网 发布:java迭代器模式 编辑:程序博客网 时间:2024/06/05 03:37

回归树是提升树的基础, 算法之前已经说过了, 现在用golang实现了一个多线程的版本, 由于没有用矩阵库,看起来比会比较啰嗦. 基本思想是通过一个函数收集所有叶子节点的地址, 然后并行地对叶子节点进行展开.

package mainimport (    "bufio"    "fmt"    "io"    "os"    "sort"    "strconv"    "strings"    "sync")//定义向量type Vector []float32//为了排序定义下面的方法func (v Vector) Len() int {    return len(v)}func (v Vector) Less(i, j int) bool {    return v[i] < v[j]}func (v Vector) Swap(i, j int) {    v[i], v[j] = v[j], v[i]}func (v Vector) Mean() (res float32) {    for _, value := range v {        res += value    }    return res / float32(v.Len())}//定义向量组type Data []Vectortype MSETree struct {    data   Data    index  []int    left   *MSETree    right  *MSETree    c      float32    isLeaf bool    j      int    s      float32}//预测一个点的输出func (node *MSETree) Predict(point Vector) float32 {    p := node    for {        if p.isLeaf {            return p.c        }        if point[p.j] < p.s {            p = p.left        } else {            p = p.right        }    }}//文件读取func Loader(fileName string) (Data, Vector, error) {    f, err := os.Open(fileName)    if err != nil {        return nil, nil, err    }    reader := bufio.NewReader(f)    var data Data = make([]Vector, 0)    var z Vector = make([]float32, 0)    for {        line, _, err := reader.ReadLine()        if err != nil {            if err == io.EOF {                return data, z, nil            }            return nil, nil, err        }        lineSplited := strings.Split(string(line), ",")        var temp Vector = make([]float32, 0)        l := len(lineSplited)        for i := 0; i < l-1; i++ {            digit, _ := strconv.ParseFloat(lineSplited[i], 32)            temp = append(temp, float32(digit))        }        data = append(data, temp)        digit, _ := strconv.ParseFloat(lineSplited[l-1], 32)        z = append(z, float32(digit))    }}//针对属性j枚举所有可能的划分点func getSplitters(data Data, slice []int, j int) Vector {    var arr Vector = make([]float32, len(slice))    for i, idx := range slice {        arr[i] = data[idx][j]    }    sort.Sort(arr)    var spliters Vector = make([]float32, len(arr)-1)    for i := range spliters {        spliters[i] = (arr[i] + arr[i+1]) / 2    }    return spliters}//根据属性j和分割点s划分数据func split(data Data, slice []int, j int, s float32) (slice1, slice2 []int) {    for _, idx := range slice {        if data[idx][j] <= s {            slice1 = append(slice1, idx)        } else {            slice2 = append(slice2, idx)        }    }    return}//计算最小平方误差func sqErr(z Vector, slice []int) (errs, c float32) {    l := len(slice)    for _, idx := range slice {        c += z[idx]    }    c = c / float32(l)    for _, idx := range slice {        errs += (z[idx] - c) * (z[idx] - c)    }    return}//枚举所有可能的分隔属性和分割点,寻找最佳划分使得平方误差和最小func bestSpliter(data Data, z Vector, slice []int) (best_j int, best_s float32,    best_c1, best_c2 float32,    best_slice1, best_slice2 []int) {    dim := len(data[0])    var sqrErr float32 = 10e6    for j := 0; j < dim; j++ {        spliters := getSplitters(data, slice, j)        for _, s := range spliters {            slice1, slice2 := split(data, slice, j, s)            err1, c1 := sqErr(z, slice1)            err2, c2 := sqErr(z, slice2)            newSqrErr := err1 + err2            if newSqrErr < sqrErr {                sqrErr = newSqrErr                best_j = j                best_s = s                best_c1 = c1                best_c2 = c2                best_slice1 = slice1                best_slice2 = slice2            }        }    }    return}//生成树的根节点func initTree(data Data, z Vector) *MSETree {    basicSlice := make([]int, len(z))    for i := 0; i < len(z); i++ {        basicSlice[i] = i    }    return &MSETree{data, basicSlice, nil, nil, z.Mean(), true, 0, 0}}//展开叶子节点func sprout(node *MSETree, z Vector) {    if len(node.index) > 1 {        j, s, c1, c2, slice1, slice2 := bestSpliter(node.data, z, node.index)        node.j = j        node.s = s        node.left = &MSETree{node.data, slice1, nil, nil, c1, true, 0, 0}        node.right = &MSETree{node.data, slice2, nil, nil, c2, true, 0, 0}        node.isLeaf = false    }}func TotalError(tree *MSETree, data Data, z Vector) float32 {    var errs, t float32    for i := range z {        t = tree.Predict(data[i]) - z[i]        errs += t * t    }    return errs / float32(len(z))}//收集叶子节点func leafCollector(node *MSETree) []*MSETree {    if node.isLeaf {        return []*MSETree{node}    }    return append(leafCollector(node.left), leafCollector(node.right)...)}//多线程地的展开树func Grouth(root *MSETree, z Vector, coreNum int) {    leafs := leafCollector(root)    l := len(leafs)    var wg sync.WaitGroup    wg.Add(coreNum)    for i := 0; i < coreNum; i++ {        go func(leafsPart []*MSETree) {            for _, p := range leafsPart {                sprout(p, z)            }            wg.Done()        }(leafs[i*l/coreNum : l*(i+1)/coreNum])    }    wg.Wait()    fmt.Println("grouth done.")}func main() {    data, z, _ := Loader("data.txt")    root := initTree(data, z)    for i := 0; i < 6; i++ {        Grouth(root, z, 8)        fmt.Println("Total error:", TotalError(root, data, z))    }}

输出结果:

grouth done.Total error: 0.096614435grouth done.Total error: 0.036780898grouth done.Total error: 0.020771468grouth done.Total error: 0.010595326grouth done.Total error: 0.005394744grouth done.Total error: 0.0031773413