您的位置:首页 > Web前端

CAFFE提取特征并可视化

2018-01-26 16:21 525 查看
使用CAFFE( http://caffe.berkeleyvision.org )运行CNN网络,并提取出特征,将其存储成lmdb以供后续使用,亦可以对其可视化。

使用已训练好的模型进行图像分类

其实在 http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb 中已经很详细地介绍了怎么使用已训练好的模型对测试图像进行分类了。由于CAFFE不断更新,这个页面的内容和代码也会更新。以下只记录当前能运行的主要步骤。

下载CAFFE,并安装相应的dependencies。

caffe_root
下运行
./scripts/download_model_binary.py models/bvlc_reference_caffenet
获得预训练的CaffeNet。

在ipython里(或python,但需要把部分代码注释掉)运行以下代码来加载网络。

./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Make sure that caffe is on the python path:

caffe_root = '../'  # this file is expected to be in {caffe_root}/examples
import sys
sys.path.insert(0, caffe_root + 'python')

import caffe

plt.rcParams['figure.figsize'] = (10, 10)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

import os
if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'):
print("Downloading pre-trained CaffeNet model...")
!../scripts/download_model_binary.py ../models/bvlc_reference_caffenet
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21

设置网络为测试阶段,并加载网络模型prototxt和数据平均值mean_npy。

./models/bvlc_reference_caffenet/deploy.prototxt

./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

./python/caffe/imagenet/ilsvrc_2012_mean.npy

caffe.set_mode_cpu()
net = caffe.Net(caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt',
caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel',
caffe.TEST)

# input preprocessing: 'data' is the name of the input blob == net.inputs[0]

transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy').mean(1).mean(1)) # mean pixel
transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB
1
2
3
4
5
6
7
8
9
10
11
12
13

加载测试图片,并预测分类结果。

./examples/images/cat.jpg

# set net to batch size of 50

net.blobs['data'].reshape(50,3,227,227)

net.blobs['data'].data[...] = transformer.preprocess('data', caffe.io.load_image(caffe_root + 'examples/images/cat.jpg'))
out = net.forward()
print("Predicted class is #{}.".format(out['prob'].argmax()))
1
2
3
4
5
6
7
8

加载标签,并输出top_k。

./data/ilsvrc12/synset_words.txt

# load labels

imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
try:
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
except:
!../data/ilsvrc12/get_ilsvrc_aux.sh
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')

# sort top k predictions from softmax output

top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
print labels[top_k]
1
2
3
4
5
6
7
8
9<
4000
/li>10
11
12
13
14
15

提取特征并可视化

接上一章,如果提取特征之后不作存储直接可视化的话,可按以下步骤。

网络的特征存储在
net.blobs
,参数和bias存储在
net.params
,以下代码输出每一层的名称和大小。这里亦可手动把它们存储下来。

[(k, v.data.shape) for k, v in net.blobs.items()]
[(k, v[0].data.shape) for k, v in net.params.items()]
1
2

可视化。以下是辅助函数。

# take an array of shape (n, height, width) or (n, height, width, channels)

# and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)

def vis_square(data, padsize=1, padval=0):
data -= data.min()
data /= data.max()

# force the number of filters to be square
n = int(np.ceil(np.sqrt(data.shape[0])))
padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))

# tile the filters into an image
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])

plt.imshow(data)
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
根据每一层的名称,选择需要可视化的层,可以可视化filters(参数)和output(特征)
# the parameters are a list of [weights, biases]

filters = net.params['conv1'][0].data
vis_square(filters.transpose(0, 2, 3, 1))

feat = net.blobs['conv1'].data[0, :36]
vis_square(feat, padval=1)

# There are 256 filters, each of which has dimension 5 x 5 x 48. We show only the first 48 filters, with each channel shown separately, so that each filter is a row.

filters = net.params['conv2'][0].data
vis_square(filters[:48].reshape(48**2, 5, 5))

# rectified, only the first 36 of 256 channels

feat = net.blobs['conv2'].data[0, :36]
vis_square(feat, padval=1)

feat = net.blobs['conv3'].data[0]
vis_square(feat, padval=0.5)

feat = net.blobs['conv4'].data[0]
vis_square(feat, padval=0.5)

feat = net.blobs['conv5'].data[0]
vis_square(feat, padval=0.5)

feat = net.blobs['pool5'].data[0]
vis_square(feat, padval=1)

feat = net.blobs['fc6'].data[0]
plt.subplot(2, 1, 1)
plt.plot(feat.flat)
plt.subplot(2, 1, 2)
_ = plt.hist(feat.flat[feat.flat > 0], bins=100)

feat = net.blobs['fc7'].data[0]
plt.subplot(2, 1, 1)
plt.plot(feat.flat)
plt.subplot(2, 1, 2)
_ = plt.hist(feat.flat[feat.flat > 0], bins=100)

feat = net.blobs['prob'].data[0]
plt.plot(feat.flat)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

提取特征并存储

CAFFE提供了一个提取特征的tool,见 http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html

选择需要特征提取的图像。

./examples/_temp

mkdir examples/_temp
find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt
1
2
3

跟前面一样,下载模型以及定义prototxt。

./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

./examples/_temp/imagenet_val.prototxt

./data/ilsvrc12/get_ilsvrc_aux.sh
cp examples/feature_extraction/imagenet_val.prototxt examples/_temp
1
2

使用
extract_features.bin
工具提取特征,并存储为lmdb。运行参数为
extract_features.bin $MODEL $PROTOTXT $LAYER $LMDB_OUTPUT_PATH $BATCHSIZE


./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel

./examples/_temp/imagenet_val.prototxt

./examples/_temp/features

./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb
1

使用特征文件进行可视化

参考 http://www.cnblogs.com/platero/p/3967208.html 和 lmdb的文档 https://lmdb.readthedocs.org/en/release ,读取lmdb文件,然后转换成mat文件,再用matlab调用mat进行可视化。

安装CAFFE的python依赖库,并使用以下两个辅助文件把lmdb转换为mat。

./feat_helper_pb2.py


# Generated by the protocol buffer compiler.  DO NOT EDIT!

from google.protobuf import descriptor
from google.protobuf import message
from google.protobuf import reflection
from google.protobuf import descriptor_pb2

# @@protoc_insertion_point(imports)

DESCRIPTOR = descriptor.FileDescriptor(
name='datum.proto',
package='feat_extract',
serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')

_DATUM = descriptor.Descriptor(
name='Datum',
full_name='feat_extract.Datum',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
descriptor.FieldDescriptor(
name='channels', full_name='feat_extract.Datum.channels', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='height', full_name='feat_extract.Datum.height', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='width', full_name='feat_extract.Datum.width', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='data', full_name='feat_extract.Datum.data', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value="",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='label', full_name='feat_extract.Datum.label', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
descriptor.FieldDescriptor(
name='float_data', full_name='feat_extract.Datum.float_data', index=5,
number=6, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
extension_ranges=[],
serialized_start=29,
serialized_end=134,
)

DESCRIPTOR.message_types_by_name['Datum'] = _DATUM

class Datum(message.Message):
__metaclass__ = reflection.GeneratedProtocolMessageType
DESCRIPTOR = _DATUM

# @@protoc_insertion_point(class_scope:feat_extract.Datum)

# @@protoc_insertion_point(module_scope)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

./lmdb2mat.py


import lmdb
import feat_helper_pb2
import numpy as np
import scipy.io as sio
import time

def main(argv):
lmdb_name = sys.argv[1]
print "%s" % sys.argv[1]
batch_num = int(sys.argv[2]);
batch_size = int(sys.argv[3]);
window_num = batch_num*batch_size;

start = time.time()
if 'db' not in locals().keys():
db = lmdb.open(lmdb_name)
txn= db.begin()
cursor = txn.cursor()
cursor.iternext()
datum = feat_helper_pb2.Datum()

keys = []
values = []
for key, value in enumerate( cursor.iternext_nodup()):
keys.append(key)
values.append(cursor.value())

ft = np.zeros((window_num, int(sys.argv[4])))
for im_idx in range(window_num):
datum.ParseFromString(values[im_idx])
ft[im_idx, :] = datum.float_data

print 'time 1: %f' %(time.time() - start)
sio.savemat(sys.argv[5], {'feats':ft})
print 'time 2: %f' %(time.time() - start)
print 'done!'

if __name__ == '__main__':
import sys
main(sys.argv)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

运行bash

#!/usr/bin/env sh

LMDB=./examples/_temp/features_fc7 # lmdb文件路径
BATCHNUM=1
BATCHSIZE=10

# DIM=290400 # feature长度,conv1

# DIM=43264 # conv5

DIM=4096
OUT=./examples/_temp/features_fc7.mat #mat文件保存路径
python ./lmdb2mat.py $LMDB $BATCHNUM $BATCHSIZE $DIM $OUT
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

参考UFLDL里的display_network函数,对mat文件里的特征进行可视化。

display_network.m


function [h, array] = display_network(A, opt_normalize, opt_graycolor, cols, opt_colmajor)
% This function visualizes filters in matrix A. Each column of A is a
% filter. We will reshape each column into a square image and visualizes
% on each cell of the visualization panel.
% All other parameters are optional, usually you do not need to worry
% about it.
% opt_normalize: whether we need to normalize the filter so that all of
% them can have similar contrast. Default value is true.
% opt_graycolor: whether we use gray as the heat map. Default is true.
% cols: how many columns are there in the display. Default value is the
% squareroot of the number of columns in A.
% opt_colmajor: you can switch convention to row major for A. In that
% case, each row of A is a filter. Default value is false.
warning off all

if ~exist('opt_normalize', 'var') || isempty(opt_normalize)
opt_normalize= true;
end

if ~exist('opt_graycolor', 'var') || isempty(opt_graycolor)
opt_graycolor= true;
end

if ~exist('opt_colmajor', 'var') || isempty(opt_colmajor)
opt_colmajor = false;
end

% rescale
A = A - mean(A(:));

if opt_graycolor, colormap(gray); end

% compute rows, cols
[L M]=size(A);
sz=sqrt(L);
buf=1;
if ~exist('cols', 'var')
if floor(sqrt(M))^2 ~= M
n=ceil(sqrt(M));
while mod(M, n)~=0 && n<1.2*sqrt(M), n=n+1; end
m=ceil(M/n);
else
n=sqrt(M);
m=n;
end
else
n = cols;
m = ceil(M/n);
end

array=-ones(buf+m*(sz+buf),buf+n*(sz+buf));

if ~opt_graycolor
array = 0.1.* array;
end

if ~opt_colmajor
k=1;
for i=1:m
for j=1:n
if k>M,
continue;
end
clim=max(abs(A(:,k)));
if opt_normalize
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
else
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/max(abs(A(:)));
end
k=k+1;
end
end
else
k=1;
for j=1:n
for i=1:m
if k>M,
continue;
end
clim=max(abs(A(:,k)));
if opt_normalize
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
else
array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)';
end
k=k+1;
end
end
end

if opt_graycolor
h=imagesc(array);
else
h=imagesc(array,'EraseMode','none',[-1 1]);
end
axis image off

drawnow;

warning on all
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

在matlab里运行以下代码:

nsample = 2;
% num_output = 96; % conv1
% num_output = 256; % conv5
num_output = 4096; % fc7

load features_fc7.mat
width = size(feats, 2);
nmap = width / num_output;

for i = 1 : nsample
feat = feats(i, :);
feat = reshape(feat, [nmap num_output]);
figure('name', sprintf('image #%d', i));
display_network(feat);
end
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

在python中读取mat文件

在python中,使用
scipy.io.loadmat()
即可读取mat文件,返回一个
dict()


import scipy.io
matfile = 'features_fc7.mat'
data = scipy.io.loadmat(matfile)
1
2
3

使用自己的网络

只需把前面列出来的文件与参数修改成自定义的即可。

参考

http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb

http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html

http://www.cnblogs.com/platero/p/3967208.html

https://lmdb.readthedocs.org/en/release/
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: