博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow中滑动平均模型介绍
阅读量:5996 次
发布时间:2019-06-20

本文共 1819 字,大约阅读时间需要 6 分钟。

内容移至:

内容总结于《TensorFlow实战Google深度学习框架》

不知道大家有没有听过一阶滞后滤波法:

new_value=(1a)×value+a×old_valuenew_value=(1−a)×value+a×old_value

其中a的取值范围[0,1],具体就是:本次滤波结果=(1-a)本次采样值+a上次滤波结果,采用此算法的目的是:

1、降低周期性的干扰;
2、在波动频率较高的场合有很好的效果。


而在TensorFlow中提供了tf.train.ExponentialMovingAverage 来实现滑动平均模型,在采用随机梯度下降算法训练神经网络时,使用其可以提高模型在测试数据上的健壮性(robustness)。

TensorFlow下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率decay。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个待更新的变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

shadow_variable=decay×shadow_variable+(1decay)×variableshadow_variable=decay×shadow_variable+(1−decay)×variable

上述公式与之前介绍的一阶滞后滤波法的公式相比较,会发现有很多相似的地方,从名字上面也可以很好的理解这个简约不简单算法的原理:平滑、滤波,即使数据平滑变化,通过调整参数来调整变化的稳定性。

在滑动平滑模型中, decay 决定了模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.999或0.9999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

decay=min{
decay,1+num_updates10+num_updates}
decay=min{decay,1+num_updates10+num_updates}

用一段书中代码带解释如何使用滑动平均模型:

import tensorflow as tfv1 = tf.Variable(0, dtype=tf.float32)//初始化v1变量step = tf.Variable(0, trainable=False) //初始化step为0ema = tf.train.ExponentialMovingAverage(0.99, step) //定义平滑类,设置参数以及stepmaintain_averages_op = ema.apply([v1]) //定义更新变量平均操作with tf.Session() as sess:    # 初始化    init_op = tf.global_variables_initializer()    sess.run(init_op)    print sess.run([v1, ema.average(v1)])    # 更新变量v1的取值    sess.run(tf.assign(v1, 5))    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])     # 更新step和v1的取值    sess.run(tf.assign(step, 10000))      sess.run(tf.assign(v1, 10))    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])           # 更新一次v1的滑动平均值    sess.run(maintain_averages_op)    print sess.run([v1, ema.average(v1)])

output:

[0.0, 0.0][5.0, 4.5][10.0, 4.5549998][10.0, 4.6094499]

转载地址:http://kjhlx.baihongyu.com/

你可能感兴趣的文章
【在线研讨-现场文字】《敏捷开发用户故事分类与组织结构(一期-1)》2012-06-26...
查看>>
ln 命令
查看>>
光纤故障判断
查看>>
HTML与XHTML的区别
查看>>
未来十年我们拼什么?
查看>>
IntelliJ IDEA Export to Eclipse Android工程不能正常被Eclipse识别的解决方法
查看>>
我的友情链接
查看>>
SVN服务注册
查看>>
6.NIO2-Path、Paths、Files
查看>>
如何对EDM邮件的用户数据进行分类
查看>>
《职场经验》
查看>>
ups机制下停电提前关闭oracle数据库
查看>>
Python基础学习篇-4-常用的正则表达式处理函数
查看>>
Linux下基础命令(五)
查看>>
python re库-----学习(正则表达式)
查看>>
python 变量赋值,引用,初始化问题
查看>>
[20180813]刷新共享池与父子游标.txt
查看>>
Win下部署Django开发环境
查看>>
malloc,calloc,alloca和free函数
查看>>
Python 时间处理
查看>>