#!/usr/bin/env python3
import re, sys, os
import xml.etree.ElementTree as ET

if len(sys.argv) > 1:
    svg_file = sys.argv[1]
else:
    svg_file = input("Enter SVG file path: ").strip()

if not os.path.exists(svg_file):
    print(f"Error: File not found: {svg_file}")
    sys.exit(1)

tree = ET.parse(svg_file)
root = tree.getroot()

paths = root.findall(".//{http://www.w3.org/2000/svg}path")

all_xs, all_ys = [], []

for path in paths:
    d = path.get("d")
    if not d:
        continue
    
    x, y = 0, 0
    
    # Split into commands
    tokens = re.findall(r'[MmLlHhVvCcSsQqTtAaZz]|[-+]?[0-9]*\.?[0-9]+', d)
    
    i = 0
    while i < len(tokens):
        cmd = tokens[i]
        i += 1
        
        if cmd in 'Zz':
            continue
            
        # Get numbers following the command
        nums = []
        while i < len(tokens) and tokens[i] not in 'MmLlHhVvCcSsQqTtAaZz':
            nums.append(float(tokens[i]))
            i += 1
        
        if cmd == 'M':
            for j in range(0, len(nums), 2):
                x, y = nums[j], nums[j+1]
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'm':
            for j in range(0, len(nums), 2):
                x += nums[j]
                y += nums[j+1]
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'L':
            for j in range(0, len(nums), 2):
                x, y = nums[j], nums[j+1]
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'l':
            for j in range(0, len(nums), 2):
                x += nums[j]
                y += nums[j+1]
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'H':
            for val in nums:
                x = val
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'h':
            for val in nums:
                x += val
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'V':
            for val in nums:
                y = val
                all_xs.append(x)
                all_ys.append(y)
        elif cmd == 'v':
            for val in nums:
                y += val
                all_xs.append(x)
                all_ys.append(y)
        elif cmd in 'CcSsQqTt':
            # For curves, just track the endpoints
            if cmd in 'Cc':
                for j in range(0, len(nums), 6):
                    if cmd == 'C':
                        x, y = nums[j+4], nums[j+5]
                    else:
                        x += nums[j+4]
                        y += nums[j+5]
                    all_xs.append(x)
                    all_ys.append(y)
            elif cmd in 'Ss':
                for j in range(0, len(nums), 4):
                    if cmd == 'S':
                        x, y = nums[j+2], nums[j+3]
                    else:
                        x += nums[j+2]
                        y += nums[j+3]
                    all_xs.append(x)
                    all_ys.append(y)
            elif cmd in 'Qq':
                for j in range(0, len(nums), 4):
                    if cmd == 'Q':
                        x, y = nums[j+2], nums[j+3]
                    else:
                        x += nums[j+2]
                        y += nums[j+3]
                    all_xs.append(x)
                    all_ys.append(y)
            elif cmd in 'Tt':
                for j in range(0, len(nums), 2):
                    if cmd == 'T':
                        x, y = nums[j], nums[j+1]
                    else:
                        x += nums[j]
                        y += nums[j+1]
                    all_xs.append(x)
                    all_ys.append(y)

if not all_xs or not all_ys:
    print("Error: No valid path data found")
    sys.exit(1)

min_x, max_x = min(all_xs), max(all_xs)
min_y, max_y = min(all_ys), max(all_ys)
width = round(max_x - min_x, 1)
height = round(max_y - min_y, 1)

root.set("viewBox", f"0 0 {width} {height}")

# Translate paths
for path in paths:
    d = path.get("d")
    if not d:
        continue
    
    # Simple regex replacement for absolute coordinates
    def translate_m(m):
        return f"M{round(float(m.group(1)) - min_x, 1)},{round(float(m.group(2)) - min_y, 1)}"
    
    def translate_l(m):
        return f"L{round(float(m.group(1)) - min_x, 1)},{round(float(m.group(2)) - min_y, 1)}"
    
    def translate_h(m):
        return f"H{round(float(m.group(1)) - min_x, 1)}"
    
    def translate_v(m):
        return f"V{round(float(m.group(1)) - min_y, 1)}"
    
    d = re.sub(r'M\s*([-+]?[0-9]*\.?[0-9]+)[,\s]+([-+]?[0-9]*\.?[0-9]+)', translate_m, d)
    d = re.sub(r'L\s*([-+]?[0-9]*\.?[0-9]+)[,\s]+([-+]?[0-9]*\.?[0-9]+)', translate_l, d)
    d = re.sub(r'H\s*([-+]?[0-9]*\.?[0-9]+)', translate_h, d)
    d = re.sub(r'V\s*([-+]?[0-9]*\.?[0-9]+)', translate_v, d)
    
    path.set("d", d)

input_dir = os.path.dirname(svg_file)
input_name = os.path.splitext(os.path.basename(svg_file))[0]
output_file = os.path.join(input_dir, f"{input_name}-cropped-{width}x{height}.svg")

ET.register_namespace('', 'http://www.w3.org/2000/svg')
tree.write(output_file, encoding='UTF-8', xml_declaration=True)

print(f"Cropped viewBox: 0 0 {width} {height}")
print(f"Translated by: {-min_x:.1f}, {-min_y:.1f}")
print(f"Saved to: {output_file}")
