您的位置:首页 > 理论基础 > 计算机网络

matlab调试卷积深度置信网络CDBN-master的时候出现crbm_forward2D_batch_mex没法识别(解决)

2017-05-04 14:24 387 查看
   今天帮群里的一个群友调matlab代码,CDBN,卷积深度置信网络,他说的是这个错误改了好几天都没法改,其实就是matlab如何调用c语言的问题,挺简单的。下面说说我的做法和如何在matlab中调用c语言的问题。

有一个通俗的比喻, 如果程序设计语言是车,那么C 语言就是全能手, C十十语言是加强版的C 语言, MATLAB是科学

家用来完成特殊任务的工具。作为使用MATLAB
的科学家和工程师, 通过混合程序设计,就可以借用CIC十十语言这两个全能手增强

MATLAB 的功能;作为使用C/C十十语言开发的开发者,也可以通过混合程序设计来使用MATLAB强大的科学计算与数据可视化功能。

准备好C语言程序,一般情况下要清楚C语言的入口函数,比如,如下的C语言函数:

void mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

编写mexfunction函数。mexfunction函数为C语言与MATLAB语言的接口函数。调用实例在mylinedetect.c文件中,文件内容如下:

#include <math.h>

#include <mex.h>

#include <matrix.h>

#include <time.h>

#include <string.h>

void mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

{
const mxArray  *model, *layer, *batch_data;
     mxArray  *model_new, *h_input_array, *h_sample_array, *output_array;

    int            ni,N,n_dim,n_map_h, n_map_v, nh,nv,j,i,jj,ii,id,
              Hstride,Wstride,Hfilter,Wfilter,Hres,Wres,H,W,Hpool,Wpool,Hout,Wout;
int            *_id;
double         *s_filter, *stride, *h, *data, *weights, *h_bias, *block, *pool,

                   *h_input, *h_sample, *output, *gaussian;
mwSize         *dim_vi, *dim_hi, *dim_id, *dim_h, *dim_out;

    mxChar         *type;

    model          = prhs[0];
layer          = prhs[1];
batch_data     = prhs[2];

    dim_vi         = mxGetDimensions(batch_data);

    n_dim          = mxGetNumberOfDimensions(batch_data);

    if (n_dim == 2 || n_dim == 3)

    N = 1;

    else

    N = dim_vi[3];

    dim_h          = (mwSize*)mxMalloc(sizeof(mwSize)*4);
dim_hi         = mxGetDimensions(mxGetField(model,0,"h_input"));

    dim_h[0]       = dim_hi[0];

    dim_h[1]       = dim_hi[1];

    dim_h[2]       = dim_hi[2];

    dim_h[3]       = N;

    n_map_h        = mxGetScalar(mxGetField(layer,0,"n_map_h"));
n_map_v        = mxGetScalar(mxGetField(layer,0,"n_map_v"));

    s_filter       = mxGetPr(mxGetField(layer,0,"s_filter"));   

    stride         = mxGetPr(mxGetField(layer,0,"stride"));

    data           = mxGetPr(batch_data);

    weights        = mxGetPr(mxGetField(model,0,"W"));

    h              = mxGetPr(mxCreateNumericArray(4,dim_h,mxDOUBLE_CLASS,mxREAL));

    h_bias         = mxGetPr(mxGetField(model,0,"h_bias"));

    block          = mxGetPr(mxCreateNumericArray(4,dim_h,mxDOUBLE_CLASS,mxREAL));

    pool           = mxGetPr(mxGetField(layer,0,"s_pool"));
h_input_array  = mxCreateNumericArray(4,dim_h,mxDOUBLE_CLASS, mxREAL);
h_input        = mxGetPr(h_input_array);
h_sample_array = mxCreateNumericArray(4,dim_h,mxDOUBLE_CLASS, mxREAL);
h_sample       = mxGetPr(h_sample_array);

    dim_out        = mxGetDimensions(mxGetField(model,0,"output"));

    dim_out[3]     = N;

    output_array   = plhs[0] = mxCreateNumericArray(4,dim_out,mxDOUBLE_CLASS, mxREAL);

    output         = mxGetPr(output_array);

    gaussian       = mxGetPr(mxGetField(model,0,"start_gau"));

    type           = mxGetChars(mxGetField(layer,0,"type_input"));

    /*Here need to pay attention to the _id:mxUINT32_CLASS*/

    dim_id         = (mwSize*)mxMalloc(sizeof(mwSize)*2);

    dim_id[0]      = pool[0]; dim_id[1] = pool[1];

    _id            = mxGetPr(mxCreateNumericArray(2,dim_id,mxUINT32_CLASS,mxREAL));

    mxFree(dim_id);

    mxFree(dim_h);

    Hstride        = stride[0];

    Wstride        = stride[1];

    Hfilter        = s_filter[0];

    Wfilter        = s_filter[1];

    H              = dim_vi[0];

    W              = dim_vi[1];

    Hres           = dim_hi[0];

    Wres           = dim_hi[1];

    Hpool          = pool[0];

    Wpool          = pool[1];

    Hout           = floor(Hres/Hpool);

    Wout           = floor(Wres/Wpool);

    for (ni = 0; ni < N; ni++){

        for (nh = 0; nh < n_map_h; nh++){

            for (j = 0; j < Wres; j++){

                for (i = 0; i < Hres; i++){

                    id = i+Hres*j+Hres*Wres*nh+Hres*Wres*n_map_h*ni;

                    h[id] = 0;

                    h_input[id] = 0;

                    for (nv = 0; nv < n_map_v; nv++){

                        for (jj = 0; jj < Wfilter; jj++){

                            for (ii = 0; ii < Hfilter; ii++){

                                h[id] += data[(i*Hstride+ii)+H*(j*Wstride+jj)+H*W*nv+H*W*n_map_v*ni]

                                        * weights[(ii+Hfilter*jj)+Hfilter*Wfilter*nv+Hfilter*Wfilter*n_map_v*nh];

                            }

                        }

                    }

                    h_input[id] = h[id] + h_bias[nh];

                    /* for crbm blocksum & outpooing */

                    if (type[0] == 'B')

                        block[id] = exp(h_input[id]);

                    if (type[0] == 'G')

                        block[id] = exp(1.0/(gaussian[0]*gaussian[0])*h_input[id]);

                }

            }

            /* output the pooling & crbm blocksum: hidden activation summation */

            for (j = 0; j < Wout; j++){

                for (i = 0; i < Hout; i++){

                    double sum = 0.0;

                    for (jj = 0; jj < Wpool; jj++){

                        _id[jj*Hpool] = i*Hpool+(j*Wpool+jj)*Hres + Hres*Wres*nh+Hres*Wres*n_map_h*ni;

                        sum += block[_id[jj*Hpool]];

                        for (ii = 1; ii < Hpool; ii++){

                            _id[jj*Hpool+ii] = _id[jj*Hpool+ii-1] + 1;

                            sum += block[_id[jj*Hpool+ii]];

                        }

                    }

                    int out_id = i+j*Hout+Hout*Wout*nh+Hout*Wout*n_map_h*ni;

                    for (jj = 0; jj < Hpool*Wpool; jj++){

                        h_sample[_id[jj]] = 1.0-(1.0/(1.0+sum));

                    }

                    output[out_id] = h_sample[_id[0]];

                }

            }

        }

    }

    return;

}

4

在MATLAB中调用mex指令编译相关文件,将C语言编译为MEX文件,如下所示。

mex mylinedetect.c linedetect.c

编译完成后,生成mylinedetect.mexw32或mylinedetect.mexw64文件,此文件即mex文件,用于MATLAB与C语言接口函数

5

编译完成之后,编写MATLAB函数,调用MEX文件。如下所示。

load trees;

%以MEX文件的形式调用编译完成的C语言函数

[o1,o2]=mylinedetect(double(X).');

......

6

输出结果,上述linedetect函数完成图像中直线检测功能,带入MATLAB中调用后,形成如下结果。







不懂的可以加我的QQ群:522869126(语音信号处理) 欢迎你
的到来哦,看了博文给点脚印呗,谢谢啦~~
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐