lda_spark代码

来源:互联网 发布:windows server 下载 编辑:程序博客网 时间:2024/06/10 22:29

单机版下测试,对于小文档来说可以运行,一旦文档数量过多,词汇的数量过多的时候就会抛出heap out of memory的错误

有一个问题: 如果给每个文档编上一个编号,则能有更多的并行操作。 主要在ppl_lik()方法里,对文档的顺序位置有要求,这里是将这些操作集中在driver程序中处理,是串行操作

结果

topic: of the a is that an for in at plate
topic: the of and a to in this at with on
topic: the to a and of flow in is on be
topic: the to a for that are by in theory on
topic: the a of in to at boundary as on an
topic: in of and for to is on are flow layer
topic: the in a is and that flow by layer this
topic: the in of a that this is are to it
topic: of to the a and at on flow for an
topic: the of is for a flow are to in and
topic: of the a is for flow on that this to
topic: and to of in for a with be is by
topic: of a to is for are in at the layer
topic: the of and a in to for flow on an
topic: of and for are with by is layer flow boundary
topic: of is in for the to are that on a
topic: the a of an that is on at flow by
topic: the of is in that and to at be an
topic: the and of for in on that by is a
topic: of the in is a and that be are by
topic: the of and a is with an theory for be
topic: of a in and is to the flow for are
topic: and in to of a are the be at by
topic: the of a for flow in that at this theory
topic: of the to and are with in is flow a
topic: the a is to of flow and on for be
topic: and to a is for that flow with at are
topic: the of and in flow to a is for boundary
topic: the of a to for is and flow by as
topic: of in is a with be an layer boundary are
topic: the a in is and with for of to be
topic: a in to of are flow is as an for
topic: of the and a flow for on be theory can
topic: of the and in flow at boundary a for layer
topic: the of flow are a as an for and with
topic: a to in for flow by and at of is
topic: the to and of a in for an this boundary
topic: the of in to are on that be a this
topic: the and of in for is a flow by on
topic: the of and in a on to are for is
topic: the a in of for that at by boundary to
topic: the of a to and is by for at theory
topic: the of to is for that with on theory by
topic: the in is with on are a layer at that
topic: the of in and to flow is boundary be a
topic: of the is on flow with a for that layer
topic: the of in a that on is and an to
topic: and of a to by in theory the at as
topic: the is of for with flow be by in as
topic: the of and a for are in this be layer 


代码

import org.apache.spark.SparkContextimport org.apache.spark.SparkContext._import org.apache.spark.rdd._import scala.io.Sourceimport scala.collection.mutable.ArrayBufferimport scala.util.Randomimport java.io.FileWriterimport scala.util.control.Breaks._//需要给Document加一个在corpus位置的属性, index,因为在ppl_lik方法里需要知道每个文档的在gamma矩阵的位置object LdaSpark2{    var zeroBegin = false    var lenBegin = false    def main(args: Array[String]){    val sc = new SparkContext("local", "LdaSpark", "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2", List("target/scala-2.10/ldaspark_2.10-1.0.jar"))        //val logFile = "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2/README.md"        //val numAs = logData.filter(line => line.contains("a")).count()        //val numBs = logData.filter(line => line.contains("b")).count()        //println("Lines with a: %s, Lines with b: %s".format(numAs, numBs))                //val nClass = args(0).toInt              //var zeroBegin = true      val nClass = 50      val emmax = 10      val demmax = 2      val epsilon = 1.0e-4      val documentFile = "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2/apps/lda/data/testTrain.txt"                  lda(sc, documentFile, nClass, emmax, demmax, epsilon)      //println(gammas mkString "\n")      println("done")    }    def lda(sc: SparkContext, documentFilePath: String, nClass: Int = 50, emmax: Int =100 , demmax: Int = 20, epsilon: Double = 1.0e-4): Unit = {      val fileAP = "result_alpha.txt"      val fileBP = "result_beta.txt"      val (documentRDD: RDD[Document], dLenMax: Int, nLex: Int) = readDocuments(documentFilePath, sc)      //val(dataV, maxIdV, maxLenV) = readDocuments2(datapath)      //val(dataV, maxIdV, maxLenV) = readDocuments(datapath)      //var data = dataV      //var nLex = maxIdV      //var dLenMax = maxLenV      //println("nLex: " + nLex + "   " + "dLenMax: " + dLenMax)      //var alpha = ArrayBuffer.fill(nClass)(Random.nextDouble)      //var beta = ArrayBuffer.fill(nLex, nClass)(1.0 / nLex)      val(alpha_result, beta_result) = lda_learn(documentRDD, documentFilePath, nClass, nLex, dLenMax, emmax, demmax, epsilon)      lda_write(fileAP, fileBP, alpha_result, beta_result, nClass, nLex)    }    def lda_learn(documentRDD: RDD[Document], documentFilePath: String, nClass: Int, nLex: Int, dLenMax: Int,      emmax: Int, demmax: Int, epsilon: Double):(ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) ={            val numberOfDocument = documentRDD.count.toInt              println(documentRDD.collect mkString "\n")      var (alpha: ArrayBuffer[Double], beta: ArrayBuffer[ArrayBuffer[Double]]) = initialAlphaAndBeta(nClass, nLex)              println("Number of documents      = " + numberOfDocument)      println("Number of words          = " + nLex)      println("Number of latent Classes = " + nClass)      println("Number of outer EM iteration = " + emmax)      println("Number of inner EM iteration = " + demmax)      println("Convergence threshold        = " + epsilon)      var pppl = 0.0      for(t <- 0 to emmax-1){        printf("iteration %d/%d..\n", t+1, emmax)        //parallelize the vbem method(which is the variational inference)        val gammaPhiTupleArray: Array[(ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]])]                 = documentRDD.map(x=> vbem(x, alpha, beta, nClass, dLenMax, demmax, nLex)).collect        println(gammaPhiTupleArray(0)._1 mkString ",")        var gammas = ArrayBuffer.fill(numberOfDocument, nClass)(0.0)        var betas = ArrayBuffer.fill(nLex, nClass)(0.0)        //accum gammas        //gammaPhiTupleArray.map(x=> accum_gammas(gammas, x._1, i, nClass))        for(i <- 0 to gammaPhiTupleArray.length - 1){          accum_gammas(gammas, gammaPhiTupleArray(i)._1, i, nClass)          accum_betas(betas, gammaPhiTupleArray(i)._3, nClass, nLex)        }                //VB M-step        //Newton-Raphson for alpha        val(alphaV, gammasV) = newton_alpha(alpha, gammas, numberOfDocument, nClass, 0)        alpha = alphaV        gammas = gammasV              //规范化后的betas变换赋值给beta        val(betaV, betasV) = normalize_matrix_col(beta, betas, nLex, nClass)        beta = betaV        betas = betasV              //clean Buffer        betas = ArrayBuffer.fill(nLex, nClass)(0.0)        val ppl_n = documentRDD.map(x=>lda_ppl_spark(x)).reduce(_+_).toDouble        val documentArray = documentRDD.collect                //converge?        val ppl = lda_ppl(ppl_n, documentArray, beta, gammas, numberOfDocument, nClass)              if((t > 1) && (Math.abs((ppl - pppl)/pppl) < epsilon)){          if(t < 5){            printf("\nearly convergence. restarting..\n")            val (alpha_re: ArrayBuffer[Double], beta_re: ArrayBuffer[ArrayBuffer[Double]]) = lda_learn(documentRDD, documentFilePath, nClass, nLex, dLenMax, emmax, demmax, epsilon)            return (alpha_re, beta_re)          }else{            printf("converged")            return (alpha, beta)          }        }        pppl = ppl      }      //println(alpha mkString ",")      return (alpha,beta)    }    /*    def parseStr(path: String): RDD[String] = {    val documentData = sc.textFile(path).cache()                val documentRDD = documentData.map(x => x.reverse)        documentRDD    }    */    def parseDocument(line: String): Document = {        val len = line.split(" ").size        val idArray = line.split(" ").map(x=>x.split(":")(0).toInt)        val cntArray = line.split(" ").map(x=>x.split(":")(1).toDouble)        val document = new Document(len, idArray, cntArray)        document    }        //文档编号问题没有清楚    //return (documentRDD, dMaxLen, nLex)    def readDocuments(documentFile: String, sc: SparkContext): (RDD[Document], Int, Int) = {    def parseDocument(line: String): Document = {            if(lenBegin){              val len = line.split(" ").size - 1              val idArray = line.split(" ").tail.map(x=>x.split(":")(0).toInt)              val cntArray = line.split(" ").tail.map(x=>x.split(":")(1).toDouble)              val document = new Document(len, idArray, cntArray)              return document            }else{              val len = line.split(" ").size              val idArray = line.split(" ").map(x=>x.split(":")(0).toInt)              val cntArray = line.split(" ").map(x=>x.split(":")(1).toDouble)              val document = new Document(len, idArray, cntArray)              return document            }                    }        println("--------------------------------------------------"+zeroBegin)      val documentData = sc.textFile(documentFile).cache()                val documentRDD = documentData.map(parseDocument _)                val documentLenArray = documentRDD.map(x=>x.len).collect        //tested        val dMaxLen = documentLenArray.reduce((x,y) => if(x > y) x else y)                val documentIdArray = documentRDD.flatMap(x=> x.id).collect        //TESTED        var nLex = documentIdArray.reduce((x,y) => if(x > y) x else y)        println("DEBUG:" + "nLex = " + nLex)        if(zeroBegin) nLex += 1        println("DEBUG:" + "nLex2 = " + nLex)        (documentRDD, dMaxLen, nLex)    }    def initialAlphaAndBeta(nClass: Int, nLex: Int): (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) = {    var alpha = ArrayBuffer.fill(nClass)(Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue)        val z = alpha.reduce(_+_)        alpha = alpha.map(x => x / z)         //排序alpha        alpha = alpha.sortWith(_ > _)        //initial beta        var beta = ArrayBuffer.fill(nLex, nClass)(Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue *10)            var tot = 0.0        for(j <- 0 to nClass - 1){          for(i <- 0 to nLex -1){            beta(i)(j) = Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue *10            tot += beta(i)(j)          }          for(i <- 0 to nLex -1)            beta(i)(j) = beta(i)(j) / tot          tot = 0.0        }        (alpha, beta)    }    /*    def calculate_beta(document: Document, q: ArrayBuffer[ArrayBuffer[Double]], K: Int, nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = {      val n = data(indexDocument).len      var beta = ArrayBuffer.fill(nLex, K)(0.0)      //println("document len: " + n)      for(i <- 0 to n-1)        for(k <- 0 to K-1){          //println("i: "+ i + " k: " + k)          //println("id: " + data(indexDocument).id(i))          //需要减1,因为id是以1开始          if(zeroBegin)            beta(document.id(i))(k) += q(i)(k) * document.cnt(i)          else            beta(document.id(i)-1)(k) += q(i)(k) * document.cnt(i)        }      return beta    }    */    def accum_gammas(gammas: ArrayBuffer[ArrayBuffer[Double]], gamma: ArrayBuffer[Double],      n: Int, K: Int):  ArrayBuffer[ArrayBuffer[Double]] = {      for(k <- 0 to K-1)        gammas(n)(k) = gamma(k)      return gammas    }     //   (beta, betas) tested correct    def accum_betas(betas: ArrayBuffer[ArrayBuffer[Double]], betaPerDoc: ArrayBuffer[ArrayBuffer[Double]], K: Int,         nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = {      val n = nLex      //println("document len: " + n)      for(i <- 0 to n-1)        for(k <- 0 to K-1){          //println("i: "+ i + " k: " + k)          //println("id: " + data(indexDocument).id(i))          //需要减1,因为id是以1开始          if(zeroBegin){            //betas(data(indexDocument).id(i))(k) += q(i)(k) * data(indexDocument).cnt(i)            betas(i)(k) += betaPerDoc(i)(k)          }          else{            //betas(data(indexDocument).id(i)-1)(k) += q(i)(k) * data(indexDocument).cnt(i)            betas(i)(k) += betaPerDoc(i)(k)          }        }      return betas    }       //correct    def normalize_matrix_row(dst: ArrayBuffer[ArrayBuffer[Double]], src: ArrayBuffer[ArrayBuffer[Double]],        rows: Int, cols: Int): (ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = {      for(i <- 0 to rows-1){        var z = 0.0        for(j <- 0 to cols -1)          z += src(i)(j)        for(j <- 0 to cols -1)          dst(i)(j) = src(i)(j) / z          }      return (dst, src)    }    //correct                beta     betas    def normalize_matrix_col(dst: ArrayBuffer[ArrayBuffer[Double]], src: ArrayBuffer[ArrayBuffer[Double]],        rows: Int, cols: Int): (ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = {      for(j <- 0 to cols-1){        var z = 0.0        for(i <- 0 to rows -1)          z += src(i)(j)        for(i <- 0 to rows -1)          dst(i)(j) = src(i)(j) / z          }      return (dst, src)    }        //tested correct    def lda_write(fileAP: String, fileBP: String, alpha: ArrayBuffer[Double],        beta: ArrayBuffer[ArrayBuffer[Double]], nClass: Int, nLex: Int){      printf("writing model..\n")      write_vector(fileAP, alpha, nClass)      write_matrix(fileBP, beta, nLex, nClass)      printf("done.\n")    }      //tested correct    def write_vector(filePath: String, vector: ArrayBuffer[Double], n: Int){      println("vector: " + vector mkString " ")      val fw = new FileWriter(filePath, false)      fw.write(vector mkString " ")      fw.close()    }      //test correct    def write_matrix(filePath: String, matrix: ArrayBuffer[ArrayBuffer[Double]],        rows: Int, cols: Int){      val fw = new FileWriter(filePath, false)      for(i <- 0 to rows -1){        for(j <- 0 to cols -1){           //fw.write("This line appended to file!") ;             fw.append(matrix(i)(j).toString + " ")        }        fw.append("\n")      }                   fw.close()    }        // L: document Length, K: nClass, emmax: demmax    def vbem(document: Document, alpha: ArrayBuffer[Double],     beta: ArrayBuffer[ArrayBuffer[Double]], K: Int, dLenMax: Int, emmax: Int, nLex: Int)       : (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = {        val nClass = K        var gamma = ArrayBuffer.fill(nClass)(0.0)        var ap = ArrayBuffer.fill(nClass)(0.0)        var nt = ArrayBuffer.fill(nClass)(0.0)        var pnt = ArrayBuffer.fill(nClass)(0.0)        var q = ArrayBuffer.fill(dLenMax, nClass)(0.0)         val L = document.len        for(k <- 0 to K-1){         nt(k) = L.toDouble / K        }        var isConverged = false        breakable{          for(j <- 0 to (emmax-1)){                  //每次调用都是重新赋值, 与alpha有关            for(k <- 0 to K-1){              ap(k) = Math.exp(digamma(alpha(k) + nt(k)))            }                    //accumulate q            //L为当前文档的长度,避免了越界            //每次都重新赋值,每个文档一个q矩阵            for(l <- 0 to L-1)              for(k <- 0 to K-1)                if(zeroBegin)                  q(l)(k) = beta(document.id(l))(k) * ap(k)                else  //文档中第l个字符在第k个topic中出现的概率 * ap(k)                  q(l)(k) = beta(document.id(l) - 1)(k) * ap(k)                         //normalize            for(l <- 0 to L-1){              var z = 0.0              for(k <- 0 to K-1)                z += q(l)(k)                    for(k <- 0 to K-1)                q(l)(k) = q(l)(k) / z            }                //vb-mstep            for(k <- 0 to K -1){              var z =0.0              for(l <- 0 to L - 1)                z += q(l)(k) * document.cnt(l)              nt(k) = z            }                //converge            if(j > 0 && converged(nt, pnt, K, 1.0e-2))              break//isConverged = true                  for(k <- 0 to K-1)              pnt(k) = nt(k)           }        }        //gamma每一次都重新赋值        for(k <- 0 to K-1)          gamma(k) = alpha(k) + nt(k)                        var betaPerDoc = calculate_beta(document, q, nClass, nLex)        return (gamma, q, betaPerDoc)    }    def calculate_beta(document: Document, q: ArrayBuffer[ArrayBuffer[Double]], K: Int, nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = {          val n = document.len          var beta = ArrayBuffer.fill(nLex, K)(0.0)          //println("document len: " + n)          for(i <- 0 to n-1)            for(k <- 0 to K-1){              //println("i: "+ i + " k: " + k)              //println("id: " + data(indexDocument).id(i))              //需要减1,因为id是以1开始              if(zeroBegin)                beta(document.id(i))(k) += q(i)(k) * document.cnt(i)              else                beta(document.id(i)-1)(k) += q(i)(k) * document.cnt(i)            }          return beta    }    //tested correc    def converged(u: ArrayBuffer[Double], v: ArrayBuffer[Double], n: Int, threshold: Double): Boolean = {      var us = 0.0      var ds = 0.0      var d = 0.0          for(i <- 0 to n-1)        us += u(i) * u(i)            for(i <- 0 to n-1){        d = u(i) - v(i)        ds += d*d      }          if(Math.sqrt(ds / us) < threshold)        return true      else        return false    }         /*    def getDocumentLen(documents: RDD[Document]): Array[Int] = {    documents.map(x=>x.len).collect    }    */    def digamma(xval: Double): Double = {      var x = xval      var result = 0.0      val neginf = -1.0 / 0.0      val c = 12      val s = 1e-6      val d1 = -0.57721566490153286      val d2 = 1.6449340668482264365      val s3 = 1.0/12      val s4 = 1.0/120      val s5 = 1.0/252      val s6 = 1.0/240      val s7 = 1.0/132      val s8 = 691/32760      val s9 = 1/12      val s10 = 3617/8160          /* Illegal arguments */      /*      if((x == neginf) || isnan(x)) {        return 0.0/0.0;      }      *        */      /* Singularities */      if((x <= 0) && (Math.floor(x) == x)) {        return neginf;      }      /* Negative values */      /* Use the reflection formula (Jeffrey 11.1.6):       * digamma(-x) = digamma(x+1) + pi*cot(pi*x)       *       * This is related to the identity       * digamma(-x) = digamma(x+1) - digamma(z) + digamma(1-z)       * where z is the fractional part of x       * For example:       * digamma(-3.1) = 1/3.1 + 1/2.1 + 1/1.1 + 1/0.1 + digamma(1-0.1)       *               = digamma(4.1) - digamma(0.1) + digamma(1-0.1)       * Then we use       * digamma(1-z) - digamma(z) = pi*cot(pi*z)       */      if(x < 0) {        return digamma(1-x) + Math.PI/Math.tan(-Math.PI*x);      }      /* Use Taylor series if argument <= S */      if(x <= s) return d1 - 1/x + d2*x;      /* Reduce to digamma(X + N) where (X + N) >= C */      result = 0;      while(x < c) {        result -= 1/x;        x = x + 1;      }      /* Use de Moivre's expansion if argument >= C */      /* This expansion can be computed in Maple via asympt(Psi(x),x) */      if(x >= c) {        var r = 1/x;        result += Math.log(x) - 0.5*r;        r *= r;        result -= r * (s3 - r * (s4 - r * (s5 - r * (s6 - r * s7))));      }      return result;    }    //tested correct  def trigamma(xVal: Double): Double = {      var x = xVal      var result = 0.0      val neginf = -1.0/0.0        val  small = 1e-4        val  large = 8       val  c = 1.6449340668482264365 /* pi^2/6 = Zeta(2) */      val  c1 = -2.404113806319188570799476  /* -2 Zeta(3) */      val  b2 =  1./6        val  b4 = -1./30      val  b6 =  1./42        val  b8 = -1./30      val  b10 = 5./66        /* Illegal arguments */      /*       if((x == neginf) || isnan(x)) {        return 0.0/0.0;        }        * */      /* Singularities */    if((x <= 0) && (Math.floor(x) == x)) {      return -neginf;    }    /* Negative values */    /* Use the derivative of the digamma reflection formula:     * -trigamma(-x) = trigamma(x+1) - (pi*csc(pi*x))^2     */    if(x < 0) {      result = Math.PI/Math.sin(-Math.PI*x);      return -trigamma(1-x) + result*result;    }    /* Use Taylor series if argument <= small */    if(x <= small) {      return 1/(x*x) + c + c1*x;    }    result = 0;    /* Reduce to trigamma(x+n) where ( X + N ) >= B */    while(x < large) {      result += 1/(x*x);      x = x+1;    }     /* Apply asymptotic formula when X >= B */    /* This expansion can be computed in Maple via asympt(Psi(1,x),x) */     if(x >= large) {       var r = 1/(x*x);       result += 0.5*r + (1 + r*(b2 + r*(b4 + r*(b6 + r*(b8 + r*b10)))))/x;     }     return result;  }         /*    //correct calculate the total number of word show in the corpus    def lda_ppl(data: ArrayBuffer[Document], beta: ArrayBuffer[ArrayBuffer[Double]],         gammas: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = {          var n = 0.0      val ds = data.size      for(i <- 0 to ds -1)        for(j <- 0 to data(i).len-1)          n += data(i).cnt(j)          return Math.exp(- lda_lik(data, beta, gammas, m, nClass) / n)    }    */    def lda_ppl(n: Double, documents: Array[Document], beta: ArrayBuffer[ArrayBuffer[Double]],         gammas: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = {              return Math.exp(- lda_lik_spark(documents, beta, gammas, m, nClass) / n)    }    def lda_ppl_spark(document: Document): Double = {      var n = 0.0      val ds = document.len      for(j <- 0 to ds - 1)        n += document.cnt(j)      n    }        def lda_lik_spark(documents: Array[Document], beta: ArrayBuffer[ArrayBuffer[Double]],      gammasVal: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double ={      var gammas = gammasVal      var egammas = ArrayBuffer.fill(m, nClass)(0.0)          val(egammasV, gammasV) = normalize_matrix_row(egammas, gammas, m, nClass)      egammas = egammasV      gammas = gammasV        val numberOfDocument = m              var lik = 0.0      for(i <- 0 to numberOfDocument-1){        for(j <- 0 to documents(i).len - 1){          var z = 0.0          for(k <- 0 to nClass - 1)            if(zeroBegin)              z+= beta(documents(i).id(j))(k) * egammas(i)(k)            else              z+= beta(documents(i).id(j)-1)(k) * egammas(i)(k)          lik += documents(i).cnt(j) * Math.log(z)        }      }      return lik    }     //correct    def lda_lik(data: ArrayBuffer[Document], beta: ArrayBuffer[ArrayBuffer[Double]],        gammasVal: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = {          val numberOfDocument =  data.size      var gammas = gammasVal      var egammas = ArrayBuffer.fill(m, nClass)(0.0)          val(egammasV, gammasV) = normalize_matrix_row(egammas, gammas, m, nClass)      egammas = egammasV      gammas = gammasV          var lik = 0.0      for(i <- 0 to numberOfDocument-1){        for(j <- 0 to data(i).len -1){          var z = 0.0          for(k <- 0 to nClass -1)            if(zeroBegin)              z += beta(data(i).id(j))(k) * egammas(i)(k)            else              z += beta(data(i).id(j)-1)(k) * egammas(i)(k)          lik += data(i).cnt(j) * Math.log(z)        }        }      return lik    }      def newton_alpha(alpha: ArrayBuffer[Double], gammas: ArrayBuffer[ArrayBuffer[Double]],        M: Int, K: Int, level: Int): (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) = {      var g = ArrayBuffer.fill(K)(0.0)      var h = ArrayBuffer.fill(K)(0.0)      var pg = ArrayBuffer.fill(K)(0.0)      var palpha = ArrayBuffer.fill(K)(0.0)          if(level == 0){        for( i <- 0 to K-1){          var z = 0.0          for(j <- 0 to M-1 )            z += gammas(j)(i)          alpha(i) = z / (M * K)        }      }else{        for(i <- 0 to K-1){          var z = 0.0          for(j <- 0 to M-1)            z += gammas(j)(i)          alpha(i) = z / (M*K*Math.pow(10, level))        }      }          var psg = 0.0      for(i <- 0 to M-1){        var gs = 0.0        for(j <- 0 to K-1)          gs += gammas(i)(j)        psg += digamma(gs)      }          for(i <- 0 to K-1){        var spg = 0.0        for(j <- 0 to M-1)          spg += digamma(gammas(j)(i))        pg(i) = spg - psg      }          val MAX_NEWTON_ITERATION = 20      for(t <- 0 to MAX_NEWTON_ITERATION){        var alpha0 = 0.0        for(i <- 0 to K-1)          alpha0 += alpha(i)        var palpha0 = digamma(alpha0)              for(i <- 0 to K-1)          g(i) = M*(palpha0 - digamma(alpha(i))) + pg(i)              for(i <- 0 to K-1)          h(i) = -1 / trigamma(alpha(i))              var sh = 0.0        for(i <- 0 to K-1)          sh += h(i)              var hgz = 0.0        for(i <- 0 to K-1)          hgz += g(i) * h(i)                hgz /= (1 / trigamma(alpha0) + sh)              for(i <- 0 to K-1)          alpha(i) = alpha(i) - h(i) * (g(i) - hgz) / M              val MAX_RECURSION_LIMIT = 10        for(i <- 0 to K-1)          if(alpha(i) < 0){            if(level >= MAX_RECURSION_LIMIT){              println("error")            }else{              return newton_alpha(alpha, gammas, M, K, 1+level)            }          }              if((t>0) && converged(alpha, palpha, K, 1.0e-4)){          return (alpha, gammas)        }else{          for(i <- 0 to K-1)            palpha(i) = alpha(i)        }      }      return (alpha, gammas)    }}


0 0
原创粉丝点击