87 lines
3.7 KiB
Python
87 lines
3.7 KiB
Python
|
from tqdm import tqdm
|
||
|
from agent import Agent
|
||
|
from common.replay_buffer import Buffer
|
||
|
import torch
|
||
|
import os
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
class Runner:
|
||
|
def __init__(self, args, env):
|
||
|
self.args = args
|
||
|
self.noise = args.noise_rate
|
||
|
self.epsilon = args.epsilon
|
||
|
self.episode_limit = args.max_episode_len
|
||
|
self.env = env
|
||
|
self.agents = self._init_agents()
|
||
|
self.buffer = Buffer(args)
|
||
|
self.save_path = self.args.save_dir + '/' + self.args.scenario_name
|
||
|
if not os.path.exists(self.save_path):
|
||
|
os.makedirs(self.save_path)
|
||
|
|
||
|
def _init_agents(self):
|
||
|
agents = []
|
||
|
for i in range(self.args.n_agents):
|
||
|
agent = Agent(i, self.args)
|
||
|
agents.append(agent)
|
||
|
return agents
|
||
|
|
||
|
def run(self):
|
||
|
returns = []
|
||
|
for time_step in tqdm(range(self.args.time_steps)):
|
||
|
# reset the environment
|
||
|
if time_step % self.episode_limit == 0:
|
||
|
s = self.env.reset()
|
||
|
u = []
|
||
|
actions = []
|
||
|
with torch.no_grad():
|
||
|
for agent_id, agent in enumerate(self.agents):
|
||
|
action = agent.select_action(s[agent_id], self.noise, self.epsilon)
|
||
|
u.append(action)
|
||
|
actions.append(action)
|
||
|
for i in range(self.args.n_agents, self.args.n_players):
|
||
|
actions.append([0, np.random.rand() * 2 - 1, 0, np.random.rand() * 2 - 1, 0])
|
||
|
s_next, r, done, info = self.env.step(actions)
|
||
|
self.buffer.store_episode(s[:self.args.n_agents], u, r[:self.args.n_agents], s_next[:self.args.n_agents])
|
||
|
s = s_next
|
||
|
if self.buffer.current_size >= self.args.batch_size:
|
||
|
transitions = self.buffer.sample(self.args.batch_size)
|
||
|
for agent in self.agents:
|
||
|
other_agents = self.agents.copy()
|
||
|
other_agents.remove(agent)
|
||
|
agent.learn(transitions, other_agents)
|
||
|
if time_step > 0 and time_step % self.args.evaluate_rate == 0:
|
||
|
returns.append(self.evaluate())
|
||
|
plt.figure()
|
||
|
plt.plot(range(len(returns)), returns)
|
||
|
plt.xlabel('episode * ' + str(self.args.evaluate_rate / self.episode_limit))
|
||
|
plt.ylabel('average returns')
|
||
|
plt.savefig(self.save_path + '/plt.png', format='png')
|
||
|
self.noise = max(0.05, self.noise - 0.0000005)
|
||
|
self.epsilon = max(0.05, self.epsilon - 0.0000005)
|
||
|
np.save(self.save_path + '/returns.pkl', returns)
|
||
|
|
||
|
def evaluate(self):
|
||
|
returns = []
|
||
|
for episode in range(self.args.evaluate_episodes):
|
||
|
# reset the environment
|
||
|
s = self.env.reset()
|
||
|
rewards = 0
|
||
|
for time_step in range(self.args.evaluate_episode_len):
|
||
|
# if (episode > self.args.evaluate_episode_len - 50):
|
||
|
#self.env.render()
|
||
|
actions = []
|
||
|
with torch.no_grad():
|
||
|
for agent_id, agent in enumerate(self.agents):
|
||
|
action = agent.select_action(s[agent_id], 0, 0)
|
||
|
actions.append(action)
|
||
|
for i in range(self.args.n_agents, self.args.n_players):
|
||
|
actions.append([0, np.random.rand() * 2 - 1, 0, np.random.rand() * 2 - 1, 0])
|
||
|
s_next, r, done, info = self.env.step(actions)
|
||
|
rewards += r[0]
|
||
|
s = s_next
|
||
|
returns.append(rewards)
|
||
|
if (episode % 1000 == 0):
|
||
|
print('Returns is', rewards)
|
||
|
return sum(returns) / self.args.evaluate_episodes
|