您的位置:首页 > 理论基础 > 计算机网络

选择性的加载网络模型的前几层训练(27)---《深度学习》

2017-11-22 23:15 176 查看
加载模型的前几层拼接自己构建的层进行训练

注意这里我们使用了nets.inception.inception_v3_base来进行网络模型的部分恢复,因为nets.inception.inception_v3_base中可以指定final_endpoint参数进行网络的末尾层指定,然后通过在saver的restore函数中进行参数的设定来确保那些权值进行恢复,那些不需要进行恢复!

train.py

#-*-coding=utf-8-*-
from PIL import Image
import os
import os.path
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
import inception_resnet_v2
import img_convert

height = 299
width = 299
channels = 3
num_classes=1001

def convert(dir):
filelists=os.listdir(dir)
arr_col=[]
for file in filelists:
file_path=os.path.join(dir,file)
img=Image.open(file_path).resize((299,299)).convert("RGB")
r,g,b=img.split()

r_arr=np.array(r)
g_arr=np.array(g)
b_arr=np.array(b)

img_arr=np.concatenate((r_arr,g_arr,b_arr))
result=img_arr.reshape((299,299,3))
arr_col.append(result)
return arr_col
def convert_3_2_4_dims(arr_):
ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))
for i in range(len(arr_)):
ret[i,:,:,:]=arr_[i]
return ret
if __name__=="__main__":
o_dir="E:/test"
num_classes=182
batch_size=3
epoches=2
X = tf.placeholder(tf.float32, shape=[None, height, width, channels])
y = tf.placeholder(tf.float32,shape=[None,182])
with slim.arg_scope(nets.inception.inception_v3_arg_scope()):
logits,end_points_ = nets.inception.inception_v3_base(X,final_endpoint='Mixed_7c')
variables_to_restore=slim.get_variables_to_restore()
shape=logits.get_shape().as_list()
dim=1
for d in shape[1:]:
dim*=d
fc_=tf.reshape(logits,[-1,dim])

fc0_weights=tf.get_variable(name="fc0_weights",shape=(dim,182),initializer=tf.contrib.layers.xavier_initializer())
fc0_biases=tf.get_variable(name="fc0_biases",shape=(182),initializer=tf.contrib.layers.xavier_initializer())
logits_=tf.nn.bias_add(tf.matmul(fc_,fc0_weights),fc0_biases)
predictions=tf.nn.softmax(logits_)
#cross_entropy = -tf.reduce_sum(y*tf.log(predictions))
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=logits_))
#cross_entropy_mean=tf.reduce_mean(cross_entropy)
train_step=tf.train.GradientDescentOptimizer(1e-6).minimize(cross_entropy)

correct_pred=tf.equal(tf.argmax(y,1),tf.argmax(predictions,1))
#acc=tf.reduce_sum(tf.cast(correct_pred,tf.float32))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))

with tf.Session() as sess:

batches=img_convert.data_lrn(img_convert.load_data(o_dir,num_classes,batch_size))
sess.run(tf.global_variables_initializer())
saver=tf.train.Saver(variables_to_restore)
saver.restore(sess,os.path.join("E:\\","inception_v3.ckpt"))

for epoch in range(epoches):
for batch in batches:
sess.run(train_step,feed_dict={X:batch[0],y:batch[1]})
acc=sess.run(accuracy,feed_dict={X:batches[0][0],y:batches[1][1]})
print(acc)
print("Done")


img_convert.py

#coding=utf-8
from PIL import Image
import os
import os.path
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
import inception_resnet_v2

def convert(dir):
filelists=os.listdir(dir)
arr_col=[]
for file in filelists:
file_path=os.path.join(dir,file)
img=Image.open(file_path).resize((299,299)).convert("RGB")
r,g,b=img.split()

r_arr=np.array(r)
g_arr=np.array(g)
b_arr=np.array(b)

img_arr=np.concatenate((r_arr,g_arr,b_arr))
result=img_arr.reshape((299,299,3))
arr_col.append(result)
return arr_col
def convert_3_2_4_dims(arr_):
ret=np.zeros((len(arr_),arr_[0].shape[0],arr_[0].shape[1],arr_[0].shape[2]))
for i in range(len(arr_)):
ret[i,:,:,:]=arr_[i]
return ret
def to_categorial(y,n_classes):
y_std=np.zeros([len(y),n_classes])
for i in range(len(y)):
y_std[i,y[i]]=1.0
return y_std
def batch_list(x,y,batch_size):
batches=[]
for i in range(int(len(x)/batch_size)):
batch_data=[x[batch_size*i:batch_size*i+batch_size],y[batch_size*i:batch_size*i+batch_size]]
batches.append(list(batch_data))
if(i+1)*batch_size<len(x):
batch_data=[x[batch_size*(i+1):],y[batch_size*(i+1):]]
batches.append(list(batch_data))

return batches
def load_data(dir,num_classes,batch_size):
arr_col=convert_3_2_4_dims(convert(dir))
arr_col=arr_col.astype(np.float32)
#因为这儿我没指定它的标签,所以就随机指定了一些标签
z=np.random.rand(arr_col.shape[0])*num_classes
z=z.astype("int")
labels=np.array(z)
batches=batch_list(arr_col,to_categorial(labels,num_classes),batch_size)
return batches

def data_lrn(batches):
for batch in batches:
batch[0]/=255
return batches
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: