您的位置:首页 > 其它

weka2

2015-09-07 14:44 211 查看

GridSearch(源码)

主要参数设置:

m_MinX m_MaxX m_StepX 
m_LabelX  m_X_Basem_MinY 
m_MaxYm_StepY
m_LabelY  m_Y_Base

classifier evaluation等等

主要函数:

public void buildClassifier(Instances data){...} 

|

protected PointDouble findBest(){...}

|

protected PointDouble determineBestInGrid(Grid grid, Instances inst, int cv){...}

|

public EvaluationTask(GridSearch owner, SetupGenerator generator,Instances inst, PointDouble values, int folds, int eval) {...}

buildClassifier-> findBest()->determineBestInGrid->EvaluationTask

重要的程序:

findBest :

result = determineBestInGrid(m_Grid, sample, 2);

确定了m_MinX等参数后可以生成一个网格grid,对网格中每一对点对(x,y)先用2折交叉验证选出其中performance最好的那一对点对,即result。

判断result在网格中的位置center是否在边界上,如果在边界上并且可以扩展的话,得到新的center,然后再以新的center为中心的邻域组成的grid进行10折交叉验证,找出

其中更最优的点对result。如果新的result和旧的相同,那么退出,否则继续上述过程,具体程序如下:

findBest :

finished=false;

if (!finished) {

      do {

        iteration++;

        resultOld = (PointDouble) result.clone();

        center = m_Grid.getLocation(result);  //获得在grid中的位置

        if (m_Grid.isOnBorder(center)) {

          log("Center is on border of grid.");

          if (getGridIsExtendable()) {

            if (m_GridExtensionsPerformed == getMaxGridExtensions()) {

              log("Maximum number of extensions reached!\n");

              finished = true;

            } else {

              m_GridExtensionsPerformed++;

              m_Grid = m_Grid.extend(result); //扩展

              center = m_Grid.getLocation(result);

              log("Extending grid (" + m_GridExtensionsPerformed + "/"

                + getMaxGridExtensions() + "):\n" + m_Grid + "\n");

            }

          } else {

            finished = true;

          }

        }

        if (!finished) {

          neighborGrid = m_Grid.subgrid((int) center.getY() + 1,

            (int) center.getX() - 1, (int) center.getY() - 1,

            (int) center.getX() + 1);

          result = determineBestInGrid(neighborGrid, sample, 10);

          log("\nResult of Step 2/Iteration " + (iteration) + ":\n" + result);

          finished = m_UniformPerformance;

          if (result.equals(resultOld)) {

            finished = true;

            log("\nNo better point found.");

          }

        }

      } while (!finished);

    }

determineBestInGrid:

Collections.sort(m_Performances, new PerformanceComparator(m_Evaluation));     //排序

result = m_Performances.get(m_Performances.size() - 1).getValues();   //选择最大值所对应的点对


EvaluationTask:

x = m_Generator.evaluate(m_Values.getX(), true);   //计算x的值

y = m_Generator.evaluate(m_Values.getY(), false); //计算y的值

classifier = (Classifier) m_Generator.setup(m_Classifier, x, y);  

eval = new Evaluation(data);
eval.crossValidateModel(classifier, data, m_Folds, new Random(m_Owner.getSeed()));   //交叉验证

performance = new Performance(m_Values, eval);

m_Owner.addPerformance(performance, m_Folds);   //

整个程序最主要就是这几个函数。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: