File Graph.hpp
File List > Agents > GP > Graph.hpp
Go to the documentation of this file
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "../../core/AgentBase.hpp"
#include "../AgentLibary.hpp"
#include "GraphNode.hpp"
namespace cowboys {
using GraphLayer = std::vector<std::shared_ptr<GraphNode>>;
class Graph {
protected:
std::vector<GraphLayer> layers;
public:
Graph() = default;
~Graph() = default;
size_t GetNodeCount() const {
return std::accumulate(layers.cbegin(), layers.cend(), 0,
[](size_t sum, const auto &layer) { return sum + layer.size(); });
}
size_t GetLayerCount() const { return layers.size(); }
size_t MakeDecision(const std::vector<double> &inputs, const std::vector<size_t> &actions) {
if (layers.size() == 0)
return actions.at(0);
// Set inputs of input layer
size_t i = 0;
for (auto &node : layers[0]) {
double input = 0;
if (i < inputs.size())
input = inputs.at(i);
node->SetDefaultOutput(input);
++i;
}
// Get output of last layer
std::vector<double> outputs;
for (auto &node : layers.back()) {
outputs.push_back(node->GetOutput());
}
// Choose the action with the highest output
auto max_output = std::max_element(outputs.cbegin(), outputs.cend());
size_t index = std::distance(outputs.cbegin(), max_output);
// If index is out of bounds, return the last action
size_t action = 0;
if (index >= actions.size())
action = actions.back();
else // Otherwise, return the action at the index
action = actions.at(index);
return action;
}
void AddLayer(const GraphLayer &layer) { layers.push_back(layer); }
std::vector<std::shared_ptr<GraphNode>> GetFunctionalNodes() const {
std::vector<std::shared_ptr<GraphNode>> functional_nodes;
for (size_t i = 1; i < layers.size(); ++i) {
functional_nodes.insert(functional_nodes.cend(), layers.at(i).cbegin(), layers.at(i).cend());
}
return functional_nodes;
}
std::vector<std::shared_ptr<GraphNode>> GetNodes() const {
std::vector<std::shared_ptr<GraphNode>> all_nodes;
for (auto &layer : layers) {
all_nodes.insert(all_nodes.cend(), layer.cbegin(), layer.cend());
}
return all_nodes;
}
};
std::vector<size_t> EncodeActions(const std::unordered_map<std::string, size_t> &action_map) {
std::vector<size_t> actions;
for (const auto &[action_name, action_id] : action_map) {
actions.push_back(action_id);
}
// Sort the actions so that they are in a consistent order.
std::sort(actions.begin(), actions.end());
return actions;
}
std::vector<double> EncodeState(const cse491::WorldGrid &grid, const cse491::type_options_t & /*type_options*/,
const cse491::item_map_t & /*item_set*/, const cse491::agent_map_t & /*agent_set*/,
const cse491::AgentBase *agent,
const std::unordered_map<std::string, double> &extra_agent_state) {
std::vector<double> inputs;
auto current_position = agent->GetPosition();
double current_state = grid.At(current_position);
double above_state = grid.IsValid(current_position.Above()) ? grid.At(current_position.Above()) : 0.;
double below_state = grid.IsValid(current_position.Below()) ? grid.At(current_position.Below()) : 0.;
double left_state = grid.IsValid(current_position.ToLeft()) ? grid.At(current_position.ToLeft()) : 0.;
double right_state = grid.IsValid(current_position.ToRight()) ? grid.At(current_position.ToRight()) : 0.;
double prev_action = extra_agent_state.at("previous_action");
double starting_x = extra_agent_state.at("starting_x");
double starting_y = extra_agent_state.at("starting_y");
auto starting_pos = cse491::GridPosition(starting_x, starting_y);
auto path = walle::GetShortestPath(agent->GetPosition(), starting_pos, agent->GetWorld(), *agent);
double distance_from_start = path.size();
inputs.insert(inputs.end(), {prev_action, starting_x, starting_y, distance_from_start, current_state, above_state, below_state, left_state, right_state});
return inputs;
}
} // namespace cowboys