lda_spark代码
来源:互联网 发布:windows server 下载 编辑:程序博客网 时间:2024/06/10 22:29
单机版下测试,对于小文档来说可以运行,一旦文档数量过多,词汇的数量过多的时候就会抛出heap out of memory的错误
有一个问题: 如果给每个文档编上一个编号,则能有更多的并行操作。 主要在ppl_lik()方法里,对文档的顺序位置有要求,这里是将这些操作集中在driver程序中处理,是串行操作
结果
topic: of the a is that an for in at plate
topic: the of and a to in this at with on
topic: the to a and of flow in is on be
topic: the to a for that are by in theory on
topic: the a of in to at boundary as on an
topic: in of and for to is on are flow layer
topic: the in a is and that flow by layer this
topic: the in of a that this is are to it
topic: of to the a and at on flow for an
topic: the of is for a flow are to in and
topic: of the a is for flow on that this to
topic: and to of in for a with be is by
topic: of a to is for are in at the layer
topic: the of and a in to for flow on an
topic: of and for are with by is layer flow boundary
topic: of is in for the to are that on a
topic: the a of an that is on at flow by
topic: the of is in that and to at be an
topic: the and of for in on that by is a
topic: of the in is a and that be are by
topic: the of and a is with an theory for be
topic: of a in and is to the flow for are
topic: and in to of a are the be at by
topic: the of a for flow in that at this theory
topic: of the to and are with in is flow a
topic: the a is to of flow and on for be
topic: and to a is for that flow with at are
topic: the of and in flow to a is for boundary
topic: the of a to for is and flow by as
topic: of in is a with be an layer boundary are
topic: the a in is and with for of to be
topic: a in to of are flow is as an for
topic: of the and a flow for on be theory can
topic: of the and in flow at boundary a for layer
topic: the of flow are a as an for and with
topic: a to in for flow by and at of is
topic: the to and of a in for an this boundary
topic: the of in to are on that be a this
topic: the and of in for is a flow by on
topic: the of and in a on to are for is
topic: the a in of for that at by boundary to
topic: the of a to and is by for at theory
topic: the of to is for that with on theory by
topic: the in is with on are a layer at that
topic: the of in and to flow is boundary be a
topic: of the is on flow with a for that layer
topic: the of in a that on is and an to
topic: and of a to by in theory the at as
topic: the is of for with flow be by in as
topic: the of and a for are in this be layer
代码
import org.apache.spark.SparkContextimport org.apache.spark.SparkContext._import org.apache.spark.rdd._import scala.io.Sourceimport scala.collection.mutable.ArrayBufferimport scala.util.Randomimport java.io.FileWriterimport scala.util.control.Breaks._//需要给Document加一个在corpus位置的属性, index,因为在ppl_lik方法里需要知道每个文档的在gamma矩阵的位置object LdaSpark2{ var zeroBegin = false var lenBegin = false def main(args: Array[String]){ val sc = new SparkContext("local", "LdaSpark", "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2", List("target/scala-2.10/ldaspark_2.10-1.0.jar")) //val logFile = "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2/README.md" //val numAs = logData.filter(line => line.contains("a")).count() //val numBs = logData.filter(line => line.contains("b")).count() //println("Lines with a: %s, Lines with b: %s".format(numAs, numBs)) //val nClass = args(0).toInt //var zeroBegin = true val nClass = 50 val emmax = 10 val demmax = 2 val epsilon = 1.0e-4 val documentFile = "/Users/heruilong/Downloads/spark-0.9.1-bin-hadoop2/apps/lda/data/testTrain.txt" lda(sc, documentFile, nClass, emmax, demmax, epsilon) //println(gammas mkString "\n") println("done") } def lda(sc: SparkContext, documentFilePath: String, nClass: Int = 50, emmax: Int =100 , demmax: Int = 20, epsilon: Double = 1.0e-4): Unit = { val fileAP = "result_alpha.txt" val fileBP = "result_beta.txt" val (documentRDD: RDD[Document], dLenMax: Int, nLex: Int) = readDocuments(documentFilePath, sc) //val(dataV, maxIdV, maxLenV) = readDocuments2(datapath) //val(dataV, maxIdV, maxLenV) = readDocuments(datapath) //var data = dataV //var nLex = maxIdV //var dLenMax = maxLenV //println("nLex: " + nLex + " " + "dLenMax: " + dLenMax) //var alpha = ArrayBuffer.fill(nClass)(Random.nextDouble) //var beta = ArrayBuffer.fill(nLex, nClass)(1.0 / nLex) val(alpha_result, beta_result) = lda_learn(documentRDD, documentFilePath, nClass, nLex, dLenMax, emmax, demmax, epsilon) lda_write(fileAP, fileBP, alpha_result, beta_result, nClass, nLex) } def lda_learn(documentRDD: RDD[Document], documentFilePath: String, nClass: Int, nLex: Int, dLenMax: Int, emmax: Int, demmax: Int, epsilon: Double):(ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) ={ val numberOfDocument = documentRDD.count.toInt println(documentRDD.collect mkString "\n") var (alpha: ArrayBuffer[Double], beta: ArrayBuffer[ArrayBuffer[Double]]) = initialAlphaAndBeta(nClass, nLex) println("Number of documents = " + numberOfDocument) println("Number of words = " + nLex) println("Number of latent Classes = " + nClass) println("Number of outer EM iteration = " + emmax) println("Number of inner EM iteration = " + demmax) println("Convergence threshold = " + epsilon) var pppl = 0.0 for(t <- 0 to emmax-1){ printf("iteration %d/%d..\n", t+1, emmax) //parallelize the vbem method(which is the variational inference) val gammaPhiTupleArray: Array[(ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]])] = documentRDD.map(x=> vbem(x, alpha, beta, nClass, dLenMax, demmax, nLex)).collect println(gammaPhiTupleArray(0)._1 mkString ",") var gammas = ArrayBuffer.fill(numberOfDocument, nClass)(0.0) var betas = ArrayBuffer.fill(nLex, nClass)(0.0) //accum gammas //gammaPhiTupleArray.map(x=> accum_gammas(gammas, x._1, i, nClass)) for(i <- 0 to gammaPhiTupleArray.length - 1){ accum_gammas(gammas, gammaPhiTupleArray(i)._1, i, nClass) accum_betas(betas, gammaPhiTupleArray(i)._3, nClass, nLex) } //VB M-step //Newton-Raphson for alpha val(alphaV, gammasV) = newton_alpha(alpha, gammas, numberOfDocument, nClass, 0) alpha = alphaV gammas = gammasV //规范化后的betas变换赋值给beta val(betaV, betasV) = normalize_matrix_col(beta, betas, nLex, nClass) beta = betaV betas = betasV //clean Buffer betas = ArrayBuffer.fill(nLex, nClass)(0.0) val ppl_n = documentRDD.map(x=>lda_ppl_spark(x)).reduce(_+_).toDouble val documentArray = documentRDD.collect //converge? val ppl = lda_ppl(ppl_n, documentArray, beta, gammas, numberOfDocument, nClass) if((t > 1) && (Math.abs((ppl - pppl)/pppl) < epsilon)){ if(t < 5){ printf("\nearly convergence. restarting..\n") val (alpha_re: ArrayBuffer[Double], beta_re: ArrayBuffer[ArrayBuffer[Double]]) = lda_learn(documentRDD, documentFilePath, nClass, nLex, dLenMax, emmax, demmax, epsilon) return (alpha_re, beta_re) }else{ printf("converged") return (alpha, beta) } } pppl = ppl } //println(alpha mkString ",") return (alpha,beta) } /* def parseStr(path: String): RDD[String] = { val documentData = sc.textFile(path).cache() val documentRDD = documentData.map(x => x.reverse) documentRDD } */ def parseDocument(line: String): Document = { val len = line.split(" ").size val idArray = line.split(" ").map(x=>x.split(":")(0).toInt) val cntArray = line.split(" ").map(x=>x.split(":")(1).toDouble) val document = new Document(len, idArray, cntArray) document } //文档编号问题没有清楚 //return (documentRDD, dMaxLen, nLex) def readDocuments(documentFile: String, sc: SparkContext): (RDD[Document], Int, Int) = { def parseDocument(line: String): Document = { if(lenBegin){ val len = line.split(" ").size - 1 val idArray = line.split(" ").tail.map(x=>x.split(":")(0).toInt) val cntArray = line.split(" ").tail.map(x=>x.split(":")(1).toDouble) val document = new Document(len, idArray, cntArray) return document }else{ val len = line.split(" ").size val idArray = line.split(" ").map(x=>x.split(":")(0).toInt) val cntArray = line.split(" ").map(x=>x.split(":")(1).toDouble) val document = new Document(len, idArray, cntArray) return document } } println("--------------------------------------------------"+zeroBegin) val documentData = sc.textFile(documentFile).cache() val documentRDD = documentData.map(parseDocument _) val documentLenArray = documentRDD.map(x=>x.len).collect //tested val dMaxLen = documentLenArray.reduce((x,y) => if(x > y) x else y) val documentIdArray = documentRDD.flatMap(x=> x.id).collect //TESTED var nLex = documentIdArray.reduce((x,y) => if(x > y) x else y) println("DEBUG:" + "nLex = " + nLex) if(zeroBegin) nLex += 1 println("DEBUG:" + "nLex2 = " + nLex) (documentRDD, dMaxLen, nLex) } def initialAlphaAndBeta(nClass: Int, nLex: Int): (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) = { var alpha = ArrayBuffer.fill(nClass)(Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue) val z = alpha.reduce(_+_) alpha = alpha.map(x => x / z) //排序alpha alpha = alpha.sortWith(_ > _) //initial beta var beta = ArrayBuffer.fill(nLex, nClass)(Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue *10) var tot = 0.0 for(j <- 0 to nClass - 1){ for(i <- 0 to nLex -1){ beta(i)(j) = Random.nextInt(Int.MaxValue).toDouble / Int.MaxValue *10 tot += beta(i)(j) } for(i <- 0 to nLex -1) beta(i)(j) = beta(i)(j) / tot tot = 0.0 } (alpha, beta) } /* def calculate_beta(document: Document, q: ArrayBuffer[ArrayBuffer[Double]], K: Int, nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = { val n = data(indexDocument).len var beta = ArrayBuffer.fill(nLex, K)(0.0) //println("document len: " + n) for(i <- 0 to n-1) for(k <- 0 to K-1){ //println("i: "+ i + " k: " + k) //println("id: " + data(indexDocument).id(i)) //需要减1,因为id是以1开始 if(zeroBegin) beta(document.id(i))(k) += q(i)(k) * document.cnt(i) else beta(document.id(i)-1)(k) += q(i)(k) * document.cnt(i) } return beta } */ def accum_gammas(gammas: ArrayBuffer[ArrayBuffer[Double]], gamma: ArrayBuffer[Double], n: Int, K: Int): ArrayBuffer[ArrayBuffer[Double]] = { for(k <- 0 to K-1) gammas(n)(k) = gamma(k) return gammas } // (beta, betas) tested correct def accum_betas(betas: ArrayBuffer[ArrayBuffer[Double]], betaPerDoc: ArrayBuffer[ArrayBuffer[Double]], K: Int, nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = { val n = nLex //println("document len: " + n) for(i <- 0 to n-1) for(k <- 0 to K-1){ //println("i: "+ i + " k: " + k) //println("id: " + data(indexDocument).id(i)) //需要减1,因为id是以1开始 if(zeroBegin){ //betas(data(indexDocument).id(i))(k) += q(i)(k) * data(indexDocument).cnt(i) betas(i)(k) += betaPerDoc(i)(k) } else{ //betas(data(indexDocument).id(i)-1)(k) += q(i)(k) * data(indexDocument).cnt(i) betas(i)(k) += betaPerDoc(i)(k) } } return betas } //correct def normalize_matrix_row(dst: ArrayBuffer[ArrayBuffer[Double]], src: ArrayBuffer[ArrayBuffer[Double]], rows: Int, cols: Int): (ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = { for(i <- 0 to rows-1){ var z = 0.0 for(j <- 0 to cols -1) z += src(i)(j) for(j <- 0 to cols -1) dst(i)(j) = src(i)(j) / z } return (dst, src) } //correct beta betas def normalize_matrix_col(dst: ArrayBuffer[ArrayBuffer[Double]], src: ArrayBuffer[ArrayBuffer[Double]], rows: Int, cols: Int): (ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = { for(j <- 0 to cols-1){ var z = 0.0 for(i <- 0 to rows -1) z += src(i)(j) for(i <- 0 to rows -1) dst(i)(j) = src(i)(j) / z } return (dst, src) } //tested correct def lda_write(fileAP: String, fileBP: String, alpha: ArrayBuffer[Double], beta: ArrayBuffer[ArrayBuffer[Double]], nClass: Int, nLex: Int){ printf("writing model..\n") write_vector(fileAP, alpha, nClass) write_matrix(fileBP, beta, nLex, nClass) printf("done.\n") } //tested correct def write_vector(filePath: String, vector: ArrayBuffer[Double], n: Int){ println("vector: " + vector mkString " ") val fw = new FileWriter(filePath, false) fw.write(vector mkString " ") fw.close() } //test correct def write_matrix(filePath: String, matrix: ArrayBuffer[ArrayBuffer[Double]], rows: Int, cols: Int){ val fw = new FileWriter(filePath, false) for(i <- 0 to rows -1){ for(j <- 0 to cols -1){ //fw.write("This line appended to file!") ; fw.append(matrix(i)(j).toString + " ") } fw.append("\n") } fw.close() } // L: document Length, K: nClass, emmax: demmax def vbem(document: Document, alpha: ArrayBuffer[Double], beta: ArrayBuffer[ArrayBuffer[Double]], K: Int, dLenMax: Int, emmax: Int, nLex: Int) : (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]], ArrayBuffer[ArrayBuffer[Double]]) = { val nClass = K var gamma = ArrayBuffer.fill(nClass)(0.0) var ap = ArrayBuffer.fill(nClass)(0.0) var nt = ArrayBuffer.fill(nClass)(0.0) var pnt = ArrayBuffer.fill(nClass)(0.0) var q = ArrayBuffer.fill(dLenMax, nClass)(0.0) val L = document.len for(k <- 0 to K-1){ nt(k) = L.toDouble / K } var isConverged = false breakable{ for(j <- 0 to (emmax-1)){ //每次调用都是重新赋值, 与alpha有关 for(k <- 0 to K-1){ ap(k) = Math.exp(digamma(alpha(k) + nt(k))) } //accumulate q //L为当前文档的长度,避免了越界 //每次都重新赋值,每个文档一个q矩阵 for(l <- 0 to L-1) for(k <- 0 to K-1) if(zeroBegin) q(l)(k) = beta(document.id(l))(k) * ap(k) else //文档中第l个字符在第k个topic中出现的概率 * ap(k) q(l)(k) = beta(document.id(l) - 1)(k) * ap(k) //normalize for(l <- 0 to L-1){ var z = 0.0 for(k <- 0 to K-1) z += q(l)(k) for(k <- 0 to K-1) q(l)(k) = q(l)(k) / z } //vb-mstep for(k <- 0 to K -1){ var z =0.0 for(l <- 0 to L - 1) z += q(l)(k) * document.cnt(l) nt(k) = z } //converge if(j > 0 && converged(nt, pnt, K, 1.0e-2)) break//isConverged = true for(k <- 0 to K-1) pnt(k) = nt(k) } } //gamma每一次都重新赋值 for(k <- 0 to K-1) gamma(k) = alpha(k) + nt(k) var betaPerDoc = calculate_beta(document, q, nClass, nLex) return (gamma, q, betaPerDoc) } def calculate_beta(document: Document, q: ArrayBuffer[ArrayBuffer[Double]], K: Int, nLex: Int): ArrayBuffer[ArrayBuffer[Double]] = { val n = document.len var beta = ArrayBuffer.fill(nLex, K)(0.0) //println("document len: " + n) for(i <- 0 to n-1) for(k <- 0 to K-1){ //println("i: "+ i + " k: " + k) //println("id: " + data(indexDocument).id(i)) //需要减1,因为id是以1开始 if(zeroBegin) beta(document.id(i))(k) += q(i)(k) * document.cnt(i) else beta(document.id(i)-1)(k) += q(i)(k) * document.cnt(i) } return beta } //tested correc def converged(u: ArrayBuffer[Double], v: ArrayBuffer[Double], n: Int, threshold: Double): Boolean = { var us = 0.0 var ds = 0.0 var d = 0.0 for(i <- 0 to n-1) us += u(i) * u(i) for(i <- 0 to n-1){ d = u(i) - v(i) ds += d*d } if(Math.sqrt(ds / us) < threshold) return true else return false } /* def getDocumentLen(documents: RDD[Document]): Array[Int] = { documents.map(x=>x.len).collect } */ def digamma(xval: Double): Double = { var x = xval var result = 0.0 val neginf = -1.0 / 0.0 val c = 12 val s = 1e-6 val d1 = -0.57721566490153286 val d2 = 1.6449340668482264365 val s3 = 1.0/12 val s4 = 1.0/120 val s5 = 1.0/252 val s6 = 1.0/240 val s7 = 1.0/132 val s8 = 691/32760 val s9 = 1/12 val s10 = 3617/8160 /* Illegal arguments */ /* if((x == neginf) || isnan(x)) { return 0.0/0.0; } * */ /* Singularities */ if((x <= 0) && (Math.floor(x) == x)) { return neginf; } /* Negative values */ /* Use the reflection formula (Jeffrey 11.1.6): * digamma(-x) = digamma(x+1) + pi*cot(pi*x) * * This is related to the identity * digamma(-x) = digamma(x+1) - digamma(z) + digamma(1-z) * where z is the fractional part of x * For example: * digamma(-3.1) = 1/3.1 + 1/2.1 + 1/1.1 + 1/0.1 + digamma(1-0.1) * = digamma(4.1) - digamma(0.1) + digamma(1-0.1) * Then we use * digamma(1-z) - digamma(z) = pi*cot(pi*z) */ if(x < 0) { return digamma(1-x) + Math.PI/Math.tan(-Math.PI*x); } /* Use Taylor series if argument <= S */ if(x <= s) return d1 - 1/x + d2*x; /* Reduce to digamma(X + N) where (X + N) >= C */ result = 0; while(x < c) { result -= 1/x; x = x + 1; } /* Use de Moivre's expansion if argument >= C */ /* This expansion can be computed in Maple via asympt(Psi(x),x) */ if(x >= c) { var r = 1/x; result += Math.log(x) - 0.5*r; r *= r; result -= r * (s3 - r * (s4 - r * (s5 - r * (s6 - r * s7)))); } return result; } //tested correct def trigamma(xVal: Double): Double = { var x = xVal var result = 0.0 val neginf = -1.0/0.0 val small = 1e-4 val large = 8 val c = 1.6449340668482264365 /* pi^2/6 = Zeta(2) */ val c1 = -2.404113806319188570799476 /* -2 Zeta(3) */ val b2 = 1./6 val b4 = -1./30 val b6 = 1./42 val b8 = -1./30 val b10 = 5./66 /* Illegal arguments */ /* if((x == neginf) || isnan(x)) { return 0.0/0.0; } * */ /* Singularities */ if((x <= 0) && (Math.floor(x) == x)) { return -neginf; } /* Negative values */ /* Use the derivative of the digamma reflection formula: * -trigamma(-x) = trigamma(x+1) - (pi*csc(pi*x))^2 */ if(x < 0) { result = Math.PI/Math.sin(-Math.PI*x); return -trigamma(1-x) + result*result; } /* Use Taylor series if argument <= small */ if(x <= small) { return 1/(x*x) + c + c1*x; } result = 0; /* Reduce to trigamma(x+n) where ( X + N ) >= B */ while(x < large) { result += 1/(x*x); x = x+1; } /* Apply asymptotic formula when X >= B */ /* This expansion can be computed in Maple via asympt(Psi(1,x),x) */ if(x >= large) { var r = 1/(x*x); result += 0.5*r + (1 + r*(b2 + r*(b4 + r*(b6 + r*(b8 + r*b10)))))/x; } return result; } /* //correct calculate the total number of word show in the corpus def lda_ppl(data: ArrayBuffer[Document], beta: ArrayBuffer[ArrayBuffer[Double]], gammas: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = { var n = 0.0 val ds = data.size for(i <- 0 to ds -1) for(j <- 0 to data(i).len-1) n += data(i).cnt(j) return Math.exp(- lda_lik(data, beta, gammas, m, nClass) / n) } */ def lda_ppl(n: Double, documents: Array[Document], beta: ArrayBuffer[ArrayBuffer[Double]], gammas: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = { return Math.exp(- lda_lik_spark(documents, beta, gammas, m, nClass) / n) } def lda_ppl_spark(document: Document): Double = { var n = 0.0 val ds = document.len for(j <- 0 to ds - 1) n += document.cnt(j) n } def lda_lik_spark(documents: Array[Document], beta: ArrayBuffer[ArrayBuffer[Double]], gammasVal: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double ={ var gammas = gammasVal var egammas = ArrayBuffer.fill(m, nClass)(0.0) val(egammasV, gammasV) = normalize_matrix_row(egammas, gammas, m, nClass) egammas = egammasV gammas = gammasV val numberOfDocument = m var lik = 0.0 for(i <- 0 to numberOfDocument-1){ for(j <- 0 to documents(i).len - 1){ var z = 0.0 for(k <- 0 to nClass - 1) if(zeroBegin) z+= beta(documents(i).id(j))(k) * egammas(i)(k) else z+= beta(documents(i).id(j)-1)(k) * egammas(i)(k) lik += documents(i).cnt(j) * Math.log(z) } } return lik } //correct def lda_lik(data: ArrayBuffer[Document], beta: ArrayBuffer[ArrayBuffer[Double]], gammasVal: ArrayBuffer[ArrayBuffer[Double]], m: Int, nClass: Int): Double = { val numberOfDocument = data.size var gammas = gammasVal var egammas = ArrayBuffer.fill(m, nClass)(0.0) val(egammasV, gammasV) = normalize_matrix_row(egammas, gammas, m, nClass) egammas = egammasV gammas = gammasV var lik = 0.0 for(i <- 0 to numberOfDocument-1){ for(j <- 0 to data(i).len -1){ var z = 0.0 for(k <- 0 to nClass -1) if(zeroBegin) z += beta(data(i).id(j))(k) * egammas(i)(k) else z += beta(data(i).id(j)-1)(k) * egammas(i)(k) lik += data(i).cnt(j) * Math.log(z) } } return lik } def newton_alpha(alpha: ArrayBuffer[Double], gammas: ArrayBuffer[ArrayBuffer[Double]], M: Int, K: Int, level: Int): (ArrayBuffer[Double], ArrayBuffer[ArrayBuffer[Double]]) = { var g = ArrayBuffer.fill(K)(0.0) var h = ArrayBuffer.fill(K)(0.0) var pg = ArrayBuffer.fill(K)(0.0) var palpha = ArrayBuffer.fill(K)(0.0) if(level == 0){ for( i <- 0 to K-1){ var z = 0.0 for(j <- 0 to M-1 ) z += gammas(j)(i) alpha(i) = z / (M * K) } }else{ for(i <- 0 to K-1){ var z = 0.0 for(j <- 0 to M-1) z += gammas(j)(i) alpha(i) = z / (M*K*Math.pow(10, level)) } } var psg = 0.0 for(i <- 0 to M-1){ var gs = 0.0 for(j <- 0 to K-1) gs += gammas(i)(j) psg += digamma(gs) } for(i <- 0 to K-1){ var spg = 0.0 for(j <- 0 to M-1) spg += digamma(gammas(j)(i)) pg(i) = spg - psg } val MAX_NEWTON_ITERATION = 20 for(t <- 0 to MAX_NEWTON_ITERATION){ var alpha0 = 0.0 for(i <- 0 to K-1) alpha0 += alpha(i) var palpha0 = digamma(alpha0) for(i <- 0 to K-1) g(i) = M*(palpha0 - digamma(alpha(i))) + pg(i) for(i <- 0 to K-1) h(i) = -1 / trigamma(alpha(i)) var sh = 0.0 for(i <- 0 to K-1) sh += h(i) var hgz = 0.0 for(i <- 0 to K-1) hgz += g(i) * h(i) hgz /= (1 / trigamma(alpha0) + sh) for(i <- 0 to K-1) alpha(i) = alpha(i) - h(i) * (g(i) - hgz) / M val MAX_RECURSION_LIMIT = 10 for(i <- 0 to K-1) if(alpha(i) < 0){ if(level >= MAX_RECURSION_LIMIT){ println("error") }else{ return newton_alpha(alpha, gammas, M, K, 1+level) } } if((t>0) && converged(alpha, palpha, K, 1.0e-4)){ return (alpha, gammas) }else{ for(i <- 0 to K-1) palpha(i) = alpha(i) } } return (alpha, gammas) }}