#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# File: sawmill.py
# Author: Maciej Kaniewski
# Copyright (c) Maciej Kaniewski 2023
# MIT License

from pulp import LpInteger, LpMinimize, LpProblem, LpVariable
from tabulate import tabulate

BOARD_WIDTH = 22
BOARDS_WIDTHS = [3, 5, 7]
REQUIRED_COUNTS = [72, 135, 120]

def find_cutting_patterns(total_width, widths, current_pattern=None):
    """
    Finds all the unique ways to cut a board
    of a given total width into pieces of specified widths.
    """
    if current_pattern is None:
        current_pattern = [0] * len(widths)
    if total_width < 0:
        return []
    if total_width == 0:
        return [list(current_pattern)]
    patterns = []
    for i, length in enumerate(widths):
        new_pattern = current_pattern[:]
        new_pattern[i] += 1
        patterns += find_cutting_patterns(total_width - length, widths, new_pattern)
    return [list(pattern) for pattern in set(tuple(pattern) for pattern in patterns)]

def print_patterns(cutting_patterns):
    """
    Print the cutting patterns for a given list of patterns.
    """
    for pattern in cutting_patterns:
        print(f"For a board of width {BOARD_WIDTH} inch:")
        for width, count in zip(BOARDS_WIDTHS, pattern):
            print(f"· Cut {count} pieces of width {width} inch")
        print()

def solve_problem(cutting_patterns):
    """
    Solve the sawmill problem using linear programming.
    """
    problem = LpProblem("Sawmill_Problem", LpMinimize)

    # Decision variables
    for i in range(len(cutting_patterns)):
        problem.addVariable(LpVariable(f"x{i+1}", 0, None, LpInteger))

    # Objective function
    problem.setObjective(sum(problem.variables()))

    # Constraints
    for i in range(len(REQUIRED_COUNTS)):
        constraint = sum(
            cutting_patterns[j][i] * problem.variables()[j]
            for j in range(len(cutting_patterns))
        )
        problem.addConstraint(constraint >= REQUIRED_COUNTS[i])

    # Solve the problem
    problem.solve()
    return problem

def display_results(problem, cutting_patterns):
    """
    Display the results of the sawmill optimization problem.
    """
    print(f"Problem variables: {problem.variables()}")
    print(f"Problem objective: {problem.objective}")
    for key, value in problem.constraints.items():
        print(f"Constraint Key: {key}, Constraint Value: {value}")

    print()
    table_data = []
    for i, width in enumerate(BOARDS_WIDTHS):
        total_logs_produced = sum(
            cutting_patterns[j][i] * problem.variables()[j].varValue
            for j in range(len(cutting_patterns))
        )
        table_data.append([width, REQUIRED_COUNTS[i], total_logs_produced])

    headers = ["Width", "Number Required", "Number Produced"]
    table = tabulate(
        table_data, headers, tablefmt="grid", numalign="center", stralign="center"
    )
    print(table)
    print()
    total_logs = 0
    for i, var in enumerate(problem.variables()):
        if var.varValue != 0:
            total_logs += var.varValue
            print(f"Cut {int(var.varValue)} logs with pattern")
            for j, width in enumerate(BOARDS_WIDTHS):
                if cutting_patterns[i][j] != 0:
                    print(f"· {cutting_patterns[i][j]} cut(s) of length {width}")
            print()

    print(f"Optimal solution uses {total_logs} logs")

    extra_boards_dict = {width: 0 for width in BOARDS_WIDTHS}

    for i, width in enumerate(BOARDS_WIDTHS):
        total_logs_produced = sum(
            cutting_patterns[j][i] * problem.variables()[j].varValue
            for j in range(len(cutting_patterns))
        )
        extra_boards = total_logs_produced - REQUIRED_COUNTS[i]
        if extra_boards > 0:
            extra_boards_dict[width] += extra_boards

    storage_cost = sum(width * count for width, count in extra_boards_dict.items())
    print(f"Storage cost: {storage_cost}")

def main():
    """
    Entry point of the sawmill program.
    """
    cutting_patterns = find_cutting_patterns(BOARD_WIDTH, BOARDS_WIDTHS)
    results = solve_problem(cutting_patterns)
    print_patterns(cutting_patterns)
    display_results(results, cutting_patterns)

if __name__ == "__main__":
    main()