Search
⚙️

PyTorch에서 register_buffer를 써야하는 이유

PyTorch 모델을 구현할 때 모델 내부적으로 가지고 있으면서, 학습은 필요 없는 tensor들을 사용해야 할 때가 있다.
예를 들면, transformer 구현에서의 causal attention mask 라던지…
이 때 흔히 하는 실수가 그 tensor들을 적절한 device (CPU/GPU)로 보내지 않아서 device의 mismatch가 생기는 실수인데, 개인적으로 forward 내부나 모델 구현 내부에서 .to(device)를 call 하는 모양새가 예쁘지 않아서 어떻게 하면 model.cuda() 실행 시에 이러한 non-learnable tensor들을 parameter들과 함께 적절한 device로 옮길 수 있을지 알아보았다.
nn.Module.register_buffer('attribute_name', tensor)를 이용하면 되더라!
아래는 register_buffer 메소드를 실행했을 때의 특징을 정리한 것이다.
nn.Module.register_buffer('attribute_name', t)
모듈 내에서 tensor tself.attribute_name 으로 접근 가능하다.
Tensor t는 학습되지 않는다. (중요)
model.cuda() 시에 t도 함께 GPU로 간다.

예시

model.cuda() 시에 GPU로 이동한다

import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super().__init__() self.param = nn.Parameter(torch.randn([2, 2])) buff = torch.randn([2, 2]) self.register_buffer('buff', buff) self.non_buff = torch.randn([2, 2]) def forward(self, x): return x model = Model() print(model.param.device) # cpu print(model.buff.device) # cpu print(model.non_buff.device) # cpu model.cuda() print(model.param.device) # cuda:0 print(model.buff.device) # cuda:0 print(model.non_buff.device) # cpu
Python
복사
위의 예시에서, buffer 에 들어간 tensor는 model.cuda() 시 일반적인 parameter 처럼 GPU로 이동하는 것을 확인할 수 있다.
반면 buffer에 넣지 않은 tensor는 그대로 CPU에 남아있다.

state_dict()로 확인이 가능하다

import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super().__init__() self.param = nn.Parameter(torch.randn([2, 2])) buff = torch.randn([2, 2]) self.register_buffer('buff', buff) self.non_buff = torch.randn([2, 2]) def forward(self, x): return x model = Model() print(model.state_dict())
Python
복사
OrderedDict([ ('param', tensor([[...]])), ('buff', tensor([[...]])) ]
Python
복사
parambuffstate_dict()에서 확인이 가능한 것을 알 수 있다.

Parameter가 아니므로, 당연히 model.parameters()에는 나타나지 않는다.

import torch import torch.nn as nn import torch.optim as optim class Model(nn.Module): def __init__(self): super().__init__() self.param = nn.Parameter(torch.randn([2, 2])) buff = torch.randn([2, 2]) self.register_buffer('buff', buff) self.non_buff = torch.randn([2, 2]) def forward(self, x): return x model = Model() for name, param in model.named_parameters(): print(name, param.data)
Python
복사
param tensor([[...]])
Python
복사
buffnon_buff 둘 다 나타나지 않는다.

requires_grad=True로 buffer에 넣어도 parameter로 인식하지 않는다.

import torch import torch.nn as nn import torch.optim as optim class Model(nn.Module): def __init__(self): super().__init__() self.param = nn.Parameter(torch.randn([2, 2])) buff = torch.randn([2, 2], requires_grad=True) self.register_buffer('buff', buff) self.non_buff = torch.randn([2, 2]) def forward(self, x): return x model = Model() for name, param in model.named_parameters(): print(name, param.data)
Python
복사
param tensor([[...]])
Python
복사
따라서 optimizer로 업데이트 되지 않는다! 주의해야 할 듯.

buffers(), named_buffers() 로 접근이 가능하다.

import torch import torch.nn as nn import torch.optim as optim class Model(nn.Module): def __init__(self): super().__init__() self.param = nn.Parameter(torch.randn([2, 2])) buff = torch.randn([2, 2]) self.register_buffer('buff', buff) self.non_buff = torch.randn([2, 2]) def forward(self, x): return x model = Model() for name, param in model.named_buffers(): print(name, param.data)
Python
복사
buff tensor([[...]])
Python
복사
param은 나타나지 않고 buff만 나타나는 것에 주목하자.