您的位置:首页 > 理论基础 > 计算机网络

神经网络之自适应谐振网络ART及matlab实现——改进版

2016-10-03 12:42 585 查看
基础篇见点击打开链接

假如将样本重复输入三次产生的结果:

样本1属于第1类

样本2属于第2类

样本3属于第1类

样本4属于第3类

样本6属于第4类

样本7属于第3类

样本9属于第5类

样本10属于第1类

样本12属于第6类

样本13属于第4类

样本14属于第3类

样本16属于第7类

样本17属于第1类

样本19属于第8类

样本20属于第4类

样本21属于第3类

%%%%%%%%%%%%%%%%基础篇存在的问题%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

1、样本8无结果显示

2、相同样本结果输出不一致

解决方案:

问题一:很简单,在于在判定过程中对于产生新分类过程中缺少显示代码,加上就好,详情见下面的代码例程。

问题二:虽然程序可以完成对样本的分类,但是在后续的分类过程中,对于相同的数据程序会将其判定成不同的类别。究其原因在于这个竞争分类算法采用“胜者通吃”的原理(WTA)(余以为类似于贪心算法),每次都希望新样本在下次分类过程中有更大的相似度,在更新过程中会对老样本的记忆进行覆盖,导致下一次老样本输入网络并不能得到较大的相似度,从而导致老样本的分类与之前不一致。

整体来说,之前的程序对于相同样本存在不同分类,在实用过程中基本无使用价值。

解决思路:引入迭代过程,多次训练样本,使整个网络分类最终能够收敛。

MATLAB代码如下所示

ART1_main.m

%%%ART1神经网络练习
%%%作者:xd.wp
%%%时间:2016.09.30 19:37开始
clear all; clc;
%% 加载数据
train_data=[0 0 0 1 1 1 0;
0 0 0 1 1 0 0;
0 0 0 1 0 1 1;
1 0 1 0 1 0 1;
1 1 1 0 1 0 0;
1 1 0 0 0 0 1];                           % ART1网络输入为二进制,train_data为七个样本,每一列为一个样本
%验证部分
% train_data=[train_data,train_data,train_data,train_data,train_data,train_data,train_data,train_data,train_data];
data_length=size(train_data,1);
data_num=size(train_data,2);
N=100;
%% 网络参数初始化
R_node_num=3;
weight_b=ones(data_length,R_node_num)/N;
weight_t=ones(data_length,R_node_num);
threshold_ro=0.5;

%% 开始网络训练
result_pre=zeros(data_num,1);
for n=1:10
[R_node_num,weight_b,weight_t,result]=learn_iteration(train_data,R_node_num,weight_b,weight_t,threshold_ro,n);
if (result_pre==result)
disp('样本分类迭代完成!!!!');
break;
end
if (R_node_num==data_num+1)
disp('分类错误:样本类别数大于样本数');
break;
end
result_pre=result;
end
learn_iteration.m

function [R_node_num,weight_b,weight_t,result]=learn_iteration(train_data,R_node_num,weight_b,weight_t,threshold_ro,n)
data_length=size(train_data,1);
data_num=size(train_data,2);
result=zeros(data_num,1);
for i=1:data_num
R_node=zeros(R_node_num,1);                 %%匹配循环过程标志
for n=1:R_node_num                          %%匹配循环过程
%寻找竞争获胜神经元
for j=1:R_node_num
net(j,1)=sum(train_data(:,i).*weight_b(:,j));
end
[~,j_max]=max(net);
if R_node(j_max,1)==1                  %%循环激活标志判断
net(j_max,1)=-n;
end
[~,j_max]=max(net);
R_node(j_max,1)=1;                     %%去激活

%竞争获胜神经元通过外星权向量返回C层,进行相似度计算
weight_t_active=weight_t(:,j_max);
weight_b_active=weight_b(:,j_max);
Similarity_N0=sum(weight_t_active.*train_data(:,i));
Similarity_N1=sum(train_data(:,i));

if (threshold_ro<Similarity_N0/Similarity_N1)
[weight_t(:,j_max),weight_b(:,j_max)]=ART1_learn(train_data(:,i),weight_t_active,weight_b_active,n);
fprintf('样本%d属于第%d类\n',i,j_max);
result(i,1)=j_max;
flag=0;
break;
end
flag=1;
end
%% 判断是否需要添加新的结点,更新网络
if(flag==1)
%       ART1_updata_model()
R_node_num=R_node_num+1;
if (R_node_num==data_num+1)                               %判断分类数是否大于样本数
fprintf('样本%d属于第%d类\n Error:目前的分类类别数为%d \n',i,R_node_num,R_node_num);
return;
end
weight_b=[weight_b,train_data(:,i)];
weight_t=[weight_t,ones(data_length,1)];
fprintf('样本%d属于第%d类\n',i,R_node_num);
result(i,1)=R_node_num;
end
end
ART1_learn.m

function [weight_t_updata,weight_b_updata]=ART1_learn(train_data_active,weight_t_active,weight_b_active,n)

weight_t_updata=(n*weight_t_active+train_data_active.*weight_t_active)/(n+1);
% weight_t_updata=train_data_active.*weight_t_active;
weight_b_updata=weight_t_updata./(0.5+sum(weight_t_updata));
end


程序运行结果:

样本1属于第1类

样本2属于第1类

样本3属于第1类

样本4属于第2类

样本5属于第2类

样本6属于第2类

样本7属于第4类

样本1属于第4类

样本2属于第1类

样本3属于第1类

样本4属于第2类

样本5属于第4类

样本6属于第2类

样本7属于第4类

样本1属于第4类

样本2属于第1类

样本3属于第1类

样本4属于第2类

样本5属于第5类

样本6属于第5类

样本7属于第4类

样本1属于第4类

样本2属于第1类

样本3属于第1类

样本4属于第2类

样本5属于第5类

样本6属于第2类

样本7属于第4类

样本1属于第4类

样本2属于第1类

样本3属于第1类

样本4属于第2类

样本5属于第5类

样本6属于第2类

样本7属于第4类

样本分类迭代完成!!!!

程序分析:

程序的不足:

分类阈值大约就在0.5左右,当阈值增加时,程序会出现分类不收敛问题。

对应分析:可能是因为样本数量较少,维度较低造成的。当然也可能是程序迭代过程中有问题,这点还没有进行透彻分析。如有进展将在后边的文章给出。当然各路大神发现不足还望指点指点。。

程序迭代过程的解释:

<ul><li>weight_t_updata=train_data_active.*weight_t_active;</li></ul>


<ul><li>weight_t_updata=(n*weight_t_active+train_data_active.*weight_t_active)/(n+1);</li></ul>


上述第二行代码是第一行代码的改进,第一行代码运行过程中,识别阈值大概在0.4左右。 第二行代码,将后几轮迭代过程加入权值约束,轮数越大对网络权值更改的能力越低,最终完成程序的收敛,识别阈值达到0.5左右。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息