cleanUrl: "pytorch-register-buffer"
description: "PyTorch model에서 register_buffer를 사용하는 이유에 대해 알아봅니다."
forward 내부나 모델 구현 내부에서 .to(device)를 call 하는 모양새가 예쁘지 않아서 어떻게 하면 model.cuda() 실행 시에 이러한 non-learnable tensor들을 parameter들과 함께 적절한 device로 옮길 수 있을지 알아보았다.nn.Module.register_buffer('attribute_name', tensor)를 이용하면 되더라!register_buffer 메소드를 실행했을 때의 특징을 정리한 것이다.<aside>
💡 nn.Module.register_buffer('attribute_name', t)
t는 self.attribute_name 으로 접근 가능하다.t는 학습되지 않는다. (중요)model.cuda() 시에 t도 함께 GPU로 간다.
</aside>model.cuda() 시에 GPU로 이동한다import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
print(model.param.device) # cpu
print(model.buff.device) # cpu
print(model.non_buff.device) # cpu
model.cuda()
print(model.param.device) # cuda:0
print(model.buff.device) # cuda:0
print(model.non_buff.device) # cpu
model.cuda() 시 일반적인 parameter 처럼 GPU로 이동하는 것을 확인할 수 있다.state_dict()로 확인이 가능하다import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
print(model.state_dict())
OrderedDict([
('param', tensor([[...]])),
('buff', tensor([[...]]))
]
param과 buff만 state_dict()에서 확인이 가능한 것을 알 수 있다.model.parameters()에는 나타나지 않는다.import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn([2, 2]))
buff = torch.randn([2, 2])
self.register_buffer('buff', buff)
self.non_buff = torch.randn([2, 2])
def forward(self, x):
return x
model = Model()
for name, param in model.named_parameters():
print(name, param.data)
param tensor([[...]])
buff와 non_buff 둘 다 나타나지 않는다.requires_grad=True로 buffer에 넣어도 parameter로 인식하지 않는다.