您的位置:首页 > 理论基础 > 计算机网络

MXNet | LeNet-5(卷积神经网络)用于手写字识别

2017-01-25 17:57 483 查看
卷积神经网络用于手写字识别,数据集来自kaggle的竞赛项目MNIST

卷积神经网络参考:http://yann.lecun.com/exdb/lenet/

比赛的官网:https://www.kaggle.com/c/digit-recognizer

若是下载数据集困难,可以去我的百度网盘下载:链接:http://pan.baidu.com/s/1sl50KjV 密码:ca56

读取数据集,这里用readr中的函数read_csv,读取速度快高效

读数据

>setwd("F:\\迅雷下载\\mnist")

>require(mxnet)
>library(readr)
>train <- read_csv('train.csv')
>test <- read_csv('test.csv')


数据处理,训练集和测试集

>train <- data.matrix(train)
>test <- data.matrix(test)

>train.x <- train[,-1]
>train.y <- train[,1]


数据放缩到[0,1]

>train.x <- t(train.x/255)
>test <- t(test/255)


查看数据类别的平衡性

>table(train.y)

train.y
0    1    2    3    4    5    6    7    8    9
4132 4684 4177 4351 4072 3795 4137 4401 4063 4188


可知,不同类之间比较平衡,数据质量较好

构建网络

# input
data <- mx.symbol.Variable('data')
# first conv
conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20)
tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# second conv
conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50)
tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# first fullc
flatten <- mx.symbol.Flatten(data=pool2)
fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10)


定义损失函数

# loss,softmax
lenet <- mx.symbol.SoftmaxOutput(data=fc2)


矩阵变为数组

##matrices into arrays
train.array <- train.x
dim(train.array) <- c(28, 28, 1, ncol(train.x))
test.array <- test
dim(test.array) <- c(28, 28, 1, ncol(test))


定义CPU和GPU

n.gpu <- 1
device.cpu <- mx.cpu()
device.gpu <- lapply(0:(n.gpu-1), function(i) {
mx.gpu(i)
})


模型定义:CPU

>#定义随机种子
>mx.set.seed(0)
>#当前的时间
>tic <- proc.time()
>model <- mx.model.FeedForward.create(lenet, X=train.array, y=train.y,
ctx=device.cpu, num.round=1, array.batch.size=100,
learning.rate=0.05, momentum=0.9, wd=0.00001,
eval.metric=mx.metric.accuracy,
epoch.end.callback=mx.callback.log.train.metric(100))

Start training with 1 devices
[1] Train-accuracy=0.557112171837709

>#基于cpu计算的时间
>print(proc.time() - tic)
用户    系统    流逝
502.69    2.66 3063.02


模型定义:GPU

> #随机种子
> mx.set.seed(0)
> #当前时间
> tic <- proc.time()
> #模型定义
> model <- mx.model.FeedForward.create(lenet, X=train.array, y=train.y,
ctx=device.gpu, num.round=5, array.batch.size=100,
learning.rate=0.05, momentum=0.9, wd=0.00001,
eval.metric=mx.metric.accuracy,
epoch.end.callback=mx.callback.log.train.metric(100))

Start training with 1 devices
Error in mx.nd.internal.empty.array(shape, ctx) :
[17:08:47] src/storage/storage.cc:78: Compile with USE_CUDA=1 to enable GPU usage
[17:08:47] C:/Users/qkou/mxnet/dmlc-core/include/dmlc/logging.h:235: [17:08:47] src/storage/storage.cc:78: Compile with USE_CUDA=1 to enable GPU usage

> print(proc.time() - tic)

用户 系统 流逝
0.26 0.33 4.68


预测

preds <- predict(model, test.array)
pred.label <- max.col(t(preds)) - 1


提交

submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)


得到符合提交要求的csv文件,直接去kaggle官网提交即可看到你的成绩排名。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐