optimizer.state_dict() & optimizer.load_state_dict()
浏览次数: 发布时间:2024-08-12 02:52:24

1.optimizer储存的模型参数的值只是引用,因此不必将optimizer移动到cuda上

import torch
import torch.optim as optim
import torch.nn as nn

model = nn.Conv2D(1 ,1, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
optimizer.param_groups
"""
[{'dampening': 0,
  'lr': 0.0001,
  'momentum': 0,
  'nesterov': False,
  'params':[Parameter containing:
   tensor([[[[-0.8806]]]], requires_grad=True), Parameter containing:
   tensor([0.3750], requires_grad=True)],
  'weight_decay': 0}]
"""

现在我们将model移动到CUDA上

model.cuda()
optimizer.param_groups
"""
[{'dampening': 0,
  'lr': 0.0001,
  'momentum': 0,
  'nesterov': False,
  'params':[Parameter containing:
   tensor([[[[-0.8806]]]], device='cuda:0', requires_grad=True),
   Parameter containing:
   tensor([0.3750], device='cuda:0', requires_grad=True)],
  'weight_decay': 0}]
"""

优化器内的参数页跟着移动到了CUDA上。

2.optim.state_dict()没有保存优化器中的模型参数,仅仅保存了参数的数量

此处参见源码

def state_dict(self):
        r"""Returns the state of the optimizer as a :class:`dict`.

        It contains two entries:

        * state - a dict holding current optimization state. Its content
            differs between optimizer classes.
        * param_groups - a dict containing all parameter groups
        """
        # Save order indices instead of Tensors
        param_mappings = {}
        start_index = 0

        def pack_group(group):
            nonlocal start_index
            packed = {k: v for k, v in group.items() if k != 'params'}
            param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
                                   if id(p) not in param_mappings})
            # 可以发现此处,对于参数'params'仅仅保存了索引,从0开始,有n个参数就到n-1。
            packed['params'] = [param_mappings[id(p)] for p in group['params']]
            start_index += len(packed['params'])
            return packed
        param_groups = [pack_group(g) for g in self.param_groups]
        # Remap state to use order indices as keys
        packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
                        for k, v in self.state.items()}
        return {
            'state': packed_state,
            'param_groups': param_groups,
        }

3.optim.load_state_dict(state_dict)操作是将目前optimizer中的params参数填充到statedict中,然后用statedict中的state和params_group替换掉目前optimizer中的state和param_group

此处见源码

def load_state_dict(self, state_dict):
        r"""Loads the optimizer state.

        Args:
            state_dict (dict): optimizer state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        # deepcopy, to be consistent with module API
        state_dict = deepcopy(state_dict)
        # Validate the state_dict
        groups = self.param_groups
        saved_groups = state_dict['param_groups']

        if len(groups) != len(saved_groups):
            raise ValueError("loaded state dict has a different number of "
                             "parameter groups")
        param_lens = (len(g['params']) for g in groups)
        saved_lens = (len(g['params']) for g in saved_groups)
        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
            raise ValueError("loaded state dict contains a parameter group "
                             "that doesn't match the size of optimizer's group")

        # Update the state
        id_map = {old_id: p for old_id, p in
                  zip(chain.from_iterable((g['params'] for g in saved_groups)),
                      chain.from_iterable((g['params'] for g in groups)))}

        def cast(param, value):
            r"""Make a deep copy of value, casting all tensors to device of param."""
            if isinstance(value, torch.Tensor):
                # Floating-point types are a bit special here. They are the only ones
                # that are assumed to always match the type of params.
                if param.is_floating_point():
                    value = value.to(param.dtype)
                value = value.to(param.device)
                return value
            elif isinstance(value, dict):
                return {k: cast(param, v) for k, v in value.items()}
            elif isinstance(value, container_abcs.Iterable):
                return type(value)(cast(param, v) for v in value)
            else:
                return value

        # Copy state assigned to params (and cast tensors to appropriate types).
        # State that is not assigned to params is copied as is (needed for
        # backward compatibility).
        state = defaultdict(dict)
        for k, v in state_dict['state'].items():
            if k in id_map:
                param = id_map[k]
                state[param] = cast(param, v)
            else:
                state[k] = v

        # Update parameter groups, setting their 'params' value
        # 这里进行参数替换
        def update_group(group, new_group):
            new_group['params'] = group['params']
            return new_group
        param_groups = [
            update_group(g, ng) for g, ng in zip(groups, saved_groups)]
        self.__setstate__({'state': state, 'param_groups': param_groups})

服务热线
020-66666666

平台注册入口