Skip to content

File GPAgentBase.hpp

File List > Agents > GP > GPAgentBase.hpp

Go to the documentation of this file



#pragma once

#include <algorithm>
#include <cassert>
#include <iostream>
#include <string>
#include <unordered_map>

#include "tinyxml2.h"

#include "../../core/AgentBase.hpp"

namespace cowboys {
  class GPAgentBase : public cse491::AgentBase {
  protected:
    std::unordered_map<std::string, double> extra_state; 
    unsigned int seed = 0;                               
    std::mt19937 rng{seed};                              
    std::uniform_real_distribution<double> uni_dist;     
    std::normal_distribution<double> norm_dist;          

  public:
    GPAgentBase(size_t id, const std::string &name) : AgentBase(id, name) {
      Reset();
    }
    ~GPAgentBase() = default;

    bool Initialize() override { return true; }

    size_t SelectAction(const cse491::WorldGrid &grid, const cse491::type_options_t &type_options,
                        const cse491::item_map_t &item_set, const cse491::agent_map_t &agent_set) override {
      // Update extra state information before action
      if (extra_state["starting_x"] == std::numeric_limits<double>::max()) {
        auto pos = GetPosition();
        extra_state["starting_x"] = pos.GetX();
        extra_state["starting_y"] = pos.GetY();
      }

      size_t action = GetAction(grid, type_options, item_set, agent_set);

      // Update extra state information after action
      extra_state["previous_action"] = action;

      return action;
    }

    virtual size_t GetAction(const cse491::WorldGrid &grid, const cse491::type_options_t &type_options,
                             const cse491::item_map_t &item_set, const cse491::agent_map_t &agent_set) = 0;

    const std::unordered_map<std::string, double> GetExtraState() const { return extra_state; }

    virtual void MutateAgent(double mutation_rate = 0.8) = 0;

    virtual void Copy(const GPAgentBase &other) = 0;

    virtual void PrintAgent(){

    };

    virtual void SerializeGP(tinyxml2::XMLDocument &doc, tinyxml2::XMLElement *parentElem, double fitness = -1) = 0;

    virtual std::string Export() { return ""; }

    virtual void Reset(bool /*hard*/ = false) {
      extra_state["previous_action"] = 0;
      extra_state["starting_x"] = std::numeric_limits<double>::max();
      extra_state["starting_y"] = std::numeric_limits<double>::max();
    };

    //    virtual void crossover(const GPAgentBase &other) {};
    virtual void Import(const std::string &genotype) = 0;

    // -- Random Number Generation --

    void SetSeed(unsigned int seed) {
      this->seed = seed;
      rng.seed(seed);
    }

    unsigned int GetSeed() const { return seed; }

    double GetRandom() { return uni_dist(rng); }

    double GetRandom(double max) { return GetRandom() * max; }

    double GetRandom(double min, double max) {
      assert(max > min);
      return min + GetRandom(max - min);
    }

    size_t GetRandomULL(size_t max) { return static_cast<size_t>(GetRandom(max)); }

    double GetRandomNormal() { return norm_dist(rng); }

    double GetRandomNormal(double mean, double sd = 1.0) {
      assert(sd > 0);
      return mean + norm_dist(rng) * sd;
    }
  };

} // End of namespace cowboys