SGD对20Newsgroups训练

来源:互联网 发布:二手mac pro工作站 编辑:程序博客网 时间:2024/05/29 10:00

前言:

SGD又名Logistic Regression,逻辑回归。


1.环境准备:hadoop2.2.0集群(或伪集群),mahout0.9,有关hadoop2与mahout0.9冲突问题见其他文档。


2. 下载20Newsgroups数据集放到hadoop主节点上,因为主节点配置了mahout


3.具体代码如下:

package mahout.SGD;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.io.Reader;import java.io.StringReader;import java.util.ArrayList;import java.util.Arrays;import java.util.Collection;import java.util.Collections;import java.util.List;import java.util.Map;import java.util.Set;import java.util.TreeMap;import org.apache.lucene.analysis.Analyzer;import org.apache.lucene.analysis.TokenStream;import org.apache.lucene.analysis.standard.StandardAnalyzer;import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;import org.apache.lucene.util.Version;import org.apache.mahout.classifier.sgd.L1;import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;import org.apache.mahout.math.DenseVector;import org.apache.mahout.math.RandomAccessSparseVector;import org.apache.mahout.math.Vector;import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;import org.apache.mahout.vectorizer.encoders.Dictionary;import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;import com.google.common.collect.ConcurrentHashMultiset;import com.google.common.collect.HashMultiset;import com.google.common.collect.Iterables;import com.google.common.collect.Multiset;public class TrainNewsGroups {private static final int FEATURES = 10000;private static Multiset<String> overallCounts;public static void main(String[] args) throws IOException {String path = "E:\\data\\20news-bydate\\20news-bydate-train";File base = new File(args[0]);//File base = new File(path);overallCounts = HashMultiset.create();// 建立向量编码器Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");encoder.setProbes(2);encoder.setTraceDictionary(traceDictionary);FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");bias.setTraceDictionary(traceDictionary);FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");lines.setTraceDictionary(traceDictionary);Dictionary newsGroups = new Dictionary();// 配置学习算法OnlineLogisticRegression learningAlgorithm =     new OnlineLogisticRegression(          20, FEATURES, new L1())        .alpha(1).stepOffset(1000)        .decayExponent(0.9)         .lambda(3.0e-5)        .learningRate(20);// 访问数据文件List<File> files = new ArrayList<File>();for (File newsgroup : base.listFiles()) {  newsGroups.intern(newsgroup.getName());  files.addAll(Arrays.asList(newsgroup.listFiles()));}Collections.shuffle(files);System.out.printf("%d training files\n", files.size());// 数据词条化前的预备工作double averageLL = 0.0;double averageCorrect = 0.0;double averageLineCount = 0.0;int k = 0;double step = 0.0;int[] bumps = new int[]{1, 2, 5};double lineCount = 0;// 读取数据并进行词条化处理Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);for (File file : files) {BufferedReader reader = new BufferedReader(new FileReader(file));String ng = file.getParentFile().getName();int actual = newsGroups.intern(ng);Multiset<String> words = ConcurrentHashMultiset.create();String line = reader.readLine();while (line != null && line.length() > 0) {if (line.startsWith("Lines:")) {// String count = Iterables.get(onColon.split(line), 1);String[] lineArr = line.split("Lines:"); // 获得line行数String count = lineArr[1];try {lineCount = Integer.parseInt(count);averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);} catch (NumberFormatException e) {lineCount = averageLineCount;}}boolean countHeader = (line.startsWith("From:")|| line.startsWith("Subject:")|| line.startsWith("Keywords:") || line.startsWith("Summary:"));do {StringReader in = new StringReader(line);if (countHeader) {countWords(analyzer, words, in);}line = reader.readLine();} while (line.startsWith(" "));}countWords(analyzer, words, reader);reader.close();// 数据向量化Vector v = new RandomAccessSparseVector(FEATURES);bias.addToVector("", 1, v);//lines.addToVector("", lineCount / 30, v);lines.addToVector("", Math.log(lineCount + 1), v);//logLines.addToVector(null, Math.log(lineCount + 1), v);for (String word : words.elementSet()) {encoder.addToVector(word, Math.log(1 + words.count(word)), v);}// 评估当前进度double mu = Math.min(k + 1, 200);double ll = learningAlgorithm.logLikelihood(actual, v);averageLL = averageLL + (ll - averageLL) / mu;Vector p = new DenseVector(20);learningAlgorithm.classifyFull(p, v);int estimated = p.maxValueIndex();int correct = (estimated == actual? 1 : 0);averageCorrect = averageCorrect + (correct - averageCorrect) / mu;// 用编码数据训练SGD模型learningAlgorithm.train(actual, v);k++;int bump = bumps[(int) Math.floor(step) % bumps.length];int scale = (int) Math.pow(10, Math.floor(step / bumps.length));if (k % (bump * scale) == 0) {step += 0.25;System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n",k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated));}learningAlgorithm.close();}//System.out.println(overallCounts);//System.out.println(overallCounts.size());}private static void countWords(Analyzer analyzer, Collection<String> words,Reader in) throws IOException {TokenStream ts = analyzer.tokenStream("text", in);ts.addAttribute(CharTermAttribute.class);// 这里解决方案见:http://ask.csdn.net/questions/57173ts.reset();while (ts.incrementToken()) {String s = ts.getAttribute(CharTermAttribute.class).toString();words.add(s);}ts.end();ts.close();overallCounts.addAll(words); }}


4. 打包,并在hadoop上调用。注意,jar包需放在java项目的新建lib文件夹下,否则hadoop会找不到包而报ClassNotFoundException。

[root@hadoop1 bin]# hadoop jar ../../jar/javaTex2.jar mahout.SGD.TrainNewsGroups /usr/local/mahout/data/20news-bydate/20news-bydate-train/



0 0