bmstu-marl/maddpg/actor_critic.py

47 lines
1.4 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
# define the actor network
class Actor(nn.Module):
def __init__(self, args, agent_id):
super(Actor, self).__init__()
self.max_action = args.high_action
2023-01-13 17:30:17 +03:00
self.fc1 = nn.Linear(args.obs_shape[agent_id], 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 128)
self.fc4 = nn.Linear(128, 128)
self.action_out = nn.Linear(128, args.action_shape[agent_id])
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
2023-01-13 17:30:17 +03:00
x = F.relu(self.fc4(x))
actions = self.max_action * torch.tanh(self.action_out(x))
return actions
class Critic(nn.Module):
def __init__(self, args):
super(Critic, self).__init__()
self.max_action = args.high_action
2023-01-13 17:30:17 +03:00
self.fc1 = nn.Linear(sum(args.obs_shape) + sum(args.action_shape), 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 256)
self.q_out = nn.Linear(256, 1)
def forward(self, state, action):
state = torch.cat(state, dim=1)
for i in range(len(action)):
action[i] /= self.max_action
action = torch.cat(action, dim=1)
x = torch.cat([state, action], dim=1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
q_value = self.q_out(x)
return q_value