File size: 9,123 Bytes
9d8bf2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
OpenSpiel Environment Server Implementation.

This module wraps OpenSpiel's rl_environment.Environment and exposes it
via the OpenEnv Environment interface.
"""

import uuid
from typing import Any, Dict

from openenv_core.env_server.interfaces import Environment
from openenv_core.env_server.types import State

from ..models import OpenSpielAction, OpenSpielObservation, OpenSpielState
from .opponent_policies import get_opponent_policy, OpponentPolicy

# Import OpenSpiel
try:
    from open_spiel.python import rl_environment
    import pyspiel
except ImportError as e:
    raise ImportError(
        "OpenSpiel is not installed. "
        "Please install it following instructions at: "
        "https://github.com/google-deepmind/open_spiel"
    ) from e


class OpenSpielEnvironment(Environment):
    """
    OpenSpiel Environment wrapper for OpenEnv.

    This environment wraps OpenSpiel games and provides a single-agent interface.
    For multi-player games, the agent controls one player while opponent(s) use
    a fixed policy (e.g., random).

    Supported games:
    - Single-player: catch, cliff_walking, 2048, blackjack
    - Multi-player: tic_tac_toe, kuhn_poker

    Args:
        game_name: Name of the OpenSpiel game (e.g., "catch", "tic_tac_toe").
        agent_player: Which player ID the agent controls (default 0).
        opponent_policy: Policy for opponent players ("random", "first", etc.).
        game_params: Optional game-specific parameters.

    Example:
        >>> env = OpenSpielEnvironment("catch")
        >>> obs = env.reset()
        >>> print(obs.info_state)  # Agent's observation
        >>> obs = env.step(OpenSpielAction(action_id=1))
        >>> print(obs.reward)
    """

    def __init__(
        self,
        game_name: str = "catch",
        agent_player: int = 0,
        opponent_policy: str = "random",
        game_params: Dict[str, Any] | None = None,
    ):
        """Initialize OpenSpiel environment."""
        super().__init__()

        self.game_name = game_name
        self.agent_player = agent_player
        self.game_params = game_params or {}

        # Create OpenSpiel environment
        try:
            self._ospiel_env = rl_environment.Environment(
                game_name, **self.game_params
            )
        except Exception as e:
            raise ValueError(
                f"Failed to create OpenSpiel game '{game_name}': {e}"
            ) from e

        self.num_players = self._ospiel_env.num_players
        self.is_turn_based = self._ospiel_env.is_turn_based

        # Validate agent_player
        if agent_player >= self.num_players:
            raise ValueError(
                f"agent_player={agent_player} >= num_players={self.num_players}"
            )

        # Set up opponent policy for multi-player games
        self.opponent_policy_fn: OpponentPolicy | None = None
        if self.num_players > 1:
            self.opponent_policy_fn = get_opponent_policy(opponent_policy)

        # Initialize state
        self._state = OpenSpielState(
            game_name=game_name,
            agent_player=agent_player,
            opponent_policy=opponent_policy,
            game_params=self.game_params,
            num_players=self.num_players,
        )

        # Track last opponent action for learning
        self._last_opponent_action: int | None = None

    def reset(self) -> Observation:
        """
        Reset the environment and return initial observation.

        For multi-player games, this will autoplay opponent turns until
        it's the agent's turn (or terminal state).

        Returns:
            Initial observation for the agent.
        """
        # Reset OpenSpiel environment
        time_step = self._ospiel_env.reset()

        # Reset state tracking
        self._state.episode_id = str(uuid.uuid4())
        self._state.step_count = 0
        self._last_opponent_action = None

        # Autoplay opponent turns until agent's turn
        time_step = self._auto_play_opponents(time_step)

        # Convert to OpenEnv observation
        return self._make_observation(time_step)

    def step(self, action: Action) -> Observation:
        """
        Execute agent's action and return resulting observation.

        For multi-player games, this will:
        1. Apply the agent's action
        2. Autoplay opponent turns until it's the agent's turn again
        3. Return the observation from the agent's perspective

        Args:
            action: OpenSpielAction containing the action_id to execute.

        Returns:
            Observation after action execution (and opponent turns if multi-player).

        Raises:
            ValueError: If action is not an OpenSpielAction.
        """
        if not isinstance(action, OpenSpielAction):
            raise ValueError(f"Expected OpenSpielAction, got {type(action)}")

        # Apply agent's action
        if self.is_turn_based:
            # Turn-based: single action
            time_step = self._ospiel_env.step([action.action_id])
        else:
            # Simultaneous-move: need actions for all players
            # For now, only support agent as player 0 in simultaneous games
            if self.agent_player != 0:
                raise NotImplementedError(
                    "Simultaneous-move games only support agent_player=0"
                )
            # Get opponent actions
            opponent_actions = []
            for player_id in range(self.num_players):
                if player_id == self.agent_player:
                    opponent_actions.append(action.action_id)
                else:
                    legal_actions = time_step.observations["legal_actions"][player_id]
                    opp_action = self.opponent_policy_fn.select_action(
                        legal_actions, time_step.observations
                    )
                    opponent_actions.append(opp_action)
            time_step = self._ospiel_env.step(opponent_actions)

        self._state.step_count += 1

        # Autoplay opponent turns (for turn-based games)
        if self.is_turn_based:
            time_step = self._auto_play_opponents(time_step)

        # Convert to OpenEnv observation
        return self._make_observation(time_step)

    @property
    def state(self) -> OpenSpielState:
        """Get current environment state."""
        return self._state

    def _auto_play_opponents(self, time_step) -> Any:
        """
        Autoplay opponent turns until it's the agent's turn or game is terminal.

        Args:
            time_step: Current TimeStep from OpenSpiel environment.

        Returns:
            Updated TimeStep after opponent moves.
        """
        # Single-player games: nothing to do
        if self.num_players == 1:
            return time_step

        # Multi-player games: play opponent turns
        while (
            not time_step.last()
            and time_step.observations["current_player"] != self.agent_player
        ):
            current_player = time_step.observations["current_player"]
            legal_actions = time_step.observations["legal_actions"][current_player]

            # Select opponent action
            opp_action = self.opponent_policy_fn.select_action(
                legal_actions, time_step.observations
            )
            self._last_opponent_action = opp_action

            # Apply opponent action
            time_step = self._ospiel_env.step([opp_action])
            self._state.step_count += 1

        return time_step

    def _make_observation(self, time_step) -> OpenSpielObservation:
        """
        Convert OpenSpiel TimeStep to OpenEnv Observation.

        Args:
            time_step: OpenSpiel TimeStep object.

        Returns:
            OpenSpielObservation for the agent.
        """
        # Extract agent's information
        info_state = time_step.observations["info_state"][self.agent_player]
        legal_actions = time_step.observations["legal_actions"][self.agent_player]
        current_player_id = time_step.observations["current_player"]

        # Determine game phase
        if time_step.last():
            game_phase = "terminal"
        elif time_step.first():
            game_phase = "initial"
        else:
            game_phase = "playing"

        # Get reward for agent
        reward = None
        if time_step.rewards is not None:
            reward = float(time_step.rewards[self.agent_player])

        # Create observation
        obs = OpenSpielObservation(
            info_state=info_state.tolist() if hasattr(info_state, "tolist") else list(info_state),
            legal_actions=legal_actions,
            game_phase=game_phase,
            current_player_id=current_player_id,
            opponent_last_action=self._last_opponent_action,
            done=time_step.last(),
            reward=reward,
        )

        return obs