dl4j源码阅读心得及问题(Spark部分)
2016-08-05 21:50
447 查看
public class IrisLocal { public static void main(String[] args) throws Exception { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[*]"); sparkConf.setAppName("Iris"); JavaSparkContext sc = new JavaSparkContext(sparkConf); //Load the data from local (driver) classpath into a JavaRDD<DataSet>, for training //CSVRecordReader converts CSV data (as a String) into usable format for network training RecordReader recordReader = new CSVRecordReader(0,","); File f = new File("src/main/resources/iris_shuffled_normalized_csv.txt"); JavaRDD<String> irisDataLines = sc.textFile(f.getAbsolutePath()); //labelIndex变量指向目标向量在记录中的索引 int labelIndex = 4; int numOutputClasses = 3; //分别为每条记录创建特征向量和目标向量,目标向量根据numOutputClasses变量的个数以及记录中所给的目标索引确定,如目标索引为2,numOutputClasses为3,则目标向量为<0,1,0> JavaRDD<DataSet> trainingData = irisDataLines.map(new RecordReaderFunction(recordReader, labelIndex, numOutputClasses)) ; //First: Create and initialize multi-layer network. Configuration is the same as in normal (non-distributed) DL4J training final int numInputs = 4; int outputNum = 3; int iterations = 1; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(12345) .iterations(iterations) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(0.5) .regularization(true).l2(1e-4) .activation("tanh") .weightInit(WeightInit.XAVIER) .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3).build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax") .nIn(2).nOut(outputNum).build()) .backprop(true).pretrain(false) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); //Second: Set up the Spark training. //Set up the TrainingMaster. The TrainingMaster controls how learning is actually executed on Spark //Here, we are using standard parameter averaging int examplesPerDataSetObject = 1; ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject) .workerPrefetchNumBatches(2) //Asynchronously prefetch up to 2 batches .saveUpdater(true) .averagingFrequency(1) //See comments on averaging frequency in LSTM example. Averaging every 1 iteration is inefficient in practical problems .batchSizePerWorker(8) //Number of examples that each worker gets, per fit operation .build(); SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(sc,net,tm); int nEpochs = 100; for( int i=0; i<nEpochs; i++ ){ sparkNetwork.fit(trainingData); } //Finally: evaluate the (training) data accuracy in a distributed manner: Evaluation evaluation = sparkNetwork.evaluate(trainingData); System.out.println(evaluation.stats()); } }上面是程序示例,主要实现的功能是:在spark环境下进行神经网络的训练Evaluation evaluation = sparkNetwork.evaluate(trainingData);进入SparkDl4jMultiLayer类中evalute方法,其中传递的参数分别为trainingdata,null,64
public Evaluation evaluate(JavaRDD<DataSet> data, List<String> labelsList, int evalBatchSize) { Broadcast listBroadcast = labelsList == null?null:this.sc.broadcast(labelsList); JavaRDD evaluations = data.mapPartitions(new EvaluateFlatMapFunction(this.sc.broadcast(this.conf.toJson()), this.sc.broadcast(this.network.params()), evalBatchSize, listBroadcast)); return (Evaluation)evaluations.reduce(new EvaluationReduceFunction()); }data.mapPartitions()方法需要一个FlatMapFunction<Iterator<DataSet>, Evaluation>参数,这里使用子类来实例化,newEvaluateFlatMapFunction()方法中参数分别为SparkDl4jMultiLayer对象的json格式,即
MultiLayerNetwork net = new MultiLayerNetwork(conf)对象的json格式,第二个参数为
MultiLayerNetwork对象中的
flattenedParams变量,该变量为神经网络中的权值加偏移量的总和,最后的两个参数为64和null。
EvaluateFlatMapFunction()方法中实现了上层接口FlatMapFunction<Iterator<DataSet>, Evaluation>的call方法,该方法主要完成 神经网络训练结果的测试。
public Iterable<Evaluation> call(Iterator<DataSet> dataSetIterator) throws Exception { if(!dataSetIterator.hasNext()) { return Collections.emptyList(); } else { MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)this.json.getValue())); network.init(); INDArray val = (INDArray)this.params.value(); if(val.length() != network.numParams(false)) { throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters"); } else { network.setParameters(val); Evaluation evaluation; if(this.labels != null) { evaluation = new Evaluation((List)this.labels.getValue()); } else { evaluation = new Evaluation(); } ArrayList collect = new ArrayList(); int totalCount = 0; while(dataSetIterator.hasNext()) { collect.clear(); int nExamples = 0; DataSet data; while(dataSetIterator.hasNext() && nExamples < this.evalBatchSize) { data = (DataSet)dataSetIterator.next(); nExamples += data.numExamples(); collect.add(data); } totalCount += nExamples; data = DataSet.merge(collect, false); INDArray out; if(data.hasMaskArrays()) { out = network.output(data.getFeatureMatrix(), false, data.getFeaturesMaskArray(), data.getLabelsMaskArray()); } else { out = network.output(data.getFeatureMatrix(), false); } if(data.getLabels().rank() == 3) { if(data.getLabelsMaskArray() == null) { evaluation.evalTimeSeries(data.getLabels(), out); } else { evaluation.evalTimeSeries(data.getLabels(), out, data.getLabelsMaskArray()); } } else { evaluation.eval(data.getLabels(), out); } } if(log.isDebugEnabled()) { log.debug("Evaluated {} examples ", Integer.valueOf(totalCount)); } return Collections.singletonList(evaluation); } } }具体什么意思不是很明白,
(Evaluation)evaluations.reduce(new EvaluationReduceFunction());这个返回值主要是对各个分区最后得到的结果进行合并。
public void merge(Evaluation other) { if(other != null) { this.truePositives.incrementAll(other.truePositives); this.falsePositives.incrementAll(other.falsePositives); this.trueNegatives.incrementAll(other.trueNegatives); this.falseNegatives.incrementAll(other.falseNegatives); if(this.confusion == null) { if(other.confusion != null) { this.confusion = new ConfusionMatrix(other.confusion); } } else if(other.confusion != null) { this.confusion.add(other.confusion); } this.numRowCounter += other.numRowCounter; if(this.labelsList.isEmpty()) { this.labelsList.addAll(other.labelsList); } } }
相关文章推荐
- Dl4j-fit(DataSetIterator iterator)源码阅读(二)
- strtok调用问题及原因(部分源码)
- Struts源码阅读心得之html:link篇
- Lucene阅读源码需要解答的几个问题
- 淘宝数据库OceanBase SQL编译器部分 源码阅读--解析SQL语法树
- ArrayBlockingQueue 源码阅读 问题(一)
- 菜鸟学习OGRE和天龙八部之十七: 修正部分地图载入的通用性问题,附源码
- 淘宝数据库OceanBase SQL编译器部分 源码阅读--生成物理查询计划
- 纯C语言:贪心部分背包问题源码
- 原:android4.2.2蓝牙源码阅读--bluedroid部分
- 淘宝数据库OceanBase SQL编译器部分 源码阅读--生成逻辑计划
- Struts源码阅读心得之logic:Iterator篇
- 《DB2 最佳实践: 性能调优和问题诊断最佳实践,第 1 部分》阅读笔记
- 菜鸟学习OGRE和天龙八部之十七: 修正部分地图载入的通用性问题,附源码
- Openfire Spark Eclipse 源码阅读
- Struts源码阅读心得之html:cancel篇
- Struts源码阅读心得之bean:message篇
- Struts2(2.1.2)部分源码阅读
- emacs+ensime+sbt打造spark源码阅读环境
- Apache Spark源码走读之1 -- Spark论文阅读笔记