您的位置:首页 > 其它

dl4j源码阅读心得及问题(Spark部分)

2016-08-05 21:50 447 查看
public class IrisLocal {

public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[*]");
sparkConf.setAppName("Iris");

JavaSparkContext sc = new JavaSparkContext(sparkConf);

//Load the data from local (driver) classpath into a JavaRDD<DataSet>, for training
//CSVRecordReader converts CSV data (as a String) into usable format for network training
RecordReader recordReader = new CSVRecordReader(0,",");
File f = new File("src/main/resources/iris_shuffled_normalized_csv.txt");
JavaRDD<String> irisDataLines = sc.textFile(f.getAbsolutePath());
//labelIndex变量指向目标向量在记录中的索引
int labelIndex = 4;
int numOutputClasses = 3;
//分别为每条记录创建特征向量和目标向量,目标向量根据numOutputClasses变量的个数以及记录中所给的目标索引确定,如目标索引为2,numOutputClasses为3,则目标向量为<0,1,0>
JavaRDD<DataSet> trainingData = irisDataLines.map(new RecordReaderFunction(recordReader, labelIndex, numOutputClasses)) ;

//First: Create and initialize multi-layer network. Configuration is the same as in normal (non-distributed) DL4J training
final int numInputs = 4;
int outputNum = 3;
int iterations = 1;

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.5)
.regularization(true).l2(1e-4)
.activation("tanh")
.weightInit(WeightInit.XAVIER)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
.nIn(2).nOut(outputNum).build())
.backprop(true).pretrain(false)
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

//Second: Set up the Spark training.
//Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark
//Here, we are using standard parameter averaging
int examplesPerDataSetObject = 1;
ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
.workerPrefetchNumBatches(2)    //Asynchronously prefetch up to 2 batches
.saveUpdater(true)
.averagingFrequency(1)  //See comments on averaging frequency in LSTM example. Averaging every 1 iteration is inefficient in practical problems
.batchSizePerWorker(8)  //Number of examples that each worker gets, per fit operation
.build();
SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc,net,tm);

int nEpochs = 100;
for( int i=0; i<nEpochs; i++ ){
sparkNetwork.fit(trainingData);
}

//Finally: evaluate the (training) data accuracy in a distributed manner:
Evaluation evaluation = sparkNetwork.evaluate(trainingData);
System.out.println(evaluation.stats());
}

}
上面是程序示例,主要实现的功能是:在spark环境下进行神经网络的训练Evaluation evaluation = sparkNetwork.evaluate(trainingData);进入SparkDl4jMultiLayer类中evalute方法,其中传递的参数分别为trainingdata,null,64
public Evaluation evaluate(JavaRDD<DataSet> data, List<String> labelsList, int evalBatchSize) {
Broadcast listBroadcast = labelsList == null?null:this.sc.broadcast(labelsList);
JavaRDD evaluations = data.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), evalBatchSize, listBroadcast));
return (Evaluation)evaluations.reduce(new EvaluationReduceFunction());
}
data.mapPartitions()方法需要一个FlatMapFunction<Iterator<DataSet>, Evaluation>参数,这里使用子类来实例化,newEvaluateFlatMapFunction()方法中参数分别为SparkDl4jMultiLayer对象的json格式,即
MultiLayerNetwork net = new MultiLayerNetwork(conf)
对象的json格式,第二个参数为
MultiLayerNetwork对象中的
flattenedParams
变量,该变量为神经网络中的权值加偏移量的总和,最后的两个参数为64和null。
EvaluateFlatMapFunction()方法中实现了上层接口FlatMapFunction<Iterator<DataSet>, Evaluation>的call方法,该方法主要完成
神经网络训练结果的测试。
public Iterable<Evaluation> call(Iterator<DataSet> dataSetIterator) throws Exception {
if(!dataSetIterator.hasNext()) {
return Collections.emptyList();
} else {
MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)this.json.getValue()));
network.init();
INDArray val = (INDArray)this.params.value();
if(val.length() != network.numParams(false)) {
throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
} else {
network.setParameters(val);
Evaluation evaluation;
if(this.labels != null) {
evaluation = new Evaluation((List)this.labels.getValue());
} else {
evaluation = new Evaluation();
}

ArrayList collect = new ArrayList();
int totalCount = 0;

while(dataSetIterator.hasNext()) {
collect.clear();
int nExamples = 0;

DataSet data;
while(dataSetIterator.hasNext() && nExamples < this.evalBatchSize) {
data = (DataSet)dataSetIterator.next();
nExamples += data.numExamples();
collect.add(data);
}

totalCount += nExamples;
data = DataSet.merge(collect, false);
INDArray out;
if(data.hasMaskArrays()) {
out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray());
} else {
out = network.output(data.getFeatureMatrix(), false);
}

if(data.getLabels().rank() == 3) {
if(data.getLabelsMaskArray() == null) {
evaluation.evalTimeSeries(data.getLabels(), out);
} else {
evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray());
}
} else {
evaluation.eval(data.getLabels(), out);
}
}

if(log.isDebugEnabled()) {
log.debug("Evaluated {} examples ", Integer.valueOf(totalCount));
}

return Collections.singletonList(evaluation);
}
}
}
具体什么意思不是很明白,
(Evaluation)evaluations.reduce(new EvaluationReduceFunction());
这个返回值主要是对各个分区最后得到的结果进行合并。
public void merge(Evaluation other) {
if(other != null) {
this.truePositives.incrementAll(other.truePositives);
this.falsePositives.incrementAll(other.falsePositives);
this.trueNegatives.incrementAll(other.trueNegatives);
this.falseNegatives.incrementAll(other.falseNegatives);
if(this.confusion == null) {
if(other.confusion != null) {
this.confusion = new ConfusionMatrix(other.confusion);
}
} else if(other.confusion != null) {
this.confusion.add(other.confusion);
}

this.numRowCounter += other.numRowCounter;
if(this.labelsList.isEmpty()) {
this.labelsList.addAll(other.labelsList);
}

}
}

                                            
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: