#!/usr/bin/env python3

import pandas as pd
import numpy as np


# returns the value in the first column it finds that contains
#         the param name.  (for a given row
def getParam(param_name, row_indx, df_in):
    for index, row in df_in.iterrows():
        if (index == row_indx):
            for column_name, value in row.items():
                if (param_name in column_name):
                    return(float(value))
    print('ERROR: Could not find entry in df that contains' + str(param_name))
    return(0.0)



def findBatchStartAndChecks(df_in):
    #          Find when all the batches start and
    #          Check that the values are all the same accross all
    #          the vehicles.
    batch_start_indx = []
    last_all_input_gain = -111111.0
    last_search_value_imprv_gain = -111111.0
    last_batt_exh_gain = -111111111.0
    last_explore_bias  = -111111111.0
    
    for index, row in df_in.iterrows():

        new_batch_started_this_row = False
        all_vehicles_same_params = True

        this_row_all_input_gain = 0.0
        this_row_search_value_imprv_gain = 0.0
        this_row_batt_exh_gain = 0.0
        this_row_explore_bias = 0.0

        found_all_input_gain = False
        found_search_value_imprv_gain = False
        found_batt_exh_gain = False
        found_explore_bias = False
        
        for column_name, value in row.items():

            if ('ALL_INPUT_GAIN' in column_name):
                # is this the first time we've seen this
                # var in this row?
                if (not found_all_input_gain):
                    found_all_input_gain = True
                    this_row_all_input_gain = float(value)
                    
                # if not, is this a new value?
                elif (float(value) != this_row_all_input_gain):
                    all_vehicles_same_params = False

            elif ('SEARCH_VALUE_IMPRV_GAIN' in column_name):
                # is this the first time we've seen this
                # var in this row?
                if (not found_search_value_imprv_gain):
                    found_search_value_imprv_gain = True
                    this_row_search_value_imprv_gain = float(value)
        
                # if not, is this a new value?
                elif (float(value) != this_row_search_value_imprv_gain):
                    all_vehicles_same_params = False

            elif ('BATT_EXH_GAIN' in column_name):
                # is this the first time we've seen this
                # var in this row?
                if (not found_batt_exh_gain):
                    found_batt_exh_gain = True
                    this_row_batt_exh_gain = float(value)
                    
                # if not, is this a new value?
                elif (float(value) != this_row_batt_exh_gain):
                    all_vehicles_same_params = False
            
            elif ('EXPLORE_BIAS' in column_name):
                # is this the first time we've seen this
                # var in this row?
                if (not found_explore_bias):
                    found_explore_bias = True
                    this_row_explore_bias = float(value)
                    
                # if not, is this a new value?
                elif (float(value) != this_row_explore_bias):
                    all_vehicles_same_params = False
                    
  
                      
        # Ok, now check they all had same
        if (not all_vehicles_same_params):
            print('ERROR: found a run where vehicles did not have the same params!')
            print('       Bad run is index = ' + str(index))
        else:

            if (this_row_all_input_gain != last_all_input_gain):
                new_batch_started_this_row = True
            if (this_row_search_value_imprv_gain != last_search_value_imprv_gain):
                new_batch_started_this_row = True
            if (this_row_batt_exh_gain != last_batt_exh_gain):
                new_batch_started_this_row = True
            if (this_row_explore_bias != last_explore_bias):
                new_batch_started_this_row = True

        if (new_batch_started_this_row):
            print('Found new batch start at index = ' + str(index))
            batch_start_indx.append(index)
            last_all_input_gain = this_row_all_input_gain
            last_search_value_imprv_gain = this_row_search_value_imprv_gain
            last_batt_exh_gain = this_row_batt_exh_gain
            last_explore_bias = this_row_explore_bias
 
    
    # Now we know where each batch starts and the data is clean!
    return batch_start_indx, all_vehicles_same_params


def computeCosts(df_in):

    costs = []

    # compute these metrics
    dist_per_sample = 0.0    # total odom for the fleet divided by the total samples
    odom_variance = 0.0
    percent_unobserved_blooms = 0.0
    percent_unsampled_blooms = 0.0

    # compute 
    for index, row in df_in.iterrows():

        odoms = []
        samples = []
        detected_blooms_completed = 0.0
        sampled_blooms_completed = 0.0
        unsamp_undetect_blooms_completed = 0.0
        
        for column_name, value in row.items():
            if ('ODOM' in column_name):
                odoms.append(float(value))
            elif ('NUMBER_OF_SAMPLED_BLOOMS' in column_name):
                samples.append(float(value))
            elif ('DETECTED_BLOOMS_COMPLETED' in column_name):
                detected_blooms_completed = float(value)
            elif ('SAMPLED_BLOOMS_COMPLETED' in column_name):
                sampled_blooms_completed = float(value)
            elif ('UNSAMP_UNDETECT_BLOOMS_COMPLETED' in column_name):
                unsamp_undetect_blooms_completed = float(value)
                

        # calculate cost for this row.
        total_odom = sum(odoms)
        total_samples = sum(samples)
        if (total_samples == 0):
            total_samples = 1.0  # just so we never divide by zero.  This will be a huge
                                 # number anyway
                                 
        #-------------------------------------------
        dist_per_sample = total_odom / total_samples
        #print('dist_per_sample =' + str(dist_per_sample))

        #-------------------------------------------
        odom_variance = np.var(odoms)
        #print('odom_variance =' + str(odom_variance))

        total_blooms = detected_blooms_completed + sampled_blooms_completed + unsamp_undetect_blooms_completed
        if (total_blooms == 0):
            total_blooms = 1.0

        #-------------------------------------------
        percent_unobserved_blooms = float(unsamp_undetect_blooms_completed/total_blooms)
        #print('percent_unobserved_blooms = ' + str(percent_unobserved_blooms))
        
        #-------------------------------------------
        percent_unsampled_blooms = float(detected_blooms_completed/total_blooms)
        #print('percent_unsampled_blooms = ' + str(percent_unsampled_blooms))

        # try to nomalize everything
        this_cost = 0.0
        #this_cost += dist_per_sample / 100.0
        #this_cost += odom_variance / 100000.0
        #this_cost += percent_unobserved_blooms * 4.0
        this_cost += percent_unsampled_blooms * 2.0
        #print('this cost = ' + str(this_cost))

        costs.append(this_cost)
        
    # get all 
    return costs
