|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
OpenSpiel Environment Server Implementation. |
|
|
|
|
|
This module wraps OpenSpiel's rl_environment.Environment and exposes it |
|
|
via the OpenEnv Environment interface. |
|
|
""" |
|
|
|
|
|
import uuid |
|
|
from typing import Any, Dict |
|
|
|
|
|
from openenv_core.env_server.interfaces import Environment |
|
|
from openenv_core.env_server.types import State |
|
|
|
|
|
from ..models import OpenSpielAction, OpenSpielObservation, OpenSpielState |
|
|
from .opponent_policies import get_opponent_policy, OpponentPolicy |
|
|
|
|
|
|
|
|
try: |
|
|
from open_spiel.python import rl_environment |
|
|
import pyspiel |
|
|
except ImportError as e: |
|
|
raise ImportError( |
|
|
"OpenSpiel is not installed. " |
|
|
"Please install it following instructions at: " |
|
|
"https://github.com/google-deepmind/open_spiel" |
|
|
) from e |
|
|
|
|
|
|
|
|
class OpenSpielEnvironment(Environment): |
|
|
""" |
|
|
OpenSpiel Environment wrapper for OpenEnv. |
|
|
|
|
|
This environment wraps OpenSpiel games and provides a single-agent interface. |
|
|
For multi-player games, the agent controls one player while opponent(s) use |
|
|
a fixed policy (e.g., random). |
|
|
|
|
|
Supported games: |
|
|
- Single-player: catch, cliff_walking, 2048, blackjack |
|
|
- Multi-player: tic_tac_toe, kuhn_poker |
|
|
|
|
|
Args: |
|
|
game_name: Name of the OpenSpiel game (e.g., "catch", "tic_tac_toe"). |
|
|
agent_player: Which player ID the agent controls (default 0). |
|
|
opponent_policy: Policy for opponent players ("random", "first", etc.). |
|
|
game_params: Optional game-specific parameters. |
|
|
|
|
|
Example: |
|
|
>>> env = OpenSpielEnvironment("catch") |
|
|
>>> obs = env.reset() |
|
|
>>> print(obs.info_state) # Agent's observation |
|
|
>>> obs = env.step(OpenSpielAction(action_id=1)) |
|
|
>>> print(obs.reward) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
game_name: str = "catch", |
|
|
agent_player: int = 0, |
|
|
opponent_policy: str = "random", |
|
|
game_params: Dict[str, Any] | None = None, |
|
|
): |
|
|
"""Initialize OpenSpiel environment.""" |
|
|
super().__init__() |
|
|
|
|
|
self.game_name = game_name |
|
|
self.agent_player = agent_player |
|
|
self.game_params = game_params or {} |
|
|
|
|
|
|
|
|
try: |
|
|
self._ospiel_env = rl_environment.Environment( |
|
|
game_name, **self.game_params |
|
|
) |
|
|
except Exception as e: |
|
|
raise ValueError( |
|
|
f"Failed to create OpenSpiel game '{game_name}': {e}" |
|
|
) from e |
|
|
|
|
|
self.num_players = self._ospiel_env.num_players |
|
|
self.is_turn_based = self._ospiel_env.is_turn_based |
|
|
|
|
|
|
|
|
if agent_player >= self.num_players: |
|
|
raise ValueError( |
|
|
f"agent_player={agent_player} >= num_players={self.num_players}" |
|
|
) |
|
|
|
|
|
|
|
|
self.opponent_policy_fn: OpponentPolicy | None = None |
|
|
if self.num_players > 1: |
|
|
self.opponent_policy_fn = get_opponent_policy(opponent_policy) |
|
|
|
|
|
|
|
|
self._state = OpenSpielState( |
|
|
game_name=game_name, |
|
|
agent_player=agent_player, |
|
|
opponent_policy=opponent_policy, |
|
|
game_params=self.game_params, |
|
|
num_players=self.num_players, |
|
|
) |
|
|
|
|
|
|
|
|
self._last_opponent_action: int | None = None |
|
|
|
|
|
def reset(self) -> Observation: |
|
|
""" |
|
|
Reset the environment and return initial observation. |
|
|
|
|
|
For multi-player games, this will autoplay opponent turns until |
|
|
it's the agent's turn (or terminal state). |
|
|
|
|
|
Returns: |
|
|
Initial observation for the agent. |
|
|
""" |
|
|
|
|
|
time_step = self._ospiel_env.reset() |
|
|
|
|
|
|
|
|
self._state.episode_id = str(uuid.uuid4()) |
|
|
self._state.step_count = 0 |
|
|
self._last_opponent_action = None |
|
|
|
|
|
|
|
|
time_step = self._auto_play_opponents(time_step) |
|
|
|
|
|
|
|
|
return self._make_observation(time_step) |
|
|
|
|
|
def step(self, action: Action) -> Observation: |
|
|
""" |
|
|
Execute agent's action and return resulting observation. |
|
|
|
|
|
For multi-player games, this will: |
|
|
1. Apply the agent's action |
|
|
2. Autoplay opponent turns until it's the agent's turn again |
|
|
3. Return the observation from the agent's perspective |
|
|
|
|
|
Args: |
|
|
action: OpenSpielAction containing the action_id to execute. |
|
|
|
|
|
Returns: |
|
|
Observation after action execution (and opponent turns if multi-player). |
|
|
|
|
|
Raises: |
|
|
ValueError: If action is not an OpenSpielAction. |
|
|
""" |
|
|
if not isinstance(action, OpenSpielAction): |
|
|
raise ValueError(f"Expected OpenSpielAction, got {type(action)}") |
|
|
|
|
|
|
|
|
if self.is_turn_based: |
|
|
|
|
|
time_step = self._ospiel_env.step([action.action_id]) |
|
|
else: |
|
|
|
|
|
|
|
|
if self.agent_player != 0: |
|
|
raise NotImplementedError( |
|
|
"Simultaneous-move games only support agent_player=0" |
|
|
) |
|
|
|
|
|
opponent_actions = [] |
|
|
for player_id in range(self.num_players): |
|
|
if player_id == self.agent_player: |
|
|
opponent_actions.append(action.action_id) |
|
|
else: |
|
|
legal_actions = time_step.observations["legal_actions"][player_id] |
|
|
opp_action = self.opponent_policy_fn.select_action( |
|
|
legal_actions, time_step.observations |
|
|
) |
|
|
opponent_actions.append(opp_action) |
|
|
time_step = self._ospiel_env.step(opponent_actions) |
|
|
|
|
|
self._state.step_count += 1 |
|
|
|
|
|
|
|
|
if self.is_turn_based: |
|
|
time_step = self._auto_play_opponents(time_step) |
|
|
|
|
|
|
|
|
return self._make_observation(time_step) |
|
|
|
|
|
@property |
|
|
def state(self) -> OpenSpielState: |
|
|
"""Get current environment state.""" |
|
|
return self._state |
|
|
|
|
|
def _auto_play_opponents(self, time_step) -> Any: |
|
|
""" |
|
|
Autoplay opponent turns until it's the agent's turn or game is terminal. |
|
|
|
|
|
Args: |
|
|
time_step: Current TimeStep from OpenSpiel environment. |
|
|
|
|
|
Returns: |
|
|
Updated TimeStep after opponent moves. |
|
|
""" |
|
|
|
|
|
if self.num_players == 1: |
|
|
return time_step |
|
|
|
|
|
|
|
|
while ( |
|
|
not time_step.last() |
|
|
and time_step.observations["current_player"] != self.agent_player |
|
|
): |
|
|
current_player = time_step.observations["current_player"] |
|
|
legal_actions = time_step.observations["legal_actions"][current_player] |
|
|
|
|
|
|
|
|
opp_action = self.opponent_policy_fn.select_action( |
|
|
legal_actions, time_step.observations |
|
|
) |
|
|
self._last_opponent_action = opp_action |
|
|
|
|
|
|
|
|
time_step = self._ospiel_env.step([opp_action]) |
|
|
self._state.step_count += 1 |
|
|
|
|
|
return time_step |
|
|
|
|
|
def _make_observation(self, time_step) -> OpenSpielObservation: |
|
|
""" |
|
|
Convert OpenSpiel TimeStep to OpenEnv Observation. |
|
|
|
|
|
Args: |
|
|
time_step: OpenSpiel TimeStep object. |
|
|
|
|
|
Returns: |
|
|
OpenSpielObservation for the agent. |
|
|
""" |
|
|
|
|
|
info_state = time_step.observations["info_state"][self.agent_player] |
|
|
legal_actions = time_step.observations["legal_actions"][self.agent_player] |
|
|
current_player_id = time_step.observations["current_player"] |
|
|
|
|
|
|
|
|
if time_step.last(): |
|
|
game_phase = "terminal" |
|
|
elif time_step.first(): |
|
|
game_phase = "initial" |
|
|
else: |
|
|
game_phase = "playing" |
|
|
|
|
|
|
|
|
reward = None |
|
|
if time_step.rewards is not None: |
|
|
reward = float(time_step.rewards[self.agent_player]) |
|
|
|
|
|
|
|
|
obs = OpenSpielObservation( |
|
|
info_state=info_state.tolist() if hasattr(info_state, "tolist") else list(info_state), |
|
|
legal_actions=legal_actions, |
|
|
game_phase=game_phase, |
|
|
current_player_id=current_player_id, |
|
|
opponent_last_action=self._last_opponent_action, |
|
|
done=time_step.last(), |
|
|
reward=reward, |
|
|
) |
|
|
|
|
|
return obs |
|
|
|