您的位置:首页 > 其它

predict.py backup

2016-05-25 10:55 309 查看
import mxnet as mx
import numpy as np
from skimage import io, transform

prefix="model/mnist-0"
num_round=2
model = mx.model.FeedForward.load(prefix, num_round, ctx=mx.cpu(), numpy_batch_size=1)

#mean_img=mx.nd.load("Inception/mean_224.nd")["mean_img"]
mean_img = np.load("mean_img.npy")
mean_img = np.swapaxes(mean_img, 0, 2)
mean_img = np.swapaxes(mean_img, 1, 2)
#print mean_img

#synset=[l.strip() for l in open("Inception/synset-2w.txt").readlines()]

def PreprocessImage(path):
img=io.imread(path)
short_egde = min(img.shape[:2])
yy = int((img.shape[0] - short_egde)/2)
xx = int((img.shape[1] - short_egde)/2)
crop_img = img[yy: yy+short_egde, xx:xx+short_egde]
resized_img = transform.resize(crop_img, (224, 224))
sample = np.asarray(resized_img)*256
#sample = np.asarray(resized_img)
sample = np.swapaxes(sample, 0, 2)
sample = np.swapaxes(sample, 1, 2)
normed_img = sample - mean_img
print normed_img
print normed_img.shape
normed_img = normed_img.reshape(1, 3, 224, 224)
return normed_img

batch = PreprocessImage("1999.jpg")
prob = model.predict(batch)[0]
print prob
#pred = np.argsort(prob)[::-1]
#top1=synset[pred[0]]
#print("Top1: ", top1)
~
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: