GraphX SVDPlusPlus Java源码

来源:互联网 发布:阳光灿烂的日子知乎 编辑:程序博客网 时间:2024/05/22 14:06

用Java写了SVDPlusPlus


1:User 只有第一项和第三项是有用的为Pu,Bu,第二项放的是pu + |N(u)|^(-0.5)*sum(y), 第四项放的是|N(u)|^(-0.5),方便计算,。Item有三项有用的,分别为Qi,Yi,bi.计算公式为u + user._3( ) + item._3( ) + blas.ddot( rank, convert2double( item._1( ) ), 1,), item第四项放的是|N(I)|^(-0.5), 没用用处,最后结果中放的是每条边打分偏差的平方的和。

2:原可以看http://www.farseer.cn/2015/08/16/svd-implementation-in-graphx/,梯度下降的求导也简单,五个参数梯度下降,就是1中提到的5个有用的参数。gamma1, gamma2是梯度现将的参数,gamma6,gamma7正规化因子的参数.


个人感觉最后,循环结束后,需要重新计算一次user的第二项pu + |N(u)|^(-0.5)*sum(y)才是最严谨的。

代码如下

import java.io.Serializable;
import java.util.Arrays;
import java.util.Random;


import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.StorageLevels;
import org.apache.spark.graphx.Edge;
import org.apache.spark.graphx.EdgeContext;
import org.apache.spark.graphx.Graph;
import org.apache.spark.graphx.TripletFields;
import org.apache.spark.graphx.VertexRDD;
import org.apache.spark.graphx.lib.SVDPlusPlus;
import org.apache.spark.rdd.RDD;
import org.apache.spark.storage.StorageLevel;


import com.github.fommil.netlib.BLAS;


import scala.Option;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.reflect.ClassManifestFactory;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction1;
import scala.runtime.AbstractFunction2;
import scala.runtime.AbstractFunction3;
import scala.runtime.BoxedUnit;


/**
 * 
 */


public class SVDPlusPlusTest
{
private static final ClassTag<String> tagString = ClassManifestFactory.classType( String.class );
private static final ClassTag<Object> tagObject = ClassManifestFactory.classType( Object.class );
private static final ClassTag<Double> tagDouble = ClassManifestFactory.classType( Double.class );
private static final ClassTag<Double[]> tagDoubleArray = ClassManifestFactory.classType( Double[].class );
private static final ClassTag<Tuple4<Double[], Double[], Double, Double>> tagTuple4 = ClassManifestFactory.classType( Tuple4.class );

private static final BLAS blas = BLAS.getInstance( );
public static void main( String[] args )
{
SparkConf conf = new SparkConf().setAppName( "SVD ++" ).setMaster( "local" );
JavaSparkContext ctx = new JavaSparkContext( conf );

JavaRDD<Tuple2<Object, String>> vertices = ctx.parallelize( Arrays.asList( 
new Tuple2<Object, String>(1L, "a"),
new Tuple2<Object, String>(2L, "b"),
new Tuple2<Object, String>(3L, "c"),
new Tuple2<Object, String>(4L, "d")
) );

JavaRDD<Edge<Double>> edges = ctx.parallelize( Arrays.asList(
new Edge<Double>(1L, 10L, 3.0),
new Edge<Double>(2L, 11L, 4.0)
) );

Graph<String,Double> g = Graph.apply( vertices.rdd( ), edges.rdd( ), "", StorageLevel.MEMORY_ONLY( ), StorageLevel.MEMORY_ONLY( ), tagString, tagDouble );

SVDPlusPlus.Conf svdConf = new SVDPlusPlus.Conf( 2, 20, 0, 5, 0.007, 0.007, 0.005, 0.015 );

Tuple2<Graph<Tuple4<Double[], Double[],Double, Double>, Double>, Double> result = run( g.edges( ), svdConf );


}

private static Tuple2<Graph<Tuple4<Double[], Double[],Double, Double>, Double>, Double> run(RDD<Edge<Double>> edges, SVDPlusPlus.Conf conf)
{
//计算平均数
Tuple2<Long, Double> mean = edges.toJavaRDD( ).map( s-> 
{
return new Tuple2<Long, Double>(1L, s.attr( ));
}).reduce( (t1, t2) -> new Tuple2<Long, Double>(t1._1 + t2._1, t1._2 + t2._2) );

double u = mean._2( )/mean._1;

final int rank = conf.rank( );

Graph<Tuple4<Double[], Double[], Double, Double>, Double> g = Graph.fromEdges( 
edges, randomInit( rank ), StorageLevels.MEMORY_ONLY, StorageLevels.MEMORY_ONLY, 
tagTuple4, tagDouble ).cache( );

materialize( g );
edges.unpersist( true );

VertexRDD<Tuple2<Long, Double>> t0 = g.aggregateMessages( new MyFunction1<EdgeContext<Tuple4<Double[],Double[],Double,Double>,Double,Tuple2<Long, Double>>, BoxedUnit>( )
{


@Override
public BoxedUnit apply(
EdgeContext<Tuple4<Double[], Double[], Double, Double>, Double, Tuple2<Long, Double>> t1 )
{
t1.sendToSrc( new Tuple2<Long, Double>(1L, t1.attr( )) );
t1.sendToDst( new Tuple2<Long, Double>(1L, t1.attr( )) );
return BoxedUnit.UNIT;
}
}, new MyFunction2<Tuple2<Long, Double>, Tuple2<Long, Double>, Tuple2<Long, Double>>( )
{


@Override
public Tuple2<Long, Double> apply( Tuple2<Long, Double> t1, Tuple2<Long, Double> t2 )
{
return new Tuple2<Long, Double>(t1._1( ) + t2._1( ), t2._2( ) + t2._2( ) );
}
}, TripletFields.All, ClassManifestFactory.classType( Tuple2.class ) );

Graph<Tuple4<Double[], Double[], Double, Double>, Double > gJoin0 = g.outerJoinVertices( t0, new MyFunction3<Object, Tuple4<Double[], Double[], Double, Double>,
Option<Tuple2<Long, Double>>, Tuple4<Double[], Double[], Double, Double>>( )
{


@Override
public Tuple4<Double[], Double[], Double, Double> apply(
Object t0,
Tuple4<Double[], Double[], Double, Double> t1,
Option<Tuple2<Long, Double>> t2 )
{
if (t2.isDefined( ))
{
return new Tuple4<Double[], Double[], Double, Double>(t1._1( ), t1._2( ),
t2.get( )._2( )/t2.get( )._1( ) - u, 1.0/Math.sqrt( t2.get( )._1( ) ));
}
return t1;
}
}, ClassManifestFactory.classType( Tuple2.class ), tagTuple4, null ).cache( );

materialize( gJoin0 );
g.unpersist( true );

g = gJoin0;

for (int i=0; i<conf.maxIters( ); i++)
{
VertexRDD<Double[]> t1 = g.aggregateMessages( new MyFunction1<EdgeContext<Tuple4<Double[],Double[],Double,Double>,Double,Double[]>, BoxedUnit>( )
{


@Override
public BoxedUnit apply(
EdgeContext<Tuple4<Double[], Double[], Double, Double>, Double, Double[]> t )
{
t.sendToSrc( t.dstAttr( )._2( ) );
return BoxedUnit.UNIT;
}
}, new MyFunction2<Double[], Double[], Double[]>( )
{


@Override
public Double[] apply( Double[] t0, Double[] t1 )
{
double[] dy = convert2double( t0 );
blas.daxpy( rank, 1.0, convert2double( t1 ), 1, dy, 1 );
return convert2Double( dy );
}
}, TripletFields.Dst, tagDoubleArray );

Graph<Tuple4<Double[], Double[], Double, Double>, Double > gJoin1 = g.outerJoinVertices( t1, new MyFunction3<Object, Tuple4<Double[], Double[], Double, Double>, Option<Double[]>, Tuple4<Double[], Double[], Double, Double>>( )
{


@Override
public Tuple4<Double[], Double[], Double, Double> apply(
Object t0,
Tuple4<Double[], Double[], Double, Double> t1,
Option<Double[]> t2 )
{
if (t2.isDefined( ))
{
double[] dy = convert2double( t1._1( ) );
blas.daxpy( rank, t1._4( ), convert2double( t2.get( ) ), 1, dy, 1 );
return new Tuple4<Double[], Double[], Double, Double>( t1._1( ), convert2Double( dy ), t1._3( ), t1._4( ));
}
return t1;
}
}, tagDoubleArray, tagTuple4, null ).cache( );

materialize( gJoin1 );
g.unpersist( true );

g= gJoin1;

VertexRDD<Tuple3<Double[], Double[], Double>> t2 = g.aggregateMessages( new MyFunction1<EdgeContext<Tuple4<Double[],Double[],Double,Double>,Double,Tuple3<Double[], Double[], Double>>, BoxedUnit>( )
{


@Override
public BoxedUnit apply(
EdgeContext<Tuple4<Double[], Double[], Double, Double>, Double, Tuple3<Double[], Double[], Double>> t )
{
Tuple4<Double[], Double[], Double, Double> user = t.srcAttr( );
Tuple4<Double[], Double[], Double, Double> item = t.dstAttr( );

double value = u + user._3( ) + item._3( ) + blas.ddot( rank, convert2double( item._1( ) ), 1,
convert2double( user._2( ) ), 1 );

value = Math.min( value, conf.maxVal( ) );
value = Math.max( value, conf.minVal( ) );
double err = t.attr( ) - value;

//Bu 
double bu = conf.gamma1( )*(err - conf.gamma6( )*user._3( ));
//Iu
double iu = conf.gamma1( )*(err - conf.gamma6( ) * item._3( ));

//pu (err * q - conf.gamma7 * p) * conf.gamma2
double[] pu = convert2double( item._1( ) );
blas.dscal( rank, conf.gamma2( )*err, pu, 1 );
blas.daxpy( rank, -conf.gamma2( )*conf.gamma7( ), convert2double( user._1( ) ), 1, pu, 1 );

//qi (err * usr._2 - conf.gamma7 * q) * conf.gamma2
double[] qi = convert2double( user._2( ) );
blas.dscal( rank, err*conf.gamma2( ), qi, 1 );
blas.daxpy( rank, -conf.gamma2( )*conf.gamma7( ), convert2double( item._1( ) ), 1, qi, 1 );

//yiv(err * usr._4 * q - conf.gamma7 * itm._2) * conf.gamma2
double[] yi = convert2double( item._1( ) );
blas.dscal( rank, conf.gamma2( )*err*user._4( ), yi, 1 );
blas.daxpy( rank, -conf.gamma2( )*conf.gamma7( ), convert2double( item._2( ) ), 1, yi, 1 );

t.sendToSrc( new Tuple3<Double[], Double[], Double>(convert2Double( pu ), convert2Double(yi), bu) );
t.sendToDst( new Tuple3<Double[], Double[], Double>(convert2Double(qi), convert2Double(yi), iu) );
return BoxedUnit.UNIT;
}
}, new MyFunction2<Tuple3<Double[], Double[], Double>, Tuple3<Double[], Double[], Double>, Tuple3<Double[], Double[], Double>>( )
{


@Override
public Tuple3<Double[], Double[], Double> apply(
Tuple3<Double[], Double[], Double> t0,
Tuple3<Double[], Double[], Double> t1 )
{
double[] dy1 = convert2double( t0._1( ) );
blas.daxpy( rank, 1.0, convert2double( t1._1( ) ), 1, dy1, 1 );
double[] dy2 = convert2double( t0._2( ) );
blas.daxpy( rank, 1.0, convert2double( t1._2( ) ), 1, dy2, 1 );
return new Tuple3<Double[], Double[], Double>(convert2Double( dy1 ), convert2Double( dy2 ), t0._3( ) + t1._3( ));
}
}, TripletFields.All, ClassManifestFactory.classType( Tuple3.class ) );

Graph<Tuple4<Double[], Double[], Double, Double>, Double > gJoin2 = g.outerJoinVertices( t2, new MyFunction3<Object, Tuple4<Double[], Double[], Double, Double>, Option<Tuple3<Double[], Double[], Double>>, Tuple4<Double[], Double[], Double, Double>>( )
{


@Override
public Tuple4<Double[], Double[], Double, Double> apply(
Object t0,
Tuple4<Double[], Double[], Double, Double> t1,
Option<Tuple3<Double[], Double[], Double>> t2 )
{
double[] vd1 = convert2double( t1._1( ) );
blas.daxpy( rank, 1.0, convert2double( t2.get( )._1( )), 1, vd1, 1 );

double[] vd2 = convert2double( t1._2( ) );
blas.daxpy( rank, 1.0, convert2double( t2.get( )._2( ) ), 1, vd2, 1 );

return new Tuple4<Double[], Double[], Double, Double>(convert2Double( vd1 ), convert2Double( vd2 ), t1._3( ) + t2.get( )._3( ), t1._4( ));
}
}, ClassManifestFactory.classType( Tuple3.class ), tagTuple4, null ).cache( );

materialize( gJoin2 );
g.unpersist( true );
g = gJoin2;
}


VertexRDD<Double> t3 = g.aggregateMessages( new MyFunction1<EdgeContext<Tuple4<Double[],Double[],Double,Double>,Double,Double>, BoxedUnit>( )
{


@Override
public BoxedUnit apply(
EdgeContext<Tuple4<Double[], Double[], Double, Double>, Double, Double> t )
{
Tuple4<Double[], Double[], Double, Double> user = t.srcAttr( );
Tuple4<Double[], Double[], Double, Double> item = t.dstAttr( );

double value = u + user._3( ) + item._3( ) + blas.ddot( rank, convert2double( item._1( ) ), 1,
convert2double( user._2( ) ), 1 );

value = Math.min( value, conf.maxVal( ) );
value = Math.max( value, conf.minVal( ) );
double err = (t.attr( ) - value)*(t.attr( ) - value);
t.sendToDst( err );
return BoxedUnit.UNIT;
}
}, new MyFunction2<Double, Double, Double>( )
{


@Override
public Double apply( Double t0, Double t1 )
{
return t0 + t1;
}
}, TripletFields.All, tagDouble );

Graph<Tuple4<Double[], Double[], Double, Double>, Double > gJoin3 = g.outerJoinVertices( t3, new MyFunction3<Object, Tuple4<Double[], Double[], Double, Double>, Option<Double>, Tuple4<Double[], Double[], Double, Double>>( )
{


@Override
public Tuple4<Double[], Double[], Double, Double> apply(
Object t0,
Tuple4<Double[], Double[], Double, Double> t1,
Option<Double> t2 )
{
if (t2.isDefined( ))
{
return new Tuple4<Double[], Double[], Double, Double>(t1._1( ), t1._2( ), t1._3( ), t2.get( ));
}
return t1;
}
}, tagDouble, tagTuple4, null ).cache( );

materialize( gJoin3 );
g.unpersist( true );
g = gJoin3;

return new Tuple2<Graph<Tuple4<Double[], Double[], Double, Double>, Double>, Double>(g, u); 
}



private static Tuple4<Double[], Double[], Double, Double> randomInit(int rank)
{
Random random = new Random( );
Double[] t1 = new Double[rank];
Double[] t2 = new Double[rank];

for (int i=0; i<rank; i++)
{
t1[i] = random.nextDouble( );
t2[i] = random.nextDouble( );
}
return new Tuple4<Double[], Double[], Double, Double>(t1,t2,0.0,0.0);
}

private static void materialize(Graph g)
{
g.vertices( ).count( );
g.edges( ).count( );
}

private static abstract class  MyFunction1<T1, R> extends AbstractFunction1<T1, R> implements Serializable
{

}

private static abstract class MyFunction2<T1, T2, R> extends AbstractFunction2<T1, T2, R> implements Serializable
{

}

private static abstract class MyFunction3<T1,T2,T3,R> extends AbstractFunction3<T1, T2, T3, R> implements Serializable
{

}

private static Double[] convert2Double(double[] ds)
{
Double[] retValue = new Double[ds.length];
for (int i=0; i<retValue.length; i++)
{
retValue[i] = ds[i];
}

return retValue;
}
private static double[] convert2double(Double[] ds)
{
double[] retValue = new double[ds.length];
for (int i=0; i<retValue.length; i++)
{
retValue[i] = ds[i];
}

return retValue;
}
}

原创粉丝点击