#!/usr/bin/env python3

import pandas as pd
import sys
import os
import numpy as np


import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter




def main():
    if len(sys.argv) < 2:
        print("Usage: python process_mission.py results_dict")
        print("       must be at least one results dict")
        sys.exit(1)

    # plot it
    clrs = ['blue', 'navy', 'blue', 'darkviolet', 'dodgerblue']
    print(clrs[2])
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    plt.rcParams["font.family"] = "serif"
    #plt.rcParams["font.weight"] = "bold"
    plt.rcParams["xtick.labelsize"] = 22
    plt.rcParams["ytick.labelsize"] = 22
    plt.rcParams["axes.titlesize"] = 22
    plt.rcParams["axes.labelsize"] = 30
    plt.rcParams["legend.fontsize"] = 14

    ax.hlines(y=1-0.707, xmin=0, xmax=100, color='r', linestyles='--',
              label='Ave of 100 simulations (Kemna. et al)')
    #ax.hlines(y=0.558, xmin=0, xmax=len(batch_nub)-1, linestyles='--', color='r')
    #ax.hlines(y=0.833, xmin=0, xmax=len(batch_nub)-1, linestyles='--', color='r')
    
    ax.fill_between(range(0,100), 1-0.558, 1-0.833, alpha=0.3, color='r', label='Min/Max (Kemna. et al)')

        
    for j in range(1,len(sys.argv)):
        print(j)
        filename = sys.argv[j]
        #   print(filename)

        param_dict = {}
        try:
            results = np.load(filename, allow_pickle=True)

            for key, value in results.items():
                batch_start_indx = value.item().get('cent.csv').get('batch_start_indx')
                costs = value.item().get('cent.csv').get('costs')

        
        except FileNotFoundError:
            print("Error: File '{filename}' not found.")


        
        # print the average costs
        batch_nub = []
        batch_mean = []
        batch_var = []
        batch_std = []
        batch_max = []
        batch_min = []
        print(' -------------------------------------------')
        print(filename)
        print(' -------------------------------------------')
        batches =  len(batch_start_indx)
        for i in range(batches):
            print(' batch = ' + str(i))
            batch_start  = batch_start_indx[i]

            if (i == batches - 1):
                batch_end = -1
            else:
                batch_end  = batch_start_indx[i+1] -1
            print(' batch start = ' + str(batch_start) + ', batch end = ' + str(batch_end))

            # TODO:  Remove the problematic starting to one run. 
            if ((batch_start < 4) and (filename == 'ok_runs/take_four/results_dict.npz')):
                continue
            
            costs_this_batch = costs[batch_start:batch_end]
            average_costs = np.mean(costs_this_batch)
            variance_costs  = np.var(costs_this_batch)
            max_cost = np.max(costs_this_batch)
            min_cost = np.min(costs_this_batch)

            
            print(' ave cost = ' + str(average_costs) + ' variance = ' + str(variance_costs))

            # Hack
            if (filename == 'ok_runs/take_four/results_dict.npz'):
                batch_nub.append(i-4)
            else:
                batch_nub.append(i)
            batch_mean.append(average_costs)
            batch_var.append(variance_costs)
            batch_std.append(np.sqrt(variance_costs))
            batch_min.append(min_cost)
            batch_max.append(max_cost)


        batch_nub = np.array(batch_nub)
        batch_mean = np.array(batch_mean)
        batch_var = np.array(batch_var)
        batch_std = np.array(batch_std)
        batch_min = np.array(batch_min)
        batch_max = np.array(batch_max)


  
        if (clrs[j] == 'blue'):
            ax.plot(batch_nub,  batch_mean/2.0, clrs[j], label='Ours')
        else:
            ax.plot(batch_nub,  batch_mean/2.0, clrs[j])
    
        #ax.fill_between(batch_nub, batch_mean - 3 * batch_std, batch_mean + 3 * batch_std, alpha=0.3)
        ax.fill_between(batch_nub, batch_min/2.0, batch_max/2.0, alpha=0.1, color='b')

    
    ax.set_ylabel("Cost\n(% of areas not sampled)",  fontsize=16)
    ax.set_xlabel("Mini-Batch",  fontsize=16)
    ax.tick_params(axis='both', labelsize=16)
    # set the limits
    
    ax.set_xlim([0,30])
    ax.set_ylim(bottom=0.0, top=1.0)
    ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    #ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    ax.xaxis.set_ticks(np.arange(0, 31, 5.0))
    #ax.set_title('line plot with data points')
    ax.legend()
    ax.grid()
    
    fig.set_size_inches(10, 5)
    fig.subplots_adjust(bottom=0.2)
    #fig.subplots_adjust(left=0.14)
    # display the plot
    plt.show()
    fig.savefig('loss.pdf')
    
 
    quit()

        
            
if __name__ == "__main__":
    main()
