您的位置:首页 > 其它

tf.concat (API r1.3)

2017-11-15 14:03 459 查看

tf.concat (API r1.3)

1. tf.concat

concat(
values,
axis,
name='concat'
)


Defined in tensorflow/python/ops/array_ops.py.

See the guide: Tensor Transformations > Slicing and Joining

Concatenates tensors along one dimension.

Concatenates the list of tensors values along dimension axis. If values[i].shape = [D0, D1, ... Daxis(i), ...Dn], the concatenated result has shape

[D0, D1, ... Raxis, ...Dn]
where
Raxis = sum(Daxis(i))
That is, the data from the input tensors is joined along the axis dimension.

The number of dimensions of the input tensors must match, and all dimensions except axis must be equal.

For example:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]

# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0)) ==> [4, 3]
tf.shape(tf.concat([t3, t4], 1)) ==> [2, 6]


Note: If you are concatenating along a new axis consider using stack. E.g.

tf.concat([tf.expand_dims(t, axis) for t in tensors], axis)


can be rewritten as

tf.stack(tensors, axis=axis)

Args:

values: A list of Tensor objects or a single Tensor.

axis: 0-D int32 Tensor. Dimension along which to concatenate.

name: A name for the operation (optional).

Returns:

A Tensor resulting from concatenation of the input tensors.

2. example 1

import tensorflow as tf
import numpy as np

t1 = tf.constant([[0, 1, 2], [3, 4, 5]], dtype=np.float32)
t2 = tf.constant([[6, 7, 8], [9, 10, 11]], dtype=np.float32)

matrix0 = tf.concat([t1, t2], 0)
matrix1 = tf.concat([t1, t2], 1)

ops_shape0 = tf.shape(tf.concat([t1, t2], 0))
ops_shape1 = tf.shape(tf.concat([t1, t2], 1))

with tf.Session() as sess:
input_t1 = sess.run(t1)
print("input_t1.shape:")
print(input_t1.shape)
print('\n')

input_t2 = sess.run(t2)
print("input_t2.shape:")
print(input_t2.shape)
print('\n')

output_t1 = sess.run(matrix0)
print("output_t1.shape:")
print(output_t1.shape)
print("output_t1:")
print(output_t1)
print('
4000
\n')

output_t2 = sess.run(matrix1)
print("output_t2.shape:")
print(output_t2.shape)
print("output_t2:")
print(output_t2)
print('\n')

output_shape0 = sess.run(ops_shape0)
output_shape1 = sess.run(ops_shape1)
print("output_shape0:")
print(output_shape0)
print("output_shape1:")
print(output_shape1)

output:

input_t1.shape:
(2, 3)

input_t2.shape:
(2, 3)

output_t1.shape:
(4, 3)
output_t1:
[[  0.   1.   2.]
[  3.   4.   5.]
[  6.   7.   8.]
[  9.  10.  11.]]

output_t2.shape:
(2, 6)
output_t2:
[[  0.   1.   2.   6.   7.   8.]
[  3.   4.   5.   9.  10.  11.]]

output_shape0:
[4 3]
output_shape1:
[2 6]

Process finished with exit code 0

0表示行,1表示列

3. example 2

import tensorflow as tf
import numpy as np

t1 = tf.constant([[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]], dtype=np.float32)
t2 = tf.constant([[[12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23]]], dtype=np.float32)

matrix0 = tf.concat([t1, t2], 0)
matrix1 = tf.concat([t1, t2], 1)
matrix2 = tf.concat([t1, t2], 2)

ops_shape0 = tf.shape(tf.concat([t1, t2], 0))
ops_shape1 = tf.shape(tf.concat([t1, t2], 1))
ops_shape2 = tf.shape(tf.concat([t1, t2], 2))

with tf.Session() as sess:
input_t1 = sess.run(t1)
print("input_t1.shape:")
print(input_t1.shape)
print('\n')

input_t2 = sess.run(t2)
print("input_t2.shape:")
print(input_t2.shape)
print('\n')

output_t1 = sess.run(matrix0)
print("output_t1.shape:")
print(output_t1.shape)
print("output_t1:")
print(output_t1)
print('\n')

output_t2 = sess.run(matrix1)
print("output_t2.shape:")
print(output_t2.shape)
print("output_t2:")
print(output_t2)
print('\n')

output_t3 = sess.run(matrix2)
print("output_t3.shape:")
print(output_t3.shape)
print("output_t3:")
print(output_t3)
print('\n')

output_shape0 = sess.run(ops_shape0)
output_shape1 = sess.run(ops_shape1)
output_shape2 = sess.run(ops_shape2)
print("output_shape0:")
print(output_shape0)
print("output_shape1:")
print(output_shape1)
print("output_shape2:")
print(output_shape2)

output:

input_t1.shape:
(1, 4, 3)

input_t2.shape:
(1, 4, 3)

output_t1.shape:
(2, 4, 3)
output_t1:
[[[  0.   1.   2.]
[  3.   4.   5.]
[  6.   7.   8.]
[  9.  10.  11.]]

[[ 12.  13.  14.]
[ 15.  16.  17.]
[ 18.  19.  20.]
[ 21.  22.  23.]]]

output_t2.shape:
(1, 8, 3)
output_t2:
[[[  0.   1.   2.]
[  3.   4.   5.]
[  6.   7.   8.]
[  9.  10.  11.]
[ 12.  13.  14.]
[ 15.  16.  17.]
[ 18.  19.  20.]
[ 21.  22.  23.]]]

output_t3.shape:
(1, 4, 6)
output_t3:
[[[  0.   1.   2.  12.  13.  14.]
[  3.   4.   5.  15.  16.  17.]
[  6.   7.   8.  18.  19.  20.]
[  9.  10.  11.  21.  22.  23.]]]

output_shape0:
[2 4 3]
output_shape1:
[1 8 3]
output_shape2:
[1 4 6]

Process finished with exit code 0


0表示纵向,1表示行,2表示列

4.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: