actor=adamW4-128, critic=RMSprop3-256

This commit is contained in:
Ivan I. Ovchinnikov 2023-01-13 17:30:17 +03:00
parent 0a78c5a7d9
commit d743e23082
3 changed files with 24 additions and 106 deletions

View File

@ -8,15 +8,17 @@ class Actor(nn.Module):
def __init__(self, args, agent_id): def __init__(self, args, agent_id):
super(Actor, self).__init__() super(Actor, self).__init__()
self.max_action = args.high_action self.max_action = args.high_action
self.fc1 = nn.Linear(args.obs_shape[agent_id], 64) self.fc1 = nn.Linear(args.obs_shape[agent_id], 128)
self.fc2 = nn.Linear(64, 64) self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(64, 64) self.fc3 = nn.Linear(128, 128)
self.fc4 = nn.Linear(128, 128)
self.action_out = nn.Linear(128, args.action_shape[agent_id]) self.action_out = nn.Linear(128, args.action_shape[agent_id])
def forward(self, x): def forward(self, x):
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x)) x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x)) x = F.relu(self.fc3(x))
x = F.relu(self.fc4(x))
actions = self.max_action * torch.tanh(self.action_out(x)) actions = self.max_action * torch.tanh(self.action_out(x))
return actions return actions
@ -26,10 +28,10 @@ class Critic(nn.Module):
def __init__(self, args): def __init__(self, args):
super(Critic, self).__init__() super(Critic, self).__init__()
self.max_action = args.high_action self.max_action = args.high_action
self.fc1 = nn.Linear(sum(args.obs_shape) + sum(args.action_shape), 64) self.fc1 = nn.Linear(sum(args.obs_shape) + sum(args.action_shape), 256)
self.fc2 = nn.Linear(64, 64) self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(64, 64) self.fc3 = nn.Linear(256, 256)
self.q_out = nn.Linear(64, 1) self.q_out = nn.Linear(256, 1)
def forward(self, state, action): def forward(self, state, action):
state = torch.cat(state, dim=1) state = torch.cat(state, dim=1)

View File

@ -22,8 +22,8 @@ class MADDPG:
self.critic_target_network.load_state_dict(self.critic_network.state_dict()) self.critic_target_network.load_state_dict(self.critic_network.state_dict())
# create the optimizer # create the optimizer
self.actor_optim = torch.optim.Adam(self.actor_network.parameters(), lr=self.args.lr_actor) self.actor_optim = torch.optim.AdamW(self.actor_network.parameters(), lr=self.args.lr_actor)
self.critic_optim = torch.optim.Adam(self.critic_network.parameters(), lr=self.args.lr_critic) self.critic_optim = torch.optim.RMSprop(self.critic_network.parameters(), lr=self.args.lr_critic)
# create the dict for store the model # create the dict for store the model
if not os.path.exists(self.args.save_dir): if not os.path.exists(self.args.save_dir):

View File

@ -186,108 +186,24 @@
"pip install pyglet==1.5.27" "pip install pyglet==1.5.27"
] ]
}, },
{
"attachments": {},
"cell_type": "markdown",
"id": "9c427530",
"metadata": {},
"source": [
"- actor=4-128-AdamW\n",
"- critic=3-256-RMSprop"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"id": "cb877007", "id": "cb877007",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пытаемся загрузить данные!\n",
"Пытаемся загрузить данные!\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 203/2000000 [00:00<32:10, 1036.07it/s]/home/ovchinnikov_ii@RISDE.ru/Software/Jupyter/MADDPG/maddpg/maddpg.py:60: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" transitions[key] = torch.tensor(transitions[key], dtype=torch.float32)\n",
" 0%| | 307/2000000 [00:01<2:18:56, 239.88it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_0/1_actor_params.pkl\n",
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_1/1_actor_params.pkl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 459/2000000 [00:03<6:02:10, 92.02it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_0/2_actor_params.pkl\n",
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_1/2_actor_params.pkl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 566/2000000 [00:04<7:30:23, 73.99it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_0/3_actor_params.pkl\n",
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_1/3_actor_params.pkl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 667/2000000 [00:06<8:44:54, 63.48it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_0/4_actor_params.pkl\n",
"Пытаемся сохранить данные по пути = ./model/simple_adversary/agent_1/4_actor_params.pkl\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 717/2000000 [00:07<5:36:33, 99.01it/s]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/Software/Jupyter/MADDPG/main.py:18\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mAverage returns is\u001b[39m\u001b[38;5;124m'\u001b[39m, returns)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 18\u001b[0m \u001b[43mrunner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/Jupyter/MADDPG/runner.py:52\u001b[0m, in \u001b[0;36mRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 50\u001b[0m other_agents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39magents\u001b[38;5;241m.\u001b[39mcopy()\n\u001b[1;32m 51\u001b[0m other_agents\u001b[38;5;241m.\u001b[39mremove(agent)\n\u001b[0;32m---> 52\u001b[0m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransitions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mother_agents\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 53\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m time_step \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m time_step \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mevaluate_rate \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 54\u001b[0m returns\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mevaluate())\n",
"File \u001b[0;32m~/Software/Jupyter/MADDPG/agent.py:27\u001b[0m, in \u001b[0;36mAgent.learn\u001b[0;34m(self, transitions, other_agents)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\u001b[38;5;28mself\u001b[39m, transitions, other_agents):\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpolicy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtransitions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mother_agents\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/Jupyter/MADDPG/maddpg/maddpg.py:95\u001b[0m, in \u001b[0;36mMADDPG.train\u001b[0;34m(self, transitions, other_agents)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactor_optim\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 94\u001b[0m actor_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 95\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactor_optim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcritic_optim\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 97\u001b[0m critic_loss\u001b[38;5;241m.\u001b[39mbackward()\n",
"File \u001b[0;32m~/Software/Jupyter/venv/lib/python3.9/site-packages/torch/autograd/grad_mode.py:26\u001b[0m, in \u001b[0;36m_DecoratorContextManager.__call__.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m():\n\u001b[0;32m---> 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/Jupyter/venv/lib/python3.9/site-packages/torch/optim/adamw.py:116\u001b[0m, in \u001b[0;36mAdamW.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 112\u001b[0m denom \u001b[38;5;241m=\u001b[39m (exp_avg_sq\u001b[38;5;241m.\u001b[39msqrt() \u001b[38;5;241m/\u001b[39m math\u001b[38;5;241m.\u001b[39msqrt(bias_correction2))\u001b[38;5;241m.\u001b[39madd_(group[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124meps\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 114\u001b[0m step_size \u001b[38;5;241m=\u001b[39m group[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlr\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m/\u001b[39m bias_correction1\n\u001b[0;32m--> 116\u001b[0m p\u001b[38;5;241m.\u001b[39maddcdiv_(exp_avg, denom, value\u001b[38;5;241m=\u001b[39m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43mstep_size\u001b[49m)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [ "source": [
"%run ./main.py --scenario-name=simple_adversary --evaluate-episodes=10000 --save-rate=100" "%run ./main.py --scenario-name=simple_adversary --evaluate-episodes=10 --save-rate=50000 --evaluate-rate=50000"
] ]
}, },
{ {