您的位置:首页 > 其它

文章标题

2016-07-27 17:17 585 查看
import mxnet as mx

def mlp():
data=mx.sym.Variable('data')
fc1=mx.sym.FullyConnected(data, name="fc1", num_hidden=512)
act1=mx.sym.Activation(fc1, name="relu1", act_type="relu")
fc2=mx.sym.FullyConnected(act1, name="fc2", num_hidden=512)
act2=mx.sym.Activation(fc2, name="relu2", act_type="relu")
fc3=mx.sym.FullyConnected(act2, name="fc3", num_hidden=10)
mlp=mx.sym.SoftmaxOutput(fc3, name="softmax")
return mlp

if __name__=="__main__":
num_epoch=3
batch_size=100
train_dataiter=mx.io.CSVIter(data_csv="mnist.train", data_shape=(28, 28), label_csv="label.train", label_shape=(1,), batch_size=batch_size)
val_dataiter=mx.io.CSVIter(data_csv="mnist.val", data_shape=(28, 28), label_csv="label.val", label_shape=(1,), batch_size=batch_size)
mlp=mlp()
###config model_args
model_args = dict()
##the first parameter is the number of the batch_num
model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(500, 0.9)
model=mx.model.FeedForward(
ctx=mx.gpu(0),
symbol=mlp,
num_epoch=5,
learning_rate=0.01,
momentum=0.9,
wd=0.01,
**model_args)

####cofig log file
import logging
LOG_FILE='mnist.log'
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG)

model.fit(
X=train_dataiter,
eval_data=val_dataiter,
batch_end_callback=mx.callback.Speedometer(batch_size, 50),)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  mxnet