Skip to content

File WorldBase.hpp

File List > core > WorldBase.hpp

Go to the documentation of this file



#pragma once

#include <algorithm>
#include <cassert>
#include <memory>
#include <queue>
#include <random>
#include <sstream>
#include <string>
#include <vector>

#include "../DataCollection/AgentReciever.hpp"
#include "AgentBase.hpp"
#include "Data.hpp"
#include "DataCollection/DataManager.hpp"
#include "Interfaces/NetWorth/client/ClientInterface.hpp"
#include "Interfaces/NetWorth/client/ClientManager.hpp"
#include "Interfaces/NetWorth/client/ControlledAgent.hpp"
#include "Interfaces/NetWorth/server/ServerManager.hpp"
#include "ItemBase.hpp"
#include "WorldGrid.hpp"

// Forward declaration
namespace worldlang {
class ProgramExecutor;
}

namespace cse491 {
class DataReceiver;

class WorldBase {
public:
  static constexpr size_t npos = static_cast<size_t>(-1);
  netWorth::ServerManager *server_manager = nullptr;  
  netWorth::ClientManager *client_manager = nullptr;  

  virtual void ConfigAgent(AgentBase & /* agent */) const {}

protected:
  std::unordered_map<size_t, WorldGrid> grids;

  WorldGrid &main_grid;         
  type_options_t type_options;  

  item_map_t item_map;        
  agent_map_t agent_map;      
  size_t last_entity_id = 0;  

  bool run_over = false;      
  bool world_running = true;  

  std::string action;  
  std::shared_ptr<DataCollection::AgentReceiver> agent_receiver;

  unsigned int seed;                          
  std::mt19937 random_gen;                    
  std::uniform_real_distribution<> uni_dist;  
  std::normal_distribution<> norm_dist;       

  size_t NextEntityID() { return ++last_entity_id; }

  virtual void ConfigAgent(AgentBase & /* agent */) {}

  size_t AddCellType(const std::string &name, const std::string &desc = "", char symbol = '\0') {
    type_options.push_back(CellType{name, desc, symbol});
    return type_options.size() - 1;
  }

public:
  WorldBase(unsigned int seed = 0) : grids(), main_grid(grids[0]), seed(seed) {
    // The first cell type (ID 0) should be reserved for errors or empty spots in a grid.
    AddCellType("Unknown", "This is an invalid cell type and should not be reachable.");

    // Initialize the random number generator.
    if (seed == 0) {
      std::random_device rd;  // An expensive "true" random number generator.
      seed = rd();            // Change the seed to a random value.
    }
    random_gen.seed(seed);
  }

  virtual ~WorldBase() = default;

  virtual void Reset() {
    item_map.clear();
    agent_map.clear();
    last_entity_id = 0;
    run_over = false;
  }

  // -- Accessors --

  [[nodiscard]] size_t GetNumItems() const { return item_map.size(); }

  [[nodiscard]] size_t GetNumAgents() const { return agent_map.size(); }

  [[nodiscard]] bool HasItem(size_t id) const { return item_map.count(id); }

  [[nodiscard]] bool HasAgent(size_t id) const { return agent_map.count(id); }

  [[nodiscard]] ItemBase &GetItem(size_t id) {
    assert(HasItem(id));
    return *item_map[id];
  }

  [[nodiscard]] AgentBase &GetAgent(size_t id) {
    assert(HasAgent(id));
    return *agent_map[id];
  }

  [[nodiscard]] size_t GetItemID(const std::string &name) {
    for (auto &[id, ptr] : item_map) {
      if (ptr->GetName() == name) return id;
    }
    return npos;
  }

  [[nodiscard]] size_t GetAgentID(const std::string &name) {
    for (auto &[id, ptr] : agent_map) {
      if (ptr->GetName() == name) return id;
    }
    return npos;
  }

  virtual WorldGrid &GetGrid() { return main_grid; }
  virtual WorldGrid &GetGrid(size_t grid_id) { return grids[grid_id]; }

  virtual const WorldGrid &GetGrid() const { return main_grid; }
  virtual const WorldGrid &GetGrid(size_t grid_id) const { return grids.at(grid_id); }

  virtual bool GetRunOver() const { return run_over; }

  // -- Random Number Generation --

  unsigned int GetSeed() const { return seed; }

  double GetRandom() { return uni_dist(random_gen); }

  double GetRandom(double max) { return GetRandom() * max; }

  double GetRandom(double min, double max) {
    assert(max > min);
    return min + GetRandom(max - min);
  }

  double GetRandomNormal() { return norm_dist(random_gen); }

  double GetRandomNormal(double mean, double sd = 1.0) {
    assert(sd > 0);
    return mean + norm_dist(random_gen) * sd;
  }

  // -- Agent Management --

  AgentBase &AddConfiguredAgent(std::unique_ptr<AgentBase> agent_ptr) {
    std::mutex agent_map_lock;
    agent_map_lock.lock();
    agent_ptr->SetWorld(*this);
    if (agent_ptr->Initialize() == false) {
      std::cerr << "Failed to initialize agent '" << agent_ptr->GetName() << "'." << std::endl;
    }
    AgentBase & agentReturn = *agent_map[agent_ptr->GetID()];
    agent_map[agent_ptr->GetID()] = std::move(agent_ptr);
    agent_map_lock.unlock();
    return agentReturn;
  }


  template <typename AGENT_T, typename... PROPERTY_Ts>
  AgentBase &AddAgent(std::string agent_name = "None", PROPERTY_Ts... properties) {
    std::mutex agent_map_lock;
    agent_map_lock.lock();
    const size_t agent_id = NextEntityID();
    auto agent_ptr = std::make_unique<AGENT_T>(agent_id, agent_name);
    agent_ptr->SetWorld(*this);
    agent_ptr->SetProperties(std::forward<PROPERTY_Ts>(properties)...);
    ConfigAgent(*agent_ptr);
    if (agent_ptr->Initialize() == false) {
      std::cerr << "Failed to initialize agent '" << agent_name << "'." << std::endl;
    }
    agent_map[agent_id] = std::move(agent_ptr);
    AgentBase &agentReturn = *agent_map[agent_id];
    agent_map_lock.unlock();
    return agentReturn;
  }

  ItemBase &AddItem(std::unique_ptr<ItemBase> item_ptr) {
    assert(item_ptr);                // item_ptr must not be null.
    assert(item_ptr->GetID() != 0);  // item_ptr must have had a non-zero ID assigned.
    item_ptr->SetWorld(*this);
    size_t item_id = item_ptr->GetID();
    item_map[item_id] = std::move(item_ptr);
    return *item_map[item_id];
  }

  template <typename ITEM_T = ItemBase, typename... PROPERTY_Ts>
  ItemBase &AddItem(std::string item_name = "None", PROPERTY_Ts... properties) {
    auto item_ptr = std::make_unique<ITEM_T>(NextEntityID(), item_name);
    item_ptr->SetProperties(std::forward<PROPERTY_Ts>(properties)...);
    return AddItem(std::move(item_ptr));
  }

  WorldBase &RemoveAgent(size_t agent_id) {
    agent_map.erase(agent_id);
    return *this;
  }

  WorldBase &RemoveItem(size_t item_id) {
    item_map.erase(item_id);
    return *this;
  }

  WorldBase &RemoveAgent(std::string agent_name = "None") {
    assert(agent_name != "Interface");  // We are not allowed to remove interfaces.
    return RemoveAgent(GetAgentID(agent_name));
  }

  WorldBase &RemoveItem(std::string item_name) { return RemoveItem(GetItemID(item_name)); }

  WorldBase &AddItemToGrid(size_t item_id, GridPosition pos, size_t grid_id = 0) {
    item_map[item_id]->SetPosition(pos, grid_id);
    return *this;
  }

  // -- Action Management --

  virtual int DoAction(AgentBase &agent, size_t action_id) = 0;

  virtual void RunAgents() {
    for (auto &[id, agent_ptr] : agent_map) {
      size_t action_id = agent_ptr->SelectAction(main_grid, type_options, item_map, agent_map);
      agent_ptr->storeActionMap(agent_ptr->GetName());
      int result = DoAction(*agent_ptr, action_id);
      agent_ptr->SetActionResult(result);
    }
  }

  virtual void RunClientAgents() {
    for (auto &[id, agent_ptr] : agent_map) {
      size_t action_id = agent_ptr->SelectAction(main_grid, type_options, item_map, agent_map);
      agent_ptr->storeActionMap(agent_ptr->GetName());
      int result = DoAction(*agent_ptr, action_id);
      agent_ptr->SetActionResult(result);
    }

    // Deserialize agents
    std::string data = client_manager->getSerializedAgents();
    if (data.substr(0, 18) == ":::START agent_set") {
      std::istringstream is(data);
      DeserializeAgentSet(is, client_manager);
    }
  }

  virtual void RunServerAgents() {
    std::set<size_t> to_delete;

    for (auto &[id, agent_ptr] : agent_map) {
      // wait until clients have connected to run
      while (!server_manager->hasAgentsPresent() || !world_running) {
      }

      // select action and send to client
      size_t action_id = agent_ptr->SelectAction(main_grid, type_options, item_map, agent_map);
      server_manager->writeToActionMap(id, action_id);
      agent_ptr->storeActionMap(agent_ptr->GetName());
      int result = DoAction(*agent_ptr, action_id);
      agent_ptr->SetActionResult(result);

      // mark agent for deletion if client disconnects
      if (action_id == 9999) to_delete.insert(id);
    }

    // delete agents
    for (size_t id : to_delete) {
      RemoveAgent(id);
    }

    // send updates to client for deleted agents
    if (!to_delete.empty()) {
      std::ostringstream os;
      SerializeAgentSet(os);
      std::string data = os.str();
      server_manager->setSerializedAgents(data);
      server_manager->setNewAgent(true);
      server_manager->sendGameUpdates();
    }
  }

  void CollectData() {
    for (const auto &[id, agent_ptr] : agent_map) {
      DataCollection::DataManager::GetInstance().GetAgentReceiver().StoreData(
          agent_ptr->GetName(), agent_ptr->GetPosition(), agent_ptr->GetActionResult());
    }
  }

  virtual void UpdateWorld() {}

  virtual void Run() {
    run_over = false;
    while (!run_over) {
      RunAgents();
      CollectData();
      UpdateWorld();
    }
  }

  virtual void RunClient(netWorth::ClientManager *manager) {
    run_over = false;
    client_manager = manager;
    while (!run_over) {
      if (world_running) {
        RunClientAgents();
        CollectData();
        UpdateWorld();
      }
    }
  }

  virtual void RunServer(netWorth::ServerManager *manager) {
    run_over = false;
    server_manager = manager;
    while (!run_over) {
      if (world_running) {
        RunServerAgents();
        CollectData();
        UpdateWorld();
      }
    }
  }

  virtual void SetWorldRunning(bool running) { world_running = running; }

  // CellType management.

  // Return a const vector of all of the possible cell types.
  [[nodiscard]] const type_options_t &GetCellTypes() const { return type_options; }

  [[nodiscard]] size_t GetCellTypeID(const std::string &name) const {
    for (size_t i = 1; i < type_options.size(); ++i) {
      if (type_options[i].name == name) return i;
    }
    return 0;
  }

  [[nodiscard]] const std::string &GetCellTypeName(size_t id) const {
    if (id >= type_options.size()) return type_options[0].name;
    return type_options[id].name;
  }

  [[nodiscard]] char GetCellTypeSymbol(size_t id) const {
    if (id >= type_options.size()) return type_options[0].symbol;
    return type_options[id].symbol;
  }

  // -- Grid Analysis Helpers --

  [[nodiscard]] virtual std::vector<size_t> FindItemsAt(GridPosition pos,
                                                        size_t grid_id = 0) const {
    std::vector<size_t> item_ids;
    for (const auto &[id, item_ptr] : item_map) {
      if (item_ptr->IsOnGrid(grid_id) && item_ptr->GetPosition() == pos) item_ids.push_back(id);
    }
    return item_ids;
  }

  [[nodiscard]] virtual std::vector<size_t> FindAgentsAt(GridPosition pos,
                                                         size_t grid_id = 0) const {
    std::vector<size_t> agent_ids;
    for (const auto &[id, agent_ptr] : agent_map) {
      if (agent_ptr->IsOnGrid(grid_id) && agent_ptr->GetPosition() == pos) agent_ids.push_back(id);
    }
    return agent_ids;
  }

  [[nodiscard]] virtual std::vector<size_t> FindItemsNear(GridPosition pos, double dist = 1.0,
                                                          size_t grid_id = 0) const {
    std::vector<size_t> item_ids;
    for (const auto &[id, item_ptr] : item_map) {
      if (item_ptr->IsOnGrid(grid_id) && item_ptr->GetPosition().IsNear(pos, dist)) {
        item_ids.push_back(id);
      }
    }
    return item_ids;
  }

  [[nodiscard]] virtual std::vector<size_t> FindAgentsNear(GridPosition pos, double dist = 1.0,
                                                           size_t grid_id = 0) const {
    std::vector<size_t> agent_ids;
    for (const auto &[id, agent_ptr] : agent_map) {
      if (agent_ptr->IsOnGrid(grid_id) && agent_ptr->GetPosition().IsNear(pos, dist)) {
        agent_ids.push_back(id);
      }
    }
    return agent_ids;
  }

  [[nodiscard]] virtual bool IsTraversable(const AgentBase & /*agent*/,
                                           cse491::GridPosition /*pos*/) const {
    return true;
  }

  // -- Network Serialization and Deserialization --

  void SerializeAgentSet(std::ostream &os) {
    os << ":::START agent_set\n";
    SerializeValue(os, agent_map.size());

    for (const auto &[agent_id, agent_ptr] : agent_map) {
      SerializeValue(os, *agent_ptr);
    }
    os << ":::END agent_set\n";
  }

  void DeserializeAgentSet(std::istream &is, netWorth::ClientManager *manager) {
    // find beginning of agent_set serialization
    std::string read;
    std::getline(is, read, '\n');
    if (read != ":::START agent_set") {
      std::cerr << "Could not find start of agent_set serialization" << std::endl;
      return;
    }

    size_t client_id = manager->getClientID();

    // Remove all agents other than the interface
    std::vector<size_t> to_delete;
    for (auto & [agent_id, agent_ptr] : agent_map) {
      if (agent_id != client_id) to_delete.push_back(agent_id);
    }

    for (size_t agent_id : to_delete) {
      RemoveAgent(agent_id);
    }

    // reset last_entity_id; start from the beginning
    last_entity_id = 0;

    // Load the number of agents saved.
    size_t server_last_entity_id = DeserializeAs<size_t>(is);

    // client id NOT in agent map yet if ID = 0
    // append to end of set
    if (client_id == 0) client_id = server_last_entity_id;

    // Load back in all agents.
    for (size_t i = 0; i < server_last_entity_id; i++) {
      // First, check to see if we've hit the end of the agent_set
      // Because we are looking at last entity id (and not total size), we may have
      //  gaps in our agent_set
      auto tmp_pos = is.tellg();
      std::getline(is, read, '\n');
      if(read == ":::END agent_set"){
        if(last_entity_id < client_id) last_entity_id = client_id;
        return;
      }
      else is.seekg(tmp_pos);
      auto agent_ptr = std::make_unique<netWorth::ControlledAgent>(0, "temp");
      DeserializeValue(is, *agent_ptr);
      agent_ptr->SetProperty("manager", manager);

      if (agent_ptr->GetID() >= last_entity_id) last_entity_id = agent_ptr->GetID();

      // If this agent is the client interface, skip it (we already have it).
      if (agent_ptr->GetID() == client_id) { continue; }

      AddConfiguredAgent(std::move(agent_ptr));
    }

    // find end of agent_set deserialization
    std::getline(is, read, '\n');
    if (read != ":::END agent_set") {
      std::cerr << "Could not find end of agent_set serialization" << std::endl;
      return;
    }
  }

  void SerializeItemSet(std::ostream &os) {
    os << ":::START item_set\n";
    SerializeValue(os, item_map.size());

    for (const auto &item : item_map) {
      item.second->Serialize(os);
    }
    os << ":::END item_set\n";
  }

  void DeserializeItemSet(std::istream &is) {
    // find beginning of item_set serialization
    std::string read;
    std::getline(is, read, '\n');
    if (read != ":::START item_set") {
      std::cerr << "Could not find start of item_set serialization" << std::endl;
      return;
    }

    // how many items?
    size_t size;
    DeserializeValue(is, size);

    // read each item
    for (size_t i = 0; i < size; i++) {
      auto item = std::make_unique<ItemBase>(agent_map.size() + i, "");
      DeserializeValue(is, *item);
      AddItem(std::move(item));
    }

    // find end of item_set serialization
    std::getline(is, read, '\n');
    if (read != ":::END item_set") {
      std::cerr << "Could not find end of item_set serialization" << std::endl;
      return;
    }
  }

  void Serialize(std::ostream &os) {
    main_grid.Serialize(os);
    SerializeAgentSet(os);
    SerializeItemSet(os);
  }

  void Deserialize(std::istream &is, netWorth::ClientManager *manager) {
    main_grid.Deserialize(is);
    DeserializeAgentSet(is, manager);
    DeserializeItemSet(is);
  }

  // Needs access to most things here so this is easiest way to do so
  friend worldlang::ProgramExecutor;
};

}  // End of namespace cse491