cleanUrl: "pytorch-register-buffer"
description: "PyTorch model에서 register_buffer를 사용하는 이유에 대해 알아봅니다."

<aside> 💡 nn.Module.register_buffer('attribute_name', t)

예시

model.cuda() 시에 GPU로 이동한다

import torch
import torch.nn as nn

class Model(nn.Module):
	def __init__(self):
		super().__init__()
		
		self.param = nn.Parameter(torch.randn([2, 2]))
		
		buff = torch.randn([2, 2])
		self.register_buffer('buff', buff)

		self.non_buff = torch.randn([2, 2])

	def forward(self, x):
		return x

model = Model()
print(model.param.device)    # cpu
print(model.buff.device)     # cpu
print(model.non_buff.device) # cpu

model.cuda()
print(model.param.device)    # cuda:0
print(model.buff.device)     # cuda:0
print(model.non_buff.device) # cpu

state_dict()로 확인이 가능하다

import torch
import torch.nn as nn

class Model(nn.Module):
	def __init__(self):
		super().__init__()
		
		self.param = nn.Parameter(torch.randn([2, 2]))
		
		buff = torch.randn([2, 2])
		self.register_buffer('buff', buff)

		self.non_buff = torch.randn([2, 2])

	def forward(self, x):
		return x

model = Model()
print(model.state_dict())
OrderedDict([
	('param', tensor([[...]])),
  ('buff', tensor([[...]]))
]

Parameter가 아니므로, 당연히 model.parameters()에는 나타나지 않는다.

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
	def __init__(self):
		super().__init__()
		
		self.param = nn.Parameter(torch.randn([2, 2]))
		
		buff = torch.randn([2, 2])
		self.register_buffer('buff', buff)

		self.non_buff = torch.randn([2, 2])

	def forward(self, x):
		return x

model = Model()
for name, param in model.named_parameters():
	print(name, param.data)
param tensor([[...]])

requires_grad=True로 buffer에 넣어도 parameter로 인식하지 않는다.