#!/usr/bin/env python3
"""
Subsampling library for argumentation frameworks.

This library provides:
1) Parsing of an argumentation framework (AF) file in a specific format (typically .af extension).
2) Optional parsing of a solution file under preferred semantics
   (to compute extension-based metrics for "Degree-/Extension-Guided" sampling).
3) Four subsampling methods:
   - random_subsample
   - degree_extension_subsample
   - bfs_subsample
   - community_subsample
4) Functions to output the resulting subsampled AF in the same format (typically .af extension).
"""

import random
import networkx as nx

def parse_graph(graph_filename):
    """
    Reads an argumentation framework (AF) file.
    It expects an optional header line 'p af <num_nodes>' which will be ignored.
    The subsequent format expected is:
      <argument_1>
      <argument_2>
      ...
      #
      <attacker_1> <attacked_1>
      ...
    Typically, these files use the .af extension.

    Returns:
      nodes (list): list of node labels (strings)
      edges (list of tuples): list of directed edges (attacker, attacked)
    """
    nodes = []
    edges = []
    header_checked = False

    with open(graph_filename, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue  # skip empty lines

            # Check for header on the first non-empty line encountered
            if not header_checked:
                header_checked = True # Mark header as checked/processed
                if line.startswith("p af "):
                    try:
                        # Validate and store the declared node count (optional)
                        parts = line.split()
                        if len(parts) == 3:
                            declared_node_count = int(parts[2])
                            nodes.extend([i for i in range(1,declared_node_count + 1)])
                        else:
                             print(f"Warning: Malformed header line {line_num}: '{line}' in {graph_filename}. Ignoring.")
                    except ValueError:
                        print(f"Warning: Non-integer node count in header line {line_num}: '{line}' in {graph_filename}. Ignoring.")
                    continue # Skip the rest of the loop for this header line

                # else: If it wasn't a header, this line is the first node, fall through to processing below

            else: # reading edges (after '#')
                # Check if the line is the separator again (should not happen ideally)
                if line.startswith('#'):
                     continue

                # Each line after "#" is "X Y" meaning X attacks Y
                parts = line.split()
                if len(parts) == 2:
                    edges.append([int(parts[0]), int(parts[1])])
                else:
                    print(f"Warning: Malformed edge line {line_num}: '{line}' in {graph_filename}. Expected format 'attacker attacked'. Ignoring.")

    return nodes, edges


def parse_solution(solution_filename):
    """
    Reads a solution file under preferred semantics, which has multiple solution sets in a bracketed format:
      [[a1,a2,...],[a1,a3,...],...]

    Returns:
      extension_count (dict):
          A dictionary counting how many times each argument appears
          across all solution sets. Key: argument, Value: occurrence count.
    """
    extension_count = {}
    with open(solution_filename, 'r', encoding='utf-8') as f:
        content = f.read().strip()

        # Remove outer brackets
        cleaned = content.lstrip('[').rstrip(']')
        # Handle empty list case
        if not cleaned:
            return {}
        # Now split on "],["
        sets_raw = cleaned.split('],[')

        for s in sets_raw:
            # Remove any remaining [ or ]
            s_clean = s.replace('[', '').replace(']', '')
            # split on commas
            args = s_clean.split(',')
            for arg in args:
                arg = arg.strip()
                if arg:
                    extension_count[arg] = extension_count.get(arg, 0) + 1

    return extension_count


def build_graph(nodes, edges):
    """
    Build and return a directed NetworkX graph from node list and edge list.
    """
    G = nx.DiGraph()
    G.add_nodes_from(nodes)
    G.add_edges_from(edges)
    return G


def random_subsample(nodes, edges, proportion):
    """
    Randomly selects proportion of the nodes and retains edges that connect them.
    Returns a subsampled (nodes_sub, edges_sub).
    """

    total_nodes = len(nodes)
    k = int(round(proportion * total_nodes))
    if k <= 0 and total_nodes > 0: # Allow k=0 if input nodes is empty
         raise ValueError("proportion * total_nodes results in zero nodes to be selected.")
    if k <= 0 and total_nodes == 0:
        return [], [] # Handle empty graph case
    if k > total_nodes: # Cannot sample more nodes than exist
        k = total_nodes

    selected_nodes = set(random.sample(nodes, k))

    # Filter edges to keep only those where both endpoints are in selected_nodes
    filtered_edges = [(u, v) for (u, v) in edges if (u in selected_nodes and v in selected_nodes)]

    return list(selected_nodes), filtered_edges


def degree_extension_subsample(nodes, edges, proportion, solution_dict=None):
    """
    Selects proportion of the nodes based on a relevance score:
      score(node) = degree(node) + alpha * extension_count(node)
    where alpha is a scaling factor you can adjust if you want.

    If solution_dict is provided, extension_count is used; otherwise only degree is used.
    """
    alpha = 1.0  # you can tune this parameter

    G = build_graph(nodes, edges)

    # Calculate (in_degree + out_degree) for each node
    degrees = {}
    for n in G.nodes:
        degrees[n] = G.in_degree(n) + G.out_degree(n)

    # For extension count, if no solution_dict given, treat it as 0
    def get_extension_count(n):
        if solution_dict is None:
            return 0
        return solution_dict.get(n, 0)

    # Compute combined score
    node_scores = []
    for n in nodes:
        score = degrees.get(n, 0) + alpha * get_extension_count(n) # Use degrees.get for safety
        node_scores.append((n, score))

    # Sort by descending score
    node_scores.sort(key=lambda x: x[1], reverse=True)

    # Pick top p% of nodes
    total_nodes = len(nodes)
    k = int(round(proportion * total_nodes))
    if k <= 0 and total_nodes > 0:
         raise ValueError("proportion * total_nodes results in zero nodes to be selected.")
    if k <= 0 and total_nodes == 0:
        return [], []
    if k > total_nodes:
        k = total_nodes


    selected_nodes = set([x[0] for x in node_scores[:k]])

    # Filter edges
    filtered_edges = [(u, v) for (u, v) in edges if (u in selected_nodes and v in selected_nodes)]

    return list(selected_nodes), filtered_edges


def bfs_subsample(nodes, edges, proportion, seed=None):
    """
    BFS ("snowball") sampling.
    - Start from one random or specified seed.
    - Expand until proportion of the nodes is reached (or graph exhausted).
    Returns (nodes_sub, edges_sub).
    """
    G = build_graph(nodes, edges)

    total_nodes = len(nodes)
    k = int(round(proportion * total_nodes))

    if k <= 0 and total_nodes > 0:
         raise ValueError("proportion * total_nodes results in zero nodes to be selected.")
    if k <= 0 and total_nodes == 0:
        return [], []
    if k > total_nodes:
        k = total_nodes

    if not nodes: # Handle empty graph
        return [], []

    if seed is None:
        seed = random.choice(nodes)
    elif seed not in G:
         # If provided seed is not in the graph, maybe default to random? Or raise error?
         # Let's default to random for robustness.
         print(f"Warning: Provided seed '{seed}' not found in nodes. Choosing a random seed.")
         seed = random.choice(nodes)


    visited = set()
    queue = [seed]
    visited.add(seed)

    idx = 0
    while queue and len(visited) < k:
        current = queue.pop(0)
        # neighbors = direct successors + direct predecessors to mimic argumentation links
        neighbors = list(G.successors(current)) + list(G.predecessors(current))
        random.shuffle(neighbors)  # optional shuffle for variety
        for nb in neighbors:
            if nb not in visited:
                visited.add(nb)
                queue.append(nb)
                if len(visited) >= k:
                    break

    selected_nodes = visited
    filtered_edges = [(u, v) for (u, v) in edges if (u in selected_nodes and v in selected_nodes)]
    return list(selected_nodes), filtered_edges


def community_subsample(nodes, edges, proportion):
    """
    Community-based sampling:
    1) Detect communities via a standard community-detection method (e.g. greedy_modularity_communities).
    2) Sample from each community proportionally.

    Returns (nodes_sub, edges_sub).
    """
    G = build_graph(nodes, edges)
    total_nodes = len(nodes)
    k = int(round(proportion * total_nodes))

    if k <= 0 and total_nodes > 0:
         raise ValueError("proportion * total_nodes results in zero nodes to be selected.")
    if k <= 0 and total_nodes == 0:
        return [], []
    if k > total_nodes:
        k = total_nodes

    if not nodes: # Handle empty graph
        return [], []

    # We'll use the built-in greedy_modularity_communities (undirected approach),
    # so let's convert to an undirected version for community detection.
    # Then sample from each community in proportion to its size.

    UG = G.to_undirected()
    try:
        # from networkx.algorithms import community # Not needed directly if using nx.community
        communities_gen = nx.algorithms.community.greedy_modularity_communities(UG)
        communities = [frozenset(c) for c in communities_gen] # Use frozenset for hashability if needed later
    except ImportError:
         print("Error: Community detection requires relevant NetworkX components.")
         # Fallback or re-raise? Let's raise for clarity.
         raise ImportError("NetworkX community algorithms might require additional installs or availability.")


    # total communities
    sampled_nodes = set()
    remaining_needed = k

    # Calculate sizes and prepare for sampling
    comm_data = []
    total_community_nodes = 0
    for c in communities:
        size = len(c)
        if size > 0:
            comm_data.append({'nodes': list(c), 'size': size})
            total_community_nodes += size

    # Handle potential case where community detection doesn't cover all nodes (unlikely with greedy)
    if total_community_nodes != total_nodes:
        print(f"Warning: Community detection covered {total_community_nodes} nodes, expected {total_nodes}.")
        # Optionally add remaining nodes as individual communities or handle differently


    # Sample proportionally from each community
    for comm_info in comm_data:
        community_nodes = comm_info['nodes']
        community_size = comm_info['size']

        # how many to pick from this community?
        # Proportional share based on total nodes k desired.
        # Avoid division by zero if total_nodes is 0 (handled earlier)
        comm_k = int(round((community_size / total_nodes) * k))

        # Clamp comm_k: cannot be more than available in community or more than remaining needed
        comm_k = min(comm_k, community_size, remaining_needed)
        comm_k = max(comm_k, 0) # Ensure non-negative

        if comm_k > 0:
            # random sample comm_k from this community
            selected_in_comm = random.sample(community_nodes, comm_k)
            sampled_nodes.update(selected_in_comm)
            remaining_needed -= comm_k # Correctly decrement remaining needed

        if remaining_needed <= 0:
            break

    # If we haven't reached k (due to rounding or empty communities), sample randomly from the remaining nodes
    if remaining_needed > 0:
        leftover_nodes = list(set(nodes) - sampled_nodes)
        # How many more do we need vs how many are available?
        num_to_sample = min(remaining_needed, len(leftover_nodes))
        if num_to_sample > 0:
            extra_samples = random.sample(leftover_nodes, num_to_sample)
            sampled_nodes.update(extra_samples)
            remaining_needed -= num_to_sample


    # Final check (optional, for debugging)
    if len(sampled_nodes) != k and k > 0:
        print(f"Warning: Final sample size is {len(sampled_nodes)}, but target was {k}.")


    filtered_edges = [(u, v) for (u, v) in edges if (u in sampled_nodes and v in sampled_nodes)]
    return list(sampled_nodes), filtered_edges


def write_subsampled_graph(output_filename, nodes_sub, edges_sub):
    """
    Writes the subsampled graph (nodes and edges) to a file.
    Adds a 'p af <num_nodes>' header line at the beginning.
    The output file typically uses the .af extension.
    Format:
      p af <num_nodes>
      node1
      node2
      ...
      #
      attacker1 attacked1
      ...

    """
    with open(output_filename, 'w', encoding='utf-8') as f:
        num_nodes = len(nodes_sub)
        f.write(f"p af {num_nodes}\n")

        # Write node labels 
        # Sort nodes for consistent output
        node_map = dict()
        for i,n in enumerate(list(nodes_sub)):
            f.write(f"#{i+1}\n")
            node_map[n] = i + 1


        # Write edges using labels 
        # Sort edges for consistent output
        for (u, v) in sorted(list(edges_sub)):
            f.write(f"{node_map[u]} {node_map[v]}\n")


if __name__ == '__main__':
    import argparse
    import os

    parser = argparse.ArgumentParser(description="Subsample an Argumentation Framework (.af file)")
    parser.add_argument("input_af", help="Path to the input .af file.")
    parser.add_argument("output_af", help="Path for the output subsampled .af file.")
    parser.add_argument("-p", "--proportion", type=float, required=True, help="Proportion of nodes to sample (e.g., 0.5 for 50%).")
    parser.add_argument("-m", "--method", choices=['random', 'degree', 'bfs', 'community'], default='random', help="Subsampling method.")
    parser.add_argument("-s", "--solution", help="Optional path to a solution file (for degree-extension method).")
    parser.add_argument("--seed", help="Optional seed node for BFS sampling.")

    args = parser.parse_args()

    # --- Input Validation ---
    if not os.path.exists(args.input_af):
        print(f"Error: Input file not found: {args.input_af}")
        exit(1)
    if not args.input_af.lower().endswith('.af'):
        print(f"Warning: Input file '{args.input_af}' does not end with .af")
    if not args.output_af.lower().endswith('.af'):
        print(f"Warning: Output file '{args.output_af}' does not end with .af")
    if not 0 < args.proportion <= 1.0:
         print(f"Error: Proportion must be between 0 (exclusive) and 1 (inclusive). Got: {args.proportion}")
         exit(1)
    if args.solution and not os.path.exists(args.solution):
         print(f"Error: Solution file not found: {args.solution}")
         exit(1)
    if args.method == 'degree' and args.solution is None:
         print("Info: Running degree-based sampling without solution file (only degree will be used).")

    # --- Processing ---
    try:
        print(f"Parsing graph from: {args.input_af}")
        nodes, edges = parse_graph(args.input_af)
        print(f"Parsed {len(nodes)} nodes and {len(edges)} edges.")
        if not nodes:
             print("Input graph has no nodes. Output will be empty.")
             nodes_sub, edges_sub = [], []
        else:
             solution_data = None
             if args.method == 'degree' and args.solution:
                 print(f"Parsing solution data from: {args.solution}")
                 solution_data = parse_solution(args.solution)
                 print(f"Parsed solution data for {len(solution_data)} arguments.")

             print(f"Applying '{args.method}' subsampling with proportion {args.proportion}...")

             if args.method == 'random':
                 nodes_sub, edges_sub = random_subsample(nodes, edges, args.proportion)
             elif args.method == 'degree':
                 nodes_sub, edges_sub = degree_extension_subsample(nodes, edges, args.proportion, solution_data)
             elif args.method == 'bfs':
                 nodes_sub, edges_sub = bfs_subsample(nodes, edges, args.proportion, seed=args.seed)
             elif args.method == 'community':
                 nodes_sub, edges_sub = community_subsample(nodes, edges, args.proportion)
             else:
                 # Should not happen due to choices in argparse
                 print(f"Error: Unknown method '{args.method}'")
                 exit(1)

             print(f"Subsampled graph has {len(nodes_sub)} nodes and {len(edges_sub)} edges.")

        # --- Output ---
        print(f"Writing subsampled graph to: {args.output_af}")
        write_subsampled_graph(args.output_af, nodes_sub, edges_sub)
        print("Done.")

    except ValueError as e:
        print(f"Error during processing: {e}")
        exit(1)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        # import traceback
        # traceback.print_exc() # Uncomment for detailed debugging info
        exit(1)