博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数
阅读量:3904 次
发布时间:2019-05-23

本文共 3260 字,大约阅读时间需要 10 分钟。

我们来解读一下,中心损失,再来看代码。

链接:

我们的重点是分析代码,所以定义部分,大家详情参见上面的博客。

代码:

#coding=gbk'''Created on 2020年4月20日@author: DELL'''import tensorflow as tfimport numpy as npdata = [[1,1,1,1,1],        [1,1,2,1,1],        [1,1,3,1,1],        [1,1,4,1,1],        [2,2,2,1,2],        [2,2,2,2,2],        [2,2,2,3,2],        [3,3,3,3,1],        [3,3,3,3,2]]label = [0,0,0,0,1,1,1,2,2]data = np.array(data,dtype = 'float32')label = np.array(label)data = tf.convert_to_tensor(data)label = tf.convert_to_tensor(label)def center_loss(features, label, alfa, nrof_classes):        """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"       (http://ydwen.github.io/papers/WenECCV16.pdf)    """    nrof_features = features.get_shape()[1]        centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,initializer=tf.constant_initializer(0), trainable=False)    #定义一个全零的centers, [nrof_classes, nrof_features]->(类别数,特征维度)       #print(sess.run(centers))        label = tf.reshape(label, [-1]) #一维向量    centers_batch = tf.gather(centers, label) #[batch_size,nrof_features] #按照label将centers归类,形成的新矩阵维度为 [label_size,nrof_features]    diff = (1 - alfa) * (centers_batch - features) #乘上我们的因子alfa [label_size,nrof_features]    centers = tf.scatter_sub(centers, label, diff) #按照label用centers - diff,产生本次的centers        with tf.control_dependencies([centers]):#注意这个函数的作用,是限制计算顺序的,即先计算centers,在利用计算好的centers去计算centers_batch以求loss        loss = tf.reduce_mean(tf.square(features - centers_batch))            return loss, centers,features,centers_batch,features - centers_batchloss, cen, fea, cen_bat,a = center_loss(data,label,0.5,3)sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)print(sess.run(cen))#print(sess.run(loss))print(sess.run(fea))#print(sess.run(cen_bat))print(sess.run(a))print(sess.run(fea - cen_bat))print(sess.run(tf.square(fea - cen_bat)))print(sess.run(loss)) '''验证tf.scatter_sub函数sess = tf.Session()ref = tf.Variable([1, 2, 3],dtype = tf.int32)indices = tf.constant([0, 0, 1, 1],dtype = tf.int32)updates = tf.constant([9, 10, 11, 12],dtype = tf.int32)sub = tf.scatter_sub(ref, indices, updates)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print (sess.run(sub))'''

结果:

1.centers:[[2.  2.  5.  2.  2. ] [3.  3.  3.  3.  3. ] [3.  3.  3.  3.  1.5]]2.features:[[1. 1. 1. 1. 1.] [1. 1. 2. 1. 1.] [1. 1. 3. 1. 1.] [1. 1. 4. 1. 1.] [2. 2. 2. 1. 2.] [2. 2. 2. 2. 2.] [2. 2. 2. 3. 2.] [3. 3. 3. 3. 1.] [3. 3. 3. 3. 2.]]3.centers_batch[[2.  2.  5.  2.  2. ] [2.  2.  5.  2.  2. ] [2.  2.  5.  2.  2. ] [2.  2.  5.  2.  2. ] [3.  3.  3.  3.  3. ] [3.  3.  3.  3.  3. ] [3.  3.  3.  3.  3. ] [3.  3.  3.  3.  1.5] [3.  3.  3.  3.  1.5]]4.features - centers_batch[[-1.  -1.  -4.  -1.  -1. ] [-1.  -1.  -3.  -1.  -1. ] [-1.  -1.  -2.  -1.  -1. ] [-1.  -1.  -1.  -1.  -1. ] [-1.  -1.  -1.  -2.  -1. ] [-1.  -1.  -1.  -1.  -1. ] [-1.  -1.  -1.   0.  -1. ] [ 0.   0.   0.   0.  -0.5] [ 0.   0.   0.   0.   0.5]]5.loss1.4111111

主要用到的函数:1.tf.gather(data,labels),将data按labels扩充

                             2.tf.scatter_sub(data,label,data_1),按label用data - data_

                             3.with tf.control_dependencies(): ,限制运算顺序

在实验验证时注意的点是:不要多次sess.run()某个张量涉及到带有依赖关系的张量,比如这里的loss,计算loss时 会 主动更新一次值,导致运算结果出错。原理我还没搞清,日后补上

转载地址:http://lmten.baihongyu.com/

你可能感兴趣的文章
145. 二叉树的后序遍历
查看>>
2. 两数相加
查看>>
3. 无重复字符的最长子串
查看>>
5. 最长回文子串
查看>>
4. 两个排序数组的中位数
查看>>
10. 正则表达式匹配
查看>>
23. 合并K个元素的有序链表
查看>>
32. 最长有效括号
查看>>
6. Z字形转换
查看>>
8. 字符串转整数(atoi)
查看>>
12. 整数转罗马数字
查看>>
15. 三数之和
查看>>
16. 最接近的三数之和
查看>>
18. 四数之和
查看>>
22. 括号生成
查看>>
24. 两两交换链表中的节点
查看>>
71. 简化路径
查看>>
77. 组合
查看>>
78. 子集
查看>>
89. 格雷编码
查看>>