# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ OpenSpiel Environment Server Implementation. This module wraps OpenSpiel's rl_environment.Environment and exposes it via the OpenEnv Environment interface. """ import uuid from typing import Any, Dict from openenv_core.env_server.interfaces import Environment from openenv_core.env_server.types import State from ..models import OpenSpielAction, OpenSpielObservation, OpenSpielState from .opponent_policies import get_opponent_policy, OpponentPolicy # Import OpenSpiel try: from open_spiel.python import rl_environment import pyspiel except ImportError as e: raise ImportError( "OpenSpiel is not installed. " "Please install it following instructions at: " "https://github.com/google-deepmind/open_spiel" ) from e class OpenSpielEnvironment(Environment): """ OpenSpiel Environment wrapper for OpenEnv. This environment wraps OpenSpiel games and provides a single-agent interface. For multi-player games, the agent controls one player while opponent(s) use a fixed policy (e.g., random). Supported games: - Single-player: catch, cliff_walking, 2048, blackjack - Multi-player: tic_tac_toe, kuhn_poker Args: game_name: Name of the OpenSpiel game (e.g., "catch", "tic_tac_toe"). agent_player: Which player ID the agent controls (default 0). opponent_policy: Policy for opponent players ("random", "first", etc.). game_params: Optional game-specific parameters. Example: >>> env = OpenSpielEnvironment("catch") >>> obs = env.reset() >>> print(obs.info_state) # Agent's observation >>> obs = env.step(OpenSpielAction(action_id=1)) >>> print(obs.reward) """ def __init__( self, game_name: str = "catch", agent_player: int = 0, opponent_policy: str = "random", game_params: Dict[str, Any] | None = None, ): """Initialize OpenSpiel environment.""" super().__init__() self.game_name = game_name self.agent_player = agent_player self.game_params = game_params or {} # Create OpenSpiel environment try: self._ospiel_env = rl_environment.Environment( game_name, **self.game_params ) except Exception as e: raise ValueError( f"Failed to create OpenSpiel game '{game_name}': {e}" ) from e self.num_players = self._ospiel_env.num_players self.is_turn_based = self._ospiel_env.is_turn_based # Validate agent_player if agent_player >= self.num_players: raise ValueError( f"agent_player={agent_player} >= num_players={self.num_players}" ) # Set up opponent policy for multi-player games self.opponent_policy_fn: OpponentPolicy | None = None if self.num_players > 1: self.opponent_policy_fn = get_opponent_policy(opponent_policy) # Initialize state self._state = OpenSpielState( game_name=game_name, agent_player=agent_player, opponent_policy=opponent_policy, game_params=self.game_params, num_players=self.num_players, ) # Track last opponent action for learning self._last_opponent_action: int | None = None def reset(self) -> Observation: """ Reset the environment and return initial observation. For multi-player games, this will autoplay opponent turns until it's the agent's turn (or terminal state). Returns: Initial observation for the agent. """ # Reset OpenSpiel environment time_step = self._ospiel_env.reset() # Reset state tracking self._state.episode_id = str(uuid.uuid4()) self._state.step_count = 0 self._last_opponent_action = None # Autoplay opponent turns until agent's turn time_step = self._auto_play_opponents(time_step) # Convert to OpenEnv observation return self._make_observation(time_step) def step(self, action: Action) -> Observation: """ Execute agent's action and return resulting observation. For multi-player games, this will: 1. Apply the agent's action 2. Autoplay opponent turns until it's the agent's turn again 3. Return the observation from the agent's perspective Args: action: OpenSpielAction containing the action_id to execute. Returns: Observation after action execution (and opponent turns if multi-player). Raises: ValueError: If action is not an OpenSpielAction. """ if not isinstance(action, OpenSpielAction): raise ValueError(f"Expected OpenSpielAction, got {type(action)}") # Apply agent's action if self.is_turn_based: # Turn-based: single action time_step = self._ospiel_env.step([action.action_id]) else: # Simultaneous-move: need actions for all players # For now, only support agent as player 0 in simultaneous games if self.agent_player != 0: raise NotImplementedError( "Simultaneous-move games only support agent_player=0" ) # Get opponent actions opponent_actions = [] for player_id in range(self.num_players): if player_id == self.agent_player: opponent_actions.append(action.action_id) else: legal_actions = time_step.observations["legal_actions"][player_id] opp_action = self.opponent_policy_fn.select_action( legal_actions, time_step.observations ) opponent_actions.append(opp_action) time_step = self._ospiel_env.step(opponent_actions) self._state.step_count += 1 # Autoplay opponent turns (for turn-based games) if self.is_turn_based: time_step = self._auto_play_opponents(time_step) # Convert to OpenEnv observation return self._make_observation(time_step) @property def state(self) -> OpenSpielState: """Get current environment state.""" return self._state def _auto_play_opponents(self, time_step) -> Any: """ Autoplay opponent turns until it's the agent's turn or game is terminal. Args: time_step: Current TimeStep from OpenSpiel environment. Returns: Updated TimeStep after opponent moves. """ # Single-player games: nothing to do if self.num_players == 1: return time_step # Multi-player games: play opponent turns while ( not time_step.last() and time_step.observations["current_player"] != self.agent_player ): current_player = time_step.observations["current_player"] legal_actions = time_step.observations["legal_actions"][current_player] # Select opponent action opp_action = self.opponent_policy_fn.select_action( legal_actions, time_step.observations ) self._last_opponent_action = opp_action # Apply opponent action time_step = self._ospiel_env.step([opp_action]) self._state.step_count += 1 return time_step def _make_observation(self, time_step) -> OpenSpielObservation: """ Convert OpenSpiel TimeStep to OpenEnv Observation. Args: time_step: OpenSpiel TimeStep object. Returns: OpenSpielObservation for the agent. """ # Extract agent's information info_state = time_step.observations["info_state"][self.agent_player] legal_actions = time_step.observations["legal_actions"][self.agent_player] current_player_id = time_step.observations["current_player"] # Determine game phase if time_step.last(): game_phase = "terminal" elif time_step.first(): game_phase = "initial" else: game_phase = "playing" # Get reward for agent reward = None if time_step.rewards is not None: reward = float(time_step.rewards[self.agent_player]) # Create observation obs = OpenSpielObservation( info_state=info_state.tolist() if hasattr(info_state, "tolist") else list(info_state), legal_actions=legal_actions, game_phase=game_phase, current_player_id=current_player_id, opponent_last_action=self._last_opponent_action, done=time_step.last(), reward=reward, ) return obs