#!/usr/bin/env python3
"""
Contour-accurate: extrude ONLY black pixels, respecting holes.
"""

import numpy as np
from PIL import Image
import cv2
from shapely.geometry import Polygon, MultiPolygon, LinearRing
from shapely.ops import unary_union
import trimesh


def mask_to_polygons(mask, pixel_size_mm):
    """
    Convert a binary mask to shapely polygons,
    preserving holes using OpenCV's CCOMP (two-level) hierarchy.
    """
    contours, hierarchy = cv2.findContours(
        mask,
        cv2.RETR_CCOMP,
        cv2.CHAIN_APPROX_NONE
    )

    if hierarchy is None:
        return []

    hierarchy = hierarchy[0]

    polygons = []
    idx_map = {}

    # First pass: create polygon shells
    for i, (cnt, h) in enumerate(zip(contours, hierarchy)):
        parent = h[3]

        pts = cnt[:, 0, :] * pixel_size_mm

        if parent == -1:
            # Top-level contour → shell
            poly = Polygon(pts)
            if poly.is_valid:
                polygons.append(poly)
                idx_map[i] = len(polygons) - 1

    # Second pass: add holes
    for i, (cnt, h) in enumerate(zip(contours, hierarchy)):
        parent = h[3]

        if parent != -1:
            # Child contour → hole
            pts = cnt[:, 0, :] * pixel_size_mm
            ring = LinearRing(pts)

            if ring.is_valid and parent in idx_map:
                shell_index = idx_map[parent]
                shell = polygons[shell_index]
                new_poly = Polygon(shell.exterior.coords, list(shell.interiors) + [ring])
                polygons[shell_index] = new_poly

    return polygons


def image_to_stl_black_pixels(
    input_path: str,
    output_path: str,
    height_mm: float = 5.0,
    pixel_size_mm: float = 0.5,
    threshold: int = 128,
):

    img = Image.open(input_path).convert("L")
    arr = np.array(img)

    # Black mask (invert threshold → black=255, white=0)
    _, bw = cv2.threshold(arr, threshold, 255, cv2.THRESH_BINARY_INV)

    # Extract polygons with holes preserved
    polygons = mask_to_polygons(bw, pixel_size_mm)

    if not polygons:
        raise RuntimeError("No black regions detected")

    meshes = []
    for poly in polygons:
        if poly.area > 0:
            tm = trimesh.creation.extrude_polygon(
                poly,
                height_mm,
                triangulation_engine="manifold"
            )
            meshes.append(tm)

    merged = trimesh.util.concatenate(meshes)
    merged.export(output_path)
    print(f"Saved corrected STL: {output_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("input_image")
    parser.add_argument("output_stl")
    parser.add_argument("--height-mm", type=float, default=5.0)
    parser.add_argument("--pixel-size-mm", type=float, default=0.5)
    parser.add_argument("--threshold", type=int, default=128)
    args = parser.parse_args()

    image_to_stl_black_pixels(
        args.input_image,
        args.output_stl,
        height_mm=args.height_mm,
        pixel_size_mm=args.pixel_size_mm,
        threshold=args.threshold,
    )
