Skip to content

Commit

Permalink
"putting back shared experience"
Browse files Browse the repository at this point in the history
  • Loading branch information
ferielamira1 committed May 13, 2024
1 parent d6d9122 commit 2991c47
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def initialize_environment(self):
# calculate number of resources left in the patch
resc = list(ag_resc_overlap.keys())[0]
ag.policy_network.state_tensor = torch.FloatTensor(
ag.soc_v_field.tolist() + [resc.resc_left / resc.resc_units]).unsqueeze(0)
ag.soc_v_field.tolist() + [resc.resc_left / resc.resc_units]).unsqueeze(0).to(device)
else:
ag.policy_network.state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0).to(device)

Expand Down Expand Up @@ -279,10 +279,10 @@ def start_madqn(self):
ag_resc_overlap = self.agent_resource_overlap([ag])
resc= list(ag_resc_overlap.keys())[0]

ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [resc.resc_left/resc.resc_units]).unsqueeze(0)
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [resc.resc_left/resc.resc_units]).unsqueeze(0).to(device)

else:
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0)
ag.policy_network.next_state_tensor = torch.FloatTensor(ag.soc_v_field.tolist() + [0.0]).unsqueeze(0).to(device)

# Calculate the reward as a weighted sum of the individual and collective search efficiency
reward = ag.compute_reward()
Expand Down

0 comments on commit 2991c47

Please sign in to comment.