44 lines
2.3 KiB
Python
44 lines
2.3 KiB
Python
|
# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)
|
||
|
# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
import gym
|
||
|
from gym.spaces import prng
|
||
|
|
||
|
class MultiDiscrete(gym.Space):
|
||
|
"""
|
||
|
- The multi-discrete action space consists of a series of discrete action spaces with different parameters
|
||
|
- It can be adapted to both a Discrete action space or a continuous (Box) action space
|
||
|
- It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space
|
||
|
- It is parametrized by passing an array of arrays containing [min, max] for each discrete action space
|
||
|
where the discrete action space can take any integers from `min` to `max` (both inclusive)
|
||
|
Note: A value of 0 always need to represent the NOOP action.
|
||
|
e.g. Nintendo Game Controller
|
||
|
- Can be conceptualized as 3 discrete action spaces:
|
||
|
1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4
|
||
|
2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
|
||
|
3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1
|
||
|
- Can be initialized as
|
||
|
MultiDiscrete([ [0,4], [0,1], [0,1] ])
|
||
|
"""
|
||
|
def __init__(self, array_of_param_array):
|
||
|
self.low = np.array([x[0] for x in array_of_param_array])
|
||
|
self.high = np.array([x[1] for x in array_of_param_array])
|
||
|
self.num_discrete_space = self.low.shape[0]
|
||
|
|
||
|
def sample(self):
|
||
|
""" Returns a array with one sample from each discrete action space """
|
||
|
# For each row: round(random .* (max - min) + min, 0)
|
||
|
random_array = prng.np_random.rand(self.num_discrete_space)
|
||
|
return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]
|
||
|
def contains(self, x):
|
||
|
return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()
|
||
|
|
||
|
@property
|
||
|
def shape(self):
|
||
|
return self.num_discrete_space
|
||
|
def __repr__(self):
|
||
|
return "MultiDiscrete" + str(self.num_discrete_space)
|
||
|
def __eq__(self, other):
|
||
|
return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)
|