# Test stability of discovered networks across:
#   N_samp samples
#   N_size sample sizes
import json
from stat_functions import generate_sample
from graph_functions import prune_graph
from graph_functions import assign_causality
from graph_io import network_as_matrix
import numpy as np
import pandas as pd
import warnings

# Configuration
network_file = 'smoking_cancer.json'
output_file = 'smoking_stability_tests.csv'
N_samp = 250
pthres = 0.01
sizes = range(200,2001,200)
N_size = len(sizes)
verbose = True
debug = False
max_iter = 5

# Load ground-truth network
# Load network from JSON file
with open(network_file) as file_in:
    network = json.load(file_in)

network_name = list(network.keys())[0]
network = network[network_name]
nodes = list(network.keys())
N_nodes = len(nodes)

print('Loaded {} with {} nodes.'.format(network_name, N_nodes))
print('Evaluating N = {}'.format(' '.join([str(i) for i in sizes])))

# Convert network to Numpy matrix
G_true = network_as_matrix( network )

Agreement = {}
Skipped = {}
denom = np.sum(np.flatnonzero(G_true))

warnings.filterwarnings("error")
    
for N in sizes:
    if verbose:
        print('Evaluating N={}'.format(N))

    Agreement[N] = np.zeros(N_samp)
    i = 0
    Skipped[N] = 0
    while i < N_samp:

        # Generate a sample
        itr = 0
        while True:
            sample = generate_sample( network, N )
            # Only accept if not rank-deficient
            rank = np.linalg.matrix_rank(sample)
            if rank == N_nodes:
                break
            Skipped[N] += 1
            itr += 0
            if itr > max_iter:
                raise Exception('Max of {} singular matrices produced. Aborting.'.format(max_iter))
               
        sample = pd.DataFrame(data=sample, columns=nodes)

        # Discover the graph
        try:
            pruned = prune_graph(sample, pthres, verbose=debug)
        except Exception as ex:
            if debug:
                print('Warning: Solution not found for this sample')
            Skipped[N] += 1
            continue
            #raise(ex)
        G = assign_causality(sample, pruned['Final'], nodes, verbose=debug)

        # Compare with ground truth and add to summary
        Agreement[N][i] = np.sum(np.flatnonzero(np.logical_and(G,G_true))) / denom

        i += 1

    if verbose:
        print('Agreement M={:0.2f}, S={:0.3f}; skipped {}'.format(np.mean(Agreement[N]), \
                                                                  np.std(Agreement[N]), \
                                                                  Skipped[N]))

with open(output_file, 'w') as fout:
    # Header
    fout.write('N,mean_agree,std_agree,skipped\n')
    # One line for each N
    for N in sizes:
        fout.write('{},{:0.3f},{:0.5f},{}\n'.format(N, \
                                                  np.mean(Agreement[N]), \
                                                  np.std(Agreement[N]), \
                                                  Skipped[N]))
        
if verbose:
    print('Results written to {}'.format(output_file))