
"""
Combinatorial optimization tools to find and count
solutions satisfying some tests
"""

#############################################################################
# Imports ###################################################################

import sys
import itertools
import numpy as np
try:
    import pulp
except ModuleNotFoundError:
    sys.stderr.write(
"""Pulp is not installed. You will not be able to use the solving utilities.
The rest should work anyway.
"""
    )

import fairdiv


#############################################################################
# Constants #################################################################

# Do not forget to define this constant in config.py
try:
    from config import CPLEX_PATH
except ModuleNotFoundError:
    sys.stderr.write(
"""CPLEX_PATH not configured. You should set this constant in config.py
(you can use config_template.py as a template for config.py).
"""
    )
    sys.exit(1)


#############################################################################
# Backtracking ##############################################################

def find_solution(instance, is_solution, cut=lambda allocation: False):
    """
    Runs a backtracking algorithm to find an allocation that
    satisfies is_solution. The function cut is used to prune
    branches of the tree (by default, does not prune anything).

    The function returns the allocation found, if any, or None
    if no solution exists.
    """
    return _find_solution(fairdiv.Allocation(instance), 0, is_solution, cut)


def _find_solution(allocation, k, is_solution, cut):
    if k == allocation.nb_objects:  # leaf
        if is_solution(allocation):
            return allocation
        return None
    else:  # internal node
        if cut(allocation):
            return None
        for i in range(0, allocation.nb_agents):
            allocation.give(i, k)
            solution = _find_solution(allocation, k + 1, is_solution, cut)
            if solution is not None:  # solution found!
                return solution
            else:  # backtrack
                allocation.take_back(i, k)
    return None


def find_optimal_solution(instance, criterion,
                          maximize=True,
                          cut=lambda allocation: False):
    """
    Runs a backtracking algorithm to find an allocation that
    optimizes the criterion. The function cut is used to prune
    branches of the tree (by default, does not prune anything).

    The function returns the best allocation found.
    """
    return _find_optimal_solution(fairdiv.Allocation(instance),
                                  0,
                                  criterion,
                                  [0 if maximize else float("+infty")],
                                  fairdiv.Allocation(instance),
                                  maximize, cut)


def _find_optimal_solution(allocation, k, criterion,
                           best, best_alloc,
                           maximize, cut):
    if k == allocation.nb_objects:  # leaf
        eval = criterion(allocation)
        if (maximize and eval > best[0]) or (not maximize and eval < best[0]):
            best[0] = eval
            best_alloc.pi = np.array(allocation.pi)
    else:  # internal node
        if not cut(allocation):
            for i in range(0, allocation.nb_agents):
                allocation.give(i, k)
                _find_optimal_solution(allocation, k + 1,
                                       criterion, best, best_alloc,
                                       maximize, cut)
                allocation.take_back(i, k)
    return best_alloc


def count_solutions(instance, tests, scale=True, debug=False,
                    cut=lambda allocation: False):
    """
    Runs a backtracking algorithm to count the number of allocations
    that satisfy the list of tests passed in parameters.

    Returns:
    A list of integers. The first integer is the total number of instances tested
    (number of leaves of the search tree). The ith number (for i > 0) is the number
    of allocations passing test i-1.

    Arguments:
    -- the instance
    -- the list of tests (boolean functions)

    Keyword arguments:
    scale -- if this argument is True, then it means that the tests are
             of increasing strength (hence at a leaf, if a test fails, we
             do not need to do the subsequent tests) -- default True.
    cut   -- a Boolean function indicating whether we can prune a branch
             -- does not prune anything by default.
    """
    count = _count_solutions(fairdiv.Allocation(instance), 0, tests,
                             [0] * (len(tests) + 1), scale, debug, cut)
    if debug:
        sys.stderr.write('\n')
    return count


def _count_solutions(allocation, k, tests, counts, scale, debug, cut):
    if k == allocation.nb_objects:  # leaf
        counts[0] += 1
        if debug:
            _show_progress(counts[0],
                           allocation.nb_agents ** allocation.nb_objects)
        for i, test in enumerate(tests):
            current_test = test(allocation)
            counts[i + 1] += int(current_test)
            if scale and not current_test:
                break
        return counts
    if cut(allocation):  # cut
        return 0
    for i in range(allocation.nb_agents):
        allocation.give(i, k)
        _count_solutions(allocation, k + 1, tests, counts, scale, debug, cut)
        allocation.take_back(i, k)
    return counts


def _show_progress(current, total):
    screen_total = 80
    screen_current = (current * screen_total) // total
    sys.stderr.write('\r|' + ('=' * (screen_current)))
    sys.stderr.write((' ' * (screen_total - screen_current)))
    sys.stderr.write('| {} / {}'.format(current, total))


#############################################################################
# Max-min share #############################################################

def max_min_share(instance, agent, nb_shares=0):
    """Returns the utility of the max-min-share for a given agent."""
    nb_agents = nb_shares if nb_shares else instance.nb_agents
    nb_objects = instance.nb_objects

    # 1. We define the variables
    x = []
    for ag in range(nb_agents):
        x.append(pulp.LpVariable.dicts("x_{}".format(ag), range(nb_objects),
                                       0, 1, pulp.LpInteger))
    u_min = pulp.LpVariable("u_min", 0, None, pulp.LpContinuous)

    # 2. We define the problem with the objective
    prob = pulp.LpProblem("Max-min share computation", pulp.LpMaximize)
    prob += u_min

    # 3. We define the constraints
    # 3.1 The utility min is lower than the utility of all the shares
    for ag in range(nb_agents):
        prob += (pulp.lpSum([x[ag][obj] * instance.w[agent][obj]
                             for obj in range(nb_objects)]) >= u_min)
    # 3.2 Each object should be allocated only once
    for obj in range(nb_objects):
        prob += (pulp.lpSum([x[ag][obj] for ag in range(nb_agents)]) == 1)

    prob.solve(pulp.CPLEX(path=CPLEX_PATH))

    assert prob.status == pulp.LpStatusOptimal,\
        "The solver was not able to solve the max-min problem optimally..."
    return u_min.varValue


def min_max_share(instance, agent, nb_shares=0):
    """Returns the utility of the min-max-share for a given agent."""
    nb_agents = nb_shares if nb_shares else instance.nb_agents
    nb_objects = instance.nb_objects

    # 1. We define the variables
    x = []
    for ag in range(nb_agents):
        x.append(pulp.LpVariable.dicts("x_{}".format(ag), range(nb_objects),
                                       0, 1, pulp.LpInteger))
    u_max = pulp.LpVariable("u_max", 0, None, pulp.LpContinuous)

    # 2. We define the problem with the objective
    prob = pulp.LpProblem("Min-max share computation", pulp.LpMinimize)
    prob += u_max

    # 3. We define the constraints
    # 3.1 The utility max is greater than the utility of all the shares
    for ag in range(nb_agents):
        prob += (pulp.lpSum([x[ag][obj] * instance.w[agent][obj]
                             for obj in range(nb_objects)]) <= u_max)
    # 3.2 Each object should be allocated only once
    for obj in range(nb_objects):
        prob += (pulp.lpSum([x[ag][obj] for ag in range(nb_agents)]) == 1)

    prob.solve(pulp.CPLEX(path=CPLEX_PATH))

    assert prob.status == pulp.LpStatusOptimal,\
        "The solver was not able to solve the min-max problem optimally..."
    return u_max.varValue


#############################################################################
# CEEI ######################################################################

def better_than(instance, agent, threshold, strict=True):
    """Generator of all the shares for agent a that have
    a better utility than the threshold."""
    if strict:
        return (share
                for share in itertools.product([True, False],
                                               repeat=instance.nb_objects)
                if instance.utility(agent, share) > threshold)
    return (share
            for share in itertools.product([True, False],
                                           repeat=instance.nb_objects)
            if instance.utility(agent, share) >= threshold)


def compute_ceei_prices(allocation):
    """Computes the set of prices in a Competitive Equilibrium from Equal Incomes,
    and returns None if no such equilibrium exists."""
    nb_agents = allocation.instance.nb_agents
    nb_objects = allocation.instance.nb_objects

    # 1. We define the variables
    prices = pulp.LpVariable.dicts("p", range(nb_objects),
                                   0, None, pulp.LpContinuous)
    d = pulp.LpVariable("d", 0, None, pulp.LpInteger)

    # 2. We define the problem with the objective
    prob = pulp.LpProblem("CEEI Price Computation", pulp.LpMinimize)
    prob += d

    # 3. We define the constraints
    # 3.1 The prices of the current allocation should not exceed d
    for agent in range(nb_agents):
        prob += (pulp.lpSum([allocation.pi[agent][obj] * prices[obj]
                             for obj in range(nb_objects)]) <= d)
    # 3.2 The prices of all the better shares for a given agent exceed d
    current_utilities = allocation.utility_vect()
    for agent in range(nb_agents):
        for share in better_than(allocation.instance,
                                 agent, current_utilities[agent]):
            prob += (pulp.lpSum(
                [share[obj] * prices[obj]
                 for obj in range(nb_objects)]) >= (d + 1))

    prob.solve(pulp.CPLEX(path=CPLEX_PATH))

    if prob.status == pulp.LpStatusOptimal:
        return [prices[obj].varValue / d.varValue for obj in range(nb_objects)]
    return None


#############################################################################
# Pareto ####################################################################

def is_pareto_optimal(allocation):
    """Returns the utility of the min-max-share for a given agent."""
    nb_agents = allocation.instance.nb_agents
    nb_objects = allocation.instance.nb_objects
    old_u = allocation.utility_vect()
    old_uc = sum(old_u)

    # 1. We define the variables
    x = []
    for ag in range(nb_agents):
        x.append(pulp.LpVariable.dicts("x_{}".format(ag), range(nb_objects),
                                       0, 1, pulp.LpInteger))
    u_c = pulp.LpVariable("uc", 0, None, pulp.LpContinuous)
    # 2. We define the problem with the objective
    prob = pulp.LpProblem("Pareto-optimal test", pulp.LpMaximize)
    prob += u_c

    # 3. We define the constraints
    # 3.1 Each object should be allocated only once
    for obj in range(nb_objects):
        prob += (pulp.lpSum([x[ag][obj] for ag in range(nb_agents)]) == 1)
    # 3.2 The new utility should be greater than the old one for each agent 
    for ag in range(nb_agents):
        prob += (pulp.lpSum([x[ag][obj] * allocation.instance.w[ag][obj]
                             for obj in range(nb_objects)]) >= old_u[ag])
    # 3.3 Computation of the collective utility
    prob += (pulp.lpSum([x[ag][obj] * allocation.instance.w[ag][obj]
                         for obj in range(nb_objects)
                         for ag in range(nb_agents)]) >= u_c)
    # 3.4 The new global utility should be strictly greater
    prob += (u_c >= old_uc + 1)

    prob.solve(pulp.CPLEX(path=CPLEX_PATH))

    return prob.status != pulp.LpStatusOptimal
