Spaces:
Running
Running
| from typing import Dict, Union | |
| import torch | |
| import torch.nn as nn | |
| from functools import reduce | |
| from ding.torch_utils import one_hot, MLP | |
| from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType | |
| from .q_learning import DRQN | |
| class COMAActorNetwork(nn.Module): | |
| """ | |
| Overview: | |
| Decentralized actor network in COMA algorithm. | |
| Interface: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| obs_shape: int, | |
| action_shape: int, | |
| hidden_size_list: SequenceType = [128, 128, 64], | |
| ): | |
| """ | |
| Overview: | |
| Initialize COMA actor network | |
| Arguments: | |
| - obs_shape (:obj:`int`): the dimension of each agent's observation state | |
| - action_shape (:obj:`int`): the dimension of action shape | |
| - hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64] | |
| """ | |
| super(COMAActorNetwork, self).__init__() | |
| self.main = DRQN(obs_shape, action_shape, hidden_size_list) | |
| def forward(self, inputs: Dict) -> Dict: | |
| """ | |
| Overview: | |
| The forward computation graph of COMA actor network | |
| Arguments: | |
| - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state'] | |
| - agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
| - action_mask (:obj:`torch.Tensor`): the masked action | |
| - prev_state (:obj:`torch.Tensor`): the previous hidden state | |
| Returns: | |
| - output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask'] | |
| ArgumentsKeys: | |
| - necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state`` | |
| ReturnsKeys: | |
| - necessary: ``logit``, ``next_state``, ``action_mask`` | |
| Examples: | |
| >>> T, B, A, N = 4, 8, 3, 32 | |
| >>> embedding_dim = 64 | |
| >>> action_dim = 6 | |
| >>> data = torch.randn(T, B, A, N) | |
| >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim]) | |
| >>> prev_state = [[None for _ in range(A)] for _ in range(B)] | |
| >>> for t in range(T): | |
| >>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state} | |
| >>> outputs = model(inputs) | |
| >>> logit, prev_state = outputs['logit'], outputs['next_state'] | |
| """ | |
| agent_state = inputs['obs']['agent_state'] | |
| prev_state = inputs['prev_state'] | |
| if len(agent_state.shape) == 3: # B, A, N | |
| agent_state = agent_state.unsqueeze(0) | |
| unsqueeze_flag = True | |
| else: | |
| unsqueeze_flag = False | |
| T, B, A = agent_state.shape[:3] | |
| agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:]) | |
| prev_state = reduce(lambda x, y: x + y, prev_state) | |
| output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True}) | |
| logit, next_state = output['logit'], output['next_state'] | |
| next_state, _ = list_split(next_state, step=A) | |
| logit = logit.reshape(T, B, A, -1) | |
| if unsqueeze_flag: | |
| logit = logit.squeeze(0) | |
| return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']} | |
| class COMACriticNetwork(nn.Module): | |
| """ | |
| Overview: | |
| Centralized critic network in COMA algorithm. | |
| Interface: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| input_size: int, | |
| action_shape: int, | |
| hidden_size: int = 128, | |
| ): | |
| """ | |
| Overview: | |
| initialize COMA critic network | |
| Arguments: | |
| - input_size (:obj:`int`): the size of input global observation | |
| - action_shape (:obj:`int`): the dimension of action shape | |
| - hidden_size_list (:obj:`list`): the list of hidden size, default to 128 | |
| Returns: | |
| - output (:obj:`dict`): output data dict with keys ['q_value'] | |
| Shapes: | |
| - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` | |
| - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
| - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
| - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
| - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
| """ | |
| super(COMACriticNetwork, self).__init__() | |
| self.action_shape = action_shape | |
| self.act = nn.ReLU() | |
| self.mlp = nn.Sequential( | |
| MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape) | |
| ) | |
| def forward(self, data: Dict) -> Dict: | |
| """ | |
| Overview: | |
| forward computation graph of qmix network | |
| Arguments: | |
| - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
| - agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
| - global_state (:obj:`torch.Tensor`): global state(obs) | |
| - action (:obj:`torch.Tensor`): the masked action | |
| ArgumentsKeys: | |
| - necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state`` | |
| ReturnsKeys: | |
| - necessary: ``q_value`` | |
| Examples: | |
| >>> agent_num, bs, T = 4, 3, 8 | |
| >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
| >>> coma_model = COMACriticNetwork( | |
| >>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim) | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
| >>> 'global_state': torch.randn(T, bs, global_obs_dim), | |
| >>> }, | |
| >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), | |
| >>> } | |
| >>> output = coma_model(data) | |
| """ | |
| x = self._preprocess_data(data) | |
| q = self.mlp(x) | |
| return {'q_value': q} | |
| def _preprocess_data(self, data: Dict) -> torch.Tensor: | |
| """ | |
| Overview: | |
| preprocess data to make it can be used by MLP net | |
| Arguments: | |
| - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
| - agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
| - global_state (:obj:`torch.Tensor`): global state(obs) | |
| - action (:obj:`torch.Tensor`): the masked action | |
| ArgumentsKeys: | |
| - necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state`` | |
| Return: | |
| - x (:obj:`torch.Tensor`): the data can be used by MLP net, including \ | |
| ``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id`` | |
| """ | |
| t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3] | |
| agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state'] | |
| # splite obs, last_action and agent_id | |
| agent_state = agent_state_ori[..., :-self.action_shape - agent_num] | |
| last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num] | |
| last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1) | |
| agent_id = agent_state_ori[..., -agent_num:] | |
| action = one_hot(data['action'], self.action_shape) # T, B, A,N | |
| action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1) | |
| action_mask = (1 - torch.eye(agent_num).to(action.device)) | |
| action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1) # A, A*N | |
| action = (action_mask.unsqueeze(0).unsqueeze(0)) * action # T, B, A, A*N | |
| global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1) | |
| x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1) | |
| return x | |
| class COMA(nn.Module): | |
| """ | |
| Overview: | |
| The network of COMA algorithm, which is QAC-type actor-critic. | |
| Interface: | |
| ``__init__``, ``forward`` | |
| Properties: | |
| - mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic`` | |
| """ | |
| mode = ['compute_actor', 'compute_critic'] | |
| def __init__( | |
| self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType], | |
| actor_hidden_size_list: SequenceType | |
| ) -> None: | |
| """ | |
| Overview: | |
| initialize COMA network | |
| Arguments: | |
| - agent_num (:obj:`int`): the number of agent | |
| - obs_shape (:obj:`Dict`): the observation information, including agent_state and \ | |
| global_state | |
| - action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape | |
| - actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size | |
| """ | |
| super(COMA, self).__init__() | |
| action_shape = squeeze(action_shape) | |
| actor_input_size = squeeze(obs_shape['agent_state']) | |
| critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \ | |
| agent_num * action_shape + (agent_num - 1) * action_shape | |
| critic_hidden_size = actor_hidden_size_list[-1] | |
| self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list) | |
| self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size) | |
| def forward(self, inputs: Dict, mode: str) -> Dict: | |
| """ | |
| Overview: | |
| forward computation graph of COMA network | |
| Arguments: | |
| - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] | |
| - agent_state (:obj:`torch.Tensor`): each agent local state(obs) | |
| - global_state (:obj:`torch.Tensor`): global state(obs) | |
| - action (:obj:`torch.Tensor`): the masked action | |
| ArgumentsKeys: | |
| - necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state`` | |
| ReturnsKeys: | |
| - necessary: | |
| - compute_critic: ``q_value`` | |
| - compute_actor: ``logit``, ``next_state``, ``action_mask`` | |
| Shapes: | |
| - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)` | |
| - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
| - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
| - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]` | |
| - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
| - q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)` | |
| Examples: | |
| >>> agent_num, bs, T = 4, 3, 8 | |
| >>> agent_num, bs, T = 4, 3, 8 | |
| >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9 | |
| >>> coma_model = COMA( | |
| >>> agent_num=agent_num, | |
| >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )), | |
| >>> action_shape=action_dim, | |
| >>> actor_hidden_size_list=[128, 64], | |
| >>> ) | |
| >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)] | |
| >>> data = { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
| >>> 'action_mask': None, | |
| >>> }, | |
| >>> 'prev_state': prev_state, | |
| >>> } | |
| >>> output = coma_model(data, mode='compute_actor') | |
| >>> data= { | |
| >>> 'obs': { | |
| >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim), | |
| >>> 'global_state': torch.randn(T, bs, global_obs_dim), | |
| >>> }, | |
| >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)), | |
| >>> } | |
| >>> output = coma_model(data, mode='compute_critic') | |
| """ | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| if mode == 'compute_actor': | |
| return self.actor(inputs) | |
| elif mode == 'compute_critic': | |
| return self.critic(inputs) | |