相关阅读Pytorch基础https://blog.csdn.net/weixin_45791458/category_12457644.html?spm1001.2014.3001.5482在Pytorch中detach()是Tensor的一个重要方法用于返回一个脱离了计算图的张量它的语法如下所示。Tensor.detach() → Tensor在理解这个方法前首先得知道几个概念它们是Pytorch的基础。required_grad属性所有张量都拥有这个属性它可以是True或False中的一个代表这个张量需要计算梯度。如果设置为False在反向传播的链式求导时该张量会被当作常数处理而不是一个自变量。该属性可以在创建张量时通过required_grad参数来指定如例1所示。# 例1 x torch.tensor([2.0], requires_gradTrue)也可以使用requires_grad_()方法或直接修改属性的方式进行动态修改如例2所示。但要注意的是只能改变叶张量的required_grad属性。# 例2 x.requires_grad_(False) x.requires_grad False对于计算过程中创建的张量其required_grad属性取决于其来源如果新张量计算过程用到的原张量的required_grad属性都是False则新张量的required_grad属性为False否则只要有某个原张量的required_grad属性为True新张量的required_grad属性为True如例3所示。# 例3 import torch x torch.tensor([2.0], requires_gradFalse) y torch.tensor([2.0], requires_gradFalse) z x * y print(z.requires_grad) # 输出False x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradFalse) z x * y print(z.requires_grad) # 输出True x torch.tensor([2.0], requires_gradFalse) y torch.tensor([2.0], requires_gradTrue) z x * y print(z.requires_grad) # 输出True x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradTrue) z x * y print(z.requires_grad) # 输出Truegrad_fn属性对于计算过程中创建的张量非叶张量如果其required_grad属性为True则其grad_fn属性会记录生成该张量的操作用于反向传播如例4所示。# 例4 import torch x torch.tensor([2.0], requires_gradFalse) y torch.tensor([2.0], requires_gradFalse) z x * y print(z.grad_fn) # 输出None x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradFalse) z x / y print(z.grad_fn) # 输出DivBackward0 object at 0x7fa9cfc75fd0 x torch.tensor([2.0], requires_gradFalse) y torch.tensor([2.0], requires_gradTrue) z x y print(z.grad_fn) # 输出AddBackward0 object at 0x7fa9cfc75fd0 x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradTrue) z x - y print(z.grad_fn) # 输出SubBackward0 object at 0x7fa9cfc75fd0如果你想临时禁用梯度计算可以使用torch.no_grad()下文管理器来包裹不需要梯度计算的代码块这样新张量的required_grad属性一定为False自然grad_fn属性为None与原张量的required_grad属性无关如例5所示。# 例5 import torch x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradTrue) with torch.no_grad(): z x y print(z.requires_grad) # 输出False print(z.grad_fn) # 输出None如果反向传播到达一个张量时其required_grad属性为False或者说grad_fn属性为None则会报错如例6所示。# 例6 import torch x torch.tensor([2.0], requires_gradTrue) y torch.tensor([2.0], requires_gradTrue) with torch.no_grad(): z x y z.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fnis_leaf属性所有张量都拥有is_leaf属性它可以是True或False中的一个表示该张量是否为叶张量(leaf tensor)叶张量指的是那些并非计算而得的张量比如权重张量、偏置张量、输入张量如例7所示。# 例7 import torch x torch.tensor([2.0], requires_gradTrue) linear torch.nn.Linear(1, 1) y x * 2 print(x.is_leaf) # 输出True print(linear.weight.is_leaf) # 输出True print(linear.bias.is_leaf) # 输出True print(y.is_leaf) # 输出Falseretain_grad属性所有张量都拥有retain_grad属性它可以是True或False中的一个用于指定是否在反向传播后保留该张量的梯度默认情况下为了节约内存非叶张量的梯度在用于反向传播后会被删除。使用retain_grad()方法可以设置一个张量的retain_grad属性为True从而保留非叶张量的梯度注意如果required_grad属性为False设置retain_grad属性为True是无意义的如例8所示。# 例8 import torch x torch.tensor([2.0], requires_gradTrue) y x**2 z y**2 t z**2 z.retain_grad() t.backward() print(x.grad) # 输出tensor([1024.]) print(y.grad) # UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute wont be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:475.) return self._grad None print(z.grad) # 输出tensor([32.])detach()方法下面进入正题当一个张量调用detach()方法时会返回一个新的张量该张量与调用detach()方法的张量共享底层存储除grad外但其required_grad属性为False如例9所示。# 例9 import torch x torch.tensor([2.0], requires_gradTrue) y x.detach() print(y.requires_grad) # 输出False z_1 y**2 z_2 x**2 print(z_1.requires_grad) # 输出False print(z_1.grad_fn) # 输出None #z_1.backward() RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn print(z_2.requires_grad) # 输出True print(z_2.grad_fn) # 输出PowBackward0 object at 0x7f44042ccb50 z_2.backward() print(id(x)) # 输出139785294725264 print(id(y)) # 输出139785274024560 print(x.storage().data_ptr()) # 输出76321152 print(y.storage().data_ptr()) # 输出76321152其中张量z_1是由经过张量x的detach()方法返回的张量y计算而来而张量y的required_grad属性为False因此张量z_1的required_grad属性也为False其grad_fn属性为None因此不能反向传播。张量x和张量y的id号不同证明它们是不同的张量但storage().data_ptr()方法返回的指针表示它们共享底层存储。