gym connect 4 reinforcement learning
class ConnectX(gym.Env):
def __init__(self, switch_prob=0.5):
self.env = make('connectx', debug=True)
self.pair = [None, 'random']
self.trainer = self.env.train(self.pair)
self.switch_prob = switch_prob
config = self.env.configuration
self.action_space = gym.spaces.Discrete(config.columns)
self.observation_space = gym.spaces.Box(low=0, high=2, shape=(config.rows,config.columns,1), dtype=np.int)
def switch_side(self):
self.pair = self.pair[::-1]
self.trainer = self.env.train(self.pair)
def switch_trainer(self):
current_trainer_random = 'random' in self.pair
if current_trainer_random:
self.pair = [None, 'negamax']
else:
self.pair = [None, 'random']
self.trainer = self.env.train(self.pair)
def step(self, action):
return self.trainer.step(action)
def reset(self):
if random.uniform(0, 1) < self.switch_prob: # switch side
self.switch_side()
#if random.uniform(0, 1) < self.switch_prob: # switch trainer
# self.switch_trainer()
return self.trainer.reset()