from tqdm import tqdm from agent import Agent from common.replay_buffer import Buffer import torch import os import numpy as np import matplotlib.pyplot as plt class Runner: def __init__(self, args, env): self.args = args self.noise = args.noise_rate self.epsilon = args.epsilon self.episode_limit = args.max_episode_len self.env = env self.agents = self._init_agents() self.buffer = Buffer(args) self.save_path = self.args.save_dir + '/' + self.args.scenario_name if not os.path.exists(self.save_path): os.makedirs(self.save_path) def _init_agents(self): agents = [] for i in range(self.args.n_agents): agent = Agent(i, self.args) agents.append(agent) return agents def run(self): returns = [] for time_step in tqdm(range(self.args.time_steps)): # reset the environment if time_step % self.episode_limit == 0: s = self.env.reset() u = [] actions = [] with torch.no_grad(): for agent_id, agent in enumerate(self.agents): action = agent.select_action(s[agent_id], self.noise, self.epsilon) u.append(action) actions.append(action) for i in range(self.args.n_agents, self.args.n_players): actions.append([0, np.random.rand() * 2 - 1, 0, np.random.rand() * 2 - 1, 0]) s_next, r, done, info = self.env.step(actions) self.buffer.store_episode(s[:self.args.n_agents], u, r[:self.args.n_agents], s_next[:self.args.n_agents]) s = s_next if self.buffer.current_size >= self.args.batch_size: transitions = self.buffer.sample(self.args.batch_size) for agent in self.agents: other_agents = self.agents.copy() other_agents.remove(agent) agent.learn(transitions, other_agents) if time_step > 0 and time_step % self.args.evaluate_rate == 0: returns.append(self.evaluate()) plt.figure() plt.plot(range(len(returns)), returns) plt.xlabel('episode * ' + str(self.args.evaluate_rate / self.episode_limit)) plt.ylabel('average returns') plt.savefig(self.save_path + '/plt.png', format='png') self.noise = max(0.05, self.noise - 0.0000005) self.epsilon = max(0.05, self.epsilon - 0.0000005) np.save(self.save_path + '/returns.pkl', returns) def evaluate(self): returns = [] for episode in range(self.args.evaluate_episodes): # reset the environment s = self.env.reset() rewards = 0 for time_step in range(self.args.evaluate_episode_len): # if (episode > self.args.evaluate_episode_len - 50): #self.env.render() actions = [] with torch.no_grad(): for agent_id, agent in enumerate(self.agents): action = agent.select_action(s[agent_id], 0, 0) actions.append(action) for i in range(self.args.n_agents, self.args.n_players): actions.append([0, np.random.rand() * 2 - 1, 0, np.random.rand() * 2 - 1, 0]) s_next, r, done, info = self.env.step(actions) rewards += r[0] s = s_next returns.append(rewards) if (episode % 1000 == 0): print('Returns is', rewards) return sum(returns) / self.args.evaluate_episodes