Skip to content

Commit

Permalink
adding to device
Browse files Browse the repository at this point in the history
  • Loading branch information
ferielamira1 committed May 13, 2024
1 parent 1c01d39 commit e65573b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion abm/app_madrl_foraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def start(parallel=True, headless=False):
if envconf["BRAIN_TYPE"] == "ideal":
from abm.projects.madrl_foraging.madrl_simulation.heuristic_sims import HeuristicSimulation as Simulation
else:
from abm.projects.madrl_foraging.madrl_simulation.madrl_sims_shared_replay import MADRLSimulation as Simulation
from abm.projects.madrl_foraging.madrl_simulation.madrl_sims import MADRLSimulation as Simulation

vscreen_width = int(envconf["ENV_WIDTH"]) + 2 * int(envconf["WINDOW_PAD"]) + 10
vscreen_height = int(envconf["ENV_HEIGHT"]) + 2 * int(envconf["WINDOW_PAD"]) + 10
Expand Down
6 changes: 3 additions & 3 deletions abm/projects/madrl_foraging/madrl_simulation/madrl_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def initialize_environment(self):
# calculate number of resources left in the patch
resc = list(ag_resc_overlap.keys())[0]
ag.policy_network.state_tensor = torch.FloatTensor(
ag.soc_v_field.tolist() + [resc.resc_left / resc.resc_units]).unsqueeze(0)
ag.soc_v_field.tolist() + [resc.resc_left / resc.resc_units]).unsqueeze(0).to(device)
else:
ag.policy_network.state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0).to(device)

Expand Down Expand Up @@ -280,10 +280,10 @@ def start_madqn(self):
ag_resc_overlap = self.agent_resource_overlap([ag])
resc= list(ag_resc_overlap.keys())[0]

ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [resc.resc_left/resc.resc_units]).unsqueeze(0)
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [resc.resc_left/resc.resc_units]).unsqueeze(0).to(device)

else:
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0)
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0).to(device)

# Calculate the reward as a weighted sum of the individual and collective search efficiency
reward = ag.compute_reward()
Expand Down

0 comments on commit e65573b

Please sign in to comment.