#!/usr/bin/python3
import threading
import time
import random
import logging
from influxdb import InfluxDBClient

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()

class ConfigCollector(threading.Thread):
    STATE_NAME = 0
    STATE_TIME = 1
    
    def __init__(self, sample_func, write_func, resource_name, sample_rate=2, agg_period=10):
        threading.Thread.__init__(self)
        self._start_event = threading.Event()
        self.sample_func = sample_func
        self.write_func = write_func
        self.resource_name = resource_name
        self.sample_rate = sample_rate
        self.agg_period = agg_period       
        self.agg_states = {}
        self.current_measurement = {}
        return

    def run(self):
        # if thread running then return
        if(self._start_event.is_set()):
            return
        self._start_event.set()

        # set start period to current time
        start_period = time.time()
        logger.debug("start time = {0}".format(start_period))
        # set end period to the aggregation period
        end_period = start_period + self.agg_period
        logger.debug("end time = {0}".format(end_period))
        # initialise the time in the current state
        current_state_time = 0
        samples = [] 
        while(self._start_event.is_set()):
            # get sample using sampler function
            (sample_state, sample_time) = self.sample_func()
            # add sample to list of samples
            samples.append((sample_state, sample_time))
            logger.debug("Sample state {0}".format(sample_state))
            logger.debug("Sample count: {0}".format(len(samples)))
            # if last sample was at the end of the aggregation period then process         
            if sample_time >= end_period:
                # aggregate samples into single measurement
                self.current_measurement = self.create_measurement(samples, current_state_time, sample_time)
                # write output
                write_thread = WriteThread(self.write_func, self.current_measurement)
                write_thread.start()      
                # set time in current state
                current_state_time = self.current_measurement[0]['fields']['current_state_time']
                # remove all processed samples    
                samples.clear()
                # add last sample as 1st sample of the next period
                samples.append((sample_state, sample_time))
                # set new end period
                end_period = sample_time + self.agg_period
                logger.debug("Number of samples after agg: {0}".format(len(samples)))
                logger.debug("Next end time {0}".format(end_period))       

            # calc how long it took to process samples
            processing_time = time.time() - sample_time
            logger.debug("Processing time {0}".format(processing_time))
            # calc the remaining time to wait until next sample             
            sleep_time = self.sample_rate - processing_time
            logger.debug("Sleep time {0}".format(sleep_time))
            # if processing took longer than the sample rate we have a problemm 
            # and we will need to put processing into a worker thread
            if(sleep_time < 0):
                logger.warn("Aggregation processing took longer that sample rate")
                sleep_time = 0
            logger.debug("Sleeping for sample {0}".format(sleep_time))
            # wait for the next sample
            time.sleep(sleep_time)
        logger.debug("Finished collection thread")
        return

    def stop(self):
        logger.debug("Stopping thread")
        self._start_event.clear()
        
    def create_measurement(self, samples, initial_state_time, current_time):
        logger.debug("Samples: {0}".format(str(samples)))

        # aggregate samples into states
        states = self.aggregate_samples(samples) 
        logger.debug("States: {0}".format(str(states)))  
        
        # aggregate the states into a measurement
        fields = self.aggregate_states(states, initial_state_time)
        measurement_time = int(current_time*1000000000)
        measurement = [{"measurement": "service_config_state",
               "tags": {
                   "resource_name": self.resource_name
               },
               "time": measurement_time
               }]
        measurement[0]['fields'] = fields['fields']
        logger.debug("Report: {0}".format(str(measurement)))

        return measurement

    def aggregate_samples(self, samples):
        states = []
        
        sample_count = len(samples)
        logger.debug("Sample count {0}".format(sample_count))
        # error if no samples to aggregate
        if sample_count == 0:
            raise ValueError('No samples in the samples list') 

        # no aggregation needed if only one sample
        if sample_count == 1:
            return samples[0]

        # aggregate samples
        last_index = sample_count-1 
        for index, sample in enumerate(samples):
            # for the 1st sample we set the current state and state_start_time
            if index == 0:
                current_state = sample[self.STATE_NAME]
                state_start_time = sample[self.STATE_TIME]
                logger.debug("Start time : {0}".format(state_start_time))            
            else:
                # add state duration for previous state after transition
                if current_state != sample[self.STATE_NAME]:
                    # calc time in current state
                    state_time = sample[self.STATE_TIME] - state_start_time
                    states.append([current_state,state_time])
                    # set the state to the next state
                    current_state = sample[self.STATE_NAME]
                    # set the start time of the next state
                    state_start_time = state_start_time + state_time
                # deal with the final sample
                if index == last_index:
                    # calc state duration if last sample is the same as previous state
                    if current_state == sample[self.STATE_NAME]:
                        state_time = sample[self.STATE_TIME] - state_start_time
                        states.append([current_state,state_time])
                    # add transition in final sample with zero duration
                    elif current_state != sample[self.STATE_NAME]:
                        states.append([current_state,0])
        return states        

    def aggregate_states(self, states, initial_state_time):
        # set initial state to the 1st sample
        initial_state = states[0][self.STATE_NAME]
        logger.debug("Initial state : {0}".format(initial_state))            
        logger.debug("Initial state time : {0}".format(initial_state_time))
        # set the current state as the last state sampled
        current_state = states[-1][self.STATE_NAME]
        # if no change in state  take the initial state time and add current state time
        if initial_state == current_state and len(states) == 1:
            current_state_time = initial_state_time + states[-1][self.STATE_TIME]
            state_sum_key = current_state + "_sum"
            state_count_key = current_state + "_count"
            # initialise the number of transitions if it's the 1st time
            if state_sum_key not in self.agg_states:                     
                self.agg_states[state_count_key] = 1
            self.agg_states[state_sum_key] = current_state_time       
        else:
            # current state time is the last state time
            current_state_time = states[-1][self.STATE_TIME]
            # calc the total duration and number of transitions in each state.
            for state in states:
                # if first occurance of state add with initial duration and a single transition   
                state_sum_key = state[self.STATE_NAME] + "_sum"
                state_count_key = state[self.STATE_NAME] + "_count"
                if state_sum_key not in self.agg_states:
                    logger.debug("Adding state: {0}".format(state[self.STATE_NAME]))
                    self.agg_states[state_sum_key] = state[self.STATE_TIME]              
                    self.agg_states[state_count_key] = 1
                else:
                    logger.debug("Aggregating state: {0}".format(state[self.STATE_NAME]))                     
                    # increment number of times in the state                    
                    self.agg_states[state_count_key] += 1           
                    logger.debug("increment number of times in the state")                     
                    # add state time to aggregate total
                    self.agg_states[state_sum_key] += state[self.STATE_TIME]
                    logger.debug("Duration: {0}".format(self.agg_states[state_sum_key]))

        # Create report
        measurement = {}
        measurement['fields'] = self.agg_states
        measurement['fields']['current_state'] = current_state            
        measurement['fields']['current_state_time'] = current_state_time
        return measurement

class WriteThread(threading.Thread):
    def __init__(self, write_func, measurement):
        threading.Thread.__init__(self)
        self._start_event = threading.Event()
        self.write_func = write_func
        self.measurement = measurement
        return

    def run(self):
        # if thread running then return
        if(self._start_event.is_set()):
            return
        self._start_event.set()  
        self.write_func(self.measurement)
    
    def stop(self):
        self._start_event.clear() 

class InfluxWriter():
    def __init__(self, hostname, port, database):
        self.db_client = InfluxDBClient(host=hostname, port=port, database=database, timeout=10) 
        return

    def write(self, measurement):
        # if thread running then return 
        try:  
            points = []                             
            points.append(measurement)           
            self.db_client.write_points(points)
        except Exception as e: 
            print(e)