GraphFrame message 实现最短路径

来源:互联网 发布:量化分析 java python 编辑:程序博客网 时间:2024/04/29 21:14

用aggregateMessages 实现最短路径实现,感觉

1:无法同时记录路径,或者比较难,要在sendToDst 时加上ID,同时要解析,限制很多。

2:性能不如Graph,暴力每次全图发消息,不能在Cluster内进行操作,不能用于大图计算。

3:sendToDst  API不友好,个人觉得应该用Function1接口,同时 agg 时用Function2 接口,和Graph一样,回灵活很多。


代码如下


import static org.apache.spark.sql.functions.min;


import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.graphframes.GraphFrame;


import scala.reflect.ClassManifestFactory;
import scala.runtime.AbstractFunction1;


/**
 * 
 */


public class MessageShortPaths
{
private static SQLContext sqlCtx;
private static StructType vType;
private static StructType vNewType;
private static JavaSparkContext ctx;
private static Double NA = Double.MAX_VALUE / 2.0;
public static void main( String[] args )
{
SparkConf conf = new SparkConf().setAppName( "Message short paths" ).setMaster( "local" );
ctx = new JavaSparkContext( conf );

sqlCtx = SQLContext.getOrCreate( ctx.sc( ) );
JavaRDD<Row> verticeRow = ctx.parallelize( Arrays.asList( 
RowFactory.create( 1L, "a" ), 
RowFactory.create( 2L, "b" ),
RowFactory.create( 3L, "c" ),
RowFactory.create( 4L, "d" ),
RowFactory.create( 5L, "e" )));

JavaRDD<Row> edgeRow = ctx.parallelize( Arrays.asList(
RowFactory.create( 1L, 2L, 10.0 ),
RowFactory.create( 2L, 3L, 30.0 ),
RowFactory.create( 2L, 4L, 20.0 ),
RowFactory.create( 4L, 5L, 80.0 ),
RowFactory.create( 1L, 4L, 5.0 )) );

List<StructField> vList = new ArrayList<StructField>();

vList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) );
vList.add( DataTypes.createStructField( "name", DataTypes.StringType, true ) );

vType = DataTypes.createStructType( vList );

List<StructField> vNewList = new ArrayList<StructField>();

vNewList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) );
vNewList.add( DataTypes.createStructField( "value", DataTypes.DoubleType, true ) );

vNewType = DataTypes.createStructType( vNewList );

List<StructField> eList = new ArrayList<StructField>();

eList.add( DataTypes.createStructField( "src", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "dst", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "weight", DataTypes.DoubleType, false ) );

StructType eType = DataTypes.createStructType( eList );

GraphFrame frame = new GraphFrame( sqlCtx.createDataFrame( verticeRow, vType ), sqlCtx.createDataFrame( edgeRow, eType ) );

GraphFrame shortPathsFrame = caleShortPaths( frame, 1L );

shortPathsFrame.vertices( ).show();

ctx.stop();

}
private static GraphFrame caleShortPaths(GraphFrame frame)
{
DataFrame vertices = frame.vertices( );
List<Row> updates = new ArrayList<Row>( );

DataFrame messageData = frame.aggregateMessages( ).sendToDst( "src.value + edge.weight" ).agg( min( "MSG" ) );
for ( Row row : messageData.collectAsList( ) )
{
long id = row.getLong( 0 );
double value = getNodeWeight( frame, id );
double weight = row.getDouble( 1 );
if ( weight < value )
{
updates.add( RowFactory.create( id, weight ) );
vertices = vertices.filter( "id != " + id );
}
}


if ( updates.size( ) == 0 )
{
frame.vertices( ).show( );
return frame;
}


JavaRDD<Row> rows = ctx.parallelize( updates );
JavaRDD<Row> newRDD = vertices.javaRDD( ).union( rows );
DataFrame newVertices = sqlCtx.createDataFrame( newRDD, vNewType );
GraphFrame newFrame = new GraphFrame( newVertices, frame.edges( ) );


return caleShortPaths( newFrame );
}


private static GraphFrame caleShortPaths(GraphFrame f, Long id)
{

RDD<Row> newRow = f.vertices( ).rdd( ).map( new MyFunction1<Row, Row>( )
{
@Override
public Row apply( Row row )
{
return row.getLong( 0 ) == id ?RowFactory.create( row.getLong( 0 ), 0.0 ):RowFactory.create( row.getLong( 0 ), NA);
}
}, ClassManifestFactory.classType( Double.class ) );

GraphFrame frame = new GraphFrame(sqlCtx.createDataFrame( newRow, vNewType ), f.edges( ));

return caleShortPaths( frame );
}
private static double getNodeWeight( GraphFrame frame, long id )
{
return frame.vertices( ).filter( "id ="
+ id ).collectAsList( ).get( 0 ).getDouble( 1 );
}
public static abstract class MyFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable
{

}
}