您的位置:首页 > 其它

TensorFlow 多分类标签转换成One-hot

2017-11-15 21:29 911 查看

TensorFlow 多分类标签转换成One-hot

在处理多分类问题时,将多分类标签转成One-hot编码是一种很常见的手段,以下即为Tensorflow将标签转成One-hot的tensor。以Mnist为例,如果标签为“3”,则One-hot编码为[0,0,0,1,0,0,0,0,0,0].

import tensorflow as tf  # version : 1.4

NUM_CLASSES = 10 # 10分类
labels = [0,1,2,3] # sample label
batch_size = tf.size(labels) # get size of labels : 4
labels = tf.expand_dims(labels, 1) # 增加一个维度
indices = tf.expand_dims(tf.range(0, batch_size,1), 1) #生成索引
concated = tf.concat([indices, labels] , 1) #作为拼接
onehot_labels = tf.sparse_to_dense(concated, tf.stack([batch_size, NUM_CLASSES]), 1.0, 0.0) # 生成one-hot编码的标签


sparse_to_dense 函数说明:https://www.tensorflow.org/api_docs/python/tf/sparse_to_dense

将稀疏矩阵转换成密集矩阵,其中索引在concated中,值为1.其他位置的值为默认值0.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: