/***************************************************************************
 *      GModelDataMultiplicative.cpp - Multiplicative data model class     *
 * ----------------------------------------------------------------------- *
 *  copyright (C) 2025 by Juergen Knoedlseder                              *
 * ----------------------------------------------------------------------- *
 *                                                                         *
 *  This program is free software: you can redistribute it and/or modify   *
 *  it under the terms of the GNU General Public License as published by   *
 *  the Free Software Foundation, either version 3 of the License, or      *
 *  (at your option) any later version.                                    *
 *                                                                         *
 *  This program is distributed in the hope that it will be useful,        *
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of         *
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the          *
 *  GNU General Public License for more details.                           *
 *                                                                         *
 *  You should have received a copy of the GNU General Public License      *
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.  *
 *                                                                         *
 ***************************************************************************/
/**
 * @file GModelDataMultiplicative.cpp
 * @brief Multiplicative data model class implementation
 * @author Juergen Knoedlseder
 */

/* __ Includes ___________________________________________________________ */
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include "GException.hpp"
#include "GTools.hpp"
#include "GRan.hpp"
#include "GModelDataMultiplicative.hpp"
#include "GModelRegistry.hpp"
#include "GObservation.hpp"
#include "GMatrixSparse.hpp"

/* __ Constants __________________________________________________________ */

/* __ Globals ____________________________________________________________ */
const GModelDataMultiplicative g_data_multi_seed;
const GModelRegistry           g_data_registry(&g_data_multi_seed);

/* __ Method name definitions ____________________________________________ */
#define G_EVAL               "GModelDataMultiplicative::eval(GObservation&, "\
                                                            "GMatrixSparse*)"
#define G_NPRED          "GModelDataMultiplicative::npred(GEnergy&, GTime&, "\
                                                             "GObservation&)"
#define G_MC             "GModelDataMultiplicative::mc(GObservation&, GRan&)"
#define G_READ                 "GModelDataMultiplicative::read(GXmlElement&)"
#define G_WRITE               "GModelDataMultiplicative::write(GXmlElement&)"
#define G_COMPONENT_INDEX         "GModelDataMultiplicative::component(int&)"
#define G_COMPONENT_NAME  "GModelDataMultiplicative::component(std::string&)"
#define G_APPEND             "GModelDataMultiplicative::append(GModelData&, "\
                                                              "std::string&)"

/* __ Macros _____________________________________________________________ */

/* __ Coding definitions _________________________________________________ */
//#define G_PREPEND_MODEL_NAME     //!< Prepend model name of parameter names

/* __ Debug definitions __________________________________________________ */


/*==========================================================================
 =                                                                         =
 =                        Constructors/destructors                         =
 =                                                                         =
 ==========================================================================*/

/***********************************************************************//**
 * @brief Void constructor
 ***************************************************************************/
GModelDataMultiplicative::GModelDataMultiplicative(void) : GModelData()
{
    // Initialise private members for clean destruction
    init_members();

    // Return
    return;
}


/***********************************************************************//**
 * @brief XML constructor
 *
 * @param[in] xml XML element containing data model information.
 *
 * Constructs a multiplicative data model by extracting information from an
 * XML element. See the read() method for more information about the expected
 * structure of the XML element.
 ***************************************************************************/
GModelDataMultiplicative::GModelDataMultiplicative(const GXmlElement& xml) :
                          GModelData()
{
    // Initialise members
    init_members();

    // Read information from XML element
    read(xml);

    // Return
    return;
}


/***********************************************************************//**
 * @brief Copy constructor
 *
 * @param[in] model Multiplicative data model.
 ***************************************************************************/
GModelDataMultiplicative::GModelDataMultiplicative(const GModelDataMultiplicative& model) :
                          GModelData(model)
{
    // Initialise members
    init_members();

    // Copy members
    copy_members(model);

    // Return
    return;
}


/***********************************************************************//**
 * @brief Destructor
 ***************************************************************************/
GModelDataMultiplicative::~GModelDataMultiplicative(void)
{
    // Free members
    free_members();

    // Return
    return;
}


/*==========================================================================
 =                                                                         =
 =                                Operators                                =
 =                                                                         =
 ==========================================================================*/

/***********************************************************************//**
 * @brief Assignment operator
 *
 * @param[in] model Multiplicative data model.
 * @return Multiplicative data model.
 ***************************************************************************/
GModelDataMultiplicative& GModelDataMultiplicative::operator=(const GModelDataMultiplicative& model)
{
    // Execute only if object is not identical
    if (this != &model) {

        // Copy base class members
        this->GModelData::operator=(model);

        // Free members
        free_members();

        // Initialise members
        init_members();

        // Copy members
        copy_members(model);

    } // endif: object was not identical

    // Return
    return *this;
}


/*==========================================================================
 =                                                                         =
 =                              Public methods                             =
 =                                                                         =
 ==========================================================================*/

/***********************************************************************//**
 * @brief Clear multiplicative data model
 ***************************************************************************/
void GModelDataMultiplicative::clear(void)
{
    // Free class members (base and derived classes, derived class first)
    free_members();
    this->GModelData::free_members();

    // Initialise members
    this->GModelData::init_members();
    init_members();

    // Return
    return;
}


/***********************************************************************//**
 * @brief Clone multiplicative data model
 *
 * @return Pointer to deep copy of multiplicative data model.
 ***************************************************************************/
GModelDataMultiplicative* GModelDataMultiplicative::clone(void) const
{
    // Clone multiplicative data model
    return new GModelDataMultiplicative(*this);
}


/***********************************************************************//**
 * @brief Return if model is constant
 *
 * @return True if model is constant.
 ***************************************************************************/
bool GModelDataMultiplicative::is_constant(void) const
{
    // Initialise constancy
    bool constant = true;

    // If one of the components is not constant than signal non constancy
    for (int i = 0; i < m_models.size(); ++i) {
        if (m_models[i]->is_constant() == false) {
            constant = false;
            break;
        }
    }

    // Return constancy
    return constant;
}


/***********************************************************************//**
 * @brief Return model values and gradients
 *
 * @param[in] event Event.
 * @param[in] obs Observation.
 * @param[in] gradients Compute gradients?
 * @return Model value.
 *
 * Evaluates
 *
 * \f[
 *    \prod_{i=0}^{N} {M_{\rm i}}(\rm event, \rm obs)
 * \f]
 *
 * where \f${M_{\rm i}}\f$ is the i-th model component.
 *
 * If the @p gradients flag is true the method will also compute the partial
 * derivatives of each parameter of eachmodel component with respect to the
 * parameters using
 *
 * \f[
 *    \frac{\delta S}{\delta P_{\rm ij}}\prod_{\rm k\neq \rm i}^{n} M_{\rm k}
 * \f]
 *
 * where \f${P_{\rm ij}}\f$ is the j-th parameter of the i-th multiplicative
 * component, while \f${M_{\rm k}}\f$ is the k-th model component and n the
 * number of model components.
 ***************************************************************************/
double GModelDataMultiplicative::eval(const GEvent&       event,
                                      const GObservation& obs,
                                      const bool&         gradients) const
{
    // Initialise result
    double value = 0.0;

    // Check for available model components
    if (m_models.size() > 0) {

        // Set first model component
        value = m_models[0]->eval(event, obs, gradients);

        // Loop over model components
        for (int i = 1; i < m_models.size(); ++i) {
            value *= m_models[i]->eval(event, obs, gradients);
        }

    } // endfor: loop over model components

    // Modify gradients if requested
    if (gradients) {

        // Loop over model components
        for (int i = 0; i < m_models.size(); ++i) {

            // Initialise scaling factor
            double factor = 1.0;

            // Loop over other model components and compute factor
            for (int j = 0; j < m_models.size(); ++j) {
                if (i != j) {
                    factor *= m_models[j]->eval(event, obs, false);
                }
            }

            // Loop over model parameters
            for (int ipar = 0; ipar < m_models[i]->size(); ++ipar) {

                // Get reference to model parameter
                GModelPar& par = m_models[i]->operator[](ipar);

                // Scale parameter gradient
                par.gradient(par.gradient()*factor);

            } // endfor: loop over model parameters

        } // endfor: loop over models

    } //endif: compute gradients

    // Compile option: Check for NaN/Inf
    #if defined(G_NAN_CHECK)
    if (gammalib::is_notanumber(value) || gammalib::is_infinite(value)) {
        std::cout << "*** ERROR: GModelDataMultiplicative::eval():";
        std::cout << " NaN/Inf encountered";
        std::cout << " (value=" << value;
        std::cout << ")" << std::endl;
    }
    #endif

    // Return
    return value;
}


/***********************************************************************//**
 * @brief Return model values and gradients
 *
 * @param[in] obs Observation.
 * @param[out] gradients Pointer to matrix of gradients.
 * @param[in] offset Column index of first matrix gradient (not used).
 * @return Model values.
 *
 * @exception GException::invalid_argument
 *            Gradient matrix has wrong number of rows or columns
 *
 * Evaluates the model values and parameter gradients for all events in an
 * observation. Gradients are only returned if the @p gradients pointer is
 * not NULL.
 *
 * The matrix of gradients is a sparse matrix where the number of rows
 * corresponds to the number of events and the number of columns corresponds
 * to the number of model parameters (see GObservation::model() method).
 *
 * An exception is thrown if the dimension of the @p gradients matrix is not
 * compatible with the model and the observations.
 ***************************************************************************/
GVector GModelDataMultiplicative::eval(const GObservation& obs,
                                       GMatrixSparse*      gradients,
                                       const int&          offset) const
{
    // Get number of model parameters and number of events
    int npars   = size();
    int nevents = obs.events()->size();

    // Initialise gradients flag
    bool grad = ((gradients != NULL) && (npars > 0));

    // Check matrix consistency
    if (grad) {
        if (gradients->columns() != npars) {
            std::string msg = "Number of "+gammalib::str(gradients->columns())+
                              " columns in gradient matrix differs from number "
                              "of "+gammalib::str(npars)+" parameters "
                              "in model. Please provide a compatible gradient "
                              "matrix.";
            throw GException::invalid_argument(G_EVAL, msg);
        }
        if (gradients->rows() != nevents) {
            std::string msg = "Number of "+gammalib::str(gradients->rows())+
                              " rows in gradient matrix differs from number "
                              "of "+gammalib::str(nevents)+" events in "
                              "observation. Please provide a compatible "
                              "gradient matrix.";
            throw GException::invalid_argument(G_EVAL, msg);
        }
    }

    // Allocate values vector
    GVector values(nevents);

    // If there are model components then set values and gradients
    if (m_models.size() > 0) {

        // Initialise value vectors
        std::vector<GVector> component_values(m_models.size());

        // Initialise gradient column offset
        int ioffset = 0;

        // Evaluate all model components
        for (int i = 0; i < m_models.size(); ++i) {

            // Compute vector of values and optionally matrix of gradients
            component_values[i] = m_models[i]->eval(obs, gradients, ioffset);

            // Signal all parameters that will have analytical gradients. These
            // are all parameters that are free and for which the model provides
            // analytical gradients.
            if (grad) {
                for (int ipar = 0; ipar < m_models[i]->size(); ++ipar) {
                    const GModelPar& par = (*this)[ipar+ioffset];
                    if (par.is_free() && par.has_grad()) {
                        obs.computed_gradient(*this, par);
                    }
                }
            } // endif: gradients were requested

            // Update gradient matrix column offset
            ioffset += m_models[i]->size();

        } // endfor: looped over all model component

        // Re-initialise gradient column offset
        ioffset = 0;

        // Compute values and optionally gradients
        for (int i = 0; i < m_models.size(); ++i) {

            // Set or multiply model component values
            if (i == 0) {
                values = component_values[i];
            }
            else {
                values *= component_values[i];
            }

            // Optionally compute gradients
            if (grad) {

                // Allocate factor vector
                GVector factor(nevents);

                // Initialise factor vector
                factor = 1.0;

                // Loop over other model components and compute factor vector
                for (int j = 0; j < m_models.size(); ++j) {
                    if (i != j) {
                        factor *= component_values[j];
                    }
                }

                // Multiply all gradient columns with factor vector
                for (int ipar = 0; ipar < m_models[i]->size(); ++ipar) {
                    gradients->multiply_column(ipar+ioffset, factor);
                }

                // Increment gradient column offset
                ioffset += m_models[i]->size();

            } // endif: optionally compute offsets

        } // endfor: loop over model components

    } // endif: there were model components

    // Return values
    return values;
}


/***********************************************************************//**
 * @brief Return spatially integrated data model
 *
 * @param[in] obsEng Measured event energy.
 * @param[in] obsTime Measured event time.
 * @param[in] obs Observation.
 *
 * @exception GException::feature_not_implemented
 *            Feature not implemented
 *
 * @todo Implement method.
 ***************************************************************************/
double GModelDataMultiplicative::npred(const GEnergy&      obsEng,
                                       const GTime&        obsTime,
                                       const GObservation& obs) const
{
    // Initialise result
    double npred = 0.0;

    // Throw exception signaling that feature is not yet implemented
    throw GException::feature_not_implemented(G_NPRED);

    // Return
    return npred;
}


/***********************************************************************//**
 * @brief Return simulated events
 *
 * @param[in] obs Observation.
 * @param[in] ran Random number generator.
 *
 * @exception GException::feature_not_implemented
 *            Feature not implemented
 *
 * @todo Implement method.
 ***************************************************************************/
GEvents* GModelDataMultiplicative::mc(const GObservation& obs,
                                      GRan&               ran) const
{
    // Initialise new event cube
    GEvents* events = NULL;

    // Throw exception signaling that feature is not yet implemented
    throw GException::feature_not_implemented(G_MC);

    // Return events
    return events;
}


/***********************************************************************//**
 * @brief Read model from XML element
 *
 * @param[in] xml XML element.
 *
 * @exception GException::invalid_value
 *            Invalid model type specified
 *            Model is not a data model
 *            Model is for an incompatible instrument
 *
 * Reads the multiplicative data model from an XML element. The expected
 * XML format is
 *
 *     <source name="Model" type="MultiplicativeData" instrument="...">
 *       <source name="Component1" type="DataSpace" instrument="...">
 *          ...
 *       </source>
 *       <source name="Component2" type="DataSpace" instrument="...">
 *          ...
 *       </source>
 *     </source>
 ***************************************************************************/
void GModelDataMultiplicative::read(const GXmlElement& xml)
{
    // Clear models and model parameters
    m_models.clear();
    m_pars.clear();

    // Get number of model components
    int num = xml.elements("source");

    // Read model attributes
    read_attributes(xml);

    // Loop over model components
    for (int i = 0; i < num; ++i) {

        // Get model XML element
        const GXmlElement* source = xml.element("source", i);

        // Get model type
        std::string type = source->attribute("type");

        // Allocate a model registry object
        GModelRegistry registry;

        // Read model
        GModel* ptr = registry.alloc(type);

        // Check that model is known
        if (ptr == NULL) {
            std::string msg = "Model type \""+type+"\" unknown. The following "
                              "model types are available: "+registry.content()+
                              ". Please specify one of the available model "
                              "types.";
            throw GException::invalid_value(G_READ, msg);
        }

        // Check that model is a data model
        GModelData* model = dynamic_cast<GModelData*>(ptr);
        if (model == NULL) {
            std::string msg = "Model type \""+type+"\" is not a data model. "
                              "Please specify a data model.";
            throw GException::invalid_value(G_READ, msg);
        }

        // Read model from XML file
        model->read(*source);

        // Check that model
        if (model->instruments() != instruments()) {
            std::string msg = "Model \""+model->name()+"\" is for instrument "
                              "\""+model->instruments()+"\" but model container "
                              "is for instrument \""+instruments()+"\". Please "
                              "specify a data model that is compliant with the "
                              "container instrument.";
            throw GException::invalid_value(G_READ, msg);
        }

        // Append data model component to container
        append(*model);

        // Free model
        delete ptr;

    } // endfor: loop over models

    // Return
    return;
}


/***********************************************************************//**
 * @brief Write model into XML element
 *
 * @param[in] xml XML element.
 *
 * Writes the multiplicative data model into an XML element.
 ***************************************************************************/
void GModelDataMultiplicative::write(GXmlElement& xml) const
{
    // Initialise pointer on source
    GXmlElement* src = NULL;

    // Search corresponding source
    int n = xml.elements("source");
    for (int k = 0; k < n; ++k) {
        GXmlElement* element = xml.element("source", k);
        if (element->attribute("name") == name()) {
            src = element;
            break;
        }
    }

    // If no source with corresponding name was found then append one.
    // Set also the type and the instrument.
    if (src == NULL) {
        src = xml.append("source");
        src->attribute("name", name());
        src->attribute("type", type());
        if (instruments().length() > 0) {
            src->attribute("instrument", instruments());
        }
    }

    // Verify model type
    gammalib::xml_check_type(G_WRITE, *src, type());

    // Loop over model components
    for (int i = 0; i < m_models.size(); i++) {

        #if defined(G_PREPEND_MODEL_NAME)
        // Create temporary copy of the data model. This is a kluge to
        // write out the original parameters.
        GModelData* cpy = m_models[i]->clone();

        // Loop over parameters of model
        for (int j = 0; j < cpy->size(); ++j) {

            // Get model parameter and name
            GModelPar&  par     = (*cpy)[j];
            std::string parname = par.name();

            // Check if name contains colon
            if (gammalib::contains(parname, ":")) {

                // Split at the colon
                std::vector<std::string> splits = gammalib::split(parname, ":");

                // Use second part of the string to recover original
                // parameter name
                par.name(splits[1]);

            }

        } // endfor: loop over parameters

        // Write model component
        cpy->write(*src);

        // Remove temporary copy
        delete cpy;
        #else
        // Write model component
        m_models[i]->write(*src);
        #endif

    } // endfor: loop over model components

    // Write model attributes
    write_attributes(*src);

    // Return
    return;
}


/***********************************************************************//**
 * @brief Append data model component
 *
 * @param[in] model Data model component.
 *
 * @exception GException::invalid_value
 *            Model with same name exists already in container
 *            Model is for incompatible instrument
 *
 * Appends a data model component to the multiplicative data model. If
 * the instruments for the container was not set the instrument is copied
 * from the model. Otherwise, the method verifies that the instrument of the
 * model is the same as of the container.
 ***************************************************************************/
void GModelDataMultiplicative::append(const GModelData& model)
{
    // Check if model with same name exists already in container
    for (int i = 0; i < m_models.size(); i++) {
        if (m_models[i]->name() == model.name()) {
            std::string msg = "Attempt to append model with name \""+
                              model.name()+"\" to multiplicative data model "
                              "container, but a component with the same name "
                              "exists already in the container. Every "
                              "component in the container needs to have a "
                              "unique name.";
            throw GException::invalid_value(G_APPEND, msg);
        }
    }

    // Check that instruments are compliant
    if (instruments() == "") {
        instruments(model.instruments());
    }
    else if (model.instruments() != instruments()) {
        std::string msg = "Model \""+model.name()+"\" is for instrument "
                          "\""+model.instruments()+"\" but model container "
                          "is for instrument \""+instruments()+"\". Please "
                          "specify a data model that is compliant with the "
                          "container instrument.";
        throw GException::invalid_value(G_APPEND, msg);
    }

    // Append clone of data model to container
    m_models.push_back(model.clone());

    // Get index of data model
    int index = m_models.size()-1;

    // Get number of model parameters
    int npars = m_models[index]->size();

    // Loop over model parameters
    for (int ipar = 0; ipar < npars; ++ipar) {

        // Get model parameter
        GModelPar* par = &(m_models[index]->operator[](ipar));

        // Prepend model name to parameter name
        #if defined(G_PREPEND_MODEL_NAME)
        par->name(model.name()+":"+par->name());
        #endif

        // Append model parameter with new name to internal container
        m_pars.push_back(par);

    } // endfor: loop over model parameters

    // Return
    return;
}


/***********************************************************************//**
 * @brief Return data model by index
 *
 * @param[in] index Index of data model.
 * @return Pointer to data model.
 *
 * Returns a component of the multiplicative data model by @p index.
 ***************************************************************************/
const GModelData* GModelDataMultiplicative::component(const int& index) const
{
    // Check if index is in validity range
    if (index >= m_models.size() || index < 0) {
        throw GException::out_of_range(G_COMPONENT_INDEX, "Component Index",
                                       index, m_models.size());
    }

    // Return pointer to data model
    return m_models[index];
}


/***********************************************************************//**
 * @brief Return data model by name
 *
 * @param[in] name Name of data model.
 * @return Pointer to data model.
 *
 * @exception GException::invalid_argument
 *            Model component not found.
 *
 * Returns a component of the multiplicative data model by @p name.
 ***************************************************************************/
const GModelData* GModelDataMultiplicative::component(const std::string& name) const
{
    // Check if model name is found
    int index = -1;
    for (int i = 0; i < m_models.size(); ++i) {
        if (m_models[i]->name() == name) {
            index = i;
            break;
        }
    }

    // Check if component name was found
    if (index == -1) {
        std::string msg = "Model component \""+name+"\" not found. Please "
                          "specify a valid model component name.";
        throw GException::invalid_argument(G_COMPONENT_NAME, msg);
    }

    // Return pointer to data model
    return m_models[index];
}


/***********************************************************************//**
 * @brief Print multiplicative data model information
 *
 * @param[in] chatter Chattiness.
 * @return String containing model information.
 ***************************************************************************/
std::string GModelDataMultiplicative::print(const GChatter& chatter) const
{
    // Initialise result string
    std::string result;

    // Continue only if chatter is not silent
    if (chatter != SILENT) {

        // Append header
        result.append("=== GModelDataMultiplicative ===");

        // Append information
        result.append("\n"+gammalib::parformat("Number of components"));
        result.append(gammalib::str(components()));
        result.append("\n"+gammalib::parformat("Number of parameters"));
        result.append(gammalib::str(size()));

        // Print parameter information
        for (int i = 0; i < size(); ++i) {
            result.append("\n"+m_pars[i]->print(chatter));
        }

    } // endif: chatter was not silent

    // Return result
    return result;
}


/*==========================================================================
 =                                                                         =
 =                             Private methods                             =
 =                                                                         =
 ==========================================================================*/

/***********************************************************************//**
 * @brief Initialise class members
 ***************************************************************************/
void GModelDataMultiplicative::init_members(void)
{
    // Initialise model type
    m_type = "MultiplicativeData";

    // Clear models
    m_models.clear();

    // Return
    return;
}


/***********************************************************************//**
 * @brief Copy class members
 *
 * @param[in] model Multiplicative data model.
 ***************************************************************************/
void GModelDataMultiplicative::copy_members(const GModelDataMultiplicative& model)
{
    // Get number of models
    int num = model.m_models.size();

    // Copy members
    m_type = model.m_type;

    // Clone models
    m_models.clear();
    for (int i = 0; i < num; ++i) {
        m_models.push_back(model.m_models[i]->clone());
    }

    // Store pointers to model parameters
    m_pars.clear();
    for (int i = 0; i < m_models.size(); ++i) {

        // Retrieve data model
        GModelData* model = m_models[i];

        // Loop over parameters
        for (int ipar = 0; ipar < model->size(); ++ipar) {

            // Get model parameter reference
            GModelPar& par = model->operator[](ipar);

            // Append model parameter pointer to internal container
            m_pars.push_back(&par);

        }
    }

    // Return
    return;
}


/***********************************************************************//**
 * @brief Delete class members
 ***************************************************************************/
void GModelDataMultiplicative::free_members(void)
{
    // Free memory
    for (int i = 0; i < m_models.size(); ++i) {
        if (m_models[i] != NULL) {
            delete m_models[i];
        }
    }

    // Clear models
    m_models.clear();

    // Return
    return;
}
