cleanUrl: "pytorch-register-buffer"
description: "PyTorch model에서 register_buffer를 사용하는 이유에 대해 알아봅니다."
forward
내부나 모델 구현 내부에서 .to(device)
를 call 하는 모양새가 예쁘지 않아서 어떻게 하면 model.cuda()
실행 시에 이러한 non-learnable tensor들을 parameter들과 함께 적절한 device로 옮길 수 있을지 알아보았다.nn.Module.register_buffer('attribute_name', tensor)
를 이용하면 되더라!register_buffer
메소드를 실행했을 때의 특징을 정리한 것이다.<aside>
💡 nn.Module.register_buffer('attribute_name', t)
t
는 self.attribute_name
으로 접근 가능하다.t
는 학습되지 않는다. (중요)model.cuda()
시에 t
도 함께 GPU로 간다.
</aside>model.cuda()
시에 GPU로 이동한다import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
print(model.param.device) # cpu
print(model.buff.device) # cpu
print(model.non_buff.device) # cpu
model.cuda()
print(model.param.device) # cuda:0
print(model.buff.device) # cuda:0
print(model.non_buff.device) # cpu
model.cuda()
시 일반적인 parameter 처럼 GPU로 이동하는 것을 확인할 수 있다.state_dict()
로 확인이 가능하다import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
print(model.state_dict())
OrderedDict([
('param', tensor([[...]])),
('buff', tensor([[...]]))
]
param
과 buff
만 state_dict()
에서 확인이 가능한 것을 알 수 있다.model.parameters()
에는 나타나지 않는다.import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
for name, param in model.named_parameters():
print(name, param.data)
param tensor([[...]])
buff
와 non_buff
둘 다 나타나지 않는다.requires_grad=True
로 buffer에 넣어도 parameter로 인식하지 않는다.