hadoop实现矩阵的乘法(根据张丹的矩阵乘法改编)

来源:互联网 发布:图像追踪算法 编辑:程序博客网 时间:2024/04/29 08:29

我的hadoop版本是1.2.1,在eclipse中实现。

首先看矩阵算法的解释,我就直接搬了。


map中的key解释:


矩阵A中,第一个数字表示矩阵A的行数,第二个数字表示与矩阵A相 乘的矩阵B的列数,如图中跟的1,1 就是表示矩阵A第一行与与矩阵B第一列。矩阵B也是如此。



map中的value解释:



在矩阵A中,字母后面的第一个数字表示数字的位置,1,1  A:2,0 就表示为矩阵A的第二个数字为0;因为矩阵A的每一行都有3个数字,所以矩阵A的数字的偏移量就是从1-3.

在矩阵B中也是如此,矩阵B每一行有2个数字,所以矩阵B的数字偏移量就是从1-2.



现在直接上程序:


import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;


import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;


public class MartrixMultiply_01 {


static final String INPUT_PATH = "hdfs://hadoop:9000/Martrix";
static final String OUTPUT_PATH = "hdfs://hadoop:9000/MatrixOutput";


public static void main(String[] args) throws IOException,
URISyntaxException, ClassNotFoundException, InterruptedException {
Configuration conf = new Configuration();
final Job job = new Job(conf, MartrixMultiply_01.class.getSimpleName());


final FileSystem fileSystem = FileSystem.get(new URI(INPUT_PATH), conf);


if (fileSystem.exists(new Path(OUTPUT_PATH))) {
fileSystem.delete(new Path(OUTPUT_PATH), true);
}


FileInputFormat.setInputPaths(job, INPUT_PATH);


job.setMapperClass(MyMapper.class);
job.setMapOutputKeyClass(Text.class);
job.setMapOutputValueClass(Text.class);


job.setReducerClass(MyReducer.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);


FileOutputFormat.setOutputPath(job, new Path(OUTPUT_PATH));


job.waitForCompletion(true);


}


static class MyMapper extends Mapper<Object, Text, Text, Text> {
private String flag; //这里选择读取文件的名字
private int rowNum = 2;    //这里是表示矩阵A的行数
private int colNum = 2;     //这里表示矩阵B的列数
private int rowIndex = 1;  //这里表示当前的行数
private int colIndex = 1;  // 这里表示当前的列数


@Override
protected void setup(Mapper<Object, Text, Text, Text>.Context context)
throws IOException, InterruptedException {
FileSplit fileSplit = (FileSplit) context.getInputSplit();
flag = fileSplit.getPath().getName();
}


@Override
protected void map(Object key, Text value,
Mapper<Object, Text, Text, Text>.Context context)
throws IOException, InterruptedException {
String[] tokens = value.toString().split(",");
if (flag.equals("MartrixA")) {
Text k = new Text();
Text v = new Text();
for (int i = 1; i <= colNum; i++) {   //这里是表示读取矩阵B的哪列
k.set(rowIndex + ":" + i);    // rowindex是表示当前矩阵A的哪行
        for (int j = 1; j <= tokens.length; j++) {  //tokens.length是读取矩阵A的偏移量1-3
v.set("A+" + j + "," + tokens[j - 1]);
context.write(k, v);
System.out.println(k.toString() + ":" + v.toString());
}
}
rowIndex++;  //这里的叠加表示矩阵A读取一行就加1
} else if (flag.equals("MartrixB")) {
Text k = new Text();
Text v = new Text();
for (int i = 1; i <= rowNum; i++) {//这里同样表示矩阵A哪一行与矩阵B相乘的行数
for (int j = 1; j <= colNum; j++) {//这里表示矩阵B的哪列
k.set(i + ":" + j);
v.set("B+" + colIndex + "," + tokens[j - 1]);//colIndex 表示矩阵B的列的第几个数字,一共有3个数字,所以累加两次。


context.write(k, v);
System.out.println(k.toString() + ":" + v.toString());
}
}
colIndex++;这里就是累加
}


}
}
经过map后的输出就会把每行与每列的数字放在一个iterable里,比如矩阵A第一行与矩阵B的第一列相乘,就是把key为 1,1的 数字放入里面,然后再经过reduce的计算得出矩阵相乘的结果。

static class MyReducer extends Reducer<Text, Text, Text, IntWritable> {
private IntWritable v = new IntWritable();


@Override
protected void reduce(Text key, Iterable<Text> values,
Reducer<Text, Text, Text, IntWritable>.Context context)
throws IOException, InterruptedException {
Map<String, String> mapA = new HashMap<String, String>();
Map<String, String> mapB = new HashMap<String, String>();
for (Text t : values) {
if (t.toString().trim().startsWith("A")) {
System.out.println(t.toString());
String[] v = t.toString().substring(2).split(",");
System.out.println(v[1]);
mapA.put(v[0], v[1]);
} else if (t.toString().startsWith("B")) {
String[] v = t.toString().substring(2).split(",");
mapB.put(v[0], v[1]);
}
}
int sum = 0;
Iterator<String> it = mapA.keySet().iterator();
if (it.hasNext()) {
String num = it.next();
sum += (Integer.parseInt(mapA.get(num).trim()))
* (Integer.parseInt(mapB.get(num).trim()));
}
v.set(sum);
context.write(key, v);
}
}


其实写这个主要是怕自己忘记,希望看到的朋友不要拍砖。

0 0
原创粉丝点击