#!/usr/bin/env python


"""
This program orders prerequisite atoms in an attempt to maximise how early
goals are achievable.

The input file is a list of goal / atom pairs, each line of the form:

<goal_id> <atom_id>

A goal will have as many lines as it has prerequisites. Atoms may be
prerequisites for more than one goal. Duplicate lines are allowed but
don't make a difference.

"""



import random, math, logging, sys



class Goal:


    def __init__(self, goal_id, prereqs):
        self.goal_id = goal_id
        self.prereqs = prereqs


    def __hash__(self):
        return hash(self.goal_id)

    
    def __eq__(self, other):
        return self.goal_id == other.goal_id
        


class Goals:


    def __init__(self):
        self.goals = set()

        
    def load(self, filename):

        logging.info("loading input file: %s", filename)

        prereqs_by_goal = {}
        
        line_count = 0
        for line in file(filename):
            line_count += 1
            goal, prereq = line.strip().split()
            prereqs_by_goal.setdefault(goal, set()).add(prereq)        

        logging.info("loaded %s lines", line_count) 
        
        # reset
        self.goals = set()
        
        prereq_count = 0
        for goal in prereqs_by_goal:
            prereq_count += len(prereqs_by_goal[goal])
            self.goals.add(Goal(goal, prereqs_by_goal[goal]))

        logging.info("unique prereq statements: %s", prereq_count)


    def calculate_lookups(self):
        p2g = {}
        g2p = {}

        for goal in self.goals:
            for prereq in goal.prereqs:
                p2g.setdefault(prereq, set()).add(goal)
                g2p.setdefault(goal, set()).add(prereq)

        self.p2g = p2g
        self.g2p = g2p
        
        self.atoms = p2g.keys()
        
        self.num_atoms = len(self.atoms)
        self.num_goals = len(self.goals)

        logging.info("goal count = %s, atom count = %s", self.num_goals,
            self.num_atoms)


    def can_know(self, known):
        """    
        return a set of the goals that can be known if you know the given
        set of atoms.
        """
        can_know = set()
        for goal in self.goals:
            if known.issuperset(self.g2p[goal]):
                can_know.add(goal)

        return can_know


    def freq_order(self):
        """
        return list of atoms in order of frequency they appear as
        prerequisites.
        """
        atom_list = self.atoms[:]
        
        def freq(x):
            return len(self.p2g[x])
        
        atom_list.sort(key=freq, reverse=True)
        
        return atom_list


    def goal_order_from_atom_list(self, atom_list):

        goal_list = []
        old_can = set()

        for i in range(len(atom_list)):
            known = set(atom_list[:i+1])
            can = self.can_know(known)
            for goal in (can - old_can):
                goal_list.append(goal)
            old_can = can
        return goal_list



    def display_programme(self, stream, goal_list):

        known_atoms = set()
        for goal_count, goal in enumerate(goal_list):
            to_learn = goal.prereqs - known_atoms
            for atom in to_learn:
                stream.write("learn %s\n" % atom)
            stream.write("know %s\n" % goal.goal_id)
            known_atoms.update(to_learn)
        


class Scorer:


    def __init__(self, goals):
        self.goals = goals

        
    def calc_score(self, goal_list):
        """
        calculate score for a particular ordering of goals.
        """

        known_atoms = set()
        score = 0.0
        for goal_count, goal in enumerate(goal_list):
            step_score = float(goal_count) / self.goals.num_goals
            to_learn = goal.prereqs - known_atoms
            for atom in to_learn:
                score += step_score
            known_atoms.update(to_learn)
        return score



class SimulatedAnnealing:


    def __init__(self, scorer):
        self.scorer = scorer


    def swap(self, l, i, j):
        n = l[:]
        n[i], n[j] = n[j], n[i]
        return n


    def go(self, goal_list, initial_temp, final_temp, iterations, alpha):

        logging.info("starting simulated annealing")
        logging.info("temp: %s -> %s iterations: %s alpha: %s",
            initial_temp, final_temp, iterations, alpha)

        num_goals = len(goal_list)
        temp = initial_temp

        while temp > final_temp:
            logging.info("temp = %s,  score = %s", temp,
                self.scorer.calc_score(goal_list))
            for i in range(iterations):
                score_1 = self.scorer.calc_score(goal_list)
                pos_1 = random.randrange(0, num_goals)
                pos_2 = random.randrange(0, num_goals)
                new_list = self.swap(goal_list, pos_1, pos_2)
                score_2 = self.scorer.calc_score(new_list)
                if score_2 > score_1:
                    goal_list = new_list
                else:
                    if random.random() < math.exp((score_2 - score_1) / temp):
                        goal_list = new_list
            temp = temp * alpha
        return self.scorer.calc_score(goal_list), goal_list



def usage():
    print """
    learning.py <input-file> <output-file>
    """


def configure_logging(logfile):

    formatter = logging.Formatter("%(asctime)s: %(message)s")
    stream_handler = logging.StreamHandler(sys.stderr)
    file_handler = logging.FileHandler(logfile, "a")
    stream_handler.setFormatter(formatter)
    file_handler.setFormatter(formatter)
    logging.root.addHandler(stream_handler)
    logging.root.addHandler(file_handler)
    logging.root.setLevel(logging.INFO)
    


if __name__ == "__main__":

    if len(sys.argv) is not 3:
        usage()
    else:
        input_file = sys.argv[1]    
        output_file = sys.argv[2]
        
        configure_logging("log2")
        
        logging.info("===========================")
        
        goals = Goals()
        goals.load(input_file)
        goals.calculate_lookups()
        
        scorer = Scorer(goals)
        
        logging.info("frequency order score: %s", 
            scorer.calc_score(goals.goal_order_from_atom_list(goals.freq_order())))
        
        initial_temp = 1
        final_temp   = 0.00001
        iterations   = 200
        alpha        = 0.9
        
        initial_list = list(goals.goals)
        
        # open now so if there's a problem you find out before annealing
        out = file(output_file, "w")
     
        logging.info("opening output file: %s", output_file)
       
        sa = SimulatedAnnealing(scorer)
        final_score, programme = sa.go(initial_list, initial_temp, final_temp,
                          iterations, alpha)
        
        goals.display_programme(out, programme)

        out.close()

        logging.info("END SESSION. Programme with score %s written to %s", final_score, output_file)
