Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, List, Tuple | |
| from collections import namedtuple | |
| from easydict import EasyDict | |
| import torch | |
| import torch.nn.functional as F | |
| from ding.model import model_wrap | |
| from ding.torch_utils import to_device | |
| from ding.utils.data import default_collate, default_decollate | |
| from ding.utils import POLICY_REGISTRY | |
| from .bc import BehaviourCloningPolicy | |
| from ding.model.template.ebm import create_stochastic_optimizer | |
| from ding.model.template.ebm import StochasticOptimizer, MCMC, AutoRegressiveDFO | |
| from ding.torch_utils import unsqueeze_repeat | |
| from ding.utils import EasyTimer | |
| class IBCPolicy(BehaviourCloningPolicy): | |
| r""" | |
| Overview: | |
| Implicit Behavior Cloning | |
| https://arxiv.org/abs/2109.00137.pdf | |
| .. note:: | |
| The code is adapted from the pytorch version of IBC https://github.com/kevinzakka/ibc, | |
| which only supports the derivative-free optimization (dfo) variants. | |
| This implementation moves a step forward and supports all variants of energy-based model | |
| mentioned in the paper (dfo, autoregressive dfo, and mcmc). | |
| """ | |
| config = dict( | |
| type='ibc', | |
| cuda=False, | |
| on_policy=False, | |
| continuous=True, | |
| model=dict(stochastic_optim=dict(type='mcmc', )), | |
| learn=dict( | |
| train_epoch=30, | |
| batch_size=256, | |
| optim=dict( | |
| learning_rate=1e-5, | |
| weight_decay=0.0, | |
| beta1=0.9, | |
| beta2=0.999, | |
| ), | |
| ), | |
| eval=dict(evaluator=dict(eval_freq=10000, )), | |
| ) | |
| def default_model(self) -> Tuple[str, List[str]]: | |
| return 'ebm', ['ding.model.template.ebm'] | |
| def _init_learn(self): | |
| self._timer = EasyTimer(cuda=self._cfg.cuda) | |
| self._sync_timer = EasyTimer(cuda=self._cfg.cuda) | |
| optim_cfg = self._cfg.learn.optim | |
| self._optimizer = torch.optim.AdamW( | |
| self._model.parameters(), | |
| lr=optim_cfg.learning_rate, | |
| weight_decay=optim_cfg.weight_decay, | |
| betas=(optim_cfg.beta1, optim_cfg.beta2), | |
| ) | |
| self._stochastic_optimizer: StochasticOptimizer = \ | |
| create_stochastic_optimizer(self._device, self._cfg.model.stochastic_optim) | |
| self._learn_model = model_wrap(self._model, 'base') | |
| self._learn_model.reset() | |
| def _forward_learn(self, data): | |
| with self._timer: | |
| data = default_collate(data) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| self._learn_model.train() | |
| loss_dict = dict() | |
| # obs: (B, O) | |
| # action: (B, A) | |
| obs, action = data['obs'], data['action'] | |
| # When action/observation space is 1, the action/observation dimension will | |
| # be squeezed in the first place, therefore unsqueeze there to make the data | |
| # compatiable with the ibc pipeline. | |
| if len(obs.shape) == 1: | |
| obs = obs.unsqueeze(-1) | |
| if len(action.shape) == 1: | |
| action = action.unsqueeze(-1) | |
| # N refers to the number of negative samples, i.e. self._stochastic_optimizer.inference_samples. | |
| # (B, N, O), (B, N, A) | |
| obs, negatives = self._stochastic_optimizer.sample(obs, self._learn_model) | |
| # (B, N+1, A) | |
| targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1) | |
| # (B, N+1, O) | |
| obs = torch.cat([obs[:, :1], obs], dim=1) | |
| permutation = torch.rand(targets.shape[0], targets.shape[1]).argsort(dim=1) | |
| targets = targets[torch.arange(targets.shape[0]).unsqueeze(-1), permutation] | |
| # (B, ) | |
| ground_truth = (permutation == 0).nonzero()[:, 1].to(self._device) | |
| # (B, N+1) for ebm | |
| # (B, N+1, A) for autoregressive ebm | |
| energy = self._learn_model.forward(obs, targets) | |
| logits = -1.0 * energy | |
| if isinstance(self._stochastic_optimizer, AutoRegressiveDFO): | |
| # autoregressive case | |
| # (B, A) | |
| ground_truth = unsqueeze_repeat(ground_truth, logits.shape[-1], -1) | |
| loss = F.cross_entropy(logits, ground_truth) | |
| loss_dict['ebm_loss'] = loss.item() | |
| if isinstance(self._stochastic_optimizer, MCMC): | |
| grad_penalty = self._stochastic_optimizer.grad_penalty(obs, targets, self._learn_model) | |
| loss += grad_penalty | |
| loss_dict['grad_penalty'] = grad_penalty.item() | |
| loss_dict['total_loss'] = loss.item() | |
| self._optimizer.zero_grad() | |
| loss.backward() | |
| with self._sync_timer: | |
| if self._cfg.multi_gpu: | |
| self.sync_gradients(self._learn_model) | |
| sync_time = self._sync_timer.value | |
| self._optimizer.step() | |
| total_time = self._timer.value | |
| return { | |
| 'total_time': total_time, | |
| 'sync_time': sync_time, | |
| **loss_dict, | |
| } | |
| def _monitor_vars_learn(self): | |
| if isinstance(self._stochastic_optimizer, MCMC): | |
| return ['total_loss', 'ebm_loss', 'grad_penalty', 'total_time', 'sync_time'] | |
| else: | |
| return ['total_loss', 'ebm_loss', 'total_time', 'sync_time'] | |
| def _init_eval(self): | |
| self._eval_model = model_wrap(self._model, wrapper_name='base') | |
| self._eval_model.reset() | |
| def _forward_eval(self, data: dict) -> dict: | |
| tensor_input = isinstance(data, torch.Tensor) | |
| if not tensor_input: | |
| data_id = list(data.keys()) | |
| data = default_collate(list(data.values())) | |
| if self._cuda: | |
| data = to_device(data, self._device) | |
| self._eval_model.eval() | |
| output = self._stochastic_optimizer.infer(data, self._eval_model) | |
| output = dict(action=output) | |
| if self._cuda: | |
| output = to_device(output, 'cpu') | |
| if tensor_input: | |
| return output | |
| else: | |
| output = default_decollate(output) | |
| return {i: d for i, d in zip(data_id, output)} | |
| def set_statistic(self, statistics: EasyDict) -> None: | |
| self._stochastic_optimizer.set_action_bounds(statistics.action_bounds) | |
| # =================================================================== # | |
| # Implicit Behavioral Cloning does not need `collect`-related functions | |
| # =================================================================== # | |
| def _init_collect(self): | |
| raise NotImplementedError | |
| def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]: | |
| raise NotImplementedError | |
| def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: | |
| raise NotImplementedError | |
| def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| raise NotImplementedError | |