#!/usr/bin/env python3

import pandas as pd
import sys
from os import listdir
from os.path import isfile, join
import numpy as np

from process_mission_helpers import findBatchStartAndChecks, computeCosts, getParam


def main():

    
    if len(sys.argv) != 2:
        print("Usage: python process_mission2.py path_to_results_files")
        sys.exit(1)

    data_dir = sys.argv[1]
    print('Processing files in ' + data_dir)

    files_to_process = [f for f in listdir(data_dir) if isfile(join(data_dir, f))]

    results_dict = {}
    for filename in files_to_process:
        print(filename)
        try:
            df = pd.read_csv(data_dir + '/' +filename)

            # process this df
            # step 1 sort by date in the
            df.sort_values(by='eval_' + filename[:-4] + '_job', inplace=True)
            # reset the indexes
            df.reset_index(drop=True, inplace=True)

            # get the back indexs and check that all vehicles had the same params
            batch_start_indx, all_vehicles_same_params = findBatchStartAndChecks(df)
            
            if (all_vehicles_same_params):

                # compute costs/rewards for each row in df
                costs = computeCosts(df)
                new_dict = {}
                new_dict['df'] = df
                new_dict['batch_start_indx'] = batch_start_indx
                new_dict['all_good'] = all_vehicles_same_params
                new_dict['costs'] = costs
                results_dict[filename] = new_dict
                
            else:
                print('Error: This file did not pass the checks: ' + 'eval_' + filename[:-4] + '_job')
            
        except FileNotFoundError:
            print("Error: File not found.")

    np.savez('results_dict.npz',results_dict=results_dict)
    # Now loop through each variable in the set of runs, determine slope and adjust



    # print the average costs for fun
    print(' -------------------------------------------')
    batches =  len(results_dict['cent.csv']['batch_start_indx'])
    for i in range(batches):
        batch_start  = results_dict['cent.csv']['batch_start_indx'][i]

        if (i == batches - 1):
            batch_end = -1
        else:
            batch_end  = results_dict['cent.csv']['batch_start_indx'][i+1] -1
            
        costs_this_batch = results_dict['cent.csv']['costs'][batch_start:batch_end]
        average_costs = np.mean(costs_this_batch)
        variance_costs  = np.var(costs_this_batch)
        print(' batch = ' + str(i))
        print(' ave cost = ' + str(average_costs) + ' variance = ' + str(variance_costs))


    #now update the parameters
    params_to_update = ['ALL_INPUT_GAIN','SEARCH_VALUE_IMPRV_GAIN',
                        'EXPLORE_BIAS']
    new_vals_dict = {}

    # load the adam states from last time
    saved_adam_states = np.load('adam_states.npz', allow_pickle=True)
    adam_param_state = saved_adam_states['adam_param_state'].tolist()
    beta1 = saved_adam_states['beta1']
    beta2 = saved_adam_states['beta2']
    epsilon = saved_adam_states['epsilon']


    for param in params_to_update:

        print(' ***************************  ')
        print(' Updating param: ' + param)
        # Step 1
        # what is the param value for the center?
        # find cost at the center of this group
        
        cent_last_batch_start_index = results_dict['cent.csv']['batch_start_indx'][-1]
        print(' Starting at index = ' + str(cent_last_batch_start_index))
        cent_costs_this_batch = results_dict['cent.csv']['costs'][cent_last_batch_start_index:]

    
        cent_ave_costs = sum(cent_costs_this_batch)/len(cent_costs_this_batch)
        cent_var_costs = np.var(cent_costs_this_batch)
        
        param_val_at_cent = getParam(param, cent_last_batch_start_index, results_dict['cent.csv']['df'])

        print(' Where the param val at the cent was = ' + str(param_val_at_cent))

        # Step 2
        # what are the costs in the other runs where this param was adjusted?
        # first find the other runs
        # then load the values into:
        slopes = []
        slope_vars = []
        for file_name, results in results_dict.items():

            # then did the last batch start in this run?
            this_set_last_batch_start_index = results['batch_start_indx'][-1]
           
            
            # what was the param value in this run?
            param_val_this_run = getParam(param, this_set_last_batch_start_index, results['df'])

            # is it different?
            delta = param_val_this_run - param_val_at_cent
            if (delta == 0.0):
                continue

            print(' Found a perturbation in file ' + file_name + ' at index = ' + str(this_set_last_batch_start_index))

            # it is different, so compute the cost
            this_set_costs_this_batch = results['costs'][this_set_last_batch_start_index:]
    
            this_set_ave_costs = sum(this_set_costs_this_batch)/len(this_set_costs_this_batch)
            this_set_var_costs = np.var(this_set_costs_this_batch)
            if (this_set_var_costs == 0.0):
                this_set_var_costs = 1.0

            delta_cost = this_set_ave_costs - cent_ave_costs

            slope = delta_cost / delta
            print('Found estimated slope to be = ' + str(slope))
            print('Found associated variance to be = ' + str(this_set_var_costs))
            print('Using ' + str(len(this_set_costs_this_batch)) + ' samples')

            slopes.append(slope)
            slope_vars.append(this_set_var_costs)
            

        if ((len(slopes) < 2) or (len(slope_vars) < 2)):
            print('Error:  Did not find at least two other sets of params with perturbed vals for ' + param)
            #quit()
        

        # Step 3. Now compute the new values
        num = 0.0
        denom = 0.0
        for i in range( len(slopes)):
            num += slopes[i]/slope_vars[i]
            denom += 1.0 / slope_vars[i]

        if (denom == 0.0):
            denom = 1.0
        
        best_estimate_of_slope =  num / denom
        print('Best estimate of slope ' + str(best_estimate_of_slope))

        #update the adam states
        print('Running ADAM with these params')
        print(adam_param_state[param])
        new_m = beta1 * float(adam_param_state[param]['m']) + ( 1.0 - beta1 ) * best_estimate_of_slope
        new_v = beta2 * float(adam_param_state[param]['v']) + ( 1.0 - beta2 ) * best_estimate_of_slope**2
        m_hat = new_m / ( 1.0 - beta1**( adam_param_state[param]['t'] + 1.0))
        v_hat = new_v / ( 1.0 - beta2**( adam_param_state[param]['t'] + 1.0))

        print('new_m =' + str(new_m))
        print('new_v =' + str(new_v))
        print('m_hat =' + str(m_hat))
        print('v_hat =' + str(v_hat))

        
        # learning rate
        up_gain = 1.0

        print('total update = ' + str(up_gain * m_hat / (np.sqrt(v_hat) + epsilon)))

        # update param via ADAM
        new_val =  param_val_at_cent - up_gain * m_hat / (np.sqrt(v_hat) + epsilon)

        # clip it
        max_change = 0.9
        if (new_val > (1.0 + max_change) *  param_val_at_cent):
            new_val = (1.0 + max_change) *  param_val_at_cent
        elif (new_val < (1.0 - max_change) * param_val_at_cent):
            new_val = (1.0 - max_change) *  param_val_at_cent

        print('Old val = ' + str( param_val_at_cent) + ', new val = ' + str(new_val))
        
        # and save it!
        new_vals_dict[param] = new_val

        # and also update the adam states
        adam_param_state[param]['m'] = new_m
        adam_param_state[param]['v'] = new_v
        adam_param_state[param]['t'] = adam_param_state[param]['t'] + 1.0

        print('Updating ADAM with these params')
        print(adam_param_state[param])


    # Now that we are done updating, we need to write out the adam states
    # for next time
    np.savez('adam_states.npz',adam_param_state=adam_param_state, beta1=beta1, beta2=beta2, epsilon=epsilon)


    
    # Now that we have the new params, we just need to write them out
    # Open the file in write mode ('w')
    with open("params/params_cent.txt", "w") as f:
        # Write each variable and its value to the file
        for key, val in new_vals_dict.items():
            val_str = "{:.3f}".format(val)
            f.write(key + '=' + val_str + '\n')
 
    quit()
    
       


                
if __name__ == "__main__":
    main()
