Skip to content

Commit

Permalink
"Trying shared experience replay"
Browse files Browse the repository at this point in the history
  • Loading branch information
ferielamira1 committed May 13, 2024
1 parent e497057 commit 0c37e4d
Showing 1 changed file with 373 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
import math
import random
import time

import pygame
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from abm.monitoring import ifdb, env_saver
from abm.projects.madrl_foraging.madrl_agent.madrl_agent import MADRLAgent as Agent
from abm.contrib import colors,ifdb_params as logging_params
from abm.projects.madrl_foraging.madrl_contrib import madrl_learning_params as learning_params
from abm.simulation.sims import Simulation, notify_agent, refine_ar_overlap_group



from datetime import datetime

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MADRLSimulation(Simulation):
def __init__(self, **kwargs):
"""
Inherited from Simulation class
:param train: boolean, if true the simulation will be ran in training mode, if false in evaluation mode
:param train_every: int, number of timesteps after which the agent will be trained
:param num_episodes: int, number of training episodes, if in evaluation it is set to 1
"""
super().__init__(**kwargs)

self.train=learning_params.train
self.train_every = learning_params.train_every

self.num_episodes = learning_params.num_episodes

seed = learning_params.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def add_new_agent(self, id, x, y, orient, with_proove=False, behave_params=None):
"""Adding a single new agent into agent sprites"""
agent_proven = False

while not agent_proven:
agent = Agent(
id=id,
radius=self.agent_radii,
position=(x, y),
orientation=orient,
env_size=(self.WIDTH, self.HEIGHT),
color=colors.BLUE,
v_field_res=self.v_field_res,
FOV=self.agent_fov,
window_pad=self.window_pad,
pooling_time=self.pooling_time,
pooling_prob=self.pooling_prob,
consumption=self.agent_consumption,
vision_range=self.vision_range,
visual_exclusion=self.visual_exclusion,
patchwise_exclusion=self.patchwise_exclusion,
behave_params=None,
train=self.train,
)
if with_proove:
if self.proove_sprite(agent):
self.agents.add(agent)
agent_proven = True
else:
self.agents.add(agent)
agent_proven = True

def agent_resource_overlap(self, agents):
collision_group_ar = pygame.sprite.groupcollide(
self.rescources,
agents,
False,
False,
pygame.sprite.collide_circle
)

# refine collision group according to point-like pooling in center of agents
collision_group_ar = refine_ar_overlap_group(collision_group_ar)

return collision_group_ar


def agent2resource_interaction(self, collided_agents):
collision_group_ar = self.agent_resource_overlap(self.agents)

# collecting agents that are on resource patch
agents_on_rescs = []

# Notifying agents about resource if pooling is successful + exploitation dynamics
for resc, agents in collision_group_ar.items(): # looping through patches

destroy_resc = False # if we destroy a patch it is 1
for agent in agents: # looping through all agents on patches
self.bias_agent_towards_res_center(agent, resc)

# One of previous agents on patch consumed the last unit
if destroy_resc:
notify_agent(agent, -1)
else:
# Agent finished pooling on a resource patch
if (agent.get_mode() in ["pool", "relocate"] and agent.pool_success) \
or agent.pooling_time == 0:
# Notify about the patch
notify_agent(agent, 1, resc.id)
# Teleport agent to the middle of the patch if needed
if self.teleport_exploit:
agent.position = resc.position + resc.radius - agent.radius

# Agent is exploiting this patch
if agent.get_mode() == "exploit":
#if agent.id not in resc.agent_visits:
# agent.new_discovery = 1 / (1 + math.sqrt(len(resc.agent_visits)))
# resc.agent_visits.append(agent.id)


# continue depleting the patch
depl_units, destroy_resc = resc.deplete(agent.consumption)
agent.collected_r_before = agent.collected_r # rolling resource memory
agent.collected_r += depl_units # and increasing it's collected rescources

# remember the time of last exploitation
if destroy_resc: # consumed unit was the last in the patch
# print(f"Agent {agent.id} has depleted the patch all agents must be notified that"
# f"there are no more units before the next timestep, otherwise they stop"
# f"exploiting with delays")
for agent_tob_notified in agents:
# print("C notify agent NO res ", agent_tob_notified.id)
notify_agent(agent_tob_notified, -1)

# Collect all agents on resource patches
agents_on_rescs.append(agent)

# Patch is fully depleted
if destroy_resc:
# we clear it from the memory and regenerate it somewhere else if needed
self.kill_resource(resc)

# Notifying other agents that there is no resource patch in current position (they are not on patch)
for agent in self.agents.sprites():
if agent not in agents_on_rescs: # for all the agents that are not on recourse patches
if agent not in collided_agents: # and are not colliding with each other currently
# if they finished pooling
if (agent.get_mode() in ["pool",
"relocate"] and agent.pool_success) or agent.pooling_time == 0:
notify_agent(agent, -1)
elif agent.get_mode() == "exploit":
notify_agent(agent, -1)


# Update resource patches
self.rescources.update()

def step(self,turned_on_vfield):
# order the agents by id to ensure that agent 0 has priority over agent 1 and agent 1 over agent 2

# Update internal states of the agents and their positions
self.agents.update(self.agents)

# Check for agent-agent collisions
# collided_agents = self.agent2agent_interaction()
collided_agents = []

# Check for agent-resource interactions and update the resource patches
self.agent2resource_interaction(collided_agents)


for ag in self.agents:
ag.calc_social_V_proj(self.agents)
ag.search_efficiency = ag.collected_r / self.t if self.t != 0 else 0
collective_se = sum(ag.search_efficiency for ag in self.agents) / len(
self.agents)

# Draw the updated environment and agents (and visual fields if needed)
# deciding if vis field needs to be shown in this timestep
self.decide_on_vis_field_visibility(turned_on_vfield)
if self.with_visualization:
self.draw_frame(self.stats, self.stats_pos)
pygame.display.flip()

return collective_se


def initialize_environment(self):
for ag in self.agents:
# Check for agent-agent collisions
# collided_agents = self.agent2agent_interaction()

# Check for agent-resource interactions and update the resource patches
ag_resc_overlap = self.agent_resource_overlap([ag])
if len(ag_resc_overlap) > 0:
ag.env_status = 1

ag.calc_social_V_proj(self.agents)
# Concatenate the resource signal array for the state tensor (The social visual field (1D array )+ the
# environment status (Scalar))
if ag.env_status == 1:
# calculate number of resources left in the patch
resc = list(ag_resc_overlap.keys())[0]
ag.policy_network.state_tensor = torch.FloatTensor(
ag.soc_v_field.tolist() + [resc.resc_left / resc.resc_units]).unsqueeze(0)
else:
ag.policy_network.state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0).to(device)


def start_madqn(self):
"""Main simulation loop for training the agents with MADQN"""

# Start time of the simulation
start_time = datetime.now()

# Create the agents and resources patches in the environment
self.create_agents()
self.create_resources()

# Create surface to show visual fields and local var to decide when to show visual fields
self.stats, self.stats_pos = self.create_vis_field_graph()

# Create a directory to save the data (models and tensorboard logs)
save_dir = logging_params.TIMESTAMP_SAVE_DIR

writer = SummaryWriter(save_dir)
writer.add_text('Hyperparameters',
f'Gamma: {learning_params.gamma}, \n Epsilon Start: {learning_params.epsilon_start}, '
f'\n Epsilon End: {learning_params.epsilon_end}, \n'
f'Epsilon Decay: {learning_params.epsilon_decay},\n Tau: {learning_params.tau},\n Learning '
f'Rate: {learning_params.lr}',
0)

turned_on_vfield = 0

mode = "training" if self.train else "evaluation"
print(f"Starting main simulation loop in MADQN in {mode} with {len(self.agents)} agents and {len(self.rescources)} resources \n "
f"Saving data to {save_dir}!")
for episode in range(self.num_episodes ):
# Create a variable to indicate if the simulation is done
done= False
self.initialize_environment()
collective_se_list = []
print("Starting episode: ",episode)

while self.t < self.T:
# Indicate that the simulation is not done
if self.t==self.T-1:
done = True

# Agent 0 always has riority over agent 1 and agent 1 over agent 2
# If the three of them are on the same patch, and there are not enough resources agent 0 will be allowed to deplete the patch, followed by agent 1, then agent 2
# If the three of them are on the same patch, and there are not enough resources agent 0 will be allowed
# to deplete the patch, followed by agent 1, then agent 2
# If the three of them are on the same patch, and there are not enough resources agent 0 will be allowed to deplete the patch, followed by agent 1, then agent 2

for ag in self.agents:
# Select an action
_ = ag.policy_network.select_action(ag.policy_network.state_tensor)

collective_se = self.step(turned_on_vfield)

collective_se_list.append(collective_se)
# Train the agents
for ag in self.agents:
if done:
ag.policy_network.next_state_tensor = None
ag.reward = collective_se
else:

# Concatenate the resource signal array for the next state tensor (The social visual field (1D array )+ the environment status (Scalar))
if ag.env_status == 1:

ag_resc_overlap = self.agent_resource_overlap([ag])
resc= list(ag_resc_overlap.keys())[0]

ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [resc.resc_left/resc.resc_units]).unsqueeze(0)

else:
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0)

# Calculate the reward as a weighted sum of the individual and collective search efficiency
reward = ag.compute_reward()


if ag.policy_network.action_tensor.item() == 1:
ag.last_exploit_time = self.t

ag.policy_network.reward_tensor = torch.FloatTensor([reward]).to(device)

# Add the experience to the replay memory and train the agent
if self.train:
for ag in self.agents:
for ag2 in self.agents:
ag.policy_network.replay_memory.push(
ag2.policy_network.state_tensor,
ag2.policy_network.action_tensor,
ag2.policy_network.next_state_tensor,
ag2.policy_network.reward_tensor
)
# if self.train and self.t % self.train_every == 0:
loss = ag.policy_network.optimize()

# Update the target network with soft updates
ag.policy_network.update_target_network()


if loss is not None:

writer.add_scalar(f'Agent_{ag.id}/Loss', loss, ag.policy_network.steps_done)
elif ag.policy_network.steps_done > ag.policy_network.batch_size:
print(f"Loss is None at timestep {self.t}!")

# Move to the next training step
ag.policy_network.steps_done += 1
ag.policy_network.state_tensor = ag.policy_network.next_state_tensor
ag.policy_network.last_action = ag.policy_network.action_tensor.item()
# move to next simulation timestep (only when not paused)
self.t += 1
#time.sleep(600)

if self.save_in_ram:
ifdb.save_agent_data_RAM(self.agents, self.t)
ifdb.save_resource_data_RAM(self.rescources, self.t)




for ag in self.agents:

writer.add_scalar(f'Agent_{ag.id}/Individual search efficiency)', ag.search_efficiency,
episode)
writer.add_scalar('Collective search efficiency', collective_se, episode)

ag.reset()
for resc in self.rescources:
self.kill_resource(resc)
self.t=0



print(f"Episode {episode} ended with collective search efficiency: ", collective_se)

# Save the models
if self.train:
for count, ag in enumerate(self.agents):
ag.policy_network.save_model(f'{save_dir}/model_{ag.id}.pth')
# Close the tensorboard writer
writer.close()
env_saver.save_env_vars([self.env_path], "env_params.json", pop_num=None)

if self.save_csv_files:
if self.save_in_ifd or self.save_in_ram:

ifdb.save_ifdb_as_csv(exp_hash=self.ifdb_hash, use_ram=self.save_in_ram, as_zar=self.use_zarr,
save_extracted_vfield=False, pop_num=None)
else:
raise Exception("Tried to save simulation data as csv file due to env configuration, "
"but IFDB/RAM logging was turned off. Nothing to save! Please turn on IFDB/RAM logging"
" or turn off CSV saving feature.")

# Quit the pygame environment
pygame.quit()

# Print the execution time
end_time = datetime.now()
print(f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S.%f')} Total simulation time: ",
(end_time - start_time).total_seconds())

0 comments on commit 0c37e4d

Please sign in to comment.