// This file is part of ArboristCore.

/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

/**
   @file leaf.h

   @brief Records sample contents of leaf nodes.

   @author Mark Seligman
 */

#ifndef FOREST_LEAF_H
#define FOREST_LEAF_H

#include "typeparam.h"
#include "util.h"

#include <vector>

using namespace std;

/**
   @brief Rank and sample-counts associated with sampled rows.

   Client:  quantile inference.
 */
class RankCount {
  // When sampling is not weighted, the sample-count value typically
  // requires four bits or fewer.  Packing therefore accomodates rank
  // values well over 32 bits.
  PackedT packed; // Packed representation of rank and sample count.

  static unsigned int rightBits; // # bits occupied by rank value.
  static PackedT rankMask; // Mask unpacking the rank value.

public:

  /**
     @brief Invoked at Sampler construction, as needed.
   */
  static void setMasks(IndexT nObs) {
    rightBits = Util::packedWidth(nObs);
    rankMask = (1 << rightBits) - 1;
  }


  /**
     @brief Invoked at Sampler destruction.
   */
  static void unsetMasks() {
    rightBits = 0;
    rankMask = 0;
  }
  

  /**
     @brief Packs statistics associated with a response.

     @param rank is the rank of the response value.

     @param sCount is the number of times the observation was sampled.
   */
  void init(IndexT rank,
            IndexT sCount) {
    packed = rank | (sCount << rightBits);
  }

  IndexT getRank() const {
    return packed & rankMask;
  }


  IndexT getSCount() const {
    return packed >> rightBits;
  }
};


/**
   @brief Leaves are indexed by their numbering within the tree.
 */
struct Leaf {
  const bool thin; // EXIT.

  // Training only:
  vector<IndexT> indexCresc; // Sample indices within leaves.
  vector<IndexT> extentCresc; // Index extent, per leaf.
  
  // Post-training only:  extent, index maps fixed.
  const vector<vector<size_t>> extent; // # sample index entries per leaf, per tree.
  const vector<vector<vector<size_t>>> index; // sample indices per leaf, per tree.

  /**
     @brief Training factory.

     @param Sampler conveys observation count, to set static packing parameters.
   */
  static unique_ptr<Leaf> train(IndexT nObs,
				bool thin);

  
  /**
     @brief Prediction factory.

     @param Sampler guides reading of leaf contents.

     @param extent gives the number of distinct samples, forest-wide.

     @param index gives sample positions.
  */
  static unique_ptr<Leaf> predict(const class Sampler* sampler,
				  bool thin,
				  vector<vector<size_t>> extent,
				  vector<vector<vector<size_t>>> index);


  /**
     @brief Training constructor:  crescent structures only.
   */
  Leaf(bool thin_);


  /**
     @brief Post-training constructor:  fixed maps passed in.
   */
  Leaf(const class Sampler* sampler,
       bool thin_,
       vector<vector<size_t>> extent_,
       vector<vector<vector<size_t>>> index_);

  
  /**
     @brief Resets static packing parameters.
   */
  ~Leaf();

  
  /**
     @brief Copies terminal contents, if 'noLeaf' not specified.

     Training caches leaves in order of production.  Depth-first
     leaf numbering requires that the sample maps be reordered.
   */
  void consumeTerminals(const class PreTree* pretree,
			const struct SampleMap& smTerminal);


  /**
     @brief Enumerates the number of samples at each leaf's category.

     'probSample' is the only client.

     @return 3-d vector category counts, indexed by tree/leaf/ctg.
   */
  vector<vector<vector<size_t>>> countLeafCtg(const class Sampler* sampler,
					      const class ResponseCtg* response) const;


  /**
     @brief Count samples at each rank, per leaf, per tree:  regression.

     @param row2Rank is the ranked training outcome.

     @return 3-d mapping as described.
   */
  vector<vector<vector<RankCount>>> alignRanks(const class Sampler* sampler,
					       const vector<IndexT>& row2Rank) const;


  /**
     @return # leaves at a given tree index.
   */
  size_t getLeafCount(unsigned int tIdx) const {
    return extent[tIdx].size();
  }


  const vector<IndexT>& getExtentCresc() const {
    return extentCresc;
  }


  const vector<IndexT>& getIndexCresc() const {
    return indexCresc;
  }
  
  
  const vector<size_t>& getExtents(unsigned int tIdx) const {
    return extent[tIdx];
  }


  const vector<vector<size_t>>& getIndices(unsigned int tIdx) const {
    return index[tIdx];
  }
};

#endif
