package uk.ac.aber.cs31920.assignment.tests.solvers;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import uk.ac.aber.cs31920.assignment.implementation.datastructures.Edge;
import uk.ac.aber.cs31920.assignment.implementation.datastructures.GraphNode;
import uk.ac.aber.cs31920.assignment.implementation.datastructures.MaxFlowGraph;
import uk.ac.aber.cs31920.assignment.implementation.solvers.ResidualCalculator;

import java.util.ArrayList;
import java.util.List;

public class ResidualCalculationTests {
    @Test
    public void testResidualNetworkCalculation(){
        /*
        graph
        ┌─┐     ┌─┐     ┌─┐
        │1├────►│2├────►│4│
        └─┘     └┬┘     └─┘
                 │
                 ▼
                ┌─┐
                │3│
                └─┘
        flow
         1──────►2──────►3
         */

        // arrange
        GraphNode node4 = new GraphNode(null, null);
        GraphNode node3 = new GraphNode(null, null);
        GraphNode node2 = new GraphNode(new GraphNode[]{node3, node4}, null);
        GraphNode node1 = new GraphNode(new GraphNode[]{node2}, null);
        List<Edge> flow = new ArrayList<>();
        flow.add(new Edge(node1, node2));
        flow.add(new Edge(node2, node3));
        MaxFlowGraph graph = new MaxFlowGraph(null, null, List.of(node1, node2, node3, node4));

        // act
        ResidualCalculator.calculateResidualNetwork(node1, flow, graph);

        // assert
        List<GraphNode> expectedNode1Residual = new ArrayList<GraphNode>();
        List<GraphNode> expectedNode2Residual = new ArrayList<GraphNode>(List.of(node1, node4));
        List<GraphNode> expectedNode3Residual = new ArrayList<GraphNode>(List.of(node2));
        List<GraphNode> expectedNode4Residual = new ArrayList<GraphNode>();
        Assertions.assertEquals(expectedNode1Residual, node1.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode2Residual, node2.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode3Residual, node3.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode4Residual, node4.getResidualNetworkChildren());
    }

    @Test
    public void testResidualNetworkCalculationInMaxMatchGraph() {
        /*
                    graph
                     ┌─┐       ┌─┐
                 ┌──►│2├───┐   │5├───┐
                 │   └─┘   │┌─►└─┘   │
                 │       ┌─┼┘        ▼
                ┌┴┐  ┌─┐ │ └──►┌─┐  ┌─┐
               S│1├─►│3├─┼────►│6├─►│8│T
                └┬┘  └─┘ │ ┌──►└─┘  └▲┘
                 │       └─┼┐        │
                 │   ┌─┐   │└─►┌─┐   │
                 └──►│4├───┘   │7├───┘
                     └─┘       └─┘
        */

        // arrange
        GraphNode node1 = new GraphNode(null, null, 1); node1.setAsSource();
        GraphNode node2 = new GraphNode(null, null, 2);
        GraphNode node3 = new GraphNode(null, null, 3);
        GraphNode node4 = new GraphNode(null, null, 4);
        GraphNode node5 = new GraphNode(null, null, 5);
        GraphNode node6 = new GraphNode(null, null, 6);
        GraphNode node7 = new GraphNode(null, null, 7);
        GraphNode node8 = new GraphNode(null, null, 8); node8.setAsSink();
        node1.addNetworkChildren(new GraphNode[]{node2, node3, node4});
        node2.addNetworkChild(node6);
        node3.addNetworkChildren(new GraphNode[]{node5, node6, node7});
        node4.addNetworkChild(node6);
        node5.addNetworkChild(node8);
        node6.addNetworkChild(node8);
        node7.addNetworkChild(node8);

        List<Edge> flow = new ArrayList<>(List.of(new Edge(node1, node2),
                new Edge(node2, node6),
                new Edge(node6, node8)));

        MaxFlowGraph graph = new MaxFlowGraph(node1, node8, List.of(node1, node2, node3, node4, node5, node6, node7, node8));

        // act
        ResidualCalculator.calculateResidualNetwork(node1, flow, graph);

        // assert
        // note that node1 and 6 have extra children because source and sink have inf capacity
        List<GraphNode> expectedNode1Residual = new ArrayList<>(List.of(node3, node4));
        List<GraphNode> expectedNode2Residual = new ArrayList<>(List.of(node1));
        List<GraphNode> expectedNode3Residual = new ArrayList<>(List.of(node5, node6, node7));
        List<GraphNode> expectedNode4Residual = new ArrayList<>(List.of(node6));
        List<GraphNode> expectedNode5Residual = new ArrayList<>(List.of(node8));
        List<GraphNode> expectedNode6Residual = new ArrayList<>(List.of(node2));
        List<GraphNode> expectedNode7Residual = new ArrayList<>(List.of(node8));
        List<GraphNode> expectedNode8Residual = new ArrayList<>(List.of(node6));
        Assertions.assertEquals(expectedNode1Residual, node1.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode2Residual, node2.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode3Residual, node3.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode4Residual, node4.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode5Residual, node5.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode6Residual, node6.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode7Residual, node7.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode8Residual, node8.getResidualNetworkChildren());
    }

    @Test
    public void testResidualNetworkCalculationWithRealisticProblem(){
            /*
                      ┌─┐      ┌─┐
                  ┌──►│2├─────►│4├───┐
                  │   └┬┘   ┌─►└─┘   ▼
                 ┌┴┐   └───┐│       ┌─┐
                s│1│   ┌───┼┘       │6│t
                 └┬┘   │   │        └─┘
                  │   ┌┴┐  └──►┌─┐   ▲
                  └──►│3├─────►│5├───┘
                      └─┘      └─┘
            */

        // arrange
        GraphNode node1 = new GraphNode(null, null, 1);
        GraphNode node2 = new GraphNode(null, null, 2);
        GraphNode node3 = new GraphNode(null, null, 3);
        GraphNode node4 = new GraphNode(null, null, 4);
        GraphNode node5 = new GraphNode(null, null, 5);
        GraphNode node6 = new GraphNode(null, null, 6);
        node1.addNetworkChildren(new GraphNode[]{node2, node3});node1.setAsSource();
        node2.addNetworkChildren(new GraphNode[]{node5, node4});
        node3.addNetworkChildren(new GraphNode[]{node5, node4});
        node4.addNetworkChild(node6);
        node5.addNetworkChild(node6);
        node6.setAsSink();

        List<Edge> flow = new ArrayList<>(List.of(
                new Edge(node1, node2),
                new Edge(node2, node4),
                new Edge(node4, node6)
        ));

        MaxFlowGraph graph = new MaxFlowGraph(node1, node6, List.of(node1, node2, node3, node4, node5, node6));

        // act
        ResidualCalculator.calculateResidualNetwork(node1, flow, graph);

        // assert
        List<GraphNode> expectedNode1Residual = new ArrayList<>(List.of(node3));
        List<GraphNode> expectedNode2Residual = new ArrayList<>(List.of(node1, node5));
        List<GraphNode> expectedNode3Residual = new ArrayList<>(List.of(node5, node4));
        List<GraphNode> expectedNode4Residual = new ArrayList<>(List.of(node2));
        List<GraphNode> expectedNode5Residual = new ArrayList<>(List.of(node6));
        List<GraphNode> expectedNode6Residual = new ArrayList<>(List.of(node4));

        Assertions.assertEquals(expectedNode1Residual, node1.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode2Residual, node2.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode3Residual, node3.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode4Residual, node4.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode5Residual, node5.getResidualNetworkChildren());
        Assertions.assertEquals(expectedNode6Residual, node6.getResidualNetworkChildren());
    }

}
