# -*- coding: utf-8 -*-

# Longxu's Multi-Nomial Logit Module, 
# designed for estimating MNLogit model with super large samples or many alternatives
# Longxu is an Associate Professor at Department of Urban Planning, Tongji University
# His email = yanlongxu@tongji.edu.cn

import numpy  as np
import pandas as pd
from scipy import stats
import warnings


# %%
#####################################################################################################

class LMNLogit():
    def __init__(self, modelname="LMNLogit"):
        self.modelname = modelname
        pass 

    def cal_utility(self):
        V1 = np.zeros((self.i, self.j))
        for w, matrix in zip(self.weight_X1_matrix, self.X1_matrix_arr):
            V1 += w * matrix
        V2 = (self.weight_X2_arr * self.X2_arr).sum(axis=1) # X2_arr is (j, k); sum over attributes to get shape (j,)
        return V1 + V2

    def validate_inputs(self, initial_weights):
        if self.Y_matrix.ndim != 2:
            raise ValueError("Y_matrix must be a 2D matrix")
        if self.AV_matrix.shape != self.Y_matrix.shape:
            raise ValueError("AV_matrix must have the same shape as Y_matrix")
        if np.any(self.Y_matrix < 0):
            raise ValueError("Y_matrix cannot contain negative values")
        if np.any(self.AV_matrix < 0):
            raise ValueError("AV_matrix cannot contain negative values")
        if np.any((self.Y_matrix > 0) & (self.AV_matrix <= 0)):
            raise ValueError("positive choices in Y_matrix must be available in AV_matrix")
        if np.any(self.AV_matrix.sum(axis=1) <= 0):
            raise ValueError("each row in AV_matrix must have at least one available alternative")
        if self.X1_matrix_arr.shape != (self.m, self.i, self.j):
            raise ValueError("X1_matrix_arr must have shape (len(X1_names), i, j)")
        if self.X2_arr.shape != (self.j, self.k):
            raise ValueError("X2_arr must have shape (j, len(X2_names))")
        if len(initial_weights) not in (0, self.total_num_weight):
            raise ValueError("initial_weights must be empty or have length len(X1_names) + len(X2_names)")
    
    ########## calculate probability 
    def cal_Probt(self):
        V = self.cal_utility()  # V has shape (i, j)
        available = self.AV_matrix > 0
        Vmax = np.where(available, V, -np.inf).max(axis=1).reshape((-1, 1))
        V = np.where(available, V - Vmax, -np.inf)
        Probt = np.exp(V) * self.AV_matrix  # normalize over available alternatives only
        Probt = Probt / Probt.sum(axis=1).reshape((-1, 1)) 
        return Probt

    ########## calculate neg log Likelihood
    def cal_neg_LnLikelihood(self):
        V = self.cal_utility()  # (i,j)
        available = self.AV_matrix > 0
        Vmax = np.where(available, V, -np.inf).max(axis=1).reshape((self.i, 1))
        V = np.where(available, V - Vmax, -np.inf)
        log_denom = np.log((np.exp(V) * self.AV_matrix).sum(axis=1).reshape((self.i, 1)))
        weight = self.Y_matrix * self.AV_matrix
        neg_LnLikelihood = np.zeros_like(self.Y_matrix)
        used = weight > 0
        neg_LnLikelihood[used] = (log_denom - V)[used] * weight[used]
        return neg_LnLikelihood.sum() 
    
    ########## calculate neg log Likelihood of L0
    def cal_neg_LnLikelihood_L0(self):
        V = self.AV_matrix
        available = self.AV_matrix > 0
        Vmax = np.where(available, V, -np.inf).max(axis=1).reshape((self.i, 1))
        V = np.where(available, V - Vmax, -np.inf)
        log_denom = np.log((np.exp(V) * self.AV_matrix).sum(axis=1).reshape((self.i, 1)))
        weight = self.Y_matrix * self.AV_matrix
        neg_LnLikelihood = np.zeros_like(self.Y_matrix)
        used = weight > 0
        neg_LnLikelihood[used] = (log_denom - V)[used] * weight[used]
        return neg_LnLikelihood.sum() 

    ########## calculate the derivative of neg log Likelihood over parameters 
    def cal_gradient(self):
        Probt = self.cal_Probt()

        # self.X1_matrix_arr has shape (m, i, j).
        gradient1 = np.array([(( (Probt * xx).sum(axis=1).reshape((self.i,1)) - xx ) * self.Y_matrix * self.AV_matrix ).sum() for xx in self.X1_matrix_arr])
        # (m)
        
        # self.X2_arr has shape (j, k).
        gradient2 = np.array([( (np.dot(Probt,xx.reshape(-1,1)) - xx) * self.Y_matrix * self.AV_matrix ).sum() for xx in self.X2_arr.T])
        # xx.shape = (j)
        # np.dot(Probt,xx.reshape(-1,1)).shape = (i, 1) 
        # After subtracting xx, broadcasting gives shape (i, j), matching self.Y_matrix.

        gradient = np.concatenate( [gradient1, gradient2] ) 

        return gradient 
    

    ########## calculate Hessian matrix and standard errors
    def cal_hessian(self, SE=False):
        hess = np.zeros((self.total_num_weight, self.total_num_weight))
        Probt = self.cal_Probt()
        
        # self.X1_matrix_arr stores m matrix-form explanatory variables.
        # self.X2_arr stores k alternative-level attributes.
        for p in range(self.total_num_weight):
            for q in range(p, self.total_num_weight):
                # q >= p
                if p < self.m and q< self.m:  # both parameters are from X1_matrix_arr
                    second_order_derivative  = ((
                                 (Probt * self.X1_matrix_arr[p] * self.X1_matrix_arr[q]).sum(axis=1) - \
                                 (Probt * self.X1_matrix_arr[p]).sum(axis=1) * (Probt * self.X1_matrix_arr[q]).sum(axis=1)
                                                ).reshape(-1,1) * self.Y_matrix  * self.AV_matrix).sum()
                    
                elif p >= self.m and q >= self.m: # both parameters are from X2_arr
                    second_order_derivative  = ((
                                 (Probt * self.X2_arr[:,p-self.m] * self.X2_arr[:,q-self.m]).sum(axis=1) - \
                                 (Probt * self.X2_arr[:,p-self.m]).sum(axis=1) * (Probt * self.X2_arr[:,q-self.m]).sum(axis=1)
                                                ).reshape(-1,1) * self.Y_matrix * self.AV_matrix).sum()

                    
                elif p < self.m and q >= self.m:
                    second_order_derivative  = ((
                                 (Probt * self.X1_matrix_arr[p] * self.X2_arr[:,q-self.m]).sum(axis=1) - \
                                 (Probt * self.X1_matrix_arr[p]).sum(axis=1) * (Probt * self.X2_arr[:,q-self.m]).sum(axis=1)
                                                ).reshape(-1,1) * self.Y_matrix * self.AV_matrix).sum()

                elif p >= self.m and q< self.m:
                    second_order_derivative  = ((
                                 (Probt * self.X2_arr[:,p-self.m] * self.X1_matrix_arr[q]).sum(axis=1) - \
                                 (Probt * self.X2_arr[:,p-self.m]).sum(axis=1) * (Probt * self.X1_matrix_arr[q]).sum(axis=1)
                                                ).reshape(-1,1) * self.Y_matrix * self.AV_matrix).sum()
                else:
                    # impossible 
                    raise 

                hess[p,q] = second_order_derivative

        for p in range(len(hess)):
            for q in range(len(hess)):
                if q < p:
                    hess[p,q] = hess[q,p]
                    
        if SE:
            # See Discrete Choice Methods with Simulation, p. 228,
            # and https://ww2.mathworks.cn/matlabcentral/answers/153414-estimator-standard-errors-using-fmincon-portfolio-optimization-context
            self.hessian_condition_number = np.linalg.cond(hess)
            if (not np.isfinite(self.hessian_condition_number)) or self.hessian_condition_number > 1e12:
                warnings.warn("Hessian matrix is ill-conditioned; standard errors may be unreliable.", RuntimeWarning)
            try:
                tmp = np.linalg.solve(hess, np.eye(hess.shape[1]))                  # inverse via linear solve
            except np.linalg.LinAlgError:
                warnings.warn("Hessian matrix is singular; adding a small ridge term.", RuntimeWarning)
                tmp = np.linalg.solve(hess + np.eye(hess.shape[1])*1e-8, np.eye(hess.shape[1]))
            tmp = np.diag(tmp)                                                      # diagonal of the inverse Hessian
            if tmp.min() <= 0:
                warnings.warn("Hessian inverse has non-positive diagonal values; standard errors may be unreliable.", RuntimeWarning)
                tmp = np.linalg.solve(hess + np.eye(hess.shape[1])*1e-8, np.eye(hess.shape[1]))
                tmp = np.diag(tmp)
            standard_err = np.sqrt(np.where(tmp > 0, tmp, np.nan))                  # square root of variances
            return standard_err, hess
        else:
            return hess
    
    # Two-sided large-sample significance test.
    def Ttest(self, T_statistic):  
        return np.around(2 * stats.norm.sf(np.abs(T_statistic)), 5)

    
    def adjust_alpha(self, gradient=None):
        if gradient is None:
            gradient = self.cal_gradient() 
        # Hessian/Newton-style method.
        if self.fit_method == "hessian":
            self.alpha = np.zeros(self.total_num_weight) + self.alpha 
        # Normal gradient descent.
        elif self.fit_method == "gradient":
            adjustment = - gradient
            max_adjustment = np.abs(adjustment).max()
            if max_adjustment == 0:
                self.alpha = np.zeros(self.total_num_weight) + self.alpha
            else:
                self.alpha = np.ones(self.total_num_weight) / max_adjustment * self.alpha 
        else:
            raise ValueError("wrong input of fit_method, must be one of 'hessian' or 'gradient'")
        return 
    
    # estimation per epoch 
    def cal_weight_adjustment(self, k ):
        # Gradient of the negative log-likelihood.
        gradient = self.cal_gradient()
        gradient2 = np.dot(gradient, gradient)
        
        # Adapt the learning rate based on the gradient direction.
        if k >= 1:
            for ind, a_dot in enumerate(self.Previous_gradient*gradient):
                if a_dot >= 0:
                    if self.alpha[ind] > 1:
                        self.alpha[ind] += self.alpha_step
                    else:
                        self.alpha[ind] *= (1 + self.alpha_step)
                else:
                    self.alpha[ind] *= 0.75
        
        # Hessian/Newton-style update.
        if self.fit_method == "hessian":
            # Calculate the Hessian matrix.
            hessian = self.cal_hessian(SE=False)
            # The gradient here is for the negative log-likelihood.
            try:
                adjustment = -self.alpha * np.linalg.solve(hessian, gradient) 
            except np.linalg.LinAlgError:
                adjustment = -self.alpha * np.linalg.solve(hessian + np.eye(hessian.shape[1])*1e-6, gradient)

        # Normal gradient descent.
        elif self.fit_method == "gradient":
            self.Momentum = self.beta * self.Momentum - self.alpha * gradient 
            adjustment = self.Momentum
        else:
            raise ValueError("wrong input of fit_method, must be one of 'hessian' or 'gradient'")

        if self.verbose:
            which = np.abs(adjustment).argmax()
            print ( "{}, gradient2= {:0.3e}, max_adjust= {:.3e}, which= {}, max_alpha= {:.3e}".format(k, gradient2, adjustment[which] , which, self.alpha[which] ) )
        
        self.Previous_gradient = gradient 
        return adjustment, gradient2
        
    # accuracy
    def cal_accuracy(self):
        # If rows contain aggregate counts, allocate choices by probability.
        # If rows represent individual observations, assign the highest-probability alternative.
        Probt = self.cal_Probt()
        person = self.Y_matrix.sum(axis=1)
        if person.max() > 1:
            prediction_choice = person.reshape(-1,1) * Probt
        elif person.max() == 1:
            prediction_choice = np.zeros((self.i, self.j))
            for ind, maxind in enumerate(Probt.argmax(axis=1)):
                if person[ind] == 0:
                    continue
                prediction_choice[ind, maxind] = 1
        else:
            raise ValueError("wrong input of Y_matrix")
        accuracy = 1 - (np.abs(prediction_choice - self.Y_matrix)).sum() / 2.0 / person.sum()
        prediction_agg = prediction_choice.sum(axis=0)
        correlation = np.corrcoef(  np.array([self.Y_matrix.sum(axis=0), prediction_agg])  )[0, 1]
        return accuracy, correlation, prediction_agg
    
    # Finite-difference gradient check.
    def check_gradient(self):
        gradient = self.cal_gradient()
        old_likelihood = self.cal_neg_LnLikelihood()
        
        old_weight = np.concatenate([self.weight_X1_matrix, self.weight_X2_arr])
        
        cal_gradient = []
        for ind in range(self.total_num_weight):
            self.weight_X1_matrix = old_weight[:self.m]
            self.weight_X2_arr = old_weight[self.m:]
            
            new_weight = old_weight.copy()
            new_weight[ind] += 1e-9

            self.weight_X1_matrix = new_weight[:self.m]
            self.weight_X2_arr = new_weight[self.m:]

            new_likelihood = self.cal_neg_LnLikelihood()
            cal_gradient.append(  (new_likelihood-old_likelihood)/1e-9 )
        
        print ("model_gradient = ", gradient)
        print ("check_gradient = ", cal_gradient)
    
    
    ###################################################################################################
    # estimation 
    def fit(self, Y_matrix, AV_matrix, X1_matrix_arr, X2_arr, X1_names, X2_names, initial_weights=[], 
            verbose = False, alpha = 1.0, alpha_step = 0.01, max_iteration = 1000, fit_method = "hessian", threshold = 100.0):
        
        # verbose=True prints iteration-level diagnostics.
        self.verbose = verbose
        self.fit_method = fit_method
        self.alpha = alpha
        self.m = len(X1_names)
        self.k = len(X2_names)
        self.total_num_weight = self.m + self.k
        self.alpha_step = alpha_step
        self.beta = 0.9 
        self.Momentum = 0
        
        # Y_matrix (i, j) stores observed choices or aggregate choice counts.
        # It can also represent individual one-hot choices.
        self.Y_matrix  = np.asarray(Y_matrix, dtype=float)  # (i,j)
        self.i, self.j = self.Y_matrix.shape 

        # AV_matrix (i, j) indicates whether alternative j is available to row i.
        self.AV_matrix     = np.asarray(AV_matrix, dtype=float) # (i,j)

        # X1_matrix_arr (m, i, j) stores matrix-form variables, such as travel impedance from i to j.
        if self.m == 0 and len(X1_matrix_arr) == 0:
            self.X1_matrix_arr = np.zeros((0, self.i, self.j))
        else:
            self.X1_matrix_arr = np.asarray(X1_matrix_arr, dtype=float) # (m,i,j)

        # X2_arr (j, k) stores alternative-level attributes for each alternative j.
        if self.k == 0 and len(X2_arr) == 0:
            self.X2_arr = np.zeros((self.j, 0))
        else:
            self.X2_arr = np.asarray(X2_arr, dtype=float) # (j, k)
        
        self.validate_inputs(initial_weights)
        
        # Initialize parameters with zeros unless initial_weights is supplied.
        if len(initial_weights) == 0:
            self.weight_X1_matrix = np.zeros( self.m )
            self.weight_X2_arr    = np.zeros( self.k )
        else:
            initial_weights = np.asarray(initial_weights, dtype=float)
            self.weight_X1_matrix = initial_weights[:self.m]
            self.weight_X2_arr = initial_weights[self.m:]

        # Initialize parameters used by the optimizer.
        self.adjust_alpha()
        # if (self.total_num_weight > 10 or self.i+self.j>1000) and self.fit_method == "hessian":
        #     print ("'hessian' may be very slow! try fit_method = 'gradient' ")
        
        # Degrees of freedom retained for backward compatibility.
        self.degree_freedom = Y_matrix.sum() - self.m - self.k - 1
        self.store_adjustment = []
        # Parameter estimation loop.
        converged = False
        for k in range(max_iteration+1):
            self.weight = np.concatenate([self.weight_X1_matrix, self.weight_X2_arr])
            adjustment, gradient2  = self.cal_weight_adjustment( k )
            if gradient2 < threshold:
                converged = True
                break
            self.weight += adjustment
            
            self.weight_X1_matrix = self.weight[:self.m]
            self.weight_X2_arr = self.weight[self.m:]
                
        if not converged:
            raise ValueError("max_iteration arrived! estimation failed to converge!")

        ###### Model-level summary
        summary_model = []
        summary_model.append( ["iteration", k] )                                                      # number of iterations
        summary_model.append( ["ODshape", str(Y_matrix.shape)] )             
        summary_model.append( ["sample_num", Y_matrix.sum()] )                                        # sample size or total count
        summary_model.append( ["fit_method", self.fit_method] )                                       # optimizer

        neg_LnLikelihood_L0 = self.cal_neg_LnLikelihood_L0()
        summary_model.append( ["neg_logLikelihood_L0", neg_LnLikelihood_L0 ] )                        # null-model negative log-likelihood

        neg_LnLikelihood_model = self.cal_neg_LnLikelihood()
        summary_model.append( ["neg_logLikelihood_model", neg_LnLikelihood_model ] )                  # fitted-model negative log-likelihood

        likehd_ratio =  neg_LnLikelihood_model / neg_LnLikelihood_L0 
        summary_model.append( ["logLikelihood_ratio", likehd_ratio ] )                                # likelihood ratio

        rho_squared =  1 - likehd_ratio 
        summary_model.append( ["rho_squared", rho_squared ] )                                         # pseudo R-squared

        rho_squared_bar =  1 - (neg_LnLikelihood_model + self.total_num_weight)/ neg_LnLikelihood_L0 
        summary_model.append( ["rho_squared_bar", rho_squared_bar ] )                                 # adjusted pseudo R-squared

        Akaike_information_criterion =  2 * neg_LnLikelihood_model + 2 * self.total_num_weight
        summary_model.append( ["Akaike_information_criterion", Akaike_information_criterion ] )       # AIC

        Bayesian_information_criterion =  2 * neg_LnLikelihood_model + self.total_num_weight * np.log(self.Y_matrix.sum()) # BIC
        summary_model.append( ["Bayesian_information_criterion", Bayesian_information_criterion ] )  

        accuracy, correlation, prediction_agg = self.cal_accuracy()
        self.prediction_agg = prediction_agg
        summary_model.append( ["accuracy", accuracy ] )  
        summary_model.append( ["Pearson_corr_agg_choice_set", correlation ] )  

        gradient = self.cal_gradient( )
        summary_model.append( ["final_gradient^2", np.dot(gradient, gradient)] )                      # squared gradient norm

        self.summary_model = pd.DataFrame( summary_model, columns=["item", "value"] )

        
        ###### Parameter-level summary
        summary_paras = []

        weights = np.concatenate([self.weight_X1_matrix, self.weight_X2_arr]) # estimated coefficients
        summary_paras.append( weights )   

        standarderr, hessian = self.cal_hessian(SE=True)   # coefficient standard errors
        T = weights / standarderr                          # test statistic
        P = self.Ttest( T )                                # two-sided p-values
        summary_paras.append( standarderr )
        summary_paras.append( T )
        summary_paras.append( P )
        
        summary_paras = np.array(summary_paras).T
        self.summary_parameters = pd.DataFrame( summary_paras, columns=["estimation", "std err", "T", "p-value"], index = [*X1_names, *X2_names] )
        self.summary_parameters.index.name = "parameters"

        # Calculate fitted aggregate predictions.
        self.prediction_matrix, self.prediction_agg = self.predict_agg(Y_matrix.sum(axis=1))

        print ("estimation succeeded! see model.summary_model and model.summary_parameters")



    # Predict aggregate choices using the fitted probabilities.
    def predict_agg(self, Y_individuals):
        Probt = self.cal_Probt()
        prediction_matrix = Y_individuals.reshape((-1,1)) * Probt
        prediction_agg = prediction_matrix.sum(axis=0)
        return prediction_matrix, prediction_agg
