您的位置:首页 > 编程语言 > MATLAB

【小波变换】STL版 一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的

2013-04-02 13:05 585 查看

【转】一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的

来源:http://hi.baidu.com/anatacollin/item/69fdab74ca7d045c0d0a07b4

一维离散小波变换(DWT)库,完全按matlab的wavelet toolbox 的API实现的2008-12-01 20:37最近项目中需要用,就自己写了个,发在这里算是备忘。需要的朋友也可以拿去试试,经测试没有发现bug,基于STL实现。如果发现bug或有什么建议请通知我,谢谢。

/************************************************************************/

/* wavelet.h

* Author: Collin

* Date: 2008/12/01

*/

/************************************************************************/

#ifndef WAVELET_H

#define WAVELET_H

#include <vector>

namespace Wavelet{

using std::vector;

struct C_L

{

vector<double> C;

vector<int> L;

};

struct WaveFilter{

vector<double> Low;

vector<double> High;

};

struct WaveCoeff{

vector<double> app;

vector<double> det;

};

const double sym4_Lo_D[] = {-0.0758, -0.0296, 0.4976, 0.8037, 0.2979, -0.0992, -0.0126, 0.0322};

const double sym4_Hi_D[] = {-0.0322, -0.0126, 0.0992, 0.2979, -0.8037, 0.4976, 0.0296, -0.0758};

const double sym4_Lo_R[] = {0.0322, -0.0126, -0.0992, 0.2979, 0.8037, 0.4976, -0.0296, -0.0758};

const double sym4_Hi_R[] = {-0.0758, 0.0296, 0.4976, -0.8037, 0.2979, 0.0992, -0.0126, -0.0322};

const static WaveFilter sym4_d = {vector<double>(sym4_Lo_D, sym4_Lo_D + 8), vector<double>(sym4_Hi_D, sym4_Hi_D + 8)};

const static WaveFilter sym4_r = {vector<double>(sym4_Lo_R, sym4_Lo_R + 8), vector<double>(sym4_Hi_R, sym4_Hi_R + 8)};

const WaveFilter& WFilters(

const char* strWaveName,

const char d_or_r

);

C_L WaveDec(

const vector<double>& signal,

const int nMaxLevel,

const char* strWaveName

);

WaveCoeff DWT(

const vector<double>& signal,

const vector<double>& Lo_D,

const vector<double>& Hi_D

);

vector<double> WRCoef(

const char a_or_d,

const vector<double>& C,

const vector<int>& L,

const char* strWaveName,

const int nLevel

);

vector<double> AppCoef(

const vector<double>& C,

const vector<int>& L,

const char* strWaveName,

const int nLevel

);

vector<double> DetCoef(

const vector<double>& C,

const vector<int>& L,

const int nLevel

);

//upsample and convolution

vector<double> UpsConv1(

const vector<double>& signal,

const vector<double>& filter,

const int nLen,

const char* strMode = "sym"

);

vector<double> Conv(

const vector<double>& vecSignal,

const vector<double>& vecFilter

);

vector<double> IDWT(

const vector<double>& app,

const vector<double>& det,

const vector<double>& Lo_R,

const vector<double>& Hi_R,

const int nLenCentral

);

vector<double> WExtend(

const vector<double>& signal,

const int nLenExt,

const char* mode = "sym"

);

vector<double> WConv1(

const vector<double>& signal,

const vector<double>& filter,

const char* shape = "valid"

);

}

#endif

/************************************************************************/

/* wavelet.cpp

* Author: Collin

* Date: 2008/12/01

*/

/************************************************************************/

#include <vector>

#include <string>

#include <iostream>

#include "wavelet.h"

using namespace std;

using namespace Wavelet;

C_L Wavelet::WaveDec(const vector<double>& signal,

const int nMaxLevel,

const char* strWaveName

)

{

const WaveFilter& filters = WFilters(strWaveName, 'd');

int len = signal.size();

C_L cl;

cl.L.push_back(len);

WaveCoeff waveCoeff;

waveCoeff.app = signal;

vector<double>::iterator itC;

vector<int>::iterator itL;

for (int i = 0; i < nMaxLevel; ++i){

waveCoeff = DWT(waveCoeff.app, filters.Low, filters.High);

itC = cl.C.begin();

cl.C.insert(itC, waveCoeff.det.begin(), waveCoeff.det.end());

itL = cl.L.begin();

cl.L.insert(itL, waveCoeff.det.size());

}

itC = cl.C.begin();

cl.C.insert(itC, waveCoeff.app.begin(), waveCoeff.app.end());

itL = cl.L.begin();

cl.L.insert(itL, waveCoeff.app.size());

return cl;

}

vector<double> Wavelet::WRCoef(const char a_or_d,

const vector<double>& C,

const vector<int>& L,

const char* strWaveName,

const int nLevel

)

{

vector<double> Coef;

const WaveFilter& filter = WFilters(strWaveName, 'r');

int nMax = L.size() - 2;

int nMin;

char type = tolower(a_or_d);

if ('a' == type)

nMin = 0;

else if ('d' == type)

nMin = 1;

else {

cerr << "bad parameter: a_or_d: "<< a_or_d << "\n";

exit(1);

}

if (nLevel < nMin || nLevel > nMax){

cerr << "bad parameter for level\n";

exit(1);

}

vector<double> F1;

switch (type){

case 'a':

Coef = AppCoef(C, L, strWaveName, nLevel);

if (0 == nLevel)

return Coef;

F1 = filter.Low;

break;

case 'd':

Coef = DetCoef(C, L, nLevel);

F1 = filter.High;

break;

default:

;

}

int iMin = L.size() - nLevel;

Coef = UpsConv1(Coef, F1, L[iMin], "sym");

for (int k = 1; k < nLevel; ++k){

Coef = UpsConv1(Coef, filter.Low, L[iMin + k], "sym");

}

return Coef;

}

vector<double> Wavelet::UpsConv1(const vector<double>& signal,

const vector<double>& filter,

const int nLen,

const char* strMode

)

{

//implement dyadup(y,0)

vector<double> y(2 * signal.size() - 1);

y[0] = signal[0];

for (int i = 1; i < signal.size(); ++i){

y[2*i - 1] = 0;

y[2*i] = signal[i];

}

y = Conv(y, filter);

//extract the central portion

vector<double>::iterator it = y.begin();

return vector<double>(it + (y.size() - nLen) / 2, it + (y.size() + nLen) / 2);

}

vector<double> Wavelet::Conv(const vector<double>& vecSignal, const vector<double>& vecFilter){

vector<double> signal(vecSignal);

vector<double> filter(vecFilter);

if (signal.size() < filter.size())

signal.swap(filter);

int lenSignal = signal.size();

int lenFilter = filter.size();

vector<double> result(lenSignal + lenFilter - 1);

for (int i = 0; i < lenFilter; i++){

for (int j = 0; j <= i; j++)

result[i] += signal[j] * filter[i - j];

}

for (int i = lenFilter; i < lenSignal; i++){

for (int j = 0; j <lenFilter; j++)

result[i] += signal[i - j] * filter[j];

}

for (int i = lenSignal; i < lenSignal + lenFilter - 1; i++){

for (int j = i - lenSignal + 1; j < lenFilter; j++)

result[i] += signal[i - j] * filter[j];

}

return result;

}

vector<double> Wavelet::DetCoef(const vector<double>& C,

const vector<int>& L,

const int nLevel

)

{

if (nLevel < 1 || nLevel > L.size() - 2){

cerr << "bad level parameter\n";

exit(1);

}

int nlast = 0, nfirst = 0;

vector<int>::const_reverse_iterator it = L.rbegin();

++it;

for (int i = 1; i < nLevel; ++i){

nlast += *it;

++it;

}

nfirst = nlast + *it;

return vector<double>(C.end() - nfirst, C.end() - nlast);

}

WaveCoeff Wavelet::DWT(const vector<double>& signal,

const vector<double>& Lo_D,

const vector<double>& Hi_D

)

{

int nLenExt = Lo_D.size() - 1;

vector<double> y;

y = WExtend(signal, nLenExt, "sym");

vector<double> z;

z = WConv1(y, Lo_D, "valid");

WaveCoeff coeff;

for (int i = 1; i < z.size(); i += 2){

coeff.app.push_back(z[i]);

}

z = WConv1(y, Hi_D, "valid");

for (int i = 1; i < z.size(); i += 2){

coeff.det.push_back(z[i]);

}

return coeff;

}

const WaveFilter& Wavelet::WFilters(const char* strWaveName,

const char d_or_r

)

{

char type = tolower(d_or_r);

if (!strcmp(strWaveName, "sym4")){

switch(type){

case 'd':

return Wavelet::sym4_d;

break;

case 'r':

return Wavelet::sym4_r;

break;

default:

cerr << "bad parameter for d_or_r\n";

exit(1);

}

}

else {

cerr << "not implement \n";

exit(1);

}

}

vector<double> Wavelet::AppCoef(

const vector<double>& C,

const vector<int>& L,

const char* strWaveName,

const int nLevel

)

{

int nMaxLevel = L.size() - 2;

if (nLevel < 0 || nLevel > nMaxLevel){

cerr << "bad parameter for level\n";

exit(1);

}

const WaveFilter& filters = WFilters(strWaveName, 'r');

vector<double> app(C.begin(), C.begin() + L[0]); //app for the last level

vector<double> det;

for (int i = 0; i < nMaxLevel - nLevel; ++i){

det = DetCoef(C, L, nMaxLevel - i);

app = IDWT(app, det, filters.Low, filters.High, L[i + 2]);

}

return app;

}

vector<double> Wavelet::IDWT(

const vector<double>& app,

const vector<double>& det,

const vector<double>& Lo_R,

const vector<double>& Hi_R,

const int nLenCentral

)

{

vector<double> app1, app2;

app1 = UpsConv1(app, Lo_R, nLenCentral, "sym");

app2 = UpsConv1(det, Hi_R, nLenCentral, "sym");

for (int i = 0; i < nLenCentral; ++i){

app1[i] += app2[i];

}

return app1;

}

vector<double> Wavelet::WExtend(

const vector<double>& signal,

const int nLenExt,

const char* mode

)

{

int signalLen = signal.size();

vector<double> result(signalLen + 2 * nLenExt);

for (int i = 0, idx = nLenExt; idx < signalLen + nLenExt; ++i, ++idx){

result[idx] = signal[i];

}

for (int idx = nLenExt - 1, bFlag = 1, signalIdx = 0; idx >= 0; --idx){

result[idx] = signal[signalIdx];

if (bFlag && ++signalIdx == signalLen){

bFlag = 0;

signalIdx = signalLen - 1;

}

else if (!bFlag && --signalIdx == -1) {

bFlag = 1;

signalIdx = 0;

}

}

for (int idx = nLenExt + signalLen, bFlag = 0, signalIdx = signalLen - 1; idx < 2 * nLenExt + signalLen; ++idx){

result[idx] = signal[signalIdx];

if (bFlag && ++signalIdx == signalLen){

bFlag = 0;

signalIdx = signalLen - 1;

}

else if (!bFlag && --signalIdx == -1) {

bFlag = 1;

signalIdx = 0;

}

}

return result;

}

vector<double> Wavelet::WConv1(

const vector<double>& signal,

const vector<double>& filter,

const char* shape

)

{

vector<double> y;

y = Conv(signal, filter);

int nLenExt = filter.size() - 1;

return vector<double>(y.begin() + nLenExt, y.end() - nLenExt);

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: