# Takes a sample of binary variables and generates an
# undirected graph where edges represent (conditional) 
# dependence between two variables

import numpy as np
from statsmodels.formula.api import logit
import pandas as pd

# Iterate across conditioning sets Z for increasing
# cardinality, removing edges where independence is
# established
def prune_graph(sample, pthres=0.01, verbose=False):
    
    nodes = list(sample)
    N_nodes = len(nodes)

    # Generate all unique subsets of a given
    # cardinality, for a given node pair
    # Credit: https://www.geeksforgeeks.org/find-all-unique-subsets-of-a-given-set/

    res = set()
    subset = []
    result = []

    # Insert an empty subset into the resultant set
    res.add(tuple(subset))

    # Iterate over all elements
    for i in range(N_nodes):
        N = len(res)
        for it in res:
            # Iterate through every subset generated till now and insert the current element in the end of it
            subset = list(it)
            subset.append(nodes[i])
            result.append(subset)
            N -= 1
            if N == 0:
                break
        res.update([tuple(x) for x in result])
        result.clear()

    subsets = [list(x) for x in res]
    D = np.full((N_nodes,N_nodes), False)
    D_iter = {}
    D_iter['Full'] = D

    # Iterate over each cardinality
    for c in range(N_nodes-2):
        # Get all conditioning sets for this cardinality
        Zs = []
        if c == 0:
            Zs.append([])
        else:
            for Z in subsets:
                if len(Z) == c:
                    Zs.append(Z)
        
        if verbose:
            print('Cardinality {}:'.format(c))
        # Iterate over node pairs
        for i in range(N_nodes-1):
            for j in range(i+1, N_nodes):
                if not D[i,j]: # If they have not already been pruned
                    # Iterate over conditioning sets
                    for Z in Zs:
                        if not (nodes[i] in Z or nodes[j] in Z):
                            # Test independence with this conditioning set
                            formula = '{} ~ {}'.format(nodes[i], nodes[j])
                            for node in Z:
                                formula = '{} + {}'.format(formula, node)
                            
                            logit_model = logit(formula, sample)
                            summary = logit_model.fit(disp=0).summary2()
                            pval = summary.tables[1].loc[nodes[j],'P>|z|']
                            if pval > pthres:
                                # Consider these nodes conditionally independent 
                                D[i,j] = True

        D_iter[c] = D.copy()
        
        for i in range(N_nodes-1):
            for j in range(i+1, N_nodes):
                str = 'Dependent'
                if D[i,j]:
                    str = 'Independent'
                if verbose:
                    print('{}, {}: {}'.format(nodes[i], nodes[j], str))

    # Invert all the graphs
    D_iter['Final'] = D
    G = {}
    for d in list(D_iter.keys()):
        G[d] = np.invert(D_iter[d])
        np.fill_diagonal(G[d], False)

    return G

# Given a sample and a pruned graph G, attempt to assign causality
# (directionality) to G, assuming a DAG. The resulting graph has 
# asymmetric values (1,0) for directed edges and symmetric (1,1) for
# undirected edges. 
def assign_causality(sample, G, nodes, verbose=True):

    # Configurations
    C = ['ChainLeft','ChainRight','Collider','Fork']

    # Identify all unique triplets
    T = get_triplets(G, nodes)
    if verbose:
        print(T)

    # For each triplet, assign all configurations
    N_tri = len(T)
    T_config = {}
    for i in range(N_tri):
        T_config[i] = C.copy()

    G_dir = G.copy()

    # For each triplet, prune configurations based on
    #  1. Existing arrows
    #  2. Independence
    has_changed = True
    while has_changed:
        has_changed = False
        for i in range(N_tri):
            # Remove configurations if arrows exist
            Ti = T[i]
            if verbose:
                print('{}: {}'.format(Ti, T_config[i]))
            
            len_cfg = len(T_config[i])
            Tidx = [nodes.index(Ti[0]), nodes.index(Ti[1]), nodes.index(Ti[2])]
            if G_dir[Tidx[0],Tidx[1]] and not G_dir[Tidx[1],Tidx[0]]:
                if 'ChainRight' in T_config[i]: 
                    T_config[i].remove('ChainRight')
                if 'Fork' in T_config[i]: 
                    T_config[i].remove('Fork')
            if G_dir[Tidx[1],Tidx[0]] and not G_dir[Tidx[0],Tidx[1]]:
                if 'ChainLeft' in T_config[i]: 
                    T_config[i].remove('ChainLeft')
                if 'Collider' in T_config[i]: 
                    T_config[i].remove('Collider')
            if G_dir[Tidx[1],Tidx[2]] and not G_dir[Tidx[2],Tidx[1]]:
                if 'ChainRight' in T_config[i]: 
                    T_config[i].remove('ChainRight')
                if 'Collider' in T_config[i]: 
                    T_config[i].remove('Collider')
            if G_dir[Tidx[2],Tidx[1]] and not G_dir[Tidx[1],Tidx[2]]:
                if 'ChainLeft' in T_config[i]: 
                    T_config[i].remove('ChainLeft')
                if 'Fork' in T_config[i]: 
                    T_config[i].remove('Fork')

            # Test pairs for independence
            if (G_dir[Tidx[0],Tidx[1]] and G_dir[Tidx[1],Tidx[0]]) or \
               (G_dir[Tidx[1],Tidx[2]] and G_dir[Tidx[2],Tidx[1]]):
                dep = get_dependence(sample, Ti[0], Ti[2])
                if verbose:
                    print('{},{}: Dependent? {}'.format(Ti[0], Ti[2], dep))
                if not dep:
                    # This is a collider
                    if 'ChainRight' in T_config[i]: 
                        T_config[i].remove('ChainRight')
                    if 'ChainLeft' in T_config[i]: 
                        T_config[i].remove('ChainLeft')
                    if 'Fork' in T_config[i]:   
                        T_config[i].remove('Fork')
                else:
                    # Remove collider and set arrows
                    if 'Collider' in T_config[i]: 
                        T_config[i].remove('Collider')

            if len(T_config[i]) == 1:
                if T_config[i][0] == 'Collider':
                    G_dir[Tidx[1],Tidx[0]] = False
                    G_dir[Tidx[1],Tidx[2]] = False
                    if verbose:
                        print('Collider at {}: {}'.format(Ti, Tidx))
                elif T_config[i][0] == 'ChainLeft':
                    G_dir[Tidx[1],Tidx[0]] = False
                    G_dir[Tidx[2],Tidx[1]] = False
                    if verbose:
                        print('ChainLeft at {}: {}'.format(Ti, Tidx))
                elif T_config[i][0] == 'ChainRight':
                    G_dir[Tidx[0],Tidx[1]] = False
                    G_dir[Tidx[1],Tidx[2]] = False
                    if verbose:
                        print('ChainRight at {}: {}'.format(Ti, Tidx))
                else:
                    G_dir[Tidx[0],Tidx[1]] = False
                    G_dir[Tidx[2],Tidx[1]] = False
                    if verbose:
                        print('Fork at {}: {}'.format(Ti, Tidx))
                    
            has_changed = len_cfg != len(T_config[i])

    return G_dir

# Returns sets of three nodes {A,B,C} in G where AB and BC are
# edges but AC is not
def get_triplets(G, nodes):

    T = {}
    N_nodes = len(nodes)

    # Make G symmetric
    for i in range(0,N_nodes-1):
        for j in range(i+1,N_nodes):
            G[j,i] = G[i,j]

    T = []

    for i in range(N_nodes):
        idx = np.flatnonzero(G[i,:])
        if len(idx) > 1:
            for j in range(0,len(idx)-1):
                for k in range(j+1,len(idx)):
                    # Ensure no edge between j and k
                    if not G[j,k]:
                        T.append([nodes[idx[j]], nodes[i], nodes[idx[k]]])

    return T

# Determines whether binary variables node1 and node2 in sample 
# are dependent, given a p-value threshold pthres.
# sample must be a Pandas table with node1 and node2 as
# column names.
def get_dependence(sample, node1, node2, pthres=0.01):

    formula = '{} ~ {}'.format(node1, node2)
    logit_model = logit(formula, sample)
    summary = logit_model.fit(disp=0).summary2()
    pval = summary.tables[1].loc[node2,'P>|z|']
    return pval <= pthres