本篇内容主要讲解“Pytorch中retain_graph的坑如何解决”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch中retain_graph的坑如何解决”吧!在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True) 让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True免费云主机域名)就可以单独的计算梯度,屡试不爽。但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错错误使用optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward(retain_graph=True)
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)正确使用optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward()
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了到此,相信大家对“Pytorch中retain_graph的坑如何解决”有了更深的了解,不妨来实际操作一番吧!这里是百云主机网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!
这篇文章主要介绍了el-menu如何实现横向溢出截取的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇el-menu如何实现横向溢出截取文章都会有所收获,下面我们一起来看看吧。antd的menu组件,会在subMenu超出的情况下对超…
免责声明:本站发布的图片视频文字,以转载和分享为主,文章观点不代表本站立场,本站不承担相关法律责任;如果涉及侵权请联系邮箱:360163164@qq.com举报,并提供相关证据,经查实将立刻删除涉嫌侵权内容。