您的位置:首页 > 编程语言 > Python开发

Python3.5实现BP算法小实例

2017-03-31 23:09 211 查看


具体问题描述在下面PPT中,PPT链接: http://pan.baidu.com/s/1b3GRHc

import random, math
import numpy as np

# 输入值
a0 = [[1.78, 1.14, -1], [1.96, 1.18, -1], [1.86, 1.20, -1], [1.72, 1.24, -1], [2.00, 1.26, -1],
[2.00,1.28,-1], [1.96,1.30,-1], [1.74,1.36,-1], [1.64,1.38,-1], [1.82,1.38,-1],
[1.90,1.38,-1], [1.70,1.40,-1], [1.82,1.48,-1], [1.82,1.54,-1], [2.08,1.56,-1]]
# 目标值
t = [0.9,0.9,0.9,0.1,0.9,0.9,0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]
# 权重
w1 = np.random.randn(2,3)
w2 = np.random.randn(1,3)

# 学习率
ate = 0.1
# 迭代 100 轮
for Epoch in range(15000):
k = Epoch % 15
# 第一层输出加工一
u01 = w1[0][0] * a0[k][0] + w1[0][1] * a0[k][1] + w1[0][2] * a0[k][2]
u02 = w1[1][0] * a0[k][0] + w1[1][1] * a0[k][1] + w1[1][2] * a0[k][2]
# 第一层输出加工二,即激活函数
a1 = [0, 0, -1]
a1[0] = 1 / (1 + math.exp(-u01))
a1[1] = 1 / (1 + math.exp(-u02))

# 第二层输出加工一
u11 = w2[0][0] * a1[0] + w2[0][1] * a1[1] + w2[0][2] * a1[2]
# 第二层输出加工二
a2 = 1 / (1 + math.exp(-u11))

# 倒数第一层偏倒
a21 = math.exp(-u11) / math.pow((1 + math.exp(-u11)), 2)
ww2 = (t[k] - a2) * a21
# 其实 ww2 是需要乘以其对应的 a1[i] 才算是某个权重的真正的偏倒数,这里为了简化编程没有乘,
# 下面三行代码更新权重时会乘上,下文 ww1 同理
w2[0][0] = w2[0][0] + ate * ww2 * a1[0]
w2[0][1] = w2[0][1] + ate * ww2 * a1[1]
w2[0][2] = w2[0][2] + ate * ww2 * a1[2]

# 倒数第二层偏倒,有两行权重,所以有两个偏导数
a11 = math.exp(-u01) / math.pow((1 + math.exp(-u01)), 2)
a12 = math.exp(-u02) / math.pow((1 + math.exp(-u02)), 2)
ww1 = [0, 0]
ww1[0] = (ww2 * w2[0][0]) * a11
ww1[1] = (ww2 * w2[0][1]) * a12
# 更新权重
for i in range(2):
for j in range(3):
w1[i][j] = w1[i][j] + ate * ww1[i] * a0[k][j]

# 测试最后结果
for k in range(15):
# 第一层输出加工一
u01 = w1[0][0] * a0[k][0] + w1[0][1] * a0[k][1] + w1[0][2] * a0[k][2]
u02 = w1[1][0] * a0[k][0] + w1[1][1] * a0[k][1] + w1[1][2] * a0[k][2]
# 第一层输出加工二
a1 = [0, 0, -1]
a1[0] = 1 / (1 + math.exp(-u01))
a1[1] = 1 / (1 + math.exp(-u02))

# 第二层输出加工一
u11 = w2[0][0] * a1[0] + w2[0][1] * a1[1] + w2[0][2] * a1[2]
# 第二层输出加工二
a2 = 1 / (1 + math.exp(-u11))

# 计算的值如果大于0.5则为第一类,否则为第二类,t[k]是正确答案
if(a2 > 0.5):
print(0.9, t[k])
else:
print(0.1, t[k])
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: