Keras实现一个简单的Regression回归
2017-07-28 16:36
621 查看
linux平台下运行,使用Keras框架,其中构建神经网络很简单,例子中指构造了一层神经网络
通过深度学习,将图中的点回归成线性模型,学习直线的W和b
#import various of packages import numpy as np np.random.seed(1337) from keras.models import Sequential from keras.layers import Dense import matplotlib.pyplot as plt #create some data X=np.linspace(-1,1,200) np.random.shuffle(X) //最后训练出的结果W越接近0.5,b越接近2则效果越好 Y=0.5*X+2+np.random.normal(0,0.05,(200,)) #plot data plt.scatter(X,Y) plt.show() //前160个数据作为训练样本,用于训练模型 X_train,Y_train=X[:160],Y[:160] X_test,Y_test=X[160:],Y[160:] #build a neural network model=Sequential() //Dense为全连接层,设定输入和输出的维度,因为每次输入的是一个点,所以维度为1,输出也是一个点,所以维度也为1. model.add(Dense(output_dim=1,input_dim=1)) #choose loss function and optimizer //训练之前,编译,设置随时函数,和优化函数 model.compile(loss='mse',optimizer='sgd') #training //训练步骤,300次,每100次返回一个损失值 print('Training---------------') for step in range(301): //每次训练batch大小的数据量 cost=model.train_on_batch(X_train,Y_train) if step%100==0: print('train cost:',cost) #test print('\nTesting---------------') cost=model.evaluate(X_test,Y_test,batch_size=40) print('test cost:',cost) //返回权重项和偏置项 W,b=model.layers[0].get_weights() print('Weight=',W,'\nbiases=',b) #plotting the prediction //图形化的形式,显示出来 Y_pred=model.predict(X_test) plt.scatter(X_test,Y_test) plt.plot(X_test,Y_pred) plt.show()
运行结果:
最后训练结果W=0.49222, b=1.99950
相关文章推荐
- Keras实现一个简单的CNN的分类例子
- 模式识别与机器学习基础之1-一个简单的回归问题(regression problem)
- [深度学习] (3)- Keras实现一个简单的翻译器( 从数字到对应的英文 )
- tensorflow实现softmax回归(softmax regression)——简单的MNIST识别(第一课)
- [简单题]换一个思维,代码简洁度就完全变了(Python实现)
- Asp.Net MVC3 简单入门第一季(五) 通过Asp.Net MVC的区域功能实现将多个MVC项目部署到一个站点
- 一个简单的线程池实现
- 简单实现一个EventEmitter
- 一个自动更新的简单实现(通过反射解耦)
- Android UI开发: 横向ListView(HorizontalListView)及一个简单相册的完整实现 (附源码下载)
- java中集合的运用,实现一个简单的购物程序
- 一个简单string类的实现
- Android: 横向ListView(HorizontalListView)及一个简单相册的完整实现 (附源码下载)
- 一个简单的js实现的隔行变色脚本
- Android 一个简单的自定义WheelView实现
- Android UI开发: 横向ListView(HorizontalListView)及一个简单相册的完整实现 (附源码下载)
- Tcp/ip实验准备:一个简单的定时器——boost实现
- [shiro学习笔记]第二节 shiro与web融合实现一个简单的授权认证
- 使用python多线程实现一个简单spider
- 一个简单实用的内存池实现之二 (C实现)