您的位置:首页 > 编程语言 > PHP开发

使用GraphFrame 的shortestPaths API 求最短路径

2017-08-30 21:06 417 查看
GraphFrame 的shortestPaths 是可以计算节点到节点的最短路径,但是不能计算带权重的最短路径。然后利用BFS方法和find 方法求出路径节点。

代码如下

import java.util.ArrayList;

import java.util.Arrays;

import java.util.List;

import org.apache.spark.SparkConf;

import org.apache.spark.api.java.JavaRDD;

import org.apache.spark.api.java.JavaSparkContext;

import org.apache.spark.sql.DataFrame;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.RowFactory;

import org.apache.spark.sql.SQLContext;

import org.apache.spark.sql.catalyst.expressions.GenericRow;

import org.apache.spark.sql.types.DataTypes;

import org.apache.spark.sql.types.StructField;

import org.apache.spark.sql.types.StructType;

import org.graphframes.GraphFrame;

import scala.Option;

import scala.collection.Map;

/**

 * 

 */

public class GraphFrameShorhPaths

{

public static void main( String[] args )
{
SparkConf conf = new SparkConf( ).setAppName( "Short Paths" ).setMaster( "local" );

JavaSparkContext ctx = new JavaSparkContext( conf );

SQLContext sqlCtx = SQLContext.getOrCreate( ctx.sc( ) );

List<StructField> vList = new ArrayList<StructField>( );

vList.add( DataTypes.createStructField( "id", DataTypes.LongType, false ) );
vList.add( DataTypes.createStructField( "name", DataTypes.StringType, true ) );

StructType vType = DataTypes.createStructType( vList );

List<StructField> eList = new ArrayList<StructField>();

eList.add( DataTypes.createStructField( "src", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "dst", DataTypes.LongType, false ) );
eList.add( DataTypes.createStructField( "weight", DataTypes.DoubleType, true ) );

StructType eType = DataTypes.createStructType( eList );

JavaRDD<Row> verticeRow = ctx.parallelize( Arrays.asList( 
RowFactory.create( 1L,"a" ),
RowFactory.create( 2L,"b" ),
RowFactory.create( 3L,"c" ),
RowFactory.create( 4L,"d" ),
RowFactory.create( 5L,"e" )
) );
JavaRDD<Row> edgeRow = ctx.parallelize( Arrays.asList( 
RowFactory.create( 1L,2L,10.0 ),
RowFactory.create( 2L,3L,20.0 ),
RowFactory.create( 2L,4L,30.0 ),
RowFactory.create( 4L,5L,90.0 ),
RowFactory.create( 1L,4L,15.0 )) );

GraphFrame frame = new GraphFrame( sqlCtx.createDataFrame( verticeRow, vType ), sqlCtx.createDataFrame( edgeRow, eType ) );

ArrayList<Object> lamd = new ArrayList<Object>();
lamd.addAll( Arrays.asList( 1L,2L,3L,4L,5L ) );

DataFrame shortPathData = frame.shortestPaths( ).landmarks( lamd ).run( );

List<Long> ids = BFS( frame, shortPathData, 1
4000
L, 5L );
System.out.println( ids );
ctx.stop( );
}
private static int getShortPathLenght(DataFrame shortPathData, long from, long to)
{
Row row = shortPathData.filter( "id = " + from ).collectAsList( ).get( 0 );
Map map  = row.getMap( 2 );

Option option = map.get( to );
if (!option.isDefined( ))
{
return -1;
}
return (int)option.get( );
}
private static List<Long> BFS(GraphFrame frame, DataFrame shortPathData, long  from, long to)
{
List<Long> retValue = new ArrayList<Long>();
int lenght = (int)getShortPathLenght( shortPathData, from, to );
if (lenght <= 0 )
{
return retValue;
}
DataFrame pathData = frame.bfs( ).fromExpr( "id = " + from ).toExpr( "id = " + to ).maxPathLength( lenght ).run( );

long count = pathData.columns( ).length;
Row row = pathData.collectAsList( ).get( 0 );
for (int i=0; i<count; i=i+2)
{
retValue.add( ((GenericRow)row.getAs( i )).getLong( 0 ) );
}

return retValue;
}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: