cleanUrl: "pytorch-custom-gradual-warmup-scheduler-implementation"
description: "PyTorch에서 custom하게 GradualWarmupScheduler를 구현하는 방법에 대해 알아봅니다."

torch.optim.lr_scheduler.LRScheduler base class

https://github.com/pytorch/pytorch/blob/b818b3fe1c1fa90529b8500cd8ef800bec8415e8/torch/optim/lr_scheduler.py#L26

__init__

class LRScheduler:

    def __init__(self, optimizer, last_epoch=-1, verbose=False):

        # Scheduler가 lr을 조절하는 optimizer를 할당.
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Scheduler는 optimizer param_group 중
				# 'initial_lr'이라는 key를 learning rate scheduling에 활용한다.
				# 
				# last_epoch으로 -1이 주어지면 (대부분의 경우가 이 경우)
				# 현재 optimizer의 각 param_group의 'lr' 값을 initial_lr로 사용한다.
        if last_epoch == -1:
						# Optimizer
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
				# last_epoch으로 다른 값이 주어지면 scheduling을 위해
				# optimizer 각 param_group 내에 이미 initial_lr이 정의되어 있어야 한다.
				# 만약 아니라면, error를 발생시키고 끝낸다.
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
				
				# 각 param_group에 대해서 initial_lr 값의 리스트를 가지고 있는다.
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch

        # Following <https://github.com/pytorch/pytorch/issues/20124>
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
						# 이미 
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method

						# Optimizer instance에 대한 cyclic reference를 방지하기 위해서 weak reference를 사용한다.
            instance_ref = weakref.ref(method.__self__)  # Optimizer instance 자체에 대한 reference.

            # Get the unbound method for the same purpose.
            func = method.__func__
            cls = instance_ref().__class__
            del method

            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)

            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper
				

				# 위의 with_counter 라는 wrapper 덕분에 optimizer instance와 step 메소드는 아래와 같이 변형된다.
				#
				# 1. optimizer instance는 _step_count 변수를 가지며, 이 변수는 optimizer.step 메소드가 실행될 때마다 1씩 증가한다.
				# 2. optimizer.step method는 '_with_counter' 라는 attribute를 가지며, 값이 True로 정해진다.
				#
        self.optimizer.step = with_counter(self.optimizer.step)
        self.verbose = verbose

        self._initial_step()

_initial_step(self)

def _initial_step(self):
        # Optimizer와 이 scheduler의 `_step_count` 변수를 0으로 초기화하고
				# scheduler의 step을 한번 진행한다.
        self.optimizer._step_count = 0
        self._step_count = 0
        self.step()

get_lr(self)

def get_lr(self):
        # Compute learning rate using chainable form of the scheduler
				# 
				# 이 클래스를 상속받아 custom scheduler를 구현 시 이 메소드를 잘 override 해서 구현해야할 것.
        raise NotImplementedError

step(self, epoch=None)

def step(self, epoch=None):  # 참고로 step 메소드 호출 시 epoch 파라미터를 넘겨주는 방식은 이제 사용하지 않는 편이 좋다.

        # 첫 step 실행 시에 검사:
        if self._step_count == 1:
						# optimizer.step이 _with_counter attribute을 안 가지는 상황. Initialization에 문제가 있다.
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "<https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>", UserWarning)

						# optimizer.step()이 실행되기 전에 lr_scheduler.step()이 실행된 경우.
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
                              "See more details at "
                              "<https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>", UserWarning)
        
				# scheduler의 _step_count를 증가시킨다.
				self._step_count += 1

        with _enable_get_lr_call(self):
            if epoch is None:
                self.last_epoch += 1
                values = self.get_lr()
            else:
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr()
                else:
                    values = self.get_lr()

        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
            self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

LinearAnnealingLR 구현