diff --git a/getdp_path.txt b/getdp_path.txt index 8e9754e..9404475 100644 --- a/getdp_path.txt +++ b/getdp_path.txt @@ -1 +1 @@ -/home/laura/getdp-3.5.0/getdp-3.5.0-Linux64/bin/getdp +/home/sarah/getdp-3.5.0-Linux64c/getdp-3.5.0-Linux64/bin/getdp diff --git a/requirements.txt b/requirements.txt index 31109b4..6f0519a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,9 @@ matplotlib numpy Pillow scipy -setuptools \ No newline at end of file +setuptools +opencv-python +svgwrite +PyYAML +pytest +svgpathtools \ No newline at end of file diff --git a/sketchgetdp/bezier/BezierCurve.py b/sketchgetdp/bezier/BezierCurve.py deleted file mode 100644 index b634a1e..0000000 --- a/sketchgetdp/bezier/BezierCurve.py +++ /dev/null @@ -1,95 +0,0 @@ -"""This module contains the BezierCurve class, which represents a single Bézier curve. - -Author: Laura D'Angelo -""" - -import math -import matplotlib.pyplot as plt -import numpy as np - - -class BezierCurve: - """This class represents a single Bézier curve. - - Attributes: - control_points (np.array): A (degree+1) x 2 array of control points for the Bézier curve. - degree (int): The degree of the Bézier curve. - """ - - def __init__(self, control_points: np.array) -> "BezierCurve": - """The constructor for the BezierCurve class. - - Parameters: - control_points (np.array): An array of control points for the Bézier curve. - """ - self.control_points = control_points - self.degree = np.size(control_points, 0) - 1 - - def evaluate(self, t: np.array) -> np.array: - """This method evaluates the Bézier curve at given parameters t. - - Parameters: - t (np.array): The parameters at which to evaluate the Bézier curve. - - Returns: - np.array: The evaluated points on the Bézier curve. - """ - # Ensure t has the correct shape - if t.ndim == 1: - t = t[:, np.newaxis] - if np.size(t, 0) < np.size(t, 1): - t = np.transpose(t) - - # Evaluate the Bézier curve using the Bernstein polynomial - value = np.zeros((np.size(t, 0), 2)) - n = self.degree - for i in range(n + 1): - value += ( - math.comb(n, i) * t**i * (1 - t) ** (n - i) * self.control_points[i, :] - ) - return value - - def evaluate_derivative(self, t: np.array) -> np.array: - """This method evaluates the derivative of the Bézier curve at given parameters t. - - Parameters: - t (np.array): The parameters at which to evaluate the derivative of the Bézier curve. - - Returns: - np.array: The evaluated points on the derivative of the Bézier curve. - """ - # Ensure t has the correct shape - if t.ndim == 1: - t = t[:, np.newaxis] - if np.size(t, 0) < np.size(t, 1): - t = np.transpose(t) - - # Evaluate the derivative of the Bézier curve - value = np.zeros((np.size(t, 0), 2)) - n = self.degree - for i in range(n): - value += ( - math.comb(n - 1, i) - * t**i - * (1 - t) ** (n - i - 1) - * (self.control_points[i + 1, :] - self.control_points[i, :]) - ) - return value - - def plot(self) -> None: - """This method plots the Bézier curve and its control polygon. - - Returns: - None - """ - t = np.linspace(0, 1, 100) - evaluated_points = self.evaluate(t) - plt.plot(evaluated_points[:, 0], evaluated_points[:, 1], label="Bézier Curve") - plt.plot( - self.control_points[:, 0], - self.control_points[:, 1], - "ro--", - label="Control Points", - ) - plt.legend() - plt.show() diff --git a/sketchgetdp/bezier/__init__.py b/sketchgetdp/bezier/__init__.py deleted file mode 100644 index c4c4870..0000000 --- a/sketchgetdp/bezier/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .BezierCurve import BezierCurve - -__all__ = ['BezierCurve'] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/README.md b/sketchgetdp/bitmap_tracer/README.md new file mode 100644 index 0000000..aa74e08 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/README.md @@ -0,0 +1,149 @@ +# Bitmap Tracer + +A sophisticated image-to-SVG tracing application that converts bitmap images into clean, scalable vector graphics with intelligent color categorization and structure filtering. + +## 🎯 Overview + +Bitmap Tracer is a Python-based tool that analyzes bitmap images and converts them into SVG vector graphics. It features: + +- **Smart color categorization** (Red, Blue, Green) +- **Intelligent curve fitting** for optimal shape preservation +- **Configurable structure filtering** to keep only the most important elements +- **Point detection** for small, compact shapes +- **Automatic contour closure** ensuring all paths form complete loops + +## 🏗️ Architecture + +The project follows Clean Architecture principles with clear separation of concerns: + +### Core Layers + +- **`core/`** - Enterprise business rules + - `entities/` - Domain models (Point, Contour, Color) + - `use_cases/` - Application logic (Image Tracing, Structure Filtering) + +- **`infrastructure/`** - Frameworks & drivers + - `image_processing/` - Contour detection, color analysis, closure services + - `svg_generation/` - SVG creation and shape processing + - `configuration/` - Config loading and management + - `point_detection/` - Point detection and curve fitting + +- **`interfaces/`** - Interface adapters + - `controllers/` - Application flow control + - `presenters/` - Output formatting (SVG presentation) + - `gateways/` - External interfaces (image loading, config access) + +## 🚀 Key Features + +### Color Categorization +- Automatically detects and categorizes strokes into Red, Blue, and Green +- Red shapes are reserved exclusively for point markers +- Ignores white/black background colors + +### Smart Curve Fitting +- Hybrid approach using lines for straight segments and curves for curved segments +- Preserves actual shape while smoothing where appropriate +- Automatic contour closure with distance validation + +### Configurable Filtering +- Control the number of structures kept for each color via YAML configuration +- Filters by area, keeping only the largest structures +- Hierarchical filtering to remove nested contours + +### Point Detection +- Identifies small, compact shapes as points +- Creates simple dot markers at contour centers +- Unified sorting with larger red structures + +## 📁 Project Structure + +``` +bitmap_tracer/ +├── core/ # Business logic +│ ├── entities/ # Domain models +│ └── use_cases/ # Application services +├── infrastructure/ # External concerns +│ ├── image_processing/ # Computer vision +│ ├── svg_generation/ # Vector output +│ ├── configuration/ # Config management +│ └── point_detection/ # Point analysis +├── interfaces/ # Adapters +│ ├── controllers/ # Flow control +│ ├── presenters/ # Output formatting +│ └── gateways/ # External interfaces +├── __main__.py # Python module entry point +└── config.yaml # Configuration +``` + +## ⚙️ Configuration + +Configure the tracing behavior in `config.yaml`: + +```yaml +## Structure Limits +# Maximum number of structures to keep for each color category after filtering. +# Structures are sorted by area (largest first) and only the top N are kept. +red_dots: 1 # Maximum red points to preserve +blue_paths: 1 # Maximum blue paths to preserve +green_paths: 1 # Maximum green paths to preserve + +## Contour Detection Parameters +# Control how contours are detected and filtered from the source image. +point_max_area: 2000 # Maximum area for a contour to be classified as a point +point_max_perimeter: 1000 # Maximum perimeter for point classification + +## Color Detection Parameters +# Define thresholds for categorizing colors in the source image. +blue_hue_range: [100, 140] # HSV hue range for blue color detection +red_hue_range: [[0, 10], [170, 180]] # HSV hue ranges for red color detection +green_hue_range: [35, 85] # HSV hue range for green color detection +min_saturation: 50 # Minimum saturation to avoid classifying as white +max_value_white: 200 # Maximum value above which colors are considered white +min_value_black: 50 # Minimum value below which colors are considered black +``` + +## 🛠️ Usage + +The Bitmap Tracer can be run from the command line in two ways: + +### From the sketchgetdp directory as a python module: +```bash +python -m bitmap_tracer +``` + +### From the bitmap_tracer directory: +```bash +python main.py +``` + +Where `` is the path to the bitmap image you want to convert to SVG. + +The application will automatically: +- Load configuration from `config.yaml` +- Process the input image +- Generate an SVG output file with the same name as the input image (changing extension to .svg) +- Apply color categorization and structure filtering based on your configuration + +## 📊 Output + +The tracer generates SVG files with: +- **Blue paths** - Curved and straight segments from blue strokes +- **Green paths** - Curved and straight segments from green strokes +- **Red points** - Simple dot markers from red shapes and small points +- Clean, optimized vector paths suitable for scaling and further processing + +## 🔧 Dependencies + +- OpenCV - Image processing and contour detection +- NumPy - Numerical computations +- svgwrite - SVG generation +- PyYAML - Configuration parsing + +## 🎨 Use Cases + +- Converting hand-drawn sketches to vector graphics +- Processing technical diagrams and schematics +- Creating scalable versions of bitmap artwork +- Extracting structured information from images + +The Bitmap Tracer excels at transforming complex bitmap images into clean, manageable vector representations while preserving the essential structure and color information. \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/__init__.py b/sketchgetdp/bitmap_tracer/__init__.py new file mode 100644 index 0000000..319b7ae --- /dev/null +++ b/sketchgetdp/bitmap_tracer/__init__.py @@ -0,0 +1,8 @@ +""" +Bitmap Tracer Package + +A clean architecture implementation for converting bitmap images to SVG vector graphics. +""" + +__version__ = "2.0.0" +__author__ = "Sarah Schleidt" \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/__main__.py b/sketchgetdp/bitmap_tracer/__main__.py new file mode 100644 index 0000000..a4643d1 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/__main__.py @@ -0,0 +1,219 @@ +""" +Bitmap Tracer Application - Entry Point + +The application converts bitmap images to SVG vector graphics through a structured +process of contour detection, color analysis, and vector path generation. +""" + +import sys +import os +import argparse + +from interfaces.controllers.tracing_controller import TracingController + + +def find_config_file(config_path: str) -> str: + """ + Find configuration file, checking multiple possible locations. + + Priority order: + 1. User-specified path (absolute or relative to cwd) + 2. Relative to current working directory + 3. In the package directory (for default config) + + Returns: + Path to the first found config file, or original path if none found. + """ + from pathlib import Path + + search_paths = [ + Path(config_path), # User-specified path + Path.cwd() / config_path, # Current working directory + Path(__file__).parent / config_path, # Package directory (where main.py lives) + ] + + for path in search_paths: + if path.exists(): + print(f"✅ Found configuration file: {path}") + return str(path) + + print(f"⚠️ Configuration file not found: {config_path}, using defaults") + return config_path # Return original if not found anywhere + + +def validate_input_file_exists(file_path: str) -> None: + """ + Validates that the specified file exists and is readable. + + Args: + file_path: Absolute or relative path to the file to validate. + + Raises: + FileNotFoundError: When the specified file does not exist. + PermissionError: When the file exists but cannot be read. + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"Input image not found: {file_path}") + + if not os.access(file_path, os.R_OK): + raise PermissionError(f"Cannot read input image: {file_path}") + + +def parse_command_line_arguments() -> argparse.Namespace: + """ + Parses and validates command-line arguments provided by the user. + + Returns: + Parsed arguments object containing: + - input_image: Path to source bitmap file + - output: Path for generated SVG file + - config: Path to configuration file + + Raises: + SystemExit: When help is requested or arguments are invalid. + """ + argument_parser = argparse.ArgumentParser( + description=( + 'Convert bitmap images to SVG vector graphics using ' + 'advanced computer vision techniques. The tracer detects ' + 'contours, analyzes colors, and generates optimized vector paths.' + ), + epilog=( + 'Example usage:\n' + ' python main.py drawing.jpg\n' + ' python main.py sketch.png -o output.svg -c settings.yaml\n' + ), + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + argument_parser.add_argument( + 'input_image', + help='Path to input bitmap image (supports JPEG, PNG, BMP formats)' + ) + + argument_parser.add_argument( + '-o', '--output', + default='output.svg', + help='Output SVG file path (default: output.svg)' + ) + + argument_parser.add_argument( + '-c', '--config', + default='config.yaml', + help='Configuration file controlling tracing behavior (default: config.yaml)' + ) + + arguments = argument_parser.parse_args() + + # Find the actual config file location + arguments.config = find_config_file(arguments.config) + + return arguments + + +def execute_tracing_pipeline(input_path: str, output_path: str, config_path: str) -> bool: + """ + Executes the complete bitmap-to-SVG tracing pipeline. + + Args: + input_path: Path to source bitmap image. + output_path: Path where SVG output will be saved. + config_path: Path to YAML configuration file. + + Returns: + True if SVG was generated successfully, False otherwise. + """ + try: + controller = TracingController() + + # Execute the tracing workflow + result = controller.trace_image( + image_path=input_path, + output_svg_path=output_path, + config_path=config_path + ) + + # Return success status + return result.get('success', False) + + except Exception as processing_error: + print(f"❌ Tracing pipeline error: {processing_error}") + return False + + +def log_application_startup(arguments: argparse.Namespace) -> None: + """ + Logs application startup parameters for user verification. + + Args: + arguments: Parsed command-line arguments containing execution parameters. + """ + print("🖼️ Bitmap Tracer Application Starting - Clean Architecture") + print("=" * 50) + print(f"📁 Input Image: {arguments.input_image}") + print(f"📁 Output SVG: {arguments.output}") + print(f"⚙️ Configuration: {arguments.config}") + print("=" * 50) + + +def log_application_result(success: bool, output_path: str = "") -> None: + """ + Logs the final result of the tracing operation. + + Args: + success: True if tracing completed successfully, False otherwise. + output_path: Path to the generated SVG file (on success). + """ + if success: + print(f"✅ Tracing completed successfully - SVG file generated: {output_path}") + else: + print("❌ Tracing failed - check error messages above for details.") + + +def main() -> None: + """ + Entry point for the Bitmap Tracer. + + This function orchestrates the complete application workflow: + 1. Parse and validate command-line arguments + 2. Verify input file existence and accessibility + 3. Execute the tracing pipeline via TracingController + 4. Provide clear success/failure feedback + 5. Return appropriate exit codes + + System Exit Codes: + 0: Success - SVG file generated successfully + 1: Failure - Invalid input, processing error, or file issues + 2: System error - Unexpected application failure + """ + try: + arguments = parse_command_line_arguments() + validate_input_file_exists(arguments.input_image) + log_application_startup(arguments) + + tracing_success = execute_tracing_pipeline( + input_path=arguments.input_image, + output_path=arguments.output, + config_path=arguments.config + ) + + log_application_result(tracing_success, arguments.output) + exit_code = 0 if tracing_success else 1 + sys.exit(exit_code) + + except FileNotFoundError as file_error: + print(f"❌ File error: {file_error}") + sys.exit(1) + except PermissionError as permission_error: + print(f"❌ Permission error: {permission_error}") + sys.exit(1) + except KeyboardInterrupt: + print("\n⚠️ Operation cancelled by user") + sys.exit(1) + except Exception as unexpected_error: + print(f"💥 Unexpected application error: {unexpected_error}") + sys.exit(2) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/config.yaml b/sketchgetdp/bitmap_tracer/config.yaml new file mode 100644 index 0000000..586ee8a --- /dev/null +++ b/sketchgetdp/bitmap_tracer/config.yaml @@ -0,0 +1,26 @@ +# Bitmap Tracer Configuration +# +# This configuration file controls the behavior of the bitmap tracing process. +# All parameters have sensible defaults and can be adjusted based on the +# characteristics of the input images and desired output quality. + +## Structure Limits +# Maximum number of structures to keep for each color category after filtering. +# Structures are sorted by area (largest first) and only the top N are kept. +red_dots: 1 # Maximum red points to preserve +blue_paths: 1 # Maximum blue paths to preserve +green_paths: 1 # Maximum green paths to preserve + +## Contour Detection Parameters +# Control how contours are detected and filtered from the source image. +point_max_area: 2000 # Maximum area for a contour to be classified as a point +point_max_perimeter: 1000 # Maximum perimeter for point classification + +## Color Detection Parameters +# Define thresholds for categorizing colors in the source image. +blue_hue_range: [100, 140] # HSV hue range for blue color detection +red_hue_range: [[0, 10], [170, 180]] # HSV hue ranges for red color detection +green_hue_range: [35, 85] # HSV hue range for green color detection +min_saturation: 50 # Minimum saturation to avoid classifying as white +max_value_white: 200 # Maximum value above which colors are considered white +min_value_black: 50 # Minimum value below which colors are considered black diff --git a/sketchgetdp/geometry/__init__.py b/sketchgetdp/bitmap_tracer/core/__init__.py similarity index 100% rename from sketchgetdp/geometry/__init__.py rename to sketchgetdp/bitmap_tracer/core/__init__.py diff --git a/sketchgetdp/bitmap_tracer/core/entities/__init__.py b/sketchgetdp/bitmap_tracer/core/entities/__init__.py new file mode 100644 index 0000000..76004ba --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/entities/__init__.py @@ -0,0 +1,22 @@ +""" +Core business entities for bitmap tracing. +These objects represent the fundamental concepts that drive the tracing algorithm: +- Spatial coordinates and relationships +- Shape boundaries and properties +- Color classification and standardization + +The entities contain the business rules that determine how bitmap features +are interpreted and converted to vector graphics. +""" + +from .point import Point, PointData +from .contour import Contour +from .color import Color, ColorCategory + +__all__ = [ + 'Point', + 'PointData', + 'Contour', + 'Color', + 'ColorCategory' +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/entities/color.py b/sketchgetdp/bitmap_tracer/core/entities/color.py new file mode 100644 index 0000000..52c8b4d --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/entities/color.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Tuple + + +class ColorCategory(Enum): + """ + The three primary colors we track plus ignored categories. + This classification drives the entire tracing strategy. + """ + BLUE = "blue" + RED = "red" + GREEN = "green" + WHITE = "white" # Background - ignored + BLACK = "black" # Noise - ignored + OTHER = "other" # Unsupported colors - ignored + + +@dataclass(frozen=True) +class Color: + """ + Represents a color in BGR format (OpenCV standard). + Immutable to ensure consistent color handling throughout the pipeline. + """ + b: int + g: int + r: int + + CATEGORY_HEX_COLORS = { + ColorCategory.BLUE: "#0000FF", + ColorCategory.RED: "#FF0000", + ColorCategory.GREEN: "#00FF00" + } + + def to_bgr_tuple(self) -> Tuple[int, int, int]: + """BGR format for image processing libraries.""" + return (self.b, self.g, self.r) + + def to_rgb_tuple(self) -> Tuple[int, int, int]: + """Standard RGB format for web and graphics applications.""" + return (self.r, self.g, self.b) + + def to_hex(self) -> str: + """Hex format for SVG color attributes.""" + return f"#{self.r:02x}{self.g:02x}{self.b:02x}".upper() + + def categorize(self) -> Tuple[ColorCategory, Optional[str]]: + """ + Core color classification logic. + Uses HSV space for more accurate color perception than RGB. + Returns both category and standardized output color. + """ + import cv2 + import numpy as np + + bgr_array = np.uint8([[[self.b, self.g, self.r]]]) + hsv = cv2.cvtColor(bgr_array, cv2.COLOR_BGR2HSV)[0][0] + hue, saturation, value = hsv + + # High value + low saturation = white/light colors (background) + if value > 200 and saturation < 50: + return ColorCategory.WHITE, None + + # Low value = dark colors (noise) + if value < 50: + return ColorCategory.BLACK, None + + # Primary color detection uses both HSV ranges and RGB relationships + # as fallback for edge cases + if (hue >= 100 and hue <= 140) or (self.b > self.g + 20 and self.b > self.r + 20): + return ColorCategory.BLUE, self.CATEGORY_HEX_COLORS[ColorCategory.BLUE] + elif (hue >= 0 and hue <= 10) or (hue >= 170 and hue <= 180) or (self.r > self.g + 20 and self.r > self.b + 20): + return ColorCategory.RED, self.CATEGORY_HEX_COLORS[ColorCategory.RED] + elif (hue >= 35 and hue <= 85) or (self.g > self.r + 20 and self.g > self.b + 20): + return ColorCategory.GREEN, self.CATEGORY_HEX_COLORS[ColorCategory.GREEN] + else: + return ColorCategory.OTHER, None + + def is_ignored_color(self) -> bool: + """Determines if this color should be excluded from tracing results.""" + category, _ = self.categorize() + return category in [ColorCategory.WHITE, ColorCategory.BLACK, ColorCategory.OTHER] + + def is_primary_color(self) -> bool: + """Checks if this is one of the three colors we actively trace.""" + category, _ = self.categorize() + return category in [ColorCategory.BLUE, ColorCategory.RED, ColorCategory.GREEN] + + @classmethod + def from_bgr_tuple(cls, bgr_tuple: Tuple[int, int, int]) -> 'Color': + """Primary constructor - images from OpenCV are in BGR format.""" + return cls(b=bgr_tuple[0], g=bgr_tuple[1], r=bgr_tuple[2]) + + @classmethod + def from_rgb_tuple(cls, rgb_tuple: Tuple[int, int, int]) -> 'Color': + """Alternative constructor for RGB sources.""" + return cls(b=rgb_tuple[2], g=rgb_tuple[1], r=rgb_tuple[0]) + + @classmethod + def from_hex(cls, hex_code: str) -> 'Color': + """Constructor for web colors and configuration values.""" + hex_code = hex_code.lstrip('#') + + # Support both #RGB and #RRGGBB formats + if len(hex_code) == 3: + hex_code = ''.join(character * 2 for character in hex_code) + + red = int(hex_code[0:2], 16) + green = int(hex_code[2:4], 16) + blue = int(hex_code[4:6], 16) + + return cls(b=blue, g=green, r=red) + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/entities/contour.py b/sketchgetdp/bitmap_tracer/core/entities/contour.py new file mode 100644 index 0000000..e1333b3 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/entities/contour.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +from typing import List, Optional +import numpy as np +import cv2 + +from .point import Point + + +@dataclass +class Contour: + """ + A closed shape detected in the bitmap image. + The closure status is critical for proper SVG path generation. + """ + points: List[Point] + is_closed: bool + closure_gap: float + + def __post_init__(self): + """Make a defensive copy of the points list to prevent external mutation.""" + self.points = self.points.copy() + + @property + def area(self) -> float: + """ + Calculates area using the shoelace formula. + Used for filtering out noise and prioritizing larger structures. + """ + if len(self.points) < 3: + return 0.0 + + x_coordinates = [point.x for point in self.points] + y_coordinates = [point.y for point in self.points] + + # Shoelace formula: ∑(x_i * y_i+1 - x_i+1 * y_i) / 2 + area = 0.5 * abs(sum( + x_coordinates[i] * y_coordinates[i + 1] - x_coordinates[i + 1] * y_coordinates[i] + for i in range(len(x_coordinates) - 1) + ) + (x_coordinates[-1] * y_coordinates[0] - x_coordinates[0] * y_coordinates[-1])) + + return area + + @property + def perimeter(self) -> float: + """Total boundary length, used for circularity calculation and simplification thresholds.""" + if len(self.points) < 2: + return 0.0 + + perimeter = 0.0 + for i in range(len(self.points)): + current_point = self.points[i] + next_point = self.points[(i + 1) % len(self.points)] # Wrap for closed contour + perimeter += current_point.distance_to(next_point) + + return perimeter + + @property + def circularity(self) -> float: + """ + Measures how circular the shape is (4πA/P²). + Perfect circle = 1.0, other shapes < 1.0. + Used to filter out irregular noise artifacts. + """ + area = self.area + perimeter = self.perimeter + + if perimeter == 0: + return 0.0 + + return (4 * np.pi * area) / (perimeter * perimeter) + + def get_center(self) -> Optional[Point]: + """Centroid calculation for point marker placement and spatial analysis.""" + if not self.points: + return None + + sum_x = sum(point.x for point in self.points) + sum_y = sum(point.y for point in self.points) + + return Point(sum_x / len(self.points), sum_y / len(self.points)) + + @classmethod + def from_numpy_contour(cls, contour: np.ndarray, tolerance: float = 5.0) -> 'Contour': + """ + Converts OpenCV contour format to domain representation. + The tolerance parameter controls how close endpoints must be to consider the contour closed. + """ + if len(contour) == 0: + return cls(points=[], is_closed=True, closure_gap=0.0) + + # OpenCV contours are nested arrays: [[[x, y]]], [[[x, y]]], ... + points = [Point(float(point[0][0]), float(point[0][1])) for point in contour] + + if len(points) < 3: + return cls(points=points, is_closed=False, closure_gap=0.0) + + # Closure detection: if start and end points are within tolerance, contour is closed + start_point = points[0] + end_point = points[-1] + closure_gap = start_point.distance_to(end_point) + + points_are_identical = (start_point.x == end_point.x and start_point.y == end_point.y) + + # Consider contour closed if either: + # 1. Points are within tolerance (natural closure) + # 2. Points are identical (explicit closure by closure service) + is_closed = closure_gap <= tolerance or points_are_identical + + actual_closure_gap = 0.0 if points_are_identical else closure_gap + + # Debug output to verify closure detection + closure_type = "explicit" if points_are_identical else "natural" if is_closed else "open" + print(f" 🔍 Contour closure: {closure_type}, gap: {actual_closure_gap:.2f}px, points: {len(points)}") + + return cls(points=points, is_closed=is_closed, closure_gap=actual_closure_gap) + + def to_numpy(self) -> np.ndarray: + """ + Convert contour points to OpenCV numpy format. + + Returns: + Numpy array in format [[[x, y]], [[x, y]], ...] for OpenCV compatibility + """ + points_array = np.array([[point.x, point.y] for point in self.points], dtype=np.float32) + return points_array.reshape(-1, 1, 2) + + def is_empty(self) -> bool: + """ + Check if contour has no points. + + Returns: + True if contour has no points, False otherwise + """ + return len(self.points) == 0 + + def get_bounding_box(self) -> Optional[tuple]: + """ + Calculate the axis-aligned bounding box of the contour. + + Returns: + Tuple (min_x, min_y, max_x, max_y) or None if contour is empty + """ + if not self.points: + return None + + x_coords = [point.x for point in self.points] + y_coords = [point.y for point in self.points] + + return (min(x_coords), min(y_coords), max(x_coords), max(y_coords)) + + def simplify(self, epsilon: float = 1.0) -> 'Contour': + """ + Simplify the contour using Douglas-Peucker algorithm. + + Args: + epsilon: Approximation accuracy parameter + + Returns: + New simplified Contour instance + """ + if len(self.points) < 3: + return self + + # Convert to numpy for OpenCV processing + numpy_contour = self.to_numpy() + + # Apply Douglas-Peucker simplification + simplified_numpy = cv2.approxPolyDP(numpy_contour, epsilon, self.is_closed) + + # Convert back to domain entity + return Contour.from_numpy_contour(simplified_numpy) + + def __len__(self) -> int: + """Return the number of points in the contour.""" + return len(self.points) + + def __repr__(self) -> str: + """String representation for debugging.""" + status = "CLOSED" if self.is_closed else "OPEN" + return f"Contour(points={len(self.points)}, {status}, area={self.area:.1f}, gap={self.closure_gap:.2f}px)" + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/entities/point.py b/sketchgetdp/bitmap_tracer/core/entities/point.py new file mode 100644 index 0000000..7ef4bd9 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/entities/point.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass +from typing import Tuple + + +@dataclass(frozen=True) +class Point: + """ + Represents a coordinate in 2D space. + This is a value object and should be immutable. + """ + x: float + y: float + + def to_tuple(self) -> Tuple[float, float]: + """Required for compatibility with OpenCV and other libraries that expect tuples.""" + return (self.x, self.y) + + def distance_to(self, other: 'Point') -> float: + """Euclidean distance calculation for spatial analysis.""" + return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5 + + @classmethod + def from_tuple(cls, point_tuple: Tuple[float, float]) -> 'Point': + """Factory method for creating Points from external data formats.""" + return cls(x=point_tuple[0], y=point_tuple[1]) + + +@dataclass(frozen=True) +class PointData: + """ + Enhanced point information for the tracing algorithm. + Contains metadata needed for point detection and SVG generation. + """ + x: float + y: float + radius: float = 0.0 + is_small_point: bool = False + + @property + def center(self) -> Point: + """The center coordinate is the primary spatial identifier.""" + return Point(self.x, self.y) + + def to_point(self) -> Point: + """Extracts the basic spatial information when full metadata isn't needed.""" + return Point(self.x, self.y) + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/use_cases/__init__.py b/sketchgetdp/bitmap_tracer/core/use_cases/__init__.py new file mode 100644 index 0000000..43b754b --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/use_cases/__init__.py @@ -0,0 +1,29 @@ +""" +Application Business Rules Layer - Use Cases. + +This package contains the application-specific business rules that coordinate +the workflow between enterprise entities and interface adapters. Use cases +encapsulate and implement all of the application's business rules while +remaining independent of frameworks, UI, and databases. + +The use cases in this layer: +- Contain application-specific business logic +- Coordinate data flow between entities and adapters +- Define the application's behavior independent of delivery mechanisms +- Are the central organizing structure for the application's capabilities + +Use cases should: +- Be framework-agnostic +- Operate on enterprise entities +- Contain no infrastructure concerns +- Be easily testable in isolation +- Express the application's intent clearly +""" + +from .image_tracing import ImageTracingUseCase +from .structure_filtering import StructureFilteringUseCase + +__all__ = [ + 'ImageTracingUseCase', + 'StructureFilteringUseCase' +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/use_cases/image_tracing.py b/sketchgetdp/bitmap_tracer/core/use_cases/image_tracing.py new file mode 100644 index 0000000..6e55580 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/use_cases/image_tracing.py @@ -0,0 +1,140 @@ +import numpy as np +from typing import List, Tuple, Optional +from core.entities.point import Point +from core.entities.contour import Contour + + +class ImageTracingUseCase: + """Coordinates the image tracing workflow from bitmap contours to vector paths.""" + + def __init__(self, contour_detector=None, color_analyzer=None, point_detector=None): + """ + Initialize use case with required dependencies. + + Args: + contour_detector: Service for detecting contours in images + color_analyzer: Service for analyzing contour colors + point_detector: Service for identifying point structures + """ + self.contour_detector = contour_detector + self.color_analyzer = color_analyzer + self.point_detector = point_detector + + def execute(self, image_data: dict, config: dict) -> dict: + """ + Main execution method for the image tracing use case. + """ + try: + print("🔍 Detecting contours...") + # Detect contours from the image + contours = self.detect_contours(image_data) + print(f"📐 Found {len(contours)} contours") + + red_points = [] + blue_structures = [] + green_structures = [] + + # Process each contour + for i, contour in enumerate(contours): + print(f" Processing contour {i+1}/{len(contours)}...") + + # Categorize contour color + color_category = self.color_analyzer.categorize(contour, image_data['image_array']) + + # Check if it's a point + point = self.detect_points(contour, config) + + if point and color_category == 'red': + red_points.append(point) + print(f" 🔴 Contour {i+1}: RED POINT") + elif color_category == 'blue': + blue_structures.append(contour) + print(f" 🔵 Contour {i+1}: BLUE PATH") + elif color_category == 'green': + green_structures.append(contour) + print(f" 🟢 Contour {i+1}: GREEN PATH") + else: + print(f" ⚫ Contour {i+1}: UNCATEGORIZED (color: {color_category})") + + return { + 'success': True, + 'structures': { + 'red_points': red_points, + 'blue_structures': blue_structures, + 'green_structures': green_structures + }, + 'total_contours': len(contours), + 'processed_contours': len(red_points) + len(blue_structures) + len(green_structures) + } + + except Exception as error: + print(f"❌ Image tracing error: {error}") + import traceback + traceback.print_exc() + return { + 'success': False, + 'error': str(error), + 'structures': { + 'red_points': [], + 'blue_structures': [], + 'green_structures': [] + }, + 'total_contours': 0, + 'processed_contours': 0 + } + + def detect_contours(self, image_data) -> List[Contour]: + """ + Extracts contours from image data for vectorization. + """ + if self.contour_detector: + contours_tuple, hierarchy = self.contour_detector.detect(image_data) + + if contours_tuple is None: + return [] + + print(f"🔍 DEBUG: contours_tuple type: {type(contours_tuple)}, length: {len(contours_tuple)}") + + # Convert the tuple to a list for iteration + raw_contours_list = list(contours_tuple) + + if not raw_contours_list: + return [] + + # Convert all raw contours to Contour entities + contours = [self._convert_to_contour_entity(contour) for contour in raw_contours_list] + print(f"✅ Converted {len(contours)} contours to entities") + return contours + + print("⚠️ No contour detector available - returning empty list") + return [] + + def detect_points(self, contour: Contour, config: dict = None) -> Optional[Point]: + """ + Identifies if a contour represents a point marker rather than a path. + """ + if config and hasattr(self.point_detector, 'set_config'): + self.point_detector.set_config(config) + + numpy_contour = np.array([[[point.x, point.y]] for point in contour.points], dtype=np.int32) + point = self.point_detector.detect_point(numpy_contour) + + if point: + print(f" 📍 Point detected at ({point.x}, {point.y})") + else: + print(f" ❌ Point NOT detected - area: {contour.area:.1f}, perimeter: {contour.perimeter:.1f}, points: {len(contour.points)}") + + return point + + def _convert_to_contour_entity(self, raw_contour) -> Contour: + """ + Convert raw OpenCV contour to our domain Contour entity. + + Args: + raw_contour: Raw contour from OpenCV's findContours() + + Returns: + Contour entity with points and calculated properties + """ + return Contour.from_numpy_contour(raw_contour) + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/core/use_cases/structure_filtering.py b/sketchgetdp/bitmap_tracer/core/use_cases/structure_filtering.py new file mode 100644 index 0000000..f444f92 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/core/use_cases/structure_filtering.py @@ -0,0 +1,159 @@ +from typing import List, Tuple, Any, Dict +from core.entities.contour import Contour + + +class StructureFilteringUseCase: + """Applies business rules for filtering and prioritizing image structures.""" + + def __init__(self): + """ + Initialize use case. + """ + + def execute(self, structures: Dict[str, Any], config: Dict) -> Dict[str, Any]: + """ + Main execution method for the structure filtering use case. + """ + try: + print("🎯 Filtering structures based on configuration limits...") + + red_points = structures.get('red_points', []) + blue_structures = structures.get('blue_structures', []) + green_structures = structures.get('green_structures', []) + + # Apply configuration limits + max_red_dots = config.get('red_dots', 0) + max_blue_paths = config.get('blue_paths', 0) + max_green_paths = config.get('green_paths', 0) + + print(f"📊 Configuration limits: {max_red_dots} red, {max_blue_paths} blue, {max_green_paths} green") + + # Filter red points + if max_red_dots > 0 and len(red_points) > max_red_dots: + print(f" 🔴 Limiting red points from {len(red_points)} to {max_red_dots}") + red_points = red_points[:max_red_dots] + + # Filter blue structures + if max_blue_paths > 0 and len(blue_structures) > max_blue_paths: + print(f" 🔵 Limiting blue paths from {len(blue_structures)} to {max_blue_paths}") + blue_structures = blue_structures[:max_blue_paths] + + # Filter green structures + if max_green_paths > 0 and len(green_structures) > max_green_paths: + print(f" 🟢 Limiting green paths from {len(green_structures)} to {max_green_paths}") + green_structures = green_structures[:max_green_paths] + + filtered_structures = { + 'red_points': red_points, + 'blue_structures': blue_structures, + 'green_structures': green_structures + } + + total_filtered = len(red_points) + len(blue_structures) + len(green_structures) + print(f"✅ Filtering complete: {total_filtered} structures remaining") + + return filtered_structures + + except Exception as error: + print(f"❌ Structure filtering error: {error}") + import traceback + traceback.print_exc() + # Return original structures on error + return structures + + def filter_structures_by_area(self, + structures: List[Tuple[float, Any]], + max_count: int) -> List[Tuple[float, Any]]: + """ + Retains only the largest structures up to the specified count limit. + + Structures are sorted by area in descending order and the top N are kept. + This prioritization ensures the most significant structures are processed + while maintaining performance by limiting total output. + + Args: + structures: List of (area, structure_data) tuples to filter + max_count: Maximum number of structures to retain after filtering + + Returns: + Filtered list containing only the largest structures up to max_count + """ + if max_count <= 0: + return [] + + structures.sort(key=lambda x: x[0], reverse=True) + + if max_count < len(structures): + return structures[:max_count] + + return structures + + def filter_contours_by_size(self, + contours: List[Contour], + min_area: float, + max_area: float) -> List[Contour]: + """ + Removes contours that fall outside the acceptable size range. + + Filters out noise (too small) and background elements (too large) based + on area thresholds. This focuses processing on meaningful structures. + + Args: + contours: Contours to evaluate against size constraints + min_area: Minimum area threshold - contours smaller than this are excluded + max_area: Maximum area threshold - contours larger than this are excluded + + Returns: + Contours that meet the size criteria + """ + filtered_contours = [] + + for contour in contours: + area = contour.area + if min_area <= area <= max_area: + filtered_contours.append(contour) + + return filtered_contours + + def filter_by_circularity(self, + contours: List[Contour], + min_circularity: float = 0.01) -> List[Contour]: + """ + Eliminates contours with irregular shapes that likely represent noise. + + Circularity measures how close a shape is to a perfect circle. Very low + circularity values indicate elongated, fragmented, or noisy contours + that should be excluded from vectorization. + + Args: + contours: Contours to evaluate for shape regularity + min_circularity: Minimum circularity threshold (1.0 = perfect circle) + + Returns: + Contours with acceptable circularity values + """ + filtered_contours = [] + + for contour in contours: + if contour.perimeter > 0: + circularity = 4 * 3.14159 * contour.area / (contour.perimeter * contour.perimeter) + if circularity >= min_circularity: + filtered_contours.append(contour) + + return filtered_contours + + def sort_contours_by_area(self, contours: List[Contour], descending: bool = True) -> List[Contour]: + """ + Orders contours by their area for priority processing. + + Larger contours typically represent more important structures. Sorting + enables processing prioritization and consistent output ordering. + + Args: + contours: Contours to sort by area + descending: True for largest first, False for smallest first + + Returns: + Contours sorted by area + """ + return sorted(contours, key=lambda c: c.area, reverse=descending) diff --git a/sketchgetdp/bitmap_tracer/infrastructure/__init__.py b/sketchgetdp/bitmap_tracer/infrastructure/__init__.py new file mode 100644 index 0000000..4ba0985 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/__init__.py @@ -0,0 +1,35 @@ +""" +Infrastructure Layer - Frameworks & Drivers + +Contains concrete implementations of technical concerns and external interfaces. +This layer is the outermost in Clean Architecture and depends inward toward the core. + +Responsibilities: +- Image processing with OpenCV +- Shape processing +- Configuration file management +- Point detection and curve fitting algorithms + +Dependencies: +- Can depend on core entities and use cases +- Must not contain business logic +- Implements interfaces defined in the interfaces layer +""" + +from .image_processing import * +from .configuration import * +from .point_detection import * + +__all__ = [ + # Image processing components + "ContourDetector", + "ColorAnalyzer", + "ContourClosureService", + + # Configuration components + "ConfigLoader", + + # Point detection components + "PointDetector", + "CurveFitter", +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/configuration/__init__.py b/sketchgetdp/bitmap_tracer/infrastructure/configuration/__init__.py new file mode 100644 index 0000000..d511a22 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/configuration/__init__.py @@ -0,0 +1,11 @@ +""" +Configuration infrastructure module. + +This module provides services for loading and managing application configuration. +It follows the dependency inversion principle by implementing gateway interfaces +defined in the interfaces layer. +""" + +from .config_loader import ConfigLoader + +__all__ = ['ConfigLoader'] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/configuration/config_loader.py b/sketchgetdp/bitmap_tracer/infrastructure/configuration/config_loader.py new file mode 100644 index 0000000..4d9e1db --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/configuration/config_loader.py @@ -0,0 +1,97 @@ +""" +Configuration loader implementation. + +Responsible for loading configuration parameters from YAML files and providing +type-safe access to different configuration categories. This class implements +the ConfigRepository interface from the application core. +""" + +import yaml +import os +from typing import Tuple, Dict, Any, Optional +from interfaces.gateways.config_repository import ConfigRepository + + +class ConfigLoader(ConfigRepository): + """ + Loads and manages application configuration from YAML files. + """ + + def __init__(self, default_config_path: str = "config.yaml") -> None: + self.default_config_path = default_config_path + self._config_cache = None + self._overrides = {} # For runtime configuration overrides + + def load_config(self, config_path: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Load configuration data from the YAML file. + + Args: + config_path: Optional path to config file, uses default if not provided + + Returns: + Dictionary containing all configuration key-value pairs, or None if loading fails + """ + if self._config_cache is not None: + return self._apply_overrides(self._config_cache) + + actual_config_path = config_path or self.default_config_path + + if not os.path.exists(actual_config_path): + print(f"⚠️ Configuration file not found: {actual_config_path}, using defaults") + self._config_cache = {} + return self._apply_overrides(self._config_cache) + + try: + with open(actual_config_path, 'r') as file: + config = yaml.safe_load(file) + self._config_cache = config or {} + print(f"✅ Loaded configuration from: {actual_config_path}") + return self._apply_overrides(self._config_cache) + except yaml.YAMLError as e: + print(f"❌ Error parsing YAML configuration {actual_config_path}: {e}") + return None + except Exception as e: + print(f"❌ Error loading configuration {actual_config_path}: {e}") + return None + + def get_structure_limits(self) -> Tuple[int, int, int]: + """Get the maximum number of structures to keep for each color category.""" + config = self.load_config() or {} + + red_dots = config.get('red_dots', 0) + blue_paths = config.get('blue_paths', 0) + green_paths = config.get('green_paths', 0) + + return red_dots, blue_paths, green_paths + + def get_contour_detection_params(self) -> Dict[str, Any]: + """Get parameters for contour detection and filtering.""" + config = self.load_config() or {} + + return { + 'point_max_area': config.get('point_max_area', 100), + 'point_max_perimeter': config.get('point_max_perimeter', 80) + } + + def get_color_detection_params(self) -> Dict[str, Any]: + """Get parameters for color categorization.""" + config = self.load_config() or {} + + return { + 'blue_hue_range': config.get('blue_hue_range', [100, 140]), + 'red_hue_range': config.get('red_hue_range', [[0, 10], [170, 180]]), + 'green_hue_range': config.get('green_hue_range', [35, 85]), + 'min_saturation': config.get('min_saturation', 50), + 'max_value_white': config.get('max_value_white', 200), + 'min_value_black': config.get('min_value_black', 50) + } + + def _apply_overrides(self, config: Dict[str, Any]) -> Dict[str, Any]: + """Apply runtime overrides to the configuration.""" + if not self._overrides: + return config + + result = config.copy() + result.update(self._overrides) + return result + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/image_processing/__init__.py b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/__init__.py new file mode 100644 index 0000000..7ee676c --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/__init__.py @@ -0,0 +1,19 @@ +""" +Image processing infrastructure layer for the bitmap tracing system. + +This package provides the core image analysis capabilities including contour detection, +color analysis, and geometric processing. These components implement the framework-side +concerns of the Clean Architecture, handling OpenCV interactions and image processing +algorithms while exposing clean interfaces to the domain layer. +""" + +from .contour_detector import ContourDetector +from .color_analyzer import ColorAnalyzer +from .contour_closure_service import ContourClosureService, ClosedContour + +__all__ = [ + 'ContourDetector', + 'ColorAnalyzer', + 'ContourClosureService', + 'ClosedContour' +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/image_processing/color_analyzer.py b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/color_analyzer.py new file mode 100644 index 0000000..073ebd5 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/color_analyzer.py @@ -0,0 +1,240 @@ +import cv2 +import numpy as np +from collections import defaultdict +from typing import Tuple, Optional, Dict, List +from core.entities.color import ColorCategory + +class ColorAnalyzer: + """ + Analyzes and categorizes colors in images and contours using HSV color space. + + This class provides color classification capabilities that distinguish between + major color groups (blue, red, green) while filtering out background colors + (white, black) and undefined colors. + """ + + def __init__(self, config: Dict = None): + """ + Initialize with optional configuration for color ranges. + + Args: + config: Dictionary containing color detection parameters + """ + self.config = config or {} + # Set default ranges if not provided in config + self.blue_hue_range = self.config.get('blue_hue_range', [100, 140]) + self.red_hue_ranges = self.config.get('red_hue_range', [[0, 10], [170, 180]]) + self.green_hue_range = self.config.get('green_hue_range', [35, 85]) + self.min_saturation = self.config.get('min_saturation', 50) + self.max_value_white = self.config.get('max_value_white', 200) + self.min_value_black = self.config.get('min_value_black', 50) + + def categorize_color_pixel(self, bgr_color: List[int]) -> Tuple[ColorCategory, Optional[str]]: + """ + Classifies a BGR color pixel into one of the predefined color categories. + + Uses HSV color space for more perceptually accurate color discrimination + compared to RGB/BGR. The categorization focuses on the three primary colors + used in the tracing system while excluding background and noise colors. + + Args: + bgr_color: List of [blue, green, red] color values (0-255 range) + + Returns: + Tuple containing: + - ColorCategory enum value + - Standardized hex color code for primary colors, None for others + """ + if len(bgr_color) < 3: + return ColorCategory.OTHER, None + + b, g, r = bgr_color[:3] + + # Convert to HSV for perceptual color analysis + # HSV provides better separation of hue, saturation, and brightness + hsv_color = np.uint8([[[b, g, r]]]) + hsv = cv2.cvtColor(hsv_color, cv2.COLOR_BGR2HSV)[0][0] + hue, saturation, value = hsv + + # Debug output for red colors + if r > 150 and g < 100 and b < 100: + print(f"🔴 Potential red: RGB=({r},{g},{b}), HSV=({hue},{saturation},{value})") + + # Filter out near-white colors (high brightness, low saturation) + if value > self.max_value_white and saturation < self.min_saturation: + return ColorCategory.WHITE, None + + # Filter out near-black colors (very low brightness) + if value < self.min_value_black: + return ColorCategory.BLACK, None + + # Ensure minimum saturation for colorfulness + if saturation < self.min_saturation: + return ColorCategory.OTHER, None + + # Blue classification + blue_low, blue_high = self.blue_hue_range + if (blue_low <= hue <= blue_high) or (b > g + 20 and b > r + 20): + return ColorCategory.BLUE, "#0000FF" + + # Red classification + for red_low, red_high in self.red_hue_ranges: + if red_low <= hue <= red_high: + return ColorCategory.RED, "#FF0000" + # Also check RGB dominance for red + if (r > g + 30 and r > b + 30): + return ColorCategory.RED, "#FF0000" + + # Green classification + green_low, green_high = self.green_hue_range + if (green_low <= hue <= green_high) or (g > r + 20 and g > b + 20): + return ColorCategory.GREEN, "#00FF00" + + return ColorCategory.OTHER, None + + def get_dominant_color(self, contour: np.ndarray, original_image: np.ndarray) -> Optional[str]: + """Identifies dominant stroke color along contour boundary.""" + if contour is None: + print("❌ Contour is None") + return None + + try: + # Ensure contour is in the correct format for OpenCV + if len(contour) == 0: + print("❌ Empty contour array") + return None + + print(f"🔍 Initial contour shape: {contour.shape}, dtype: {contour.dtype}") + + # Make a copy and ensure it's the exact format OpenCV expects + contour_fixed = contour.astype(np.int32) # OpenCV often prefers int32 for drawContours + print(f"🔍 Fixed contour shape: {contour_fixed.shape}, dtype: {contour_fixed.dtype}") + + # Create boundary mask to isolate the actual stroke pixels + boundary_mask = np.zeros(original_image.shape[:2], np.uint8) + + # Try different approaches for drawing contours + try: + # Method 1: Direct drawing + cv2.drawContours(boundary_mask, [contour_fixed], 0, 255, 2) + except Exception as e1: + print(f"⚠️ Method 1 failed: {e1}") + try: + # Method 2: Ensure it's a list of contours + cv2.drawContours(boundary_mask, [contour_fixed], -1, 255, 2) + except Exception as e2: + print(f"⚠️ Method 2 failed: {e2}") + try: + # Method 3: Convert to list of points + points = contour_fixed.reshape(-1, 2) + contour_list = [points.astype(np.int32)] + cv2.drawContours(boundary_mask, contour_list, 0, 255, 2) + except Exception as e3: + print(f"❌ All contour drawing methods failed: {e3}") + return None + + # Check if we successfully drew anything + if np.count_nonzero(boundary_mask) == 0: + print("⚠️ No pixels drawn in boundary mask") + return None + + boundary_pixels = original_image[boundary_mask == 255] + + # Early return if no boundary pixels were sampled + if len(boundary_pixels) == 0: + print("⚠️ No boundary pixels found for color analysis") + return None + + print(f"🔍 Found {len(boundary_pixels)} boundary pixels for analysis") + + # Tally color categories from all boundary pixels + color_categories = defaultdict(int) + total_pixels = len(boundary_pixels) + + # Sample every 10th pixel for performance (unless it's a small contour) + step = max(1, total_pixels // 100) # Sample at most 100 pixels + + sampled_pixels = boundary_pixels[::step] + print(f"🔍 Analyzing {len(sampled_pixels)} sampled pixels") + + for pixel in sampled_pixels: + category, hex_color = self.categorize_color_pixel(pixel.tolist()) + # Only count meaningful color categories, ignore background colors + if category in [ColorCategory.BLUE, ColorCategory.RED, ColorCategory.GREEN]: + color_categories[category.value] += 1 + + # Debug output with percentages + if color_categories: + category_info = [] + for category, count in color_categories.items(): + percentage = (count / len(sampled_pixels)) * 100 + category_info.append(f"{category}: {count}({percentage:.1f}%)") + print(f"🎨 Color distribution: {', '.join(category_info)}") + else: + print("🎨 No primary colors detected in boundary pixels") + # Let's check what colors we ARE seeing + unique_colors = np.unique(boundary_pixels, axis=0) + print(f"🔍 Unique colors found: {len(unique_colors)}") + if len(unique_colors) > 0: + for i, color in enumerate(unique_colors[:5]): # Show first 5 unique colors + b, g, r = color + print(f" Color {i}: BGR({b},{g},{r})") + + # Determine the most frequent valid color category + if color_categories: + dominant_category = max(color_categories.items(), key=lambda x: x[1])[0] + + # Map category to standardized hex color + color_map = { + "blue": "#0000FF", + "red": "#FF0000", + "green": "#00FF00" + } + dominant_color = color_map.get(dominant_category) + print(f"🎯 Dominant color: {dominant_color}") + return dominant_color + + print("⚠️ No valid color categories found") + return None + + except Exception as e: + print(f"❌ Error in get_dominant_color: {e}") + import traceback + traceback.print_exc() + return None + + def categorize(self, contour, image: np.ndarray) -> Optional[str]: + """ + MAIN INTERFACE METHOD - Categorizes the dominant color of a contour. + """ + if hasattr(contour, 'to_numpy'): + # It's a Contour entity - convert to numpy for OpenCV processing + contour_points = contour.to_numpy() + print(f"🔍 ColorAnalyzer.categorize() called with Contour entity: {len(contour.points)} points, area: {contour.area:.1f}") + print(f"🔍 Contour numpy shape: {contour_points.shape}, dtype: {contour_points.dtype}") + print(f"🔍 First few points: {contour_points[:3] if len(contour_points) > 0 else 'EMPTY'}") + elif hasattr(contour, 'points'): + # Alternative check for Contour entity + contour_points = np.array([[point.x, point.y] for point in contour.points], dtype=np.float32).reshape(-1, 1, 2) + print(f"🔍 ColorAnalyzer.categorize() called with Contour entity: {len(contour.points)} points, area: {contour.area:.1f}") + print(f"🔍 Manual numpy shape: {contour_points.shape}, dtype: {contour_points.dtype}") + + # Check if contour_points is valid + if contour_points is None or len(contour_points) == 0: + print("❌ Empty contour points, skipping color analysis") + return None + + hex_color = self.get_dominant_color(contour_points, image) + + if hex_color == "#FF0000": + print("✅ Categorized as RED") + return "red" + elif hex_color == "#0000FF": + print("✅ Categorized as BLUE") + return "blue" + elif hex_color == "#00FF00": + print("✅ Categorized as GREEN") + return "green" + else: + print(f"❌ No dominant color found, got: {hex_color}") + return None diff --git a/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_closure_service.py b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_closure_service.py new file mode 100644 index 0000000..3d5cb1e --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_closure_service.py @@ -0,0 +1,116 @@ +import cv2 +import numpy as np +from typing import List +from dataclasses import dataclass + + +@dataclass +class ClosedContour: + """ + Represents a contour with closure verification and metrics. + + This immutable data structure provides a clean interface for contour + information, ensuring closure status is explicitly tracked and available + to downstream processing stages. + + Attributes: + points: List of contour points as numpy arrays + is_closed: Boolean indicating whether the contour forms a closed shape + closure_gap: Distance between start and end points in pixels + """ + points: List[np.ndarray] + is_closed: bool + closure_gap: float + + +class ContourClosureService: + """ + Ensures contour closure and provides closure analysis utilities. + + Handles the important task of verifying and enforcing contour closure, + which is essential for generating valid SVG paths and proper shape rendering. + Open contours can cause rendering artifacts and incorrect shape interpretation. + """ + + def ensure_closure(self, contour: np.ndarray, tolerance: float = 5.0) -> np.ndarray: + """ + Guarantees a contour forms a closed loop by connecting endpoints if necessary. + + Checks the distance between start and end points. If beyond tolerance, + explicitly adds the start point to the end to create a mathematically + closed contour. This prevents rendering issues in downstream SVG generation. + + Args: + contour: numpy array of contour points to check and potentially close + tolerance: Maximum allowed gap between start and end points in pixels + + Returns: + Guaranteed closed contour as numpy array + """ + # Contours with less than 3 points cannot form closed shapes + if len(contour) < 3: + return contour + + start_point = contour[0][0] + end_point = contour[-1][0] + + # Calculate Euclidean distance between start and end points + distance = np.linalg.norm(start_point - end_point) + + # Explicitly close the contour if endpoints are too far apart + if distance > tolerance: + # Reshape start point to match contour array structure + start_point_reshaped = contour[0].reshape(1, 1, 2) + closed_contour = np.vstack([contour, start_point_reshaped]) + print(f" 🔒 Closed contour: start-end distance was {distance:.2f} pixels") + return closed_contour + + return contour + + def is_closed(self, contour: np.ndarray, tolerance: float = 5.0) -> bool: + """ + Determines if a contour forms a mathematically closed shape. + + A contour is considered closed if the distance between its start + and end points is within the specified tolerance. This is essential + for validating contour integrity before further processing. + + Args: + contour: numpy array of contour points to check + tolerance: Maximum allowed gap for considering the contour closed + + Returns: + True if contour is closed within tolerance, False otherwise + """ + # Contours with insufficient points cannot be closed + if len(contour) < 3: + return False + + start_point = contour[0][0] + end_point = contour[-1][0] + distance = np.linalg.norm(start_point - end_point) + + return distance <= tolerance + + def calculate_closure_gap(self, contour: np.ndarray) -> float: + """ + Calculates the precise gap distance between contour start and end points. + + This metric helps quantify how "open" a contour is and informs + closure decisions. Larger gaps may indicate detection errors or + intentionally open shapes. + + Args: + contour: numpy array of contour points to measure + + Returns: + Euclidean distance between start and end points in pixels, + or infinity for invalid contours + """ + if len(contour) < 3: + return float('inf') + + start_point = contour[0][0] + end_point = contour[-1][0] + return np.linalg.norm(start_point - end_point) + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_detector.py b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_detector.py new file mode 100644 index 0000000..d4d0691 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/contour_detector.py @@ -0,0 +1,81 @@ +import cv2 +import numpy as np +from typing import Tuple, Optional, Dict +from .contour_closure_service import ContourClosureService + + +class ContourDetector: + """ + Detects and extracts contours from bitmap images using multiple thresholding strategies. + This class is responsible for the initial image processing and contour detection phase, + converting raster images into vectorizable shapes. + """ + + def __init__(self): + """Initialize the contour detector with closure service.""" + self.closure_service = ContourClosureService() + + def detect(self, image_data: Dict) -> Tuple[Optional[tuple], Optional[np.ndarray]]: + """ + Detects all contours in the provided image data using a multi-method thresholding approach. + + Args: + image_data: Dictionary containing 'image_array' with the image data + + Returns: + Tuple containing: + - Tuple of detected contours (or None if image loading fails) + - Contour hierarchy information as numpy array (or None if no contours detected) + """ + print(f"🔍 Detecting contours in image data...") + + # Extract image array from the data dictionary + img = image_data.get('image_array') + if img is None: + print(f"❌ No image array found in image data") + return None, None + + height, width = img.shape[:2] + print(f"📐 Image size: {width}x{height}") + + # Convert to grayscale as contour detection operates on single channel + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + # Apply multiple thresholding methods for robustness + binary1 = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, 15, 5) + + _, binary2 = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + # Combine results from both methods to capture all potential contours + combined = cv2.bitwise_or(binary1, binary2) + + # Apply morphological operations to clean up noise and connect broken segments + kernel = np.ones((3,3), np.uint8) + cleaned = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, kernel, iterations=2) + cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_OPEN, kernel, iterations=1) + + # Extract contours with hierarchy to preserve parent-child relationships + contours, hierarchy = cv2.findContours(cleaned, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_KCOS) + + # Ensure all contours are closed + closed_contours = [] + for i, contour in enumerate(contours): + # Use the closure service to guarantee this contour is closed + closed_contour = self.closure_service.ensure_closure(contour) + closed_contours.append(closed_contour) + + # Debug information about closure status + original_length = len(contour) + closed_length = len(closed_contour) + is_closed = self.closure_service.is_closed(closed_contour) + closure_gap = self.closure_service.calculate_closure_gap(contour) + + closure_status = "🔒 CLOSED" if is_closed else "🔓 OPEN" + if closed_length > original_length: + closure_status += " (forced)" + + print(f" {closure_status} Contour {i+1}: {original_length} → {closed_length} points, gap: {closure_gap:.1f}px") + + print(f"✅ Found {len(closed_contours)} total contours (all ensured closed)") + return tuple(closed_contours), hierarchy diff --git a/sketchgetdp/bitmap_tracer/infrastructure/image_processing/image_loader_impl.py b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/image_loader_impl.py new file mode 100644 index 0000000..19441b2 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/image_processing/image_loader_impl.py @@ -0,0 +1,88 @@ +""" +Concrete implementation of ImageLoader using OpenCV. + +This implementation provides the actual image loading functionality +that the abstract interface defines. +""" + +import os +import cv2 +import numpy as np +from typing import Optional, Tuple +from interfaces.gateways.image_loader import ImageLoader + + +class OpenCVImageLoader(ImageLoader): + """Concrete image loader implementation using OpenCV library.""" + + SUPPORTED_FORMATS = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} + + def load_image(self, image_path: str) -> Optional[np.ndarray]: + """ + Load image using OpenCV's imread function. + + Args: + image_path: Path to the image file + + Returns: + Image as numpy array in BGR format, or None if loading fails + """ + try: + if not self.validate_image_path(image_path): + return None + + # Load image in color mode (BGR format) + image = cv2.imread(image_path, cv2.IMREAD_COLOR) + + if image is None: + print(f"⚠️ OpenCV could not decode image: {image_path}") + return None + + print(f"✅ Loaded image: {image_path} - Shape: {image.shape}") + return image + + except Exception as error: + print(f"❌ Error loading image {image_path}: {error}") + return None + + def get_image_dimensions(self, image: np.ndarray) -> Tuple[int, int]: + """ + Extract width and height from image array. + + Args: + image: numpy array with shape (height, width, channels) + + Returns: + Tuple of (width, height) + """ + if not isinstance(image, np.ndarray) or image.ndim < 2: + raise ValueError("Invalid image array provided") + + height, width = image.shape[:2] + return width, height + + def validate_image_path(self, image_path: str) -> bool: + """ + Validate that the image file exists and has supported format. + + Args: + image_path: Path to validate + + Returns: + True if file is valid and accessible + """ + if not os.path.exists(image_path): + print(f"❌ Image file not found: {image_path}") + return False + + if not os.access(image_path, os.R_OK): + print(f"❌ Cannot read image file: {image_path}") + return False + + file_ext = os.path.splitext(image_path)[1].lower() + if file_ext not in self.SUPPORTED_FORMATS: + print(f"❌ Unsupported image format: {file_ext}") + return False + + return True + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/point_detection/__init__.py b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/__init__.py new file mode 100644 index 0000000..428cc1f --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/__init__.py @@ -0,0 +1,11 @@ +""" +Point detection infrastructure components. + +This module provides concrete implementations for point detection and curve fitting +operations that interact with external frameworks and libraries. +""" + +from .point_detector import PointDetector +from .curve_fitter import CurveFitter + +__all__ = ['PointDetector', 'CurveFitter'] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/point_detection/curve_fitter.py b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/curve_fitter.py new file mode 100644 index 0000000..f7247a5 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/curve_fitter.py @@ -0,0 +1,207 @@ +import cv2 +import numpy as np +from typing import Optional + + +class CurveFitter: + """ + Converts raster contours into smooth vector paths using adaptive fitting. + + This class implements a hybrid curve fitting approach that intelligently + switches between straight lines and curved segments based on local + contour geometry. It preserves sharp corners while smoothing gentle curves. + """ + + def __init__(self, angle_threshold: float = 25, min_curve_angle: float = 120): + """ + Initialize curve fitter with geometric thresholds. + + Args: + angle_threshold: Minimum angle (degrees) for curve segment classification + min_curve_angle: Minimum angle (degrees) for considering curve fitting + """ + self.angle_threshold = angle_threshold + self.min_curve_angle = min_curve_angle + + def fit_curve(self, contour: np.ndarray, epsilon_factor: float = 0.0015) -> Optional[str]: + """ + Convert contour to SVG path data using adaptive line/curve fitting. + + The algorithm: + 1. Simplifies contour to remove noise while preserving structure + 2. Ensures path closure for valid SVG rendering + 3. Analyzes angles between segments to determine fitting strategy + 4. Uses lines for sharp corners and quadratic bezier for gentle curves + + Args: + contour: OpenCV contour array to convert + epsilon_factor: Douglas-Peucker simplification tolerance factor + + Returns: + SVG path data string, or None if conversion fails + """ + if len(contour) < 3: + return None + + # Simplify contour to reduce noise while preserving important features + simplified_contour = self._simplify_contour(contour, epsilon_factor) + if simplified_contour is None: + return None + + points = [point[0] for point in simplified_contour] + + # Ensure path forms closed loop for valid SVG rendering + points, is_closed = self._ensure_closure(points) + + # Generate SVG path data using adaptive segment fitting + path_data = self._generate_svg_path(points, is_closed) + + return path_data + + def _simplify_contour(self, contour: np.ndarray, epsilon_factor: float) -> Optional[np.ndarray]: + """ + Apply contour simplification with length-adaptive tolerance. + + Args: + contour: Raw contour array to simplify + epsilon_factor: Tolerance factor relative to contour perimeter + + Returns: + Simplified contour array meeting minimum point requirements + """ + contour_length = cv2.arcLength(contour, True) + epsilon = epsilon_factor * contour_length + simplified_contour = cv2.approxPolyDP(contour, epsilon, True) + + return simplified_contour if len(simplified_contour) >= 3 else None + + def _ensure_closure(self, points: list) -> tuple: + """ + Verify and enforce contour closure for valid SVG path generation. + + Checks distance between start and end points. If beyond threshold, + appends start point to end to force closure. This ensures all + generated paths form complete, renderable shapes. + + Args: + points: List of contour points as [x, y] coordinates + + Returns: + Tuple of (closed_points, closure_status) + """ + start_point = points[0] + end_point = points[-1] + closure_distance = np.linalg.norm(np.array(start_point) - np.array(end_point)) + + closure_threshold = 10.0 # pixels + is_naturally_closed = closure_distance <= closure_threshold + + if not is_naturally_closed: + print(f" ⚠️ Simplified contour not closed, distance: {closure_distance:.2f}") + points.append(points[0]) + print(" 🔒 Forced closure on simplified points") + is_naturally_closed = True + + return points, is_naturally_closed + + def _generate_svg_path(self, points: list, is_closed: bool) -> str: + """ + Convert point sequence to SVG path data using adaptive fitting. + + Analyzes angles between consecutive segments to determine optimal + path commands. Sharp angles use straight lines, while gentle curves + use quadratic bezier segments for smooth rendering. + + Args: + points: List of contour points as [x, y] coordinates + is_closed: Boolean indicating if path forms closed loop + + Returns: + SVG path data string with move, line, and curve commands + """ + path_data = f"M {points[0][0]},{points[0][1]}" + point_count = len(points) + current_index = 1 + + while current_index < point_count: + current_point = points[current_index] + previous_point = points[current_index - 1] + next_point = points[(current_index + 1) % point_count] + + # Handle final segment connection for closed paths + if current_index == point_count - 1 and is_closed: + path_data += f" L {points[0][0]},{points[0][1]}" + break + + # Analyze segment geometry for curve fitting decisions + if self._should_use_curve_fitting(current_index, point_count, is_closed): + segment_angle = self._calculate_segment_angle(previous_point, current_point, next_point) + + if segment_angle is not None and segment_angle < self.angle_threshold: + # Sharp corner - use straight line segment + path_data += f" L {current_point[0]},{current_point[1]}" + current_index += 1 + else: + # Gentle curve - use quadratic bezier + path_data += f" Q {current_point[0]},{current_point[1]} {next_point[0]},{next_point[1]}" + current_index += 2 # Skip next point as it's used in curve + else: + # Default to straight line segment + path_data += f" L {current_point[0]},{current_point[1]}" + current_index += 1 + + # Ensure path termination for closed shapes + path_data += " Z" + print(f" {'✅' if is_closed else '⚠️'} Path closure: {is_closed}") + + return path_data + + def _should_use_curve_fitting(self, current_index: int, total_points: int, is_closed: bool) -> bool: + """ + Determine if current segment is suitable for curve analysis. + + Curve fitting requires sufficient surrounding points for + angle calculation. This prevents errors at path boundaries. + + Args: + current_index: Current position in point sequence + total_points: Total number of points in contour + is_closed: Whether path forms closed loop + + Returns: + True if segment can be evaluated for curve fitting + """ + return current_index < total_points - 1 or (is_closed and total_points > 3) + + def _calculate_segment_angle(self, previous_point: list, current_point: list, next_point: list) -> Optional[float]: + """ + Calculate angle between consecutive contour segments. + + Uses vector analysis to determine the turning angle at each + contour vertex. This angle guides the line vs curve decision. + + Args: + previous_point: Point before current vertex + current_point: Current vertex position + next_point: Point after current vertex + + Returns: + Angle in degrees between segments, or None if calculation fails + """ + vector_to_previous = np.array([previous_point[0] - current_point[0], + previous_point[1] - current_point[1]]) + vector_to_next = np.array([next_point[0] - current_point[0], + next_point[1] - current_point[1]]) + + previous_magnitude = np.linalg.norm(vector_to_previous) + next_magnitude = np.linalg.norm(vector_to_next) + + if previous_magnitude > 0 and next_magnitude > 0: + normalized_previous = vector_to_previous / previous_magnitude + normalized_next = vector_to_next / next_magnitude + + dot_product = np.clip(np.dot(normalized_previous, normalized_next), -1.0, 1.0) + angle_radians = np.arccos(dot_product) + return np.degrees(angle_radians) + + return None \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/infrastructure/point_detection/point_detector.py b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/point_detector.py new file mode 100644 index 0000000..dacd971 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/infrastructure/point_detection/point_detector.py @@ -0,0 +1,112 @@ +import cv2 +import numpy as np +from typing import Optional +from core.entities.point import Point + + +class PointDetector: + """ + Detects point-like contours and extracts their geometric properties. + + A point is defined as a small, compact contour that represents a discrete + marker rather than a continuous path. This class encapsulates the logic + for identifying such contours and calculating their center points. + """ + + def __init__(self, max_area: int = 100, max_perimeter: int = 80): + """ + Initialize the point detector with size thresholds. + + Args: + max_area: Maximum contour area to be considered a point (pixels²) + max_perimeter: Maximum contour perimeter to be considered a point (pixels) + """ + self.max_area = max_area + self.max_perimeter = max_perimeter + + def set_config(self, config: dict): + """ + Update detection thresholds from configuration. + + Args: + config: Dictionary containing point_max_area and point_max_perimeter + """ + if config: + self.max_area = config.get('point_max_area', self.max_area) + self.max_perimeter = config.get('point_max_perimeter', self.max_perimeter) + print(f"🔧 PointDetector configured - max_area: {self.max_area}, max_perimeter: {self.max_perimeter}") + + def is_point(self, contour: np.ndarray) -> bool: + """ + Determine if a contour represents a point-like shape. + + Points are small, compact contours that meet both area and perimeter + criteria. This prevents large or elongated shapes from being misclassified. + + Args: + contour: OpenCV contour array to evaluate + + Returns: + True if contour meets point criteria, False otherwise + """ + if len(contour) < 3: + return False + + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + + is_point = area < self.max_area and perimeter < self.max_perimeter + + if not is_point: + print(f" ❌ Point criteria failed: area {area:.1f} >= {self.max_area} OR perimeter {perimeter:.1f} >= {self.max_perimeter}") + + return is_point + + def get_center(self, contour: np.ndarray) -> Optional[Point]: + """ + Calculate the centroid of a contour using moment analysis. + + The centroid represents the geometric center of the contour shape. + This method uses OpenCV's moments calculation for accurate center detection. + + Args: + contour: OpenCV contour array to analyze + + Returns: + Point object representing the centroid, or None if calculation fails + """ + if len(contour) < 3: + return None + + moments = cv2.moments(contour) + if moments["m00"] != 0: + center_x = int(moments["m10"] / moments["m00"]) + center_y = int(moments["m01"] / moments["m00"]) + return Point(center_x, center_y) + + return None + + def detect_point(self, contour: np.ndarray) -> Optional[Point]: + """ + Complete point detection pipeline: identification and center calculation. + + This method combines contour evaluation and center calculation into + a single operation. It first verifies the contour meets point criteria, + then calculates and returns its center if valid. + + Args: + contour: OpenCV contour array to process + + Returns: + Point object for valid point contours, None for non-point contours + """ + if not self.is_point(contour): + return None + + center = self.get_center(contour) + if center: + area = cv2.contourArea(contour) + perimeter = cv2.arcLength(contour, True) + print(f" 📍 Point detected: area={area:.1f}, perimeter={perimeter:.1f}, center=({center.x}, {center.y})") + + return center diff --git a/sketchgetdp/bitmap_tracer/interfaces/__init__.py b/sketchgetdp/bitmap_tracer/interfaces/__init__.py new file mode 100644 index 0000000..a9ca6ba --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/__init__.py @@ -0,0 +1,31 @@ +""" +Interface Adapters Layer + +This layer contains adapters that convert data between the form most convenient +for external agencies (e.g., web, UI, devices) and the form most convenient +for use cases and entities. + +The interfaces layer depends on the enterprise business rules in the core layer, +but external agencies (like databases and web frameworks) depend on this layer. + +Components: +- Controllers: Handle input from external sources and convert it to use case input +- Presenters: Format output from use cases for external presentation +- Gateways: Interface with external resources while abstracting their implementation +""" + +from .controllers import * +from .presenters import * +from .gateways import * + +__all__ = [ + # Controllers + "TracingController", # Handles image tracing requests and coordinates use cases + + # Presenters + "SVGPresenter", # Formats tracing results as SVG documents + + # Gateways + "ImageLoader", # Abstracts image loading from various sources + "ConfigRepository", # Abstracts configuration storage and retrieval +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/controllers/__init__.py b/sketchgetdp/bitmap_tracer/interfaces/controllers/__init__.py new file mode 100644 index 0000000..e50fef6 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/controllers/__init__.py @@ -0,0 +1,15 @@ +""" +Interface adapters that handle user input and coordinate use cases. + +Controllers are responsible for: +- Accepting input from the outside world +- Coordinating the execution of use cases +- Transforming data between external and internal representations +- Handling presentation concerns + +This package follows the Interface Adapter layer in Clean Architecture. +""" + +from .tracing_controller import TracingController + +__all__ = ["TracingController"] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/controllers/tracing_controller.py b/sketchgetdp/bitmap_tracer/interfaces/controllers/tracing_controller.py new file mode 100644 index 0000000..b787869 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/controllers/tracing_controller.py @@ -0,0 +1,366 @@ +""" +Coordinates the image tracing workflow from input image to SVG output. + +The TracingController is the primary interface for the bitmap tracing functionality. +It orchestrates the complete workflow while maintaining separation of concerns +between use cases, business rules, and external interfaces. + +Key Responsibilities: +- Validate and sanitize input parameters +- Coordinate execution of use cases in proper sequence +- Handle errors and transform them for presentation +- Provide status information about the system +- Maintain dependency inversion through constructor injection +""" + +import os +import sys +from typing import Optional, Dict, Any + +# Add the parent directory to Python path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from infrastructure.image_processing.contour_detector import ContourDetector +from infrastructure.image_processing.color_analyzer import ColorAnalyzer +from infrastructure.point_detection.point_detector import PointDetector +from core.entities.color import Color +from core.use_cases.image_tracing import ImageTracingUseCase +from core.use_cases.structure_filtering import StructureFilteringUseCase +from interfaces.presenters.svg_presenter import SVGPresenter +from interfaces.gateways.image_loader import ImageLoader +from interfaces.gateways.config_repository import ConfigRepository + + +class TracingController: + """ + Primary controller for bitmap-to-vector image tracing operations. + + This controller follows the Single Responsibility Principle by focusing + solely on coordinating the tracing workflow. It delegates specific + responsibilities to specialized use cases and infrastructure components. + + Dependencies are injected to support testability and follow the + Dependency Inversion Principle. + """ + + def __init__(self, + config_repository: Optional[ConfigRepository] = None, + image_loader: Optional[ImageLoader] = None, + contour_detector: Optional[ContourDetector] = None, + color_analyzer: Optional[ColorAnalyzer] = None, + point_detector: Optional[PointDetector] = None): + """ + Initialize controller with dependencies. + + All dependencies are optional to allow for flexible testing and + default implementations. This follows the Null Object Pattern + for optional dependencies. + + Args: + config_repository: Repository for configuration data access + image_loader: Service for loading image data from filesystem + contour_detector: Detects contours in loaded images + color_analyzer: Analyzes and categorizes colors in contours + point_detector: Identifies point-like structures in contours + """ + # Import concrete implementations here to avoid circular imports + from infrastructure.configuration.config_loader import ConfigLoader + from infrastructure.image_processing.image_loader_impl import OpenCVImageLoader + + self.config_repository = config_repository or ConfigLoader() + self.image_loader = image_loader or OpenCVImageLoader() + self.contour_detector = contour_detector or ContourDetector() + self.color_analyzer = color_analyzer or ColorAnalyzer() + self.point_detector = point_detector or PointDetector() + + # Use cases encapsulate business rules and workflow logic + self.image_tracing_use_case = ImageTracingUseCase( + contour_detector=self.contour_detector, + color_analyzer=self.color_analyzer, + point_detector=self.point_detector + ) + + self.structure_filtering_use_case = StructureFilteringUseCase() + + def trace_image(self, + image_path: str, + output_svg_path: str = "output.svg", + config_path: Optional[str] = None) -> Dict[str, Any]: + """ + Execute complete bitmap-to-SVG tracing workflow. + + This is the main entry point for the tracing functionality. + The method coordinates the entire pipeline while maintaining + clean separation between concerns. + + Workflow Steps: + 1. Load configuration parameters + 2. Load and validate input image + 3. Detect and analyze contours with color categorization + 4. Filter structures based on configuration limits + 5. Generate SVG output from processed structures + + Args: + image_path: Filesystem path to source bitmap image + output_svg_path: Destination path for generated SVG file + config_path: Optional path to YAML configuration file + + Returns: + Dictionary containing: + - success: Boolean indicating overall operation success + - output_path: Path to generated SVG file (on success) + - statistics: Counts of different structure types processed + - metadata: Additional information about the operation + - error: Description of failure (when success is False) + """ + try: + print(f"⚡ Starting image tracing: {image_path}") + + # Step 1: Load configuration - business rules about structure limits + config = self._load_configuration(config_path) + if not config: + return self._create_error_response("Failed to load configuration") + + # Step 2: Load image data - external interface concern + image_data = self._load_image_data(image_path) + if not image_data: + return self._create_error_response(f"Could not load image: {image_path}") + + print(f"📐 Image size: {image_data['width']}x{image_data['height']}") + + # Step 3: Execute image tracing use case - core business logic + tracing_result = self._execute_tracing_use_case(image_data, config) + if not tracing_result.get('success', False): + return self._create_error_response("Image tracing failed") + + # Step 4: Filter structures based on configuration limits + filtered_structures = self._execute_filtering_use_case(tracing_result, config) + + # Step 5: Generate SVG output using SVGPresenter + svg_success = self._generate_svg_output(filtered_structures, image_data, output_svg_path) + if not svg_success: + return self._create_error_response("SVG generation failed") + + return self._create_success_response(output_svg_path, filtered_structures, config, image_data) + + except Exception as error: + # All exceptions are caught and transformed for consistent error handling + error_message = f"Unexpected error during tracing: {str(error)}" + print(f"❌ {error_message}") + return self._create_error_response(error_message) + + def _load_configuration(self, config_path: Optional[str]) -> Optional[Dict]: + """Load configuration from repository.""" + config = self.config_repository.load_config(config_path) + if config is None: + print("⚠️ Using default configuration due to loading failure") + return {} + return config + + def _load_image_data(self, image_path: str) -> Optional[Dict]: + """Load and validate image data with proper metadata.""" + try: + # Load the actual image array + image_array = self.image_loader.load_image(image_path) + if image_array is None: + return None + + # Get dimensions from the image array + width, height = self.image_loader.get_image_dimensions(image_array) + + # Create the dictionary structure with metadata + image_data = { + 'image_array': image_array, + 'image_path': image_path, + 'width': width, + 'height': height, + 'channels': image_array.shape[2] if len(image_array.shape) > 2 else 1 + } + + print(f"📐 Image size: {width}x{height}") + return image_data + + except Exception as error: + print(f"❌ Error loading image data: {error}") + return None + + def _execute_tracing_use_case(self, image_data: Dict, config: Dict) -> Dict[str, Any]: + """Execute the image tracing use case with provided data.""" + return self.image_tracing_use_case.execute( + image_data=image_data, + config=config + ) + + def _execute_filtering_use_case(self, tracing_result: Dict, config: Dict) -> Dict[str, Any]: + """Execute structure filtering based on configuration limits.""" + return self.structure_filtering_use_case.execute( + structures=tracing_result['structures'], + config=config + ) + + def _generate_svg_output(self, structures: Dict, image_data: Dict, output_path: str) -> bool: + """ + Generate SVG file from processed structures using SVGPresenter. + + Args: + structures: Filtered structures to render + image_data: Source image dimensions and metadata + output_path: Destination path for SVG file + + Returns: + True if SVG was generated successfully, False otherwise + """ + try: + # Create SVGPresenter with the image dimensions + presenter = SVGPresenter( + output_path=output_path, + width=image_data['width'], + height=image_data['height'] + ) + + # Add all structures to SVG + self._add_structures_to_svg(presenter, structures) + + # Save the SVG file + success = presenter.save() + + if success: + print(f"✅ SVG successfully generated: {output_path}") + else: + print(f"❌ Failed to save SVG: {output_path}") + + return success + + except Exception as error: + print(f"❌ SVG generation error: {error}") + return False + + def _add_structures_to_svg(self, presenter: SVGPresenter, structures: Dict) -> None: + """ + Add all structures to the SVG presenter. + """ + # Import Color class for conversion + from core.entities.color import Color + + # Add red points + red_points = structures.get('red_points', []) + for point in red_points: + # Convert ColorCategory.RED to Color object + red_color = Color.from_hex("#FF0000") + presenter.add_point(point, red_color) + + # Add blue paths + blue_structures = structures.get('blue_structures', []) + for structure in blue_structures: + # Handle both raw contours and processed structures + if isinstance(structure, dict) and 'contour' in structure: + contour = structure['contour'] + path_data = structure.get('path_data') + else: + contour = structure + path_data = None + + # Convert ColorCategory.BLUE to Color object + blue_color = Color.from_hex("#0000FF") + + if path_data: + # Use the processed path data + presenter.add_path(path_data, blue_color) + else: + # Fallback to contour conversion + presenter.add_contour_as_path(contour, blue_color) + + # Add green paths + green_structures = structures.get('green_structures', []) + for structure in green_structures: + # Handle both raw contours and processed structures + if isinstance(structure, dict) and 'contour' in structure: + contour = structure['contour'] + path_data = structure.get('path_data') + else: + contour = structure + path_data = None + + # Convert ColorCategory.GREEN to Color object + green_color = Color.from_hex("#00FF00") + + if path_data: + # Use the processed path data + presenter.add_path(path_data, green_color) + else: + # Fallback to contour conversion + presenter.add_contour_as_path(contour, green_color) + + # Log structure counts for debugging + print(f"📊 Structures to render: {len(red_points)} red points, " + f"{len(blue_structures)} blue paths, {len(green_structures)} green paths") + + def _create_success_response(self, + output_path: str, + structures: Dict, + config: Dict, + image_data: Dict) -> Dict[str, Any]: + """ + Create standardized success response. + + This method ensures consistent response structure across all + successful operations, making it easier for clients to parse results. + + Args: + output_path: Path to generated SVG file + structures: Filtered structures that were rendered + config: Configuration used for processing + image_data: Source image metadata + + Returns: + Standardized success response dictionary + """ + return { + 'success': True, + 'output_path': output_path, + 'statistics': { + 'red_points': len(structures.get('red_points', [])), + 'blue_paths': len(structures.get('blue_structures', [])), + 'green_paths': len(structures.get('green_structures', [])), + 'total_structures': ( + len(structures.get('red_points', [])) + + len(structures.get('blue_structures', [])) + + len(structures.get('green_structures', [])) + ) + }, + 'metadata': { + 'image_size': f"{image_data['width']}x{image_data['height']}", + 'config_limits': { + 'red_dots': config.get('red_dots', 0), + 'blue_paths': config.get('blue_paths', 0), + 'green_paths': config.get('green_paths', 0) + } + } + } + + def _create_error_response(self, error_message: str) -> Dict[str, Any]: + """ + Create standardized error response. + + All errors follow the same structure, making error handling + predictable for clients. This follows the Consistent Error + Handling principle. + + Args: + error_message: Description of what went wrong + + Returns: + Standardized error response dictionary + """ + return { + 'success': False, + 'error': error_message, + 'statistics': { + 'red_points': 0, + 'blue_paths': 0, + 'green_paths': 0, + 'total_structures': 0 + }, + 'metadata': {} + } + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/gateways/__init__.py b/sketchgetdp/bitmap_tracer/interfaces/gateways/__init__.py new file mode 100644 index 0000000..48a87cb --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/gateways/__init__.py @@ -0,0 +1,15 @@ +""" +Gateways Package + +Exports abstract gateway interfaces that define the boundaries between +the application core and external infrastructure. These abstractions +enable testability and flexibility in choosing implementations. +""" + +from .image_loader import ImageLoader +from .config_repository import ConfigRepository + +__all__ = [ + "ImageLoader", + "ConfigRepository", +] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/gateways/config_repository.py b/sketchgetdp/bitmap_tracer/interfaces/gateways/config_repository.py new file mode 100644 index 0000000..e4466bc --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/gateways/config_repository.py @@ -0,0 +1,73 @@ +""" +Configuration Repository Gateway Interface + +Defines the abstraction for configuration management operations that infrastructure +components must implement. This interface centralizes all configuration access +patterns behind a consistent abstraction. +""" + +from abc import ABC, abstractmethod +from typing import Tuple, Any, Dict, Optional + + +class ConfigRepository(ABC): + """Contracts for managing application configuration state and defaults.""" + + @abstractmethod + def load_config(self, config_path: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Load and parse configuration from persistent storage. + + Implementations should handle YAML parsing, schema validation, + and setting appropriate defaults for missing values. + + Args: + config_path: Optional path to configuration file in YAML format + + Returns: + Dictionary containing configuration data, or None if loading fails + """ + pass + + @abstractmethod + def get_structure_limits(self) -> Tuple[int, int, int]: + """ + Retrieve the maximum number of structures to process for each color category. + + These limits control the filtering behavior during image tracing, + ensuring only the most significant structures are processed. + + Returns: + Tuple of (red_dots_limit, blue_paths_limit, green_paths_limit) + where each limit represents the maximum count for that color category + """ + pass + + @abstractmethod + def get_contour_detection_params(self) -> Dict[str, Any]: + """ + Retrieve parameters for contour detection and filtering. + + These parameters control how contours are detected and filtered + during image processing. + + Returns: + Dictionary containing contour detection parameters such as + maximum area and perimeter thresholds + """ + pass + + @abstractmethod + def get_color_detection_params(self) -> Dict[str, Any]: + """ + Retrieve parameters for color categorization in HSV space. + + These parameters define the hue ranges and thresholds for + identifying different colors in the image. + + Returns: + Dictionary containing color detection parameters including + hue ranges for red, blue, green, and saturation/value thresholds + """ + pass + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/gateways/image_loader.py b/sketchgetdp/bitmap_tracer/interfaces/gateways/image_loader.py new file mode 100644 index 0000000..c54e86b --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/gateways/image_loader.py @@ -0,0 +1,72 @@ +""" +Image Loader Gateway Interface + +Defines the abstraction for image loading operations that infrastructure components +must implement. This interface follows the Dependency Inversion Principle, allowing +high-level modules to depend on abstractions rather than concrete implementations. +""" + +from abc import ABC, abstractmethod +from typing import Optional, Tuple +import numpy as np + + +class ImageLoader(ABC): + """Contracts for loading and validating image data from various sources.""" + + @abstractmethod + def load_image(self, image_path: str) -> Optional[np.ndarray]: + """ + Load image data from the filesystem into a processable format. + + Implementations should handle file format decoding, color space conversion, + and memory allocation for the image data. + + Args: + image_path: Absolute or relative path to the image file + + Returns: + Image data as numpy array with shape (height, width, channels), + or None when file cannot be loaded + + Raises: + FileNotFoundError: When image_path does not exist + PermissionError: When image_path cannot be read + ValueError: When file contains invalid image data + """ + pass + + @abstractmethod + def get_image_dimensions(self, image: np.ndarray) -> Tuple[int, int]: + """ + Extract width and height from loaded image data. + + This method provides a consistent way to access image dimensions + regardless of the underlying image representation. + + Args: + image: Valid image data as returned by load_image() + + Returns: + Tuple containing (width, height) in pixels + + Raises: + ValueError: When image parameter is not a valid image array + """ + pass + + @abstractmethod + def validate_image_path(self, image_path: str) -> bool: + """ + Verify that an image file exists and is accessible before loading. + + This pre-validation prevents unnecessary processing attempts + on non-existent or inaccessible files. + + Args: + image_path: Path to verify + + Returns: + True if file exists, is readable, and has supported image extension + """ + pass \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/presenters/__init__.py b/sketchgetdp/bitmap_tracer/interfaces/presenters/__init__.py new file mode 100644 index 0000000..d3f7c18 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/presenters/__init__.py @@ -0,0 +1,8 @@ +""" +Presenters for formatting and presenting tracing results. +Presenters convert application data into specific output formats like SVG. +""" + +from .svg_presenter import SVGPresenter + +__all__ = ["SVGPresenter"] \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/interfaces/presenters/svg_presenter.py b/sketchgetdp/bitmap_tracer/interfaces/presenters/svg_presenter.py new file mode 100644 index 0000000..22c04f3 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/interfaces/presenters/svg_presenter.py @@ -0,0 +1,197 @@ +""" +SVG format presenter for bitmap tracing results. +Converts contours and points into SVG vector graphics elements. +""" + +from svgwrite import Drawing +from typing import List +from core.entities.contour import Contour +from core.entities.point import Point +from core.entities.color import Color +from core.entities.color import ColorCategory + + +class SVGPresenter: + """Converts traced shapes and points into SVG vector graphics.""" + + def __init__(self, output_path: str, width: int, height: int): + """Initializes SVG presenter with output specifications. + + Args: + output_path: File path for SVG output + width: Canvas width in pixels + height: Canvas height in pixels + """ + self.output_path = output_path + self.width = width + self.height = height + self.dwg = Drawing(output_path, size=(width, height)) + self._initialize_element_counters() + + def _initialize_element_counters(self) -> None: + """Sets up counters for tracking different SVG element types.""" + self.elements_count = { + 'points': 0, + 'paths': 0, + 'blue_paths': 0, + 'green_paths': 0, + 'red_points': 0 + } + + def add_point(self, point: Point, color: Color, radius: int = 4) -> None: + """Adds a point marker as SVG circle element. + + Red points are filled circles, other colors use standard styling. + + Args: + point: Point coordinates to render + color: Color classification for styling + radius: Circle radius in pixels + """ + category, _ = color.categorize() + if category == ColorCategory.RED: + fill_color = "#FF0000" + self.elements_count['red_points'] += 1 + else: + fill_color = color.to_hex() + + self.dwg.add(self.dwg.circle( + center=(point.x, point.y), + r=radius, + fill=fill_color, + stroke="none" + )) + self.elements_count['points'] += 1 + + def add_path(self, path_data: str, color: Color, stroke_width: int = 2) -> None: + """Adds SVG path element with specified color styling. + + Args: + path_data: SVG path commands string + color: Determines stroke color (blue/green) + stroke_width: Path line thickness + """ + stroke_color = self._get_path_stroke_color(color) + self._increment_path_counter(color) + + self.dwg.add(self.dwg.path( + d=path_data, + fill="none", + stroke=stroke_color, + stroke_width=stroke_width, + stroke_linecap="round", + stroke_linejoin="round" + )) + self.elements_count['paths'] += 1 + + def _get_path_stroke_color(self, color: Color) -> str: + """Determines SVG stroke color from color classification. + + Args: + color: Color classification + + Returns: + Hex color code for SVG stroke + """ + category, hex_color = color.categorize() + if category == ColorCategory.BLUE: + return "#0000FF" + elif category == ColorCategory.GREEN: + return "#00FF00" + elif category == ColorCategory.RED: + return "#FF0000" + return color.to_hex() + + def _increment_path_counter(self, color: Color) -> None: + """Updates path counters based on color type. + + Args: + color: Color classification for counter selection + """ + category, _ = color.categorize() + if category == ColorCategory.BLUE: + self.elements_count['blue_paths'] += 1 + elif category == ColorCategory.GREEN: + self.elements_count['green_paths'] += 1 + + def add_contour_as_path(self, contour: Contour, color: Color, stroke_width: int = 2) -> None: + """Converts contour to SVG path and adds to drawing. + + Args: + contour: Shape contour to convert + color: Path stroke color + stroke_width: Line thickness + """ + if contour.is_empty(): + return + + path_data = self._convert_contour_to_path_data(contour) + self.add_path(path_data, color, stroke_width) + + def _convert_contour_to_path_data(self, contour: Contour) -> str: + """Generates SVG path data from contour points. + + Args: + contour: Contains ordered points defining shape boundary + + Returns: + SVG path data string with move-to and line-to commands + """ + if len(contour.points) < 1: + return "" + + path_commands = self._build_path_commands_from_contour(contour) + return " ".join(path_commands) + + def _build_path_commands_from_contour(self, contour: Contour) -> List[str]: + """Constructs SVG path commands from contour point sequence. + + Args: + contour: Ordered points defining shape + + Returns: + List of SVG path commands + """ + first_point = contour.points[0] + commands = [f"M {first_point.x},{first_point.y}"] + + for point in contour.points[1:]: + commands.append(f"L {point.x},{point.y}") + + if contour.is_closed and len(contour.points) > 2: + commands.append("Z") + + return commands + + def save(self) -> bool: + """Saves SVG file to disk and prints creation summary. + + Returns: + True if save successful, False on error + """ + try: + self.dwg.save() + self._report_save_success() + return True + except Exception as error: + self._report_save_error(error) + return False + + def _report_save_success(self) -> None: + """Prints success message and element summary.""" + print(f"✅ SVG saved: {self.output_path}") + self._print_creation_summary() + + def _report_save_error(self, error: Exception) -> None: + """Prints error message when save fails.""" + print(f"❌ Error saving SVG: {error}") + + def _print_creation_summary(self) -> None: + """Outputs formatted summary of created SVG elements.""" + print(f"🎨 SVG Creation Summary:") + print(f" Canvas size: {self.width}x{self.height}") + print(f" Total paths: {self.elements_count['paths']}") + print(f" - Blue paths: {self.elements_count['blue_paths']}") + print(f" - Green paths: {self.elements_count['green_paths']}") + print(f" Total points: {self.elements_count['points']}") + print(f" - Red points: {self.elements_count['red_points']}") diff --git a/sketchgetdp/bitmap_tracer/pytest.ini b/sketchgetdp/bitmap_tracer/pytest.ini new file mode 100644 index 0000000..decb2e8 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = . +testpaths = tests \ No newline at end of file diff --git a/sketchgetdp/solver/__init__.py b/sketchgetdp/bitmap_tracer/tests/__init__.py similarity index 100% rename from sketchgetdp/solver/__init__.py rename to sketchgetdp/bitmap_tracer/tests/__init__.py diff --git a/sketchgetdp/bitmap_tracer/tests/core/__init__.py b/sketchgetdp/bitmap_tracer/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/bitmap_tracer/tests/core/entities/__init__.py b/sketchgetdp/bitmap_tracer/tests/core/entities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/bitmap_tracer/tests/core/entities/test_color.py b/sketchgetdp/bitmap_tracer/tests/core/entities/test_color.py new file mode 100644 index 0000000..d681383 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/core/entities/test_color.py @@ -0,0 +1,220 @@ +import pytest +import sys +import os + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from core.entities.color import Color, ColorCategory + + +class TestColor: + """Comprehensive tests for Color entity including categorization logic.""" + + @pytest.mark.parametrize("b,g,r,expected_blue", [ + (200, 100, 100, True), # Blue dominant + (180, 150, 150, True), # Blue dominant + (100, 200, 100, False), # Green dominant + (100, 100, 200, False), # Red dominant + (150, 150, 150, False), # Equal - no dominance + ]) + def test_identifies_blue_dominant_colors(self, b, g, r, expected_blue): + color = Color(b=b, g=g, r=r) + is_blue_dominant = (color.b > color.g + 20 and color.b > color.r + 20) + assert is_blue_dominant == expected_blue + + @pytest.mark.parametrize("b,g,r,expected_red", [ + (100, 100, 200, True), # Red dominant + (150, 150, 180, True), # Red dominant + (200, 100, 100, False), # Blue dominant + (100, 200, 100, False), # Green dominant + (150, 150, 150, False), # Equal - no dominance + ]) + def test_identifies_red_dominant_colors(self, b, g, r, expected_red): + color = Color(b=b, g=g, r=r) + is_red_dominant = (color.r > color.g + 20 and color.r > color.b + 20) + assert is_red_dominant == expected_red + + @pytest.mark.parametrize("b,g,r,expected_green", [ + (100, 200, 100, True), # Green dominant + (150, 180, 150, True), # Green dominant + (200, 100, 100, False), # Blue dominant + (100, 100, 200, False), # Red dominant + (150, 150, 150, False), # Equal - no dominance + ]) + def test_identifies_green_dominant_colors(self, b, g, r, expected_green): + color = Color(b=b, g=g, r=r) + is_green_dominant = (color.g > color.r + 20 and color.g > color.b + 20) + assert is_green_dominant == expected_green + + def test_converts_between_color_formats(self): + color = Color.from_bgr_tuple((100, 150, 200)) + assert color.b == 100 + assert color.g == 150 + assert color.r == 200 + assert color.to_bgr_tuple() == (100, 150, 200) + assert color.to_rgb_tuple() == (200, 150, 100) + assert color.to_hex() == "#C89664" + + def test_immutable_dataclass(self): + """Test that Color is immutable (frozen dataclass).""" + color = Color(b=100, g=150, r=200) + with pytest.raises(Exception): + color.b = 50 + + @pytest.mark.parametrize("hex_input,expected_rgb", [ + ("#FF8040", (0xFF, 0x80, 0x40)), + ("#F84", (0xFF, 0x88, 0x44)), + ("#0000FF", (0x00, 0x00, 0xFF)), + ("#00FF00", (0x00, 0xFF, 0x00)), + ("#FF0000", (0xFF, 0x00, 0x00)), + ("ffffff", (0xFF, 0xFF, 0xFF)), # Without # prefix + ("#fff", (0xFF, 0xFF, 0xFF)), # Short form + ]) + def test_parses_hex_codes_correctly(self, hex_input, expected_rgb): + color = Color.from_hex(hex_input) + assert color.r == expected_rgb[0] + assert color.g == expected_rgb[1] + assert color.b == expected_rgb[2] + + def test_maps_primary_categories_to_hex_values(self): + assert Color.CATEGORY_HEX_COLORS[ColorCategory.BLUE] == "#0000FF" + assert Color.CATEGORY_HEX_COLORS[ColorCategory.RED] == "#FF0000" + assert Color.CATEGORY_HEX_COLORS[ColorCategory.GREEN] == "#00FF00" + + # Only primary colors should have hex mappings + assert ColorCategory.WHITE not in Color.CATEGORY_HEX_COLORS + assert ColorCategory.BLACK not in Color.CATEGORY_HEX_COLORS + assert ColorCategory.OTHER not in Color.CATEGORY_HEX_COLORS + + @pytest.mark.parametrize("bgr_tuple,expected_primary", [ + ((200, 100, 100), True), # Blue + ((100, 200, 100), True), # Green + ((100, 100, 200), True), # Red + ((255, 255, 255), False), # White + ((0, 0, 0), False), # Black + ((150, 150, 150), False), # Gray + ]) + def test_detects_primary_colors_using_mocked_categorization(self, bgr_tuple, expected_primary): + color = Color.from_bgr_tuple(bgr_tuple) + + if bgr_tuple == (200, 100, 100): + mock_return = (ColorCategory.BLUE, "#0000FF") + elif bgr_tuple == (100, 200, 100): + mock_return = (ColorCategory.GREEN, "#00FF00") + elif bgr_tuple == (100, 100, 200): + mock_return = (ColorCategory.RED, "#FF0000") + else: + mock_return = (ColorCategory.OTHER, None) + + def mock_categorize(self): + return mock_return + + original_categorize = Color.categorize + Color.categorize = mock_categorize + + try: + is_primary = color.is_primary_color() + assert is_primary == expected_primary + finally: + Color.categorize = original_categorize + + @pytest.mark.parametrize("bgr_tuple,expected_ignored", [ + ((255, 255, 255), True), # White + ((0, 0, 0), True), # Black + ((150, 150, 150), True), # Gray + ((200, 100, 100), False), # Blue + ((100, 200, 100), False), # Green + ((100, 100, 200), False), # Red + ]) + def test_detects_ignored_colors_using_mocked_categorization(self, bgr_tuple, expected_ignored): + color = Color.from_bgr_tuple(bgr_tuple) + + if bgr_tuple == (255, 255, 255): + mock_return = (ColorCategory.WHITE, None) + elif bgr_tuple == (0, 0, 0): + mock_return = (ColorCategory.BLACK, None) + elif bgr_tuple == (150, 150, 150): + mock_return = (ColorCategory.OTHER, None) + elif bgr_tuple == (200, 100, 100): + mock_return = (ColorCategory.BLUE, "#0000FF") + elif bgr_tuple == (100, 200, 100): + mock_return = (ColorCategory.GREEN, "#00FF00") + else: + mock_return = (ColorCategory.RED, "#FF0000") + + def mock_categorize(self): + return mock_return + + original_categorize = Color.categorize + Color.categorize = mock_categorize + + try: + is_ignored = color.is_ignored_color() + assert is_ignored == expected_ignored + finally: + Color.categorize = original_categorize + + def test_constructors_equivalence(self): + """Test that different constructors produce equivalent results.""" + bgr_color = Color.from_bgr_tuple((100, 150, 200)) + rgb_color = Color.from_rgb_tuple((200, 150, 100)) + hex_color = Color.from_hex("#C89664") + + assert bgr_color == rgb_color + assert bgr_color == hex_color + assert bgr_color.to_hex() == "#C89664" + + @pytest.mark.parametrize("invalid_hex", [ + "invalid", + "#", + "#12", + "#12345", + "#GGGGGG", + ]) + def test_invalid_hex_codes(self, invalid_hex): + """Test that invalid hex codes raise appropriate exceptions.""" + try: + color = Color.from_hex(invalid_hex) + assert isinstance(color, Color) + assert 0 <= color.r <= 255 + assert 0 <= color.g <= 255 + assert 0 <= color.b <= 255 + except (ValueError, IndexError): + pass + + def test_hex_parsing_behavior(self): + """Specifically test the behavior with problematic hex codes.""" + color = Color.from_hex("#12345") + + print(f"#12345 parsed as: r={color.r}, g={color.g}, b={color.b}") + + def test_categorize_integration(self): + """Integration test for actual categorization logic with OpenCV.""" + blue_color = Color(b=255, g=0, r=0) + red_color = Color(b=0, g=0, r=255) + green_color = Color(b=0, g=255, r=0) + + blue_category, blue_hex = blue_color.categorize() + red_category, red_hex = red_color.categorize() + green_category, green_hex = green_color.categorize() + + assert blue_category == ColorCategory.BLUE + assert blue_hex == "#0000FF" + assert red_category == ColorCategory.RED + assert red_hex == "#FF0000" + assert green_category == ColorCategory.GREEN + assert green_hex == "#00FF00" + + def test_white_and_black_categorization(self): + """Test categorization of white and black colors.""" + white_color = Color(b=255, g=255, r=255) + black_color = Color(b=0, g=0, r=0) + + white_category, white_hex = white_color.categorize() + black_category, black_hex = black_color.categorize() + + assert white_category == ColorCategory.WHITE + assert white_hex is None + assert black_category == ColorCategory.BLACK + assert black_hex is None \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/tests/core/entities/test_contour.py b/sketchgetdp/bitmap_tracer/tests/core/entities/test_contour.py new file mode 100644 index 0000000..3d186be --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/core/entities/test_contour.py @@ -0,0 +1,225 @@ +import pytest +import numpy as np +import sys +import os + +# Add the root project directory to Python path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../..')) + +from core.entities.point import Point +from core.entities.contour import Contour + + +class TestContour: + """Unit tests for Contour entity.""" + + @pytest.fixture + def square_points(self): + return [Point(0, 0), Point(2, 0), Point(2, 2), Point(0, 2)] + + @pytest.fixture + def triangle_points(self): + # Right triangle: base=3, height=4 + return [Point(0, 0), Point(3, 0), Point(3, 4)] + + @pytest.fixture + def empty_contour(self): + return Contour(points=[], is_closed=True, closure_gap=0.0) + + def test_initialization(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.5) + + assert contour.points == square_points + assert contour.is_closed is True + assert contour.closure_gap == 0.5 + + def test_area_triangle(self, triangle_points): + contour = Contour(points=triangle_points, is_closed=True, closure_gap=0.0) + + # 3*4/2 = 6.0 + expected_area = 6.0 + assert contour.area == expected_area + + def test_area_square(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + + assert contour.area == 4.0 + + @pytest.mark.parametrize("points,expected_area", [ + ([], 0.0), + ([Point(0, 0)], 0.0), + ([Point(0, 0), Point(1, 1)], 0.0), + ]) + def test_area_insufficient_points(self, points, expected_area): + contour = Contour(points=points, is_closed=True, closure_gap=0.0) + assert contour.area == expected_area + + def test_perimeter_square(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + + assert contour.perimeter == 8.0 + + @pytest.mark.parametrize("points,expected_perimeter", [ + ([], 0.0), + ([Point(0, 0)], 0.0), + ]) + def test_perimeter_insufficient_points(self, points, expected_perimeter): + contour = Contour(points=points, is_closed=True, closure_gap=0.0) + assert contour.perimeter == expected_perimeter + + def test_circularity_perfect_circle_approximation(self): + # Create a rough circle approximation to test circularity calculation + points = [] + radius = 5.0 + num_points = 36 + + for i in range(num_points): + angle = 2 * np.pi * i / num_points + x = radius * np.cos(angle) + y = radius * np.sin(angle) + points.append(Point(x, y)) + + contour = Contour(points=points, is_closed=True, closure_gap=0.0) + + # Should be close to 1.0 for a circle + assert 0.9 < contour.circularity < 1.1 + + def test_circularity_square(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + + # 4πA/P² = 4π*4/64 = π/4 ≈ 0.785 + expected_circularity = np.pi / 4 + assert contour.circularity == pytest.approx(expected_circularity, abs=0.01) + + def test_circularity_zero_perimeter(self, empty_contour): + assert empty_contour.circularity == 0.0 + + def test_get_center(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + + # Centroid of square from (0,0) to (2,2) is at (1.0, 1.0) + assert contour.get_center() == Point(1.0, 1.0) + + def test_get_center_empty_contour(self, empty_contour): + assert empty_contour.get_center() is None + + def test_from_numpy_contour_empty(self): + result = Contour.from_numpy_contour(np.array([])) + + assert result.points == [] + assert result.is_closed is True + assert result.closure_gap == 0.0 + + def test_from_numpy_contour_single_point(self): + result = Contour.from_numpy_contour(np.array([[[0, 0]]])) + + assert len(result.points) == 1 + assert result.points[0] == Point(0, 0) + assert result.is_closed is False + assert result.closure_gap == 0.0 + + def test_from_numpy_contour_closed_shape(self): + # Closing point matches start - should be detected as closed + triangle_contour = np.array([[[0, 0]], [[4, 0]], [[0, 3]], [[0, 0]]]) + + result = Contour.from_numpy_contour(triangle_contour, tolerance=1.0) + + assert len(result.points) == 4 + assert result.is_closed is True + assert result.closure_gap == 0.0 + + def test_from_numpy_contour_open_shape(self): + # Ends far from start point - should be detected as open + open_contour = np.array([[[0, 0]], [[4, 0]], [[4, 3]], [[8, 3]]]) + + result = Contour.from_numpy_contour(open_contour, tolerance=1.0) + + assert len(result.points) == 4 + assert result.is_closed is False + assert result.closure_gap > 1.0 + + @pytest.mark.parametrize("tolerance,expected_closed", [ + (0.05, False), # Strict tolerance - not closed + (1.0, True), # Lenient tolerance - closed + ]) + def test_from_numpy_contour_tolerance(self, tolerance, expected_closed): + # Slightly off from start - tolerance affects closure detection + almost_closed_contour = np.array([ + [[0, 0]], [[2, 0]], [[2, 2]], [[0, 2]], [[0.1, 0.1]] + ]) + + result = Contour.from_numpy_contour(almost_closed_contour, tolerance=tolerance) + assert result.is_closed is expected_closed + + def test_property_consistency(self, triangle_points): + contour = Contour(points=triangle_points, is_closed=True, closure_gap=0.0) + + # For triangle with points (0,0), (3,0), (3,4) + expected_area = 6.0 # 3*4/2 + expected_perimeter = 12.0 # 3 + 4 + 5 + expected_circularity = (4 * np.pi * expected_area) / (expected_perimeter ** 2) + expected_center = Point(2.0, 1.3333333333333333) # (0+3+3)/3, (0+0+4)/3 + + assert contour.area == expected_area + assert contour.perimeter == expected_perimeter + assert contour.circularity == pytest.approx(expected_circularity, abs=0.001) + assert contour.get_center() == expected_center + + def test_immutability_of_points(self): + """Test that external changes to points list don't affect the contour.""" + original_points = [Point(0, 0), Point(1, 0), Point(1, 1)] + contour = Contour(points=original_points, is_closed=True, closure_gap=0.0) + + original_point_count = len(contour.points) + original_area = contour.area + + # Modify external list - contour should be unaffected + original_points.append(Point(0, 1)) + + assert len(contour.points) == original_point_count + assert contour.area == original_area + + # New contour with modified list should be different + new_contour = Contour(points=original_points, is_closed=True, closure_gap=0.0) + assert len(new_contour.points) == 4 + assert new_contour.area != original_area + + @pytest.mark.parametrize("points,expected_closed,expected_gap", [ + ([Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1)], True, 0.0), + ([Point(0, 0), Point(1, 0), Point(1, 1)], False, 0.0), + ]) + def test_closure_properties(self, points, expected_closed, expected_gap): + contour = Contour(points=points, is_closed=expected_closed, closure_gap=expected_gap) + + assert contour.is_closed == expected_closed + assert contour.closure_gap == expected_gap + + def test_is_empty(self): + empty_contour = Contour(points=[], is_closed=True, closure_gap=0.0) + non_empty_contour = Contour(points=[Point(0, 0)], is_closed=False, closure_gap=0.0) + + assert empty_contour.is_empty() is True + assert non_empty_contour.is_empty() is False + + def test_get_bounding_box(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + bbox = contour.get_bounding_box() + + assert bbox == (0.0, 0.0, 2.0, 2.0) + + def test_get_bounding_box_empty(self, empty_contour): + assert empty_contour.get_bounding_box() is None + + def test_len(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.0) + assert len(contour) == 4 + + def test_repr(self, square_points): + contour = Contour(points=square_points, is_closed=True, closure_gap=0.5) + repr_str = repr(contour) + + assert "Contour" in repr_str + assert "points=4" in repr_str + assert "CLOSED" in repr_str + assert "area=4.0" in repr_str + assert "gap=0.50" in repr_str \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/tests/core/entities/test_point.py b/sketchgetdp/bitmap_tracer/tests/core/entities/test_point.py new file mode 100644 index 0000000..87f27cf --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/core/entities/test_point.py @@ -0,0 +1,238 @@ +import pytest +import math +from typing import Tuple +import sys +import os + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from bitmap_tracer.core.entities.point import Point, PointData + +class TestPoint: + """Test suite for Point value object""" + + def test_point_creation(self): + """Test basic Point creation with x and y coordinates""" + point = Point(3.5, 7.2) + assert point.x == 3.5 + assert point.y == 7.2 + + def test_point_immutability(self): + """Test that Point objects are immutable (dataclass frozen behavior)""" + point = Point(1.0, 2.0) + + # Verify attributes cannot be modified directly + with pytest.raises(AttributeError): + point.x = 5.0 + with pytest.raises(AttributeError): + point.y = 5.0 + + def test_to_tuple(self): + """Test conversion to tuple format""" + point = Point(3.14, 2.71) + result = point.to_tuple() + + assert isinstance(result, Tuple) + assert result == (3.14, 2.71) + assert result[0] == 3.14 + assert result[1] == 2.71 + + def test_distance_to_same_point(self): + """Test distance calculation to the same point""" + point1 = Point(5.0, 5.0) + point2 = Point(5.0, 5.0) + + distance = point1.distance_to(point2) + assert distance == 0.0 + + def test_distance_to_different_points(self): + """Test distance calculation to different points""" + point1 = Point(0.0, 0.0) + point2 = Point(3.0, 4.0) # 3-4-5 triangle + + distance = point1.distance_to(point2) + assert distance == 5.0 + + def test_distance_to_negative_coordinates(self): + """Test distance calculation with negative coordinates""" + point1 = Point(-1.0, -1.0) + point2 = Point(2.0, 3.0) + + distance = point1.distance_to(point2) + expected_distance = math.sqrt((3.0 ** 2) + (4.0 ** 2)) # 5.0 + assert distance == expected_distance + + def test_from_tuple_creation(self): + """Test factory method creating Point from tuple""" + input_tuple = (10.5, 20.7) + point = Point.from_tuple(input_tuple) + + assert point.x == 10.5 + assert point.y == 20.7 + assert isinstance(point, Point) + + def test_from_tuple_with_negative_values(self): + """Test factory method with negative tuple values""" + input_tuple = (-5.5, -10.2) + point = Point.from_tuple(input_tuple) + + assert point.x == -5.5 + assert point.y == -10.2 + + def test_equality_comparison(self): + """Test that Points with same coordinates are equal""" + point1 = Point(1.0, 2.0) + point2 = Point(1.0, 2.0) + + assert point1 == point2 + + def test_inequality_comparison(self): + """Test that Points with different coordinates are not equal""" + point1 = Point(1.0, 2.0) + point2 = Point(1.0, 3.0) + point3 = Point(2.0, 2.0) + + assert point1 != point2 + assert point1 != point3 + + def test_hashability(self): + """Test that Point objects are hashable (required for value objects)""" + point1 = Point(1.0, 2.0) + point2 = Point(1.0, 2.0) + + # Should be able to create sets and use as dict keys + point_set = {point1, point2} + assert len(point_set) == 1 # Duplicates should be removed + + point_dict = {point1: "value"} + assert point_dict[point2] == "value" # Same coordinates should access same key + + +class TestPointData: + """Test suite for PointData enhanced point information""" + + def test_point_data_creation_defaults(self): + """Test PointData creation with default values""" + point_data = PointData(1.0, 2.0) + + assert point_data.x == 1.0 + assert point_data.y == 2.0 + assert point_data.radius == 0.0 + assert point_data.is_small_point is False + + def test_point_data_creation_custom_values(self): + """Test PointData creation with custom radius and small point flag""" + point_data = PointData(1.0, 2.0, radius=5.5, is_small_point=True) + + assert point_data.x == 1.0 + assert point_data.y == 2.0 + assert point_data.radius == 5.5 + assert point_data.is_small_point is True + + def test_point_data_immutability(self): + """Test that PointData objects are immutable""" + point_data = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + + # Verify attributes cannot be modified directly + with pytest.raises(AttributeError): + point_data.x = 5.0 + with pytest.raises(AttributeError): + point_data.y = 5.0 + with pytest.raises(AttributeError): + point_data.radius = 10.0 + with pytest.raises(AttributeError): + point_data.is_small_point = False + + def test_center_property(self): + """Test center property returns correct Point""" + point_data = PointData(3.5, 7.2, radius=2.0) + center = point_data.center + + assert isinstance(center, Point) + assert center.x == 3.5 + assert center.y == 7.2 + + def test_to_point_conversion(self): + """Test conversion to basic Point object""" + point_data = PointData(4.5, 6.7, radius=1.5, is_small_point=True) + point = point_data.to_point() + + assert isinstance(point, Point) + assert point.x == 4.5 + assert point.y == 6.7 + # Should not include radius or is_small_point in basic Point + + def test_equality_comparison_point_data(self): + """Test that PointData objects with same attributes are equal""" + point_data1 = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + point_data2 = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + + assert point_data1 == point_data2 + + def test_inequality_comparison_point_data(self): + """Test that PointData objects with different attributes are not equal""" + point_data1 = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + point_data2 = PointData(1.0, 2.0, radius=4.0, is_small_point=True) # Different radius + point_data3 = PointData(1.0, 2.0, radius=3.0, is_small_point=False) # Different flag + + assert point_data1 != point_data2 + assert point_data1 != point_data3 + + def test_hashability_point_data(self): + """Test that PointData objects are hashable""" + point_data1 = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + point_data2 = PointData(1.0, 2.0, radius=3.0, is_small_point=True) + + # Should be able to create sets and use as dict keys + point_data_set = {point_data1, point_data2} + assert len(point_data_set) == 1 # Duplicates should be removed + + point_data_dict = {point_data1: "value"} + assert point_data_dict[point_data2] == "value" + + +class TestPointAndPointDataIntegration: + """Test integration between Point and PointData classes""" + + def test_point_data_center_returns_point(self): + """Test that PointData.center returns a proper Point object""" + point_data = PointData(10.0, 20.0, radius=5.0) + center_point = point_data.center + + # Verify it's a Point with correct coordinates + assert isinstance(center_point, Point) + assert center_point.x == 10.0 + assert center_point.y == 20.0 + + # Verify Point methods work on the returned center + distance = center_point.distance_to(Point(13.0, 24.0)) + expected_distance = math.sqrt(3.0**2 + 4.0**2) # 5.0 + assert distance == expected_distance + + def test_point_data_to_point_conversion(self): + """Test that to_point() returns a proper Point object""" + point_data = PointData(15.0, 25.0, radius=10.0, is_small_point=False) + basic_point = point_data.to_point() + + assert isinstance(basic_point, Point) + assert basic_point.x == 15.0 + assert basic_point.y == 25.0 + + # Verify the converted Point has all Point functionality + tuple_result = basic_point.to_tuple() + assert tuple_result == (15.0, 25.0) + + def test_interoperability_between_point_and_point_data(self): + """Test that Point and PointData can work together seamlessly""" + point_data = PointData(5.0, 5.0, radius=2.0) + regular_point = Point(8.0, 9.0) + + # PointData.center should work with Point.distance_to + distance = point_data.center.distance_to(regular_point) + expected_distance = math.sqrt(3.0**2 + 4.0**2) # 5.0 + assert distance == expected_distance + + # PointData.to_point() should create compatible Point objects + converted_point = point_data.to_point() + assert converted_point.distance_to(regular_point) == distance \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_image_tracing.py b/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_image_tracing.py new file mode 100644 index 0000000..0b94363 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_image_tracing.py @@ -0,0 +1,276 @@ +import pytest +import numpy as np +from unittest.mock import Mock, patch +import sys +import os + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +# Import the use case and entities +from bitmap_tracer.core.use_cases.image_tracing import ImageTracingUseCase +from bitmap_tracer.core.entities.point import Point +from bitmap_tracer.core.entities.contour import Contour +from bitmap_tracer.core.entities.color import ColorCategory + + +class TestImageTracingUseCase: + """Test suite for ImageTracingUseCase""" + + @pytest.fixture + def mock_dependencies(self): + """Create mocked dependencies for the use case""" + contour_detector = Mock() + color_analyzer = Mock() + point_detector = Mock() + + return { + 'contour_detector': contour_detector, + 'color_analyzer': color_analyzer, + 'point_detector': point_detector + } + + @pytest.fixture + def use_case(self, mock_dependencies): + """Create use case instance with mocked dependencies""" + return ImageTracingUseCase( + contour_detector=mock_dependencies['contour_detector'], + color_analyzer=mock_dependencies['color_analyzer'], + point_detector=mock_dependencies['point_detector'] + ) + + @pytest.fixture + def sample_image_data(self): + """Sample image data for testing""" + return { + 'image_array': np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8), + 'filename': 'test_image.png', + 'width': 100, + 'height': 100 + } + + @pytest.fixture + def sample_config(self): + """Sample configuration for testing""" + return { + 'point_max_area': 2000, + 'point_max_perimeter': 165, + 'angle_threshold': 25, + 'min_curve_angle': 120 + } + + @pytest.fixture + def sample_contours(self): + """Create sample contours for testing""" + contour1 = Mock(spec=Contour) + contour1.points = [Point(10, 10), Point(20, 10), Point(15, 20)] + contour1.area = 50.0 + contour1.perimeter = 30.0 + contour1.get_center.return_value = Point(15, 13.3) + contour1.center = (15, 13.3) + + contour2 = Mock(spec=Contour) + contour2.points = [Point(50, 50), Point(60, 50), Point(55, 60)] + contour2.area = 50.0 + contour2.perimeter = 30.0 + contour2.get_center.return_value = Point(55, 53.3) + contour2.center = (55, 53.3) + + return [contour1, contour2] + + @pytest.fixture + def sample_raw_contours(self): + """Create sample raw OpenCV contours""" + contour1 = np.array([[[10, 10]], [[20, 10]], [[15, 20]]], dtype=np.int32) + contour2 = np.array([[[50, 50]], [[60, 50]], [[55, 60]]], dtype=np.int32) + return (contour1, contour2), None # contours, hierarchy + + def test_initialization(self, mock_dependencies): + """Test use case initialization with dependencies""" + use_case = ImageTracingUseCase( + contour_detector=mock_dependencies['contour_detector'], + color_analyzer=mock_dependencies['color_analyzer'], + point_detector=mock_dependencies['point_detector'] + ) + + assert use_case.contour_detector == mock_dependencies['contour_detector'] + assert use_case.color_analyzer == mock_dependencies['color_analyzer'] + assert use_case.point_detector == mock_dependencies['point_detector'] + + def test_initialization_without_dependencies(self): + """Test use case initialization without dependencies""" + use_case = ImageTracingUseCase() + + assert use_case.contour_detector is None + assert use_case.color_analyzer is None + assert use_case.point_detector is None + + def test_execute_successful_tracing(self, use_case, mock_dependencies, sample_image_data, sample_config): + """Test successful execution of image tracing workflow""" + raw_contours = ( + np.array([[[10, 10]], [[20, 10]], [[15, 20]]], dtype=np.int32), + np.array([[[50, 50]], [[60, 50]], [[55, 60]]], dtype=np.int32) + ) + mock_dependencies['contour_detector'].detect.return_value = (raw_contours, None) + mock_dependencies['color_analyzer'].categorize.side_effect = ['red', 'blue'] + + + mock_dependencies['point_detector'].detect_point.side_effect = [ + Point(15, 13), + None + ] + + with patch.object(use_case, '_convert_to_contour_entity') as mock_convert: + mock_contour1 = Mock(spec=Contour) + mock_contour1.points = [Point(10, 10), Point(20, 10), Point(15, 20)] + mock_contour1.area = 50.0 + mock_contour1.perimeter = 30.0 + mock_contour1.get_center.return_value = Point(15, 13.3) + + mock_contour2 = Mock(spec=Contour) + mock_contour2.points = [Point(50, 50), Point(60, 50), Point(55, 60)] + mock_contour2.area = 50.0 + mock_contour2.perimeter = 30.0 + mock_contour2.get_center.return_value = Point(55, 53.3) + + mock_convert.side_effect = [mock_contour1, mock_contour2] + + result = use_case.execute(sample_image_data, sample_config) + + assert result['success'] is True + assert len(result['structures']['red_points']) == 1 + assert len(result['structures']['blue_structures']) == 1 + assert len(result['structures']['green_structures']) == 0 + assert result['total_contours'] == 2 + assert result['processed_contours'] == 2 + + # Verify dependency calls + mock_dependencies['contour_detector'].detect.assert_called_once_with(sample_image_data) + assert mock_dependencies['color_analyzer'].categorize.call_count == 2 + assert mock_dependencies['point_detector'].detect_point.call_count == 2 + + def test_execute_with_no_contours(self, use_case, mock_dependencies, sample_image_data, sample_config): + """Test execution when no contours are found""" + mock_dependencies['contour_detector'].detect.return_value = (None, None) + + result = use_case.execute(sample_image_data, sample_config) + + assert result['success'] is True + assert len(result['structures']['red_points']) == 0 + assert len(result['structures']['blue_structures']) == 0 + assert len(result['structures']['green_structures']) == 0 + assert result['total_contours'] == 0 + assert result['processed_contours'] == 0 + + def test_execute_with_empty_contours(self, use_case, mock_dependencies, sample_image_data, sample_config): + """Test execution when empty contours are returned""" + mock_dependencies['contour_detector'].detect.return_value = ((), None) + + result = use_case.execute(sample_image_data, sample_config) + + assert result['success'] is True + assert len(result['structures']['red_points']) == 0 + assert len(result['structures']['blue_structures']) == 0 + assert len(result['structures']['green_structures']) == 0 + assert result['total_contours'] == 0 + assert result['processed_contours'] == 0 + + def test_execute_with_exception(self, use_case, mock_dependencies, sample_image_data, sample_config): + """Test execution when an exception occurs""" + mock_dependencies['contour_detector'].detect.side_effect = Exception("Detection failed") + + result = use_case.execute(sample_image_data, sample_config) + + assert result['success'] is False + assert "Detection failed" in result['error'] + assert len(result['structures']['red_points']) == 0 + assert len(result['structures']['blue_structures']) == 0 + assert len(result['structures']['green_structures']) == 0 + assert result['total_contours'] == 0 + assert result['processed_contours'] == 0 + + def test_detect_contours_success(self, use_case, mock_dependencies, sample_image_data): + """Test successful contour detection""" + raw_contours = (np.array([[[10, 10]], [[20, 10]], [[15, 20]]], dtype=np.int32),) + mock_dependencies['contour_detector'].detect.return_value = (raw_contours, None) + + with patch.object(use_case, '_convert_to_contour_entity') as mock_convert: + mock_contour = Mock(spec=Contour) + mock_convert.return_value = mock_contour + + contours = use_case.detect_contours(sample_image_data) + + assert len(contours) == 1 + mock_dependencies['contour_detector'].detect.assert_called_once_with(sample_image_data) + mock_convert.assert_called_once() + + def test_detect_contours_no_detector(self, sample_image_data): + """Test contour detection when no detector is available""" + use_case = ImageTracingUseCase() # No dependencies + + contours = use_case.detect_contours(sample_image_data) + + assert contours == [] + + def test_detect_contours_none_result(self, use_case, mock_dependencies, sample_image_data): + """Test contour detection when detector returns None""" + mock_dependencies['contour_detector'].detect.return_value = (None, None) + + contours = use_case.detect_contours(sample_image_data) + + assert contours == [] + + def test_detect_contours_empty_result(self, use_case, mock_dependencies, sample_image_data): + """Test contour detection when detector returns empty result""" + mock_dependencies['contour_detector'].detect.return_value = ([], None) + + contours = use_case.detect_contours(sample_image_data) + + assert contours == [] + + def test_detect_points_with_detector(self, use_case, mock_dependencies): + """Test point detection using the point detector service""" + contour = Mock(spec=Contour) + contour.points = [Point(10, 10), Point(20, 10), Point(15, 20)] + + expected_point = Point(15, 13) + mock_dependencies['point_detector'].detect_point.return_value = expected_point + + config = {'some_setting': 'value'} + + result = use_case.detect_points(contour, config) + + assert result.x == expected_point.x + assert result.y == expected_point.y + mock_dependencies['point_detector'].set_config.assert_called_once_with(config) + mock_dependencies['point_detector'].detect_point.assert_called_once() + + def test_detect_points_with_no_config(self, use_case, mock_dependencies): + """Test point detection without providing config""" + contour = Mock(spec=Contour) + contour.points = [Point(10, 10), Point(20, 10), Point(15, 20)] + + expected_point = Point(15, 13) + mock_dependencies['point_detector'].detect_point.return_value = expected_point + + result = use_case.detect_points(contour) + + assert result.x == expected_point.x + assert result.y == expected_point.y + + mock_dependencies['point_detector'].set_config.assert_not_called() + mock_dependencies['point_detector'].detect_point.assert_called_once() + + def test_convert_to_contour_entity(self, use_case): + """Test conversion of raw contour to Contour entity""" + raw_contour = np.array([[[10, 10]], [[20, 10]], [[15, 20]]], dtype=np.int32) + + with patch('bitmap_tracer.core.use_cases.image_tracing.Contour') as MockContour: + mock_contour = Mock(spec=Contour) + MockContour.from_numpy_contour.return_value = mock_contour + + result = use_case._convert_to_contour_entity(raw_contour) + + assert result == mock_contour + MockContour.from_numpy_contour.assert_called_once_with(raw_contour) + \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_structure_filtering.py b/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_structure_filtering.py new file mode 100644 index 0000000..daacc95 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/core/use_cases/test_structure_filtering.py @@ -0,0 +1,218 @@ +# test_structure_filtering.py +import pytest +from unittest.mock import Mock, patch +from core.use_cases.structure_filtering import StructureFilteringUseCase + + +# Mock Contour class for testing +class MockContour: + def __init__(self, area: float, perimeter: float = 10.0): + self.area = area + self.perimeter = perimeter + + +class TestStructureFilteringUseCase: + + def setup_method(self): + self.use_case = StructureFilteringUseCase() + + def test_execute_basic_filtering(self): + """Test basic filtering with limits""" + structures = { + 'red_points': ['r1', 'r2', 'r3', 'r4', 'r5'], + 'blue_structures': ['b1', 'b2', 'b3'], + 'green_structures': ['g1', 'g2'] + } + config = {'red_dots': 3, 'blue_paths': 2, 'green_paths': 1} + + result = self.use_case.execute(structures, config) + + assert len(result['red_points']) == 3 + assert len(result['blue_structures']) == 2 + assert len(result['green_structures']) == 1 + + def test_execute_within_limits(self): + """Test when structures are already within limits""" + structures = { + 'red_points': ['r1', 'r2'], + 'blue_structures': ['b1'], + 'green_structures': [] + } + config = {'red_dots': 5, 'blue_paths': 3, 'green_paths': 2} + + result = self.use_case.execute(structures, config) + + assert result['red_points'] == ['r1', 'r2'] + assert result['blue_structures'] == ['b1'] + assert result['green_structures'] == [] + + def test_execute_zero_limits(self): + """Test with zero limits (should not filter)""" + structures = { + 'red_points': ['r1', 'r2'], + 'blue_structures': ['b1'], + 'green_structures': ['g1'] + } + config = {'red_dots': 0, 'blue_paths': 0, 'green_paths': 0} + + result = self.use_case.execute(structures, config) + + assert len(result['red_points']) == 2 + assert len(result['blue_structures']) == 1 + assert len(result['green_structures']) == 1 + + def test_execute_missing_keys(self): + """Test with missing structure or config keys""" + structures = {'red_points': ['r1', 'r2']} # Missing others + config = {'red_dots': 1} + + result = self.use_case.execute(structures, config) + + assert result['red_points'] == ['r1'] + assert 'blue_structures' in result + assert 'green_structures' in result + + def test_filter_structures_by_area_basic(self): + """Test basic area filtering""" + structures = [ + (100.0, 'large'), + (50.0, 'medium'), + (25.0, 'small'), + (10.0, 'tiny') + ] + + result = self.use_case.filter_structures_by_area(structures, max_count=2) + + assert len(result) == 2 + assert result[0][0] == 100.0 # Largest area + assert result[1][0] == 50.0 # Second largest + + def test_filter_structures_by_area_no_limit(self): + """Test area filtering with high limit""" + structures = [ + (100.0, 'large'), + (50.0, 'medium') + ] + + result = self.use_case.filter_structures_by_area(structures, max_count=10) + + assert len(result) == 2 + + def test_filter_structures_by_area_zero_limit(self): + """Test area filtering with zero limit""" + structures = [ + (100.0, 'large'), + (50.0, 'medium') + ] + + result = self.use_case.filter_structures_by_area(structures, max_count=0) + + assert len(result) == 0 + + def test_filter_contours_by_size_basic(self): + """Test basic size filtering of contours""" + contours = [ + MockContour(area=25.0), + MockContour(area=50.0), + MockContour(area=75.0), + MockContour(area=100.0) + ] + + result = self.use_case.filter_contours_by_size( + contours, min_area=50.0, max_area=75.0 + ) + + assert len(result) == 2 + assert all(50.0 <= c.area <= 75.0 for c in result) + + def test_filter_contours_by_size_boundary(self): + """Test size filtering with boundary values""" + contours = [ + MockContour(area=50.0), # Exactly min + MockContour(area=75.0), # Exactly max + MockContour(area=49.9), # Just below min + MockContour(area=75.1) # Just above max + ] + + result = self.use_case.filter_contours_by_size( + contours, min_area=50.0, max_area=75.0 + ) + + assert len(result) == 2 + + def test_filter_by_circularity_basic(self): + """Test basic circularity filtering""" + # Perfect circle: area = πr², perimeter = 2πr + # For r=5: area ≈ 78.54, perimeter ≈ 31.42 + contours = [ + MockContour(area=78.54, perimeter=31.42), # High circularity (~1.0) + MockContour(area=10.0, perimeter=100.0), # Low circularity (~0.013) + ] + + result = self.use_case.filter_by_circularity(contours, min_circularity=0.5) + + assert len(result) == 1 + assert result[0].area == 78.54 + + def test_filter_by_circularity_default(self): + """Test circularity filtering with default threshold""" + contours = [ + MockContour(area=10.0, perimeter=50.0), # Circularity ~0.05 + MockContour(area=5.0, perimeter=100.0), # Circularity ~0.006 (below default) + ] + + result = self.use_case.filter_by_circularity(contours) + + # Default min_circularity is 0.01 + assert len(result) == 1 + + def test_sort_contours_by_area_descending(self): + """Test sorting contours by area (largest first)""" + contours = [ + MockContour(area=25.0), + MockContour(area=100.0), + MockContour(area=50.0) + ] + + result = self.use_case.sort_contours_by_area(contours, descending=True) + + assert result[0].area == 100.0 + assert result[1].area == 50.0 + assert result[2].area == 25.0 + + def test_sort_contours_by_area_ascending(self): + """Test sorting contours by area (smallest first)""" + contours = [ + MockContour(area=100.0), + MockContour(area=25.0), + MockContour(area=50.0) + ] + + result = self.use_case.sort_contours_by_area(contours, descending=False) + + assert result[0].area == 25.0 + assert result[1].area == 50.0 + assert result[2].area == 100.0 + + def test_sort_contours_by_area_empty(self): + """Test sorting empty contour list""" + contours = [] + + result = self.use_case.sort_contours_by_area(contours, descending=True) + + assert len(result) == 0 + + @patch('builtins.print') + def test_execute_exception_handling(self, mock_print): + """Test exception handling in execute method""" + # Create a structure that will cause an error when trying to get length + bad_structures = Mock() + bad_structures.get.return_value = None + + config = {'red_dots': 5} + + # Should not raise exception, should return original structures + result = self.use_case.execute(bad_structures, config) + + assert result == bad_structures + mock_print.assert_called() diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/configuration/test_config_loader.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/configuration/test_config_loader.py new file mode 100644 index 0000000..943e445 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/configuration/test_config_loader.py @@ -0,0 +1,200 @@ +""" +Unit tests for config_loader.py +""" + +import os +import sys +import pytest +import tempfile +import yaml + +# Add project root to Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from infrastructure.configuration.config_loader import ConfigLoader + + +class TestConfigLoader: + """Test cases for ConfigLoader class.""" + + @pytest.fixture + def sample_config_data(self): + """Sample configuration data for testing.""" + return { + 'red_dots': 10, + 'blue_paths': 5, + 'green_paths': 8, + + 'point_max_area': 100, + 'point_max_perimeter': 80, + + 'blue_hue_range': [100, 140], + 'red_hue_range': [[0, 10], [170, 180]], + 'green_hue_range': [35, 85], + 'min_saturation': 50, + 'max_value_white': 200, + 'min_value_black': 50, + + 'custom_setting': 'test_value' + } + + @pytest.fixture + def config_loader(self): + """Create a ConfigLoader instance for testing.""" + return ConfigLoader() + + @pytest.fixture + def temp_config_file(self, sample_config_data): + """Create a temporary config file for testing.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(sample_config_data, f) + temp_path = f.name + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_initialization(self): + """Test ConfigLoader initialization.""" + loader = ConfigLoader() + assert loader.default_config_path == "config.yaml" + assert loader._config_cache is None + assert loader._overrides == {} + + custom_loader = ConfigLoader("custom_config.yaml") + assert custom_loader.default_config_path == "custom_config.yaml" + + def test_load_config_success(self, temp_config_file, sample_config_data): + """Test successful configuration loading.""" + loader = ConfigLoader(temp_config_file) + config = loader.load_config() + + assert config is not None + for key, value in sample_config_data.items(): + assert config[key] == value + + def test_load_config_file_not_found(self, config_loader): + """Test loading when config file doesn't exist.""" + config = config_loader.load_config("non_existent_config.yaml") + + assert config == {} + + def test_load_config_invalid_yaml(self): + """Test loading invalid YAML file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("invalid: yaml: content: [") + temp_path = f.name + + try: + loader = ConfigLoader(temp_path) + config = loader.load_config() + + assert config is None + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_load_config_caching(self, temp_config_file): + """Test that config is cached after first load.""" + loader = ConfigLoader(temp_config_file) + + # First load + config1 = loader.load_config() + assert loader._config_cache is not None + + # Second load should use cache + config2 = loader.load_config() + assert config1 == config2 + + def test_get_structure_limits(self, temp_config_file): + """Test getting structure limits.""" + loader = ConfigLoader(temp_config_file) + red_dots, blue_paths, green_paths = loader.get_structure_limits() + + assert red_dots == 10 + assert blue_paths == 5 + assert green_paths == 8 + + def test_get_structure_limits_defaults(self, config_loader): + """Test getting structure limits with defaults.""" + red_dots, blue_paths, green_paths = config_loader.get_structure_limits() + + assert red_dots == 0 + assert blue_paths == 0 + assert green_paths == 0 + + def test_get_contour_detection_params(self, temp_config_file): + """Test getting contour detection parameters.""" + loader = ConfigLoader(temp_config_file) + params = loader.get_contour_detection_params() + + expected_keys = [ + 'point_max_area', 'point_max_perimeter' + ] + + for key in expected_keys: + assert key in params + assert isinstance(params[key], (int, float)) + + def test_get_color_detection_params(self, temp_config_file): + """Test getting color detection parameters.""" + loader = ConfigLoader(temp_config_file) + params = loader.get_color_detection_params() + + expected_keys = [ + 'blue_hue_range', 'red_hue_range', 'green_hue_range', + 'min_saturation', 'max_value_white', 'min_value_black' + ] + + for key in expected_keys: + assert key in params + assert params[key] is not None + + def test_apply_overrides_internal(self, temp_config_file): + """Test internal _apply_overrides method.""" + loader = ConfigLoader(temp_config_file) + + # Load config to populate cache + original_config = loader.load_config() + + # Set some overrides + loader._overrides = { + 'custom_setting': 'overridden', + 'new_setting': 'new_value' + } + + # Apply overrides + overridden_config = loader._apply_overrides(original_config) + + # Check original config is not modified + assert original_config['custom_setting'] == 'test_value' + + # Check overrides are applied + assert overridden_config['custom_setting'] == 'overridden' + assert overridden_config['new_setting'] == 'new_value' + + def test_empty_config_file(self): + """Test loading an empty config file.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write('') # Empty file + temp_path = f.name + + try: + loader = ConfigLoader(temp_path) + config = loader.load_config() + + assert config == {} + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_none_config_path(self, config_loader): + """Test loading with None config path (should use default).""" + # This will try to load from default path which may not exist + config = config_loader.load_config(None) + + # Should return empty dict if default file doesn't exist + assert config == {} diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_color_analyzer.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_color_analyzer.py new file mode 100644 index 0000000..302870f --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_color_analyzer.py @@ -0,0 +1,225 @@ +import os +import sys +import pytest +import numpy as np +import cv2 +from unittest.mock import patch, MagicMock + +# Add project root to path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from core.entities.color import ColorCategory +from infrastructure.image_processing.color_analyzer import ColorAnalyzer + + +class TestColorAnalyzer: + """Test suite for ColorAnalyzer class""" + + def setup_method(self): + """Setup before each test method""" + self.analyzer = ColorAnalyzer() + + # Create a mock Contour entity for testing + self.mock_contour = MagicMock() + self.mock_contour.points = [MagicMock(x=10, y=10), MagicMock(x=20, y=20), MagicMock(x=30, y=30)] + self.mock_contour.area = 100.0 + self.mock_contour.to_numpy.return_value = np.array([[10, 10], [20, 20], [30, 30]], dtype=np.float32).reshape(-1, 1, 2) + + def test_initialization_default_config(self): + """Test initialization with default configuration""" + analyzer = ColorAnalyzer() + assert analyzer.blue_hue_range == [100, 140] + assert analyzer.red_hue_ranges == [[0, 10], [170, 180]] + assert analyzer.green_hue_range == [35, 85] + assert analyzer.min_saturation == 50 + assert analyzer.max_value_white == 200 + assert analyzer.min_value_black == 50 + + def test_initialization_custom_config(self): + """Test initialization with custom configuration""" + config = { + 'blue_hue_range': [90, 130], + 'red_hue_range': [[5, 15], [160, 170]], + 'green_hue_range': [40, 90], + 'min_saturation': 60, + 'max_value_white': 180, + 'min_value_black': 40 + } + analyzer = ColorAnalyzer(config) + assert analyzer.blue_hue_range == [90, 130] + assert analyzer.red_hue_ranges == [[5, 15], [160, 170]] + assert analyzer.green_hue_range == [40, 90] + assert analyzer.min_saturation == 60 + assert analyzer.max_value_white == 180 + assert analyzer.min_value_black == 40 + + def test_categorize_color_pixel_blue(self): + """Test blue color categorization""" + # Test with blue BGR color + blue_bgr = [255, 0, 0] # Pure blue in BGR + category, hex_color = self.analyzer.categorize_color_pixel(blue_bgr) + assert category == ColorCategory.BLUE + assert hex_color == "#0000FF" + + def test_categorize_color_pixel_red(self): + """Test red color categorization""" + # Test with red BGR color + red_bgr = [0, 0, 255] # Pure red in BGR + category, hex_color = self.analyzer.categorize_color_pixel(red_bgr) + assert category == ColorCategory.RED + assert hex_color == "#FF0000" + + def test_categorize_color_pixel_green(self): + """Test green color categorization""" + # Test with green BGR color + green_bgr = [0, 255, 0] # Pure green in BGR + category, hex_color = self.analyzer.categorize_color_pixel(green_bgr) + assert category == ColorCategory.GREEN + assert hex_color == "#00FF00" + + def test_categorize_color_pixel_white(self): + """Test white color categorization""" + # Test with white BGR color (high value, low saturation) + white_bgr = [255, 255, 255] # Pure white + category, hex_color = self.analyzer.categorize_color_pixel(white_bgr) + assert category == ColorCategory.WHITE + assert hex_color is None + + def test_categorize_color_pixel_black(self): + """Test black color categorization""" + # Test with black BGR color (low value) + black_bgr = [0, 0, 0] # Pure black + category, hex_color = self.analyzer.categorize_color_pixel(black_bgr) + assert category == ColorCategory.BLACK + assert hex_color is None + + def test_categorize_color_pixel_other(self): + """Test other color categorization""" + # Test with low saturation color (should be categorized as OTHER) + gray_bgr = [100, 100, 100] # Gray (low saturation) + category, hex_color = self.analyzer.categorize_color_pixel(gray_bgr) + assert category == ColorCategory.OTHER + assert hex_color is None + + def test_categorize_color_pixel_invalid_input(self): + """Test color categorization with invalid input""" + # Test with empty list + category, hex_color = self.analyzer.categorize_color_pixel([]) + assert category == ColorCategory.OTHER + assert hex_color is None + + # Test with short list + category, hex_color = self.analyzer.categorize_color_pixel([100, 100]) + assert category == ColorCategory.OTHER + assert hex_color is None + + def test_get_dominant_color_none_contour(self): + """Test get_dominant_color with None contour""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + result = self.analyzer.get_dominant_color(None, image) + assert result is None + + def test_get_dominant_color_empty_contour(self): + """Test get_dominant_color with empty contour""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + empty_contour = np.array([]) + result = self.analyzer.get_dominant_color(empty_contour, image) + assert result is None + + @patch('cv2.drawContours') + def test_get_dominant_color_contour_drawing_failure(self, mock_draw_contours): + """Test get_dominant_color when contour drawing fails""" + mock_draw_contours.side_effect = Exception("Drawing failed") + image = np.zeros((100, 100, 3), dtype=np.uint8) + contour = np.array([[10, 10], [20, 20], [30, 30]], dtype=np.int32) + result = self.analyzer.get_dominant_color(contour, image) + assert result is None + + def test_get_dominant_color_no_boundary_pixels(self): + """Test get_dominant_color when no boundary pixels are found""" + # Create an image and contour that won't produce boundary pixels + image = np.zeros((100, 100, 3), dtype=np.uint8) + contour = np.array([[1, 1], [2, 2]], dtype=np.int32) # Very small contour + + with patch('cv2.drawContours'): + result = self.analyzer.get_dominant_color(contour, image) + assert result is None + + def test_categorize_with_contour_entity(self): + """Test categorize method with Contour entity""" + # Create a test image with red pixels + image = np.zeros((100, 100, 3), dtype=np.uint8) + image[10:20, 10:20] = [0, 0, 255] # Red in BGR + + with patch.object(self.analyzer, 'get_dominant_color', return_value="#FF0000"): + result = self.analyzer.categorize(self.mock_contour, image) + assert result == "red" + + def test_categorize_no_dominant_color(self): + """Test categorize method when no dominant color is found""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + + with patch.object(self.analyzer, 'get_dominant_color', return_value=None): + result = self.analyzer.categorize(self.mock_contour, image) + assert result is None + + def test_categorize_green_color(self): + """Test categorize method with green color""" + image = np.zeros((100, 100, 3), dtype=np.uint8) + + with patch.object(self.analyzer, 'get_dominant_color', return_value="#00FF00"): + result = self.analyzer.categorize(self.mock_contour, image) + assert result == "green" + + def test_hsv_color_conversion_blue(self): + """Test HSV conversion for blue color""" + blue_bgr = [255, 0, 0] # Pure blue in BGR + hsv_color = np.uint8([[[blue_bgr[0], blue_bgr[1], blue_bgr[2]]]]) + hsv = cv2.cvtColor(hsv_color, cv2.COLOR_BGR2HSV)[0][0] + hue, saturation, value = hsv + + # Blue should have hue around 120 in OpenCV HSV (0-180 range) + assert 100 <= hue <= 140 # Within blue range + + def test_hsv_color_conversion_red(self): + """Test HSV conversion for red color""" + red_bgr = [0, 0, 255] # Pure red in BGR + hsv_color = np.uint8([[[red_bgr[0], red_bgr[1], red_bgr[2]]]]) + hsv = cv2.cvtColor(hsv_color, cv2.COLOR_BGR2HSV)[0][0] + hue, saturation, value = hsv + + # Red should have hue around 0 or 180 in OpenCV HSV + assert (0 <= hue <= 10) or (170 <= hue <= 180) + + def test_color_dominance_calculation(self): + """Test color dominance calculation logic""" + # Mock boundary pixels with majority blue + blue_pixel = [255, 0, 0] # Blue in BGR + red_pixel = [0, 0, 255] # Red in BGR + + # Create mock boundary pixels with 70% blue, 30% red + mock_pixels = np.array([blue_pixel] * 7 + [red_pixel] * 3) + + analyzer = ColorAnalyzer() + + # Mock the boundary pixel sampling + with patch.object(analyzer, 'categorize_color_pixel') as mock_categorize: + def side_effect(pixel): + if list(pixel) == blue_pixel: + return (ColorCategory.BLUE, "#0000FF") + else: + return (ColorCategory.RED, "#FF0000") + + mock_categorize.side_effect = side_effect + + # This is testing the internal logic, so we'll create a simplified test + color_categories = {} + for pixel in mock_pixels: + category, hex_color = analyzer.categorize_color_pixel(pixel.tolist()) + if category in [ColorCategory.BLUE, ColorCategory.RED, ColorCategory.GREEN]: + color_categories[category.value] = color_categories.get(category.value, 0) + 1 + + assert color_categories['blue'] == 7 + assert color_categories['red'] == 3 + assert 'green' not in color_categories diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_closure_service.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_closure_service.py new file mode 100644 index 0000000..268c53c --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_closure_service.py @@ -0,0 +1,79 @@ +import os +import sys +import pytest +import numpy as np + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from infrastructure.image_processing.contour_closure_service import ContourClosureService, ClosedContour + + +class TestContourClosureService: + + @pytest.fixture + def service(self): + return ContourClosureService() + + @pytest.fixture + def perfectly_closed_contour(self): + return np.array([[[0, 0]], [[0, 10]], [[10, 10]], [[10, 0]], [[0, 0]]], dtype=np.float32) + + @pytest.fixture + def obviously_open_contour(self): + return np.array([[[0, 0]], [[0, 10]], [[10, 10]], [[20, 20]]], dtype=np.float32) + + @pytest.fixture + def too_small_contour(self): + return np.array([[[0, 0]], [[5, 5]]], dtype=np.float32) + + def test_ensure_closure_preserves_already_closed_contour(self, service, perfectly_closed_contour): + result = service.ensure_closure(perfectly_closed_contour, tolerance=5.0) + assert np.array_equal(result, perfectly_closed_contour) + + def test_ensure_closure_closes_open_contour(self, service, obviously_open_contour): + result = service.ensure_closure(obviously_open_contour, tolerance=5.0) + assert len(result) == len(obviously_open_contour) + 1 + assert np.array_equal(result[0], result[-1]) + + def test_ensure_closure_ignores_small_contours(self, service, too_small_contour): + result = service.ensure_closure(too_small_contour, tolerance=5.0) + assert np.array_equal(result, too_small_contour) + + def test_ensure_closure_respects_tolerance_threshold(self, service): + contour_with_small_gap = np.array([[[0, 0]], [[1, 0]], [[2, 0]], [[3, 0]]], dtype=np.float32) + + result_with_loose_tolerance = service.ensure_closure(contour_with_small_gap, tolerance=5.0) + assert len(result_with_loose_tolerance) == len(contour_with_small_gap) + + result_with_tight_tolerance = service.ensure_closure(contour_with_small_gap, tolerance=2.0) + assert len(result_with_tight_tolerance) == len(contour_with_small_gap) + 1 + + def test_is_closed_returns_true_for_closed_contour(self, service, perfectly_closed_contour): + assert service.is_closed(perfectly_closed_contour, tolerance=5.0) == True + + def test_is_closed_returns_false_for_open_contour(self, service, obviously_open_contour): + assert service.is_closed(obviously_open_contour, tolerance=5.0) == False + + def test_is_closed_returns_false_for_insufficient_points(self, service, too_small_contour): + assert service.is_closed(too_small_contour, tolerance=5.0) == False + + def test_calculate_closure_gap_returns_zero_for_perfectly_closed_contour(self, service, perfectly_closed_contour): + gap = service.calculate_closure_gap(perfectly_closed_contour) + assert gap == pytest.approx(0.0) + + def test_calculate_closure_gap_returns_correct_distance_for_open_contour(self, service, obviously_open_contour): + gap = service.calculate_closure_gap(obviously_open_contour) + expected_gap = np.linalg.norm(np.array([0, 0]) - np.array([20, 20])) + assert gap == pytest.approx(expected_gap) + + def test_calculate_closure_gap_returns_infinity_for_small_contours(self, service, too_small_contour): + gap = service.calculate_closure_gap(too_small_contour) + assert gap == float('inf') + + def test_all_methods_handle_empty_contour(self, service): + empty_contour = np.array([], dtype=np.float32).reshape(0, 1, 2) + + assert len(service.ensure_closure(empty_contour)) == 0 + assert service.is_closed(empty_contour) == False + assert service.calculate_closure_gap(empty_contour) == float('inf') diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_detector.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_detector.py new file mode 100644 index 0000000..3fa5636 --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/image_processing/test_contour_detector.py @@ -0,0 +1,126 @@ +import os +import sys +import pytest +import cv2 +import numpy as np +from unittest.mock import patch + + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from infrastructure.image_processing.contour_detector import ContourDetector + + +class TestContourDetector: + + @pytest.fixture + def contour_detector(self): + return ContourDetector() + + @pytest.fixture + def sample_image_data(self): + img = np.zeros((100, 100, 3), dtype=np.uint8) + cv2.rectangle(img, (20, 20), (80, 80), (255, 255, 255), -1) + return {'image_array': img} + + @pytest.fixture + def empty_image_data(self): + return {'image_array': None} + + @pytest.fixture + def mock_contours(self): + contour = np.array([[[10, 10]], [[10, 90]], [[90, 90]], [[90, 10]]], dtype=np.int32) + hierarchy = np.array([[[-1, -1, 1, -1]]], dtype=np.int32) + return [contour], hierarchy + + def test_initialization_creates_closure_service(self, contour_detector): + assert contour_detector.closure_service is not None + + def test_detect_returns_contours_for_valid_image(self, contour_detector, sample_image_data): + contours, hierarchy = contour_detector.detect(sample_image_data) + + assert contours is not None + assert isinstance(contours, tuple) + assert len(contours) > 0 + assert hierarchy is not None + + def test_detect_returns_none_for_empty_image_data(self, contour_detector, empty_image_data): + contours, hierarchy = contour_detector.detect(empty_image_data) + assert contours is None + assert hierarchy is None + + def test_detect_returns_none_for_missing_image_array(self, contour_detector): + invalid_data = {'wrong_key': np.zeros((100, 100, 3), dtype=np.uint8)} + contours, hierarchy = contour_detector.detect(invalid_data) + assert contours is None + assert hierarchy is None + + def test_detect_ensures_all_contours_are_closed(self, contour_detector, sample_image_data): + contours, _ = contour_detector.detect(sample_image_data) + + if contours is not None: + for contour in contours: + assert len(contour) >= 3 # Minimum points for closed shape + + def test_image_processing_creates_valid_binary_images(self, contour_detector, sample_image_data): + img = sample_image_data['image_array'] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + binary1 = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY_INV, 15, 5) + _, binary2 = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + + assert binary1.shape == gray.shape + assert binary2.shape == gray.shape + assert np.isin(binary1, [0, 255]).all() + assert np.isin(binary2, [0, 255]).all() + + def test_morphological_operations_clean_noise(self): + binary_img = np.zeros((50, 50), dtype=np.uint8) + binary_img[10:40, 10:40] = 255 + binary_img[5:7, 5:7] = 255 # Noise + + kernel = np.ones((3,3), np.uint8) + closed = cv2.morphologyEx(binary_img, cv2.MORPH_CLOSE, kernel, iterations=2) + opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel, iterations=1) + + assert closed.shape == binary_img.shape + assert opened.shape == binary_img.shape + + @pytest.mark.parametrize("image_shape", [ + (100, 100, 3), + (50, 50, 3), + (200, 200, 3), + ]) + def test_detect_handles_different_image_sizes(self, contour_detector, image_shape): + img = np.zeros(image_shape, dtype=np.uint8) + cv2.rectangle(img, (10, 10), (image_shape[1]-10, image_shape[0]-10), (255, 255, 255), -1) + image_data = {'image_array': img} + + contours, hierarchy = contour_detector.detect(image_data) + assert contours is not None + assert len(contours) > 0 + + def test_contour_hierarchy_has_correct_structure(self, contour_detector, sample_image_data): + contours, hierarchy = contour_detector.detect(sample_image_data) + + if hierarchy is not None: + assert isinstance(hierarchy, np.ndarray) + assert hierarchy.ndim == 3 + assert hierarchy.shape[2] == 4 # [next, previous, first_child, parent] + + def test_closure_service_called_during_detection(self, contour_detector, sample_image_data): + with patch.object(contour_detector.closure_service, 'ensure_closure') as mock_ensure: + with patch.object(contour_detector.closure_service, 'is_closed') as mock_is_closed: + with patch.object(contour_detector.closure_service, 'calculate_closure_gap') as mock_gap: + + mock_ensure.return_value = np.array([[[10, 10]], [[10, 90]], [[90, 90]], [[90, 10]]]) + mock_is_closed.return_value = True + mock_gap.return_value = 0.0 + + contours, hierarchy = contour_detector.detect(sample_image_data) + + assert mock_ensure.called + assert mock_is_closed.called + assert mock_gap.called diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_curve_fitter.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_curve_fitter.py new file mode 100644 index 0000000..002844f --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_curve_fitter.py @@ -0,0 +1,160 @@ +import os +import sys +import pytest +import numpy as np +import copy + +# Required for importing the module under test +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from infrastructure.point_detection.curve_fitter import CurveFitter + + +class TestCurveFitter: + """Verify CurveFitter converts raster contours to smooth vector paths.""" + + @pytest.fixture + def curve_fitter(self): + return CurveFitter(angle_threshold=25, min_curve_angle=120) + + @pytest.fixture + def simple_contour(self): + """Square contour tests basic shape handling.""" + return np.array([[[0, 0]], [[100, 0]], [[100, 100]], [[0, 100]]], dtype=np.int32) + + @pytest.fixture + def triangle_contour(self): + """Triangle contour tests corner detection.""" + return np.array([[[0, 0]], [[50, 100]], [[100, 0]]], dtype=np.int32) + + @pytest.fixture + def closed_contour(self): + """Circular contour tests curve fitting behavior.""" + points = [] + center_x, center_y = 50, 50 + radius = 40 + + # Exact integer coordinates ensure proper closure detection + angles = [0, 45, 90, 135, 180, 225, 270, 315] + for angle_deg in angles: + angle_rad = np.radians(angle_deg) + x = int(center_x + radius * np.cos(angle_rad)) + y = int(center_y + radius * np.sin(angle_rad)) + points.append([[x, y]]) + + points.append(points[0]) # Force exact closure + return np.array(points, dtype=np.int32) + + def test_initialization_sets_geometric_thresholds(self, curve_fitter): + """Thresholds determine line vs curve classification.""" + assert curve_fitter.angle_threshold == 25 + assert curve_fitter.min_curve_angle == 120 + + def test_fit_curve_generates_valid_svg_path(self, curve_fitter, simple_contour): + """SVG path must be properly formatted for rendering.""" + path_data = curve_fitter.fit_curve(simple_contour) + assert path_data.startswith('M') # Move command starts path + assert path_data.endswith('Z') # Close command ends path + assert any(cmd in path_data for cmd in ['L', 'Q']) # Contains drawing commands + + def test_fit_curve_rejects_invalid_contours(self, curve_fitter): + """Prevents processing of malformed input data.""" + insufficient_contour = np.array([[[0, 0]], [[1, 1]]], dtype=np.int32) + assert curve_fitter.fit_curve(insufficient_contour) is None + + def test_ensure_closure_preserves_already_closed_contours(self, curve_fitter): + """Avoids redundant operations on properly formed shapes.""" + points = [[0, 0], [100, 0], [100, 100], [0, 100], [0, 0]] + closed_points, is_closed = curve_fitter._ensure_closure(points) + assert bool(is_closed) is True + assert len(closed_points) == len(points) + + def test_ensure_closure_force_closes_open_contours(self, curve_fitter): + """SVG requires closed paths for proper rendering.""" + original_points = [[0, 0], [100, 0], [100, 100], [0, 100]] + points = copy.deepcopy(original_points) + closed_points, is_closed = curve_fitter._ensure_closure(points) + assert bool(is_closed) is True + assert len(closed_points) == len(original_points) + 1 # Added closure point + assert closed_points[0] == closed_points[-1] # Path forms complete loop + + def test_calculate_segment_angle_computes_turning_angles(self, curve_fitter): + """Angles determine whether to use lines or curves.""" + angle = curve_fitter._calculate_segment_angle([0, 0], [0, 1], [1, 1]) + assert angle is not None + assert abs(angle - 90) < 1.0 # Right angle should be ~90 degrees + + def test_calculate_segment_angle_handles_degenerate_cases(self, curve_fitter): + """Prevents mathematical errors with invalid geometry.""" + angle = curve_fitter._calculate_segment_angle([0, 0], [0, 0], [0, 0]) + assert angle is None + + def test_should_use_curve_fitting_determines_segment_eligibility(self, curve_fitter): + """Curve fitting requires sufficient surrounding points.""" + assert curve_fitter._should_use_curve_fitting(1, 5, True) is True # Has neighbors + assert curve_fitter._should_use_curve_fitting(4, 5, False) is False # Path boundary + + def test_generate_svg_path_creates_drawing_commands(self, curve_fitter): + """Translates geometric data into SVG render instructions.""" + points = [[0, 0], [100, 0], [100, 100], [0, 100]] + path_data = curve_fitter._generate_svg_path(points, True) + assert path_data.startswith('M 0,0') + assert path_data.endswith('Z') + + def test_contour_closure_detection_handles_various_geometries(self, curve_fitter, closed_contour): + """Different contour types require different closure strategies.""" + points = [[point[0][0], point[0][1]] for point in closed_contour] + + # Naturally closed contours remain unchanged + points_copy_1 = copy.deepcopy(points) + closed_points_1, is_closed_1 = curve_fitter._ensure_closure(points_copy_1) + assert bool(is_closed_1) is True + assert len(closed_points_1) == len(points) + + # Artificially opened contours get forced closure + points_copy_2 = copy.deepcopy(points) + opened_points = points_copy_2[:-1] + original_opened_count = len(opened_points) + closed_points_2, is_closed_2 = curve_fitter._ensure_closure(opened_points) + assert bool(is_closed_2) is True + assert len(closed_points_2) == original_opened_count + 1 + + def test_different_epsilon_factors_affect_simplification_aggressiveness(self, curve_fitter, simple_contour): + """Tolerance balance between detail preservation and point reduction.""" + path_aggressive = curve_fitter.fit_curve(simple_contour, epsilon_factor=0.01) + path_conservative = curve_fitter.fit_curve(simple_contour, epsilon_factor=0.0001) + assert path_aggressive is not None + assert path_conservative is not None + + def test_path_data_contains_required_svg_elements(self, curve_fitter, simple_contour): + """SVG specification mandates specific command structure.""" + path_data = curve_fitter.fit_curve(simple_contour) + commands = path_data.split() + assert commands[0] == 'M' # Must start with move + assert commands[-1] == 'Z' # Must end with close + + def test_performance_with_large_contours(self, curve_fitter): + """Algorithm must handle realistic input sizes efficiently.""" + points = [] + for i in range(50): + angle = 2 * np.pi * i / 50 + x = 100 + 80 * np.cos(angle) + y = 100 + 80 * np.sin(angle) + points.append([[x, y]]) + points.append(points[0]) + large_contour = np.array(points, dtype=np.int32) + + path_data = curve_fitter.fit_curve(large_contour) + assert path_data is not None + + def test_square_contour_prefers_curves_over_lines(self, curve_fitter): + """Curve fitting produces smoother results than straight lines.""" + square_contour = np.array([[[0, 0]], [[100, 0]], [[100, 100]], [[0, 100]]], dtype=np.int32) + path_data = curve_fitter.fit_curve(square_contour) + assert any(cmd in path_data for cmd in ['L', 'Q']) + + def test_triangle_contour_generation(self, curve_fitter, triangle_contour): + """Triangles test corner case with minimal points.""" + path_data = curve_fitter.fit_curve(triangle_contour) + assert path_data is not None diff --git a/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_point_detector.py b/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_point_detector.py new file mode 100644 index 0000000..50ed88d --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/infrastructure/point_detection/test_point_detector.py @@ -0,0 +1,139 @@ +import os +import sys +import numpy as np +import cv2 +from unittest.mock import patch + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from core.entities.point import Point +from infrastructure.point_detection.point_detector import PointDetector + + +class TestPointDetector: + """Verify point detection logic for small, compact contours""" + + def setup_method(self): + # Fresh detector for each test prevents state leakage + self.detector = PointDetector(max_area=100, max_perimeter=80) + + def test_initialization(self): + """Ensure detector starts with correct size thresholds""" + detector = PointDetector(max_area=50, max_perimeter=40) + assert detector.max_area == 50 + assert detector.max_perimeter == 40 + + def test_set_config(self): + """Allow runtime adjustment of detection parameters""" + config = {'point_max_area': 75, 'point_max_perimeter': 60} + self.detector.set_config(config) + assert self.detector.max_area == 75 + assert self.detector.max_perimeter == 60 + + def test_set_config_partial(self): + """Maintain existing values when config is incomplete""" + config = {'point_max_area': 75} + self.detector.set_config(config) + # Perimeter unchanged because not in config + assert self.detector.max_perimeter == 80 + + def test_is_point_valid_contour(self): + """Accept contours that meet both size criteria""" + contour = np.array([[[10, 10]], [[12, 10]], [[12, 12]], [[10, 12]]], dtype=np.int32) + + with patch.object(cv2, 'contourArea', return_value=50), \ + patch.object(cv2, 'arcLength', return_value=30): + + assert self.detector.is_point(contour) is True + + def test_is_point_too_large_area(self): + """Reject contours that exceed area threshold""" + contour = np.array([[[0, 0]], [[20, 0]], [[20, 20]], [[0, 20]]], dtype=np.int32) + + with patch.object(cv2, 'contourArea', return_value=150), \ + patch.object(cv2, 'arcLength', return_value=30): + + assert self.detector.is_point(contour) is False + + def test_is_point_too_large_perimeter(self): + """Reject contours that exceed perimeter threshold""" + contour = np.array([[[0, 0]], [[20, 0]], [[20, 20]], [[0, 20]]], dtype=np.int32) + + with patch.object(cv2, 'contourArea', return_value=50), \ + patch.object(cv2, 'arcLength', return_value=100): + + assert self.detector.is_point(contour) is False + + def test_is_point_invalid_contour(self): + """Reject degenerate contours that cannot form shapes""" + invalid_contour = np.array([[[10, 10]], [[12, 10]]], dtype=np.int32) + assert self.detector.is_point(invalid_contour) is False + + def test_get_center_valid_contour(self): + """Calculate geometric center using moment analysis""" + contour = np.array([[[0, 0]], [[10, 0]], [[10, 10]], [[0, 10]]], dtype=np.int32) + + center = self.detector.get_center(contour) + + assert center.x == 5 # Centroid of rectangle + assert center.y == 5 + + def test_get_center_invalid_contour(self): + """Avoid center calculation for invalid contours""" + invalid_contour = np.array([[[10, 10]], [[12, 10]]], dtype=np.int32) + assert self.detector.get_center(invalid_contour) is None + + def test_get_center_zero_moment(self): + """Handle edge case where contour has no area""" + contour = np.array([[[0, 0]], [[10, 0]], [[10, 10]], [[0, 10]]], dtype=np.int32) + + with patch.object(cv2, 'moments') as mock_moments: + mock_moments.return_value = {'m00': 0, 'm10': 100, 'm01': 100} + assert self.detector.get_center(contour) is None + + def test_detect_point_valid(self): + """Complete pipeline: validate contour and return center""" + contour = np.array([[[5, 5]], [[7, 5]], [[7, 7]], [[5, 7]]], dtype=np.int32) + + with patch.object(cv2, 'contourArea', return_value=50), \ + patch.object(cv2, 'arcLength', return_value=30), \ + patch.object(cv2, 'moments') as mock_moments: + + mock_moments.return_value = {'m00': 1, 'm10': 6, 'm01': 6} + point = self.detector.detect_point(contour) + + assert point.x == 6 + assert point.y == 6 + + def test_detect_point_invalid(self): + """Return None when contour fails point criteria""" + contour = np.array([[[0, 0]], [[20, 0]], [[20, 20]], [[0, 20]]], dtype=np.int32) + + with patch.object(cv2, 'contourArea', return_value=150), \ + patch.object(cv2, 'arcLength', return_value=30): + + assert self.detector.detect_point(contour) is None + + +class TestPointDetectorIntegration: + """Verify detector works with actual OpenCV operations""" + + def setup_method(self): + self.detector = PointDetector(max_area=100, max_perimeter=80) + + def test_real_contour_detection(self): + """Ensure mathematical correctness with real contour calculations""" + image = np.zeros((100, 100), dtype=np.uint8) + cv2.circle(image, (50, 50), 3, 255, -1) + + contours, _ = cv2.findContours(image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if contours: + contour = contours[0] + center = self.detector.detect_point(contour) + + # Small circle should be detected as point with correct center + assert center is not None + assert abs(center.x - 50) <= 2 # Allow small calculation tolerance + assert abs(center.y - 50) <= 2 \ No newline at end of file diff --git a/sketchgetdp/bitmap_tracer/tests/interfaces/test_svg_presenter.py b/sketchgetdp/bitmap_tracer/tests/interfaces/test_svg_presenter.py new file mode 100644 index 0000000..587efbf --- /dev/null +++ b/sketchgetdp/bitmap_tracer/tests/interfaces/test_svg_presenter.py @@ -0,0 +1,191 @@ +import os +import sys +import pytest +import tempfile +from unittest.mock import Mock + +# Required for importing project modules +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../../')) +sys.path.insert(0, project_root) + +from interfaces.presenters.svg_presenter import SVGPresenter +from core.entities.point import Point +from core.entities.contour import Contour +from core.entities.color import ColorCategory + + +class TestSVGPresenter: + """Validates SVG generation from geometric primitives.""" + + @pytest.fixture + def temp_output_path(self): + """Isolates tests by using temporary files that auto-clean.""" + with tempfile.NamedTemporaryFile(suffix='.svg', delete=False) as f: + temp_path = f.name + yield temp_path + if os.path.exists(temp_path): + os.unlink(temp_path) + + @pytest.fixture + def sample_points(self): + """Provides consistent test data across multiple tests.""" + return [Point(10, 20), Point(30, 40), Point(50, 60)] + + @pytest.fixture + def sample_contour(self, sample_points): + """Creates closed shape for testing path generation.""" + return Contour(points=sample_points, is_closed=True, closure_gap=0.0) + + @pytest.fixture + def basic_presenter(self, temp_output_path): + """Base presenter instance to avoid constructor duplication.""" + return SVGPresenter(temp_output_path, width=800, height=600) + + def test_initialization(self, temp_output_path): + """Ensures presenter starts in clean state.""" + presenter = SVGPresenter(temp_output_path, width=800, height=600) + + assert presenter.output_path == temp_output_path + assert presenter.width == 800 + assert presenter.height == 600 + assert presenter.elements_count['points'] == 0 + + def test_add_point_red(self, basic_presenter): + """Red points increment special counter for highlighting.""" + point = Point(100, 150) + color = Mock() + color.categorize.return_value = (ColorCategory.RED, "#FF0000") + color.to_hex.return_value = "#FF0000" + + basic_presenter.add_point(point, color) + + assert basic_presenter.elements_count['red_points'] == 1 + + def test_add_point_blue(self, basic_presenter): + """Non-red points use standard counting.""" + point = Point(200, 250) + color = Mock() + color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + color.to_hex.return_value = "#0000FF" + + basic_presenter.add_point(point, color) + + assert basic_presenter.elements_count['red_points'] == 0 + + def test_add_path_blue(self, basic_presenter): + """Color categorization drives both styling and statistics.""" + path_data = "M 10,20 L 30,40 L 50,60 Z" + color = Mock() + color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + + basic_presenter.add_path(path_data, color) + + assert basic_presenter.elements_count['blue_paths'] == 1 + + def test_add_path_green(self, basic_presenter): + """Separate counters allow color-specific analytics.""" + path_data = "M 10,20 L 30,40 L 50,60 Z" + color = Mock() + color.categorize.return_value = (ColorCategory.GREEN, "#00FF00") + + basic_presenter.add_path(path_data, color) + + assert basic_presenter.elements_count['green_paths'] == 1 + + def test_add_contour_as_path(self, basic_presenter, sample_contour): + """Contours become paths while preserving color semantics.""" + color = Mock() + color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + + basic_presenter.add_contour_as_path(sample_contour, color) + + assert basic_presenter.elements_count['blue_paths'] == 1 + + def test_add_empty_contour(self, basic_presenter): + """Empty contours avoid generating invalid paths.""" + empty_contour = Contour(points=[], is_closed=False, closure_gap=0.0) + color = Mock() + color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + + basic_presenter.add_contour_as_path(empty_contour, color) + + assert basic_presenter.elements_count['paths'] == 0 + + def test_convert_contour_to_path_data(self, basic_presenter, sample_contour): + """Path data must match point sequence and closure flag.""" + path_data = basic_presenter._convert_contour_to_path_data(sample_contour) + + assert "M 10,20" in path_data + assert "L 30,40" in path_data + assert "Z" in path_data # Closed contours get termination + + def test_save_svg(self, basic_presenter, temp_output_path): + """File output must succeed and create valid SVG structure.""" + point = Point(100, 150) + color = Mock() + color.categorize.return_value = (ColorCategory.RED, "#FF0000") + color.to_hex.return_value = "#FF0000" + basic_presenter.add_point(point, color) + + result = basic_presenter.save() + + assert result is True + assert os.path.exists(temp_output_path) + + def test_path_stroke_color_mapping(self, basic_presenter): + """Categorized colors map to consistent stroke values.""" + blue_color = Mock() + blue_color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + + stroke_color = basic_presenter._get_path_stroke_color(blue_color) + + assert stroke_color == "#0000FF" + + def test_path_counter_increment(self, basic_presenter): + """Color-specific counting supports usage analytics.""" + blue_color = Mock() + blue_color.categorize.return_value = (ColorCategory.BLUE, "#0000FF") + + basic_presenter._increment_path_counter(blue_color) + + assert basic_presenter.elements_count['blue_paths'] == 1 + + def test_build_path_commands_closed(self, basic_presenter, sample_contour): + """Closed contours must include termination command.""" + commands = basic_presenter._build_path_commands_from_contour(sample_contour) + + assert commands[-1] == "Z" + + def test_build_path_commands_open(self, basic_presenter): + """Open contours omit termination for incomplete shapes.""" + points = [Point(10, 20), Point(30, 40)] + contour = Contour(points=points, is_closed=False, closure_gap=0.0) + + commands = basic_presenter._build_path_commands_from_contour(contour) + + assert "Z" not in commands + + def test_contour_with_single_point(self, basic_presenter): + """Single points create positioning-only paths.""" + single_point_contour = Contour(points=[Point(10, 20)], is_closed=False, closure_gap=0.0) + + path_data = basic_presenter._convert_contour_to_path_data(single_point_contour) + + assert path_data == "M 10,20" # Move-to without drawing + + def test_contour_with_two_points(self, basic_presenter): + """Two points form simple line segments.""" + two_point_contour = Contour(points=[Point(10, 20), Point(30, 40)], is_closed=False, closure_gap=0.0) + + path_data = basic_presenter._convert_contour_to_path_data(two_point_contour) + + assert path_data == "M 10,20 L 30,40" + + def test_empty_contour_path_data(self, basic_presenter): + """Empty contours prevent invalid SVG generation.""" + empty_contour = Contour(points=[], is_closed=False, closure_gap=0.0) + + path_data = basic_presenter._convert_contour_to_path_data(empty_contour) + + assert path_data == "" + \ No newline at end of file diff --git a/sketchgetdp/demos/demo_geometry_construction.py b/sketchgetdp/demos/demo_geometry_construction.py deleted file mode 100644 index 04fee42..0000000 --- a/sketchgetdp/demos/demo_geometry_construction.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Demo: Basic usage of the Gmsh geometry construction functionalities - -Author: Laura D'Angelo -""" - -from sketchgetdp.geometry import gmsh_toolbox as geo - -def draw_rectangle(factory: geo.GeoFactory, x1: float, y1: float, x2: float, y2: float, - hole_tags: list[int] = []) -> dict: - """ Draws a rectangle from two given corner points, possibly with a hole. - - Parameters: - factory (GeoFactory): a Gmsh factory object - x1 (float): x-coordinate of first corner point - y1 (float): y-coordinate of first corner point - x2 (float): x-coordinate of second corner point - y2 (float): y-coordinate of second corner point - hole_tags (list[int]): list of surfaces within the rectangle which should be - treated as holes, optional - - Returns: - dict: dictionary containing surface, curve loop and line tags of the drawn rectangle - """ - # Draw the points - p1 = factory.addPoint(x1, y1, 0) - p2 = factory.addPoint(x2, y1, 0) - p3 = factory.addPoint(x2, y2, 0) - p4 = factory.addPoint(x1, y2, 0) - - # Draw the lines - l1 = factory.addLine(p1, p2) - l2 = factory.addLine(p2, p3) - l3 = factory.addLine(p3, p4) - l4 = factory.addLine(p4, p1) - - # Define curve loop and plane surface - curve_loop = factory.addCurveLoop([l1, l2, l3, l4]) - curve_loop_list = [curve_loop] + hole_tags - surface = factory.addPlaneSurface(curve_loop_list) - - # Return curve loop tag (for future holes) and surface tags for wires - return {"hole": [curve_loop], - "surface": surface, - "boundary": [l1, l2, l3, l4]} - -def run_demo() -> None: - """ Runs a demo script that draws a rectangular domain within a larger rectangular domain. - """ - model_name = "demo_rectangular_model" # Define the model name - - # Define geometrical parameters - width_in = 0.7 - height_in = 0.3 - width_out = 1 - height_out = 0.5 - - # Define physical region identifiers - domain_in = 1 - domain_out = 2 - boundary_in = 11 - boundary_out = 12 - - # Define the mesh size - h_mesh = 0.1 - - factory = geo.initialize_gmsh(model_name) # Initialize Gmsh - geo.set_characteristic_mesh_length(h_mesh) # Set the mesh size - - # Draw the inner rectangle, and define its surface and boundary as physical regions - inner_rectangle_tags = draw_rectangle(factory, -width_in/2, -height_in/2, - +width_in/2, +height_in/2) - geo.add_to_physical_group(factory, 2, inner_rectangle_tags["surface"], domain_in) - geo.add_to_physical_group(factory, 1, inner_rectangle_tags["boundary"], boundary_in) - - # Draw the outer rectangle, having the inner rectangle as hole, and define its surface and - # boundary as physical regions - outer_rectangle_tags = draw_rectangle(factory, -width_out/2, -height_out/2, - +width_out/2, +height_out/2, inner_rectangle_tags["hole"]) - geo.add_to_physical_group(factory, 2, outer_rectangle_tags["surface"], domain_out) - geo.add_to_physical_group(factory, 1, outer_rectangle_tags["boundary"], boundary_out) - - # Synchronize before meshing - factory.synchronize() - - # Mesh and save - geo.mesh_and_save(model_name, 2) - - # Open the Gmsh GUI to show the drawn and meshed geometry - geo.show_model() - - -if __name__ == "__main__": - run_demo() \ No newline at end of file diff --git a/sketchgetdp/image_processing/CurveExtractor.py b/sketchgetdp/image_processing/CurveExtractor.py deleted file mode 100644 index a07e456..0000000 --- a/sketchgetdp/image_processing/CurveExtractor.py +++ /dev/null @@ -1,71 +0,0 @@ -"""This module is used to extract the curve(s) from a given image. - -Author: Laura D'Angelo -""" - -from PIL import Image -import numpy as np -import matplotlib.pyplot as plt - - -class CurveExtractor: - """This class is used to extract the curve(s) from a given image. - - Attributes: - image_path (str): The path to the image file. - image (PIL.Image): The image object. - image_array (np.array): The image as a numpy array. - curve (np.array): The x- and y-coordinates of the extracted curve, normalized to [0, 1]². - """ - - def __init__(self, image_path: str) -> "CurveExtractor": - """The constructor for the CurveExtractor class. Reads an image file found at the path - image_path. - - Parameters: - image_path (str): The path to the image file. - """ - self.image_path = image_path - self.image = Image.open(self.image_path) - self.image_array = np.array(self.image) - self.curve = None - - def extract_curve(self) -> np.array: - """This method extracts the curve from the image by converting the image to a binary image, - then to a binary array, from which the coordinates of the curve are extracted and normalized. - - Returns: - np.array: The x- and y-coordinates of the extracted curve, normalized to [0, 1]². - """ - # Convert the image to binary image - binary_image = self.image.convert("1", dither=Image.NONE) - - # Convert the binary image to a numpy array. Black pixels are 0 and white pixels are 1, - # so we need to negate the binary array. - binary_array = np.array(binary_image) - negated_binary_array = np.logical_not(binary_array) - - # Extract the curve and normalize the coordinates to [0, 1] - indices_row, indices_col = np.where(negated_binary_array) - image_size_x = np.size(negated_binary_array, 0) - image_size_y = np.size(negated_binary_array, 1) - x_coordinates = indices_row / image_size_x - y_coordinates = indices_col / image_size_y - curve = np.array([x_coordinates, y_coordinates]).T - - self.curve = curve - return curve - - def plot_curve(self): - """This method plots the extracted normalized curve on a xy-plane. - - Returns: - None - """ - # Check if the curve has already been extracted. If not, extract the curve before plotting. - if self.curve is None: - self.extract_curve() - - # Plot the curve - plt.plot(self.curve[:, 1], self.curve[:, 0], "x") - plt.show() diff --git a/sketchgetdp/image_processing/__init__.py b/sketchgetdp/image_processing/__init__.py deleted file mode 100644 index 0f59708..0000000 --- a/sketchgetdp/image_processing/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .CurveExtractor import CurveExtractor - -__all__ = ['CurveExtractor'] \ No newline at end of file diff --git a/sketchgetdp/solver/rmvp_formulation.pro b/sketchgetdp/rmvp_formulation.pro similarity index 100% rename from sketchgetdp/solver/rmvp_formulation.pro rename to sketchgetdp/rmvp_formulation.pro diff --git a/sketchgetdp/svg_to_getdp/README.md b/sketchgetdp/svg_to_getdp/README.md new file mode 100644 index 0000000..13a8aa2 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/README.md @@ -0,0 +1,196 @@ +# SVG to GetDP + +A sophisticated electromagnetic simulation pipeline that converts SVG sketches into Gmsh meshes and solves them using GetDP, with configurable physical properties and intelligent geometry processing. + +## 🎯 Overview + +SVG to GetDP is a Python-based electromagnetic simulation pipeline that processes SVG files containing electromagnetic structures and generates simulation results through a multi-stage workflow. It features: + +- **Three operation modes**: SVG→Gmsh, SVG→Gmsh→GetDP, or Gmsh→GetDP +- **Configurable physical properties** via YAML configuration +- **Intelligent SVG parsing** with Bézier curve fitting and corner detection +- **Fixed color mapping** for physical group identification +- **Automatic wire grouping** and boundary curve meshing + +## 🏗️ Architecture + +The project follows Clean Architecture principles with clear separation of concerns: + +### Core Layers + +- **`core/`** - Enterprise business rules + - `entities/` - Domain models (Point, Color, BezierSegment, BoundaryCurve, PhysicalGroup) + - `use_cases/` - Application logic (SVG-to-Geometry conversion, Geometry-to-Gmsh conversion, GetDP simulation execution) + +- **`infrastructure/`** - Frameworks & drivers + - `factories/` - Factory classes for dependency creation + - `svg_processing/` - SVG parsing and path extraction + - `corner_detection/` - Corner detection for curve segmentation + - `bezier_fitting/` - Bézier curve fitting + - `boundary_curve_grouper/` - Wire grouping logic + - `boundary_curve_mesher/` - Boundary curve meshing + - `wire_preprocessor/` - Wire preprocessing for meshing + +- **`interfaces/`** - Interface adapters + - `controllers/` - Application flow control + - `arg_parser/` - Command line argument parsing + - `abstractions/` - Interfaces for dependency inversion + - `debug/` - Internal visualization and debug output + +## 🚀 Key Features + +### Three Operation Modes +1. **SVG → Gmsh**: Convert SVG sketches to Gmsh meshes +2. **SVG → Gmsh → GetDP**: Full pipeline from SVG to simulation results +3. **Gmsh → GetDP**: Run GetDP simulation on existing meshes + +### Intelligent SVG Processing +- **Bézier curve fitting** for accurate shape representation +- **Corner detection** for optimal curve segmentation +- **Fixed color mapping** for physical group identification +- **Automatic wire grouping** based on spatial relationships + +### Configurable Physical Properties +- Customizable coil current directions and magnitudes +- Adjustable mesh sizing parameters +- Configurable physical values for simulation + +### Visualization & Debug +- Visualization of internal geometry +- Debug output of intermediate processing steps via .txt files + +## 📁 Project Structure +``` +svg_to_getdp/ +├── core/ # Business logic +│ ├── entities/ # Domain models +│ └── use_cases/ # Application services +├── infrastructure/ # External concerns +│ ├── factories/ # Factory pattern implementations +│ ├── svg_processing/ # SVG parsing +│ ├── corner_detection/ # Corner detection +│ ├── bezier_fitting/ # Bézier fitting +│ ├── boundary_curve_grouper.py # Wire grouping +│ ├── boundary_curve_mesher.py # Boundary meshing +│ └── wire_preprocessor # Wire preprocessing +├── interfaces/ # Adapters +│ ├── arg_parser.py # Command line interface +│ ├── abstractions/ # Dependency interfaces +│ ├── debug/ # Debug tools +│ ├── mesher/ # Meshing tools +│ └── solver/ # Solving tools +├── tests/ # pytests +│ ├── core/ # Core layer tests +│ └── infrastructure/ # Infrastructure tests +├── __main__.py # Package entry point +├── config.yaml # Configuration file +└── rmvp_formulation.pro # GetDP configuration file +``` + +## ⚙️ Configuration + +Configure wire currents, mesh settings, and simulation parameters in `config.yaml`: + +```yaml +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 6 + current_sign: 1 + cluster_2: + wire_count: 6 + current_sign: -1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 9000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity +``` + +## 🛠️ Usage + +### Mode 1: SVG to Gmsh Mesh + +Convert an SVG file to a Gmsh mesh file: + +```bash +python -m svg_to_getdp drawing.svg --config config.yaml +``` + +### Mode 2: Full Pipeline (SVG to Simulation) + +Convert SVG file to mesh file and run GetDP simulation: + +```bash +python -m svg_to_getdp drawing.svg --run-simulation --config config.yaml +``` + +### Mode 3: Simulation Only (Existing Mesh) + +Run GetDP simulation on an existing mesh file: + +```bash +python -m svg_to_getdp --simulation-only existing_mesh.msh --config config.yaml +``` + +### Additional Options +- `--mesh-name my_mesh`: Specify output mesh name +- `--no-gui`: Run in batch mode without GUI +- `--debug`: Enable debug output + +### Examples +```bash +# Generate mesh with custom name and no GUI +python -m svg_to_getdp sketch.svg --mesh-name my_design --no-gui + +# Full pipeline with custom config +python -m svg_to_getdp circuit.svg --config custom_config.yaml --run-simulation + +# Get debug output +python -m svg_to_getdp layout.svg --debug +``` + +## 📊 Output + +The pipeline generates the following outputs depending on the mode: + +### Mode 1 (SVG → Gmsh) + +- **`.msh` file**: Gmsh mesh file with physical groups + +### Mode 2 (SVG → Gmsh → GetDP) + +- **`.msh` file**: Gmsh mesh file +- **`.pro` file**: GetDP problem definition +- **`results/` directory**: GetDP simulation results + +### Mode 3 (Gmsh Mesh → GetDP) + +- **`.pro` file**: GetDP problem definition +- **`results/` directory**: GetDP simulation results + +## 🔧 Dependencies + +- **NumPy** - Numerical computations +- **svgpathtools** - SVG parsing and path manipulation +- **PyYAML** - Configuration parsing +- **Gmsh** - Meshing engine (external dependency) +- **GetDP** - Finite element solver (external dependency) +- **matplotlib** - Visualization (optional) + +## 🎨 Use Cases + +- **Rapid prototyping**: Get first estimates of electromagnetic poperties from SVG sketches +- **Educational Tool**: Visualize electromagnetic field distributions from simple drawings +- **Design validation**: Quickly test electromagnetic structures before detailed CAD modeling +- **Mesh generation**: Create quality meshes from vector graphics for various Finite Element Analysis applications + +The SVG to GetDP pipeline excels at transforming intuitive SVG sketches into detailed electromagnetic simulations, bridging the gap between conceptual design and numerical analysis while maintaining configurability and reproducability. \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/__init__.py b/sketchgetdp/svg_to_getdp/__init__.py new file mode 100644 index 0000000..107c3ce --- /dev/null +++ b/sketchgetdp/svg_to_getdp/__init__.py @@ -0,0 +1,9 @@ +""" +SVG to Getdp Package + +A clean architecture implementation for meshing SVG designs into Gmsh geometries suitable for Getdp simulations. +Then running magnetostatic simulations using the RMVP formulation. +""" + +__version__ = "1.0.0" +__author__ = "Sarah Schleidt" \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/__main__.py b/sketchgetdp/svg_to_getdp/__main__.py new file mode 100644 index 0000000..a4abae3 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/__main__.py @@ -0,0 +1,265 @@ +""" +SVG to Getdp - Package Entry Point + +This module allows the package to be executed as: +python -m svg_to_getdp [arguments] +""" + +from pathlib import Path + +def main(): + """Main entry point for the SVG to Geometry converter""" + + # Import here to ensure path is set correctly + from .interfaces.arg_parser import ArgParser + + # Parse command line arguments + arg_parser = ArgParser() + args = arg_parser.parse_args() + + try: + # MODE 1: Simulation-only mode (existing mesh) + if args.simulation_only: + from .core.use_cases.run_getdp_simulation import RunGetDPSimulation + + # Get mesh name from the provided mesh file + mesh_path = Path(args.simulation_only) + if not mesh_path.exists(): + raise FileNotFoundError(f"Mesh file not found: {args.simulation_only}") + + # Remove .msh extension if present + mesh_name = mesh_path.stem + + print(f"\n=== Running GetDP Simulation on Existing Mesh ===") + print(f"Mesh file: {args.simulation_only}") + print(f"Config file: {args.config}") + + # Initialize and run GetDP simulation + getdp_usecase = RunGetDPSimulation() + getdp_usecase.execute( + mesh_name=mesh_name, + use_config_yaml=True, + config_yaml_path=args.config, + show_simulation_result=not args.no_gui + ) + + print(f"\n✓ GetDP simulation completed successfully!") + print(f" Results saved to: results/") + + return 0 + + # MODE 2 & 3: Normal processing (SVG → Gmsh) + from svg_to_getdp.core.use_cases.convert_svg_to_geometry import ConvertSVGToGeometry + from svg_to_getdp.core.use_cases.convert_geometry_to_gmsh import ConvertGeometryToGmsh + converter = ConvertSVGToGeometry() + + # Execute the SVG conversion use case with debug data collection + outlines, wires, colored_raw_outlines, corner_debug_data = converter.execute(args.svg_file) + + # Output conversion results + print(f"Successfully converted {len(outlines)} outlines and {len(wires)} wires:") + + for i, outline in enumerate(outlines): + print(f" Outline {i+1}: {len(outline.bezier_segments)} segments, " + f"{len(outline.corners)} corners, color: {outline.color.name.lower()}") + + for i, (point, color) in enumerate(wires): + print(f" Wire {i+1}: at ({point.x:.3f}, {point.y:.3f}), color: {color.name.lower()}") + + # Handle debug output of svg to geometry conversion + if args.debug: + try: + from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + from svg_to_getdp.interfaces.debug.svg_parser_debug_writer import SVGParserDebugWriter + from svg_to_getdp.interfaces.debug.corner_detector_debug_writer import CornerDetectorDebugWriter + from svg_to_getdp.interfaces.debug.geometry_debug_writer import GeometryDebugWriter + from svg_to_getdp.interfaces.debug.geometry_visualizer import GeometryVisualizer + + # Initialize debug coordinator first + debug_coordinator = DebugCoordinator() + debug_coordinator.set_svg_file(args.svg_file) + shared_timestamp = debug_coordinator.get_shared_timestamp() + + # Initialize debug writers with the same timestamp + svg_parser_debug_writer = SVGParserDebugWriter() + svg_parser_debug_writer._shared_timestamp = shared_timestamp + + corner_detector_debug_writer = CornerDetectorDebugWriter() + corner_detector_debug_writer._shared_timestamp = shared_timestamp + + geometry_debug_writer = GeometryDebugWriter() + geometry_debug_writer._shared_timestamp = shared_timestamp + + # Write SVG parser debug info + print(f"\n=== Writing SVG Parser Debug ===") + svg_parser_debug_writer.write_svg_parser_debug_info( + svg_file_path=args.svg_file, + colored_raw_outlines=colored_raw_outlines + ) + + # Write corner detection debug info + print(f"\n=== Writing Corner Detection Debug ===") + if corner_debug_data: + corner_detector_debug_writer.write_corner_detection_debug_info( + svg_file_path=args.svg_file, + corner_debug_data=corner_debug_data, + raw_outlines_by_color=colored_raw_outlines + ) + + # Write geometry debug info + print(f"\n=== Generating Geometry Debug ===") + summary_path = geometry_debug_writer.write_geometry_debug_info( + svg_file_path=args.svg_file, + outlines=outlines, + wires=wires + ) + + # Generate geometry plot + try: + plot_path = GeometryVisualizer.save_plot_with_coordinator( + outlines=outlines, + coordinator=debug_coordinator, + wires=wires, + colored_raw_outlines=colored_raw_outlines, + show_control_points=True, + show_corners=True, + show_raw_outlines=True + ) + + except ImportError as e: + print(f" Geometry plot unavailable: {e}") + print(" Install with: pip install matplotlib") + except Exception as e: + print(f" Geometry plot error: {e}") + import traceback + traceback.print_exc() + + except ImportError as e: + print(f"Debug output unavailable: {e}") + except Exception as e: + print(f"Debug output error: {e}") + import traceback + traceback.print_exc() + + # Determine config file path + config_file_path = Path(args.config) + if not config_file_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file_path}") + + # ALWAYS perform Gmsh meshing + print("\n=== Starting Gmsh Meshing ===") + + gmsh_converter = ConvertGeometryToGmsh() + + # Determine mesh name (output filename) + if args.mesh_name: + # User specified custom mesh name + mesh_name = args.mesh_name + else: + # Default: use SVG filename without extension + svg_path = Path(args.svg_file) + mesh_name = svg_path.stem + + # Execute Gmsh conversion + gmsh_results = gmsh_converter.execute( + outlines=outlines, + wires=wires, + config_file_path=str(config_file_path), + model_name="svg_geometry", + output_filename=mesh_name, + dimension=2, + show_gui=not args.no_gui + ) + + print(f"\n✓ Gmsh meshing completed successfully!") + print(f" Mesh saved to: {mesh_name}.msh") + + # Handle debug output of geometry to Gmsh conversion + if args.debug: + try: + from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + from svg_to_getdp.interfaces.debug.outline_grouper_debug_writer import OutlineGrouperDebugWriter + from svg_to_getdp.interfaces.debug.outline_preprocessor_debug_writer import OutlinePreprocessorDebugWriter + from svg_to_getdp.interfaces.debug.wire_preprocessor_debug_writer import WirePreprocessorDebugWriter + + # Initialize debug writers with the same timestamp + grouping_debug_writer = OutlineGrouperDebugWriter() + grouping_debug_writer.set_shared_timestamp(shared_timestamp) + + preprocessing_debug_writer = OutlinePreprocessorDebugWriter() + preprocessing_debug_writer.set_shared_timestamp(shared_timestamp) + + wire_debug_writer = WirePreprocessorDebugWriter() + wire_debug_writer.set_shared_timestamp(shared_timestamp) + + # Write outline grouping debug + if "debug_data" in gmsh_results and "outline_grouping" in gmsh_results["debug_data"]: + print(f"\n=== Writing Outline Grouping Debug ===") + + grouping_debug_data = gmsh_results["debug_data"]["outline_grouping"] + grouping_debug_file = grouping_debug_writer.write_grouping_debug_info( + svg_file_path=args.svg_file, + outlines=grouping_debug_data["outlines"], + grouping_result=grouping_debug_data["grouping_result"], + grouper_instance=grouping_debug_data["grouper_instance"] + ) + + # Write outline preprocessing debug + print(f"\n=== Writing Outline Preprocessing Debug ===") + preprocessing_debug_file = preprocessing_debug_writer.write_preprocessing_debug_info( + svg_file_path=args.svg_file, + outlines=outlines, + preprocessor_instance=gmsh_converter.outline_preprocessor, + gmsh_results=gmsh_results + ) + + # Write wire preprocessor debug + print(f"\n=== Writing Wire Preprocessor Debug ===") + wire_debug_file = wire_debug_writer.write_wire_preprocessor_debug_info( + svg_file_path=args.svg_file, + wires=wires, + config_file_path=str(config_file_path), + wire_preprocessor_instance=gmsh_converter.wire_preprocessor, + gmsh_results=gmsh_results + ) + + except ImportError as e: + print(f"Gmsh debug output unavailable: {e}") + except Exception as e: + print(f"Gmsh debug output error: {e}") + import traceback + traceback.print_exc() + + # MODE 3: Run GetDP simulation if requested + if args.run_simulation: + from .core.use_cases.run_getdp_simulation import RunGetDPSimulation + + print("\n=== Starting GetDP Simulation ===") + + # Initialize and run GetDP simulation + getdp_usecase = RunGetDPSimulation() + getdp_usecase.execute( + mesh_name=mesh_name, + use_config_yaml=True, + config_yaml_path=args.config, + show_simulation_result=not args.no_gui + ) + + print(f"\n✓ GetDP simulation completed successfully!") + print(f" Results saved to: results/") + + except FileNotFoundError as e: + print(f"Error: File not found - {e}") + print(f"Current working directory: {Path.cwd()}") + return 1 + except Exception as e: + print(f"Error processing: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 + +if __name__ == "__main__": + exit(main()) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/config.yaml b/sketchgetdp/svg_to_getdp/config.yaml new file mode 100644 index 0000000..8e73ff4 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/config.yaml @@ -0,0 +1,23 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 6 + current_sign: 1 + cluster_2: + wire_count: 6 + current_sign: -1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 9000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/__init__.py b/sketchgetdp/svg_to_getdp/core/entities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/core/entities/bezier_segment.py b/sketchgetdp/svg_to_getdp/core/entities/bezier_segment.py new file mode 100644 index 0000000..305265d --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/bezier_segment.py @@ -0,0 +1,122 @@ +import math +from typing import List +from svg_to_getdp.core.entities.point import Point + + +class BezierSegment: + """ + Represents a single Bézier curve segment of degree n. + Based on the mathematical definition from the paper. + """ + + def __init__(self, control_points: List[Point], degree: int): + """ + Initialize Bézier segment with control points and degree. + + Args: + control_points: List of n+1 control points (b_0, b_1, ..., b_n) + degree: Degree n of the Bézier curve + """ + if len(control_points) != degree + 1: + raise ValueError(f"Degree {degree} requires {degree + 1} control points, " + f"but got {len(control_points)}") + + self.control_points = control_points + self.degree = degree + + def bernstein_basis(self, i: int, t: float) -> float: + """ + Compute the i-th Bernstein basis polynomial of degree n at parameter t. + + Args: + i: Index of the basis polynomial (0 <= i <= n) + t: Parameter value in [0, 1] + + Returns: + Value of B_{i,n}(t) + """ + if not 0 <= i <= self.degree: + raise ValueError(f"Index i must be between 0 and {self.degree}, got {i}") + + return math.comb(self.degree, i) * (t ** i) * ((1 - t) ** (self.degree - i)) + + def evaluate(self, t: float) -> Point: + """ + Evaluate the Bézier curve at parameter t. + + Args: + t: Parameter value in [0, 1] + + Returns: + Point on the curve C(t) + """ + if not (0 <= t <= 1): + raise ValueError(f"Parameter t must be in [0, 1], got {t}") + + result = Point(0.0, 0.0) + for i, control_point in enumerate(self.control_points): + basis_val = self.bernstein_basis(i, t) + result = result + control_point * basis_val + + return result + + def derivative(self, t: float) -> Point: + """ + Compute the derivative of the Bézier curve at parameter t. + + Args: + t: Parameter value in [0, 1] + + Returns: + Derivative vector dC/dt at parameter t + """ + if not (0 <= t <= 1): + raise ValueError(f"Parameter t must be in [0, 1], got {t}") + + if self.degree == 0: + return Point(0.0, 0.0) + + result = Point(0.0, 0.0) + for i in range(self.degree): + # Difference between consecutive control points + diff = self.control_points[i + 1] - self.control_points[i] + # Bernstein basis of degree n-1 + basis_val = math.comb(self.degree - 1, i) * (t ** i) * ((1 - t) ** (self.degree - 1 - i)) + result = result + diff * (self.degree * basis_val) + + return result + + @property + def start_point(self) -> Point: + """First control point b_0 (start of curve)""" + return self.control_points[0] + + @property + def end_point(self) -> Point: + """Last control point b_n (end of curve)""" + return self.control_points[-1] + + def get_curve_points(self, num_points: int = 100) -> List[Point]: + """ + Sample the Bézier curve at multiple parameter values. + + Args: + num_points: Number of points to sample + + Returns: + List of points along the curve + """ + if num_points < 2: + raise ValueError("Number of points must be at least 2") + + return [self.evaluate(t) for t in [i / (num_points - 1) for i in range(num_points)]] + + def __repr__(self) -> str: + return f"BezierSegment(degree={self.degree}, control_points={len(self.control_points)})" + + def __eq__(self, other: object) -> bool: + """Equality comparison for Bézier segments""" + if not isinstance(other, BezierSegment): + return False + return (self.degree == other.degree and + self.control_points == other.control_points) \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/color.py b/sketchgetdp/svg_to_getdp/core/entities/color.py new file mode 100644 index 0000000..3f54848 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/color.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass +from typing import ClassVar + + +@dataclass(frozen=True) +class Color: + """A simple color entity supporting red, green, blue and black colors.""" + + RED: ClassVar['Color'] = None + GREEN: ClassVar['Color'] = None + BLUE: ClassVar['Color'] = None + BLACK: ClassVar['Color'] = None + + name: str + rgb: tuple[int, int, int] + + def __post_init__(self): + """Validate color after initialization""" + if not isinstance(self.name, str): + raise TypeError("Color name must be a string") + + if self.name not in ["red", "green", "blue", "black"]: + raise ValueError("Color must be 'red', 'green', 'blue', or 'black'") + + if not isinstance(self.rgb, tuple) or len(self.rgb) != 3: + raise ValueError("RGB must be a tuple of 3 integers") + + for value in self.rgb: + if not isinstance(value, int) or value < 0 or value > 255: + raise ValueError("RGB values must be integers between 0 and 255") + + def to_hex(self) -> str: + """Convert RGB color to hexadecimal format.""" + return f"#{self.rgb[0]:02x}{self.rgb[1]:02x}{self.rgb[2]:02x}" + + def to_normalized_rgb(self) -> tuple[float, float, float]: + """Convert RGB color to normalized values (0.0 to 1.0).""" + return (self.rgb[0] / 255.0, self.rgb[1] / 255.0, self.rgb[2] / 255.0) + + +# Initialize the class variables after class definition +Color.RED = Color("red", (255, 0, 0)) +Color.GREEN = Color("green", (0, 255, 0)) +Color.BLUE = Color("blue", (0, 0, 255)) +Color.BLACK = Color("black", (0, 0, 0)) \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/outline.py b/sketchgetdp/svg_to_getdp/core/entities/outline.py new file mode 100644 index 0000000..67acf78 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/outline.py @@ -0,0 +1,184 @@ +from dataclasses import dataclass +from typing import List, Tuple +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.point import Point + + +@dataclass +class Outline: + """ + Represents a complete outline composed of multiple Bézier segments. + """ + + bezier_segments: List[BezierSegment] + corners: List[Point] # Coordinates identified as corners + color: Color # Used for potential assignment in simulation + is_closed: bool = True + + def __post_init__(self): + """Validate that the outline is properly constructed with tolerance.""" + if len(self.bezier_segments) < 1: + raise ValueError("Outline must have at least one Bézier segment") + + # Warn for significant gaps + for i in range(len(self.bezier_segments) - 1): + current_segment = self.bezier_segments[i] + next_segment = self.bezier_segments[i + 1] + + distance = current_segment.end_point.distance_to(next_segment.start_point) + if distance >= 1e-8: # Only warn for gaps smaller than gmsh Geometry.Tolerance default + print(f"WARNING: Small discontinuity between segments {i} and {i+1}: {distance:.6f}") + + @property + def control_points(self) -> List[Point]: + """Get all control points from all Bézier segments (including duplicates at interfaces).""" + all_points = [] + for segment in self.bezier_segments: + all_points.extend(segment.control_points) + return all_points + + @property + def unique_control_points(self) -> List[Point]: + """Get all unique control points (removing duplicates at interfaces).""" + if not self.bezier_segments: + return [] + + unique_points = [] + for i, segment in enumerate(self.bezier_segments): + if i == 0: + # For first segment, take all control points + unique_points.extend(segment.control_points) + else: + # For subsequent segments, skip first control point (duplicate of previous segment's last point) + unique_points.extend(segment.control_points[1:]) + return unique_points + + def evaluate(self, t: float) -> Point: + """ + Evaluate the outline at parameter t ∈ [0,1]. + """ + if not 0 <= t <= 1: + raise ValueError("Parameter t must be in [0,1]") + + num_segments = len(self.bezier_segments) + segment_index = int(t * num_segments) + segment_index = min(segment_index, num_segments - 1) # Handle t=1.0 + + # Map global t to local t̃ ∈ [0,1] for the specific segment + local_t = (t * num_segments) - segment_index + + return self.bezier_segments[segment_index].evaluate(local_t) + + def derivative(self, t: float) -> Point: + """ + Compute the derivative of the outline at parameter t ∈ [0,1]. + """ + if not 0 <= t <= 1: + raise ValueError("Parameter t must be in [0,1]") + + num_segments = len(self.bezier_segments) + segment_index = int(t * num_segments) + segment_index = min(segment_index, num_segments - 1) + + local_t = (t * num_segments) - segment_index + + # Apply chain rule: d𝒞/dt = N_C * dC/dṫ + derivative = self.bezier_segments[segment_index].derivative(local_t) + return Point(derivative.x * num_segments, derivative.y * num_segments) + + def is_corner_at_parameter(self, t: float, tolerance: float = 1e-6) -> bool: + """ + Check if the given parameter t corresponds to a corner point. + """ + evaluated_point = self.evaluate(t) + for corner in self.corners: + if (abs(evaluated_point.x - corner.x) < tolerance and + abs(evaluated_point.y - corner.y) < tolerance): + return True + return False + + def is_corner_at_segment_interface(self, segment_index: int) -> bool: + """ + Check if the interface between segments is a corner. + + Args: + segment_index: Index of the segment (0 to len(segments)-2) + Represents the interface between segments[segment_index] + and segments[segment_index + 1] + """ + if segment_index < 0 or segment_index >= len(self.bezier_segments) - 1: + raise ValueError("Invalid segment index for interface check") + + interface_point = self.bezier_segments[segment_index].end_point + for corner in self.corners: + if (abs(interface_point.x - corner.x) < 1e-6 and + abs(interface_point.y - corner.y) < 1e-6): + return True + return False + + def get_segment_at_parameter(self, t: float) -> Tuple[BezierSegment, float]: + """ + Get the Bézier segment and local parameter for a given global parameter t. + + Returns: + Tuple of (segment, local_t) where local_t ∈ [0,1] + """ + if not 0 <= t <= 1: + raise ValueError("Parameter t must be in [0,1]") + + num_segments = len(self.bezier_segments) + segment_index = int(t * num_segments) + segment_index = min(segment_index, num_segments - 1) + + local_t = (t * num_segments) - segment_index + + return self.bezier_segments[segment_index], local_t + + def get_outline_points(self, num_points: int = 100) -> List[Point]: + """ + Sample the entire outline at multiple parameter values. + + Args: + num_points: Number of points to sample along the entire outline + + Returns: + List of points along the complete outline + """ + if num_points < 2: + raise ValueError("Number of points must be at least 2") + + points = [] + for i in range(num_points): + t = i / (num_points - 1) + points.append(self.evaluate(t)) + return points + + def get_outline_length_approximation(self, num_samples: int = 1000) -> float: + """ + Approximate the length of the outline by sampling. + + Args: + num_samples: Number of sample points for length approximation + + Returns: + Approximate length of the outline + """ + points = self.get_outline_points(num_samples) + length = 0.0 + for i in range(len(points) - 1): + length += points[i].distance_to(points[i + 1]) + return length + + def __len__(self) -> int: + """Return the number of Bézier segments in this outline.""" + return len(self.bezier_segments) + + def __iter__(self): + """Iterate over Bézier segments.""" + return iter(self.bezier_segments) + + def __repr__(self) -> str: + return (f"Outline(segments={len(self.bezier_segments)}, " + f"corners={len(self.corners)}, color={self.color.name}, " + f"closed={self.is_closed})") \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/physical_group.py b/sketchgetdp/svg_to_getdp/core/entities/physical_group.py new file mode 100644 index 0000000..e43ebe1 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/physical_group.py @@ -0,0 +1,120 @@ +from dataclasses import dataclass +from typing import Optional +from svg_to_getdp.core.entities.color import Color + + +@dataclass(frozen=True) +class PhysicalGroup: + """A physical group entity representing different domains and boundaries in the system.""" + + name: str + description: str + group_type: str # "domain" or "boundary" + value: int # Numeric identifier for the physical group + color: Optional[Color] = None + current_sign: Optional[int] = None # 1 for positive, -1 for negative, None for non-coil domains + + def __post_init__(self): + """Validate physical group after initialization""" + if not isinstance(self.name, str): + raise TypeError("Physical group name must be a string") + + if not isinstance(self.description, str): + raise TypeError("Physical group description must be a string") + + if self.group_type not in ["domain", "boundary"]: + raise ValueError("Group type must be either 'domain' or 'boundary'") + + if not isinstance(self.value, int): + raise TypeError("Value must be an integer") + + if self.color is not None and not isinstance(self.color, Color): + raise TypeError("Color must be an instance of Color class or None") + + if self.current_sign not in [None, 1, -1]: + raise ValueError("Current sign must be None, 1 (positive), or -1 (negative)") + + # Validate coil-specific constraints + # Only apply coil rules if it's a domain AND has "coil" in name (case-insensitive) + if self.group_type == "domain" and "coil" in self.name.lower(): + if self.current_sign is None: + raise ValueError("Coil domains must have a current sign (1 or -1)") + if self.color != Color.RED: + raise ValueError("Coil domains must be red") + else: + if self.current_sign is not None: + raise ValueError("Only coil domains can have a current sign") + + def has_color(self) -> bool: + """Check if this physical group has an associated color.""" + return self.color is not None + + def is_coil(self) -> bool: + """Check if this is a coil domain.""" + return self.group_type == "domain" and "coil" in self.name.lower() + + def is_boundary(self) -> bool: + """Check if this is a boundary.""" + return self.group_type == "boundary" + + def is_domain(self) -> bool: + """Check if this is a domain.""" + return self.group_type == "domain" + + +# Module-level constants +DOMAIN_VI_IRON = PhysicalGroup( + name="domain_Vi_iron", + description="Iron domain in Vi region", + group_type="domain", + value=2, + color=Color.BLUE +) + +DOMAIN_VI_AIR = PhysicalGroup( + name="domain_Vi_air", + description="Air domain in Vi region", + group_type="domain", + value=3, + color=Color.GREEN +) + +DOMAIN_VA = PhysicalGroup( + name="domain_Va", + description="Va domain", + group_type="domain", + value=1, + color=Color.BLACK +) + +DOMAIN_COIL_POSITIVE = PhysicalGroup( + name="domain_coil_positive", + description="Coil domain with positive current", + group_type="domain", + value=101, + color=Color.RED, + current_sign=1 +) + +DOMAIN_COIL_NEGATIVE = PhysicalGroup( + name="domain_coil_negative", + description="Coil domain with negative current", + group_type="domain", + value=102, + color=Color.RED, + current_sign=-1 +) + +BOUNDARY_GAMMA = PhysicalGroup( + name="boundary_gamma", + description="Interface boundary between Vi and Va regions", + group_type="boundary", + value=11 +) + +BOUNDARY_OUT = PhysicalGroup( + name="boundary_out", + description="Outermost boundary", + group_type="boundary", + value=12 +) \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/point.py b/sketchgetdp/svg_to_getdp/core/entities/point.py new file mode 100644 index 0000000..81b7868 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/point.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +import math + + +@dataclass(frozen=True) +class Point: + """A 0D point entity representing a position in 2D space.""" + + x: float = 0.0 + y: float = 0.0 + + def __post_init__(self): + """Validate coordinates after initialization""" + if not isinstance(self.x, (int, float)) or not isinstance(self.y, (int, float)): + raise TypeError("Coordinates must be numeric") + + if math.isnan(self.x) or math.isnan(self.y): + raise ValueError("Coordinates cannot be NaN") + + def distance_to(self, other: 'Point') -> float: + """Calculate Euclidean distance to another point.""" + return math.sqrt((self.x - other.x)**2 + (self.y - other.y)**2) + + def distance_to_origin(self) -> float: + """Calculate distance from origin (0,0).""" + return math.sqrt(self.x**2 + self.y**2) + + def __add__(self, other: 'Point') -> 'Point': + """Vector addition""" + return Point(self.x + other.x, self.y + other.y) + + def __sub__(self, other: 'Point') -> 'Point': + """Vector subtraction""" + return Point(self.x - other.x, self.y - other.y) + + def __mul__(self, scalar: float) -> 'Point': + """Scalar multiplication""" + return Point(self.x * scalar, self.y * scalar) + + def __rmul__(self, scalar: float) -> 'Point': + """Reverse scalar multiplication""" + return self.__mul__(scalar) + + def norm(self) -> float: + """Euclidean norm (magnitude) of the vector""" + return math.sqrt(self.x**2 + self.y**2) + + def __truediv__(self, scalar: float) -> 'Point': + """Scalar division""" + if scalar == 0: + raise ValueError("Division by zero") + return Point(self.x / scalar, self.y / scalar) + + def __eq__(self, other: object) -> bool: + """Equality comparison""" + if not isinstance(other, Point): + return False + return math.isclose(self.x, other.x) and math.isclose(self.y, other.y) + + def __repr__(self) -> str: + """Better representation for debugging""" + return f"Point({self.x}, {self.y})" + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/entities/raw_outline.py b/sketchgetdp/svg_to_getdp/core/entities/raw_outline.py new file mode 100644 index 0000000..20dd060 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/entities/raw_outline.py @@ -0,0 +1,28 @@ +""" +RawOutline entity - temporary data structure for SVG parsing results. +""" + +from dataclasses import dataclass +from typing import List +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + + +@dataclass +class RawOutline: + """ + Temporary data structure for raw outline data extracted from SVG. + This will be converted to Outline later after Bezier fitting. + """ + points: List[Point] + color: Color + is_closed: bool = True + + def __post_init__(self): + """Validate the raw outline data.""" + # Allow single points for red dots, but require >=3 points for other colors + if self.color != Color.RED and len(self.points) < 3: + raise ValueError(f"Raw outline must have at least 3 points for color {self.color.name}, got {len(self.points)}") + elif self.color == Color.RED and len(self.points) < 1: + raise ValueError("Red dot must have at least 1 point") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/use_cases/__init__.py b/sketchgetdp/svg_to_getdp/core/use_cases/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/core/use_cases/convert_geometry_to_gmsh.py b/sketchgetdp/svg_to_getdp/core/use_cases/convert_geometry_to_gmsh.py new file mode 100644 index 0000000..a990eef --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/use_cases/convert_geometry_to_gmsh.py @@ -0,0 +1,177 @@ +""" +Usecase to convert geometry to Gmsh format. +Integrates outlines, wires, and configuration to create a complete Gmsh model. +""" + +import yaml +from typing import List, Tuple, Dict, Any +from pathlib import Path + +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + +from svg_to_getdp.interfaces.mesher.gmsh_toolbox import ( + initialize_gmsh, + set_characteristic_mesh_length, + mesh_and_save, + show_model, + finalize_gmsh +) + +class ConvertGeometryToGmsh: + """ + Use case for converting geometry to Gmsh format. + """ + + def __init__(self): + """ + Initialize the use case using factories internally. + """ + from svg_to_getdp.infrastructure.factories.outline_grouper_factory import OutlineGrouperFactory + from svg_to_getdp.infrastructure.factories.outline_preprocessor_factory import OutlinePreprocessorFactory + from svg_to_getdp.infrastructure.factories.wire_preprocessor_factory import WirePreprocessorFactory + + self.outline_grouper = OutlineGrouperFactory.create_default() + self.outline_preprocessor = OutlinePreprocessorFactory.create_default() + self.wire_preprocessor = WirePreprocessorFactory.create_default() + + def execute( + self, + outlines: List[Outline], + wires: List[Tuple[Point, Color]], + config_file_path: str, + model_name: str = "geometry_model", + output_filename: str = "geometry_mesh", + dimension: int = 2, + show_gui: bool = True + ) -> dict: + """ + Main use case to convert geometry to Gmsh format. + + Steps: + 1. Load configuration and extract mesh size + 2. Initialize Gmsh + 3. Set the mesh size from config + 4. Prepare wires + 5. Group outlines with containment hierarchy + 6. Preprocess outlines + 7. Synchronize before meshing + 8. Mesh and save + 9. Optionally show Gmsh GUI + + Args: + outlines: List of Outline objects representing domain boundaries + wires: List of (Point, Color) tuples representing wires + config_file_path: Path to YAML configuration file for wire currents and mesh settings + model_name: Name for the Gmsh model (default: "geometry_model") + output_filename: Base filename for output mesh (without extension) + dimension: Dimension of mesh (default: 2 for 2D) + show_gui: Whether to open Gmsh GUI after meshing (default: True) + + Returns: + Dictionary containing results from all processing steps including debug data + + Raises: + ValueError: If input parameters are invalid + FileNotFoundError: If config file doesn't exist + KeyError: If required configuration is missing + """ + # Input validation + if not isinstance(outlines, list): + raise ValueError("outlines must be a list") + + if not isinstance(wires, list): + raise ValueError("wires must be a list") + + config_path = Path(config_file_path) + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file not found: {config_file_path}") + + if not outlines: + print("Warning: No outlines provided") + + # Step 1: Load configuration + print(f"Loading configuration from: {config_file_path}") + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # Extract mesh size from config (default to 0.1 if not specified) + mesh_size = config.get('mesh_size', 0.1) + print(f"Using mesh size from config: {mesh_size}") + + # Results dictionary to store outputs from each step + results: Dict[str, Any] = { + "model_name": model_name, + "output_filename": output_filename, + "mesh_size": mesh_size, + "dimension": dimension, + "config_file": config_file_path, + "debug_data": {} + } + + try: + # Step 2: Initialize Gmsh + print(f"Initializing Gmsh with model name: {model_name}") + factory = initialize_gmsh(model_name) + results["factory_initialized"] = True + + # Step 3: Set mesh size from config + print(f"Setting characteristic mesh length factor to: {mesh_size}") + set_characteristic_mesh_length(mesh_size) + results["mesh_size_set"] = True + + # Step 4: Prepare wires + print(f"Preparing {len(wires)} wires...") + wire_results = self.wire_preprocessor.prepare_wires( + factory, + config_file_path, + wires + ) + results["wire_results"] = wire_results + + # Step 5: Group outlines with containment hierarchy + print(f"Grouping {len(outlines)} outlines...") + grouping_result = self.outline_grouper.group_outlines(outlines) + results["grouping_result"] = grouping_result + + # Store debug data + results["debug_data"]["outline_grouping"] = { + "outlines": outlines, + "grouping_result": grouping_result, + "grouper_instance": self.outline_grouper + } + + # Step 6: Preprocess outlines + print("Preprocessing outlines...") + preprocessing_result = self.outline_preprocessor.preprocess_outlines(factory, outlines, grouping_result) + results["preprocessing_result"] = preprocessing_result + + # Step 7: Synchronize before meshing + factory.synchronize() + print("Geometry synchronized in Gmsh") + results["geometry_synchronized"] = True + + # Step 8: Mesh and save + print(f"Generating {dimension}D mesh...") + mesh_and_save(output_filename, dimension) + results["mesh_generated"] = True + print(f"Mesh saved to: {output_filename}.msh") + + # Step 9: Show Gmsh GUI if requested + if show_gui: + print("Opening Gmsh GUI...") + show_model() + results["gui_shown"] = True + + return results + + except Exception as e: + print(f"Error during geometry conversion: {e}") + raise + + finally: + # Clean up Gmsh resources + finalize_gmsh() + print("Gmsh finalized") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/use_cases/convert_svg_to_geometry.py b/sketchgetdp/svg_to_getdp/core/use_cases/convert_svg_to_geometry.py new file mode 100644 index 0000000..64bdfe5 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/use_cases/convert_svg_to_geometry.py @@ -0,0 +1,118 @@ +""" +Core use case: Convert SVG to Geometry +""" + +from typing import List, Tuple +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + +class ConvertSVGToGeometry: + """ + Use case for converting SVG sketches to outlines with Bézier representations. + """ + + def __init__(self): + """ + Initialize the converter using factories internally. + """ + from svg_to_getdp.infrastructure.factories.svg_parser_factory import SvgParserFactory + from svg_to_getdp.infrastructure.factories.corner_detector_factory import CornerDetectorFactory + from svg_to_getdp.infrastructure.factories.bezier_fitter_factory import BezierFitterFactory + + self.svg_parser = SvgParserFactory.create_default() + self.corner_detector = CornerDetectorFactory.create_default() + self.bezier_fitter = BezierFitterFactory.create_default() + + def execute(self, svg_file_path: str) -> Tuple[List[Outline], List[Tuple[Point, Color]], dict, dict]: + """ + Convert SVG file to outlines with Bézier representations and wires. + Returns: (outlines, wires, colored_raw_outlines, corner_debug_data) + """ + # Step 1: Parse SVG to get raw outlines grouped by color + colored_raw_outlines = self.svg_parser.extract_raw_outlines_by_color(svg_file_path) + + outlines = [] + wires = [] + corner_debug_data = {} + + # Process each color group + for color, raw_outlines in colored_raw_outlines.items(): + for outline_idx, raw_outline in enumerate(raw_outlines): + if color == Color.RED: + # For red elements: treat as wires + if len(raw_outline.points) == 1: + wires.append((raw_outline.points[0], color)) + else: + center = raw_outline.points[0] + wires.append((center, color)) + else: + # For green/blue elements: process as outlines + + # Step 1: Ensure proper closure for closed outlines + points = self._ensure_proper_closure(raw_outline.points, raw_outline.is_closed) + + # Step 2: Detect corners in the outline with debug data + corner_indices, raw_outline_debug = self.corner_detector.detect_corners(points) + + # Store debug data with unique key + key = f"{color.name}_raw_outline_{outline_idx}" + corner_debug_data[key] = { + 'color': color.name, + 'outline_index': outline_idx, + 'points_count': len(points), + 'is_closed': raw_outline.is_closed, + 'corner_indices': corner_indices, + 'debug': raw_outline_debug + } + + # Step 3: Fit piecewise Bézier curves + outline = self.bezier_fitter.fit_outline( + points=points, + corner_indices=corner_indices, + color=color, + is_closed=raw_outline.is_closed + ) + + # Step 4: Ensure closure if needed + if outline.is_closed and outline.bezier_segments: + self._force_outline_closure(outline) + + outlines.append(outline) + + return outlines, wires, colored_raw_outlines, corner_debug_data + + def _ensure_proper_closure(self, points: List[Point], is_closed: bool) -> List[Point]: + """ + Ensure that closed outlines properly connect first and last points. + """ + if not is_closed or len(points) < 3: + return points + + # Check if first and last points are already close + first_point = points[0] + last_point = points[-1] + closure_distance = first_point.distance_to(last_point) + + if closure_distance > 1e-6: # If not properly closed + # Add first point at the end to close the outline + return points + [first_point] + else: + return points + + def _force_outline_closure(self, outline: Outline): + """ + Force an outline to be properly closed by ensuring first and last control points match. + """ + if not outline.bezier_segments: + return + + first_segment = outline.bezier_segments[0] + last_segment = outline.bezier_segments[-1] + + if (first_segment.control_points and last_segment.control_points and + first_segment.control_points[0] != last_segment.control_points[-1]): + + # Make last control point of last segment match first control point of first segment + last_segment.control_points[-1] = first_segment.control_points[0] + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/core/use_cases/run_getdp_simulation.py b/sketchgetdp/svg_to_getdp/core/use_cases/run_getdp_simulation.py new file mode 100644 index 0000000..8b6e523 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/core/use_cases/run_getdp_simulation.py @@ -0,0 +1,182 @@ +""" +Use case for running GetDP magnetostatic simulations. +""" + +import yaml +from typing import Optional +import numpy as np +from svg_to_getdp.interfaces.solver.getdp_toolbox import ( + print_data_to_pro, + run_magnetostatic_simulation, + physical_identifiers +) + +# Add Gmsh import +try: + import gmsh + GMSH_AVAILABLE = True +except ImportError: + GMSH_AVAILABLE = False + gmsh = None + + +class RunGetDPSimulation: + """ + Use case for running GetDP magnetostatic simulations. + + This class encapsulates the business logic for configuring and running + GetDP simulations following the specified steps. + """ + + def __init__(self): + """Initialize the use case with default values.""" + self.physical_values = None + + def execute( + self, + mesh_name: str, + use_config_yaml: bool = False, + config_yaml_path: Optional[str] = None, + show_simulation_result: bool = True + ) -> None: + """ + Execute the GetDP simulation use case. + + Parameters: + ----------- + mesh_name : str + Name of the mesh model (without .msh extension) + use_config_yaml : bool + Whether to use config.yaml to update rmvp_formulation.pro file + config_yaml_path : Optional[str] + Path to the config.yaml file (optional, defaults to 'config.yaml') + show_simulation_result : bool + Whether to show the simulation result in Gmsh GUI + """ + # Step 0: Initialize Gmsh if needed + self._initialize_gmsh() + + # Step 1: Handle mesh name + if not mesh_name.endswith('.msh'): + mesh_name = f"{mesh_name}.msh" + + # Step 2: Handle config.yaml if requested + config_data = {} + if use_config_yaml: + config_path = config_yaml_path if config_yaml_path else 'config.yaml' + config_data = self._load_config_yaml(config_path) + + # Step 3: Define physical values + self._define_physical_values(config_data if use_config_yaml else None) + + # Step 4: Set physical identifiers + phys_ids = physical_identifiers() + print_data_to_pro("physical_identifiers", phys_ids) + + # Step 5: Set physical values + print_data_to_pro("physical_values", self.physical_values) + + # Step 6: Run simulation (always uses rmvp_formulation.pro) + self._run_simulation(mesh_name, show_simulation_result) + + # Step 7: Finalize Gmsh if we initialized it + if hasattr(self, '_gmsh_initialized_by_us') and self._gmsh_initialized_by_us: + if GMSH_AVAILABLE and gmsh.isInitialized(): + gmsh.finalize() + + def _initialize_gmsh(self) -> None: + """ + Initialize Gmsh if it's not already initialized. + """ + if not GMSH_AVAILABLE: + raise ImportError("Gmsh is not available. Please install python-gmsh package.") + + if not gmsh.isInitialized(): + gmsh.initialize() + self._gmsh_initialized_by_us = True + print("Gmsh initialized for GetDP simulation") + else: + self._gmsh_initialized_by_us = False + + def _load_config_yaml(self, config_path: str) -> dict: + """ + Load configuration from YAML file. + + Parameters: + ----------- + config_path : str + Path to the config.yaml file + + Returns: + -------- + dict: Configuration data + """ + try: + with open(config_path, 'r') as file: + return yaml.safe_load(file) + except FileNotFoundError: + print(f"Warning: Config file {config_path} not found. Using default values.") + return {} + except yaml.YAMLError as e: + print(f"Error parsing YAML file: {e}") + return {} + + def _define_physical_values(self, config_data: Optional[dict] = None) -> None: + """ + Define physical values for the simulation. + + Parameters: + ----------- + config_data : Optional[dict] + Configuration data from YAML file + """ + mu0 = 4e-7 * np.pi # Vacuum permeability + + self.physical_values = { + "Isource": 1, # Default current source [A] + "mu0": mu0, # Vacuum permeability [H/m] + "nu0": 1/mu0, # Vacuum reluctivity [m/H] + "nu_iron_linear": 1/(4000 * mu0) # Iron reluctivity (relative permeability = 4000) [m/H] + } + + # Update with config data if provided + if config_data and 'physical_values' in config_data: + config_phys_vals = config_data['physical_values'] + for key, value in config_phys_vals.items(): + # Handle special case where expressions contain pi + if isinstance(value, str): + # Replace pi with numpy pi for evaluation + if 'pi' in value.lower(): + value = value.replace('pi', str(np.pi)) + value = value.replace('Pi', str(np.pi)) + value = value.replace('PI', str(np.pi)) + try: + # Safe evaluation of mathematical expressions + value = eval(value, {"__builtins__": {}}, {"pi": np.pi}) + except: + # If evaluation fails, keep as string + pass + self.physical_values[key] = value + + def _run_simulation(self, mesh_name: str, show_result: bool) -> None: + """ + Run the magnetostatic simulation. + + Parameters: + ----------- + mesh_name : str + Name of the mesh file + show_result : bool + Whether to show the simulation result + """ + run_magnetostatic_simulation(mesh_name, show_simulation_result=show_result) + + def get_physical_values(self) -> dict: + """ + Get the current physical values. + + Returns: + -------- + dict: Current physical values + """ + return self.physical_values.copy() if self.physical_values else {} \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/__init__.py b/sketchgetdp/svg_to_getdp/infrastructure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/__init__.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_calculator.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_calculator.py new file mode 100644 index 0000000..695e389 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_calculator.py @@ -0,0 +1,286 @@ +import math +from typing import List, Tuple, Optional +from svg_to_getdp.core.entities.point import Point + + +class BezierCalculator: + """ + Mathematical calculator for Bézier curve fitting and geometric calculations. + """ + + def remove_consecutive_duplicate_points(self, points: List[Point]) -> List[Point]: + """Remove consecutive duplicate points from the input list.""" + if not points: + return [] + + unique_points = [points[0]] + for i in range(1, len(points)): + if points[i] != points[i-1]: + unique_points.append(points[i]) + + return unique_points + + def calculate_optimal_segment_count(self, points: List[Point], + corner_indices: List[int], + minimum_points_per_segment: int = 15) -> int: + """Calculate appropriate number of segments based on corners and point density.""" + point_count = len(points) + + if corner_indices: + base_segments = max(len(corner_indices), 100) + else: + base_segments = max(200, point_count // 10) + + minimum_segments = 100 + maximum_segments = min(200, max(1, point_count // 10)) + + return min(maximum_segments, max(minimum_segments, base_segments)) + + def calculate_segment_interfaces(self, points: List[Point], corner_indices: List[int], + target_segment_count: int, is_closed: bool) -> List[int]: + """Calculate bezier segment interfaces prioritizing corners while ensuring sufficient segmentation.""" + point_count = len(points) + + if point_count < 2: + return [0] + + # Start with corners as primary interfaces + interfaces = sorted(set(corner_indices)) + + # Always include the start point + if 0 not in interfaces: + interfaces.insert(0, 0) + + if is_closed: + if not interfaces: + interfaces = [0] + + current_segment_count = len(interfaces) + + if current_segment_count < target_segment_count: + additional_interfaces_needed = target_segment_count - current_segment_count + new_interfaces = set(interfaces) + + for i in range(1, additional_interfaces_needed + 1): + new_interface_index = int((i * point_count) / (additional_interfaces_needed + 1)) + # Avoid interfaces too close to existing ones + is_too_close = any(abs(new_interface_index - existing) < 5 for existing in new_interfaces) + if not is_too_close and new_interface_index < point_count: + new_interfaces.add(new_interface_index) + + interfaces = sorted(new_interfaces) + + else: + # For open outlines, include the end point + if (point_count - 1) not in interfaces: + interfaces.append(point_count - 1) + + current_segment_count = len(interfaces) - 1 + + if current_segment_count < target_segment_count: + additional_interfaces_needed = target_segment_count - current_segment_count + + # Find segments with largest gaps + segment_gaps = [] + for i in range(len(interfaces) - 1): + gap_size = interfaces[i + 1] - interfaces[i] + segment_gaps.append((gap_size, i)) + + segment_gaps.sort(reverse=True) + + # Split largest gaps + for gap_size, gap_index in segment_gaps[:additional_interfaces_needed]: + if gap_size > 20: # Only split substantial gaps + midpoint = interfaces[gap_index] + gap_size // 2 + interfaces.insert(gap_index + 1, midpoint) + + # Clean up interfaces + interfaces = [index for index in interfaces if 0 <= index < point_count] + interfaces = sorted(set(interfaces)) + + # Ensure minimum of 2 interfaces for segment creation + if len(interfaces) < 2: + if point_count > 1: + midpoint = point_count // 2 + interfaces = [0, midpoint, point_count - 1] if not is_closed else [0, midpoint] + else: + interfaces = [0] + + return interfaces + + def compute_bernstein_basis(self, basis_index: int, degree: int, parameter: float) -> float: + """Compute Bernstein basis polynomial value.""" + return math.comb(degree, basis_index) * (parameter ** basis_index) * ((1 - parameter) ** (degree - basis_index)) + + def are_points_geometrically_straight(self, points: List[Point], + relative_tolerance: float = 0.005, + absolute_tolerance: float = 1e-6) -> Tuple[bool, float]: + """ + Determine if points form a straight line within specified tolerances. + + Returns both boolean result and confidence score (0-1). + """ + if len(points) < 3: + return True, 1.0 + + max_deviation = self._calculate_max_deviation_from_line(points) + segment_length = points[0].distance_to(points[-1]) + + if segment_length == 0: + return True, 1.0 + + normalized_deviation = max_deviation / segment_length + angle_variance = self._calculate_angle_variance(points) + passes_simplified_check = self.are_points_approximately_linear(points, relative_tolerance) + + meets_all_criteria = ( + normalized_deviation < relative_tolerance and + max_deviation < absolute_tolerance and + angle_variance < 0.01 and + passes_simplified_check + ) + + confidence = self._calculate_straightness_confidence( + normalized_deviation, max_deviation, angle_variance, + relative_tolerance, absolute_tolerance + ) + + return meets_all_criteria, confidence + + def are_points_approximately_linear(self, points: List[Point], max_deviation_ratio: float = 0.01) -> bool: + """Check if points form an approximately straight line.""" + if len(points) < 3: + return True + + start_point = points[0] + end_point = points[-1] + + max_absolute_deviation = 0 + for point in points: + deviation = self.calculate_distance_from_line(start_point, end_point, point) + max_absolute_deviation = max(max_absolute_deviation, deviation) + + segment_length = start_point.distance_to(end_point) + if segment_length == 0: + return True + + normalized_deviation = max_absolute_deviation / segment_length + return normalized_deviation < max_deviation_ratio + + def find_point_with_max_deviation(self, points: List[Point], line_start: Point, line_end: Point) -> Point: + """Find the point that deviates most from the line between start and end points.""" + max_deviation = -1 + most_deviant_point = points[len(points) // 2] + + for point in points: + deviation = self.calculate_distance_from_line(line_start, line_end, point) + if deviation > max_deviation: + max_deviation = deviation + most_deviant_point = point + + return most_deviant_point + + def project_point_to_line(self, line_start: Point, line_end: Point, point: Point) -> Point: + """Project a point onto the line defined by start and end points.""" + line_vector = Point(line_end.x - line_start.x, line_end.y - line_start.y) + point_vector = Point(point.x - line_start.x, point.y - line_start.y) + + line_length_squared = line_vector.x ** 2 + line_vector.y ** 2 + if line_length_squared == 0: + return line_start + + projection_parameter = (point_vector.x * line_vector.x + point_vector.y * line_vector.y) / line_length_squared + projection_parameter = max(0, min(1, projection_parameter)) # Clamp to segment + + return Point( + line_start.x + projection_parameter * line_vector.x, + line_start.y + projection_parameter * line_vector.y + ) + + def calculate_distance_from_line(self, line_point1: Point, line_point2: Point, test_point: Point) -> float: + """Calculate perpendicular distance from a point to a line.""" + if line_point1 == line_point2: + return line_point1.distance_to(test_point) + + # Using cross product formula: |(p2 - p1) × (p - p1)| / |p2 - p1| + cross_product = abs( + (line_point2.x - line_point1.x) * (test_point.y - line_point1.y) - + (line_point2.y - line_point1.y) * (test_point.x - line_point1.x) + ) + line_length = line_point1.distance_to(line_point2) + + return cross_product / line_length if line_length > 0 else 0 + + def _calculate_max_deviation_from_line(self, points: List[Point]) -> float: + """Find maximum perpendicular distance of any point from the line between endpoints.""" + start_point, end_point = points[0], points[-1] + max_deviation = 0.0 + + for point in points: + deviation = self.calculate_distance_from_line(start_point, end_point, point) + max_deviation = max(max_deviation, deviation) + + return max_deviation + + def _calculate_straightness_confidence(self, normalized_deviation: float, + max_deviation: float, angle_variance: float, + relative_tolerance: float, + absolute_tolerance: float) -> float: + """ + Calculate confidence score (0-1) for straightness assessment. + """ + deviation_score = 1.0 - normalized_deviation / max(relative_tolerance, 1e-10) + absolute_score = 1.0 - max_deviation / max(absolute_tolerance, 1e-10) + angle_score = 1.0 - angle_variance / 0.01 + + confidence = ( + deviation_score * 0.4 + + absolute_score * 0.3 + + angle_score * 0.3 + ) + + return min(1.0, confidence) + + def _calculate_angle_variance(self, points: List[Point]) -> float: + """Calculate variance of angles between consecutive line segments.""" + if len(points) < 3: + return 0.0 + + angles = self._collect_segment_angles(points) + + if not angles: + return 0.0 + + mean_angle = sum(angles) / len(angles) + variance = sum((angle - mean_angle) ** 2 for angle in angles) / len(angles) + return variance + + def _collect_segment_angles(self, points: List[Point]) -> List[float]: + """Collect angles between consecutive segments formed by three adjacent points.""" + angles = [] + + for i in range(1, len(points) - 1): + angle = self._calculate_angle_at_point(points[i-1], points[i], points[i+1]) + if angle is not None: + angles.append(angle) + + return angles + + def _calculate_angle_at_point(self, previous_point: Point, current_point: Point, + next_point: Point) -> Optional[float]: + """Calculate angle formed by three consecutive points at the middle point.""" + vector_to_previous = Point(current_point.x - previous_point.x, + current_point.y - previous_point.y) + vector_to_next = Point(next_point.x - current_point.x, + next_point.y - current_point.y) + + dot_product = vector_to_previous.x * vector_to_next.x + vector_to_previous.y * vector_to_next.y + previous_length = math.sqrt(vector_to_previous.x**2 + vector_to_previous.y**2) + next_length = math.sqrt(vector_to_next.x**2 + vector_to_next.y**2) + + if previous_length < 1e-10 or next_length < 1e-10: + return None + + cosine = max(-1.0, min(1.0, dot_product / (previous_length * next_length))) + return math.acos(cosine) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_fitter.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_fitter.py new file mode 100644 index 0000000..7313c24 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/bezier_fitter.py @@ -0,0 +1,83 @@ +from typing import List +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.interfaces.abstractions.bezier_fitter_interface import BezierFitterInterface + +from svg_to_getdp.infrastructure.bezier_fitting.segment_classifier import SegmentClassifier +from svg_to_getdp.infrastructure.bezier_fitting.segment_fitter import SegmentFitter +from svg_to_getdp.infrastructure.bezier_fitting.continuity_enforcer import ContinuityEnforcer +from svg_to_getdp.infrastructure.bezier_fitting.bezier_calculator import BezierCalculator + + +class BezierFitter(BezierFitterInterface): + """ + Main orchestrator for fitting piecewise Bézier curves to outline points. + Coordinates the workflow between specialized components. + """ + + def __init__(self, bezier_degree: int = 2, minimum_points_per_segment: int = 15): + self.bezier_degree = bezier_degree + self.minimum_points_per_segment = minimum_points_per_segment + + # Initialize components + self.segment_classifier = SegmentClassifier() + self.segment_fitter = SegmentFitter(bezier_degree) + self.continuity_enforcer = ContinuityEnforcer(bezier_degree) + self.bezier_calculator = BezierCalculator() + + def fit_outline(self, points: List[Point], corner_indices: List[int], + color, is_closed: bool = True) -> Outline: + """ + Fit piecewise Bézier curves to outline points, treating corners as segment interfaces. + + Args: + points: Raw outline points to fit curves to + corner_indices: Indices of corner points that should be segment interfaces + color: Color for the resulting outline + is_closed: Whether the outline forms a closed loop + + Returns: + Outline with fitted Bézier segments and corner information + + Raises: + ValueError: When insufficient points are provided + """ + # Step 1: Clean input points + cleaned_points = self.bezier_calculator.remove_consecutive_duplicate_points(points) + if len(cleaned_points) < 3: + raise ValueError(f"Need at least 3 non-duplicate points for outline, got {len(cleaned_points)}") + + # Step 2: Calculate optimal segment count + optimal_segment_count = self.bezier_calculator.calculate_optimal_segment_count( + cleaned_points, corner_indices, self.minimum_points_per_segment + ) + + # Step 3: Fit piecewise Bézier curves + bezier_segments = self.segment_fitter.fit_piecewise_bezier_curves( + cleaned_points, + corner_indices, + optimal_segment_count, + is_closed, + self.segment_classifier, + self.bezier_calculator + ) + + # Step 4: Enforce continuity + if bezier_segments: + segment_interfaces = self.bezier_calculator.calculate_segment_interfaces( + cleaned_points, corner_indices, optimal_segment_count, is_closed + ) + self.continuity_enforcer.enforce_segment_continuity( + bezier_segments, segment_interfaces, corner_indices, is_closed + ) + + # Step 5: Extract corner points + corner_points = [cleaned_points[idx] for idx in corner_indices] if corner_indices else [] + + return Outline( + bezier_segments=bezier_segments, + corners=corner_points, + color=color, + is_closed=is_closed + ) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/continuity_enforcer.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/continuity_enforcer.py new file mode 100644 index 0000000..10a0551 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/continuity_enforcer.py @@ -0,0 +1,92 @@ +from typing import List +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.point import Point + + +class ContinuityEnforcer: + """ + Enforces C0 and C1 continuity between Bézier segments. + """ + + def __init__(self, bezier_degree: int = 2): + self.bezier_degree = bezier_degree + + def enforce_segment_continuity(self, segments: List[BezierSegment], + outlines: List[int], corner_indices: List[int], + is_closed: bool): + """Enforce C0 continuity at all junctions and C1 continuity only at non-corner junctions.""" + if len(segments) < 2: + return + + for segment_index in range(len(segments) - 1): + current_segment = segments[segment_index] + next_segment = segments[segment_index + 1] + junction_index = outlines[segment_index + 1] + is_corner_junction = junction_index in corner_indices + + # Always enforce C0 continuity (position continuity) + endpoint_gap = current_segment.end_point.distance_to(next_segment.start_point) + if endpoint_gap > 1e-10: + adjusted_control_points = next_segment.control_points.copy() + adjusted_control_points[0] = current_segment.end_point + segments[segment_index + 1] = BezierSegment( + control_points=adjusted_control_points, + degree=next_segment.degree + ) + + # Only enforce C1 continuity (tangent continuity) at smooth junctions + if not is_corner_junction and self.bezier_degree == 2: + self._enforce_tangent_continuity(current_segment, next_segment) + + # Handle closure for closed outlines + if is_closed and len(segments) > 1: + self._ensure_outline_closure(segments) + + def _enforce_tangent_continuity(self, first_segment: BezierSegment, second_segment: BezierSegment): + """Enforce C1 continuity between two quadratic Bézier segments.""" + if self.bezier_degree != 2: + return + + # For quadratic Bézier curves, C1 continuity requires: + # first_segment.control_points[2] - first_segment.control_points[1] = + # second_segment.control_points[1] - second_segment.control_points[0] + p0, p1, p2 = first_segment.control_points + q0, q1, q2 = second_segment.control_points + + # Calculate ideal midpoint that satisfies C1 continuity + ideal_midpoint_x = (p2.x + q0.x) / 2 + ideal_midpoint_y = (p2.y + q0.y) / 2 + + # Adjust control points toward ideal midpoint + adjustment_strength = 0.3 + + adjusted_p1 = Point( + p1.x * (1 - adjustment_strength) + ideal_midpoint_x * adjustment_strength, + p1.y * (1 - adjustment_strength) + ideal_midpoint_y * adjustment_strength + ) + + adjusted_q1 = Point( + q1.x * (1 - adjustment_strength) + ideal_midpoint_x * adjustment_strength, + q1.y * (1 - adjustment_strength) + ideal_midpoint_y * adjustment_strength + ) + + first_segment.control_points[1] = adjusted_p1 + second_segment.control_points[1] = adjusted_q1 + + def _ensure_outline_closure(self, segments: List[BezierSegment]): + """Ensure the first and last points of a closed outline match exactly.""" + if not segments: + return + + first_segment_start = segments[0].start_point + last_segment = segments[-1] + + closure_gap = last_segment.end_point.distance_to(first_segment_start) + if closure_gap > 1e-10: + adjusted_control_points = last_segment.control_points.copy() + adjusted_control_points[-1] = first_segment_start + segments[-1] = BezierSegment( + control_points=adjusted_control_points, + degree=last_segment.degree + ) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_classifier.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_classifier.py new file mode 100644 index 0000000..84aa252 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_classifier.py @@ -0,0 +1,111 @@ +from typing import List, Tuple +from svg_to_getdp.core.entities.point import Point + + +class SegmentClassifier: + """ + Classifies segments into corner regions, straight edges, or curved segments. + """ + + def __init__(self, relative_tolerance: float = 0.005, absolute_tolerance: float = 1e-6): + self.relative_tolerance = relative_tolerance + self.absolute_tolerance = absolute_tolerance + + def classify_segment_type(self, start_index: int, end_index: int, + corner_regions: List[Tuple[int, int]], + corner_indices: List[int], + points: List[Point]) -> str: + """ + Classify a segment into one of three types: corner region, straight edge, or curved. + """ + segment_points = self._extract_segment_points(points, start_index, end_index) + + if self._is_within_corner_region(start_index, end_index, corner_regions): + return "corner_region" + + if self._contains_interior_corner(start_index, end_index, corner_indices): + return "corner_region" + + is_connecting_corners = self._is_segment_connecting_corners(start_index, end_index, corner_indices) + + return self._determine_segment_type_by_geometry(segment_points, is_connecting_corners) + + def identify_corner_regions(self, points: List[Point], corner_indices: List[int]) -> List[Tuple[int, int]]: + """Identify regions around corners that require special constrained fitting.""" + corner_regions = [] + region_radius = min(20, len(points) // 20) + + for corner_index in corner_indices: + region_start = max(0, corner_index - region_radius) + region_end = min(len(points) - 1, corner_index + region_radius) + corner_regions.append((region_start, region_end)) + + return corner_regions + + def _extract_segment_points(self, points: List[Point], start_index: int, end_index: int) -> List[Point]: + """Extract points belonging to a segment from the complete point list.""" + return points[start_index:end_index + 1] + + def _is_within_corner_region(self, start_index: int, end_index: int, + corner_regions: List[Tuple[int, int]]) -> bool: + """Check if segment lies completely within any corner region.""" + for region_start, region_end in corner_regions: + if start_index >= region_start and end_index <= region_end: + return True + return False + + def _contains_interior_corner(self, start_index: int, end_index: int, + corner_indices: List[int]) -> bool: + """ + Check if segment contains a corner point that is not at its outline. + """ + for corner_index in corner_indices: + if start_index < corner_index < end_index: + return True + return False + + def _determine_segment_type_by_geometry(self, segment_points: List[Point], + is_connecting_corners: bool) -> str: + """ + Classify segment based on geometric analysis. + """ + if len(segment_points) < 3: + return self._classify_short_segment(segment_points, is_connecting_corners) + + # Import here to avoid circular imports + from svg_to_getdp.infrastructure.bezier_fitting.bezier_calculator import BezierCalculator + bezier_calculator = BezierCalculator() + + straight, _ = bezier_calculator.are_points_geometrically_straight( + segment_points, self.relative_tolerance, self.absolute_tolerance + ) + + if straight: + return "straight_edge" if is_connecting_corners else "curved" + + return "curved" + + def _classify_short_segment(self, segment_points: List[Point], + is_connecting_corners: bool) -> str: + """Handle classification for segments with fewer than 3 points.""" + if is_connecting_corners: + return "straight_edge" + return "curved" + + def _is_segment_connecting_corners(self, start_index: int, end_index: int, + corner_indices: List[int]) -> bool: + """Check if segment endpoints are consecutive corner points.""" + sorted_corners = sorted(corner_indices) + + # Check for consecutive corners in sequence + for i in range(len(sorted_corners) - 1): + if start_index == sorted_corners[i] and end_index == sorted_corners[i + 1]: + return True + + # Check for closure connection (last to first corner) + if len(sorted_corners) > 1: + if start_index == sorted_corners[-1] and end_index == sorted_corners[0]: + return True + + return False + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_fitter.py b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_fitter.py new file mode 100644 index 0000000..53a8182 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/bezier_fitting/segment_fitter.py @@ -0,0 +1,175 @@ +import numpy as np +from typing import List +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.point import Point + + +class SegmentFitter: + """ + Fits Bézier segments to point data using various fitting strategies. + """ + + def __init__(self, bezier_degree: int = 2): + self.bezier_degree = bezier_degree + + def fit_piecewise_bezier_curves(self, points: List[Point], corner_indices: List[int], + segment_count: int, is_closed: bool, + segment_classifier, bezier_calculator) -> List[BezierSegment]: + """ + Fit Bézier curves with special handling for corner regions and straight edges. + """ + if not corner_indices: + return self._fit_continuous_curves_without_corners(points, segment_count, is_closed, bezier_calculator) + + corner_regions = segment_classifier.identify_corner_regions(points, corner_indices) + segment_interfaces = bezier_calculator.calculate_segment_interfaces( + points, corner_indices, segment_count, is_closed + ) + + fitted_segments = [] + for segment_index in range(len(segment_interfaces) - 1): + start_index = segment_interfaces[segment_index] + end_index = segment_interfaces[segment_index + 1] + segment_points = points[start_index:end_index + 1] + + if len(segment_points) < 2: + continue + + segment_type = segment_classifier.classify_segment_type( + start_index, end_index, corner_regions, corner_indices, points + ) + + if segment_type == "corner_region": + fitted_segment = self._fit_constrained_corner_segment(segment_points, bezier_calculator) + elif segment_type == "straight_edge": + fitted_segment = self._fit_straight_edge_segment(segment_points) + else: + fitted_segment = self.fit_single_bezier_curve(segment_points) + + fitted_segments.append(fitted_segment) + + return fitted_segments + + def fit_single_bezier_curve(self, points: List[Point]) -> BezierSegment: + """Fit a single Bézier curve to points using least-squares optimization.""" + point_count = len(points) + + if point_count <= 3: + return self._fit_simple_bezier_curve(points) + + # Import here to avoid circular imports + from svg_to_getdp.infrastructure.bezier_fitting.bezier_calculator import BezierCalculator + bezier_calculator = BezierCalculator() + + parameter_values = np.linspace(0, 1, point_count) + + # Build Bernstein basis matrix + basis_matrix = np.zeros((point_count, self.bezier_degree + 1)) + for row, t in enumerate(parameter_values): + for col in range(self.bezier_degree + 1): + basis_matrix[row, col] = bezier_calculator.compute_bernstein_basis(col, self.bezier_degree, t) + + x_coordinates = np.array([point.x for point in points]) + y_coordinates = np.array([point.y for point in points]) + + try: + control_x, _, _, _ = np.linalg.lstsq(basis_matrix, x_coordinates, rcond=None) + control_y, _, _, _ = np.linalg.lstsq(basis_matrix, y_coordinates, rcond=None) + + control_points = [ + Point(float(control_x[i]), float(control_y[i])) + for i in range(self.bezier_degree + 1) + ] + + return BezierSegment(control_points=control_points, degree=self.bezier_degree) + + except np.linalg.LinAlgError: + return self._fit_simple_bezier_curve(points) + + def _fit_simple_bezier_curve(self, points: List[Point]) -> BezierSegment: + """Direct Bézier fitting for small point sets or when least-squares fails.""" + point_count = len(points) + + if point_count == 1: + control_points = [points[0]] * (self.bezier_degree + 1) + elif point_count == 2: + start_point, end_point = points[0], points[-1] + control_points = [start_point] + for i in range(1, self.bezier_degree): + interpolation_ratio = i / self.bezier_degree + control_points.append(Point( + start_point.x * (1 - interpolation_ratio) + end_point.x * interpolation_ratio, + start_point.y * (1 - interpolation_ratio) + end_point.y * interpolation_ratio + )) + control_points.append(end_point) + else: + if self.bezier_degree == 2: + start_point, end_point = points[0], points[-1] + middle_index = len(points) // 2 + middle_point = points[middle_index] + control_points = [start_point, middle_point, end_point] + else: + control_points = [points[0]] + for i in range(1, self.bezier_degree): + index = int((i / self.bezier_degree) * (point_count - 1)) + control_points.append(points[index]) + control_points.append(points[-1]) + + return BezierSegment(control_points=control_points, degree=self.bezier_degree) + + def _fit_constrained_corner_segment(self, points: List[Point], bezier_calculator) -> BezierSegment: + """Fit segments in corner regions with heavy constraints to prevent overshooting.""" + if len(points) <= 2: + return self._fit_simple_bezier_curve(points) + + start_point = points[0] + end_point = points[-1] + + if bezier_calculator.are_points_approximately_linear(points): + # Use midpoint for nearly linear segments + midpoint = Point((start_point.x + end_point.x) / 2, (start_point.y + end_point.y) / 2) + else: + # Find point with maximum deviation and its projection onto the line + max_deviation_point = bezier_calculator.find_point_with_max_deviation(points, start_point, end_point) + line_projection = bezier_calculator.project_point_to_line(start_point, end_point, max_deviation_point) + + # Blend between actual deviation point and its projection (70% actual, 30% projected) to prevent distortion + constraint_strength = 0.7 + midpoint = Point( + max_deviation_point.x * constraint_strength + line_projection.x * (1 - constraint_strength), + max_deviation_point.y * constraint_strength + line_projection.y * (1 - constraint_strength) + ) + + return BezierSegment(control_points=[start_point, midpoint, end_point], degree=2) + + def _fit_straight_edge_segment(self, points: List[Point]) -> BezierSegment: + """Fit segments that are known to be straight edges between corners.""" + start_point = points[0] + end_point = points[-1] + midpoint = Point((start_point.x + end_point.x) / 2, (start_point.y + end_point.y) / 2) + + return BezierSegment(control_points=[start_point, midpoint, end_point], degree=2) + + def _fit_continuous_curves_without_corners(self, points: List[Point], segment_count: int, + is_closed: bool, bezier_calculator) -> List[BezierSegment]: + """Fallback method for fitting curves when no corner points are provided.""" + point_count = len(points) + segments = [] + + # Create evenly distributed segment interfaces + points_per_segment = max(1, point_count // segment_count) + outlines = [i * points_per_segment for i in range(segment_count)] + outlines.append(point_count - 1) + + # Fit each segment independently + for segment_index in range(segment_count): + start_index = outlines[segment_index] + end_index = outlines[segment_index + 1] + segment_points = points[start_index:end_index + 1] + + if len(segment_points) >= 2: + segment = self.fit_single_bezier_curve(segment_points) + segments.append(segment) + + return segments + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/__init__.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_detector.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_detector.py new file mode 100644 index 0000000..cc0d307 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_detector.py @@ -0,0 +1,275 @@ +""" +Multi-method corner candidate detection. +""" + +import numpy as np +from typing import List, Dict +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.infrastructure.corner_detection.geometric_calculator import GeometricCalculator + + +class CandidateDetector: + """Detects corner candidates using multiple complementary methods.""" + + def __init__( + self, + window_size: int = 15, + direction_change_threshold: float = 0.8, + angle_threshold: float = np.pi / 6, + corner_strength_threshold: float = 0.45 + ): + self.window_size = window_size + self.direction_change_threshold = direction_change_threshold + self.angle_threshold = angle_threshold + self.corner_strength_threshold = corner_strength_threshold + + def detect_candidate_corners( + self, + outline_points: List[Point], + x_coordinates: np.ndarray, + y_coordinates: np.ndarray, + debug_data: Dict + ) -> List[int]: + """ + Detect candidate corners using multiple complementary methods. + + Combines results from: + 1. Local angle analysis + 2. Direction change detection + 3. Curvature peak analysis + """ + # Apply each detection method independently + angle_based_corners = self._detect_corners_by_local_angle(outline_points) + direction_based_corners = self._detect_corners_by_direction_change(x_coordinates, y_coordinates) + curvature_based_corners = self._detect_corners_by_curvature_peaks(x_coordinates, y_coordinates) + + # Record detection results for debugging + debug_data['candidate_detection'] = { + 'angle_method': angle_based_corners, + 'direction_method': direction_based_corners, + 'curvature_method': curvature_based_corners, + 'all_candidates': list(set(angle_based_corners + direction_based_corners + curvature_based_corners)) + } + + # Calculate strength for all candidates + all_candidates = debug_data['candidate_detection']['all_candidates'] + candidate_strengths = self._calculate_candidate_strengths(outline_points, all_candidates) + debug_data['strength_calculations'] = candidate_strengths + + # Combine results with method-specific weights + weighted_candidates = self._combine_candidate_methods( + angle_based_corners, + direction_based_corners, + curvature_based_corners, + candidate_strengths + ) + debug_data['candidate_detection']['combined_votes'] = weighted_candidates + + # Filter weak candidates based on votes and strength + strong_candidates = self._filter_weak_candidates(weighted_candidates, candidate_strengths) + debug_data['candidate_detection']['coarse_corners'] = strong_candidates + + return strong_candidates + + def _detect_corners_by_local_angle(self, outline_points: List[Point]) -> List[int]: + """Detect corners by analyzing local interior angles at each point.""" + point_count = len(outline_points) + if point_count < 10: + return [] + + angle_window = max(3, min(10, point_count // 50)) + angle_threshold = self.angle_threshold * 0.8 + + corners = [] + + for i in range(point_count): + angle = GeometricCalculator.calculate_point_angle(outline_points, i, angle_window) + if angle > angle_threshold: + corners.append(i) + + return corners + + def _detect_corners_by_direction_change( + self, + x_coordinates: np.ndarray, + y_coordinates: np.ndarray + ) -> List[int]: + """Detect corners by analyzing changes in direction along the outline.""" + point_count = len(x_coordinates) + if point_count < self.window_size * 2: + return [] + + corners = [] + + for i in range(point_count): + # Compute direction vectors before and after the point + previous_direction = GeometricCalculator.compute_direction_vector( + x_coordinates, y_coordinates, i, self.window_size, backward=True + ) + next_direction = GeometricCalculator.compute_direction_vector( + x_coordinates, y_coordinates, i, self.window_size, backward=False + ) + + previous_direction_norm = np.linalg.norm(previous_direction) + next_direction_norm = np.linalg.norm(next_direction) + + if previous_direction_norm > 1e-8 and next_direction_norm > 1e-8: + previous_direction_normalized = previous_direction / previous_direction_norm + next_direction_normalized = next_direction / next_direction_norm + + dot_product = np.clip(np.dot(previous_direction_normalized, next_direction_normalized), -1.0, 1.0) + angle_change = np.arccos(dot_product) + + if angle_change > self.direction_change_threshold: + corners.append(i) + + return corners + + def _detect_corners_by_curvature_peaks( + self, + x_coordinates: np.ndarray, + y_coordinates: np.ndarray + ) -> List[int]: + """Detect corners as local peaks in the curvature profile.""" + point_count = len(x_coordinates) + if point_count < 20: + return [] + + curvature_window = max(3, point_count // 100) + curvatures = [] + + # Calculate curvature at each point + for i in range(point_count): + curvature = GeometricCalculator.calculate_local_curvature( + x_coordinates, y_coordinates, i, curvature_window + ) + curvatures.append(curvature) + + # Find local peaks above threshold + average_curvature = np.mean(curvatures) + curvature_std = np.std(curvatures) + curvature_threshold = average_curvature + curvature_std * 1.0 + + corners = [] + + for i in range(point_count): + previous_index = (i - 1) % point_count + next_index = (i + 1) % point_count + + is_local_peak = ( + curvatures[i] > curvatures[previous_index] and + curvatures[i] > curvatures[next_index] and + curvatures[i] > curvature_threshold + ) + + if is_local_peak: + corners.append(i) + + return corners + + def _calculate_corner_strength(self, outline_points: List[Point], point_index: int) -> float: + """ + Calculate a strength score (0-1) for a potential corner. + + Combines: + 1. Interior angle (larger angles are stronger corners) + 2. Local curvature contrast (corners should stand out from neighbors) + """ + point_count = len(outline_points) + + # Angle component: corners have larger interior angles + angle = GeometricCalculator.calculate_point_angle(outline_points, point_index, 7) + angle_score = min(angle / (np.pi * 0.8), 1.0) + + # Curvature contrast component: corners should have higher curvature than neighbors + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + local_curvature = GeometricCalculator.calculate_local_curvature( + x_coordinates, y_coordinates, point_index, 5 + ) + + # Compare with neighboring curvatures + neighbor_window = min(10, point_count // 20) + neighbor_curvatures = [] + + for offset in range(-neighbor_window, neighbor_window + 1): + if offset != 0: + neighbor_index = (point_index + offset) % point_count + curvature = GeometricCalculator.calculate_local_curvature( + x_coordinates, y_coordinates, neighbor_index, 5 + ) + neighbor_curvatures.append(curvature) + + if neighbor_curvatures: + average_neighbor_curvature = np.mean(neighbor_curvatures) + if average_neighbor_curvature > 1e-8: + curvature_contrast = local_curvature / average_neighbor_curvature + contrast_score = min(curvature_contrast / 3.0, 1.0) + else: + contrast_score = 1.0 + else: + contrast_score = 0.5 + + # Weighted combination: angle is more important than contrast + return angle_score * 0.7 + contrast_score * 0.3 + + def _calculate_candidate_strengths( + self, + outline_points: List[Point], + candidate_indices: List[int] + ) -> Dict[int, float]: + """Calculate strength scores for multiple candidate corners.""" + return { + idx: self._calculate_corner_strength(outline_points, idx) + for idx in candidate_indices + } + + def _combine_candidate_methods( + self, + angle_corners: List[int], + direction_corners: List[int], + curvature_corners: List[int], + candidate_strengths: Dict[int, float] + ) -> Dict[int, float]: + """Combine results from multiple detection methods with weights.""" + weighted_candidates = {} + + # Method weights reflect confidence in each detection approach + method_weights = { + 'angle': 1.0, # Most reliable for clear corners + 'direction': 0.8, # Good for gradual direction changes + 'curvature': 0.6 # Sensitive to local shape changes + } + + # Add candidates from each method with their respective weights + for idx in angle_corners: + if candidate_strengths.get(idx, 0) >= self.corner_strength_threshold * 0.5: + weighted_candidates[idx] = weighted_candidates.get(idx, 0) + method_weights['angle'] + + for idx in direction_corners: + if candidate_strengths.get(idx, 0) >= self.corner_strength_threshold * 0.5: + weighted_candidates[idx] = weighted_candidates.get(idx, 0) + method_weights['direction'] + + for idx in curvature_corners: + if candidate_strengths.get(idx, 0) >= self.corner_strength_threshold * 0.5: + weighted_candidates[idx] = weighted_candidates.get(idx, 0) + method_weights['curvature'] + + return weighted_candidates + + def _filter_weak_candidates( + self, + weighted_candidates: Dict[int, float], + candidate_strengths: Dict[int, float] + ) -> List[int]: + """Filter out candidates with insufficient votes or low strength.""" + minimum_votes = 1.0 + strong_candidates = [] + + for idx, votes in weighted_candidates.items(): + strength = candidate_strengths.get(idx, 0) + if votes >= minimum_votes and strength >= self.corner_strength_threshold: + strong_candidates.append(idx) + + return strong_candidates + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_refiner.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_refiner.py new file mode 100644 index 0000000..b37cf22 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/candidate_refiner.py @@ -0,0 +1,259 @@ +""" +Refinement, clustering, and filtering of corner candidates. +""" + +import numpy as np +from typing import List, Dict, Optional +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.infrastructure.corner_detection.geometric_calculator import GeometricCalculator + + +class CandidateRefiner: + """Refines, clusters, and filters corner candidates.""" + + def __init__( + self, + minimum_corner_distance: int = 5, + corner_strength_threshold: float = 0.45, + angle_threshold: float = np.pi / 6 + ): + self.minimum_corner_distance = minimum_corner_distance + self.corner_strength_threshold = corner_strength_threshold + self.angle_threshold = angle_threshold + + def cluster_nearby_candidates( + self, + outline_points: List[Point], + candidates: List[int], + debug_data: Dict + ) -> List[List[int]]: + """Group nearby candidate corners to avoid duplicates.""" + if not candidates or len(candidates) == 1: + return [candidates] if candidates else [] + + # Cluster candidates that are close to each other + clusters = self._form_candidate_clusters(outline_points, candidates) + + debug_data['clustering']['clusters'] = clusters + + return clusters + + def refine_corner_positions( + self, + outline_points: List[Point], + clustered_corners: List[List[int]], + debug_data: Dict + ) -> List[int]: + """Refine corner positions within each cluster.""" + refined_corners = [] + + for cluster_index, cluster in enumerate(clustered_corners): + if not cluster: + continue + + # Select the strongest candidate from the cluster + candidate_strengths = self._calculate_candidate_strengths(outline_points, cluster) + best_candidate = max(cluster, key=lambda idx: candidate_strengths.get(idx, 0)) + + # Refine the corner position + refined_candidate = self._refine_corner_position(outline_points, best_candidate) + + # Record refinement details for debugging + refinement_detail = self._record_refinement_details( + cluster, best_candidate, refined_candidate, outline_points, debug_data + ) + + if refined_candidate is not None and refinement_detail.get('accepted', False): + refined_corners.append(refined_candidate) + + return refined_corners + + def filter_corners_by_strength(self, outline_points: List[Point], corners: List[int]) -> List[int]: + """Filter out corners that don't meet the strength threshold.""" + candidate_strengths = self._calculate_candidate_strengths(outline_points, corners) + return [ + idx for idx in corners + if candidate_strengths.get(idx, 0) >= self.corner_strength_threshold + ] + + def enforce_minimum_corner_spacing( + self, + outline_points: List[Point], + corners: List[int], + debug_data: Dict + ) -> List[int]: + """Ensure corners are spaced at least minimum_corner_distance apart.""" + if len(corners) <= 1: + return corners + + point_count = len(outline_points) + candidate_strengths = self._calculate_candidate_strengths(outline_points, corners) + + sorted_corners = sorted(corners) + well_spaced_corners = [] + + i = 0 + while i < len(sorted_corners): + current_corner = sorted_corners[i] + well_spaced_corners.append(current_corner) + + # Skip any corners that are too close to the current one + j = i + 1 + while j < len(sorted_corners): + next_corner = sorted_corners[j] + distance = min(abs(next_corner - current_corner), + point_count - abs(next_corner - current_corner)) + + if distance < self.minimum_corner_distance: + # Keep the stronger corner when two are too close + current_strength = candidate_strengths.get(current_corner, 0) + next_strength = candidate_strengths.get(next_corner, 0) + + if next_strength > current_strength * 1.1: + well_spaced_corners[-1] = next_corner + current_corner = next_corner + + j += 1 + else: + break + + i = j + + debug_data['clustering']['refined_corners'] = corners + debug_data['clustering']['quality_corners'] = well_spaced_corners + + return sorted(well_spaced_corners) + + def _form_candidate_clusters(self, outline_points: List[Point], candidates: List[int]) -> List[List[int]]: + """Group candidates that are within minimum distance of each other.""" + point_count = len(outline_points) + sorted_candidates = sorted(candidates) + clusters = [] + current_cluster = [sorted_candidates[0]] + + for i in range(1, len(sorted_candidates)): + previous_idx = sorted_candidates[i-1] + current_idx = sorted_candidates[i] + + # Calculate circular distance along the outline + distance = min(abs(current_idx - previous_idx), point_count - abs(current_idx - previous_idx)) + + if distance < self.minimum_corner_distance * 3: + current_cluster.append(current_idx) + else: + clusters.append(current_cluster) + current_cluster = [current_idx] + + if current_cluster: + clusters.append(current_cluster) + + return clusters + + def _refine_corner_position(self, outline_points: List[Point], coarse_index: int) -> Optional[int]: + """ + Refine a corner position by searching locally for the point with maximum interior angle. + + Args: + outline_points: List of outline points + coarse_index: Initial estimate of corner location + + Returns: + Refined corner index, or None if no good corner found nearby + """ + point_count = len(outline_points) + search_radius = min(10, point_count // 20) + + best_index = coarse_index + best_angle = 0.0 + + # Search within radius for point with maximum interior angle + for offset in range(-search_radius, search_radius + 1): + test_index = (coarse_index + offset) % point_count + angle = GeometricCalculator.calculate_point_angle(outline_points, test_index, 5) + + if angle > best_angle: + best_angle = angle + best_index = test_index + + # Only return if the refined point has a sufficiently large angle + return best_index if best_angle > self.angle_threshold * 0.5 else None + + def _record_refinement_details( + self, + cluster: List[int], + best_candidate: int, + refined_candidate: Optional[int], + outline_points: List[Point], + debug_data: Dict + ) -> Dict: + """Record details of the refinement process for debugging.""" + refinement_detail = { + 'cluster': cluster, + 'best_candidate': best_candidate, + 'refined_candidate': refined_candidate + } + + if refined_candidate is not None: + refined_strength = self._calculate_corner_strength(outline_points, refined_candidate) + refinement_detail['refined_strength'] = refined_strength + + if refined_strength >= self.corner_strength_threshold * 0.8: + refinement_detail['accepted'] = True + else: + refinement_detail['accepted'] = False + else: + refinement_detail['accepted'] = False + + debug_data['refinement_details'].append(refinement_detail) + return refinement_detail + + def _calculate_corner_strength(self, outline_points: List[Point], point_index: int) -> float: + """Calculate corner strength (delegates to geometric calculator).""" + point_count = len(outline_points) + + # Angle component + angle = GeometricCalculator.calculate_point_angle(outline_points, point_index, 7) + angle_score = min(angle / (np.pi * 0.8), 1.0) + + # Curvature contrast component + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + local_curvature = GeometricCalculator.calculate_local_curvature( + x_coordinates, y_coordinates, point_index, 5 + ) + + neighbor_window = min(10, point_count // 20) + neighbor_curvatures = [] + + for offset in range(-neighbor_window, neighbor_window + 1): + if offset != 0: + neighbor_index = (point_index + offset) % point_count + curvature = GeometricCalculator.calculate_local_curvature( + x_coordinates, y_coordinates, neighbor_index, 5 + ) + neighbor_curvatures.append(curvature) + + if neighbor_curvatures: + average_neighbor_curvature = np.mean(neighbor_curvatures) + if average_neighbor_curvature > 1e-8: + curvature_contrast = local_curvature / average_neighbor_curvature + contrast_score = min(curvature_contrast / 3.0, 1.0) + else: + contrast_score = 1.0 + else: + contrast_score = 0.5 + + return angle_score * 0.7 + contrast_score * 0.3 + + def _calculate_candidate_strengths( + self, + outline_points: List[Point], + candidate_indices: List[int] + ) -> Dict[int, float]: + """Calculate strength scores for multiple candidate corners.""" + return { + idx: self._calculate_corner_strength(outline_points, idx) + for idx in candidate_indices + } + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/corner_detector.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/corner_detector.py new file mode 100644 index 0000000..a46651e --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/corner_detector.py @@ -0,0 +1,156 @@ +""" +Main corner detector orchestrator implementing the CornerDetectorInterface. +""" + +import numpy as np +from typing import List, Tuple, Dict +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.interfaces.abstractions.corner_detector_interface import CornerDetectorInterface + +from svg_to_getdp.infrastructure.corner_detection.smooth_shape_detector import SmoothShapeDetector +from svg_to_getdp.infrastructure.corner_detection.candidate_detector import CandidateDetector +from svg_to_getdp.infrastructure.corner_detection.candidate_refiner import CandidateRefiner +from svg_to_getdp.infrastructure.corner_detection.debug_recorder import DebugRecorder + + +class CornerDetector(CornerDetectorInterface): + """ + Corner detector with handling for complex shapes like crosses. + Returns structured debug data along with corner indices. + + The detector uses multiple complementary methods to identify corners: + 1. Local angle analysis + 2. Direction change detection + 3. Curvature peak analysis + + Results are combined, clustered, refined, and filtered to produce final corner points. + """ + + def __init__( + self, + window_size: int = 15, + direction_change_threshold: float = 0.8, + angle_threshold: float = np.pi / 6, + minimum_corner_distance: int = 5, + smoothness_threshold: float = 0.72, + corner_strength_threshold: float = 0.45, + ellipse_aspect_ratio_threshold: float = 1.2, + debug_enabled: bool = True + ): + """ + Initialize the corner detector with configurable parameters. + + Args: + window_size: Size of the analysis window for direction vectors + direction_change_threshold: Minimum angle change (radians) to consider a direction change + angle_threshold: Minimum interior angle (radians) to qualify as a corner + minimum_corner_distance: Minimum distance between detected corners (pixels) + smoothness_threshold: Threshold for detecting smooth/elliptical shapes + corner_strength_threshold: Minimum strength score for a valid corner + ellipse_aspect_ratio_threshold: Maximum aspect ratio for ellipse detection + debug_enabled: Whether to collect and return debug information + """ + self.window_size = window_size + self.direction_change_threshold = direction_change_threshold + self.angle_threshold = angle_threshold + self.minimum_corner_distance = minimum_corner_distance + self.smoothness_threshold = smoothness_threshold + self.corner_strength_threshold = corner_strength_threshold + self.ellipse_aspect_ratio_threshold = ellipse_aspect_ratio_threshold + self.debug_enabled = debug_enabled + + # Initialize components + self.debug_recorder = DebugRecorder(debug_enabled) + self.smooth_shape_detector = SmoothShapeDetector( + smoothness_threshold=smoothness_threshold, + ellipse_aspect_ratio_threshold=ellipse_aspect_ratio_threshold, + window_size=window_size + ) + self.candidate_detector = CandidateDetector( + window_size=window_size, + direction_change_threshold=direction_change_threshold, + angle_threshold=angle_threshold, + corner_strength_threshold=corner_strength_threshold + ) + self.candidate_refiner = CandidateRefiner( + minimum_corner_distance=minimum_corner_distance, + corner_strength_threshold=corner_strength_threshold, + angle_threshold=angle_threshold + ) + + def detect_corners(self, outline_points: List[Point]) -> Tuple[List[int], Dict]: + """ + Identifies indices of corner points in the outline point sequence. + + The detection process involves: + 1. Early shape analysis (ellipse/smooth shape detection) + 2. Candidate detection using multiple methods + 3. Strength calculation for each candidate + 4. Clustering of nearby candidates + 5. Refinement of corner positions + 6. Final filtering and spacing enforcement + + Args: + outline_points: List of ordered points representing a closed outline + + Returns: + Tuple containing: + - List of corner indices in the outline_points list + - Dictionary containing debug information if debug_enabled is True + """ + debug_data = self.debug_recorder.initialize_debug_data() + self.debug_recorder.record_debug_step(debug_data, f"Starting corner detection for {len(outline_points)} outline points") + + # Early return for shapes that are likely ellipses or too smooth + if self.smooth_shape_detector.should_skip_corner_detection(outline_points, debug_data): + self.debug_recorder.record_debug_step(debug_data, "Shape is ellipse or too smooth: returning no corners") + return [], debug_data + + # Convert points to coordinate arrays for efficient computation + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + self.debug_recorder.record_bounding_box_info(x_coordinates, y_coordinates, debug_data) + + # Step 1: Detect candidate corners using multiple complementary methods + candidate_corners = self.candidate_detector.detect_candidate_corners( + outline_points, x_coordinates, y_coordinates, debug_data + ) + + self.debug_recorder.record_debug_step(debug_data, f"Angle method found {len(debug_data['candidate_detection']['angle_method'])} corners") + self.debug_recorder.record_debug_step(debug_data, f"Direction method found {len(debug_data['candidate_detection']['direction_method'])} corners") + self.debug_recorder.record_debug_step(debug_data, f"Curvature method found {len(debug_data['candidate_detection']['curvature_method'])} corners") + self.debug_recorder.record_debug_step(debug_data, f"After filtering: {len(candidate_corners)} strong candidates") + + if not candidate_corners: + self.debug_recorder.record_debug_step(debug_data, "No strong corners found: returning empty list") + return [], debug_data + + # Step 2: Cluster nearby candidates to avoid duplicates + clustered_corners = self.candidate_refiner.cluster_nearby_candidates( + outline_points, candidate_corners, debug_data + ) + + self.debug_recorder.record_debug_step(debug_data, f"Clustering created {len(clustered_corners)} candidate clusters") + + # Step 3: Refine corner positions within each cluster + refined_corners = self.candidate_refiner.refine_corner_positions( + outline_points, clustered_corners, debug_data + ) + + # Step 4: Filter corners by strength + strong_corners = self.candidate_refiner.filter_corners_by_strength(outline_points, refined_corners) + + # Step 5: Ensure minimum spacing between corners + final_corners = self.candidate_refiner.enforce_minimum_corner_spacing( + outline_points, strong_corners, debug_data + ) + + # Record final results + candidate_strengths = self.candidate_detector._calculate_candidate_strengths(outline_points, final_corners) + self.debug_recorder.record_final_results(outline_points, final_corners, debug_data, candidate_strengths) + + self.debug_recorder.record_debug_step(debug_data, f"Final result: {len(final_corners)} corners detected") + + return sorted(final_corners), debug_data + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/debug_recorder.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/debug_recorder.py new file mode 100644 index 0000000..72ca860 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/debug_recorder.py @@ -0,0 +1,64 @@ +""" +Debug data recording functionality for corner detection. +""" + +import numpy as np +from typing import List, Dict +from svg_to_getdp.core.entities.point import Point + + +class DebugRecorder: + """Handles debug data recording for corner detection.""" + + def __init__(self, debug_enabled: bool = True): + self.debug_enabled = debug_enabled + + def initialize_debug_data(self) -> Dict: + """Initialize the debug data structure.""" + return { + 'shape_analysis': {}, + 'candidate_detection': {}, + 'strength_calculations': {}, + 'clustering': {}, + 'refinement_details': [], + 'final_decisions': {}, + 'all_steps': [] + } + + def record_debug_step(self, debug_data: Dict, message: str) -> None: + """Record a debug step if debugging is enabled.""" + if self.debug_enabled: + debug_data['all_steps'].append(message) + + def record_bounding_box_info( + self, + x_coordinates: np.ndarray, + y_coordinates: np.ndarray, + debug_data: Dict + ) -> None: + """Record bounding box information for debugging.""" + debug_data['shape_analysis']['bounding_box'] = { + 'x_min': float(np.min(x_coordinates)), + 'x_max': float(np.max(x_coordinates)), + 'y_min': float(np.min(y_coordinates)), + 'y_max': float(np.max(y_coordinates)), + 'width': float(np.max(x_coordinates) - np.min(x_coordinates)), + 'height': float(np.max(y_coordinates) - np.min(y_coordinates)) + } + + def record_final_results( + self, + outline_points: List[Point], + final_corners: List[int], + debug_data: Dict, + candidate_strengths: Dict[int, float] + ) -> None: + """Record final corner detection results for debugging.""" + debug_data['final_decisions']['final_corners'] = final_corners + debug_data['final_decisions']['corner_coordinates'] = { + idx: outline_points[idx] for idx in final_corners + } + debug_data['final_decisions']['corner_strengths'] = { + idx: candidate_strengths.get(idx, 0) for idx in final_corners + } + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/geometric_calculator.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/geometric_calculator.py new file mode 100644 index 0000000..e726da9 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/geometric_calculator.py @@ -0,0 +1,180 @@ +""" +Pure geometric calculations for corner detection. +Stateless utility functions. +""" + +import numpy as np +from typing import List +from svg_to_getdp.core.entities.point import Point + + +class GeometricCalculator: + """Stateless geometric calculations for corner detection.""" + + @staticmethod + def calculate_point_angle(outline_points: List[Point], point_index: int, window_size: int) -> float: + """ + Calculate the interior angle at a specific outline point. + + Uses vectors to previous and next points to compute the angle. + """ + point_count = len(outline_points) + + previous_index = (point_index - window_size) % point_count + next_index = (point_index + window_size) % point_count + + # Vector from previous point to current point + vector_to_current = np.array([ + outline_points[point_index].x - outline_points[previous_index].x, + outline_points[point_index].y - outline_points[previous_index].y + ]) + + # Vector from current point to next point + vector_from_current = np.array([ + outline_points[next_index].x - outline_points[point_index].x, + outline_points[next_index].y - outline_points[point_index].y + ]) + + vector_to_current_norm = np.linalg.norm(vector_to_current) + vector_from_current_norm = np.linalg.norm(vector_from_current) + + if vector_to_current_norm > 1e-8 and vector_from_current_norm > 1e-8: + cosine_angle = np.dot(vector_to_current, vector_from_current) / (vector_to_current_norm * vector_from_current_norm) + cosine_angle = np.clip(cosine_angle, -1.0, 1.0) + return np.arccos(cosine_angle) + + return 0.0 + + @staticmethod + def calculate_local_curvature( + x_coordinates: np.ndarray, + y_coordinates: np.ndarray, + point_index: int, + window_size: int + ) -> float: + """ + Calculate the curvature at a specific point along the outline. + + Curvature is defined as the rate of change of direction per unit arc length. + """ + point_count = len(x_coordinates) + + previous_index = (point_index - window_size) % point_count + next_index = (point_index + window_size) % point_count + + # Vectors from previous to current and current to next + vector_to_current = np.array([ + x_coordinates[point_index] - x_coordinates[previous_index], + y_coordinates[point_index] - y_coordinates[previous_index] + ]) + + vector_from_current = np.array([ + x_coordinates[next_index] - x_coordinates[point_index], + y_coordinates[next_index] - y_coordinates[point_index] + ]) + + vector_to_current_norm = np.linalg.norm(vector_to_current) + vector_from_current_norm = np.linalg.norm(vector_from_current) + + if vector_to_current_norm < 1e-8 or vector_from_current_norm < 1e-8: + return 0.0 + + # Calculate angle between vectors + cosine_angle = np.dot(vector_to_current, vector_from_current) / (vector_to_current_norm * vector_from_current_norm) + cosine_angle = np.clip(cosine_angle, -1.0, 1.0) + angle = np.arccos(cosine_angle) + + # Calculate average arc length + arc_length = (vector_to_current_norm + vector_from_current_norm) / 2 + + return angle / arc_length if arc_length > 0 else 0.0 + + @staticmethod + def compute_direction_vector( + x_coordinates: np.ndarray, + y_coordinates: np.ndarray, + point_index: int, + window_size: int, + backward: bool + ) -> np.ndarray: + """Compute the average direction vector over a window of points.""" + point_count = len(x_coordinates) + + if backward: + start_index = (point_index - window_size) % point_count + end_index = point_index + else: + start_index = point_index + end_index = (point_index + window_size) % point_count + + # Extract coordinates from the window (handling circular outline) + if start_index < end_index: + x_window = x_coordinates[start_index:end_index] + y_window = y_coordinates[start_index:end_index] + else: + x_window = np.concatenate([x_coordinates[start_index:], x_coordinates[:end_index]]) + y_window = np.concatenate([y_coordinates[start_index:], y_coordinates[:end_index]]) + + if len(x_window) < 2: + return np.array([0.0, 0.0]) + + # Direction vector from first to last point in the window + return np.array([ + x_window[-1] - x_window[0], + y_window[-1] - y_window[0] + ]) + + @staticmethod + def calculate_sampled_curvatures( + x_coordinates: np.ndarray, + y_coordinates: np.ndarray, + point_count: int + ) -> List[float]: + """Calculate curvatures at regularly sampled points along the outline.""" + sample_step = max(1, point_count // 50) + curvatures = [] + + for i in range(0, point_count, sample_step): + curvature = GeometricCalculator.calculate_local_curvature(x_coordinates, y_coordinates, i, 5) + curvatures.append(curvature) + + return curvatures + + @staticmethod + def calculate_sampled_angles(outline_points: List[Point], point_count: int) -> List[float]: + """Calculate angles at regularly sampled points along the outline.""" + sample_step = max(1, point_count // 50) + angles = [] + + for i in range(0, point_count, sample_step): + angle = GeometricCalculator.calculate_point_angle(outline_points, i, 7) + angles.append(angle) + + return angles + + @staticmethod + def compute_smoothness_score(angles: List[float], curvatures: List[float]) -> float: + """Compute a combined smoothness score from angle and curvature statistics.""" + if not angles: + return 1.0 + + # Angle-based smoothness: shapes with smaller maximum angles are smoother + max_angle = max(angles) + angle_score = 1.0 - min(max_angle / (np.pi * 0.5), 1.0) + + # Curvature-based smoothness: shapes with consistent curvature are smoother + if curvatures: + curvature_std = np.std(curvatures) + curvature_mean = np.mean(curvatures) + + if curvature_mean > 1e-8: + curvature_variation = curvature_std / curvature_mean + curvature_score = 1.0 / (1.0 + curvature_variation) + else: + curvature_score = 1.0 + else: + curvature_score = 1.0 + + # Weighted combination of angle and curvature smoothness + return angle_score * 0.6 + curvature_score * 0.4 + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/smooth_shape_detector.py b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/smooth_shape_detector.py new file mode 100644 index 0000000..2193242 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/corner_detection/smooth_shape_detector.py @@ -0,0 +1,162 @@ +""" +Detection of ellipses and smooth shapes to skip unnecessary corner detection. +""" + +import numpy as np +from typing import List, Tuple, Dict +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.infrastructure.corner_detection.geometric_calculator import GeometricCalculator + + +class SmoothShapeDetector: + """Detects ellipses and smooth shapes to avoid unnecessary corner detection.""" + + def __init__( + self, + smoothness_threshold: float = 0.72, + ellipse_aspect_ratio_threshold: float = 1.2, + window_size: int = 15 + ): + self.smoothness_threshold = smoothness_threshold + self.ellipse_aspect_ratio_threshold = ellipse_aspect_ratio_threshold + self.window_size = window_size + + def should_skip_corner_detection(self, outline_points: List[Point], debug_data: Dict) -> bool: + """ + Check if the shape is likely an ellipse or too smooth for corner detection. + + Returns True if corner detection should be skipped for this shape. + """ + point_count = len(outline_points) + + # Early ellipse detection for small shapes + if point_count < 100 and self._is_likely_small_ellipse(outline_points): + debug_data['shape_analysis']['early_ellipse_detection'] = True + debug_data['shape_analysis']['ellipse_reason'] = "Small shape with ellipse-like properties" + return True + + # Smoothness check for larger shapes + if point_count > 30: + smoothness_score, is_ellipse = self.calculate_shape_smoothness(outline_points) + + debug_data['shape_analysis']['smoothness_score'] = smoothness_score + debug_data['shape_analysis']['is_ellipse'] = is_ellipse + + if is_ellipse: + debug_data['shape_analysis']['ellipse_reason'] = "Smoothness detection" + return True + + if smoothness_score > self.smoothness_threshold: + debug_data['shape_analysis']['too_smooth'] = True + return True + + # Check if shape is too small for reliable corner detection + if point_count < self.window_size * 2: + debug_data['shape_analysis']['too_small'] = True + return True + + return False + + def calculate_shape_smoothness(self, outline_points: List[Point]) -> Tuple[float, bool]: + """ + Calculate a smoothness score for the shape and detect if it's ellipse-like. + + Returns: + Tuple containing: + - Smoothness score (higher = smoother) + - Boolean indicating if shape is likely an ellipse + """ + point_count = len(outline_points) + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + # Calculate curvatures at sample points + curvatures = GeometricCalculator.calculate_sampled_curvatures(x_coordinates, y_coordinates, point_count) + + # Check if shape is ellipse-like + is_ellipse = self._is_shape_ellipse_like(outline_points, curvatures) + + # Calculate angles at sample points + angles = GeometricCalculator.calculate_sampled_angles(outline_points, point_count) + + # Compute smoothness score from angle and curvature statistics + smoothness_score = GeometricCalculator.compute_smoothness_score(angles, curvatures) + + return smoothness_score, is_ellipse + + def _is_shape_ellipse_like(self, outline_points: List[Point], curvatures: List[float]) -> bool: + """Determine if the shape is likely an ellipse based on curvature consistency.""" + point_count = len(outline_points) + + # Large shapes are less likely to be simple ellipses + if point_count > 200: + return False + + # Check curvature consistency + if curvatures: + curvature_std = np.std(curvatures) + curvature_mean = np.mean(curvatures) + + if curvature_mean > 1e-8: + coefficient_of_variation = curvature_std / curvature_mean + if coefficient_of_variation < 0.3: + return True + + # Check distance to center consistency + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + center_x = np.mean(x_coordinates) + center_y = np.mean(y_coordinates) + + distances = np.sqrt((x_coordinates - center_x)**2 + (y_coordinates - center_y)**2) + distance_mean = np.mean(distances) + + if distance_mean > 0: + distance_variation = np.std(distances) / distance_mean + if distance_variation < 0.2: + return True + + return False + + def _is_likely_small_ellipse(self, outline_points: List[Point]) -> bool: + """Check if a small shape is likely an ellipse.""" + point_count = len(outline_points) + + if point_count < 10: + return False + + x_coordinates = np.array([point.x for point in outline_points]) + y_coordinates = np.array([point.y for point in outline_points]) + + width = np.max(x_coordinates) - np.min(x_coordinates) + height = np.max(y_coordinates) - np.min(y_coordinates) + + # Check curvature consistency + curvatures = [] + sample_step = max(1, point_count // 20) + for i in range(0, point_count, sample_step): + curvature = GeometricCalculator.calculate_local_curvature(x_coordinates, y_coordinates, i, 3) + curvatures.append(curvature) + + if curvatures: + curvature_std = np.std(curvatures) + curvature_mean = np.mean(curvatures) + if curvature_mean > 1e-8: + coefficient_of_variation = curvature_std / curvature_mean + if coefficient_of_variation < 0.25: + return True + + # Check aspect ratio and closure + if width > 0 and height > 0: + aspect_ratio = max(width, height) / min(width, height) + if aspect_ratio < self.ellipse_aspect_ratio_threshold: + start_end_distance = np.sqrt( + (x_coordinates[0] - x_coordinates[-1])**2 + + (y_coordinates[0] - y_coordinates[-1])**2 + ) + if start_end_distance < min(width, height) * 0.1: + return True + + return False + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/__init__.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/bezier_fitter_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/bezier_fitter_factory.py new file mode 100644 index 0000000..8e4091a --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/bezier_fitter_factory.py @@ -0,0 +1,16 @@ +""" +Factory for creating bezier fitter instances. +""" + +from svg_to_getdp.infrastructure.bezier_fitting.bezier_fitter import BezierFitter +from svg_to_getdp.interfaces.abstractions.bezier_fitter_interface import BezierFitterInterface + + +class BezierFitterFactory: + """Factory for creating bezier fitter instances.""" + + @staticmethod + def create_default() -> BezierFitterInterface: + """Create a default bezier fitter.""" + return BezierFitter() + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/corner_detector_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/corner_detector_factory.py new file mode 100644 index 0000000..7c80884 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/corner_detector_factory.py @@ -0,0 +1,30 @@ +""" +Factory for creating corner detector instances. +""" + +from svg_to_getdp.interfaces.abstractions.corner_detector_interface import CornerDetectorInterface +from svg_to_getdp.infrastructure.corner_detection.corner_detector import CornerDetector + + +class CornerDetectorFactory: + """Factory for creating corner detector instances.""" + + @staticmethod + def create_default() -> CornerDetectorInterface: + """ + Create a corner detector with default parameters. + + Returns: + CornerDetectorInterface instance + """ + return CornerDetector( + window_size=15, + direction_change_threshold=0.8, + angle_threshold=0.5, # radians (~30 degrees) + minimum_corner_distance=5, + smoothness_threshold=0.72, + corner_strength_threshold=0.45, + ellipse_aspect_ratio_threshold=1.2, + debug_enabled=True + ) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_grouper_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_grouper_factory.py new file mode 100644 index 0000000..e851f94 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_grouper_factory.py @@ -0,0 +1,26 @@ +""" +Factory for creating outline grouper instances. +Implements the Factory pattern for dependency injection. +""" + +from svg_to_getdp.infrastructure.outline_grouper import OutlineGrouper +from svg_to_getdp.interfaces.abstractions.outline_grouper_interface import OutlineGrouperInterface + + +class OutlineGrouperFactory: + """ + Factory for creating OutlineGrouper instances. + + Follows the Factory pattern to decouple object creation from usage. + """ + + @staticmethod + def create_default() -> OutlineGrouperInterface: + """ + Create a default OutlineGrouper with standard settings. + + Returns: + OutlineGrouperInterface: A grouper instance with default parameters + """ + return OutlineGrouper() + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_preprocessor_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_preprocessor_factory.py new file mode 100644 index 0000000..8e4afba --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/outline_preprocessor_factory.py @@ -0,0 +1,27 @@ +""" +Factory for creating outline preprocessor instances. +Implements the Factory pattern for dependency injection. +""" + +from typing import Optional +from svg_to_getdp.infrastructure.outline_preprocessor import OutlinePreprocessor +from svg_to_getdp.interfaces.abstractions.outline_preprocessor_interface import OutlinePreprocessorInterface + + +class OutlinePreprocessorFactory: + """ + Factory for creating OutlinePreprocessor instances. + + Follows the Factory pattern to decouple object creation from usage. + """ + + @staticmethod + def create_default() -> OutlinePreprocessorInterface: + """ + Create a default OutlinePreprocessor with standard settings. + + Returns: + OutlinePreprocessorInterface: A preprocessor instance with default parameters + """ + return OutlinePreprocessor() + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/svg_parser_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/svg_parser_factory.py new file mode 100644 index 0000000..bfd1fe1 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/svg_parser_factory.py @@ -0,0 +1,26 @@ +""" +Factory for creating SVG parser instances. +Implements the Factory pattern for dependency injection. +""" + +from svg_to_getdp.infrastructure.svg_processing.svg_parser import SvgParser +from svg_to_getdp.interfaces.abstractions.svg_parser_interface import SVGParserInterface + + +class SvgParserFactory: + """ + Factory for creating SVG parser instances. + + Follows the Factory pattern to decouple object creation from usage. + This allows for easy swapping of implementations and centralized configuration. + """ + + @staticmethod + def create_default() -> SVGParserInterface: + """ + Create a default SVG parser with standard settings. + + Returns: + SVGParserInterface: A parser instance with default parameters + """ + return SvgParser() diff --git a/sketchgetdp/svg_to_getdp/infrastructure/factories/wire_preprocessor_factory.py b/sketchgetdp/svg_to_getdp/infrastructure/factories/wire_preprocessor_factory.py new file mode 100644 index 0000000..31003d9 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/factories/wire_preprocessor_factory.py @@ -0,0 +1,26 @@ +""" +Factory for creating wire preprocessor instances. +Implements the Factory pattern for dependency injection. +""" + +from svg_to_getdp.infrastructure.wire_preprocessor import WirePreprocessor +from svg_to_getdp.interfaces.abstractions.wire_preprocessor_interface import WirePreprocessorInterface + + +class WirePreprocessorFactory: + """ + Factory for creating WirePreprocessor instances. + + Follows the Factory pattern to decouple object creation from usage. + """ + + @staticmethod + def create_default() -> WirePreprocessorInterface: + """ + Create a default WirePreprocessor with standard settings. + + Returns: + WirePreprocessorInterface: A preprocessor instance with default parameters + """ + return WirePreprocessor() + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/outline_grouper.py b/sketchgetdp/svg_to_getdp/infrastructure/outline_grouper.py new file mode 100644 index 0000000..ce2f5d8 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/outline_grouper.py @@ -0,0 +1,314 @@ +from typing import List, Dict, Tuple +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.physical_group import PhysicalGroup, DOMAIN_VA, DOMAIN_VI_IRON, DOMAIN_VI_AIR, BOUNDARY_GAMMA, BOUNDARY_OUT +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.interfaces.abstractions.outline_grouper_interface import OutlineGrouperInterface + +class OutlineGrouper(OutlineGrouperInterface): + """ + Groups outlines into hierarchical structure with containment relationships + and assigns physical groups based on containment logic. + """ + + @staticmethod + def group_outlines(outlines: List[Outline]) -> List[Dict]: + """ + Main function to group outlines and assign physical groups. + + Args: + outlines: List of outlines to process + + Returns: + List of dictionaries, one per outline, with keys: + - "holes": List of indices of outlines contained by this outline + - "physical_groups": List of PhysicalGroup objects for this outline + """ + if not outlines: + return [] + + # Get containment hierarchy + containment_map = OutlineGrouper.get_containment_hierarchy(outlines) + + # Find the outermost outline (contains all others but is not contained by any) + outermost_candidates = [] + for i in range(len(outlines)): + # Count how many other outlines contain this one + contained_by_count = sum(1 for j in range(len(outlines)) + if i != j and i in containment_map[j]) + + if contained_by_count == 0: + outermost_candidates.append(i) + + # If multiple outermost candidates, choose the one with largest bounding box AREA + if outermost_candidates: + # Calculate areas for all candidates + candidate_areas = [] + for idx in outermost_candidates: + min_x, max_x, min_y, max_y = OutlineGrouper.get_outline_bounding_box(outlines[idx]) + area = (max_x - min_x) * (max_y - min_y) + candidate_areas.append((idx, area)) + + # Find the index with largest area + outermost_idx = max(candidate_areas, key=lambda item: item[1])[0] + else: + raise ValueError("No outermost candidates found") + + # Classify all outlines + classifications = [OutlineGrouper.classify_outline_color(outline) + for outline in outlines] + + # Check which Va outlines are inside Vi outlines + va_in_vi_flags = [False] * len(outlines) + for i, (outline, classification) in enumerate(zip(outlines, classifications)): + if classification == "va": + # Check if this Va outline is inside any Vi outline + for j, (other_outline, other_classification) in enumerate(zip(outlines, classifications)): + if i != j and (other_classification == "vi_iron" or other_classification == "vi_air"): + if OutlineGrouper.is_outline_inside_other(outline, other_outline): + va_in_vi_flags[i] = True + break + + # Build result dictionaries + result = [] + for i, outline in enumerate(outlines): + is_outermost = (i == outermost_idx) + is_va_in_vi = va_in_vi_flags[i] + + # Get holes (contained outlines) + holes = containment_map.get(i, []) + + # Get physical groups + physical_groups = OutlineGrouper.get_physical_groups_for_outline( + classification=classifications[i], + is_outermost=is_outermost, + is_va_in_vi=is_va_in_vi + ) + + result.append({ + "holes": holes, + "physical_groups": physical_groups + }) + + return result + + @staticmethod + def is_point_inside_outline(point: Point, outline: Outline, num_samples: int = 1000) -> bool: + """ + Check if a point is inside a closed outline using ray casting algorithm. + + Args: + point: The point to test + outline: The closed outline + num_samples: Number of samples for outline approximation + + Returns: + True if point is inside the outline, False otherwise + """ + if not outline.is_closed: + return False + + # Sample points along the outline + outline_points = outline.get_outline_points(num_samples) + + # Count intersections with horizontal ray to the right + intersections = 0 + n = len(outline_points) + + for i in range(n): + p1 = outline_points[i] + p2 = outline_points[(i + 1) % n] + + # Check if point is on the edge (within tolerance) + # This helps with floating-point precision issues + if abs(p1.x - point.x) < 1e-10 and abs(p1.y - point.y) < 1e-10: + return False # Point is exactly on a vertex + + # Check if the segment is horizontal + if abs(p1.y - p2.y) < 1e-10: + # Horizontal edge - check if point is on this edge + if abs(p1.y - point.y) < 1e-10 and \ + min(p1.x, p2.x) <= point.x <= max(p1.x, p2.x): + return False # Point is on horizontal edge + continue # Horizontal edges don't affect ray-casting + + # Check if ray intersects the edge + # First check if point is between the y-values of the edge + if (p1.y > point.y) != (p2.y > point.y): + # Calculate x-intersection of the edge with the horizontal line through point + x_intersect = p1.x + (point.y - p1.y) * (p2.x - p1.x) / (p2.y - p1.y) + + # Check if intersection is to the right of the point + if x_intersect > point.x + 1e-10: # Add small tolerance + intersections += 1 + # If intersection is exactly at the point, point is on the edge + elif abs(x_intersect - point.x) < 1e-10: + return False + + return intersections % 2 == 1 + + @staticmethod + def get_outline_bounding_box(outline: Outline) -> Tuple[float, float, float, float]: + """ + Get the bounding box of an outline. + + Args: + outline: Outline with control points + + Returns: + Tuple of (min_x, max_x, min_y, max_y) + + Raises: + ValueError: If the outline has no control points + """ + control_points = outline.control_points + if not control_points: + raise ValueError(f"Outline must have at least one control point. Got {len(control_points)} points.") + + min_x = min(p.x for p in control_points) + max_x = max(p.x for p in control_points) + min_y = min(p.y for p in control_points) + max_y = max(p.y for p in control_points) + + return (min_x, max_x, min_y, max_y) + + @staticmethod + def is_outline_inside_other(outline: Outline, outer_outline: Outline) -> bool: + """ + Check if one outline is completely inside another. + + Args: + outline: The inner outline candidate + outer_outline: The potential outer outline + + Returns: + True if outline is completely inside outer_outline + """ + if not outline.is_closed or not outer_outline.is_closed: + return False + + # Quick bounding box test - inner outline must be completely within outer outline's bbox + inner_min_x, inner_max_x, inner_min_y, inner_max_y = OutlineGrouper.get_outline_bounding_box(outline) + outer_min_x, outer_max_x, outer_min_y, outer_max_y = OutlineGrouper.get_outline_bounding_box(outer_outline) + + if not (inner_min_x >= outer_min_x and inner_max_x <= outer_max_x and + inner_min_y >= outer_min_y and inner_max_y <= outer_max_y): + return False + + # Sample points from the inner outline and check if they're all inside outer outline + sample_points = outline.get_outline_points(num_points=10) + for point in sample_points: + if not OutlineGrouper.is_point_inside_outline(point, outer_outline): + return False + + return True + + @staticmethod + def get_containment_hierarchy(outlines: List[Outline]) -> Dict[int, List[int]]: + """ + Determine containment hierarchy among outlines. + + Args: + outlines: List of all outlines + + Returns: + Dictionary mapping outline index to list of indices of its immediate children + """ + n = len(outlines) + containment_map = {i: [] for i in range(n)} + + # Calculate outline areas (approximated by bounding box) + outline_areas = [] + for i, outline in enumerate(outlines): + min_x, max_x, min_y, max_y = OutlineGrouper.get_outline_bounding_box(outline) + area = (max_x - min_x) * (max_y - min_y) + outline_areas.append((i, area)) + + # Sort by area descending + outline_areas.sort(key=lambda x: x[1], reverse=True) + sorted_indices = [idx for idx, _ in outline_areas] + + # Check containment relationships - only assign immediate parents + for i in range(n): + outer_idx = sorted_indices[i] + for j in range(i + 1, n): + inner_idx = sorted_indices[j] + + # Check if inner is contained by outer + if OutlineGrouper.is_outline_inside_other( + outlines[inner_idx], + outlines[outer_idx] + ): + # Check if inner outline already has a parent in the sorted list + # (i.e., check if there's another outline between outer and inner in the sorted list) + has_closer_parent = False + for k in range(i + 1, j): + potential_parent_idx = sorted_indices[k] + if OutlineGrouper.is_outline_inside_other( + outlines[inner_idx], + outlines[potential_parent_idx] + ): + has_closer_parent = True + break + + if not has_closer_parent: + containment_map[outer_idx].append(inner_idx) + + return containment_map + + @staticmethod + def classify_outline_color(outline: Outline) -> str: + """ + Classify an outline based on its color. + + Args: + outline: Outline with color property + + Returns: + String classification: "va", "vi_iron", or "vi_air" + """ + if outline.color.name == "black": + return "va" + elif outline.color.name == "blue": + return "vi_iron" + elif outline.color.name == "green": + return "vi_air" + else: + raise ValueError(f"Unknown outline color: {outline.color.name}") + + @staticmethod + def get_physical_groups_for_outline(classification: str, + is_outermost: bool = False, + is_va_in_vi: bool = False) -> List[PhysicalGroup]: + """ + Get physical groups for an outline based on classification and context. + + Args: + outline: Outline + classification: Outline classification from classify_outline_color + is_outermost: Whether this is the outermost outline + is_va_in_vi: Whether this Va outline is inside a Vi outline + + Returns: + List of physical groups assigned to this outline + """ + physical_groups = [] + + # Assign domain physical group based on color/classification + if classification == "va": + if is_va_in_vi: + # Va boundary inside Vi gets gamma boundary + physical_groups.append(BOUNDARY_GAMMA) + physical_groups.append(DOMAIN_VA) + + elif classification == "vi_iron": + physical_groups.append(DOMAIN_VI_IRON) + + elif classification == "vi_air": + physical_groups.append(DOMAIN_VI_AIR) + + # Add boundary_out if this is the outermost outline + if is_outermost: + physical_groups.append(BOUNDARY_OUT) + + return physical_groups + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/outline_preprocessor.py b/sketchgetdp/svg_to_getdp/infrastructure/outline_preprocessor.py new file mode 100644 index 0000000..13e0359 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/outline_preprocessor.py @@ -0,0 +1,303 @@ +""" +Outline preprocessing module for Gmsh integration. +Converts Outline objects into Gmsh geometry with proper physical groups. +""" + +from typing import List, Dict, Any +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.physical_group import PhysicalGroup +from svg_to_getdp.interfaces.abstractions.outline_preprocessor_interface import OutlinePreprocessorInterface + +class OutlinePreprocessor(OutlinePreprocessorInterface): + """ + Preprocesses Outline objects in Gmsh with proper physical group assignment. + Handles both straight lines and 2nd order Bézier curves. + """ + + def __init__(self): + """ + Initialize the preprocessor. + """ + self._point_tags = {} # Maps Point objects to Gmsh point tags + self._curve_loops = {} # Maps outline indices to Gmsh curve loop tags + self._surface_tags = {} # Maps outline indices to Gmsh surface tags + self._created_points = {} # Tracks created points to avoid duplicates + self._curve_tags_per_outline = {} # Store curve tags per outline index + self._processing_order = [] # Store the order in which outlines were processed + + # Track physical groups by type + self._physical_groups_by_type = { + 'boundary': {}, # physical_group.value -> list of curve tags + 'domain': {} # physical_group.value -> list of surface tags + } + + def preprocess_outlines(self, + factory: Any, # Add factory parameter + outlines: List[Outline], + properties: List[Dict[str, Any]]) -> None: + """ + Preprocess all outlines with their properties. + Processes outlines from innermost to outermost to ensure + holes are created before the surfaces that contain them. + + Args: + outlines: List of Outline objects to preprocess + properties: List of dictionaries with "holes" and "physical_groups" keys + Each dictionary corresponds to the outline at the same index + """ + self.factory = factory + + if len(outlines) != len(properties): + raise ValueError( + f"Number of outlines ({len(outlines)}) " + f"must match number of property dictionaries ({len(properties)})" + ) + + # Determine processing order from innermost to outermost + self._processing_order = self._get_processing_order(outlines, properties) + + # Process outlines in topological order (inner to outer) + for idx in self._processing_order: + outline = outlines[idx] + props = properties[idx] + self._preprocess_single_outline(idx, outline, props) + + # Now collect all entities by physical group type + for idx in self._processing_order: + outline = outlines[idx] + props = properties[idx] + self._collect_physical_groups(idx, props) + + # After all curves and surfaces are created, assign physical groups + self._assign_physical_groups() + + def _get_processing_order(self, + outlines: List[Outline], + properties: List[Dict[str, Any]]) -> List[int]: + """ + Determine the processing order from innermost to outermost outlines. + + Args: + outlines: List of Outline objects + properties: List of property dictionaries + + Returns: + List of indices in processing order (innermost to outermost) + """ + n = len(outlines) + + # Build dependency graph: edge from hole to container + # If A is a hole in B, then A must be processed before B + adjacency = [[] for _ in range(n)] + + for i in range(n): + if "holes" in properties[i] and properties[i]["holes"]: + hole_indices = properties[i]["holes"] + if isinstance(hole_indices, list): + for hole_idx in hole_indices: + if 0 <= hole_idx < n: + # hole_idx must be processed before i + adjacency[hole_idx].append(i) + + # Calculate in-degree for Kahn's algorithm + in_degree = [0] * n + for i in range(n): + for neighbor in adjacency[i]: + in_degree[neighbor] += 1 + + # Start with nodes that have no dependencies (innermost) + queue = [i for i in range(n) if in_degree[i] == 0] + processing_order = [] + + while queue: + current = queue.pop(0) + processing_order.append(current) + + # For each outline that depends on this one (containers) + for neighbor in adjacency[current]: + in_degree[neighbor] -= 1 + if in_degree[neighbor] == 0: + queue.append(neighbor) + + if len(processing_order) != n: + # Cycle detected + print("Warning: Could not determine topological order. Using input order.") + return list(range(n)) + + return processing_order + + def _preprocess_single_outline(self, + idx: int, + outline: Outline, + properties: Dict[str, Any]) -> None: + """ + Preprocess a single outline. + + Steps: + 1. Draw points + 2. Draw lines (= boundary) + 3. Define curve loop (= potential hole in other domain) + 4. Define curve loop list (curve loop with holes!) + 5. Define plane surface (= surface) + """ + # Step 1: Create points for all unique control points + point_tags = [] + for point in outline.unique_control_points: + tag = self._create_or_get_point(point) + point_tags.append(tag) + + # Step 2: Create curves (lines and Bézier curves) + curve_tags = [] + segment_start_idx = 0 + + for segment in outline.bezier_segments: + # Check if segment is a straight line (degree 1 or collinear control points) + if segment.degree == 1: + # Straight line segment - use simple line + line_tag = self.factory.addLine( + point_tags[segment_start_idx], + point_tags[segment_start_idx + 1] + ) + curve_tags.append(line_tag) + segment_start_idx += 1 # Only move by 1 since degree 1 has 2 points + else: + # For Higher degree Bézier curve: degree + 1 points + segment_point_tags = point_tags[segment_start_idx:segment_start_idx + segment.degree + 1] + + # Create compound Bézier curve in Gmsh + bezier_tag = self.factory.addBezier(segment_point_tags) + curve_tags.append(bezier_tag) + segment_start_idx += segment.degree # Move by degree for next segment + + # Store curve tags + self._curve_tags_per_outline[idx] = curve_tags + + # Step 3: Define curve loop + curve_loop_tag = self.factory.addCurveLoop(curve_tags) + self._curve_loops[idx] = curve_loop_tag + + # Step 4: Create curve loop list (main loop + holes) + curve_loops_for_surface = [curve_loop_tag] + + if "holes" in properties and properties["holes"]: + hole_indices = properties["holes"] + if isinstance(hole_indices, list): + for hole_idx in hole_indices: + # The hole should already be created since we process inner to outer + if hole_idx in self._curve_loops: + curve_loops_for_surface.append(self._curve_loops[hole_idx]) + else: + raise ValueError( + f"Hole outline {hole_idx} referenced by " + f"outline {idx} has not been created yet. " + ) + + # Step 5: Define plane surface + surface_tag = self.factory.addPlaneSurface(curve_loops_for_surface) + self._surface_tags[idx] = surface_tag + + def _create_or_get_point(self, point: Point) -> int: + """ + Create a point in Gmsh or return existing tag if point already exists. + + Args: + point: Point object with x, y coordinates + + Returns: + Gmsh point tag + """ + for existing_point, tag in self._created_points.items(): + if existing_point == point: + return tag + + point_tag = self.factory.addPoint(point.x, point.y, 0.0) + self._created_points[point] = point_tag + return point_tag + + def _collect_physical_groups(self, + idx: int, + properties: Dict[str, Any]) -> None: + """ + Collect entities that belong to each physical group type. + + Args: + idx: Index of the outline + properties: Dictionary with "physical_groups" key + """ + if "physical_groups" not in properties: + return + + physical_groups = properties["physical_groups"] + + if not isinstance(physical_groups, list): + physical_groups = [physical_groups] + + for pg in physical_groups: + if not isinstance(pg, PhysicalGroup): + raise TypeError(f"Physical group must be PhysicalGroup instance, got {type(pg)}") + + if pg.is_boundary(): + # Collect curve tags for this boundary group + if idx in self._curve_tags_per_outline: + if pg.value not in self._physical_groups_by_type['boundary']: + self._physical_groups_by_type['boundary'][pg.value] = [] + self._physical_groups_by_type['boundary'][pg.value].extend( + self._curve_tags_per_outline[idx] + ) + + elif pg.is_domain(): + # Collect surface tag for this domain group + if idx in self._surface_tags: + if pg.value not in self._physical_groups_by_type['domain']: + self._physical_groups_by_type['domain'][pg.value] = [] + self._physical_groups_by_type['domain'][pg.value].append( + self._surface_tags[idx] + ) + + def _assign_physical_groups(self) -> None: + """ + Assign physical groups after all entities are collected. + Creates one physical group per type with all relevant entities. + """ + # Assign boundary groups (1D curves) + for physical_group_value, curve_tags in self._physical_groups_by_type['boundary'].items(): + if curve_tags: + # Remove duplicates while preserving order + unique_curve_tags = list(dict.fromkeys(curve_tags)) + self.factory.addPhysicalGroup(1, unique_curve_tags, physical_group_value) + print(f"Created boundary physical group (tag {physical_group_value}) " + f"with {len(unique_curve_tags)} curves") + + # Assign domain groups (2D surfaces) + for physical_group_value, surface_tags in self._physical_groups_by_type['domain'].items(): + if surface_tags: + # Remove duplicates while preserving order + unique_surface_tags = list(dict.fromkeys(surface_tags)) + self.factory.addPhysicalGroup(2, unique_surface_tags, physical_group_value) + print(f"Created domain physical group (tag {physical_group_value}) " + f"with {len(unique_surface_tags)} surfaces") + + def get_processing_order(self) -> List[int]: + """ + Get the order in which outlines were processed. + + Returns: + List of indices in processing order (innermost to outermost) + """ + return self._processing_order.copy() + + def get_curve_loop_tag(self, idx: int) -> int: + """ + Get the curve loop tag for an outline. + + Args: + idx: Index of the outline + + Returns: + Gmsh curve loop tag + """ + if idx not in self._curve_loops: + raise KeyError(f"No curve loop found for outline index {idx}") + return self._curve_loops[idx] + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/__init__.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/raw_outline_assembler.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/raw_outline_assembler.py new file mode 100644 index 0000000..328d2c2 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/raw_outline_assembler.py @@ -0,0 +1,43 @@ +""" +Assembles RawOutline objects from processed SVG data. +Handles creation and validation of RawOutline instances. +""" + +from typing import List +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.raw_outline import RawOutline + + +class RawOutlineAssembler: + """ + Assembles RawOutline objects from processed components. + Provides factory methods for creating validated RawOutline instances. + """ + + @staticmethod + def create_raw_outline(points: List[Point], color: Color, is_closed: bool = True) -> RawOutline: + """ + Create and validate a RawOutline instance. + """ + raw_outline = RawOutline(points=points, color=color, is_closed=is_closed) + + return raw_outline + + @staticmethod + def create_red_dot(point: Point) -> RawOutline: + """ + Create a RawOutline for a red dot (single point). + """ + return RawOutline(points=[point], color=Color.RED, is_closed=True) + + @staticmethod + def create_polyline(points: List[Point], color: Color, is_closed: bool = False) -> RawOutline: + """ + Create a RawOutline for a polyline (open or closed). + """ + if color != Color.RED and len(points) < 3: + raise ValueError(f"Polyline must have at least 3 points for color {color.name}, got {len(points)}") + + return RawOutline(points=points, color=color, is_closed=is_closed) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_color_classifier.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_color_classifier.py new file mode 100644 index 0000000..8e7d8a2 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_color_classifier.py @@ -0,0 +1,202 @@ +""" +Classifies SVG colors into the application's Color enum. +Handles extraction from attributes, style strings, and color parsing. +""" + +import re +import math +from typing import Dict +from svg_to_getdp.core.entities.color import Color + + +class SvgColorClassifier: + """ + Classifies colors from SVG attributes and strings. + Maps various color representations to the application's Color enum. + """ + + def extract_color_from_attributes(self, attributes: Dict) -> Color: + """ + Extract color from svgpathtools attributes dictionary. + """ + # Check stroke, fill, and style attributes + stroke = attributes.get('stroke') + fill = attributes.get('fill') + style = attributes.get('style') + + color_str = None + + # Priority: stroke -> fill -> style attribute + if stroke and stroke != 'none': + color_str = stroke + elif fill and fill != 'none': + color_str = fill + elif style: + # Parse style attribute + style_parts = [part.strip() for part in style.split(';')] + for part in style_parts: + if part.startswith('stroke:'): + color_parts = part.split(':', 1) + if len(color_parts) == 2: + potential_color = color_parts[1].strip() + if potential_color and potential_color != 'none': + color_str = potential_color + break + elif part.startswith('fill:'): + color_parts = part.split(':', 1) + if len(color_parts) == 2: + potential_color = color_parts[1].strip() + if potential_color and potential_color != 'none': + color_str = potential_color + break + + if not color_str or color_str == 'none': + raise ValueError(f"No valid color found in attributes: {attributes}") + + return self.parse_color_string(color_str) + + def extract_color_from_style(self, style_string: str) -> Color: + """ + Extract color from SVG style attribute. + """ + if not style_string: + raise ValueError("No style attribute found") + + # Parse style string + style_parts = [part.strip() for part in style_string.split(';')] + color_str = None + + for part in style_parts: + if part.startswith('fill:'): + color_parts = part.split(':', 1) + if len(color_parts) == 2: + color_str = color_parts[1].strip() + break + + if not color_str or color_str == 'none': + raise ValueError(f"No valid fill color found in style: {style_string}") + + return self.parse_color_string(color_str) + + def parse_color_string(self, color_string: str) -> Color: + """Convert color string to Color enum.""" + normalized_color = color_string.lower().strip() + + if self._is_red_color(normalized_color): + return Color.RED + elif self._is_green_color(normalized_color): + return Color.GREEN + elif self._is_blue_color(normalized_color): + return Color.BLUE + elif self._is_black_color(normalized_color): + return Color.BLACK + elif normalized_color.startswith('#'): + return self._convert_hex_to_primary_color(normalized_color) + elif normalized_color.startswith('rgb'): + return self._parse_rgb_color_string(normalized_color) + else: + return self._infer_color_from_name(normalized_color) + + def _is_red_color(self, color_string: str) -> bool: + """Check if color string represents a red color.""" + red_representations = { + '#ff0000', 'red', '#f00', '#ff0000ff', + 'rgb(255,0,0)', 'rgb(255, 0, 0)', + '#fa0000' + } + return color_string in red_representations + + def _is_green_color(self, color_string: str) -> bool: + """Check if color string represents a green color.""" + green_representations = { + '#00ff00', 'green', '#0f0', '#00ff00ff', + 'rgb(0,255,0)', 'rgb(0, 255, 0)', + '#00f700' + } + return color_string in green_representations + + def _is_blue_color(self, color_string: str) -> bool: + """Check if color string represents a blue color.""" + blue_representations = { + '#0000ff', 'blue', '#00f', '#0000ffff', + 'rgb(0,0,255)', 'rgb(0, 0, 255)', + '#0000fb' + } + return color_string in blue_representations + + def _is_black_color(self, color_string: str) -> bool: + """Check if color string represents a black color.""" + black_representations = { + '#000000', 'black', '#000', '#000000ff', + 'rgb(0,0,0)', 'rgb(0, 0, 0)' + } + return color_string in black_representations + + def _infer_color_from_name(self, color_name: str) -> Color: + """Infer color from color name containing color hint.""" + if 'red' in color_name: + return Color.RED + elif 'green' in color_name: + return Color.GREEN + elif 'blue' in color_name: + return Color.BLUE + else: + raise ValueError(f"Unknown color format: '{color_name}'") + + def _parse_rgb_color_string(self, rgb_string: str) -> Color: + """Parse RGB color string and find closest primary color.""" + match = re.match(r'rgba?\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*(?:,\s*[\d.]+\s*)?\)', rgb_string) + if not match: + raise ValueError(f"Invalid RGB color format: '{rgb_string}'") + + red, green, blue = map(int, match.groups()) + return self._find_closest_primary_color(red, green, blue) + + def _convert_hex_to_primary_color(self, hex_string: str) -> Color: + """Convert hex color to closest primary color.""" + hex_digits = hex_string.lstrip('#') + + try: + if len(hex_digits) == 6: + red = int(hex_digits[0:2], 16) + green = int(hex_digits[2:4], 16) + blue = int(hex_digits[4:6], 16) + elif len(hex_digits) == 3: + red = int(hex_digits[0] * 2, 16) + green = int(hex_digits[1] * 2, 16) + blue = int(hex_digits[2] * 2, 16) + else: + raise ValueError(f"Invalid hex color length: {len(hex_digits)}") + + return self._find_closest_primary_color(red, green, blue) + + except ValueError as e: + raise ValueError(f"Invalid hex color format '#{hex_digits}': {e}") + + def _find_closest_primary_color(self, red: int, green: int, blue: int) -> Color: + """Find the closest primary color using Euclidean distance in RGB space.""" + primary_colors = { + Color.RED: (255, 0, 0), + Color.GREEN: (0, 255, 0), + Color.BLUE: (0, 0, 255), + Color.BLACK: (0, 0, 0) + } + + min_distance = float('inf') + closest_color = None + + for color, (target_red, target_green, target_blue) in primary_colors.items(): + distance = math.sqrt( + (red - target_red)**2 + + (green - target_green)**2 + + (blue - target_blue)**2 + ) + if distance < min_distance: + min_distance = distance + closest_color = color + + if closest_color is None: + raise ValueError(f"Could not determine closest primary color for RGB({red},{green},{blue})") + + return closest_color + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_coordinate_converter.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_coordinate_converter.py new file mode 100644 index 0000000..833944f --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_coordinate_converter.py @@ -0,0 +1,65 @@ +""" +Converts SVG coordinates to the application's unit coordinate system. +Handles viewbox scaling, dimension fallbacks, and Y-axis flipping. +""" + +import re +from typing import Optional, Tuple +from svg_to_getdp.core.entities.point import Point + + +class SvgCoordinateConverter: + """ + Converts SVG coordinates to normalized unit coordinates [0,1]×[0,1]. + Handles viewbox parsing, dimension extraction, and coordinate transformation. + """ + + def scale_to_unit_coordinates(self, point: Point, + viewbox: Optional[Tuple[float, float, float, float]], + svg_width: float, svg_height: float) -> Point: + """ + Scale point to unit square [0,1]×[0,1] and flip Y-axis. + """ + if viewbox: + viewbox_x, viewbox_y, viewbox_width, viewbox_height = viewbox + if viewbox_width > 0 and viewbox_height > 0: + normalized_x = (point.x - viewbox_x) / viewbox_width + normalized_y = (point.y - viewbox_y) / viewbox_height + flipped_y = 1.0 - normalized_y + return Point(normalized_x, flipped_y) + + if svg_width > 0 and svg_height > 0: + normalized_x = point.x / svg_width + normalized_y = point.y / svg_height + flipped_y = 1.0 - normalized_y + return Point(normalized_x, flipped_y) + + # Fallback to default scaling + normalized_x = point.x / 100.0 + normalized_y = point.y / 100.0 + flipped_y = 1.0 - normalized_y + return Point(normalized_x, flipped_y) + + def parse_viewbox(self, viewbox_string: str) -> Optional[Tuple[float, float, float, float]]: + """Parse SVG viewBox attribute.""" + if not viewbox_string: + return None + + try: + coordinates = [float(coord) for coord in viewbox_string.split()] + return tuple(coordinates) if len(coordinates) == 4 else None + except ValueError: + return None + + def get_svg_dimensions(self, root_element) -> Tuple[float, float]: + """Extract SVG width and height as fallback for scaling.""" + try: + width_string = root_element.get('width', '100') + height_string = root_element.get('height', '100') + + width = float(re.sub(r'[^\d.]', '', width_string)) + height = float(re.sub(r'[^\d.]', '', height_string)) + return width, height + except (ValueError, TypeError): + return 100.0, 100.0 + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_parser.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_parser.py new file mode 100644 index 0000000..5c465df --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_parser.py @@ -0,0 +1,302 @@ +""" +Main SVG Parser orchestrator that coordinates the parsing pipeline. +Implements the SVGParserInterface and delegates to specialized processors. +""" + +import xml.etree.ElementTree as ET +from typing import Dict, List, Optional, Tuple + +from svgpathtools import svg2paths, Path + +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.raw_outline import RawOutline +from svg_to_getdp.interfaces.abstractions.svg_parser_interface import SVGParserInterface + +from svg_to_getdp.infrastructure.svg_processing.raw_outline_assembler import RawOutline +from svg_to_getdp.infrastructure.svg_processing.svg_color_classifier import SvgColorClassifier +from svg_to_getdp.infrastructure.svg_processing.svg_transform_applier import SvgTransformApplier +from svg_to_getdp.infrastructure.svg_processing.svg_path_refiner import SvgPathRefiner +from svg_to_getdp.infrastructure.svg_processing.svg_coordinate_converter import SvgCoordinateConverter + + +class SvgParser(SVGParserInterface): + """ + Main SVG parser that orchestrates the parsing pipeline. + Delegates specific responsibilities to specialized processors. + """ + + def __init__(self, samples_per_segment: int = 20, points_per_unit_length: int = 1000): + self.namespace = '{http://www.w3.org/2000/svg}' + self.samples_per_segment = samples_per_segment + self.points_per_unit_length = points_per_unit_length + + # Initialize specialized processors + self.color_classifier = SvgColorClassifier() + self.transform_applier = SvgTransformApplier() + self.path_refiner = SvgPathRefiner(points_per_unit_length) + self.coordinate_converter = SvgCoordinateConverter() + + def extract_raw_outlines_by_color(self, svg_file_path: str) -> Dict[Color, List[RawOutline]]: + """ + Parse SVG file and extract raw_outlines grouped by color. + + Strategy: + 1. Use svg2paths for all non-red paths (green, blue, black) + 2. Parse circle/ellipse elements directly from XML for red structures + """ + try: + # Parse the XML tree to access all elements + tree = ET.parse(svg_file_path) + root = tree.getroot() + + # Parse paths with svgpathtools + paths, attributes = svg2paths(svg_file_path) + + except Exception as e: + raise ValueError(f"Invalid SVG file: {e}") + + viewbox = self._parse_viewbox(root.get('viewBox')) + svg_width, svg_height = self._get_svg_dimensions(root) + + # Parse paths from svgpathtools + # Skip red paths here - handled separately + path_raw_outlines = self._convert_paths_to_raw_outlines( + paths, attributes, viewbox, svg_width, svg_height + ) + + red_dots_raw_outlines = self._extract_red_dots_from_xml( + root, viewbox, svg_width, svg_height + ) + + # Merge both results - path raw_outlines (green, blue, black) and red dots + raw_outlines_by_color = self._merge_raw_outlines(path_raw_outlines, red_dots_raw_outlines) + + # Apply post-processing resampling to ensure even point distribution + resampled_raw_outlines = self.path_refiner.resample_all_raw_outlines(raw_outlines_by_color) + + # Remove duplicate points from all raw_outlines after resampling + clean_raw_outlines = self.path_refiner.remove_duplicates_from_all_raw_outlines(resampled_raw_outlines) + + # Merge nearby raw_outlines of the same color + merged_raw_outlines = self.path_refiner.merge_nearby_raw_outlines( + clean_raw_outlines, distance_threshold=0.02 + ) + + return merged_raw_outlines + + def _convert_paths_to_raw_outlines(self, paths: List[Path], attributes: List[dict], + viewbox: Optional[Tuple[float, float, float, float]], + svg_width: float, svg_height: float) -> Dict[Color, List[RawOutline]]: + """ + Convert all SVG paths to raw_outline objects grouped by color. Red paths are skipped here. + """ + raw_outlines_by_color = {} + + for path_index, (path, attr) in enumerate(zip(paths, attributes)): + try: + color = self.color_classifier.extract_color_from_attributes(attr) + + # SKIP RED PATHS - these are typically converted circles/ellipses + # that we'll handle separately via XML parsing for more flexibility + if color == Color.RED: + continue + + # Convert path to points + points = self._convert_path_to_points(path, viewbox, svg_width, svg_height) + + if not points: + raise ValueError("Path contains no valid points") + + # Check if path is closed + is_closed = self._is_path_closed(path) + + # Create RawOutline using the assembler (to be implemented separately) + raw_outline = RawOutline( + points=points, + color=color, + is_closed=is_closed + ) + + if raw_outline.color not in raw_outlines_by_color: + raw_outlines_by_color[raw_outline.color] = [] + raw_outlines_by_color[raw_outline.color].append(raw_outline) + + except Exception as e: + print(f"WARNING: Failed to process path {path_index}: {e}") + continue + + return raw_outlines_by_color + + def _extract_red_dots_from_xml(self, root: ET.Element, + viewbox: Optional[Tuple[float, float, float, float]], + svg_width: float, svg_height: float) -> Dict[Color, List[RawOutline]]: + """ + Extract red dots (circles and ellipses) directly from XML. + """ + red_dots_raw_outlines = {} + + # Find all circle and ellipse elements + for element_name in ['circle', 'ellipse']: + for elem in root.iter(f'{self.namespace}{element_name}'): + try: + color = self._extract_color_from_xml_element(elem) + + # Only process red circles/ellipses - skip other colors + if color != Color.RED: + continue + + # Get center coordinates + cx = float(elem.get('cx', '0')) + cy = float(elem.get('cy', '0')) + + # Apply transform if present + transform = elem.get('transform', '') + if transform: + transformed_point = self.transform_applier.apply_transform_to_point(cx, cy, transform) + cx, cy = transformed_point + + # Scale to unit coordinates + point = Point(cx, cy) + scaled_point = self.coordinate_converter.scale_to_unit_coordinates( + point, viewbox, svg_width, svg_height + ) + + # For red dots, we just want the center point + raw_outline = RawOutline( + points=[scaled_point], + color=color, + is_closed=True + ) + + if color not in red_dots_raw_outlines: + red_dots_raw_outlines[color] = [] + red_dots_raw_outlines[color].append(raw_outline) + + except Exception as e: + print(f"WARNING: Failed to process {element_name} element: {e}") + continue + + return red_dots_raw_outlines + + def _extract_color_from_xml_element(self, elem: ET.Element) -> Color: + """ + Extract color from XML element attributes. + """ + # Get color from multiple possible attributes + style = elem.get('style', '') + stroke = elem.get('stroke', '') + fill = elem.get('fill', '') + + color = None + + # Try to extract color from different sources + # Priority: stroke attribute -> fill attribute -> style attribute + if stroke and stroke != 'none': + color = self.color_classifier.parse_color_string(stroke) + elif fill and fill != 'none': + color = self.color_classifier.parse_color_string(fill) + elif style: + color = self.color_classifier.extract_color_from_style(style) + + if not color: + raise ValueError(f"No valid color found for element") + + return color + + def _convert_path_to_points(self, path: Path, viewbox: Optional[Tuple[float, float, float, float]], + svg_width: float, svg_height: float) -> List[Point]: + """ + Convert svgpathtools Path object to list of scaled points. + """ + points = [] + + for segment in path: + segment_points = self._sample_segment_points(segment, self.samples_per_segment) + points.extend(segment_points) + + points = self.path_refiner.remove_consecutive_duplicate_points(points) + return [ + self.coordinate_converter.scale_to_unit_coordinates(p, viewbox, svg_width, svg_height) + for p in points + ] + + def _sample_segment_points(self, segment, samples_per_segment: int) -> List[Point]: + """ + Sample multiple points from a path segment. + """ + from svgpathtools import Line, CubicBezier, QuadraticBezier, Arc + + points = [] + + if isinstance(segment, (Line, CubicBezier, QuadraticBezier, Arc)): + for sample_index in range(samples_per_segment + 1): + parameter = sample_index / samples_per_segment + try: + complex_point = segment.point(parameter) + points.append(Point(complex_point.real, complex_point.imag)) + except Exception as e: + print(f"WARNING: Failed to sample segment at parameter={parameter}: {e}") + continue + + return points + + def _is_path_closed(self, path: Path) -> bool: + """ + Determine if a path forms a closed shape. + """ + if len(path) == 0: + return False + + try: + start_point = path[0].point(0) + end_point = path[-1].point(1) + + tolerance = 1e-6 + distance = abs(start_point - end_point) + return distance < tolerance + except: + return False + + def _merge_raw_outlines(self, raw_outlines1: Dict[Color, List[RawOutline]], + raw_outlines2: Dict[Color, List[RawOutline]]) -> Dict[Color, List[RawOutline]]: + """ + Merge two dictionaries of raw_outlines. + """ + merged = {} + all_colors = set(raw_outlines1.keys()) | set(raw_outlines2.keys()) + + for color in all_colors: + merged[color] = [] + if color in raw_outlines1: + merged[color].extend(raw_outlines1[color]) + if color in raw_outlines2: + merged[color].extend(raw_outlines2[color]) + + return merged + + def _parse_viewbox(self, viewbox_string: str) -> Optional[Tuple[float, float, float, float]]: + """Parse SVG viewBox attribute.""" + if not viewbox_string: + return None + + try: + coordinates = [float(coord) for coord in viewbox_string.split()] + return tuple(coordinates) if len(coordinates) == 4 else None + except ValueError: + return None + + def _get_svg_dimensions(self, root_element: ET.Element) -> Tuple[float, float]: + """Extract SVG width and height as fallback for scaling.""" + import re + + try: + width_string = root_element.get('width', '100') + height_string = root_element.get('height', '100') + + width = float(re.sub(r'[^\d.]', '', width_string)) + height = float(re.sub(r'[^\d.]', '', height_string)) + return width, height + except (ValueError, TypeError): + return 100.0, 100.0 + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_path_refiner.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_path_refiner.py new file mode 100644 index 0000000..bc29cab --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_path_refiner.py @@ -0,0 +1,277 @@ +""" +Refines SVG paths by resampling, merging, and cleaning point sequences. +Post-processes SVG geometry for optimal representation. +""" + +import math +from typing import Dict, List + +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + +from svg_to_getdp.infrastructure.svg_processing.raw_outline_assembler import RawOutline + + +class SvgPathRefiner: + """ + Refines SVG paths through resampling, deduplication, and merging operations. + """ + + def __init__(self, points_per_unit_length: int = 1000): + self.points_per_unit_length = points_per_unit_length + + def resample_all_raw_outlines(self, raw_outlines_by_color: Dict[Color, List[RawOutline]]) -> Dict[Color, List[RawOutline]]: + """ + Apply uniform resampling to all raw_outlines except red dots. + """ + resampled_raw_outlines = {} + + for color, raw_outlines in raw_outlines_by_color.items(): + resampled_raw_outlines[color] = [] + for raw_outline in raw_outlines: + if color == Color.RED: + # Don't resample red dots (single points) + resampled_raw_outlines[color].append(raw_outline) + else: + # Resample polylines for even point distribution + resampled_points = self._resample_polyline_uniform(raw_outline.points) + resampled_raw_outline = RawOutline( + points=resampled_points, + color=raw_outline.color, + is_closed=raw_outline.is_closed + ) + resampled_raw_outlines[color].append(resampled_raw_outline) + + return resampled_raw_outlines + + def _resample_polyline_uniform(self, points: List[Point]) -> List[Point]: + """ + Resample polyline to have evenly spaced points. + + Args: + points: Original unevenly distributed points + + Returns: + List of evenly spaced points + """ + if len(points) < 2: + return points + + # Calculate total length and segment lengths + total_length = 0.0 + segment_lengths = [] + for i in range(len(points) - 1): + segment_length = math.sqrt( + (points[i+1].x - points[i].x)**2 + + (points[i+1].y - points[i].y)**2 + ) + segment_lengths.append(segment_length) + total_length += segment_length + + if total_length <= 0: + return points + + spacing = 1.0 / self.points_per_unit_length + + # Calculate how many points we need for each segment + resampled_points = [points[0]] + + for segment_idx in range(len(segment_lengths)): + segment_length = segment_lengths[segment_idx] + segment_start = points[segment_idx] + segment_end = points[segment_idx + 1] + + # Calculate how many points to place on this segment (excluding the start point) + num_points_on_segment = max(1, int(segment_length / spacing)) + actual_spacing = segment_length / num_points_on_segment + + # Add points along this segment + for i in range(1, num_points_on_segment): + t = i * actual_spacing / segment_length + new_x = segment_start.x + t * (segment_end.x - segment_start.x) + new_y = segment_start.y + t * (segment_end.y - segment_start.y) + resampled_points.append(Point(new_x, new_y)) + + # Add the segment end point (unless it's the very last point of the polyline) + if segment_idx < len(segment_lengths) - 1: + resampled_points.append(segment_end) + + # Always include the very last point of the polyline + if resampled_points[-1] != points[-1]: + resampled_points.append(points[-1]) + + return resampled_points + + def remove_consecutive_duplicate_points(self, points: List[Point]) -> List[Point]: + """Remove consecutive duplicate points while preserving order.""" + if not points: + return points + + unique_points = [points[0]] + for current_point in points[1:]: + if current_point != unique_points[-1]: + unique_points.append(current_point) + + return unique_points + + def _remove_duplicate_end_point(self, points: List[Point]) -> List[Point]: + """Remove closing duplicate point for closed paths.""" + if not points: + return points + + # Check if path is closed (first and last points are the same) + if len(points) > 1 and points[0] == points[-1]: + # Remove the last point since it's a duplicate of the first + points = points[:-1] + + return points + + def remove_duplicates_from_all_raw_outlines(self, raw_outlines_by_color: Dict[Color, List[RawOutline]]) -> Dict[Color, List[RawOutline]]: + """ + Remove duplicate points from all raw_outlines after resampling. + """ + cleaned_raw_outlines = {} + + for color, raw_outlines in raw_outlines_by_color.items(): + cleaned_raw_outlines[color] = [] + for raw_outline in raw_outlines: + if color == Color.RED: + # For red dots (single points), no need to remove duplicates + cleaned_raw_outlines[color].append(raw_outline) + else: + # Remove duplicate points from polyline raw_outlines + no_consecutive_duplicate_points = self.remove_consecutive_duplicate_points(raw_outline.points) + cleaned_points = self._remove_duplicate_end_point(no_consecutive_duplicate_points) + cleaned_raw_outline = RawOutline( + points=cleaned_points, + color=raw_outline.color, + is_closed=raw_outline.is_closed + ) + cleaned_raw_outlines[color].append(cleaned_raw_outline) + + return cleaned_raw_outlines + + def merge_nearby_raw_outlines(self, raw_outlines_by_color: Dict[Color, List[RawOutline]], + distance_threshold: float = 0.02) -> Dict[Color, List[RawOutline]]: + """ + Merge raw_outlines of the same color that are close to each other and not already closed. + + Args: + raw_outlines_by_color: Dictionary of raw_outlines grouped by color + distance_threshold: Maximum distance between endpoints to consider for merging (in unit coordinates) + + Returns: + Dictionary with merged raw_outlines + """ + merged_raw_outlines = {} + for color, raw_outlines in raw_outlines_by_color.items(): + if color == Color.RED: + # Don't merge red dots (they're single points) + merged_raw_outlines[color] = raw_outlines + continue + + # Skip if only one raw_outline or all raw_outline are already closed + if len(raw_outlines) <= 1 or all(o.is_closed for o in raw_outlines): + merged_raw_outlines[color] = raw_outlines + continue + + # Create a list of open raw_outlines to process + open_raw_outlines = [o for o in raw_outlines if not o.is_closed] + closed_raw_outlines = [o for o in raw_outlines if o.is_closed] + + # Try to merge open raw_outlines + merged = self._merge_open_raw_outlines(open_raw_outlines, distance_threshold) + + # Combine merged raw_outlines with closed ones + merged_raw_outlines[color] = closed_raw_outlines + merged + + return merged_raw_outlines + + def _merge_open_raw_outlines(self, open_raw_outlines: List[RawOutline], + distance_threshold: float) -> List[RawOutline]: + """ + Merge open raw_outlines by connecting endpoints that are close together. + """ + if not open_raw_outlines: + return [] + + merged_raw_outlines = [] + processed = [False] * len(open_raw_outlines) + + for i, raw_outline in enumerate(open_raw_outlines): + if processed[i]: + continue + + # Start a new merged raw_outline with this one + current_points = raw_outline.points.copy() + start_point = current_points[0] + end_point = current_points[-1] + + processed[i] = True + merged_with_something = True + + # Keep trying to merge until no more merges are possible + while merged_with_something: + merged_with_something = False + + for j, other_raw_outline in enumerate(open_raw_outlines): + if processed[j]: + continue + + other_start = other_raw_outline.points[0] + other_end = other_raw_outline.points[-1] + + # Check for possible connections + start_to_start = self._distance_between_points(start_point, other_start) + start_to_end = self._distance_between_points(start_point, other_end) + end_to_start = self._distance_between_points(end_point, other_start) + end_to_end = self._distance_between_points(end_point, other_end) + + min_distance = min(start_to_start, start_to_end, end_to_start, end_to_end) + + if min_distance <= distance_threshold: + # Merge the raw_outlines + if min_distance == start_to_start: + # Reverse other raw_outline and prepend to current + other_points_reversed = other_raw_outline.points[::-1] + current_points = other_points_reversed + current_points[1:] + start_point = other_end # After reversal, start becomes end + elif min_distance == start_to_end: + # Prepend other raw_outline to current + current_points = other_raw_outline.points[:-1] + current_points + start_point = other_start + elif min_distance == end_to_start: + # Append other raw_outline to current + current_points = current_points[:-1] + other_raw_outline.points + end_point = other_end + elif min_distance == end_to_end: + # Reverse other raw_outline and append to current + other_points_reversed = other_raw_outline.points[::-1] + current_points = current_points[:-1] + other_points_reversed + end_point = other_start # After reversal, end becomes start + + processed[j] = True + merged_with_something = True + break + + # Check if the merged raw_outline is now closed + is_closed = self._distance_between_points(start_point, end_point) <= distance_threshold + + if is_closed: + # Ensure proper closure + if self._distance_between_points(current_points[0], current_points[-1]) > distance_threshold: + current_points.append(current_points[0]) + + merged_raw_outline = RawOutline( + points=current_points, + color=raw_outline.color, + is_closed=is_closed + ) + merged_raw_outlines.append(merged_raw_outline) + + return merged_raw_outlines + + def _distance_between_points(self, p1: Point, p2: Point) -> float: + """Calculate Euclidean distance between two points.""" + return math.sqrt((p2.x - p1.x)**2 + (p2.y - p1.y)**2) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_transform_applier.py b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_transform_applier.py new file mode 100644 index 0000000..c4b7c6b --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/svg_processing/svg_transform_applier.py @@ -0,0 +1,88 @@ +""" +Applies SVG transforms to points and coordinates. +Handles matrix, rotate, scale, and translate transformations. +""" + +import re +import math +from typing import Tuple + + +class SvgTransformApplier: + """ + Applies SVG transform operations to points. + Supports matrix(), rotate(), scale(), and translate() transforms. + """ + + def apply_transform_to_point(self, x: float, y: float, transform_str: str) -> Tuple[float, float]: + """ + Apply SVG transform to a point. + Handles matrix(), rotate(), scale(), and translate() transforms. + """ + if not transform_str: + return x, y + + # Parse matrix transform: matrix(a,b,c,d,e,f) + matrix_match = re.match( + r'matrix\s*\(\s*([-\d.]+)\s*,\s*([-\d.]+)\s*,\s*([-\d.]+)\s*,\s*([-\d.]+)\s*,\s*([-\d.]+)\s*,\s*([-\d.]+)\s*\)', + transform_str + ) + + if matrix_match: + a, b, c, d, e, f = map(float, matrix_match.groups()) + # Apply matrix transformation + new_x = a * x + c * y + e + new_y = b * x + d * y + f + return new_x, new_y + + # Parse rotate transform: rotate(angle, cx, cy) or rotate(angle) + rotate_match = re.match( + r'rotate\s*\(\s*([-\d.]+)\s*(?:,\s*([-\d.]+)\s*,\s*([-\d.]+)\s*)?\)', + transform_str + ) + + if rotate_match: + angle = float(rotate_match.group(1)) + # Convert to radians + angle_rad = math.radians(angle) + + if rotate_match.group(2) and rotate_match.group(3): + # Has center point + cx = float(rotate_match.group(2)) + cy = float(rotate_match.group(3)) + # Translate to origin, rotate, translate back + x_translated = x - cx + y_translated = y - cy + new_x = x_translated * math.cos(angle_rad) - y_translated * math.sin(angle_rad) + cx + new_y = x_translated * math.sin(angle_rad) + y_translated * math.cos(angle_rad) + cy + else: + # No center point, rotate around origin (0,0) + new_x = x * math.cos(angle_rad) - y * math.sin(angle_rad) + new_y = x * math.sin(angle_rad) + y * math.cos(angle_rad) + + return new_x, new_y + + # Handle translate transforms + translate_match = re.match( + r'translate\s*\(\s*([-\d.]+)\s*(?:,\s*([-\d.]+)\s*)?\)', + transform_str + ) + if translate_match: + tx = float(translate_match.group(1)) + ty = float(translate_match.group(2)) if translate_match.group(2) else 0 + return x + tx, y + ty + + # Handle scale transforms: scale(sx, sy) or scale(s) + scale_match = re.match( + r'scale\s*\(\s*([-\d.]+)\s*(?:,\s*([-\d.]+)\s*)?\)', + transform_str + ) + if scale_match: + sx = float(scale_match.group(1)) + sy = float(scale_match.group(2)) if scale_match.group(2) else sx + return x * sx, y * sy + + # Return original point if transform not recognized + print(f"WARNING: Unsupported transform format: {transform_str}") + return x, y + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/infrastructure/wire_preprocessor.py b/sketchgetdp/svg_to_getdp/infrastructure/wire_preprocessor.py new file mode 100644 index 0000000..cab86d6 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/infrastructure/wire_preprocessor.py @@ -0,0 +1,295 @@ +import yaml +import math +from typing import List, Tuple, Any +from dataclasses import dataclass +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.physical_group import ( + DOMAIN_COIL_POSITIVE, + DOMAIN_COIL_NEGATIVE +) +from svg_to_getdp.interfaces.abstractions.wire_preprocessor_interface import WirePreprocessorInterface + + +@dataclass +class Wire: + """Represents a single wire.""" + point: Point + color: Color + original_index: int + + +@dataclass +class WireCluster: + """Represents a cluster of wires that are close to each other.""" + name: str # e.g., "cluster_1" + wire_count: int + current_sign: int # 1 for positive, -1 for negative + wires: List[Wire] = None + + def __post_init__(self): + if self.wires is None: + self.wires = [] + + +class WirePreprocessor(WirePreprocessorInterface): + """ + Preprocessor for wires that clusters them by proximity and creates Gmsh entities. + """ + + def __init__(self): + self.factory = None + self.wire_clusters: List[WireCluster] = [] + self.all_wires: List[Wire] = [] + + def prepare_wires(self, + factory: Any, + config_path: str, + wires: List[Tuple[Point, Color]]) -> dict: + """ + Prepare Gmsh entities for wires with physical groups using cluster configuration. + + Args: + factory: Gmsh factory object + config_path: Path to the YAML configuration file + wires: List of (point, color) tuples representing wires + + Returns: + Dictionary mapping wire indices to their Gmsh tags and physical groups + """ + self.factory = factory + # Load wire cluster configuration + self.wire_clusters = self._load_wire_clusters(config_path) + + if not wires: + print("Warning: No wires provided") + return {} + + # Convert given wires to Wire objects + self.all_wires = [Wire(point=p, color=c, original_index=i) + for i, (p, c) in enumerate(wires)] + + # First, sort all wires from top to bottom and left to right + sorted_wires = self._sort_wires(self.all_wires) + + # Validate total wire count matches cluster configuration + total_cluster_wires = sum(cluster.wire_count for cluster in self.wire_clusters) + if total_cluster_wires != len(sorted_wires): + raise ValueError( + f"Number of wires ({len(sorted_wires)}) doesn't match cluster configuration " + f"({total_cluster_wires} wires defined in {len(self.wire_clusters)} clusters)" + ) + + # Now cluster the wires based on proximity + self._perform_clustering(sorted_wires) + + # Create Gmsh entities and collect results + positive_point_tags = [] + negative_point_tags = [] + results = {} + + for cluster_idx, cluster in enumerate(self.wire_clusters): + for wire_idx_in_cluster, wire in enumerate(cluster.wires): + # Create Gmsh point entity + point_tag = self.factory.addPoint(wire.point.x, wire.point.y, 0.0) + + # Get physical group based on cluster + physical_group = self._get_physical_group_for_cluster(cluster) + + # Store point tag based on polarity + if physical_group == DOMAIN_COIL_POSITIVE: + positive_point_tags.append(point_tag) + elif physical_group == DOMAIN_COIL_NEGATIVE: + negative_point_tags.append(point_tag) + else: + raise ValueError(f"Unknown physical group type: {physical_group}") + + # Store results + results[wire.original_index] = { + 'point': wire.point, + 'color': wire.color, + 'gmsh_point_tag': point_tag, + 'physical_group': physical_group, + 'wire_index': wire.original_index, + 'wire_name': f"wire_{wire.original_index + 1}", + 'cluster_name': cluster.name, + 'wire_in_cluster_index': wire_idx_in_cluster, + 'cluster_index': cluster_idx + } + + # Create ONE physical group for all positive points + if positive_point_tags: + self.factory.addPhysicalGroup(0, positive_point_tags, DOMAIN_COIL_POSITIVE.value) + print(f"Created positive wire physical group (tag {DOMAIN_COIL_POSITIVE.value}) " + f"with {len(positive_point_tags)} wires") + + # Create ONE physical group for all negative points + if negative_point_tags: + self.factory.addPhysicalGroup(0, negative_point_tags, DOMAIN_COIL_NEGATIVE.value) + print(f"Created negative wire physical group (tag {DOMAIN_COIL_NEGATIVE.value}) " + f"with {len(negative_point_tags)} wires") + + # Print summary + print(f"Total wires processed: {len(wires)}") + print(f" Positive: {len(positive_point_tags)}") + print(f" Negative: {len(negative_point_tags)}") + print(f" Clusters: {len(self.wire_clusters)}") + + return results + + def _load_wire_clusters(self, config_path: str) -> List[WireCluster]: + """ + Load wire cluster configuration from the YAML configuration file. + + Args: + config_path: Path to the configuration file + + Returns: + List of WireCluster objects sorted by cluster name + """ + try: + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + + if 'wire_clusters' not in config: + raise ValueError("Config file must contain 'wire_clusters' section") + + wire_clusters_config = config['wire_clusters'] + + if not isinstance(wire_clusters_config, dict): + raise ValueError("'wire_clusters' must be a dictionary") + + # Create clusters from configuration + clusters = [] + for cluster_name, cluster_config in wire_clusters_config.items(): + if not isinstance(cluster_config, dict): + raise ValueError(f"Cluster '{cluster_name}' configuration must be a dictionary") + + if 'wire_count' not in cluster_config: + raise ValueError(f"Cluster '{cluster_name}' must have 'wire_count'") + + if 'current_sign' not in cluster_config: + raise ValueError(f"Cluster '{cluster_name}' must have 'current_sign'") + + wire_count = cluster_config['wire_count'] + current_sign = cluster_config['current_sign'] + + # Validate current_sign + if current_sign not in [1, -1]: + raise ValueError(f"Cluster '{cluster_name}': current_sign must be 1 or -1, got {current_sign}") + + # Validate wire_count + if not isinstance(wire_count, int) or wire_count <= 0: + raise ValueError(f"Cluster '{cluster_name}': wire_count must be a positive integer, got {wire_count}") + + clusters.append(WireCluster( + name=cluster_name, + wire_count=wire_count, + current_sign=current_sign + )) + + # Sort clusters by name to ensure consistent ordering + clusters.sort(key=lambda c: c.name) + + if not clusters: + raise ValueError("No wire clusters defined in configuration") + + return clusters + + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found: {config_path}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in configuration file: {e}") + + def _sort_wires(self, wires: List[Wire]) -> List[Wire]: + """ + Sort wires from top to bottom and left to right. + + Args: + wires: List of Wire objects + + Returns: + Sorted list of wires + """ + return sorted(wires, key=lambda w: (-w.point.y, w.point.x)) + + def _calculate_distance(self, wire1: Wire, wire2: Wire) -> float: + """ + Calculate Euclidean distance between two wires. + + Args: + wire1: First wire + wire2: Second wire + + Returns: + Distance between wires + """ + dx = wire1.point.x - wire2.point.x + dy = wire1.point.y - wire2.point.y + return math.sqrt(dx*dx + dy*dy) + + def _perform_clustering(self, sorted_wires: List[Wire]): + """ + Perform proximity-based clustering of wires. + + Args: + sorted_wires: All wires sorted from top to bottom, left to right + """ + available_wires = sorted_wires.copy() + + for cluster in self.wire_clusters: + # Clear any existing wires in cluster + cluster.wires.clear() + + if not available_wires: + raise ValueError(f"Not enough wires for cluster {cluster.name}. " + f"Need {cluster.wire_count}, but no wires left.") + + # Start with the first available wire as seed + seed_wire = available_wires[0] + cluster.wires.append(seed_wire) + available_wires.remove(seed_wire) + + # Find remaining wires for this cluster + while len(cluster.wires) < cluster.wire_count: + if not available_wires: + raise ValueError(f"Not enough wires for cluster {cluster.name}. " + f"Need {cluster.wire_count}, but only have {len(cluster.wires)}.") + + # Find the closest wire to any wire already in the cluster + closest_wire = None + min_distance = float('inf') + + for cluster_wire in cluster.wires: + for candidate_wire in available_wires: + distance = self._calculate_distance(cluster_wire, candidate_wire) + if distance < min_distance: + min_distance = distance + closest_wire = candidate_wire + + if closest_wire is None: + raise ValueError(f"Cannot find wire close enough for cluster {cluster.name}") + + cluster.wires.append(closest_wire) + available_wires.remove(closest_wire) + + # Sort wires within cluster for consistent ordering + cluster.wires.sort(key=lambda w: (-w.point.y, w.point.x)) + + def _get_physical_group_for_cluster(self, cluster: WireCluster): + """ + Get the appropriate physical group for a cluster. + + Args: + cluster: WireCluster object + + Returns: + Appropriate PhysicalGroup instance + """ + if cluster.current_sign == 1: + return DOMAIN_COIL_POSITIVE + elif cluster.current_sign == -1: + return DOMAIN_COIL_NEGATIVE + else: + raise ValueError(f"Invalid current sign {cluster.current_sign} for cluster {cluster.name}") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/__init__.py b/sketchgetdp/svg_to_getdp/interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/__init__.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/bezier_fitter_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/bezier_fitter_interface.py new file mode 100644 index 0000000..4728742 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/bezier_fitter_interface.py @@ -0,0 +1,33 @@ +""" +Interface for Bézier curve fitting operations. +Defines the contract for fitting Bézier curves to boundary point data. +""" + +from abc import ABC, abstractmethod +from typing import List +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.outline import Outline + +class BezierFitterInterface(ABC): + """ + Defines the interface for fitting piecewise Bézier curves. + Implementations should handle corner detection, continuity enforcement, and curve optimization. + """ + + @abstractmethod + def fit_outline(self, points: List[Point], corner_indices: List[int], + color, is_closed: bool = True) -> Outline: + """ + Fit piecewise Bézier curves with optimized continuity and accuracy. + + Args: + points: List of outline points to fit curves to + corner_indices: Indices of points that represent sharp corners + color: Visual color representation for the outline + is_closed: Whether the outline forms a closed loop + + Returns: + Outline object containing fitted Bézier segments and corner information + """ + pass + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/corner_detector_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/corner_detector_interface.py new file mode 100644 index 0000000..dae5e9f --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/corner_detector_interface.py @@ -0,0 +1,19 @@ +""" +Interface for corner detection operations. +""" + +from abc import ABC, abstractmethod +from typing import List +from svg_to_getdp.core.entities.point import Point + +class CornerDetectorInterface(ABC): + """ + Abstract interface for corner detection. + """ + + @abstractmethod + def detect_corners(self, boundary_points: List[Point]) -> List[int]: + """ + Identifies indices of corner points in the boundary point sequence. + """ + pass \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_grouper_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_grouper_interface.py new file mode 100644 index 0000000..bb1894c --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_grouper_interface.py @@ -0,0 +1,30 @@ +""" +Interface for grouping outlines into hierarchical structures. +Defines the contract for analyzing containment relationships and assigning physical groups. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict +from svg_to_getdp.core.entities.outline import Outline + +class OutlineGrouperInterface(ABC): + """ + Defines the interface for grouping outlines into hierarchical structures + with containment relationships and assigning appropriate physical groups. + """ + + @abstractmethod + def group_outlines(self, outlines: List[Outline]) -> List[Dict]: + """ + Group outlines into hierarchical structure and assign physical groups. + + Args: + outlines: List of outlines to process + + Returns: + List of dictionaries, one per outline, containing: + - "holes": List of indices of outlines contained by this outline + - "physical_groups": List of PhysicalGroup objects for this outline + """ + pass + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_preprocessor_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_preprocessor_interface.py new file mode 100644 index 0000000..4a20022 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/outline_preprocessor_interface.py @@ -0,0 +1,36 @@ +""" +Interface for outline preprocessing operations. +Defines the contract for converting Outline objects into Gmsh geometry. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any +from svg_to_getdp.core.entities.outline import Outline + + +class OutlinePreprocessorInterface(ABC): + """ + Defines the interface for preprocessing Outline objects in Gmsh. + Implementations should handle geometry creation, physical group assignment, + and proper hole/surface relationships. + """ + + @abstractmethod + def preprocess_outlines(self, + factory: Any, + outlines: List[Outline], + properties: List[Dict[str, Any]]) -> None: + """ + Preprocess all outlines with their properties. + + Args: + factory: Gmsh geometry factory (gmsh.model.geo) + outlines: List of Outline objects to preprocess + properties: List of dictionaries with "holes" and "physical_groups" keys + Each dictionary corresponds to the outline at the same index + + Raises: + ValueError: When number of outlines doesn't match number of property dictionaries + """ + pass + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/svg_parser_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/svg_parser_interface.py new file mode 100644 index 0000000..2240f79 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/svg_parser_interface.py @@ -0,0 +1,31 @@ +""" +Interface for SVG parsing operations. +""" + +from abc import ABC, abstractmethod +from typing import Dict, List +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.raw_outline import RawOutline + + +class SVGParserInterface(ABC): + """ + Abstract interface for SVG parsing. + """ + + @abstractmethod + def extract_raw_outlines_by_color(self, svg_file_path: str) -> Dict[Color, List[RawOutline]]: + """ + Parse SVG file and extract raw_outlines grouped by color. + + Args: + svg_file_path: Path to the SVG file + + Returns: + Dictionary mapping colors to lists of RawOutline objects containing raw points. + + Raises: + ValueError: If the SVG file is invalid or cannot be parsed + """ + pass + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/abstractions/wire_preprocessor_interface.py b/sketchgetdp/svg_to_getdp/interfaces/abstractions/wire_preprocessor_interface.py new file mode 100644 index 0000000..867294d --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/abstractions/wire_preprocessor_interface.py @@ -0,0 +1,34 @@ +""" +Interface for wire preprocessing operations. +Defines the contract for creating Gmsh entities from wires with physical groups. +Prepares wire geometry for meshing but doesn't perform the meshing itself. +""" + +from abc import ABC, abstractmethod +from typing import List, Tuple, Dict, Any +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + +class WirePreprocessorInterface(ABC): + """ + Defines the interface for creating Gmsh entities for wires. + Implementations should handle wire sorting, physical group assignment, and preparation for meshing. + """ + + @abstractmethod + def prepare_wires(self, + factory: Any, + config_path: str, + wires: List[Tuple[Point, Color]]) -> Dict[int, Dict[str, Any]]: + """ + Prepare Gmsh entities for wires with physical groups. + + Args: + factory: Gmsh factory object + config_path: Path to the YAML configuration file + wires: List of (point, color) tuples representing wires + + Returns: + Dictionary mapping wire indices to their Gmsh tags and physical groups. + """ + pass diff --git a/sketchgetdp/svg_to_getdp/interfaces/arg_parser.py b/sketchgetdp/svg_to_getdp/interfaces/arg_parser.py new file mode 100644 index 0000000..81b1d7f --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/arg_parser.py @@ -0,0 +1,109 @@ +import argparse +from typing import List + +class ArgParser: + """Command line argument parser for SVG to Getdp converter""" + + def parse_args(self, args: List[str] = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description='Convert SVG sketches to Getdp-compatible geometry, mesh with Gmsh and simulate with Getdp.', + epilog=( + 'Examples:\n' + ' python -m svg_to_getdp drawing.svg\n' + ' python -m svg_to_getdp design.svg --debug\n' + ' python -m svg_to_getdp design.svg --mesh-name my_mesh --no-gui\n' + ' python -m svg_to_getdp design.svg --run-simulation\n' + ' python -m svg_to_getdp --simulation-only existing_mesh.msh\n' + ' python -m svg_to_getdp design.svg --config custom_config.yaml --run-simulation --no-gui\n' + ), + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + # SVG file argument (optional for simulation-only mode) + parser.add_argument( + 'svg_file', + nargs='?', # Make it optional for simulation-only mode + help='Path to SVG file to process (required unless --simulation-only is used)' + ) + + # GetDP Simulation options + simulation_group = parser.add_argument_group('GetDP Simulation Options') + + simulation_group.add_argument( + '--run-simulation', '-s', + action='store_true', + help='Run GetDP simulation after mesh generation (full pipeline: SVG → Gmsh → GetDP)' + ) + + simulation_group.add_argument( + '--simulation-only', + metavar='MESH_FILE', + type=str, + help='Run GetDP simulation on an existing mesh file (skip SVG conversion and Gmsh)' + ) + + # Gmsh meshing options + parser.add_argument( + '--config', + default='config.yaml', + type=str, + help='Path to YAML configuration file for coil currents, mesh settings and simulation parameters (default: config.yaml)' + ) + + parser.add_argument( + '--mesh-name', + type=str, + help='Name for the output mesh file (without .msh extension). ' + 'If not specified, uses the SVG filename.' + ) + + parser.add_argument( + '--no-gui', + action='store_true', + help='Disable Gmsh and GetDP GUI display (run in batch mode)' + ) + + # Debug options + parser.add_argument( + '--debug', '-d', + action='store_true', + help='Enable debug mode to output intermediate processing information including geometry plots and text summaries' + ) + + # Parse arguments + parsed_args = parser.parse_args(args) + + # Validate arguments + self._validate_args(parser, parsed_args) + + return parsed_args + + def _validate_args(self, parser: argparse.ArgumentParser, args: argparse.Namespace) -> None: + """Validate the parsed arguments for logical consistency.""" + + # If simulation-only mode is used + if args.simulation_only: + # Check that SVG file is not also provided (they're mutually exclusive) + if args.svg_file: + parser.error("Cannot specify both SVG file and --simulation-only. " + "Use --simulation-only alone for existing meshes.") + + # Check that run-simulation is not also specified (redundant) + if args.run_simulation: + parser.error("Cannot use both --run-simulation and --simulation-only. " + "Use --run-simulation for full pipeline or --simulation-only for existing mesh.") + + # Check that mesh-only related options are not used + if args.mesh_name: + parser.error("Cannot use --mesh-name with --simulation-only. " + "Mesh name is derived from the provided mesh file.") + + # If normal mode (not simulation-only) + else: + # Check that SVG file is provided + if not args.svg_file: + parser.error("SVG file is required unless --simulation-only is used") + + # If run-simulation is used with no SVG file (shouldn't happen due to above check) + if args.run_simulation and not args.svg_file: + parser.error("SVG file is required for --run-simulation") diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/__init__.py b/sketchgetdp/svg_to_getdp/interfaces/debug/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/corner_detector_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/corner_detector_debug_writer.py new file mode 100644 index 0000000..bedf207 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/corner_detector_debug_writer.py @@ -0,0 +1,416 @@ +from datetime import datetime +from typing import Optional +from svg_to_getdp.core.entities.raw_outline import RawOutline +from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + + +class CornerDetectorDebugWriter(DebugCoordinator): + """Handles writing debug information for corner detection.""" + + def __init__(self): + super().__init__() + + def write_corner_detection_debug_info(self, svg_file_path: str, + corner_debug_data: dict, + raw_outlines_by_color: Optional[dict] = None): + """ + Write detailed corner detection debug information. + + Args: + svg_file_path: Path to the SVG file + corner_debug_data: Debug data from corner detection + raw_outlines_by_color: Raw outlines organized by color (optional) + """ + self.set_svg_file(svg_file_path) + debug_filename = self.get_debug_filename("corner_detection_debug", ".txt") + + with open(debug_filename, 'w') as f: + self._write_corner_detection_header(f, svg_file_path, corner_debug_data) + + # Check if we have data + if not corner_debug_data: + f.write("\nNO CORNER DEBUG DATA AVAILABLE\n") + return + + # Process each outline + for key, data in corner_debug_data.items(): + self._write_outline_corner_analysis(f, key, data, raw_outlines_by_color) + + print(f"Corner detection debug information written to: {debug_filename}") + + def write_detailed_decision_process(self, svg_file_path: str, corner_debug_data: dict): + """ + Write even more detailed decision process for advanced debugging. + """ + detailed_filename = self.get_debug_filename(svg_file_path, "corner_decisions_detailed", ".txt") + + with open(detailed_filename, 'w') as f: + f.write("DETAILED CORNER DETECTION DECISION PROCESS\n") + f.write("=" * 80 + "\n\n") + + f.write(f"Input SVG: {svg_file_path}\n") + f.write(f"Processed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Debug run timestamp: {self.get_shared_timestamp()}\n") + f.write(f"Total outlines analyzed: {len(corner_debug_data)}\n\n") + + for key, data in corner_debug_data.items(): + f.write(f"\n{'='*100}\n") + f.write(f"DETAILED PROCESS FOR: {key}\n") + f.write(f"{'='*100}\n\n") + + # Write extremely detailed information + self._write_extremely_detailed_analysis(f, data) + + print(f"Detailed decision process written to: {detailed_filename}") + + def _write_corner_detection_header(self, f, svg_file_path: str, corner_debug_data: dict): + """Write header for corner detection debug file.""" + f.write("CORNER DETECTION DEBUG INFORMATION\n") + f.write("=" * 60 + "\n\n") + + f.write(f"Input SVG: {svg_file_path}\n") + f.write(f"Processed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Debug run timestamp: {self.get_shared_timestamp()}\n") + f.write(f"Total outlines analyzed: {len(corner_debug_data) if corner_debug_data else 0}\n\n") + + def _write_outline_corner_analysis(self, f, key: str, data: dict, raw_outlines_by_color: Optional[dict]): + """Write detailed analysis for a single outline using raw outlines.""" + f.write(f"\n{'='*80}\n") + f.write(f"RAW OUTLINE ANALYSIS: {key}\n") + f.write(f"{'='*80}\n\n") + + # Basic info - with safety checks + color_name = data.get('color', 'N/A') + outline_index = data.get('outline_index', 'N/A') + f.write(f"Basic Information:\n") + f.write(f" Color: {color_name}\n") + f.write(f" Raw Outline Index: {outline_index}\n") + f.write(f" Total Points: {data.get('points_count', 'N/A')}\n") + f.write(f" Is Closed: {data.get('is_closed', 'N/A')}\n") + f.write(f" Final Corners: {len(data.get('corner_indices', []))}\n\n") + + # Try to find the corresponding raw outline + if raw_outlines_by_color: + raw_outline = self._find_raw_outline(raw_outlines_by_color, color_name, outline_index) + if raw_outline: + f.write(f"Raw Outline Information:\n") + f.write(f" Number of Points: {len(raw_outline.points)}\n") + f.write(f" Is Closed: {raw_outline.is_closed}\n") + f.write(f" First Point: ({raw_outline.points[0].x:.6f}, {raw_outline.points[0].y:.6f})\n") + f.write(f" Last Point: ({raw_outline.points[-1].x:.6f}, {raw_outline.points[-1].y:.6f})\n\n") + + debug_info = data.get('debug', {}) + + if not debug_info: + f.write("NO DEBUG INFO AVAILABLE FOR THIS OUTLINE\n\n") + return + + # Shape analysis + self._write_shape_analysis(f, debug_info.get('shape_analysis', {})) + + # Candidate detection + self._write_candidate_detection(f, debug_info.get('candidate_detection', {})) + + # Strength calculations + self._write_strength_calculations(f, debug_info.get('strength_calculations', {})) + + # Clustering + self._write_clustering_info(f, debug_info.get('clustering', {})) + + # Refinement - handle the new structure + refinement_details = debug_info.get('refinement_details', []) + clustering_info = debug_info.get('clustering', {}) + self._write_refinement_info(f, refinement_details, clustering_info) + + # Final decisions - enhanced with raw outline points if available + final_info = debug_info.get('final_decisions', {}) + if raw_outlines_by_color: + raw_outline = self._find_raw_outline(raw_outlines_by_color, color_name, outline_index) + if raw_outline: + self._write_final_decisions_with_points(f, final_info, raw_outline) + else: + self._write_final_decisions(f, final_info) + else: + self._write_final_decisions(f, final_info) + + # Process steps + self._write_process_steps(f, debug_info.get('all_steps', [])) + + def _find_raw_outline(self, raw_outlines_by_color: dict, color_name: str, outline_index: int) -> Optional[RawOutline]: + """ + Find a raw outline by color name and index. + + Args: + raw_outlines_by_color: Dictionary of raw outlines grouped by color + color_name: Name of the color (e.g., 'GREEN', 'BLUE') + outline_index: Index of the outline within that color group + + Returns: + RawOutline object if found, None otherwise + """ + try: + # Convert color name to Color enum if needed + from svg_to_getdp.core.entities.color import Color + color_map = { + 'RED': Color.RED, + 'GREEN': Color.GREEN, + 'BLUE': Color.BLUE, + 'BLACK': Color.BLACK + } + + color = color_map.get(color_name.upper()) + if not color or color not in raw_outlines_by_color: + return None + + raw_outlines = raw_outlines_by_color[color] + if outline_index < 0 or outline_index >= len(raw_outlines): + return None + + return raw_outlines[outline_index] + + except (KeyError, IndexError, AttributeError): + return None + + def _write_final_decisions_with_points(self, f, final_info: dict, raw_outline: RawOutline): + """Write final decisions section with actual point coordinates from raw outline.""" + final_corners = final_info.get('final_corners', []) + corner_strengths = final_info.get('corner_strengths', {}) + + f.write("FINAL DECISIONS WITH POINT COORDINATES:\n") + f.write(f" Total Final Corners: {len(final_corners)}\n") + + if final_corners: + f.write(f" Final Corner Indices: {sorted(final_corners)}\n\n") + + f.write(f" Corner Details (from raw outline):\n") + for idx in sorted(final_corners): + if 0 <= idx < len(raw_outline.points): + point = raw_outline.points[idx] + strength = corner_strengths.get(idx, 0) + f.write(f" Point {idx:4d}: ({point.x:.6f}, {point.y:.6f}) " + f"[strength={strength:.3f}]\n") + else: + f.write(f" Point {idx:4d}: INDEX OUT OF BOUNDS (raw outline has {len(raw_outline.points)} points)\n") + + f.write("\n") + + def _write_shape_analysis(self, f, shape_info: dict): + """Write shape analysis section.""" + f.write("SHAPE ANALYSIS:\n") + + if 'early_ellipse_detection' in shape_info and shape_info['early_ellipse_detection']: + f.write(f" ❌ EARLY REJECTION: {shape_info.get('ellipse_reason', 'Ellipse detected')}\n") + return + + if 'too_smooth' in shape_info and shape_info['too_smooth']: + f.write(f" ❌ REJECTION: Shape too smooth (score={shape_info.get('smoothness_score', 0):.3f})\n") + return + + if 'too_small' in shape_info and shape_info['too_small']: + f.write(f" ❌ REJECTION: Shape too small for analysis\n") + return + + if 'small_ellipse' in shape_info and shape_info['small_ellipse']: + f.write(f" ❌ REJECTION: Small ellipse detected\n") + return + + f.write(f" Smoothness Score: {shape_info.get('smoothness_score', 'N/A')}\n") + f.write(f" Is Ellipse: {shape_info.get('is_ellipse', 'N/A')}\n") + + if 'bounding_box' in shape_info: + bbox = shape_info['bounding_box'] + f.write(f" Bounding Box:\n") + f.write(f" X: [{bbox['x_min']:.6f}, {bbox['x_max']:.6f}] (width: {bbox['width']:.6f})\n") + f.write(f" Y: [{bbox['y_min']:.6f}, {bbox['y_max']:.6f}] (height: {bbox['height']:.6f})\n") + + f.write("\n") + + def _write_candidate_detection(self, f, cand_info: dict): + """Write candidate detection section.""" + f.write("CANDIDATE DETECTION:\n") + + angle_corners = cand_info.get('angle_method', []) + direction_corners = cand_info.get('direction_method', []) + curvature_corners = cand_info.get('curvature_method', []) + all_candidates = cand_info.get('all_candidates', []) + + f.write(f" Method Results:\n") + f.write(f" Angle Method: {len(angle_corners):3d} candidates\n") + f.write(f" Direction Method: {len(direction_corners):3d} candidates\n") + f.write(f" Curvature Method: {len(curvature_corners):3d} candidates\n") + f.write(f" Total Unique: {len(all_candidates):3d} candidates\n\n") + + # Show combined votes if available + if 'combined_votes' in cand_info: + combined = cand_info['combined_votes'] + if combined: + f.write(f" Combined Votes (Top 20):\n") + sorted_votes = sorted(combined.items(), key=lambda x: x[1], reverse=True)[:20] + for idx, votes in sorted_votes: + f.write(f" Point {idx:4d}: {votes:.2f} votes\n") + f.write("\n") + + # Show coarse corners + coarse_corners = cand_info.get('coarse_corners', []) + if coarse_corners: + f.write(f" Coarse Corners (after initial filtering): {len(coarse_corners)}\n") + f.write(f" Indices: {sorted(coarse_corners)}\n") + f.write("\n") + + def _write_strength_calculations(self, f, strengths: dict): + """Write strength calculations section.""" + if not strengths: + return + + f.write("CORNER STRENGTH CALCULATIONS:\n") + + # Show top strengths + if len(strengths) <= 30: + f.write(f" All Candidate Strengths:\n") + sorted_strengths = sorted(strengths.items(), key=lambda x: x[1], reverse=True) + for idx, strength in sorted_strengths: + f.write(f" Point {idx:4d}: strength={strength:.3f}\n") + else: + f.write(f" Top 30 Candidate Strengths:\n") + sorted_strengths = sorted(strengths.items(), key=lambda x: x[1], reverse=True)[:30] + for idx, strength in sorted_strengths: + f.write(f" Point {idx:4d}: strength={strength:.3f}\n") + + f.write("\n") + + def _write_clustering_info(self, f, clustering_info: dict): + """Write clustering information section.""" + clusters = clustering_info.get('clusters', []) + if not clusters: + return + + f.write("CLUSTERING RESULTS:\n") + f.write(f" Number of clusters: {len(clusters)}\n") + + for i, cluster in enumerate(clusters): + f.write(f" Cluster {i}: {cluster}\n") + if len(cluster) > 1: + f.write(f" Size: {len(cluster)} candidates\n") + f.write(f" Range: {min(cluster)} to {max(cluster)} " + f"(span: {max(cluster) - min(cluster)} points)\n") + + f.write("\n") + + def _write_refinement_info(self, f, refinement_details: list, clustering_info: dict): + """Write refinement information section.""" + if not refinement_details: + f.write("REFINEMENT PROCESS:\n") + f.write(" No refinement details available\n\n") + return + + f.write("REFINEMENT PROCESS:\n") + + for i, cluster_info in enumerate(refinement_details): + f.write(f" Cluster {i}:\n") + f.write(f" Candidates: {cluster_info.get('cluster', [])}\n") + f.write(f" Best Candidate: {cluster_info.get('best_candidate', 'N/A')}\n") + f.write(f" Refined To: {cluster_info.get('refined_candidate', 'N/A')}\n") + if 'refined_strength' in cluster_info: + f.write(f" Refined Strength: {cluster_info['refined_strength']:.3f}\n") + if 'accepted' in cluster_info: + f.write(f" Accepted: {cluster_info['accepted']}\n") + + if 'refined_corners' in clustering_info: + refined = clustering_info.get('refined_corners', []) + f.write(f"\n Refinement Results:\n") + f.write(f" Refined Corners: {refined}\n") + f.write(f" Count: {len(refined)}\n") + + if 'quality_corners' in clustering_info: + quality = clustering_info.get('quality_corners', []) + f.write(f" Quality Corners: {quality}\n") + f.write(f" Count: {len(quality)}\n") + + f.write("\n") + + def _write_final_decisions(self, f, final_info: dict): + """Write final decisions section (original version without point coordinates).""" + final_corners = final_info.get('final_corners', []) + corner_coords = final_info.get('corner_coordinates', {}) + corner_strengths = final_info.get('corner_strengths', {}) + + f.write("FINAL DECISIONS:\n") + f.write(f" Total Final Corners: {len(final_corners)}\n") + + if final_corners: + f.write(f" Final Corner Indices: {sorted(final_corners)}\n\n") + + f.write(f" Corner Details:\n") + for idx in sorted(final_corners): + point = corner_coords.get(idx) + strength = corner_strengths.get(idx, 0) + if point: + f.write(f" Point {idx:4d}: ({point.x:.6f}, {point.y:.6f}) " + f"[strength={strength:.3f}]\n") + + f.write("\n") + + def _write_process_steps(self, f, steps: list): + """Write process steps section.""" + if not steps: + return + + f.write("PROCESS STEPS:\n") + for i, step in enumerate(steps, 1): + f.write(f" {i:3d}. {step}\n") + f.write("\n") + + def _write_extremely_detailed_analysis(self, f, data: dict): + """Write extremely detailed analysis for a outline.""" + debug_info = data['debug'] + + # Write complete shape analysis + shape_info = debug_info.get('shape_analysis', {}) + f.write("COMPLETE SHAPE ANALYSIS:\n") + for key, value in shape_info.items(): + if key != 'bounding_box': + f.write(f" {key}: {value}\n") + + if 'bounding_box' in shape_info: + bbox = shape_info['bounding_box'] + f.write(f" bounding_box:\n") + for bkey, bvalue in bbox.items(): + f.write(f" {bkey}: {bvalue}\n") + f.write("\n") + + # Write complete candidate information + cand_info = debug_info.get('candidate_detection', {}) + if 'angle_method' in cand_info: + f.write(f"Angle Method Candidates ({len(cand_info['angle_method'])}):\n") + f.write(f" {cand_info['angle_method']}\n") + + if 'direction_method' in cand_info: + f.write(f"\nDirection Method Candidates ({len(cand_info['direction_method'])}):\n") + f.write(f" {cand_info['direction_method']}\n") + + if 'curvature_method' in cand_info: + f.write(f"\nCurvature Method Candidates ({len(cand_info['curvature_method'])}):\n") + f.write(f" {cand_info['curvature_method']}\n") + + if 'all_candidates' in cand_info: + f.write(f"\nAll Unique Candidates ({len(cand_info['all_candidates'])}):\n") + f.write(f" {sorted(cand_info['all_candidates'])}\n") + + f.write("\n") + + # Write all strength calculations + strengths = debug_info.get('strength_calculations', {}) + if strengths: + f.write("ALL STRENGTH CALCULATIONS:\n") + for idx, strength in sorted(strengths.items()): + f.write(f" Point {idx:4d}: {strength:.6f}\n") + f.write("\n") + + # Write decision steps + steps = debug_info.get('all_steps', []) + if steps: + f.write("DECISION STEPS:\n") + for step in steps: + f.write(f" {step}\n") + f.write("\n") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/debug_coordinator.py b/sketchgetdp/svg_to_getdp/interfaces/debug/debug_coordinator.py new file mode 100644 index 0000000..3edcd84 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/debug_coordinator.py @@ -0,0 +1,59 @@ +import os +from datetime import datetime + + +class DebugCoordinator: + """Main utility class for coordinating all debug writing operations.""" + + def __init__(self): + """Initialize DebugCoordinator with a shared timestamp for all debug outputs.""" + self._shared_timestamp = None + self._svg_file_path = None + self._svg_name = None + + def set_svg_file(self, svg_file_path: str): + """Set the SVG file being processed.""" + self._svg_file_path = svg_file_path + svg_filename = os.path.basename(svg_file_path) + self._svg_name = os.path.splitext(svg_filename)[0] + + def get_shared_timestamp(self) -> str: + """ + Get a shared timestamp for all debug outputs in this run. + Creates a new timestamp on first call, reuses it for subsequent calls. + """ + if self._shared_timestamp is None: + self._shared_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return self._shared_timestamp + + def get_debug_directory(self) -> str: + """Get the debug directory path, creating it if necessary.""" + debug_dir = "debug" + os.makedirs(debug_dir, exist_ok=True) + return debug_dir + + def get_debug_filename(self, prefix: str, extension: str = ".txt") -> str: + """Generate a debug filename with timestamp.""" + if not self._svg_name: + raise ValueError("SVG file not set. Call set_svg_file() first.") + + timestamp = self.get_shared_timestamp() + debug_dir = self.get_debug_directory() + return f"{debug_dir}/{prefix}_{self._svg_name}_{timestamp}{extension}" + + def get_debug_plot_filename(self, prefix: str = "geometry_plot", extension: str = ".png") -> str: + """Generate a debug plot filename with timestamp.""" + return self.get_debug_filename(prefix, extension) + + def get_svg_name(self) -> str: + """Get the base name of the SVG file.""" + if not self._svg_name: + raise ValueError("SVG file not set. Call set_svg_file() first.") + return self._svg_name + + def get_svg_file_path(self) -> str: + """Get the full SVG file path.""" + if not self._svg_file_path: + raise ValueError("SVG file not set. Call set_svg_file() first.") + return self._svg_file_path + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_debug_writer.py new file mode 100644 index 0000000..6bfea78 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_debug_writer.py @@ -0,0 +1,123 @@ +from datetime import datetime +from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + + +class GeometryDebugWriter(DebugCoordinator): + """Handles writing debug information for geometry conversion.""" + + def __init__(self): + super().__init__() + + def write_geometry_debug_info(self, svg_file_path: str, outlines, wires): + """ + Write geometry conversion results to a debug text file. + Follows the same structure as write_svg_parser_debug_info. + """ + self.set_svg_file(svg_file_path) + debug_filename = self.get_debug_filename("geometry_debug", ".txt") + + with open(debug_filename, 'w') as f: + f.write(f"Geometry Conversion Debug Information\n") + f.write(f"=====================================\n") + f.write(f"Input SVG: {svg_file_path}\n") + f.write(f"Processed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Debug run timestamp: {self.get_shared_timestamp()}\n") + f.write(f"\n") + + f.write(f"Summary:\n") + f.write(f" Total outlines: {len(outlines)}\n") + f.write(f" Total wires: {len(wires)}\n") + f.write(f"\n") + + # Outlines Section + f.write(f"OUTLINES\n") + f.write(f"===============\n\n") + + for i, outline in enumerate(outlines): + f.write(f"Outline {i+1}:\n") + f.write(f" Color: {outline.color.name}\n") + f.write(f" Segments: {len(outline.bezier_segments)}\n") + f.write(f" Corners: {len(outline.corners)}\n") + f.write(f" Closed: {outline.is_closed}\n") + + # Segment details with control points + f.write(f" Segments:\n") + for seg_idx, segment in enumerate(outline.bezier_segments): + f.write(f" Segment {seg_idx} (Degree {segment.degree}):\n") + for cp_idx, control_point in enumerate(segment.control_points): + f.write(f" Control Point {cp_idx}: ({control_point.x:.6f}, {control_point.y:.6f})\n") + + # Corner coordinates + if outline.corners: + f.write(f" Corners:\n") + for corner_idx, corner in enumerate(outline.corners): + f.write(f" Corner {corner_idx}: ({corner.x:.6f}, {corner.y:.6f})\n") + + # Sample points along the outline + f.write(f" Sampled Outline Points (t=0 to 1):\n") + for t in [0.0, 0.25, 0.5, 0.75, 1.0]: + point = outline.evaluate(t) + f.write(f" t={t:.2f}: ({point.x:.6f}, {point.y:.6f})\n") + + f.write(f"\n") + + # Wires Section + f.write(f"WIRES\n") + f.write(f"=====\n\n") + + for i, (point, color) in enumerate(wires): + f.write(f"Wire {i+1}:\n") + f.write(f" Color: {color.name}\n") + f.write(f" Position: ({point.x:.6f}, {point.y:.6f})\n\n") + + print(f"Geometry debug information written to: {debug_filename}") + return debug_filename + + @staticmethod + def save_results(outlines, wires, output_path: str): + """Save conversion results to file with coordinates.""" + with open(output_path, 'w') as f: + f.write("SVG to Geometry Conversion Results\n") + f.write("=" * 50 + "\n\n") + + # Outlines Section + f.write("OUTLINES\n") + f.write("=" * 50 + "\n\n") + + for i, outline in enumerate(outlines): + f.write(f"Outline {i+1}:\n") + f.write(f" Color: {outline.color.name}\n") + f.write(f" Segments: {len(outline.bezier_segments)}\n") + f.write(f" Corners: {len(outline.corners)}\n") + f.write(f" Closed: {outline.is_closed}\n") + + # Segment details with control points + f.write(" Segments:\n") + for seg_idx, segment in enumerate(outline.bezier_segments): + f.write(f" Segment {seg_idx} (Degree {segment.degree}):\n") + for cp_idx, control_point in enumerate(segment.control_points): + f.write(f" Control Point {cp_idx}: ({control_point.x:.6f}, {control_point.y:.6f})\n") + + # Corner coordinates + if outline.corners: + f.write(" Corners:\n") + for corner_idx, corner in enumerate(outline.corners): + f.write(f" Corner {corner_idx}: ({corner.x:.6f}, {corner.y:.6f})\n") + + # Sample points along the outline + f.write(" Sampled Outline Points (t=0 to 1):\n") + for t in [0.0, 0.25, 0.5, 0.75, 1.0]: + point = outline.evaluate(t) + f.write(f" t={t:.2f}: ({point.x:.6f}, {point.y:.6f})\n") + + f.write("\n") + + # Wires Section + f.write("WIRES\n") + f.write("=" * 50 + "\n\n") + + for i, (point, color) in enumerate(wires): + f.write(f"Wire {i+1}:\n") + f.write(f" Color: {color.name}\n") + f.write(f" Position: ({point.x:.6f}, {point.y:.6f})\n\n") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_visualizer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_visualizer.py new file mode 100644 index 0000000..3624778 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/geometry_visualizer.py @@ -0,0 +1,265 @@ +""" +Presentation layer service for visualizing internal geometry. +""" + +import matplotlib.pyplot as plt +import os +from datetime import datetime +from typing import List +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + + +class GeometryVisualizer: + """Presentation service for visualizing outlines, Bézier segments, and raw polylines.""" + + @staticmethod + def _plot_single_outline(outline: Outline, outline_index: int, + show_control_points: bool, show_corners: bool, + color_in_legend: dict, corner_color_in_legend: dict): + """Plot a single outline.""" + # Use the actual RGB values from the Color object + rgb = outline.color.rgb + plot_color = (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0) # Normalize to 0-1 for matplotlib + + # Sample points along the entire outline + t_values = [i/200 for i in range(201)] # High resolution for smooth outlines + outline_points = [outline.evaluate(t) for t in t_values] + + x_outline = [p.x for p in outline_points] + y_outline = [p.y for p in outline_points] + + # Determine label for the outline (only add to legend if not already added for this color) + if outline.color.name not in color_in_legend: + label = f'{outline.color.name} Outlines' + color_in_legend[outline.color.name] = True + else: + label = None + + # Plot the outline itself + plt.plot(x_outline, y_outline, color=plot_color, linewidth=2, label=label) + + # Plot control points if requested + if show_control_points: + for seg_idx, segment in enumerate(outline.bezier_segments): + cp_x = [p.x for p in segment.control_points] + cp_y = [p.y for p in segment.control_points] + + # Plot control points without adding to legend + plt.plot(cp_x, cp_y, 'o--', color=plot_color, alpha=0.7, + linewidth=1, markersize=4) + + # Plot corners if requested + if show_corners and outline.corners: + corner_x = [c.x for c in outline.corners] + corner_y = [c.y for c in outline.corners] + + # Only add corner label to legend if not already added for this color + if outline.color.name not in corner_color_in_legend: + corner_label = f'{outline.color.name} Corners' + corner_color_in_legend[outline.color.name] = True + else: + corner_label = None + + plt.plot(corner_x, corner_y, 's', color=plot_color, + markersize=10, markerfacecolor='none', markeredgewidth=2, + label=corner_label) + + @staticmethod + def _plot_colored_outlines(colored_outlines: dict): + """Plot colored polyline outlines with lighter colors.""" + # Track which colors we've already added to the legend for raw outlines + raw_color_in_legend = {} + raw_point_color_in_legend = {} + + for color, raw_outlines in colored_outlines.items(): + for i, raw_outline in enumerate(raw_outlines): + rgb = raw_outline.color.rgb + + # Create lighter colors by blending with white + light_factor = 0.6 # 0.0 = original color, 1.0 = white + plot_color = ( + (1 - light_factor) * (rgb[0] / 255.0) + light_factor, + (1 - light_factor) * (rgb[1] / 255.0) + light_factor, + (1 - light_factor) * (rgb[2] / 255.0) + light_factor + ) + + x_points = [p.x for p in raw_outline.points] + y_points = [p.y for p in raw_outline.points] + + if raw_outline.is_closed and len(raw_outline.points) > 1: + x_points.append(raw_outline.points[0].x) + y_points.append(raw_outline.points[0].y) + + # Plot the polyline with lighter styling + linestyle = '-' if raw_outline.is_closed else '--' + + # Special handling for red dots (wires in raw form) + if raw_outline.color.name == 'RED' and len(raw_outline.points) == 1: + # Use light red for single red points + light_red = (1.0, 0.7, 0.7) # Light red + + # Only add to legend once for red points + if raw_outline.color.name not in raw_point_color_in_legend: + label = 'Raw RED Points' + raw_point_color_in_legend[raw_outline.color.name] = True + else: + label = None + + plt.plot(x_points, y_points, 'x', color=light_red, markersize=8, + markeredgewidth=1.5, alpha=0.7, label=label) + else: + # For polylines, use lighter colors and thinner lines + # Only add to legend once per color for raw polylines + if raw_outline.color.name not in raw_color_in_legend: + label = f'Raw {raw_outline.color.name} Polylines' + raw_color_in_legend[raw_outline.color.name] = True + else: + label = None + + plt.plot(x_points, y_points, linestyle, color=plot_color, + linewidth=1.0, alpha=0.6, marker='.', markersize=4, + label=label) + + @staticmethod + def _plot_wires(wires: List[tuple]): + """Plot point wires.""" + # Track which wire colors we've already added to the legend + wire_color_in_legend = {} + + for point, color in wires: + # Use the actual RGB values from the Color object + rgb = color.rgb + plot_color = (rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0) # Normalize to 0-1 for matplotlib + + # Only add wire label to legend once per color + if color.name not in wire_color_in_legend: + label = f'{color.name} Wires' + wire_color_in_legend[color.name] = True + else: + label = None + + plt.plot(point.x, point.y, 'X', color=plot_color, markersize=12, + markeredgewidth=3, label=label) + + @staticmethod + def save_plot_to_file(outlines: List[Outline], wires: List[tuple] = None, + colored_outlines: dict = None, + filename: str = 'geometry_plot.png', **kwargs): + """ + Save the plot to a file. + + Args: + outlines: List of Outline objects to plot + wires: List of (Point, Color) tuples for wires + colored_outlines: Dictionary of {color: List[RawOutline]} objects to plot + filename: Output filename + **kwargs: Additional arguments for plot_outlines + """ + plt.figure(figsize=(12, 10)) + + # Track which colors we've already added to the legend + color_in_legend = {} + corner_color_in_legend = {} + + # Plot each outline + for i, outline in enumerate(outlines): + GeometryVisualizer._plot_single_outline(outline, i, + kwargs.get('show_control_points', True), + kwargs.get('show_corners', True), + color_in_legend, corner_color_in_legend) + + # Plot colored outlines (polylines) if requested + if colored_outlines and kwargs.get('show_raw_outlines', True): + GeometryVisualizer._plot_colored_outlines(colored_outlines) + + # Plot wires + if wires: + GeometryVisualizer._plot_wires(wires) + + plt.grid(True, alpha=0.3) + plt.axis('equal') + plt.title('Internal Geometry from SVG Conversion') + plt.xlabel('X coordinate') + plt.ylabel('Y coordinate') + plt.legend() + plt.tight_layout() + plt.savefig(filename, dpi=300, bbox_inches='tight') + plt.close() + print(f"Geometry debug plot saved to: {filename}") + + @staticmethod + def save_plot_to_debug_directory(outlines: List[Outline], svg_file_path: str, + wires: List[tuple] = None, colored_outlines: dict = None, + timestamp: str = None, **kwargs) -> str: + """ + Save geometry plot to debug directory with timestamped filename. + + Args: + outlines: List of Outline objects to plot + svg_file_path: Path to the original SVG file (for naming) + wires: List of (Point, Color) tuples for wires + colored_outlines: Dictionary of {color: List[RawOutline]} objects to plot + timestamp: Optional timestamp string (if None, generates new) + **kwargs: Additional arguments for the plot + + Returns: + Path to the saved plot file + """ + # Create debug directory if it doesn't exist + debug_dir = "debug" + os.makedirs(debug_dir, exist_ok=True) + + # Create debug filename based on input SVG filename and timestamp + svg_filename = os.path.basename(svg_file_path) + svg_name = os.path.splitext(svg_filename)[0] + + # Use provided timestamp or generate new one + if timestamp is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + debug_filename = f"{debug_dir}/geometry_plot_{svg_name}_{timestamp}.png" + + # Save the plot to the debug directory + GeometryVisualizer.save_plot_to_file( + outlines=outlines, + wires=wires, + colored_outlines=colored_outlines, + filename=debug_filename, + **kwargs + ) + + return debug_filename + + @classmethod + def save_plot_with_coordinator(cls, outlines: List[Outline], + coordinator: DebugCoordinator, + wires: List[tuple] = None, + colored_outlines: dict = None, + **kwargs) -> str: + """ + Save plot using a DebugCoordinator for consistent naming. + + Args: + outlines: List of Outline objects to plot + coordinator: DebugCoordinator instance + wires: List of (Point, Color) tuples for wires + colored_outlines: Dictionary of {color: List[RawOutline]} objects to plot + **kwargs: Additional arguments for the plot + + Returns: + Path to the saved plot file + """ + plot_filename = coordinator.get_debug_plot_filename("geometry_plot", ".png") + + # Save the plot to the debug directory + cls.save_plot_to_file( + outlines=outlines, + wires=wires, + colored_outlines=colored_outlines, + filename=plot_filename, + **kwargs + ) + + return plot_filename + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/outline_grouper_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/outline_grouper_debug_writer.py new file mode 100644 index 0000000..a6eed1d --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/outline_grouper_debug_writer.py @@ -0,0 +1,154 @@ +""" +Debug writer for outline grouping operations. +""" + +import os +from typing import List, Dict +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.infrastructure.outline_grouper import OutlineGrouper + + +class OutlineGrouperDebugWriter: + """Debug writer for outline grouping operations.""" + def __init__(self): + """Initialize the debug writer.""" + self._shared_timestamp = None + self._svg_file_path = None + self._svg_name = None + + def set_shared_timestamp(self, timestamp: str): + """Set the shared timestamp for debug outputs.""" + self._shared_timestamp = timestamp + + def set_svg_file(self, svg_file_path: str): + """Set the SVG file being processed.""" + self._svg_file_path = svg_file_path + svg_filename = os.path.basename(svg_file_path) + self._svg_name = os.path.splitext(svg_filename)[0] + + def write_grouping_debug_info( + self, + svg_file_path: str, + outlines: List[Outline], + grouping_result: List[Dict], + grouper_instance: 'OutlineGrouper' + ) -> str: + """ + Write debug information for outline grouping. + + Args: + svg_file_path: Path to the SVG file being processed + outlines: List of outlines + grouping_result: Result from OutlineGrouper.group_outlines() + grouper_instance: The OutlineGrouper instance used + + Returns: + Path to the generated debug file + """ + self.set_svg_file(svg_file_path) + + if not self._shared_timestamp: + raise ValueError("Shared timestamp not set. Call set_shared_timestamp() first.") + + # Create debug directory + debug_dir = "debug" + os.makedirs(debug_dir, exist_ok=True) + + # Generate debug filename + debug_file = self._get_debug_filename("outline_grouping_debug") + + # Write debug information + with open(debug_file, 'w') as f: + self._write_header(f, outlines) + self._write_grouping_summary(f, outlines, grouping_result, grouper_instance) + self._write_containment_hierarchy(f, outlines, grouping_result, grouper_instance) + + print(f"Outline grouping debug information written to: {debug_file}") + + return debug_file + + def _get_debug_filename(self, prefix: str, extension: str = ".txt") -> str: + """Generate a debug filename with timestamp.""" + if not self._svg_name: + raise ValueError("SVG file not set. Call set_svg_file() first.") + + debug_dir = "debug" + return f"{debug_dir}/{prefix}_{self._svg_name}_{self._shared_timestamp}{extension}" + + def _write_header(self, file_obj, outlines: List[Outline]): + """Write a header section to the debug file.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("OUTLINE GROUPING DEBUG\n") + file_obj.write("=" * 80 + "\n\n") + file_obj.write(f"SVG File: {self._svg_file_path}\n") + file_obj.write(f"Timestamp: {self._shared_timestamp}\n") + + # Count outlines by type + va_count = sum(1 for outline in outlines if outline.color.name == "black") + vi_iron_count = sum(1 for outline in outlines if outline.color.name == "blue") + vi_air_count = sum(1 for outline in outlines if outline.color.name == "green") + + file_obj.write(f"Total Outlines: {len(outlines)}\n") + file_obj.write(f" - Va outlines (black): {va_count}\n") + file_obj.write(f" - Vi-iron outlines (blue): {vi_iron_count}\n") + file_obj.write(f" - Vi-air outlines (green): {vi_air_count}\n\n") + + def _write_grouping_summary(self, file_obj, outlines, grouping_result, grouper_instance): + """Write the main grouping summary similar to print_grouping_summary.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("OUTLINE GROUPING SUMMARY\n") + file_obj.write("=" * 80 + "\n\n") + + for i, (outline, group_info) in enumerate(zip(outlines, grouping_result)): + file_obj.write(f"Outline {i}:\n") + file_obj.write(f" Color: {outline.color.name}\n") + file_obj.write(f" Classification: {grouper_instance.classify_outline_color(outline)}\n") + file_obj.write(f" Is Closed: {outline.is_closed}\n") + file_obj.write(f" Bezier Segments: {len(outline.bezier_segments)}\n") + file_obj.write(f" Control Points: {len(outline.control_points)}\n") + file_obj.write(f" Holes (contained outlines): {group_info['holes']}\n") + file_obj.write(f" Physical Groups ({len(group_info['physical_groups'])}):\n") + for pg in group_info['physical_groups']: + file_obj.write(f" - {pg.name} (type: {pg.group_type}, value: {pg.value})\n") + file_obj.write("\n") + + def _write_containment_hierarchy(self, file_obj, outlines, grouping_result, grouper_instance): + """Write the containment hierarchy tree.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("CONTAINMENT HIERARCHY\n") + file_obj.write("=" * 80 + "\n\n") + + n = len(outlines) + has_parent = [False] * n + + for i in range(n): + for hole_idx in grouping_result[i]["holes"]: + has_parent[hole_idx] = True + + roots = [i for i in range(n) if not has_parent[i]] + + def write_tree(node_idx: int, depth: int = 0): + indent = " " * depth + outline = outlines[node_idx] + classification = grouper_instance.classify_outline_color(outline) + + # Get bounding box + try: + min_x, max_x, min_y, max_y = grouper_instance.get_outline_bounding_box(outline) + bbox_info = f"bbox: [{min_x:.3f}, {max_x:.3f}] x [{min_y:.3f}, {max_y:.3f}]" + except Exception: + bbox_info = "bbox: N/A" + + file_obj.write(f"{indent}└─ Outline {node_idx} ({outline.color.name}, {classification}, {bbox_info})\n") + + for hole_idx in grouping_result[node_idx]["holes"]: + write_tree(hole_idx, depth + 1) + + if not roots: + file_obj.write("No root outlines found (all outlines have parents)\n") + else: + for root_idx in roots: + write_tree(root_idx) + + file_obj.write("\n") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/outline_preprocessor_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/outline_preprocessor_debug_writer.py new file mode 100644 index 0000000..82f34b1 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/outline_preprocessor_debug_writer.py @@ -0,0 +1,216 @@ +""" +Debug writer for outline preprocessing operations. +Captures processing order, created entities, and physical group assignments. +""" + +import os +from typing import List, Dict, Any +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.infrastructure.outline_preprocessor import OutlinePreprocessor + + +class OutlinePreprocessorDebugWriter: + """Debug writer for outline preprocessing operations.""" + def __init__(self): + """Initialize the debug writer.""" + self._shared_timestamp = None + self._svg_file_path = None + self._svg_name = None + + def set_shared_timestamp(self, timestamp: str): + """Set the shared timestamp for debug outputs.""" + self._shared_timestamp = timestamp + + def set_svg_file(self, svg_file_path: str): + """Set the SVG file being processed.""" + self._svg_file_path = svg_file_path + svg_filename = os.path.basename(svg_file_path) + self._svg_name = os.path.splitext(svg_filename)[0] + + def write_preprocessing_debug_info( + self, + svg_file_path: str, + outlines: List[Outline], + preprocessor_instance: 'OutlinePreprocessor', + gmsh_results: Dict[str, Any] + ) -> str: + """ + Write debug information for outline preprocessing. + + Args: + svg_file_path: Path to the SVG file being processed + outlines: List of outlines that were preprocessed + preprocessor_instance: The OutlinePreprocessor instance used + gmsh_results: Results dictionary from ConvertGeometryToGmsh.execute() + + Returns: + Path to the generated debug file + """ + self.set_svg_file(svg_file_path) + + if not self._shared_timestamp: + raise ValueError("Shared timestamp not set. Call set_shared_timestamp() first.") + + # Create debug directory + debug_dir = "debug" + os.makedirs(debug_dir, exist_ok=True) + + # Generate debug filename + debug_file = self._get_debug_filename("outline_preprocessing_debug") + + # Write debug information + with open(debug_file, 'w') as f: + self._write_header(f, outlines) + self._write_processing_order(f, preprocessor_instance, outlines) + self._write_entity_summary(f, preprocessor_instance) + self._write_physical_groups(f, preprocessor_instance) + + print(f"Outline preprocessing debug information written to: {debug_file}") + + return debug_file + + def _get_debug_filename(self, prefix: str, extension: str = ".txt") -> str: + """Generate a debug filename with timestamp.""" + if not self._svg_name: + raise ValueError("SVG file not set. Call set_svg_file() first.") + + debug_dir = "debug" + return f"{debug_dir}/{prefix}_{self._svg_name}_{self._shared_timestamp}{extension}" + + def _write_header(self, file_obj, outlines: List[Outline]): + """Write a header section to the debug file.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("OUTLINE PREPROCESSING DEBUG\n") + file_obj.write("=" * 80 + "\n\n") + file_obj.write(f"SVG File: {self._svg_file_path}\n") + file_obj.write(f"Timestamp: {self._shared_timestamp}\n") + file_obj.write(f"Total Outlines: {len(outlines)}\n") + file_obj.write(f"Output Mesh: {self._svg_name}.msh\n\n") + + def _write_processing_order(self, file_obj, preprocessor_instance, outlines: List[Outline]): + """Write the processing order (innermost to outermost).""" + file_obj.write("=" * 80 + "\n") + file_obj.write("PROCESSING ORDER (INNERMOST TO OUTERMOST)\n") + file_obj.write("=" * 80 + "\n\n") + + try: + processing_order = preprocessor_instance.get_processing_order() + if processing_order: + for i, outline_idx in enumerate(processing_order): + if 0 <= outline_idx < len(outlines): + outline = outlines[outline_idx] + file_obj.write(f"{i+1}. Outline {outline_idx} ({outline.color.name}):\n") + file_obj.write(f" - Segments: {len(outline.bezier_segments)}\n") + file_obj.write(f" - Control Points: {len(outline.control_points)}\n") + file_obj.write(f" - Unique Points: {len(outline.unique_control_points)}\n") + file_obj.write(f" - Is Closed: {outline.is_closed}\n") + + # Get curve loop tag if available + try: + curve_loop_tag = preprocessor_instance.get_curve_loop_tag(outline_idx) + file_obj.write(f" - Curve Loop Tag: {curve_loop_tag}\n") + except (KeyError, AttributeError): + pass + file_obj.write("\n") + else: + file_obj.write("No processing order available (using input order)\n") + except AttributeError: + file_obj.write("Processing order not available in preprocessor instance\n") + + file_obj.write("\n") + + def _write_entity_summary(self, file_obj, preprocessor_instance): + """Write summary of created entities (points, curves, surfaces).""" + file_obj.write("=" * 80 + "\n") + file_obj.write("ENTITY CREATION SUMMARY\n") + file_obj.write("=" * 80 + "\n\n") + + # Try to access internal tracking (if attributes exist) + try: + # Points + if hasattr(preprocessor_instance, '_created_points'): + points_count = len(preprocessor_instance._created_points) + file_obj.write(f"Created Points: {points_count}\n") + # Write first few points as example + file_obj.write(" Sample Points (Point -> Gmsh Tag):\n") + for point, tag in list(preprocessor_instance._created_points.items())[:5]: + file_obj.write(f" ({point.x:.6f}, {point.y:.6f}) -> {tag}\n") + if points_count > 5: + file_obj.write(f" ... and {points_count - 5} more points\n") + file_obj.write("\n") + + # Curve tags per outline + if hasattr(preprocessor_instance, '_curve_tags_per_outline'): + total_curves = 0 + for idx, curve_tags in preprocessor_instance._curve_tags_per_outline.items(): + total_curves += len(curve_tags) + file_obj.write(f"Total Created Curves: {total_curves}\n") + + file_obj.write(" Curves per Outline:\n") + for idx, curve_tags in preprocessor_instance._curve_tags_per_outline.items(): + file_obj.write(f" Outline {idx}: {len(curve_tags)} curves (tags: {curve_tags})\n") + file_obj.write("\n") + + # Curve loops + if hasattr(preprocessor_instance, '_curve_loops'): + file_obj.write(f"Created Curve Loops: {len(preprocessor_instance._curve_loops)}\n") + for idx, loop_tag in preprocessor_instance._curve_loops.items(): + file_obj.write(f" Outline {idx}: Curve Loop Tag {loop_tag}\n") + file_obj.write("\n") + + # Surfaces + if hasattr(preprocessor_instance, '_surface_tags'): + file_obj.write(f"Created Surfaces: {len(preprocessor_instance._surface_tags)}\n") + for idx, surface_tag in preprocessor_instance._surface_tags.items(): + file_obj.write(f" Outline {idx}: Surface Tag {surface_tag}\n") + file_obj.write("\n") + + except Exception as e: + file_obj.write(f"Unable to extract entity details: {e}\n") + + file_obj.write("\n") + + def _write_physical_groups(self, file_obj, preprocessor_instance): + """Write physical group assignments.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("PHYSICAL GROUP ASSIGNMENTS\n") + file_obj.write("=" * 80 + "\n\n") + + try: + # Try to get the summary from the preprocessor instance + if hasattr(preprocessor_instance, 'get_physical_group_summary'): + summary = preprocessor_instance.get_physical_group_summary() + file_obj.write(summary + "\n") + else: + file_obj.write("Physical group summary method not available\n") + + # Also try to access internal tracking if available + if hasattr(preprocessor_instance, '_physical_groups_by_type'): + pg_by_type = preprocessor_instance._physical_groups_by_type + + # Boundary groups (1D curves) + boundary_groups = pg_by_type.get('boundary', {}) + file_obj.write(f"Boundary Physical Groups (1D): {len(boundary_groups)}\n") + for pg_value, curve_tags in boundary_groups.items(): + unique_tags = list(dict.fromkeys(curve_tags)) + file_obj.write(f" Tag {pg_value}: {len(unique_tags)} curves\n") + if len(unique_tags) <= 10: # Don't list all tags if too many + file_obj.write(f" Curve tags: {unique_tags}\n") + else: + file_obj.write(f" Curve tags: {unique_tags[:5]} ... and {len(unique_tags)-5} more\n") + file_obj.write("\n") + + # Domain groups (2D surfaces) + domain_groups = pg_by_type.get('domain', {}) + file_obj.write(f"Domain Physical Groups (2D): {len(domain_groups)}\n") + for pg_value, surface_tags in domain_groups.items(): + unique_tags = list(dict.fromkeys(surface_tags)) + file_obj.write(f" Tag {pg_value}: {len(unique_tags)} surfaces\n") + file_obj.write(f" Surface tags: {unique_tags}\n") + file_obj.write("\n") + + except Exception as e: + file_obj.write(f"Unable to extract physical group details: {e}\n") + + file_obj.write("\n") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/svg_parser_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/svg_parser_debug_writer.py new file mode 100644 index 0000000..11ebb0d --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/svg_parser_debug_writer.py @@ -0,0 +1,66 @@ +from datetime import datetime +from svg_to_getdp.interfaces.debug.debug_coordinator import DebugCoordinator + + +class SVGParserDebugWriter(DebugCoordinator): + """Handles writing debug information for SVG parsing.""" + + def __init__(self): + super().__init__() + + def write_svg_parser_debug_info(self, svg_file_path: str, colored_raw_outlines: dict): + """ + Write SVG parser results to a debug text file. + """ + self.set_svg_file(svg_file_path) + debug_filename = self.get_debug_filename("svg_parser_debug", ".txt") + + with open(debug_filename, 'w') as f: + f.write(f"SVG Parser Debug Information\n") + f.write(f"============================\n") + f.write(f"Input SVG: {svg_file_path}\n") + f.write(f"Processed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Debug run timestamp: {self.get_shared_timestamp()}\n") + f.write(f"\n") + + total_raw_outlines = 0 + for color, raw_outlines in colored_raw_outlines.items(): + f.write(f"Color: {color}\n") + f.write(f"Number of raw_outlines: {len(raw_outlines)}\n") + total_raw_outlines += len(raw_outlines) + + for i, raw_outline in enumerate(raw_outlines): + f.write(f" raw_outline {i+1}:\n") + f.write(f" Is closed: {raw_outline.is_closed}\n") + f.write(f" Number of points: {len(raw_outline.points)}\n") + f.write(f" Points:\n") + + for j, point in enumerate(raw_outline.points): + f.write(f" [{j}] x={point.x:.6f}, y={point.y:.6f}\n") + + # Calculate bounding box + if raw_outline.points: + x_coords = [p.x for p in raw_outline.points] + y_coords = [p.y for p in raw_outline.points] + f.write(f" Bounding box: x=[{min(x_coords):.6f}, {max(x_coords):.6f}], " + f"y=[{min(y_coords):.6f}, {max(y_coords):.6f}]\n") + + f.write(f"\n") + + f.write(f"\n") + + f.write(f"Total raw_outlines processed: {total_raw_outlines}\n") + f.write(f"\n") + + # Summary statistics + f.write(f"Summary by color:\n") + for color, raw_outlines in colored_raw_outlines.items(): + total_points = sum(len(raw_outline.points) for raw_outline in raw_outlines) + avg_points = total_points / len(raw_outlines) if raw_outlines else 0 + closed_count = sum(1 for raw_outline in raw_outlines if raw_outline.is_closed) + + f.write(f" {color}: {len(raw_outlines)} raw_outlines, {total_points} total points, " + f"{avg_points:.1f} avg points, {closed_count} closed\n") + + print(f"SVG parser debug information written to: {debug_filename}") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/debug/wire_preprocessor_debug_writer.py b/sketchgetdp/svg_to_getdp/interfaces/debug/wire_preprocessor_debug_writer.py new file mode 100644 index 0000000..fcb1e38 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/interfaces/debug/wire_preprocessor_debug_writer.py @@ -0,0 +1,348 @@ +""" +Debug writer for wire preprocessor operations. +Captures wire sorting, clustering, and Gmsh entity creation. +""" + +import os +import math +from typing import List, Tuple, Dict, Any +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.interfaces.abstractions.wire_preprocessor_interface import WirePreprocessorInterface + + +class WirePreprocessorDebugWriter: + """Debug writer for wire preprocessor operations.""" + + def __init__(self): + """Initialize the debug writer.""" + self._shared_timestamp = None + self._svg_file_path = None + self._svg_name = None + + def set_shared_timestamp(self, timestamp: str): + """Set the shared timestamp for debug outputs.""" + self._shared_timestamp = timestamp + + def set_svg_file(self, svg_file_path: str): + """Set the SVG file being processed.""" + self._svg_file_path = svg_file_path + svg_filename = os.path.basename(svg_file_path) + self._svg_name = os.path.splitext(svg_filename)[0] + + def write_wire_preprocessor_debug_info( + self, + svg_file_path: str, + wires: List[Tuple[Point, Color]], + config_file_path: str, + wire_preprocessor_instance: WirePreprocessorInterface, + gmsh_results: Dict[str, Any] + ) -> str: + """ + Write debug information for wire preprocessor operations. + + Args: + svg_file_path: Path to the SVG file being processed + wires: Original list of (point, color) tuples representing wires + config_file_path: Path to the YAML configuration file + wire_preprocessor_instance: The WirePreprocessor instance used + gmsh_results: Full Gmsh results dictionary from ConvertGeometryToGmsh.execute() + + Returns: + Path to the generated debug file + """ + self.set_svg_file(svg_file_path) + + if not self._shared_timestamp: + raise ValueError("Shared timestamp not set. Call set_shared_timestamp() first.") + + # Create debug directory + debug_dir = "debug" + os.makedirs(debug_dir, exist_ok=True) + + # Generate debug filename + debug_file = self._get_debug_filename("wire_preprocessor_debug") + + # Extract wire_results from gmsh_results + wire_results = gmsh_results.get("wire_results", {}) + + # Write debug information + with open(debug_file, 'w') as f: + self._write_header(f, wires, config_file_path, gmsh_results) + self._write_configuration_summary(f, config_file_path, wire_preprocessor_instance) + self._write_wire_sorting_info(f, wires, wire_preprocessor_instance) + self._write_clustering_info(f, wires, wire_preprocessor_instance, wire_results) + self._write_gmsh_entity_info(f, wire_results) + self._write_cluster_statistics(f, wire_results) + + print(f"Wire preprocessor debug information written to: {debug_file}") + + return debug_file + + def _get_debug_filename(self, prefix: str, extension: str = ".txt") -> str: + """Generate a debug filename with timestamp.""" + if not self._svg_name: + raise ValueError("SVG file not set. Call set_svg_file() first.") + + debug_dir = "debug" + return f"{debug_dir}/{prefix}_{self._svg_name}_{self._shared_timestamp}{extension}" + + def _write_header(self, file_obj, wires: List[Tuple[Point, Color]], config_file_path: str, gmsh_results: Dict): + """Write a header section to the debug file.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("WIRE PREPROCESSOR DEBUG\n") + file_obj.write("=" * 80 + "\n\n") + file_obj.write(f"SVG File: {self._svg_file_path}\n") + file_obj.write(f"Config File: {config_file_path}\n") + file_obj.write(f"Timestamp: {self._shared_timestamp}\n") + file_obj.write(f"Total Wires: {len(wires)}\n") + + # Count wires by color + color_count = {} + for _, color in wires: + color_name = color.name + color_count[color_name] = color_count.get(color_name, 0) + 1 + + for color_name, count in color_count.items(): + file_obj.write(f" - {color_name}: {count}\n") + + file_obj.write("\n") + + def _write_configuration_summary(self, file_obj, config_file_path: str, wire_preprocessor: WirePreprocessorInterface): + """Write wire cluster configuration from YAML.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("WIRE CLUSTER CONFIGURATION\n") + file_obj.write("=" * 80 + "\n\n") + + try: + # Try to access internal method to load clusters + if hasattr(wire_preprocessor, '_load_wire_clusters'): + clusters = wire_preprocessor._load_wire_clusters(config_file_path) + + total_wires_configured = sum(cluster.wire_count for cluster in clusters) + file_obj.write(f"Total wires in configuration: {total_wires_configured}\n") + file_obj.write(f"Number of clusters: {len(clusters)}\n\n") + + for i, cluster in enumerate(clusters): + polarity = "+" if cluster.current_sign == 1 else "-" + file_obj.write(f"Cluster {i+1}: {cluster.name}\n") + file_obj.write(f" Wire count: {cluster.wire_count}\n") + file_obj.write(f" Current sign: {cluster.current_sign} ({polarity})\n\n") + else: + file_obj.write("Unable to extract cluster configuration: _load_wire_clusters method not found\n") + + except Exception as e: + file_obj.write(f"Error loading cluster configuration: {e}\n") + + file_obj.write("\n") + + def _write_wire_sorting_info(self, file_obj, wires: List[Tuple[Point, Color]], wire_preprocessor: WirePreprocessorInterface): + """Write information about wire sorting order.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("WIRE SORTING (TOP-TO-BOTTOM, LEFT-TO-RIGHT)\n") + file_obj.write("=" * 80 + "\n\n") + + try: + # Convert to Wire objects if needed + wire_objects = [] + for i, (point, color) in enumerate(wires): + # Create a simple wire-like object + class SimpleWire: + def __init__(self, point, color, index): + self.point = point + self.color = color + self.original_index = index + + wire_objects.append(SimpleWire(point, color, i)) + + # Try to sort using the preprocessor's method + if hasattr(wire_preprocessor, '_sort_wires'): + sorted_wires = wire_preprocessor._sort_wires(wire_objects) + + file_obj.write("Sorted wire order:\n") + file_obj.write(f"{'Index':<8} {'Original':<10} {'X':<12} {'Y':<12} {'Color':<10}\n") + file_obj.write("-" * 52 + "\n") + + for i, wire in enumerate(sorted_wires): + file_obj.write(f"{i+1:<8} {wire.original_index:<10} {wire.point.x:<12.6f} {wire.point.y:<12.6f} {wire.color.name:<10}\n") + else: + # Manual sorting + sorted_wires = sorted(wire_objects, key=lambda w: (-w.point.y, w.point.x)) + file_obj.write("Wires sorted manually (preprocessor method not available):\n") + file_obj.write(f"{'Index':<8} {'Original':<10} {'X':<12} {'Y':<12} {'Color':<10}\n") + file_obj.write("-" * 52 + "\n") + + for i, wire in enumerate(sorted_wires): + file_obj.write(f"{i+1:<8} {wire.original_index:<10} {wire.point.x:<12.6f} {wire.point.y:<12.6f} {wire.color.name:<10}\n") + + except Exception as e: + file_obj.write(f"Error during wire sorting debug: {e}\n") + + file_obj.write("\n") + + def _write_clustering_info(self, file_obj, wires: List[Tuple[Point, Color]], wire_preprocessor: WirePreprocessorInterface, wire_results: Dict): + """Write information about proximity-based clustering.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("WIRE CLUSTERING BY PROXIMITY\n") + file_obj.write("=" * 80 + "\n\n") + + if not wire_results: + file_obj.write("No wire results available\n\n") + return + + try: + # Group wires by cluster + clusters = {} + for wire_data in wire_results.values(): + cluster_name = wire_data['cluster_name'] + if cluster_name not in clusters: + clusters[cluster_name] = { + 'current_sign': wire_data['physical_group'].current_sign if hasattr(wire_data['physical_group'], 'current_sign') else None, + 'wire_count': 0, + 'wires': [], + 'positions': [], + 'original_indices': [] + } + clusters[cluster_name]['wire_count'] += 1 + clusters[cluster_name]['wires'].append(wire_data['wire_name']) + clusters[cluster_name]['positions'].append( + (wire_data['point'].x, wire_data['point'].y) + ) + clusters[cluster_name]['original_indices'].append(wire_data['wire_index']) + + # Write cluster summary + file_obj.write(f"Clusters created: {len(clusters)}\n\n") + + for cluster_name, cluster_info in clusters.items(): + polarity = "+" if cluster_info['current_sign'] == 1 else "-" + file_obj.write(f"Cluster '{cluster_name}' ({polarity}):\n") + file_obj.write(f" Wire count: {cluster_info['wire_count']}\n") + file_obj.write(f" Wire names: {', '.join(cluster_info['wires'])}\n") + file_obj.write(f" Original indices: {cluster_info['original_indices']}\n") + + # Calculate intra-cluster distances + if cluster_info['wire_count'] > 1: + positions = cluster_info['positions'] + max_distance = 0 + min_distance = float('inf') + + for i in range(len(positions)): + for j in range(i+1, len(positions)): + x1, y1 = positions[i] + x2, y2 = positions[j] + distance = math.sqrt((x1-x2)**2 + (y1-y2)**2) + max_distance = max(max_distance, distance) + min_distance = min(min_distance, distance) + + file_obj.write(f" Intra-cluster distances:\n") + file_obj.write(f" Max: {max_distance:.6f}\n") + file_obj.write(f" Min: {min_distance:.6f}\n") + + file_obj.write("\n") + + except Exception as e: + file_obj.write(f"Error extracting clustering info: {e}\n") + + file_obj.write("\n") + + def _write_gmsh_entity_info(self, file_obj, wire_results: Dict): + """Write information about Gmsh entities created.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("GMSH ENTITY CREATION\n") + file_obj.write("=" * 80 + "\n\n") + + if not wire_results: + file_obj.write("No wire results available\n\n") + return + + # Count positive and negative wires + positive_wires = [] + negative_wires = [] + + for wire_data in wire_results.values(): + if wire_data['physical_group'].current_sign == 1: + positive_wires.append(wire_data) + else: + negative_wires.append(wire_data) + + file_obj.write(f"Positive wires (+): {len(positive_wires)}\n") + file_obj.write(f"Negative wires (-): {len(negative_wires)}\n\n") + + # Write Gmsh point tags + file_obj.write("Gmsh Point Tags:\n") + file_obj.write(f"{'Wire':<12} {'Original':<10} {'Gmsh Tag':<12} {'Physical Group':<15} {'Cluster':<15} {'Position':<25}\n") + file_obj.write("-" * 90 + "\n") + + for wire_data in wire_results.values(): + polarity = "+" if wire_data['physical_group'].current_sign == 1 else "-" + file_obj.write( + f"{wire_data['wire_name']:<12} " + f"{wire_data['wire_index']:<10} " + f"{wire_data['gmsh_point_tag']:<12} " + f"{wire_data['physical_group'].name:<15} " + f"{wire_data['cluster_name']:<15} " + f"({wire_data['point'].x:.6f}, {wire_data['point'].y:.6f})\n" + ) + + file_obj.write("\n") + + def _write_cluster_statistics(self, file_obj, wire_results: Dict): + """Write detailed cluster statistics.""" + file_obj.write("=" * 80 + "\n") + file_obj.write("CLUSTER STATISTICS\n") + file_obj.write("=" * 80 + "\n\n") + + if not wire_results: + file_obj.write("No wire results available\n\n") + return + + # Group wires by cluster + clusters = {} + for wire_data in wire_results.values(): + cluster_name = wire_data['cluster_name'] + if cluster_name not in clusters: + clusters[cluster_name] = { + 'current_sign': wire_data['physical_group'].current_sign if hasattr(wire_data['physical_group'], 'current_sign') else None, + 'wires': [], + 'positions': [] + } + clusters[cluster_name]['wires'].append(wire_data) + clusters[cluster_name]['positions'].append( + (wire_data['point'].x, wire_data['point'].y) + ) + + # Calculate statistics for each cluster + for cluster_name, cluster_info in clusters.items(): + polarity = "+" if cluster_info['current_sign'] == 1 else "-" + positions = cluster_info['positions'] + + # Calculate cluster center + avg_x = sum(p[0] for p in positions) / len(positions) + avg_y = sum(p[1] for p in positions) / len(positions) + + # Calculate distances from center + distances = [] + for x, y in positions: + distance = math.sqrt((x - avg_x)**2 + (y - avg_y)**2) + distances.append(distance) + + max_distance = max(distances) if distances else 0 + min_distance = min(distances) if distances else 0 + avg_distance = sum(distances) / len(distances) if distances else 0 + + file_obj.write(f"Cluster '{cluster_name}' ({polarity}):\n") + file_obj.write(f" Wire count: {len(positions)}\n") + file_obj.write(f" Center: ({avg_x:.6f}, {avg_y:.6f})\n") + file_obj.write(f" Distance from center:\n") + file_obj.write(f" Max: {max_distance:.6f}\n") + file_obj.write(f" Min: {min_distance:.6f}\n") + file_obj.write(f" Avg: {avg_distance:.6f}\n") + + # Wire distances relative to center + file_obj.write(f" Wire distances from center:\n") + for wire_data, distance in zip(cluster_info['wires'], distances): + file_obj.write(f" {wire_data['wire_name']}: {distance:.6f} " + f"at ({wire_data['point'].x:.6f}, {wire_data['point'].y:.6f})\n") + + file_obj.write("\n") + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/interfaces/mesher/__init__.py b/sketchgetdp/svg_to_getdp/interfaces/mesher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/geometry/gmsh_toolbox.py b/sketchgetdp/svg_to_getdp/interfaces/mesher/gmsh_toolbox.py similarity index 100% rename from sketchgetdp/geometry/gmsh_toolbox.py rename to sketchgetdp/svg_to_getdp/interfaces/mesher/gmsh_toolbox.py diff --git a/sketchgetdp/svg_to_getdp/interfaces/solver/__init__.py b/sketchgetdp/svg_to_getdp/interfaces/solver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sketchgetdp/solver/getdp_toolbox.py b/sketchgetdp/svg_to_getdp/interfaces/solver/getdp_toolbox.py similarity index 73% rename from sketchgetdp/solver/getdp_toolbox.py rename to sketchgetdp/svg_to_getdp/interfaces/solver/getdp_toolbox.py index 70704fa..03f6eac 100644 --- a/sketchgetdp/solver/getdp_toolbox.py +++ b/sketchgetdp/svg_to_getdp/interfaces/solver/getdp_toolbox.py @@ -4,30 +4,68 @@ Author: Laura D'Angelo """ +import os +import sys + +# Add the project root directory to Python path +project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.insert(0, project_root) + import gmsh import numpy as np -from sketchgetdp.geometry import gmsh_toolbox as geo +from svg_to_getdp.interfaces.mesher import gmsh_toolbox as geo +import os -def get_getdp_path(filename: str) -> str: +def get_getdp_path(filename: str = "./../../../../getdp_path.txt") -> str: """ Returns the path for running GetDP on the respective computer. Parameters: - filename (str): file name + filename (str): file name containing GetDP path. + Can be absolute path or relative to this script. Returns: str: path to GetDP executable """ + # Get the directory where this script (getdp_toolbox.py) is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + + # Check if filename is an absolute path + if os.path.isabs(filename): + path_file = filename + else: + # If relative, resolve it relative to this script's location + # Join with script_dir, then normalize the path + path_file = os.path.join(script_dir, filename) + path_file = os.path.normpath(path_file) + try: - file = open(filename, 'r') + with open(path_file, 'r') as file: + path = file.readline().strip() + if path: + print(f"Found GetDP path at: {path_file}") + return path + else: + raise ValueError(f"{filename} is empty") except FileNotFoundError: - message = 'Error: ' + filename + " not found. You have to create this file and give the path of your GetDP executable." + # Provide helpful error message showing what we tried + message = f"""Error: Could not find GetDP path file. + +Tried to open: {path_file} +This script is located at: {script_dir} + +You need to ensure 'getdp_path.txt' exists at the expected location. +Current expected location (relative to script): {filename} + +Please check that: +1. The file exists at: {path_file} +2. Or update the filename parameter in get_getdp_path() call +""" + exit(message) + except Exception as e: + message = f"Error reading {path_file}: {str(e)}" exit(message) - data = file.readlines() - path = data[0].split('\n') - file.close() - return path[0] def physical_identifiers() -> dict: @@ -81,7 +119,7 @@ def run_magnetostatic_simulation(msh_name: str, show_simulation_result: bool = T pro_name = "rmvp_formulation.pro" resolution_name = "Magnetostatic_Resolution" gmsh.open(pro_name) - getdp_path = get_getdp_path("./../../getdp_path.txt") + getdp_path = get_getdp_path() onelab_command = getdp_path + " " + pro_name + " -msh " + msh_name + " -solve " + resolution_name + " -pos" gmsh.onelab.run("GetDP", onelab_command) if show_simulation_result: diff --git a/sketchgetdp/svg_to_getdp/pytest.ini b/sketchgetdp/svg_to_getdp/pytest.ini new file mode 100644 index 0000000..decb2e8 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = . +testpaths = tests \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/core/entities/test_bezier_segment.py b/sketchgetdp/svg_to_getdp/tests/core/entities/test_bezier_segment.py new file mode 100644 index 0000000..d392097 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/entities/test_bezier_segment.py @@ -0,0 +1,303 @@ +""" +Unit tests for BezierSegment class. + +Tests Bézier segment functionality including creation, evaluation, +derivative calculation, and geometric properties. +""" +import pytest +from core.entities.point import Point +from core.entities.bezier_segment import BezierSegment + + +class TestBezierSegment: + """Test suite for BezierSegment class.""" + + # ==================== Initialization Tests ==================== + + def test_bezier_segment_creation_linear(self): + """Test creation of linear Bézier segment (degree 1).""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + assert segment.degree == 1 + assert segment.control_points == [p0, p1] + assert segment.start_point == p0 + assert segment.end_point == p1 + + def test_bezier_segment_creation_quadratic(self): + """Test creation of quadratic Bézier segment (degree 2).""" + p0 = Point(0.0, 0.0) + p1 = Point(0.5, 1.0) + p2 = Point(1.0, 0.0) + segment = BezierSegment([p0, p1, p2], degree=2) + + assert segment.degree == 2 + assert segment.control_points == [p0, p1, p2] + + def test_bezier_segment_creation_cubic(self): + """Test creation of cubic Bézier segment (degree 3).""" + p0 = Point(0.0, 0.0) + p1 = Point(0.33, 1.0) + p2 = Point(0.66, 1.0) + p3 = Point(1.0, 0.0) + segment = BezierSegment([p0, p1, p2, p3], degree=3) + + assert segment.degree == 3 + assert segment.control_points == [p0, p1, p2, p3] + + def test_invalid_control_points_count(self): + """Test that invalid control point count raises error.""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + + with pytest.raises(ValueError, match="Degree 2 requires 3 control points"): + BezierSegment([p0, p1], degree=2) + + with pytest.raises(ValueError, match="Degree 1 requires 2 control points"): + BezierSegment([p0], degree=1) + + # ==================== Evaluation Tests ==================== + + def test_linear_bezier_evaluation(self): + """Test evaluation of linear Bézier curve.""" + p0 = Point(0.0, 0.0) + p1 = Point(2.0, 2.0) + segment = BezierSegment([p0, p1], degree=1) + + # Test start point + result_start = segment.evaluate(0.0) + assert result_start.x == pytest.approx(p0.x) + assert result_start.y == pytest.approx(p0.y) + + # Test end point + result_end = segment.evaluate(1.0) + assert result_end.x == pytest.approx(p1.x) + assert result_end.y == pytest.approx(p1.y) + + # Test midpoint + midpoint = segment.evaluate(0.5) + assert midpoint.x == pytest.approx(1.0) + assert midpoint.y == pytest.approx(1.0) + + def test_quadratic_bezier_evaluation(self): + """Test evaluation of quadratic Bézier curve.""" + p0 = Point(0.0, 0.0) + p1 = Point(0.5, 1.0) + p2 = Point(1.0, 0.0) + segment = BezierSegment([p0, p1, p2], degree=2) + + # Test start and end points + result_start = segment.evaluate(0.0) + assert result_start.x == pytest.approx(p0.x) + assert result_start.y == pytest.approx(p0.y) + + result_end = segment.evaluate(1.0) + assert result_end.x == pytest.approx(p2.x) + assert result_end.y == pytest.approx(p2.y) + + # Test midpoint + midpoint = segment.evaluate(0.5) + assert midpoint.x == pytest.approx(0.5) + assert midpoint.y == pytest.approx(0.5) + + def test_evaluation_parameter_range(self): + """Test that evaluation only works for t in [0,1].""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0, 1\\]"): + segment.evaluate(-0.1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0, 1\\]"): + segment.evaluate(1.1) + + # ==================== Bernstein Basis Tests ==================== + + def test_bernstein_basis_calculation(self): + """Test Bernstein basis polynomial calculation.""" + segment = BezierSegment([Point(0.0, 0.0), Point(1.0, 1.0)], degree=1) + + # For degree 1, Bernstein basis should be linear + assert segment.bernstein_basis(0, 0.0) == 1.0 + assert segment.bernstein_basis(0, 1.0) == 0.0 + assert segment.bernstein_basis(1, 0.0) == 0.0 + assert segment.bernstein_basis(1, 1.0) == 1.0 + assert segment.bernstein_basis(0, 0.5) == pytest.approx(0.5) + assert segment.bernstein_basis(1, 0.5) == pytest.approx(0.5) + + def test_bernstein_basis_invalid_index(self): + """Test that invalid Bernstein basis index raises error.""" + segment = BezierSegment([Point(0.0, 0.0), Point(1.0, 1.0)], degree=1) + + with pytest.raises(ValueError, match="Index i must be between 0 and 1"): + segment.bernstein_basis(2, 0.5) + + with pytest.raises(ValueError, match="Index i must be between 0 and 1"): + segment.bernstein_basis(-1, 0.5) + + # ==================== Derivative Tests ==================== + + def test_linear_bezier_derivative(self): + """Test derivative calculation for linear Bézier.""" + p0 = Point(0.0, 0.0) + p1 = Point(2.0, 2.0) + segment = BezierSegment([p0, p1], degree=1) + + # Derivative of linear Bézier is constant + derivative = segment.derivative(0.5) + expected_x = 2.0 # p1.x - p0.x + expected_y = 2.0 # p1.y - p0.y + + assert derivative.x == pytest.approx(expected_x) + assert derivative.y == pytest.approx(expected_y) + + # Should be same at all points + deriv_start = segment.derivative(0.0) + assert deriv_start.x == pytest.approx(expected_x) + assert deriv_start.y == pytest.approx(expected_y) + + deriv_end = segment.derivative(1.0) + assert deriv_end.x == pytest.approx(expected_x) + assert deriv_end.y == pytest.approx(expected_y) + + def test_quadratic_bezier_derivative(self): + """Test derivative calculation for quadratic Bézier.""" + p0 = Point(0.0, 0.0) + p1 = Point(0.5, 1.0) + p2 = Point(1.0, 0.0) + segment = BezierSegment([p0, p1, p2], degree=2) + + # Test derivative at start + deriv_start = segment.derivative(0.0) + # For quadratic: 2 * (p1 - p0) at t=0 + expected_start_x = 1.0 # 2 * (0.5 - 0) + expected_start_y = 2.0 # 2 * (1 - 0) + assert deriv_start.x == pytest.approx(expected_start_x) + assert deriv_start.y == pytest.approx(expected_start_y) + + # Test derivative at end + deriv_end = segment.derivative(1.0) + # For quadratic: 2 * (p2 - p1) at t=1 + expected_end_x = 1.0 # 2 * (1 - 0.5) + expected_end_y = -2.0 # 2 * (0 - 1) + assert deriv_end.x == pytest.approx(expected_end_x) + assert deriv_end.y == pytest.approx(expected_end_y) + + # Test derivative at midpoint + deriv_mid = segment.derivative(0.5) + # For quadratic: 2 * ((1-t)*(p1-p0) + t*(p2-p1)) at t=0.5 + expected_mid_x = 1.0 + expected_mid_y = 0.0 + assert deriv_mid.x == pytest.approx(expected_mid_x) + assert deriv_mid.y == pytest.approx(expected_mid_y) + + def test_degree_zero_bezier_derivative(self): + """Test derivative of degree 0 Bézier (constant point).""" + p0 = Point(1.0, 2.0) + segment = BezierSegment([p0], degree=0) + + # Derivative of constant should be zero + derivative = segment.derivative(0.5) + assert derivative.x == pytest.approx(0.0) + assert derivative.y == pytest.approx(0.0) + + def test_derivative_parameter_range(self): + """Test that derivative only works for t in [0,1].""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0, 1\\]"): + segment.derivative(-0.1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0, 1\\]"): + segment.derivative(1.1) + + # ==================== Sampling Tests ==================== + + def test_get_curve_points(self): + """Test sampling multiple points along the curve.""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + points = segment.get_curve_points(num_points=3) + + assert len(points) == 3 + assert points[0].x == pytest.approx(p0.x) + assert points[0].y == pytest.approx(p0.y) + assert points[1].x == pytest.approx(0.5) + assert points[1].y == pytest.approx(0.5) + assert points[2].x == pytest.approx(p1.x) + assert points[2].y == pytest.approx(p1.y) + + def test_get_curve_points_invalid_count(self): + """Test that invalid point count raises error.""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + with pytest.raises(ValueError, match="Number of points must be at least 2"): + segment.get_curve_points(num_points=1) + + with pytest.raises(ValueError, match="Number of points must be at least 2"): + segment.get_curve_points(num_points=0) + + # ==================== Property Tests ==================== + + def test_straight_line_property(self): + """Test that linear Bézier creates straight lines.""" + p0 = Point(0.0, 0.0) + p1 = Point(10.0, 5.0) + segment = BezierSegment([p0, p1], degree=1) + + # All points should lie on the straight line between p0 and p1 + for t in [0.0, 0.25, 0.5, 0.75, 1.0]: + point = segment.evaluate(t) + expected_x = t * 10.0 + expected_y = t * 5.0 + assert point.x == pytest.approx(expected_x) + assert point.y == pytest.approx(expected_y) + + def test_convex_hull_property(self): + """Test that Bézier curve lies within convex hull of control points.""" + p0 = Point(0.0, 0.0) + p1 = Point(2.0, 3.0) + p2 = Point(4.0, 0.0) + segment = BezierSegment([p0, p1, p2], degree=2) + + # Sample multiple points and verify they're within the triangle + for t in [0.0, 0.25, 0.5, 0.75, 1.0]: + point = segment.evaluate(t) + assert 0.0 <= point.x <= 4.0 + assert 0.0 <= point.y <= 1.5 + + # ==================== Interface Tests ==================== + + def test_bezier_segment_equality(self): + """Test equality comparison between Bézier segments.""" + p0, p1 = Point(0.0, 0.0), Point(1.0, 1.0) + p2, p3 = Point(0.0, 0.0), Point(2.0, 2.0) + + segment1 = BezierSegment([p0, p1], degree=1) + segment2 = BezierSegment([p0, p1], degree=1) + segment3 = BezierSegment([p0, p3], degree=1) + segment4 = BezierSegment([p0, p1, p2], degree=2) + + assert segment1 == segment2 + assert segment1 != segment3 + assert segment1 != segment4 + assert segment1 != "not a segment" + + def test_bezier_segment_repr(self): + """Test string representation of Bézier segment.""" + p0 = Point(0.0, 0.0) + p1 = Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + + repr_str = repr(segment) + assert "BezierSegment" in repr_str + assert "degree=1" in repr_str + assert "control_points=2" in repr_str \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/core/entities/test_color.py b/sketchgetdp/svg_to_getdp/tests/core/entities/test_color.py new file mode 100644 index 0000000..f6023e3 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/entities/test_color.py @@ -0,0 +1,194 @@ +""" +Unit tests for Color class. + +Tests color creation, validation, conversion, and predefined color functionality. +""" +import pytest + +from core.entities.color import Color + + +class TestColor: + """Test suite for Color class.""" + + # ==================== Basic Functionality Tests ==================== + + def test_color_creation(self): + """Test that a color can be created with name and RGB values.""" + color = Color("red", (255, 0, 0)) + + assert color.name == "red" + assert color.rgb == (255, 0, 0) + + def test_predefined_colors(self): + """Test that predefined colors are available and correct.""" + assert Color.RED.name == "red" + assert Color.RED.rgb == (255, 0, 0) + + assert Color.GREEN.name == "green" + assert Color.GREEN.rgb == (0, 255, 0) + + assert Color.BLUE.name == "blue" + assert Color.BLUE.rgb == (0, 0, 255) + + assert Color.BLACK.name == "black" + assert Color.BLACK.rgb == (0, 0, 0) + + def test_color_equality(self): + """Test that colors with same name and RGB are equal.""" + color1 = Color("red", (255, 0, 0)) + color2 = Color("red", (255, 0, 0)) + color3 = Color("blue", (0, 0, 255)) + + assert color1 == color2 + assert color1 != color3 + + def test_color_hash(self): + """Test that colors are hashable.""" + color1 = Color("red", (255, 0, 0)) + color2 = Color("red", (255, 0, 0)) + color3 = Color("green", (0, 255, 0)) + color4 = Color("black", (0, 0, 0)) + + color_set = {color1, color2, color3, color4} + assert len(color_set) == 3 # color1 and color2 are duplicates + assert color1 in color_set + assert color2 in color_set + assert color3 in color_set + assert color4 in color_set + + # ==================== Immutability Tests ==================== + + def test_color_immutability(self): + """Test that Color is immutable.""" + color = Color("red", (255, 0, 0)) + + with pytest.raises(AttributeError): + color.name = "blue" + with pytest.raises(AttributeError): + color.rgb = (0, 0, 255) + + # ==================== String Representation Tests ==================== + + def test_color_repr(self): + """Test the string representation of Color.""" + color = Color("red", (255, 0, 0)) + repr_str = repr(color) + + assert "Color" in repr_str + assert "red" in repr_str + assert "(255, 0, 0)" in repr_str + + def test_color_str(self): + """Test the human-readable string representation.""" + color = Color("green", (0, 255, 0)) + str_repr = str(color) + + assert "Color" in str_repr + assert "green" in str_repr + assert "(0, 255, 0)" in str_repr + + # ==================== Conversion Methods Tests ==================== + + def test_to_hex(self): + """Test conversion to hexadecimal format.""" + assert Color.RED.to_hex() == "#ff0000" + assert Color.GREEN.to_hex() == "#00ff00" + assert Color.BLUE.to_hex() == "#0000ff" + assert Color.BLACK.to_hex() == "#000000" + + def test_to_normalized_rgb(self): + """Test conversion to normalized RGB values.""" + red_norm = Color.RED.to_normalized_rgb() + green_norm = Color.GREEN.to_normalized_rgb() + blue_norm = Color.BLUE.to_normalized_rgb() + black_norm = Color.BLACK.to_normalized_rgb() + + assert red_norm == (1.0, 0.0, 0.0) + assert green_norm == (0.0, 1.0, 0.0) + assert blue_norm == (0.0, 0.0, 1.0) + assert black_norm == (0.0, 0.0, 0.0) + + # Test with mid-range values using allowed color names + dark_red = Color("red", (128, 0, 0)) + dark_red_norm = dark_red.to_normalized_rgb() + expected_red = (128/255.0, 0.0, 0.0) + assert dark_red_norm == pytest.approx(expected_red) + + dark_green = Color("green", (0, 128, 0)) + dark_green_norm = dark_green.to_normalized_rgb() + expected_green = (0.0, 128/255.0, 0.0) + assert dark_green_norm == pytest.approx(expected_green) + + dark_black = Color("black", (64, 64, 64)) + dark_black_norm = dark_black.to_normalized_rgb() + expected_black = (64/255.0, 64/255.0, 64/255.0) + assert dark_black_norm == pytest.approx(expected_black) + + # ==================== Validation Tests ==================== + + def test_invalid_color_name(self): + """Test that color rejects invalid names.""" + with pytest.raises(ValueError, match="Color must be 'red', 'green', 'blue', or 'black'"): + Color("yellow", (255, 255, 0)) + with pytest.raises(ValueError, match="Color must be 'red', 'green', 'blue', or 'black'"): + Color("", (255, 0, 0)) + with pytest.raises(ValueError, match="Color must be 'red', 'green', 'blue', or 'black'"): + Color("RED", (255, 0, 0)) # case sensitive + with pytest.raises(ValueError, match="Color must be 'red', 'green', 'blue', or 'black'"): + Color("gray", (128, 128, 128)) + + def test_invalid_name_type(self): + """Test that color name must be a string.""" + with pytest.raises(TypeError, match="Color name must be a string"): + Color(123, (255, 0, 0)) + with pytest.raises(TypeError, match="Color name must be a string"): + Color(None, (255, 0, 0)) + + def test_invalid_rgb_format(self): + """Test that RGB must be a tuple of 3 integers.""" + with pytest.raises(ValueError, match="RGB must be a tuple of 3 integers"): + Color("red", [255, 0, 0]) # list instead of tuple + with pytest.raises(ValueError, match="RGB must be a tuple of 3 integers"): + Color("red", (255, 0)) # too few elements + with pytest.raises(ValueError, match="RGB must be a tuple of 3 integers"): + Color("red", (255, 0, 0, 0)) # too many elements + + def test_invalid_rgb_values(self): + """Test that RGB values must be between 0 and 255.""" + with pytest.raises(ValueError, match="RGB values must be integers between 0 and 255"): + Color("red", (-1, 0, 0)) # negative value + with pytest.raises(ValueError, match="RGB values must be integers between 0 and 255"): + Color("red", (256, 0, 0)) # value too high + with pytest.raises(ValueError, match="RGB values must be integers between 0 and 255"): + Color("red", (255.5, 0, 0)) # float instead of int + with pytest.raises(ValueError, match="RGB values must be integers between 0 and 255"): + Color("red", ("255", 0, 0)) # string instead of int + + # ==================== Parameterized Tests ==================== + + @pytest.mark.parametrize("name,rgb,expected_hex,expected_norm", [ + ("red", (128, 0, 0), "#800000", (128/255.0, 0.0, 0.0)), + ("green", (0, 128, 0), "#008000", (0.0, 128/255.0, 0.0)), + ("blue", (0, 0, 128), "#000080", (0.0, 0.0, 128/255.0)), + ("black", (64, 64, 64), "#404040", (64/255.0, 64/255.0, 64/255.0)), + ]) + def test_color_conversions(self, name, rgb, expected_hex, expected_norm): + """Test various color conversion scenarios.""" + color = Color(name, rgb) + + assert color.to_hex() == expected_hex + assert color.to_normalized_rgb() == pytest.approx(expected_norm) + + # ==================== Predefined Colors Tests ==================== + + def test_predefined_colors_are_singletons(self): + """Test that predefined colors behave like singletons.""" + red1 = Color.RED + red2 = Color.RED + green = Color.GREEN + black = Color.BLACK + + assert red1 is red2 # They should be the same instance + assert red1 is not green + assert red1 is not black diff --git a/sketchgetdp/svg_to_getdp/tests/core/entities/test_outline.py b/sketchgetdp/svg_to_getdp/tests/core/entities/test_outline.py new file mode 100644 index 0000000..ad727c4 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/entities/test_outline.py @@ -0,0 +1,369 @@ +""" +Unit tests for Outline class. + +Tests outline functionality including creation, evaluation, +derivative calculation, corner handling, and geometric properties. +""" +import pytest +from core.entities.point import Point +from core.entities.bezier_segment import BezierSegment +from core.entities.color import Color +from svg_to_getdp.core.entities.outline import Outline + + +class TestOutline: + """Test suite for Outline class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def sample_bezier_segments(self): + """Create sample Bézier segments for testing.""" + # Create three connected quadratic Bézier segments + p0, p1, p2 = Point(0.0, 0.0), Point(0.5, 1.0), Point(1.0, 0.0) + p3, p4 = Point(1.5, 1.0), Point(2.0, 0.0) + + segment1 = BezierSegment([p0, p1, p2], degree=2) + segment2 = BezierSegment([p2, p3, p4], degree=2) # p2 is shared + + return [segment1, segment2] + + @pytest.fixture + def single_line_segment(self): + """Create a single straight line segment.""" + p0, p1 = Point(0.0, 0.0), Point(1.0, 0.0) + return [BezierSegment([p0, p1], degree=1)] + + @pytest.fixture + def discontinuous_segments(self): + """Create discontinuous Bézier segments.""" + p0, p1, p2 = Point(0.0, 0.0), Point(0.5, 1.0), Point(1.0, 0.0) + p3, p4, p5 = Point(1.1, 1.0), Point(1.6, 1.0), Point(2.0, 0.0) # p3 doesn't match p2 + + segment1 = BezierSegment([p0, p1, p2], degree=2) + segment2 = BezierSegment([p3, p4, p5], degree=2) + + return [segment1, segment2] + + @pytest.fixture + def outline_with_corners(self, sample_bezier_segments): + """Create an outline with corners.""" + corners = [Point(1.0, 0.0)] # p2 is a corner + return Outline( + bezier_segments=sample_bezier_segments, + corners=corners, + color=Color.BLUE + ) + + # ==================== Helper Methods ==================== + + def create_sample_bezier_segments(self): + """Helper to create sample Bézier segments for testing.""" + # Create three connected quadratic Bézier segments + p0, p1, p2 = Point(0.0, 0.0), Point(0.5, 1.0), Point(1.0, 0.0) + p3, p4 = Point(1.5, 1.0), Point(2.0, 0.0) + + segment1 = BezierSegment([p0, p1, p2], degree=2) + segment2 = BezierSegment([p2, p3, p4], degree=2) # p2 is shared + + return [segment1, segment2] + + # ==================== Initialization Tests ==================== + + def test_outline_creation(self): + """Test basic creation of Outline.""" + segments = self.create_sample_bezier_segments() + corners = [Point(1.0, 0.0)] + color = Color.BLACK + + outline = Outline( + bezier_segments=segments, + corners=corners, + color=color + ) + + assert len(outline.bezier_segments) == 2 + assert len(outline.corners) == 1 + assert outline.color == color + assert outline.is_closed == True + + def test_outline_creation_open(self): + """Test creation of open Outline.""" + segments = self.create_sample_bezier_segments() + outline = Outline( + bezier_segments=segments, + corners=[], + color=Color.BLUE, + is_closed=False + ) + + assert outline.is_closed == False + + def test_empty_segments_raises_error(self): + """Test that empty segments list raises error.""" + with pytest.raises(ValueError, match="Outline must have at least one Bézier segment"): + Outline( + bezier_segments=[], + corners=[], + color=Color.BLUE + ) + + def test_discontinuous_segments_raises_error(self): + """Test that discontinuous segments raises warning.""" + p0, p1, p2 = Point(0.0, 0.0), Point(0.5, 1.0), Point(1.0, 0.0) + p3, p4, p5 = Point(1.1, 1.0), Point(1.6, 1.0), Point(2.0, 0.0) # p3 doesn't match p2 + + segment1 = BezierSegment([p0, p1, p2], degree=2) + segment2 = BezierSegment([p3, p4, p5], degree=2) # Not connected to segment1 + + outline = Outline( + bezier_segments=[segment1, segment2], + corners=[], + color=Color.GREEN + ) + + # Verify it creates the outline (only warning printed) + assert len(outline) == 2 + + # ==================== Property Tests ==================== + + def test_control_points_property(self): + """Test control_points property aggregates all segment control points.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + control_points = outline.control_points + unique_control_points = outline.unique_control_points + + # control_points includes duplicates at interfaces + assert len(control_points) == 6 # 3 from seg1 + 3 from seg2 (including duplicate interface) + + # unique_control_points removes duplicates + assert len(unique_control_points) == 5 # 3 from seg1 + 2 from seg2 (excluding duplicate interface) + + # Verify the points are in the correct order + assert control_points[0] == segments[0].control_points[0] # p0 + assert control_points[1] == segments[0].control_points[1] # p1 + assert control_points[2] == segments[0].control_points[2] # p2 (interface) + assert control_points[3] == segments[1].control_points[0] # p2 (interface - duplicate) + assert control_points[4] == segments[1].control_points[1] # p3 + assert control_points[5] == segments[1].control_points[2] # p4 + + # ==================== Evaluation Tests ==================== + + def test_evaluate_single_segment(self): + """Test evaluation with single Bézier segment.""" + p0, p1 = Point(0.0, 0.0), Point(1.0, 1.0) + segment = BezierSegment([p0, p1], degree=1) + outline = Outline([segment], corners=[], color=Color.BLUE) + + # Test start, middle, end + assert outline.evaluate(0.0).x == pytest.approx(p0.x) + assert outline.evaluate(0.0).y == pytest.approx(p0.y) + assert outline.evaluate(0.5).x == pytest.approx(0.5) + assert outline.evaluate(0.5).y == pytest.approx(0.5) + assert outline.evaluate(1.0).x == pytest.approx(p1.x) + assert outline.evaluate(1.0).y == pytest.approx(p1.y) + + def test_evaluate_multiple_segments(self): + """Test evaluation with multiple Bézier segments.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + # Test segment interfaces + assert outline.evaluate(0.0).x == pytest.approx(segments[0].start_point.x) + assert outline.evaluate(0.0).y == pytest.approx(segments[0].start_point.y) + assert outline.evaluate(0.5).x == pytest.approx(segments[1].start_point.x) + assert outline.evaluate(0.5).y == pytest.approx(segments[1].start_point.y) + assert outline.evaluate(1.0).x == pytest.approx(segments[1].end_point.x) + assert outline.evaluate(1.0).y == pytest.approx(segments[1].end_point.y) + + # Test within first segment + point1 = outline.evaluate(0.25) + expected1 = segments[0].evaluate(0.5) # t=0.25 global = t=0.5 local in first segment + assert point1.x == pytest.approx(expected1.x) + assert point1.y == pytest.approx(expected1.y) + + # Test within second segment + point2 = outline.evaluate(0.75) + expected2 = segments[1].evaluate(0.5) # t=0.75 global = t=0.5 local in second segment + assert point2.x == pytest.approx(expected2.x) + assert point2.y == pytest.approx(expected2.y) + + def test_evaluate_parameter_range(self): + """Test that evaluation only works for t in [0,1].""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0,1\\]"): + outline.evaluate(-0.1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0,1\\]"): + outline.evaluate(1.1) + + # ==================== Derivative Tests ==================== + + def test_derivative_single_segment(self): + """Test derivative calculation with single segment.""" + p0, p1 = Point(0.0, 0.0), Point(2.0, 2.0) + segment = BezierSegment([p0, p1], degree=1) + outline = Outline([segment], corners=[], color=Color.BLUE) + + # Derivative should be scaled by number of segments (1 in this case) + derivative = outline.derivative(0.5) + expected = Point(2.0, 2.0) # Same as segment derivative since N_C=1 + + assert derivative.x == pytest.approx(expected.x) + assert derivative.y == pytest.approx(expected.y) + + def test_derivative_multiple_segments(self): + """Test derivative calculation with multiple segments.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + # Test derivative in first segment (should be scaled by N_C=2) + derivative1 = outline.derivative(0.25) + segment_deriv1 = segments[0].derivative(0.5) # Local t=0.5 for global t=0.25 + expected1 = Point(segment_deriv1.x * 2, segment_deriv1.y * 2) + assert derivative1.x == pytest.approx(expected1.x) + assert derivative1.y == pytest.approx(expected1.y) + + # Test derivative in second segment + derivative2 = outline.derivative(0.75) + segment_deriv2 = segments[1].derivative(0.5) # Local t=0.5 for global t=0.75 + expected2 = Point(segment_deriv2.x * 2, segment_deriv2.y * 2) + assert derivative2.x == pytest.approx(expected2.x) + assert derivative2.y == pytest.approx(expected2.y) + + def test_derivative_parameter_range(self): + """Test that derivative only works for t in [0,1].""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0,1\\]"): + outline.derivative(-0.1) + + with pytest.raises(ValueError, match="Parameter t must be in \\[0,1\\]"): + outline.derivative(1.1) + + # ==================== Corner Handling Tests ==================== + + def test_is_corner_at_parameter(self): + """Test corner detection at parameter values.""" + segments = self.create_sample_bezier_segments() + corner_point = segments[0].end_point # p2 + outline = Outline(segments, corners=[corner_point], color=Color.BLUE) + + # Should detect corner at t=0.5 (interface between segments) + assert outline.is_corner_at_parameter(0.5) == True + + # Should not detect corner at other parameters + assert outline.is_corner_at_parameter(0.0) == False + assert outline.is_corner_at_parameter(0.25) == False + assert outline.is_corner_at_parameter(0.75) == False + assert outline.is_corner_at_parameter(1.0) == False + + def test_is_corner_at_segment_interface(self): + """Test corner detection at segment interfaces.""" + segments = self.create_sample_bezier_segments() + corner_point = segments[0].end_point # p2 + outline = Outline(segments, corners=[corner_point], color=Color.BLUE) + + # Interface 0 (between segment 0 and 1) should be a corner + assert outline.is_corner_at_segment_interface(0) == True + + # Test invalid interface indices + with pytest.raises(ValueError, match="Invalid segment index for interface check"): + outline.is_corner_at_segment_interface(-1) + + with pytest.raises(ValueError, match="Invalid segment index for interface check"): + outline.is_corner_at_segment_interface(1) # Only interfaces 0 to N-2 + + # ==================== Geometric Property Tests ==================== + + def test_get_segment_at_parameter(self): + """Test getting segment and local parameter for global parameter.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + # Test first segment + segment1, local_t1 = outline.get_segment_at_parameter(0.25) + assert segment1 == segments[0] + assert local_t1 == pytest.approx(0.5) + + # Test second segment + segment2, local_t2 = outline.get_segment_at_parameter(0.75) + assert segment2 == segments[1] + assert local_t2 == pytest.approx(0.5) + + # Test start and end + segment_start, local_t_start = outline.get_segment_at_parameter(0.0) + assert segment_start == segments[0] + assert local_t_start == pytest.approx(0.0) + + segment_end, local_t_end = outline.get_segment_at_parameter(1.0) + assert segment_end == segments[1] + assert local_t_end == pytest.approx(1.0) + + def test_get_outline_points(self): + """Test sampling points along entire outline.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + points = outline.get_outline_points(num_points=5) + + assert len(points) == 5 + assert points[0].x == pytest.approx(segments[0].start_point.x) + assert points[0].y == pytest.approx(segments[0].start_point.y) + assert points[2].x == pytest.approx(segments[0].end_point.x) + assert points[2].y == pytest.approx(segments[0].end_point.y) + assert points[4].x == pytest.approx(segments[1].end_point.x) + assert points[4].y == pytest.approx(segments[1].end_point.y) + + def test_get_outline_points_invalid_count(self): + """Test that invalid point count raises error.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + with pytest.raises(ValueError, match="Number of points must be at least 2"): + outline.get_outline_points(num_points=1) + + def test_get_outline_length_approximation(self): + """Test outline length approximation.""" + p0, p1 = Point(0.0, 0.0), Point(1.0, 0.0) + segment = BezierSegment([p0, p1], degree=1) + outline = Outline([segment], corners=[], color=Color.BLUE) + + length = outline.get_outline_length_approximation(num_samples=10) + + # Straight line from (0,0) to (1,0) should have length 1.0 + assert length == pytest.approx(1.0, rel=1e-2) + + # ==================== Interface Tests ==================== + + def test_len_operator(self): + """Test len() operator returns number of segments.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + assert len(outline) == 2 + + def test_iteration(self): + """Test iteration over Bézier segments.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[], color=Color.BLUE) + + segment_list = list(outline) + assert segment_list == segments + + def test_repr(self): + """Test string representation.""" + segments = self.create_sample_bezier_segments() + outline = Outline(segments, corners=[Point(1.0, 0.0)], color=Color.GREEN) + + repr_str = repr(outline) + assert "Outline" in repr_str + assert "segments=2" in repr_str + assert "corners=1" in repr_str + assert "color=green" in repr_str + assert "closed=True" in repr_str diff --git a/sketchgetdp/svg_to_getdp/tests/core/entities/test_physical_group.py b/sketchgetdp/svg_to_getdp/tests/core/entities/test_physical_group.py new file mode 100644 index 0000000..5eeaab8 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/entities/test_physical_group.py @@ -0,0 +1,340 @@ +""" +Unit tests for PhysicalGroup entity class. + +Tests the creation and validation of physical groups used in the FEM mesh generation, +including domains, boundaries, and coil regions. +""" +import pytest + +from svg_to_getdp.core.entities.physical_group import ( + PhysicalGroup, + DOMAIN_VI_IRON, + DOMAIN_VI_AIR, + DOMAIN_VA, + DOMAIN_COIL_POSITIVE, + DOMAIN_COIL_NEGATIVE, + BOUNDARY_GAMMA, + BOUNDARY_OUT +) +from svg_to_getdp.core.entities.color import Color + + +class TestPhysicalGroup: + """Test suite for PhysicalGroup entity class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def valid_domain(self): + """Create a valid domain physical group.""" + return PhysicalGroup( + name="test_domain", + description="Test domain description", + group_type="domain", + value=100 + ) + + @pytest.fixture + def valid_boundary(self): + """Create a valid boundary physical group.""" + return PhysicalGroup( + name="test_boundary", + description="Test boundary description", + group_type="boundary", + value=200, + color=Color.BLUE + ) + + @pytest.fixture + def valid_coil_positive(self): + """Create a valid positive coil domain.""" + return PhysicalGroup( + name="coil_positive", + description="Positive coil domain", + group_type="domain", + value=101, + color=Color.RED, + current_sign=1 + ) + + @pytest.fixture + def valid_coil_negative(self): + """Create a valid negative coil domain.""" + return PhysicalGroup( + name="domain_coil_negative", + description="Negative coil domain", + group_type="domain", + value=102, + color=Color.RED, + current_sign=-1 + ) + + # ==================== Basic Creation Tests ==================== + + def test_valid_domain_creation(self, valid_domain): + """Test creating a valid domain physical group.""" + pg = valid_domain + + assert pg.name == "test_domain" + assert pg.description == "Test domain description" + assert pg.group_type == "domain" + assert pg.value == 100 + assert pg.color is None + assert pg.current_sign is None + assert pg.is_domain() is True + assert pg.is_boundary() is False + assert pg.has_color() is False + assert pg.is_coil() is False + + def test_valid_boundary_creation(self, valid_boundary): + """Test creating a valid boundary physical group.""" + pg = valid_boundary + + assert pg.name == "test_boundary" + assert pg.description == "Test boundary description" + assert pg.group_type == "boundary" + assert pg.value == 200 + assert pg.color == Color.BLUE + assert pg.current_sign is None + assert pg.is_boundary() is True + assert pg.is_domain() is False + assert pg.has_color() is True + assert pg.is_coil() is False + + def test_valid_coil_creation(self, valid_coil_positive, valid_coil_negative): + """Test creating valid coil domains.""" + # Positive coil + pg_pos = valid_coil_positive + assert pg_pos.name == "coil_positive" + assert pg_pos.group_type == "domain" + assert pg_pos.color == Color.RED + assert pg_pos.current_sign == 1 + assert pg_pos.is_coil() is True + assert pg_pos.is_domain() is True + + # Negative coil + pg_neg = valid_coil_negative + assert pg_neg.name == "domain_coil_negative" + assert pg_neg.color == Color.RED + assert pg_neg.current_sign == -1 + assert pg_neg.is_coil() is True + + # ==================== Validation Tests ==================== + + def test_invalid_name_type(self): + """Test invalid name type.""" + with pytest.raises(TypeError, match="Physical group name must be a string"): + PhysicalGroup( + name=123, + description="Test", + group_type="domain", + value=100 + ) + + def test_invalid_description_type(self): + """Test invalid description type.""" + with pytest.raises(TypeError, match="Physical group description must be a string"): + PhysicalGroup( + name="test", + description=456, + group_type="domain", + value=100 + ) + + def test_invalid_group_type(self): + """Test invalid group type.""" + with pytest.raises(ValueError, match="Group type must be either 'domain' or 'boundary'"): + PhysicalGroup( + name="test", + description="Test", + group_type="invalid_type", + value=100 + ) + + def test_invalid_value_type(self): + """Test invalid value type.""" + with pytest.raises(TypeError, match="Value must be an integer"): + PhysicalGroup( + name="test", + description="Test", + group_type="domain", + value="not_an_int" + ) + + def test_invalid_color_type(self): + """Test invalid color type.""" + with pytest.raises(TypeError, match="Color must be an instance of Color class or None"): + PhysicalGroup( + name="test", + description="Test", + group_type="domain", + value=100, + color="not_a_color" + ) + + def test_invalid_current_sign(self): + """Test invalid current sign value.""" + with pytest.raises(ValueError, match=r"Current sign must be None, 1 \(positive\), or -1 \(negative\)"): + PhysicalGroup( + name="test", + description="Test", + group_type="domain", + value=100, + current_sign=2 + ) + + def test_coil_missing_current_sign(self): + """Test coil domain without current sign.""" + with pytest.raises(ValueError, match="Coil domains must have a current sign"): + PhysicalGroup( + name="coil_test", + description="Coil test", + group_type="domain", + value=100, + color=Color.RED, + current_sign=None + ) + + def test_coil_wrong_color(self): + """Test coil domain with wrong color.""" + with pytest.raises(ValueError, match="Coil domains must be red"): + PhysicalGroup( + name="domain_coil_positive", + description="Coil with wrong color", + group_type="domain", + value=100, + color=Color.BLUE, + current_sign=1 + ) + + def test_non_coil_with_current_sign(self): + """Test non-coil domain with current sign.""" + with pytest.raises(ValueError, match="Only coil domains can have a current sign"): + PhysicalGroup( + name="regular_domain", + description="Regular domain", + group_type="domain", + value=100, + current_sign=1 + ) + + # ==================== Method Tests ==================== + + def test_is_coil_method(self): + """Test the is_coil() method.""" + # Should be True for domains with "coil" in name + coil_pg = PhysicalGroup( + name="some_coil_domain", + description="Coil domain", + group_type="domain", + value=100, + color=Color.RED, + current_sign=1 + ) + assert coil_pg.is_coil() is True + + # Should be False for boundaries even with "coil" in name + coil_boundary = PhysicalGroup( + name="boundary_coil", + description="Coil boundary", + group_type="boundary", + value=200 + ) + assert coil_boundary.is_coil() is False + + # Should be False for domains without "coil" in name + non_coil = PhysicalGroup( + name="regular_domain", + description="Regular", + group_type="domain", + value=300 + ) + assert non_coil.is_coil() is False + + def test_has_color_method(self): + """Test the has_color() method.""" + pg_with_color = PhysicalGroup( + name="test", + description="Test", + group_type="domain", + value=100, + color=Color.BLUE + ) + assert pg_with_color.has_color() is True + + pg_without_color = PhysicalGroup( + name="test", + description="Test", + group_type="domain", + value=100 + ) + assert pg_without_color.has_color() is False + + # ==================== Immutability Tests ==================== + + def test_frozen_dataclass(self, valid_domain): + """Test that PhysicalGroup is immutable (frozen dataclass).""" + pg = valid_domain + + with pytest.raises(Exception): + pg.name = "modified" + + with pytest.raises(Exception): + pg.value = 200 + + # ==================== Module Constants Tests ==================== + + def test_module_constants(self): + """Test the module-level constants.""" + # Test DOMAIN_VI_IRON + assert DOMAIN_VI_IRON.name == "domain_Vi_iron" + assert DOMAIN_VI_IRON.description == "Iron domain in Vi region" + assert DOMAIN_VI_IRON.group_type == "domain" + assert DOMAIN_VI_IRON.value == 2 + assert DOMAIN_VI_IRON.color == Color.BLUE + + # Test DOMAIN_VI_AIR + assert DOMAIN_VI_AIR.name == "domain_Vi_air" + assert DOMAIN_VI_AIR.description == "Air domain in Vi region" + assert DOMAIN_VI_AIR.group_type == "domain" + assert DOMAIN_VI_AIR.value == 3 + assert DOMAIN_VI_AIR.color == Color.GREEN + + # Test DOMAIN_VA + assert DOMAIN_VA.name == "domain_Va" + assert DOMAIN_VA.description == "Va domain" + assert DOMAIN_VA.group_type == "domain" + assert DOMAIN_VA.value == 1 + assert DOMAIN_VA.color == Color.BLACK + + # Test DOMAIN_COIL_POSITIVE + assert DOMAIN_COIL_POSITIVE.name == "domain_coil_positive" + assert DOMAIN_COIL_POSITIVE.description == "Coil domain with positive current" + assert DOMAIN_COIL_POSITIVE.group_type == "domain" + assert DOMAIN_COIL_POSITIVE.value == 101 + assert DOMAIN_COIL_POSITIVE.color == Color.RED + assert DOMAIN_COIL_POSITIVE.current_sign == 1 + assert DOMAIN_COIL_POSITIVE.is_coil() is True + + # Test DOMAIN_COIL_NEGATIVE + assert DOMAIN_COIL_NEGATIVE.name == "domain_coil_negative" + assert DOMAIN_COIL_NEGATIVE.description == "Coil domain with negative current" + assert DOMAIN_COIL_NEGATIVE.group_type == "domain" + assert DOMAIN_COIL_NEGATIVE.value == 102 + assert DOMAIN_COIL_NEGATIVE.color == Color.RED + assert DOMAIN_COIL_NEGATIVE.current_sign == -1 + assert DOMAIN_COIL_NEGATIVE.is_coil() is True + + # Test BOUNDARY_GAMMA + assert BOUNDARY_GAMMA.name == "boundary_gamma" + assert BOUNDARY_GAMMA.description == "Interface boundary between Vi and Va regions" + assert BOUNDARY_GAMMA.group_type == "boundary" + assert BOUNDARY_GAMMA.value == 11 + assert BOUNDARY_GAMMA.is_boundary() is True + + # Test BOUNDARY_OUT + assert BOUNDARY_OUT.name == "boundary_out" + assert BOUNDARY_OUT.description == "Outermost boundary" + assert BOUNDARY_OUT.group_type == "boundary" + assert BOUNDARY_OUT.value == 12 + assert BOUNDARY_OUT.is_boundary() is True \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/core/entities/test_point.py b/sketchgetdp/svg_to_getdp/tests/core/entities/test_point.py new file mode 100644 index 0000000..d7a2136 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/entities/test_point.py @@ -0,0 +1,264 @@ +""" +Unit tests for Point class. + +Tests the minimal Point entity with Euclidean distance and vector operations. +""" +import pytest +import math + +from svg_to_getdp.core.entities.point import Point + + +class TestPoint: + """Test suite for Point class.""" + + # ==================== Basic Functionality Tests ==================== + + def test_point_creation(self): + """Test that a point can be created with coordinates.""" + point = Point(3, 4) + assert point.x == 3 + assert point.y == 4 + + def test_point_default_origin(self): + """Test that point defaults to origin (0,0).""" + point = Point() + assert point.x == 0 + assert point.y == 0 + + def test_point_immutability(self): + """Test that Point is immutable.""" + point = Point(1, 2) + + with pytest.raises(AttributeError): + point.x = 5 + with pytest.raises(AttributeError): + point.y = 5 + + def test_point_equality(self): + """Test that points with same coordinates are equal.""" + point1 = Point(3, 4) + point2 = Point(3, 4) + point3 = Point(3, 5) + + assert point1 == point2 + assert point1 != point3 + + def test_point_hash(self): + """Test that points are hashable.""" + point1 = Point(1, 2) + point2 = Point(1, 2) + point3 = Point(3, 4) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are duplicates + assert point1 in point_set + assert point2 in point_set + assert point3 in point_set + + def test_point_repr(self): + """Test the string representation of Point.""" + point = Point(5, 6) + repr_str = repr(point) + + assert "Point" in repr_str + assert "5" in repr_str + assert "6" in repr_str + + def test_point_str(self): + """Test the human-readable string representation.""" + point = Point(7, 8) + str_repr = str(point) + + assert "7" in str_repr + assert "8" in str_repr + + def test_distance_to_origin(self): + """Test distance calculation from origin.""" + point = Point(3, 4) + distance = point.distance_to_origin() + + assert distance == 5.0 # 3-4-5 triangle + + def test_distance_to_other_point(self): + """Test distance calculation between two points.""" + point1 = Point(1, 1) + point2 = Point(4, 5) + + distance = point1.distance_to(point2) + expected_distance = math.sqrt((4-1)**2 + (5-1)**2) + + assert distance == pytest.approx(expected_distance) + + def test_distance_to_same_point(self): + """Test distance from point to itself.""" + point = Point(3, 4) + distance = point.distance_to(point) + + assert distance == 0.0 + + def test_invalid_coordinate_types(self): + """Test that point rejects non-numeric coordinates.""" + with pytest.raises(TypeError): + Point("1", 2) + with pytest.raises(TypeError): + Point(1, "2") + with pytest.raises(TypeError): + Point(None, 2) + with pytest.raises(TypeError): + Point(1, [2]) + + def test_nan_coordinates(self): + """Test that point rejects NaN values.""" + with pytest.raises(ValueError): + Point(float('nan'), 1) + with pytest.raises(ValueError): + Point(1, float('nan')) + + def test_integer_coordinates(self): + """Test that point accepts integer coordinates.""" + point = Point(1, 2) # integers + assert point.x == 1 + assert point.y == 2 + + # Should work with float operations + distance = point.distance_to_origin() + assert isinstance(distance, float) + + def test_float_coordinates(self): + """Test that point accepts float coordinates.""" + point = Point(1.5, 2.5) + assert point.x == 1.5 + assert point.y == 2.5 + + # ==================== Vector Operation Tests ==================== + + def test_vector_addition(self): + """Test vector addition of two points.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + result = point1 + point2 + + assert result == Point(4, 6) + assert isinstance(result, Point) + + def test_vector_subtraction(self): + """Test vector subtraction of two points.""" + point1 = Point(5, 6) + point2 = Point(2, 3) + result = point1 - point2 + + assert result == Point(3, 3) + assert isinstance(result, Point) + + def test_scalar_multiplication(self): + """Test scalar multiplication.""" + point = Point(2, 3) + result = point * 2.5 + + assert result == Point(5.0, 7.5) + assert isinstance(result, Point) + + def test_reverse_scalar_multiplication(self): + """Test reverse scalar multiplication.""" + point = Point(2, 3) + result = 2.5 * point + + assert result == Point(5.0, 7.5) + assert isinstance(result, Point) + + def test_scalar_division(self): + """Test scalar division.""" + point = Point(6, 9) + result = point / 3 + + assert result == Point(2.0, 3.0) + assert isinstance(result, Point) + + def test_scalar_division_by_zero(self): + """Test that scalar division by zero raises ValueError.""" + point = Point(1, 2) + + with pytest.raises(ValueError, match="Division by zero"): + point / 0 + + def test_norm_calculation(self): + """Test Euclidean norm calculation.""" + point = Point(3, 4) + norm = point.norm() + + assert norm == 5.0 + assert isinstance(norm, float) + + def test_norm_zero(self): + """Test norm calculation for zero vector.""" + point = Point(0, 0) + norm = point.norm() + + assert norm == 0.0 + + def test_equality_with_floating_point_precision(self): + """Test equality comparison with floating point precision.""" + point1 = Point(1.0, 2.0) + point2 = Point(1.0 + 1e-10, 2.0 - 1e-10) # Very close values + + assert point1 == point2 # Should be equal due to math.isclose + + def test_equality_with_different_types(self): + """Test equality comparison with non-Point types.""" + point = Point(1, 2) + + assert point != (1, 2) + assert point != [1, 2] + assert point != "Point(1, 2)" + assert point != 1 + + def test_vector_operations_chain(self): + """Test chaining of vector operations.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + point3 = Point(5, 6) + + result = point1 + point2 - point3 + expected = Point(-1, 0) + + assert result == expected + + def test_mixed_operations(self): + """Test mixed scalar and vector operations.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + + result = 2 * point1 + point2 / 2 + expected = Point(3.5, 6.0) # 2*(1,2) + (3,4)/2 = (2,4) + (1.5,2) = (3.5,6) + + assert result == expected + + def test_commutative_property(self): + """Test commutative property of addition.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + + assert point1 + point2 == point2 + point1 + + def test_associative_property_addition(self): + """Test associative property of addition.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + point3 = Point(5, 6) + + result1 = (point1 + point2) + point3 + result2 = point1 + (point2 + point3) + + assert result1 == result2 + + def test_distributive_property(self): + """Test distributive property of scalar multiplication over addition.""" + point1 = Point(1, 2) + point2 = Point(3, 4) + scalar = 2 + + result1 = scalar * (point1 + point2) + result2 = (scalar * point1) + (scalar * point2) + + assert result1 == result2 \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_geometry_to_gmsh.py b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_geometry_to_gmsh.py new file mode 100644 index 0000000..2d46177 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_geometry_to_gmsh.py @@ -0,0 +1,494 @@ +""" +Unit tests for ConvertGeometryToGmsh use case. + +Tests geometry to Gmsh conversion functionality with various outlines, +wire configurations, and edge cases. +""" + +import os +import tempfile +import time +from unittest.mock import Mock, patch +import pytest +import yaml + +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.physical_group import ( + DOMAIN_VI_IRON, + DOMAIN_VI_AIR, + BOUNDARY_OUT, + DOMAIN_COIL_POSITIVE, + DOMAIN_COIL_NEGATIVE, +) +from svg_to_getdp.core.use_cases.convert_geometry_to_gmsh import ConvertGeometryToGmsh + + +class TestConvertGeometryToGmsh: + """Test suite for ConvertGeometryToGmsh class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def converter(self): + """Create a ConvertGeometryToGmsh instance for testing.""" + return ConvertGeometryToGmsh() + + @pytest.fixture + def temporary_configuration_file(self): + """Create a temporary configuration file for testing.""" + configuration = { + "wire_currents": {"wire_1": 1, "wire_2": -1}, + "mesh_size": 0.1, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as file: + yaml.dump(configuration, file) + config_path = file.name + + yield config_path + + if os.path.exists(config_path): + os.unlink(config_path) + + @pytest.fixture + def sample_outlines(self): + """Create sample outlines for testing.""" + outer_outline = Outline( + bezier_segments=[ + BezierSegment( + [Point(0.0, 0.0), Point(0.5, 0.0), Point(1.0, 0.0)], degree=2 + ), + BezierSegment( + [Point(1.0, 0.0), Point(1.0, 1.0), Point(0.0, 1.0)], degree=2 + ), + BezierSegment( + [Point(0.0, 1.0), Point(0.0, 0.0), Point(0.0, 0.0)], degree=2 + ), + ], + corners=[ + Point(0.0, 0.0), + Point(1.0, 0.0), + Point(1.0, 1.0), + Point(0.0, 1.0), + ], + color=Color.BLUE, + is_closed=True, + ) + + inner_outline = Outline( + bezier_segments=[ + BezierSegment( + [Point(0.2, 0.2), Point(0.5, 0.2), Point(0.8, 0.2)], degree=2 + ), + BezierSegment( + [Point(0.8, 0.2), Point(0.8, 0.8), Point(0.2, 0.8)], degree=2 + ), + BezierSegment( + [Point(0.2, 0.8), Point(0.2, 0.2), Point(0.2, 0.2)], degree=2 + ), + ], + corners=[ + Point(0.2, 0.2), + Point(0.8, 0.2), + Point(0.8, 0.8), + Point(0.2, 0.8), + ], + color=Color.GREEN, + is_closed=True, + ) + + return [outer_outline, inner_outline] + + @pytest.fixture + def sample_wires(self): + """Create sample wire points for testing.""" + return [ + (Point(0.3, 0.3), Color.RED), + (Point(0.7, 0.7), Color.RED), + ] + + @pytest.fixture + def gmsh_mocks(self): + """Mock all Gmsh toolbox functions.""" + with patch( + "svg_to_getdp.core.use_cases.convert_geometry_to_gmsh.initialize_gmsh" + ) as mock_init, patch( + "svg_to_getdp.core.use_cases.convert_geometry_to_gmsh.set_characteristic_mesh_length" + ) as mock_set_mesh, patch( + "svg_to_getdp.core.use_cases.convert_geometry_to_gmsh.mesh_and_save" + ) as mock_mesh_save, patch( + "svg_to_getdp.core.use_cases.convert_geometry_to_gmsh.show_model" + ) as mock_show, patch( + "svg_to_getdp.core.use_cases.convert_geometry_to_gmsh.finalize_gmsh" + ) as mock_finalize: + + mock_factory = Mock() + mock_factory.synchronize = Mock() + mock_init.return_value = mock_factory + + yield { + "initialize_gmsh": mock_init, + "set_characteristic_mesh_length": mock_set_mesh, + "mesh_and_save": mock_mesh_save, + "show_model": mock_show, + "finalize_gmsh": mock_finalize, + "factory": mock_factory, + } + + @pytest.fixture + def many_outlines(self): + """Create many outlines for performance testing.""" + many_outlines = [] + for i in range(10): + bezier_segments = [ + BezierSegment( + [Point(i, i), Point(i + 1, i), Point(i + 1, i + 1)], degree=2 + ), + BezierSegment( + [Point(i + 1, i + 1), Point(i, i + 1), Point(i, i)], degree=2 + ), + ] + outline = Outline( + bezier_segments=bezier_segments, + corners=[ + Point(i, i), + Point(i + 1, i), + Point(i + 1, i + 1), + Point(i, i + 1), + ], + color=Color.BLUE, + is_closed=True, + ) + many_outlines.append(outline) + return many_outlines + + # ==================== Initialization Tests ==================== + + def test_initializes_without_dependencies(self, converter): + """Test that converter initializes without parameters.""" + assert converter is not None + assert hasattr(converter, 'outline_grouper') + assert hasattr(converter, 'outline_preprocessor') + assert hasattr(converter, 'wire_preprocessor') + + # ==================== Basic Functionality Tests ==================== + + def test_executes_successfully( + self, + converter, + sample_outlines, + sample_wires, + temporary_configuration_file, + gmsh_mocks, + ): + """Test successful execution of the geometry to Gmsh conversion.""" + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines: + + wire_results = { + 0: { + "original_index": 0, + "point": Point(0.3, 0.3), + "color": Color.RED, + "gmsh_point_tag": 1, + "physical_group": DOMAIN_COIL_POSITIVE, + "wire_name": "wire_1", + }, + 1: { + "original_index": 1, + "point": Point(0.7, 0.7), + "color": Color.RED, + "gmsh_point_tag": 2, + "physical_group": DOMAIN_COIL_NEGATIVE, + "wire_name": "wire_2", + }, + } + mock_prepare_wires.return_value = wire_results + + grouping_result = [ + { + "holes": [1], + "physical_groups": [DOMAIN_VI_IRON, BOUNDARY_OUT], + }, + {"holes": [], "physical_groups": [DOMAIN_VI_AIR]}, + ] + mock_group_outlines.return_value = grouping_result + + result = converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=temporary_configuration_file, + model_name="test_model", + output_filename="test_mesh", + dimension=2, + show_gui=False, + ) + + # Verify Gmsh initialization + gmsh_mocks["initialize_gmsh"].assert_called_once_with("test_model") + gmsh_mocks["set_characteristic_mesh_length"].assert_called_once_with(0.1) + + # Verify dependencies are called correctly + mock_prepare_wires.assert_called_once_with( + gmsh_mocks["factory"], + temporary_configuration_file, + sample_wires, + ) + mock_group_outlines.assert_called_once_with(sample_outlines) + mock_preprocess_outlines.assert_called_once_with( + gmsh_mocks["factory"], sample_outlines, grouping_result + ) + + # Verify Gmsh operations + gmsh_mocks["factory"].synchronize.assert_called_once() + gmsh_mocks["mesh_and_save"].assert_called_once_with("test_mesh", 2) + gmsh_mocks["show_model"].assert_not_called() + gmsh_mocks["finalize_gmsh"].assert_called_once() + + # Verify result structure + assert result["model_name"] == "test_model" + assert result["output_filename"] == "test_mesh" + assert result["mesh_size"] == 0.1 + assert result["wire_results"] == wire_results + assert result["geometry_synchronized"] is True + assert result["mesh_generated"] is True + assert "gui_shown" not in result + + def test_executes_with_gui( + self, + converter, + sample_outlines, + sample_wires, + temporary_configuration_file, + gmsh_mocks, + ): + """Test execution with GUI display enabled.""" + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines: + + mock_prepare_wires.return_value = {} + mock_group_outlines.return_value = [] + + result = converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=temporary_configuration_file, + model_name="test_model", + output_filename="test_mesh", + dimension=2, + show_gui=True, + ) + + gmsh_mocks["show_model"].assert_called_once() + assert result["gui_shown"] is True + + # ==================== Edge Case Tests ==================== + + def test_warns_when_no_outlines_provided( + self, converter, sample_wires, temporary_configuration_file, gmsh_mocks + ): + """Test warning when no outlines are provided.""" + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines, patch( + "builtins.print" + ) as mock_print: + + mock_prepare_wires.return_value = {} + mock_group_outlines.return_value = [] + + converter.execute( + outlines=[], + wires=sample_wires, + config_file_path=temporary_configuration_file, + show_gui=False, + ) + + mock_print.assert_any_call("Warning: No outlines provided") + + def test_handles_different_mesh_sizes( + self, converter, sample_outlines, sample_wires, gmsh_mocks + ): + """Test handling of different mesh sizes from configuration.""" + configuration = { + "wire_currents": {"wire_1": 1, "wire_2": -1}, + "mesh_size": 0.05, + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as file: + yaml.dump(configuration, file) + config_path = file.name + + try: + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines: + + mock_prepare_wires.return_value = {} + mock_group_outlines.return_value = [] + + result = converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=config_path, + show_gui=False, + ) + + gmsh_mocks["set_characteristic_mesh_length"].assert_called_once_with( + 0.05 + ) + assert result["mesh_size"] == 0.05 + finally: + if os.path.exists(config_path): + os.unlink(config_path) + + # ==================== Error Handling Tests ==================== + + def test_rejects_invalid_outline_type( + self, converter, sample_wires, temporary_configuration_file + ): + """Test rejection of invalid outlines type.""" + with pytest.raises(ValueError, match="outlines must be a list"): + converter.execute( + outlines="not a list", + wires=sample_wires, + config_file_path=temporary_configuration_file, + ) + + def test_rejects_invalid_wires_type( + self, converter, sample_outlines, temporary_configuration_file + ): + """Test rejection of invalid wires type.""" + with pytest.raises(ValueError, match="wires must be a list"): + converter.execute( + outlines=sample_outlines, + wires="not a list", + config_file_path=temporary_configuration_file, + ) + + def test_rejects_nonexistent_configuration_file( + self, converter, sample_outlines, sample_wires + ): + """Test rejection of nonexistent configuration file.""" + nonexistent_config = "/path/to/nonexistent/config.yaml" + + with pytest.raises( + FileNotFoundError, + match=f"Configuration file not found: {nonexistent_config}", + ): + converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=nonexistent_config, + ) + + def test_handles_exceptions_gracefully( + self, + converter, + sample_outlines, + sample_wires, + temporary_configuration_file, + gmsh_mocks, + ): + """Test graceful handling of exceptions during execution.""" + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires: + mock_prepare_wires.side_effect = RuntimeError("Test error") + + with pytest.raises(RuntimeError, match="Test error"): + converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=temporary_configuration_file, + show_gui=False, + ) + + gmsh_mocks["finalize_gmsh"].assert_called_once() + + # ==================== Integration Tests ==================== + + def test_produces_consistent_results_across_runs( + self, + converter, + sample_outlines, + sample_wires, + temporary_configuration_file, + gmsh_mocks, + ): + """Test consistent results across multiple execution runs.""" + results = [] + + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines: + + mock_prepare_wires.return_value = {} + mock_group_outlines.return_value = [] + + for _ in range(3): + result = converter.execute( + outlines=sample_outlines, + wires=sample_wires, + config_file_path=temporary_configuration_file, + show_gui=False, + ) + results.append(result) + + for i in range(1, len(results)): + assert results[i].keys() == results[0].keys() + + # ==================== Performance Tests ==================== + + def test_handles_many_outlines_efficiently( + self, converter, sample_wires, temporary_configuration_file, gmsh_mocks, many_outlines + ): + """Test efficient handling of many outlines.""" + with patch.object( + converter.wire_preprocessor, "prepare_wires" + ) as mock_prepare_wires, patch.object( + converter.outline_grouper, "group_outlines" + ) as mock_group_outlines, patch.object( + converter.outline_preprocessor, "preprocess_outlines" + ) as mock_preprocess_outlines: + + mock_prepare_wires.return_value = {} + mock_group_outlines.return_value = [] + + start_time = time.time() + result = converter.execute( + outlines=many_outlines, + wires=sample_wires, + config_file_path=temporary_configuration_file, + show_gui=False, + ) + end_time = time.time() + + assert end_time - start_time < 5.0 + assert result["mesh_generated"] is True + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_svg_to_geometry.py b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_svg_to_geometry.py new file mode 100644 index 0000000..17d6c95 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_convert_svg_to_geometry.py @@ -0,0 +1,488 @@ +""" +Unit tests for ConvertSVGToGeometry use case. + +Tests the conversion of SVG files to geometric outlines and wires, +including handling of different colors, corner detection, and Bézier fitting. +""" + +import pytest +from unittest.mock import Mock, patch + +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.color import Color + +from svg_to_getdp.core.use_cases.convert_svg_to_geometry import ConvertSVGToGeometry + + +class TestConvertSVGToGeometry: + """Test suite for ConvertSVGToGeometry class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def converter(self): + """Create a converter instance for testing.""" + return ConvertSVGToGeometry() + + @pytest.fixture + def triangle_points(self): + """Create sample points for a triangle shape.""" + return [ + Point(0.0, 0.0), Point(0.5, 0.0), Point(1.0, 0.0), + Point(0.9, 0.1), Point(0.8, 0.2), Point(0.7, 0.3), + Point(0.5, 1.0), Point(0.3, 0.7), Point(0.2, 0.5), + Point(0.1, 0.3), Point(0.0, 0.1), Point(0.0, 0.0) + ] + + @pytest.fixture + def square_points(self): + """Create sample points for a square shape.""" + return [ + Point(0.2, 0.2), Point(0.8, 0.2), Point(0.8, 0.8), + Point(0.2, 0.8), Point(0.2, 0.2) + ] + + @pytest.fixture + def mock_raw_outline_class(self): + """Create a mock RawOutline class for testing.""" + class RawOutline: + def __init__(self, points, is_closed): + self.points = points + self.is_closed = is_closed + return RawOutline + + @pytest.fixture + def mock_bezier_segment(self): + """Create a mock Bézier segment for testing.""" + segment = Mock(spec=BezierSegment) + segment.control_points = [Point(0.0, 0.0), Point(0.3, 0.1), Point(0.5, 0.2)] + return segment + + # ==================== Initialization Tests ==================== + + def test_initialization(self): + """Test that the use case initializes correctly without dependencies.""" + converter = ConvertSVGToGeometry() + + assert converter is not None + assert hasattr(converter, 'svg_parser') + assert hasattr(converter, 'corner_detector') + assert hasattr(converter, 'bezier_fitter') + + # ==================== Color Differentiation Tests ==================== + + def test_red_single_point_wire(self, converter, mock_raw_outline_class): + """Test RED elements with single point become wires.""" + test_svg_path = "test_red_single.svg" + + single_point = [Point(0.5, 0.5)] + mock_raw_outline = mock_raw_outline_class( + points=single_point, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract: + mock_extract.return_value = { + Color.RED: [mock_raw_outline] + } + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 0 + assert len(wires) == 1 + assert wires[0][1] == Color.RED + assert wires[0][0] == single_point[0] + + def test_red_multiple_points_wire(self, converter, mock_raw_outline_class): + """Test RED elements with multiple points become wires using first point.""" + test_svg_path = "test_red_multiple.svg" + + multiple_points = [Point(0.5, 0.5), Point(0.6, 0.6), Point(0.7, 0.7)] + mock_raw_outline = mock_raw_outline_class( + points=multiple_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract: + mock_extract.return_value = { + Color.RED: [mock_raw_outline] + } + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 0 + assert len(wires) == 1 + assert wires[0][1] == Color.RED + assert wires[0][0] == multiple_points[0] # First point used for wire + + def test_green_outline_processing(self, converter, triangle_points, mock_raw_outline_class): + """Test GREEN elements become outlines with Bézier fitting.""" + test_svg_path = "test_green.svg" + + mock_raw_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = { + Color.GREEN: [mock_raw_outline] + } + + mock_corner_indices = [0, 3, 6] + mock_debug_data = {'some': 'debug'} + mock_detect_corners.return_value = (mock_corner_indices, mock_debug_data) + + mock_outline = Mock(spec=Outline) + mock_outline.color = Color.GREEN + mock_outline.is_closed = True + mock_outline.bezier_segments = [] + mock_outline.corners = mock_corner_indices + + mock_fit_outline.return_value = mock_outline + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 1 + assert len(wires) == 0 + assert outlines[0].color == Color.GREEN + + # Debug data key uses lowercase color name + assert 'green_raw_outline_0' in corner_debug_data + debug_data = corner_debug_data['green_raw_outline_0'] + assert debug_data['color'] == 'green' # lowercase + assert debug_data['corner_indices'] == mock_corner_indices + + def test_blue_outline_processing(self, converter, square_points, mock_raw_outline_class): + """Test BLUE elements become outlines with Bézier fitting.""" + test_svg_path = "test_blue.svg" + + mock_raw_outline = mock_raw_outline_class( + points=square_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = { + Color.BLUE: [mock_raw_outline] + } + + mock_corner_indices = [0, 1, 2, 3] + mock_debug_data = {'some': 'debug'} + mock_detect_corners.return_value = (mock_corner_indices, mock_debug_data) + + mock_outline = Mock(spec=Outline) + mock_outline.color = Color.BLUE + mock_outline.is_closed = True + mock_outline.bezier_segments = [] + mock_outline.corners = mock_corner_indices + + mock_fit_outline.return_value = mock_outline + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 1 + assert len(wires) == 0 + assert outlines[0].color == Color.BLUE + + # Debug data key uses lowercase color name + assert 'blue_raw_outline_0' in corner_debug_data + debug_data = corner_debug_data['blue_raw_outline_0'] + assert debug_data['color'] == 'blue' # lowercase + assert debug_data['corner_indices'] == mock_corner_indices + + def test_black_outline_processing(self, converter, triangle_points, mock_raw_outline_class): + """Test BLACK elements become outlines with Bézier fitting.""" + test_svg_path = "test_black.svg" + + mock_raw_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = { + Color.BLACK: [mock_raw_outline] + } + + mock_corner_indices = [0, 3, 6] + mock_debug_data = {'some': 'debug'} + mock_detect_corners.return_value = (mock_corner_indices, mock_debug_data) + + mock_outline = Mock(spec=Outline) + mock_outline.color = Color.BLACK + mock_outline.is_closed = True + mock_outline.bezier_segments = [] + mock_outline.corners = mock_corner_indices + + mock_fit_outline.return_value = mock_outline + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 1 + assert len(wires) == 0 + assert outlines[0].color == Color.BLACK + + # Debug data key uses lowercase color name + assert 'black_raw_outline_0' in corner_debug_data + debug_data = corner_debug_data['black_raw_outline_0'] + assert debug_data['color'] == 'black' # lowercase + assert debug_data['corner_indices'] == mock_corner_indices + + def test_mixed_colors_processing(self, converter, triangle_points, square_points, + mock_raw_outline_class): + """Test processing of SVG with mixed colors.""" + test_svg_path = "test_mixed.svg" + + # Create outlines for different colors + mock_green_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=True + ) + mock_blue_outline = mock_raw_outline_class( + points=square_points, + is_closed=True + ) + mock_black_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=False # Open curve + ) + mock_red_wire = mock_raw_outline_class( + points=[Point(0.5, 0.5)], + is_closed=True + ) + mock_red_outline = mock_raw_outline_class( + points=[Point(0.2, 0.2), Point(0.8, 0.2), Point(0.5, 0.8)], # Multiple points + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = { + Color.GREEN: [mock_green_outline], + Color.BLUE: [mock_blue_outline], + Color.BLACK: [mock_black_outline], + Color.RED: [mock_red_wire, mock_red_outline] # Multiple RED elements + } + + # Setup corner detection responses + corners_green = ([0, 3, 6], {'debug': 'green'}) + corners_blue = ([0, 1, 2, 3], {'debug': 'blue'}) + corners_black = ([], {'debug': 'black'}) + mock_detect_corners.side_effect = [corners_green, corners_blue, corners_black] + + # Setup Bézier fitting responses + mock_green_result = Mock(spec=Outline) + mock_green_result.color = Color.GREEN + mock_green_result.is_closed = True + mock_green_result.bezier_segments = [] + mock_green_result.corners = corners_green[0] + + mock_blue_result = Mock(spec=Outline) + mock_blue_result.color = Color.BLUE + mock_blue_result.is_closed = True + mock_blue_result.bezier_segments = [] + mock_blue_result.corners = corners_blue[0] + + mock_black_result = Mock(spec=Outline) + mock_black_result.color = Color.BLACK + mock_black_result.is_closed = False + mock_black_result.bezier_segments = [] + mock_black_result.corners = corners_black[0] + + mock_fit_outline.side_effect = [mock_green_result, mock_blue_result, mock_black_result] + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + # Verify results + assert len(outlines) == 3 # GREEN, BLUE, BLACK + assert len(wires) == 2 # Two RED elements + + # Verify wires (RED elements) + assert wires[0][1] == Color.RED # Single point wire + assert wires[0][0] == Point(0.5, 0.5) + + assert wires[1][1] == Color.RED # Multi-point wire (uses first point) + assert wires[1][0] == Point(0.2, 0.2) + + # Verify debug data keys (all lowercase) + assert 'green_raw_outline_0' in corner_debug_data + assert 'blue_raw_outline_0' in corner_debug_data + assert 'black_raw_outline_0' in corner_debug_data + + # Corner detector should be called for GREEN, BLUE, BLACK but not RED + assert mock_detect_corners.call_count == 3 + + # Bézier fitter should be called for GREEN, BLUE, BLACK but not RED + assert mock_fit_outline.call_count == 3 + + # ==================== Edge Case Tests ==================== + + def test_empty_svg(self, converter): + """Test converting an empty SVG.""" + test_svg_path = "test_empty.svg" + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract: + mock_extract.return_value = {} + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + assert len(outlines) == 0 + assert len(wires) == 0 + mock_extract.assert_called_once_with(test_svg_path) + + def test_invalid_svg_path(self, converter): + """Test handling of invalid SVG file path.""" + test_svg_path = "nonexistent.svg" + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract: + mock_extract.side_effect = ValueError("SVG file not found") + + with pytest.raises(ValueError, match="SVG file not found"): + converter.execute(test_svg_path) + + mock_extract.assert_called_once_with(test_svg_path) + + # ==================== Open Curve Tests ==================== + + def test_open_curves(self, converter, mock_raw_outline_class): + """Test converting SVG with open curves.""" + test_svg_path = "test_open.svg" + + mock_points = [ + Point(0.0, 0.0), Point(0.3, 0.4), Point(0.7, 0.3), Point(1.0, 0.0) + ] + + mock_raw_outline = mock_raw_outline_class( + points=mock_points, + is_closed=False + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = {Color.GREEN: [mock_raw_outline]} + mock_detect_corners.return_value = ([], {}) + + mock_bezier_segment = Mock(spec=BezierSegment) + mock_bezier_segment.control_points = [Point(0.0, 0.0), Point(0.3, 0.1), Point(0.5, 0.2)] + + mock_outline = Mock(spec=Outline) + mock_outline.color = Color.GREEN + mock_outline.is_closed = False + mock_outline.bezier_segments = [mock_bezier_segment, mock_bezier_segment] + mock_outline.corners = [] + + mock_fit_outline.return_value = mock_outline + + result = converter.execute(test_svg_path) + outlines, wires, colored_outlines, corner_debug_data = result + + mock_fit_outline.assert_called_once() + + assert len(outlines) == 1 + assert not outlines[0].is_closed + + # ==================== Error Handling Tests ==================== + + def test_error_handling_in_corner_detection(self, converter, triangle_points, mock_raw_outline_class): + """Test error handling when corner detection fails.""" + test_svg_path = "test_error.svg" + + mock_raw_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners: + + mock_extract.return_value = {Color.GREEN: [mock_raw_outline]} + mock_detect_corners.side_effect = ValueError("Corner detection failed") + + with pytest.raises(ValueError, match="Corner detection failed"): + converter.execute(test_svg_path) + + def test_error_handling_in_bezier_fitting(self, converter, triangle_points, mock_raw_outline_class): + """Test error handling when Bézier fitting fails.""" + test_svg_path = "test_error.svg" + + mock_raw_outline = mock_raw_outline_class( + points=triangle_points, + is_closed=True + ) + + with patch.object(converter.svg_parser, 'extract_raw_outlines_by_color') as mock_extract, \ + patch.object(converter.corner_detector, 'detect_corners') as mock_detect_corners, \ + patch.object(converter.bezier_fitter, 'fit_outline') as mock_fit_outline: + + mock_extract.return_value = {Color.GREEN: [mock_raw_outline]} + mock_detect_corners.return_value = ([], {}) + mock_fit_outline.side_effect = ValueError("Bézier fitting failed") + + with pytest.raises(ValueError, match="Bézier fitting failed"): + converter.execute(test_svg_path) + + # ==================== Internal Method Tests ==================== + + def test_ensure_proper_closure_open_curve(self, converter): + """Test the _ensure_proper_closure method with open curve.""" + points_open = [Point(0, 0), Point(1, 0), Point(1, 1)] + result_open = converter._ensure_proper_closure(points_open, False) + assert result_open == points_open + assert len(result_open) == 3 + + def test_ensure_proper_closure_closed_curve_with_gap(self, converter): + """Test the _ensure_proper_closure method with closed curve with gap.""" + points_closed_gap = [Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1)] + result_closed_gap = converter._ensure_proper_closure(points_closed_gap, True) + assert len(result_closed_gap) == 5 + assert result_closed_gap[-1] == points_closed_gap[0] + + def test_ensure_proper_closure_already_closed(self, converter): + """Test the _ensure_proper_closure method with already closed curve.""" + points_already_closed = [Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1), Point(0, 0)] + result_already_closed = converter._ensure_proper_closure(points_already_closed, True) + assert result_already_closed == points_already_closed + + def test_ensure_proper_closure_too_few_points(self, converter): + """Test the _ensure_proper_closure method with too few points.""" + points_few = [Point(0, 0), Point(1, 0)] + result_few = converter._ensure_proper_closure(points_few, True) + assert result_few == points_few + + def test_force_outline_closure(self, converter): + """Test the _force_outline_closure method.""" + mock_segment1 = Mock() + mock_segment1.control_points = [Point(0, 0), Point(0.5, 0)] + + mock_segment2 = Mock() + mock_segment2.control_points = [Point(0.5, 0), Point(1, 1)] + + mock_outline = Mock(spec=Outline) + mock_outline.bezier_segments = [mock_segment1, mock_segment2] + + converter._force_outline_closure(mock_outline) + + assert mock_segment2.control_points[-1] == mock_segment1.control_points[0] diff --git a/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_run_getdp_simulation.py b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_run_getdp_simulation.py new file mode 100644 index 0000000..9951a21 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/core/use_cases/test_run_getdp_simulation.py @@ -0,0 +1,356 @@ +""" +Unit tests for RunGetDPSimulation use case. +""" + +import pytest +from unittest.mock import Mock, patch, mock_open +import yaml +import numpy as np + +from svg_to_getdp.core.use_cases.run_getdp_simulation import RunGetDPSimulation + + +class TestRunGetDPSimulation: + """Test suite for RunGetDPSimulation class.""" + + # ==================== Helper Methods ==================== + + def _create_default_physical_values(self): + """Create default physical values for testing.""" + return { + "Isource": 1, + "mu0": 4e-7 * np.pi, + "nu0": 1/(4e-7 * np.pi), + "nu_iron_linear": 1/(4000 * 4e-7 * np.pi) + } + + def _assert_physical_values_equal(self, actual, expected, keys=None): + """Assert that physical values match expected values for given keys.""" + if keys is None: + keys = ["Isource", "mu0", "nu0", "nu_iron_linear"] + + for key in keys: + if key in expected: + if isinstance(expected[key], (int, float)): + assert actual[key] == pytest.approx(expected[key]) + else: + assert actual[key] == expected[key] + + def _mock_use_case_internals(self, use_case, config_data=None, physical_values=None): + """Context manager to mock internal methods of RunGetDPSimulation.""" + class MockInternals: + def __init__(self, use_case, config_data=None, physical_values=None): + self.use_case = use_case + self.config_data = config_data or {} + self.physical_values = physical_values or self._create_default_physical_values() + + def _create_default_physical_values(self): + # Use the class method + return TestRunGetDPSimulation._create_default_physical_values(self) + + def __enter__(self): + self.mock_gmsh = patch.object(self.use_case, '_initialize_gmsh').start() + self.mock_load_config = patch.object(self.use_case, '_load_config_yaml').start() + self.mock_run_sim = patch.object(self.use_case, '_run_simulation').start() + + self.mock_load_config.return_value = self.config_data + + # Mock _define_physical_values to set the values + self.mock_define = patch.object(self.use_case, '_define_physical_values').start() + self.mock_define.side_effect = lambda _: setattr( + self.use_case, 'physical_values', self.physical_values.copy() + ) + + return self + + def __exit__(self, *args): + patch.stopall() + + return MockInternals(use_case, config_data, physical_values) + + # ==================== Fixtures ==================== + + @pytest.fixture + def use_case(self): + """Create a RunGetDPSimulation instance for testing.""" + return RunGetDPSimulation() + + @pytest.fixture + def mock_all_externals(self): + """Mock all external dependencies.""" + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.print_data_to_pro') as mock_print, \ + patch('svg_to_getdp.core.use_cases.run_getdp_simulation.run_magnetostatic_simulation') as mock_run_sim, \ + patch('svg_to_getdp.core.use_cases.run_getdp_simulation.physical_identifiers') as mock_phys_ids: + mock_phys_ids.return_value = {"test_id": 1} + yield mock_print, mock_run_sim, mock_phys_ids + + @pytest.fixture + def mock_gmsh(self): + """Mock Gmsh module.""" + gmsh_mock = Mock() + gmsh_mock.initialize.return_value = None + gmsh_mock.finalize.return_value = None + return gmsh_mock + + # ==================== Initialization Tests ==================== + + def test_init(self, use_case): + """Test initialization of RunGetDPSimulation.""" + assert use_case.physical_values is None + + # ==================== Basic Functionality Tests ==================== + + def test_get_physical_values_no_values(self, use_case): + """Test get_physical_values when no values are defined.""" + result = use_case.get_physical_values() + assert result == {} + + def test_get_physical_values_after_definition(self, use_case): + """Test get_physical_values returns a copy of values.""" + config_data = { + "physical_values": { + "Isource": 5, + "mu0": 1.0e-6 + } + } + + use_case._define_physical_values(config_data) + + # Get values and modify the result + returned_values = use_case.get_physical_values() + returned_values["Isource"] = 100 + + # Original should not be modified + assert use_case.physical_values["Isource"] == 5 + assert returned_values["Isource"] == 100 + + # ==================== Configuration Loading Tests ==================== + + def test_load_config_yaml_success(self, use_case): + """Test _load_config_yaml with successful loading.""" + yaml_content = "physical_values:\n Isource: 5\n mu0: 1.256637e-06" + + with patch("builtins.open", mock_open(read_data=yaml_content)): + result = use_case._load_config_yaml("config.yaml") + + assert result == {"physical_values": {"Isource": 5, "mu0": 1.256637e-06}} + + def test_load_config_yaml_file_not_found(self, use_case, capsys): + """Test _load_config_yaml when file is not found.""" + with patch("builtins.open", side_effect=FileNotFoundError): + result = use_case._load_config_yaml("nonexistent.yaml") + + assert result == {} + captured = capsys.readouterr() + assert "Warning: Config file nonexistent.yaml not found" in captured.out + + def test_load_config_yaml_yaml_error(self, use_case, capsys): + """Test _load_config_yaml with YAML parsing error.""" + with patch("builtins.open", mock_open(read_data="invalid: yaml: content:")): + with patch("yaml.safe_load", side_effect=yaml.YAMLError("Parse error")): + result = use_case._load_config_yaml("bad.yaml") + + assert result == {} + captured = capsys.readouterr() + assert "Error parsing YAML file" in captured.out + + # ==================== Physical Values Definition Tests ==================== + + @pytest.mark.parametrize("config_data,expected_values", [ + # No config data + (None, { + "Isource": 1, + "mu0": 4e-7 * np.pi, + "nu0": 1/(4e-7 * np.pi), + "nu_iron_linear": 1/(4000 * 4e-7 * np.pi) + }), + # With config data overriding defaults + ({"physical_values": {"Isource": 10, "custom_value": 2.5, "mu0": 1.0e-6}}, { + "Isource": 10, + "custom_value": 2.5, + "mu0": 1.0e-6, + "nu0": 1/(4e-7 * np.pi), + "nu_iron_linear": 1/(4000 * 4e-7 * np.pi) + }), + # With expression strings containing pi + ({"physical_values": {"Isource": "2*pi", "mu0": "4*pi*1e-7"}}, { + "Isource": 2 * np.pi, + "mu0": 4 * np.pi * 1e-7, + "nu0": 1/(4e-7 * np.pi), + "nu_iron_linear": 1/(4000 * 4e-7 * np.pi) + }), + # With invalid expression string + ({"physical_values": {"Isource": "invalid*expression", "mu0": 1.0e-6}}, { + "Isource": "invalid*expression", + "mu0": 1.0e-6, + "nu0": 1/(4e-7 * np.pi), + "nu_iron_linear": 1/(4000 * 4e-7 * np.pi) + }), + ]) + def test_define_physical_values(self, use_case, config_data, expected_values): + """Test _define_physical_values with various configurations.""" + use_case._define_physical_values(config_data) + + assert use_case.physical_values is not None + + # Check all expected values + for key, expected_value in expected_values.items(): + if isinstance(expected_value, (int, float)): + assert use_case.physical_values[key] == pytest.approx(expected_value) + else: + assert use_case.physical_values[key] == expected_value + + # ==================== Mesh Name Handling Tests ==================== + + @pytest.mark.parametrize("mesh_name,expected", [ + ("test_mesh", "test_mesh.msh"), + ("test_mesh.msh", "test_mesh.msh"), + ("model", "model.msh"), + ]) + def test_execute_mesh_name_handling(self, use_case, mesh_name, expected, mock_all_externals): + """Test execute method handles mesh name extension correctly.""" + with self._mock_use_case_internals(use_case) as mocked: + use_case.execute(mesh_name, use_config_yaml=False) + + # Check that _run_simulation was called with correct mesh name + mocked.mock_run_sim.assert_called_once() + call_args = mocked.mock_run_sim.call_args[0] + assert call_args[0] == expected + + # ==================== Configuration Usage Tests ==================== + + @pytest.mark.parametrize("use_config_yaml,config_path,expected_config_path,config_data", [ + (False, None, None, None), + (True, None, "config.yaml", {"physical_values": {"Isource": 10}}), + (True, "custom/config.yaml", "custom/config.yaml", {"test": "data"}), + ]) + def test_execute_with_config(self, use_case, use_config_yaml, config_path, + expected_config_path, config_data, mock_all_externals): + """Test execute method with various config YAML configurations.""" + mock_print, mock_run_sim, mock_phys_ids = mock_all_externals + + with self._mock_use_case_internals(use_case, config_data=config_data) as mocked: + # Prepare arguments + kwargs = {"mesh_name": "test_mesh.msh", "use_config_yaml": use_config_yaml} + if config_path: + kwargs["config_yaml_path"] = config_path + + # Execute + use_case.execute(**kwargs) + + # Verify calls + if use_config_yaml: + mocked.mock_load_config.assert_called_once_with(expected_config_path) + mocked.mock_define.assert_called_once_with(config_data) + else: + mocked.mock_load_config.assert_not_called() + mocked.mock_define.assert_called_once_with(None) + + mocked.mock_run_sim.assert_called_once() + + # ==================== External Function Integration Tests ==================== + + def test_execute_calls_external_functions(self, use_case, mock_all_externals): + """Test that execute calls external functions correctly.""" + mock_print, mock_run_sim, mock_phys_ids = mock_all_externals + + with self._mock_use_case_internals(use_case) as mocked: + # Execute the method + use_case.execute("test_mesh.msh", use_config_yaml=False) + + # Check that external functions were called + mock_phys_ids.assert_called_once() + assert mock_print.call_count == 2 + + def test_run_simulation_calls_external_function(self, use_case): + """Test _run_simulation calls external function.""" + mock_run_simulation = Mock() + + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.run_magnetostatic_simulation', + mock_run_simulation): + use_case._run_simulation("test_mesh.msh", True) + + mock_run_simulation.assert_called_once_with( + "test_mesh.msh", show_simulation_result=True + ) + + # ==================== Gmsh Integration Tests ==================== + + def test_gmsh_initialized_by_us_flag(self, use_case, mock_gmsh): + """Test that Gmsh finalize is called when we initialized it.""" + # Mock gmsh.isInitialized to return False first, then True + mock_gmsh.isInitialized.side_effect = [False, True] + + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.gmsh', mock_gmsh): + # Mock all other dependencies + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.print_data_to_pro'): + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.run_magnetostatic_simulation'): + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.physical_identifiers', + return_value={"test_id": 1}): + with patch.object(use_case, '_load_config_yaml', return_value={}): + with patch.object(use_case, '_define_physical_values') as mock_define: + mock_define.side_effect = lambda _: setattr( + use_case, 'physical_values', self._create_default_physical_values() + ) + with patch.object(use_case, '_run_simulation'): + # Call execute + use_case.execute("test_mesh.msh", use_config_yaml=False) + + # Verify calls + mock_gmsh.initialize.assert_called_once() + mock_gmsh.finalize.assert_called_once() + + # Check that isInitialized was called twice + assert mock_gmsh.isInitialized.call_count == 2 + + def test_gmsh_already_initialized(self, use_case, mock_gmsh): + """Test that Gmsh is not re-initialized if already initialized.""" + # Mock gmsh.isInitialized to always return True + mock_gmsh.isInitialized.return_value = True + + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.gmsh', mock_gmsh): + # Mock all other dependencies + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.print_data_to_pro'): + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.run_magnetostatic_simulation'): + with patch('svg_to_getdp.core.use_cases.run_getdp_simulation.physical_identifiers', + return_value={"test_id": 1}): + with patch.object(use_case, '_load_config_yaml', return_value={}): + with patch.object(use_case, '_define_physical_values') as mock_define: + mock_define.side_effect = lambda _: setattr( + use_case, 'physical_values', self._create_default_physical_values() + ) + with patch.object(use_case, '_run_simulation'): + # Call execute + use_case.execute("test_mesh.msh", use_config_yaml=False) + + # Verify calls + mock_gmsh.initialize.assert_not_called() + mock_gmsh.finalize.assert_not_called() + + # ==================== Error Handling Tests ==================== + + def test_gmsh_not_available(self): + """Test behavior when Gmsh is not available.""" + import svg_to_getdp.core.use_cases.run_getdp_simulation as module + + # Save original values + original_gmsh = module.gmsh + original_GMSH_AVAILABLE = module.GMSH_AVAILABLE + + try: + # Set GMSH_AVAILABLE to False + module.GMSH_AVAILABLE = False + module.gmsh = None + + # Create an instance + use_case = module.RunGetDPSimulation() + + # This should raise ImportError + with pytest.raises(ImportError, match="Gmsh is not available"): + use_case._initialize_gmsh() + + finally: + # Restore original values + module.GMSH_AVAILABLE = original_GMSH_AVAILABLE + module.gmsh = original_gmsh + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_bezier_fitter.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_bezier_fitter.py new file mode 100644 index 0000000..66c1cc4 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_bezier_fitter.py @@ -0,0 +1,642 @@ +""" +Test suite for the Bézier Fitter infrastructure component. +Updated for modular bezier_fitting structure. +""" +import pytest +import math +import numpy as np +from unittest.mock import patch + +from svg_to_getdp.infrastructure.bezier_fitting.bezier_fitter import BezierFitter +from svg_to_getdp.infrastructure.bezier_fitting.bezier_calculator import BezierCalculator +from svg_to_getdp.infrastructure.bezier_fitting.segment_classifier import SegmentClassifier +from svg_to_getdp.infrastructure.bezier_fitting.segment_fitter import SegmentFitter +from svg_to_getdp.infrastructure.bezier_fitting.continuity_enforcer import ContinuityEnforcer + +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + + +class TestBezierFitter: + """Test suite for the main BezierFitter orchestrator""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def fitter(self): + """Set up a fresh fitter instance for each test""" + return BezierFitter(bezier_degree=2, minimum_points_per_segment=15) + + @pytest.fixture + def calculator(self): + """Set up a bezier calculator for component testing""" + return BezierCalculator() + + @pytest.fixture + def classifier(self): + """Set up a segment classifier for component testing""" + return SegmentClassifier() + + @pytest.fixture + def segment_fitter(self): + """Set up a segment fitter for component testing""" + return SegmentFitter(bezier_degree=2) + + @pytest.fixture + def continuity_enforcer(self): + """Set up a continuity enforcer for component testing""" + return ContinuityEnforcer(bezier_degree=2) + + @pytest.fixture + def triangle_points(self): + """Create a triangle shape for testing""" + return [ + Point(0, 0), + Point(1, 0), + Point(0.5, 1), + Point(0, 0) # Closed triangle + ] + + @pytest.fixture + def circle_points(self): + """Create a circle-like shape for testing""" + points = [] + for i in range(20): + angle = 2 * math.pi * i / 20 + x = 0.5 + 0.4 * math.cos(angle) + y = 0.5 + 0.4 * math.sin(angle) + points.append(Point(x, y)) + points.append(points[0]) # Close the curve + return points + + @pytest.fixture + def mixed_shape_points(self): + """Create a shape with mixed corners and smooth sections""" + return [ + Point(0, 0), # Corner + Point(0.2, 0.1), Point(0.4, 0.15), Point(0.6, 0.1), # Smooth section + Point(0.8, 0), # Corner + Point(0.8, 0.5), # Corner + Point(0.6, 0.6), Point(0.4, 0.65), Point(0.2, 0.6), # Smooth section + Point(0, 0.5), # Corner + Point(0, 0) # Back to start + ] + + # ==================== Initialization and Configuration Tests ==================== + + def test_fitter_initialization_default(self, fitter): + """Test that fitter initializes with correct default parameters""" + assert fitter.bezier_degree == 2 + assert fitter.minimum_points_per_segment == 15 + + # Verify components are initialized + assert hasattr(fitter, 'segment_classifier') + assert hasattr(fitter, 'segment_fitter') + assert hasattr(fitter, 'continuity_enforcer') + assert hasattr(fitter, 'bezier_calculator') + + @pytest.mark.parametrize("degree,min_points", [ + (3, 10), + (2, 5), + (4, 20), + ]) + def test_fitter_initialization_custom_parameters(self, degree, min_points): + """Test fitter initialization with custom bezier degree and segment size""" + custom_fitter = BezierFitter( + bezier_degree=degree, + minimum_points_per_segment=min_points + ) + assert custom_fitter.bezier_degree == degree + assert custom_fitter.minimum_points_per_segment == min_points + + # ==================== Input Validation and Error Tests ==================== + + def test_fit_outline_insufficient_points_raises_error(self, fitter): + """Test that fitter raises ValueError for insufficient points""" + points = [Point(0, 0), Point(1, 0)] # Only 2 points + corner_indices = [] + color = Color.BLACK + + with pytest.raises(ValueError, match="Need at least 3 non-duplicate points for outline"): + fitter.fit_outline(points, corner_indices, color) + + def test_fit_outline_consecutive_duplicate_points_handling(self, fitter): + """Test that consecutive duplicate points are automatically removed""" + points = [ + Point(0, 0), + Point(0, 0), # Duplicate (will be removed) + Point(1, 0), + Point(1, 0), # Duplicate (will be removed) + Point(0.5, 1), + Point(0, 0) + ] + # After removing duplicates at indices 1 and 3, the cleaned list will be: + # [Point(0,0), Point(1,0), Point(0.5,1), Point(0,0)] + corner_indices = [0, 1, 2] # Indices in the CLEANED list + + # Should not raise error despite duplicates + outline = fitter.fit_outline(points, corner_indices, Color.BLUE) + assert hasattr(outline, 'bezier_segments') + assert len(outline.bezier_segments) >= 1 + # Should have 3 corners in the outline + assert len(outline.corners) == 3 + + @patch('svg_to_getdp.infrastructure.bezier_fitting.segment_fitter.np.linalg.lstsq') + def test_least_squares_fallback_on_singular_matrix(self, mock_lstsq, fitter): + """Test fallback to simple fitting when least squares fails with singular matrix""" + # Mock numpy.linalg.lstsq to raise LinAlgError + mock_lstsq.side_effect = np.linalg.LinAlgError("Matrix is singular") + + points = [Point(0, 0), Point(0.5, 0.5), Point(1, 1), Point(0, 0)] + + # Provide at least one corner to avoid the no-corners path + corner_indices = [0] + + # Should use fallback but still work + outline = fitter.fit_outline(points, corner_indices=corner_indices, color=Color.BLUE) + + # Check attributes + assert hasattr(outline, 'bezier_segments') + assert len(outline.bezier_segments) >= 1 + + # ==================== Basic Shape Fitting Tests ==================== + + def test_fit_outline_simple_triangle_with_corners(self, fitter, triangle_points): + """Test fitting Bézier curves to a simple triangle with explicit corners""" + corner_indices = [0, 1, 2] # All vertices are corners + color = Color.BLUE + + outline = fitter.fit_outline(triangle_points, corner_indices, color) + + # Validate the result structure + assert hasattr(outline, 'bezier_segments') + assert hasattr(outline, 'corners') + assert hasattr(outline, 'color') + assert hasattr(outline, 'is_closed') + + # Check attributes directly + assert outline.color == color + assert outline.is_closed == True + assert len(outline.corners) == 3 + + # Should have at least 1 Bézier segment + assert len(outline.bezier_segments) >= 1 + + # Each segment should be valid + for segment in outline.bezier_segments: + assert hasattr(segment, 'control_points') + assert hasattr(segment, 'degree') + # Each Bézier segment should have degree + 1 control points + assert len(segment.control_points) == segment.degree + 1 + # Control points should not be NaN or infinite + for control_point in segment.control_points: + assert math.isfinite(control_point.x) + assert math.isfinite(control_point.y) + + # Check segment connections for continuity + if len(outline.bezier_segments) > 1: + for i in range(len(outline.bezier_segments)): + current_segment = outline.bezier_segments[i] + next_segment = outline.bezier_segments[(i + 1) % len(outline.bezier_segments)] + + # Check C0 continuity (position continuity at segment interfaces) + distance = current_segment.end_point.distance_to(next_segment.start_point) + assert distance < 1e-10, f"Segment {i} end point doesn't connect to segment {(i + 1) % len(outline.bezier_segments)} start point. Distance: {distance}" + + # Verify the outline is properly closed + first_segment = outline.bezier_segments[0] + last_segment = outline.bezier_segments[-1] + closure_distance = last_segment.end_point.distance_to(first_segment.start_point) + assert closure_distance < 1e-10, f"Outline is not properly closed. Gap: {closure_distance}" + + def test_fit_outline_smooth_circle_without_corners(self, fitter, circle_points): + """Test fitting Bézier curves to a smooth circle-like shape without explicit corners""" + corner_indices = [] # No corners for smooth curve + color = Color.GREEN + + outline = fitter.fit_outline(circle_points, corner_indices, color) + + # Check attributes + assert hasattr(outline, 'bezier_segments') + assert len(outline.bezier_segments) > 0 + + # Verify there are no corners after fitting (since none were provided) + assert hasattr(outline, 'corners') + assert len(outline.corners) == 0 + + # Ensure all segments are properly connected + if len(outline.bezier_segments) > 1: + for i in range(len(outline.bezier_segments) - 1): + current = outline.bezier_segments[i] + next_seg = outline.bezier_segments[i + 1] + # Check C0 continuity (end point matches next start point) + assert current.end_point.distance_to(next_seg.start_point) < 1e-10 + + # Check closure for closed Outline + if outline.is_closed and len(outline.bezier_segments) > 1: + first = outline.bezier_segments[0] + last = outline.bezier_segments[-1] + assert last.end_point.distance_to(first.start_point) < 1e-10 + + def test_fit_outline_mixed_corners_and_smooth_sections(self, fitter, mixed_shape_points): + """Test fitting with combination of corner points and smooth sections""" + corner_indices = [0, 4, 5, 8] # Indices of corners + color = Color.BLACK + + outline = fitter.fit_outline(mixed_shape_points, corner_indices, color) + + # Should create valid outline + assert hasattr(outline, 'bezier_segments') + assert len(outline.bezier_segments) >= 1 + + # Check attributes + assert hasattr(outline, 'corners') + assert hasattr(outline, 'color') + assert hasattr(outline, 'is_closed') + + # Check attribute values + assert outline.color == color + assert outline.is_closed == True + assert len(outline.corners) == 4 # Should have 4 corners + + # Each segment should be valid + for segment in outline.bezier_segments: + assert hasattr(segment, 'control_points') + assert hasattr(segment, 'degree') + # Each Bézier segment should have degree + 1 control points + assert len(segment.control_points) == segment.degree + 1 + # Control points should not be NaN or infinite + for control_point in segment.control_points: + assert math.isfinite(control_point.x) + assert math.isfinite(control_point.y) + + # Check segment connections (C0 continuity) + if len(outline.bezier_segments) > 1: + for i in range(len(outline.bezier_segments)): + current_segment = outline.bezier_segments[i] + next_segment = outline.bezier_segments[(i + 1) % len(outline.bezier_segments)] + + # Check C0 continuity (position continuity at segment interfaces) + distance = current_segment.end_point.distance_to(next_segment.start_point) + assert distance < 1e-10, f"Segment {i} end point doesn't connect to segment {(i + 1) % len(outline.bezier_segments)} start point. Distance: {distance}" + + # Verify the outline is properly closed + if len(outline.bezier_segments) > 1: + first_segment = outline.bezier_segments[0] + last_segment = outline.bezier_segments[-1] + closure_distance = last_segment.end_point.distance_to(first_segment.start_point) + assert closure_distance < 1e-10, f"Outline is not properly closed. Gap: {closure_distance}" + + # ==================== BezierCalculator Component Tests ==================== + + def test_bezier_calculator_remove_consecutive_duplicate_points(self, calculator): + """Test that consecutive duplicate points are removed while preserving non-consecutive duplicates""" + points = [ + Point(0, 0), + Point(0, 0), # Consecutive duplicate + Point(1, 0), + Point(1, 0), # Consecutive duplicate + Point(1, 1), + Point(0, 0) # Not consecutive duplicate (different from first) + ] + + cleaned = calculator.remove_consecutive_duplicate_points(points) + assert len(cleaned) == 4 # Should have 4 unique consecutive points + + def test_bezier_calculator_calculate_segment_interfaces_with_corners(self, calculator): + """Test segment interface calculation prioritizing corner points""" + points = [Point(i * 0.1, 0) for i in range(11)] # 11 points along x-axis + corner_indices = [0, 5, 10] # Corners at start, middle, end + + interfaces = calculator.calculate_segment_interfaces( + points, corner_indices, target_segment_count=3, is_closed=False + ) + + # Should include all corner indices plus start and end + assert 0 in interfaces + assert 5 in interfaces + assert 10 in interfaces + + def test_bezier_calculator_compute_bernstein_basis(self, calculator): + """Test Bernstein basis polynomial computation for Bézier curves""" + basis_val = calculator.compute_bernstein_basis(1, 2, 0.5) # B_{1,2}(0.5) + expected = math.comb(2, 1) * (0.5 ** 1) * ((1 - 0.5) ** (2 - 1)) + assert abs(basis_val - expected) < 1e-10 + + # Test that basis polynomials sum to 1 (partition of unity property) + total = 0 + for i in range(3): # degree 2 has 3 basis functions + total += calculator.compute_bernstein_basis(i, 2, 0.3) + assert abs(total - 1.0) < 1e-10 + + def test_bezier_calculator_are_points_approximately_linear(self, calculator): + """Test detection of approximately linear point sequences""" + # Linear points + linear_points = [Point(0, 0), Point(0.5, 0.5), Point(1, 1)] + assert calculator.are_points_approximately_linear(linear_points) + + # Non-linear points + non_linear_points = [Point(0, 0), Point(0.5, 0), Point(1, 1)] + assert not calculator.are_points_approximately_linear(non_linear_points) + + def test_bezier_calculator_calculate_distance_from_line(self, calculator): + """Test perpendicular distance calculation from point to line""" + line_start = Point(0, 0) + line_end = Point(1, 0) + test_point = Point(0.5, 1) + + distance = calculator.calculate_distance_from_line(line_start, line_end, test_point) + assert abs(distance - 1.0) < 1e-10 + + def test_bezier_calculator_project_point_to_line(self, calculator): + """Test orthogonal projection of point onto line segment""" + line_start = Point(0, 0) + line_end = Point(1, 0) + test_point = Point(0.5, 1) + + projection = calculator.project_point_to_line(line_start, line_end, test_point) + assert projection == Point(0.5, 0) # Should project to (0.5, 0) + + def test_bezier_calculator_find_point_with_max_deviation(self, calculator): + """Test identification of point with maximum deviation from line""" + line_start = Point(0, 0) + line_end = Point(1, 0) + points = [ + Point(0.2, 0.1), + Point(0.5, 0.5), # Max deviation + Point(0.8, 0.1) + ] + + max_point = calculator.find_point_with_max_deviation(points, line_start, line_end) + assert max_point == points[1] # Point (0.5, 0.5) has max deviation + + # ==================== SegmentClassifier Component Tests ==================== + + def test_segment_classifier_identify_corner_regions(self, classifier, calculator): + """Test identification of regions around corners for special fitting""" + points = [Point(i * 0.1, 0) for i in range(11)] + corner_indices = [0, 5, 10] + regions = classifier.identify_corner_regions(points, corner_indices) + + assert len(regions) == 3 + for region in regions: + assert isinstance(region, tuple) + assert len(region) == 2 + assert region[0] <= region[1] + + def test_segment_classifier_classify_segment_type_corner_regions(self, classifier): + """Test segment classification for corner regions""" + points = [Point(i * 0.1, 0) for i in range(11)] + corner_regions = [(0, 2), (8, 10)] + corner_indices = [0, 5, 10] + + # Test corner region (segment within corner region) + segment_type = classifier.classify_segment_type( + start_index=1, + end_index=2, + corner_regions=corner_regions, + corner_indices=corner_indices, + points=points + ) + assert segment_type == "corner_region" + + # Test corner region (segment contains interior corner) + segment_type = classifier.classify_segment_type( + start_index=4, + end_index=6, + corner_regions=[], + corner_indices=corner_indices, + points=points + ) + assert segment_type == "corner_region" + + def test_segment_classifier_classify_segment_type_straight_and_curved(self, classifier): + """Test segment classification for straight edges and curved segments""" + points = [Point(i * 0.1, 0) for i in range(11)] + + # Test straight_edge - segment connecting corners with straight geometry + straight_points = [Point(0, 0), Point(0.5, 0), Point(1, 0)] + all_points = points + straight_points + segment_type = classifier.classify_segment_type( + start_index=11, # Start at the first straight point + end_index=13, # End at the last straight point + corner_regions=[], + corner_indices=[11, 13], # Treat endpoints as corners + points=all_points + ) + assert segment_type == "straight_edge" + + # Test curved - segment not connecting corners and not straight + curved_points = [Point(0, 0), Point(0.3, 0.1), Point(0.7, 0.1), Point(1, 0)] + all_points = points + curved_points + segment_type = classifier.classify_segment_type( + start_index=11, # Start at the first curved point + end_index=14, # End at the last curved point + corner_regions=[], + corner_indices=[], # No corners involved + points=all_points + ) + assert segment_type == "curved" + + # ==================== SegmentFitter Component Tests ==================== + + def test_segment_fitter_fit_simple_bezier_curve_small_point_sets(self, segment_fitter): + """Test simple Bézier fitting for small point sets (fallback cases)""" + # Single point + points = [Point(5, 5)] + segment = segment_fitter._fit_simple_bezier_curve(points) + assert hasattr(segment, 'control_points') + assert hasattr(segment, 'degree') + assert len(segment.control_points) == 3 # degree 2 + 1 + + # Two points + points = [Point(0, 0), Point(1, 1)] + segment = segment_fitter._fit_simple_bezier_curve(points) + assert hasattr(segment, 'start_point') + assert hasattr(segment, 'end_point') + # Access attributes directly + assert segment.control_points[0] == points[0] + assert segment.control_points[-1] == points[1] + + def test_segment_fitter_fit_straight_edge_segment(self, segment_fitter): + """Test fitting of straight edge segments between corners""" + points = [Point(0, 0), Point(0.5, 0), Point(1, 0)] + segment = segment_fitter._fit_straight_edge_segment(points) + + assert hasattr(segment, 'control_points') + assert len(segment.control_points) == 3 + assert segment.control_points[0] == points[0] + assert segment.control_points[-1] == points[-1] + + # Midpoint should be average of endpoints + midpoint = segment.control_points[1] + expected_midpoint = Point( + (points[0].x + points[-1].x) / 2, + (points[0].y + points[-1].y) / 2 + ) + assert midpoint.distance_to(expected_midpoint) < 1e-10 + + # ==================== ContinuityEnforcer Component Tests ==================== + + def test_continuity_enforcer_enforce_segment_continuity_c0(self, continuity_enforcer): + """Test C0 continuity (position continuity) enforcement between segments""" + # Create two simple segments using BezierSegment constructor + segment1 = BezierSegment( + control_points=[Point(0, 0), Point(0.3, 0.1), Point(0.5, 0)], + degree=2 + ) + segment2 = BezierSegment( + control_points=[Point(0.5, 0), Point(0.7, -0.1), Point(1, 0)], + degree=2 + ) + + segments = [segment1, segment2] + interfaces = [0, 5, 10] # Mock interfaces + corner_indices = [] # No corners for smooth junction + + # Test C0 continuity enforcement + continuity_enforcer.enforce_segment_continuity( + segments, interfaces, corner_indices, is_closed=False + ) + + # End point of first should match start point of second (C0 continuity) + assert segment1.control_points[-1] == segment2.control_points[0] + + def test_continuity_enforcer_enforce_tangent_continuity_c1(self, continuity_enforcer): + """Test C1 continuity (tangent continuity) enforcement for smooth junctions""" + segment1 = BezierSegment( + control_points=[Point(0, 0), Point(0.3, 0.1), Point(0.5, 0)], + degree=2 + ) + segment2 = BezierSegment( + control_points=[Point(0.5, 0), Point(0.7, -0.1), Point(1, 0)], + degree=2 + ) + + original_p1 = segment1.control_points[1] + original_q1 = segment2.control_points[1] + + continuity_enforcer._enforce_tangent_continuity(segment1, segment2) + + # Control points should be adjusted for tangent continuity + assert segment1.control_points[1] != original_p1 + assert segment2.control_points[1] != original_q1 + + def test_continuity_enforcer_ensure_outline_closure(self, continuity_enforcer): + """Test enforcement of closure for closed outlines""" + segment1 = BezierSegment( + control_points=[Point(0, 0), Point(0.3, 0.1), Point(0.5, 0)], + degree=2 + ) + segment2 = BezierSegment( + control_points=[Point(0.6, 0), Point(0.7, -0.1), Point(1, 0)], # Not connected to segment1 + degree=2 + ) + + segments = [segment1, segment2] + continuity_enforcer._ensure_outline_closure(segments) + + # Last point of last segment should match first point of first segment + assert segments[-1].control_points[-1] == segments[0].control_points[0] + + # ==================== Regular Polygon Fitting Tests ==================== + + @pytest.mark.parametrize("shape_type,corner_count", [ + ("triangle", 3), + ("square", 4), + ("pentagon", 5), + ("hexagon", 6), + ]) + def test_fit_regular_polygons_various_sizes(self, fitter, shape_type, corner_count): + """Test fitting Bézier curves to regular polygons with different numbers of sides""" + # Generate regular polygon points + points = [] + for i in range(corner_count): + angle = 2 * math.pi * i / corner_count + x = 0.5 + 0.4 * math.cos(angle) + y = 0.5 + 0.4 * math.sin(angle) + points.append(Point(x, y)) + points.append(points[0]) # Close the polygon + + # All vertices are corners + corner_indices = list(range(corner_count)) + + outline = fitter.fit_outline(points, corner_indices, Color.GREEN) + + # Validate outline + assert hasattr(outline, 'bezier_segments') + assert hasattr(outline, 'corners') + assert outline.color == Color.GREEN + assert outline.is_closed == True + assert len(outline.corners) == corner_count + assert len(outline.bezier_segments) >= 1 + + for segment in outline.bezier_segments: + assert hasattr(segment, 'control_points') + assert hasattr(segment, 'degree') + assert len(segment.control_points) == segment.degree + 1 + for control_point in segment.control_points: + assert math.isfinite(control_point.x) + assert math.isfinite(control_point.y) + + # ==================== Performance and Scalability Tests ==================== + + def test_performance_with_large_point_set(self, fitter): + """Test fitting performance with larger point sets (100 points)""" + # Create a larger point set + n_points = 100 + points = [Point(math.cos(2 * math.pi * i / n_points), + math.sin(2 * math.pi * i / n_points)) + for i in range(n_points)] + corner_indices = [0, 25, 50, 75] # Approximate corner indices + color = Color.GREEN + + import time + start_time = time.time() + + outline = fitter.fit_outline(points, corner_indices, color) + + end_time = time.time() + duration = end_time - start_time + + # Should complete in reasonable time + assert duration < 5.0, f"Fitting 100 points took {duration:.2f} seconds (should be < 5s)" + + # Result should be valid + assert hasattr(outline, 'bezier_segments') + assert len(outline.bezier_segments) > 0 + + # ==================== Edge Case and Robustness Tests ==================== + + def test_fit_outline_open_curve_not_closed(self, fitter): + """Test fitting Bézier curves to an open (non-closed) curve""" + points = [ + Point(0, 0), + Point(0.2, 0.1), + Point(0.4, 0.2), + Point(0.6, 0.1), + Point(0.8, 0) + ] + corner_indices = [0, 4] + + outline = fitter.fit_outline(points, corner_indices, Color.BLUE, is_closed=False) + + # Validate outline structure (open) + assert hasattr(outline, 'bezier_segments') + assert hasattr(outline, 'corners') + assert outline.color == Color.BLUE + assert outline.is_closed == False + assert len(outline.corners) == 2 + + # Should create valid segments + assert len(outline.bezier_segments) >= 1 + + for segment in outline.bezier_segments: + assert hasattr(segment, 'control_points') + assert hasattr(segment, 'degree') + assert len(segment.control_points) == segment.degree + 1 + for control_point in segment.control_points: + assert math.isfinite(control_point.x) + assert math.isfinite(control_point.y) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_corner_detector.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_corner_detector.py new file mode 100644 index 0000000..e816826 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_corner_detector.py @@ -0,0 +1,530 @@ +""" +Unit tests for CornerDetector class. + +Tests corner detection functionality on various geometric shapes, +including rectangles, circles, ellipses, and complex mixed shapes. +""" +import pytest +from math import cos, sin, pi + +from svg_to_getdp.infrastructure.corner_detection.corner_detector import CornerDetector +from svg_to_getdp.core.entities.point import Point + + +class TestCornerDetector: + """Test suite for CornerDetector class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def detector(self): + """Create a corner detector instance for testing.""" + return CornerDetector(debug_enabled=False) + + @pytest.fixture + def debug_detector(self): + """Create a corner detector with debug enabled.""" + return CornerDetector(debug_enabled=True) + + @pytest.fixture + def rectangle_points(self): + """Create points for a rectangle shape.""" + return self.generate_rectangle_points(0, 0, 100, 50) + + @pytest.fixture + def circle_points(self): + """Create points for a circle shape.""" + return self.generate_circle_points(50, 50, 40) + + @pytest.fixture + def ellipse_points(self): + """Create points for an ellipse shape.""" + return self.generate_ellipse_points(50, 50, 40, 30) + + @pytest.fixture + def tear_shape_points(self): + """Create points for a tear/drop shape.""" + return self.generate_tear_shape_points(50, 50) + + @pytest.fixture + def peanut_shape_points(self): + """Create points for a smooth peanut shape.""" + return self.generate_peanut_shape_points(50, 50) + + @pytest.fixture + def sharp_corner_points(self): + """Create points with a sharp 90-degree corner.""" + points = [] + for i in range(20): + points.append(Point(i, 0)) + for i in range(20): + points.append(Point(20, i)) + return points + + @pytest.fixture + def l_shape_points(self): + """Create points forming a simple L-shaped corner.""" + points = [] + # Horizontal line + for i in range(50): + points.append(Point(i, 0)) + + # Vertical line + for i in range(50): + points.append(Point(50, i)) + + return points + + # ==================== Helper Methods ==================== + + def generate_circle_points(self, center_x, center_y, radius, num_points=200): + """Generate points along a circle.""" + points = [] + for i in range(num_points): + angle = 2 * pi * i / num_points + x = center_x + radius * cos(angle) + y = center_y + radius * sin(angle) + points.append(Point(x, y)) + return points + + def generate_ellipse_points(self, center_x, center_y, width, height, num_points=200): + """Generate points along an ellipse.""" + points = [] + for i in range(num_points): + angle = 2 * pi * i / num_points + x = center_x + width * cos(angle) + y = center_y + height * sin(angle) + points.append(Point(x, y)) + return points + + def generate_rectangle_points(self, x, y, width, height, num_points_per_side=50): + """Generate points along a rectangle.""" + points = [] + + # Top side + for i in range(num_points_per_side): + px = x + (width * i / num_points_per_side) + py = y + points.append(Point(px, py)) + + # Right side + for i in range(num_points_per_side): + px = x + width + py = y + (height * i / num_points_per_side) + points.append(Point(px, py)) + + # Bottom side + for i in range(num_points_per_side): + px = x + width - (width * i / num_points_per_side) + py = y + height + points.append(Point(px, py)) + + # Left side + for i in range(num_points_per_side): + px = x + py = y + height - (height * i / num_points_per_side) + points.append(Point(px, py)) + + return points + + def generate_tear_shape_points(self, center_x, center_y, size=100, num_points=200): + """Generate points for a tear/drop shape (has 1 sharp corner at the pointy end).""" + points = [] + for i in range(num_points): + angle = 2 * pi * i / num_points + r = size * (1 - cos(angle)) + x = center_x + r * cos(angle) + y = center_y + r * sin(angle) + points.append(Point(x, y)) + return points + + def generate_peanut_shape_points(self, center_x, center_y, size=100, num_points=200, waist_factor=0.5): + """Generate points for a peanut shape with a distinct waist in the middle.""" + points = [] + for i in range(num_points): + angle = 2 * pi * i / num_points + r = size * (waist_factor + (1 - waist_factor) * (cos(angle) ** 2)) + x = center_x + r * cos(angle) + y = center_y + r * sin(angle) + points.append(Point(x, y)) + return points + + # ==================== Initialization Tests ==================== + + def test_initialization_default_params(self): + """Test that detector initializes with default parameters.""" + detector = CornerDetector() + + assert detector.window_size == 15 + assert detector.direction_change_threshold == pytest.approx(0.8) + assert detector.angle_threshold == pytest.approx(pi / 6) + assert detector.minimum_corner_distance == 5 + assert detector.smoothness_threshold == pytest.approx(0.72) + assert detector.corner_strength_threshold == pytest.approx(0.45) + assert detector.ellipse_aspect_ratio_threshold == pytest.approx(1.2) + assert detector.debug_enabled == True + + def test_initialization_custom_params(self): + """Test that detector initializes with custom parameters.""" + detector = CornerDetector( + window_size=20, + direction_change_threshold=1.0, + angle_threshold=pi/4, + minimum_corner_distance=10, + smoothness_threshold=0.8, + corner_strength_threshold=0.6, + ellipse_aspect_ratio_threshold=1.5, + debug_enabled=False + ) + + assert detector.window_size == 20 + assert detector.direction_change_threshold == pytest.approx(1.0) + assert detector.angle_threshold == pytest.approx(pi/4) + assert detector.minimum_corner_distance == 10 + assert detector.smoothness_threshold == pytest.approx(0.8) + assert detector.corner_strength_threshold == pytest.approx(0.6) + assert detector.ellipse_aspect_ratio_threshold == pytest.approx(1.5) + assert detector.debug_enabled == False + + # ==================== Basic Functionality Tests ==================== + + def test_detection_with_empty_boundary_points(self, detector): + """Test corner detection with empty boundary points.""" + corners, debug_data = detector.detect_corners([]) + + assert corners == [] + assert isinstance(debug_data, dict) + + def test_detection_with_small_number_of_points(self, detector): + """Test corner detection with very few points.""" + points = [Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1)] + corners, debug_data = detector.detect_corners(points) + + assert corners == [] # Too few points for detection + assert isinstance(debug_data, dict) + + def test_debug_data_structure(self, debug_detector, rectangle_points): + """Test that debug data has the expected structure.""" + corners, debug_data = debug_detector.detect_corners(rectangle_points) + + assert isinstance(debug_data, dict) + assert 'shape_analysis' in debug_data + assert 'candidate_detection' in debug_data + assert 'strength_calculations' in debug_data + assert 'clustering' in debug_data + assert 'final_decisions' in debug_data + assert 'all_steps' in debug_data + + # Check that all_steps contains messages + assert len(debug_data['all_steps']) > 0 + + # ==================== Shape Detection Tests ==================== + + def test_rectangle_detection(self, detector, rectangle_points): + """Test corner detection on a rectangle (should find 4 corners).""" + corners, debug_data = detector.detect_corners(rectangle_points) + + # Should find exactly 4 corners for a rectangle + assert len(corners) == 4 + + # Corners should be well-spaced + total_points = len(rectangle_points) + for i in range(len(corners)): + for j in range(i + 1, len(corners)): + distance = min( + abs(corners[i] - corners[j]), + total_points - abs(corners[i] - corners[j]) + ) + assert distance > 10 + + # Check debug data + if 'final_decisions' in debug_data: + assert 'final_corners' in debug_data['final_decisions'] + assert len(debug_data['final_decisions']['final_corners']) == 4 + + def test_circle_detection(self, detector, circle_points): + """Test corner detection on a circle (should find 0 corners).""" + corners, debug_data = detector.detect_corners(circle_points) + + # Circle should have no corners + assert len(corners) == 0 + + # Check shape analysis in debug data + if 'shape_analysis' in debug_data: + assert debug_data['shape_analysis'].get('is_ellipse', False) or \ + debug_data['shape_analysis'].get('too_smooth', False) + + def test_ellipse_detection(self, detector, ellipse_points): + """Test corner detection on an ellipse (should find 0 corners).""" + corners, debug_data = detector.detect_corners(ellipse_points) + + # Ellipse should have no corners + assert len(corners) == 0 + + # Check shape analysis + if 'shape_analysis' in debug_data: + assert debug_data['shape_analysis'].get('is_ellipse', False) or \ + debug_data['shape_analysis'].get('too_smooth', False) + + def test_tear_shape_detection(self, detector, tear_shape_points): + """Test corner detection on a tear/drop shape (should find 1 sharp corner).""" + corners, debug_data = detector.detect_corners(tear_shape_points) + + # Tear shape should have 1 corner + assert len(corners) == 1 + + # Check debug data has information about the corner + if 'final_decisions' in debug_data: + assert len(debug_data['final_decisions'].get('final_corners', [])) == 1 + + def test_peanut_shape_detection(self, detector, peanut_shape_points): + """Test corner detection on a peanut shape (should find 0 corners).""" + corners, debug_data = detector.detect_corners(peanut_shape_points) + + # Smooth peanut shape should have no corners + assert len(corners) == 0 + + # ==================== Edge Case Tests ==================== + + def test_detection_with_large_number_of_points(self, detector): + """Test corner detection with a very large number of points.""" + rectangle_points = self.generate_rectangle_points(0, 0, 100, 50, num_points_per_side=200) + + corners, debug_data = detector.detect_corners(rectangle_points) + + # Should still find 4 corners + assert len(corners) == 4 + + # All corners should have valid indices + for corner_idx in corners: + assert 0 <= corner_idx < len(rectangle_points) + + # ==================== Internal Method Tests ==================== + + def test_angle_calculation(self, detector, l_shape_points): + """Test angle calculation indirectly by checking corner detection.""" + # Create an L-shape (should have 3 corners) + corners, debug_data = detector.detect_corners(l_shape_points) + + assert len(corners) > 0 + + def test_corner_strength_calculation(self, debug_detector, rectangle_points): + """Test strength calculation through debug data.""" + corners, debug_data = debug_detector.detect_corners(rectangle_points) + + # Should find 4 corners + assert len(corners) == 4 + + # Check strength calculations in debug data + if 'strength_calculations' in debug_data: + strengths = debug_data['strength_calculations'] + + # Some strengths should be calculated + assert len(strengths) > 0 + + # All strengths should be between 0 and 1 + for strength in strengths.values(): + assert 0 <= strength <= 1 + + # Check final decisions include strengths + if 'final_decisions' in debug_data and 'corner_strengths' in debug_data['final_decisions']: + final_strengths = debug_data['final_decisions']['corner_strengths'] + assert len(final_strengths) == len(corners) + + for idx, strength in final_strengths.items(): + assert idx in corners + assert 0 <= strength <= 1 + assert strength >= 0.45 # Should meet threshold + + def test_candidate_combination(self, debug_detector, rectangle_points): + """Test candidate combination through debug data.""" + corners, debug_data = debug_detector.detect_corners(rectangle_points) + + # Check candidate detection methods in debug data + if 'candidate_detection' in debug_data: + candidate_data = debug_data['candidate_detection'] + + # Should have multiple detection methods + assert 'angle_method' in candidate_data + assert 'direction_method' in candidate_data + assert 'curvature_method' in candidate_data + + # Should have combined results + if 'combined_votes' in candidate_data: + combined = candidate_data['combined_votes'] + assert len(combined) > 0 + + # Check votes are reasonable + for votes in combined.values(): + assert votes >= 0 + + def test_corner_refinement(self, debug_detector, rectangle_points): + """Test refinement process through debug data.""" + corners, debug_data = debug_detector.detect_corners(rectangle_points) + + # Should have refinement details in debug data + if 'refinement_details' in debug_data: + refinement_details = debug_data['refinement_details'] + + # Should have some refinement details + assert len(refinement_details) > 0 + + # Check structure of refinement details + for detail in refinement_details: + assert 'cluster' in detail + assert 'best_candidate' in detail + assert 'refined_candidate' in detail + assert 'accepted' in detail + + # ==================== Parameter Sensitivity Tests ==================== + + def test_different_angle_thresholds(self): + """Test corner detection with different angle thresholds.""" + points = [] + + # Square with rounded corners + for i in range(50): + points.append(Point(i, 0)) + for i in range(10): + angle = pi/2 * i/10 + points.append(Point(50 + 5*cos(angle), 5 + 5*sin(angle))) + for i in range(50): + points.append(Point(55 - i, 10)) + + # Test with strict threshold + strict_detector = CornerDetector(angle_threshold=pi/3, debug_enabled=False) + strict_corners, _ = strict_detector.detect_corners(points) + + # Test with lenient threshold + lenient_detector = CornerDetector(angle_threshold=pi/12, debug_enabled=False) + lenient_corners, _ = lenient_detector.detect_corners(points) + + # Lenient should find at least as many corners as strict + assert len(lenient_corners) >= len(strict_corners) + + def test_different_smoothness_thresholds(self, ellipse_points): + """Test ellipse detection with different smoothness thresholds.""" + # Test with low threshold (0.5) + low_thresh_detector = CornerDetector(smoothness_threshold=0.5, debug_enabled=False) + low_corners, low_debug = low_thresh_detector.detect_corners(ellipse_points) + + # Test with high threshold (0.9) + high_thresh_detector = CornerDetector(smoothness_threshold=0.9, debug_enabled=False) + high_corners, high_debug = high_thresh_detector.detect_corners(ellipse_points) + + # Both should detect ellipse as having no corners + if 'shape_analysis' in low_debug: + low_smoothness = low_debug['shape_analysis'].get('smoothness_score', 0) + + assert 0.78 < low_smoothness < 0.8 # ellipse smoothness ~ 0.795 + assert len(low_corners) == 0 + + if 'shape_analysis' in high_debug: + high_smoothness = high_debug['shape_analysis'].get('smoothness_score', 0) + + assert 0.78 < high_smoothness < 0.8 # ellipse smoothness ~ 0.795 + assert len(high_corners) == 0 + + def test_minimum_corner_distance_enforcement(self): + """Test that minimum corner distance is properly enforced.""" + points = [] + + # Two close right angles + for i in range(10): + points.append(Point(i, 0)) + points.append(Point(10, 0)) + points.append(Point(10, 1)) + points.append(Point(10, 2)) + for i in range(10): + points.append(Point(10 - i, 2)) + + # Test with minimum distance of 5 + detector = CornerDetector(minimum_corner_distance=5, debug_enabled=False) + corners, _ = detector.detect_corners(points) + + # Should only keep one of the two close corners + assert len(corners) <= 2 + + if len(corners) == 2: + # Check they're sufficiently spaced + distance = min( + abs(corners[0] - corners[1]), + len(points) - abs(corners[0] - corners[1]) + ) + assert distance >= 5 + + # ==================== Integration Tests ==================== + + def test_consistency_across_runs(self, detector, rectangle_points): + """Test that corner detection is consistent across multiple runs.""" + results = [] + for _ in range(5): + corners, _ = detector.detect_corners(rectangle_points) + results.append(sorted(corners)) + + # All results should be the same + for i in range(1, len(results)): + assert results[i] == results[0] + + def test_closed_shape_handling(self, detector): + """Test that closed shapes are handled correctly.""" + points = self.generate_rectangle_points(0, 0, 100, 50, num_points_per_side=25) + closed_points = points + [points[0]] + + corners, debug_data = detector.detect_corners(closed_points) + + # Should find 4 corners + assert len(corners) == 4 + + # Check corners are reasonable + for corner_idx in corners: + assert 0 <= corner_idx < len(closed_points) + + def test_scale_invariance(self): + """Test that corner detection works at different scales.""" + # Generate small rectangle + small_points = self.generate_rectangle_points(0, 0, 10, 5, num_points_per_side=20) + + # Generate large rectangle + large_points = self.generate_rectangle_points(0, 0, 100, 50, num_points_per_side=20) + + detector = CornerDetector(debug_enabled=False) + + small_corners, _ = detector.detect_corners(small_points) + large_corners, _ = detector.detect_corners(large_points) + + # Both should find 4 corners + assert len(small_corners) == 4 + assert len(large_corners) == 4 + + # ==================== Performance Tests ==================== + + def test_performance_with_many_points(self, detector): + """Test that detector handles large number of points efficiently.""" + dense_points = self.generate_circle_points(50, 50, 40, num_points=1000) + + import time + start_time = time.time() + corners, debug_data = detector.detect_corners(dense_points) + end_time = time.time() + + # Should complete in reasonable time + assert end_time - start_time < 2.0 + + # Circle should have no corners + assert len(corners) == 0 + + # ==================== Error Handling Tests ==================== + + def test_invalid_input_types(self, detector): + """Test handling of invalid input types.""" + invalid_points = [Point(0, 0), "not a point", Point(1, 1)] + + try: + corners, debug_data = detector.detect_corners(invalid_points) + # If it doesn't fail, verify structure + assert isinstance(corners, list) + assert isinstance(debug_data, dict) + except (AttributeError, TypeError, IndexError): + # Any of these would be reasonable errors + pass diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_grouper.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_grouper.py new file mode 100644 index 0000000..c266e69 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_grouper.py @@ -0,0 +1,353 @@ +import pytest +from unittest.mock import patch, MagicMock, PropertyMock +import math + +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.physical_group import ( + DOMAIN_VA, + DOMAIN_VI_IRON, + DOMAIN_VI_AIR, + BOUNDARY_GAMMA, + BOUNDARY_OUT +) +from svg_to_getdp.infrastructure.outline_grouper import OutlineGrouper + + +# ============================================================================ +# Fixtures and Helper Functions +# ============================================================================ + +@pytest.fixture +def sample_points(): + """Create sample points for testing.""" + return [ + Point(0.0, 0.0), + Point(1.0, 0.0), + Point(2.0, 0.0), + Point(3.0, 1.0), + Point(0.0, 2.0), + Point(1.0, 2.0), + Point(2.0, 2.0), + ] + + +@pytest.fixture +def create_square_outline(): + """Create a simple square outline.""" + def _create(center_x=0.0, center_y=0.0, size=1.0, color=Color.BLACK, closed=True): + half = size / 2.0 + # Create 4 line segments for a square + segments = [] + corners = [] + + # Define the 4 corners + corners.append(Point(center_x - half, center_y - half)) # bottom-left + corners.append(Point(center_x + half, center_y - half)) # bottom-right + corners.append(Point(center_x + half, center_y + half)) # top-right + corners.append(Point(center_x - half, center_y + half)) # top-left + + # Create segments connecting the corners + for i in range(4): + start = corners[i] + end = corners[(i + 1) % 4] + # Linear Bézier (degree 1) - just a line + segment = BezierSegment([start, end], degree=1) + segments.append(segment) + + return Outline( + bezier_segments=segments, + corners=corners, + color=color, + is_closed=closed + ) + return _create + + +@pytest.fixture +def sample_outlines(create_square_outline): + """Create a set of sample outlines for testing.""" + # Outer green square (Vi air domain) + outer = create_square_outline(center_x=0.0, center_y=0.0, size=10.0, color=Color.GREEN) + + # Inner blue square (Vi iron domain) + inner1 = create_square_outline(center_x=0.0, center_y=0.0, size=6.0, color=Color.BLUE) + + # Even inner green square (Vi air domain) + inner2 = create_square_outline(center_x=0.0, center_y=0.0, size=3.0, color=Color.GREEN) + + # Black square inside the green one + inner3 = create_square_outline(center_x=0.0, center_y=0.0, size=1.0, color=Color.BLACK) + + return [outer, inner1, inner2, inner3] + + +# ============================================================================ +# Test Cases for OutlineGrouper +# ============================================================================ + +class TestOutlineGrouper: + """Test suite for OutlineGrouper class.""" + + def test_should_return_true_when_point_is_inside_closed_square_outline(self, create_square_outline): + """Test point inside/outside detection for square outline.""" + outline = create_square_outline(center_x=0.0, center_y=0.0, size=4.0) + + # Points inside + assert OutlineGrouper.is_point_inside_outline(Point(0.0, 0.0), outline) + assert OutlineGrouper.is_point_inside_outline(Point(0.5, 0.5), outline) + assert OutlineGrouper.is_point_inside_outline(Point(-0.5, -0.5), outline) + + # Points outside + assert not OutlineGrouper.is_point_inside_outline(Point(2.0, 2.0), outline) + assert not OutlineGrouper.is_point_inside_outline(Point(-2.0, -2.0), outline) + assert not OutlineGrouper.is_point_inside_outline(Point(0.0, 2.0), outline) # on edge + + def test_should_return_false_when_point_is_inside_open_outline(self, create_square_outline): + """Test that open outlines always return False.""" + open_outline = create_square_outline(center_x=0.0, center_y=0.0, size=4.0, closed=False) + + # Even points that would be inside a closed outline should return False + assert not OutlineGrouper.is_point_inside_outline(Point(0.0, 0.0), open_outline) + + def test_should_raise_value_error_when_getting_bounding_box_for_empty_outline(self): + """Test bounding box with empty outline.""" + mock_outline = MagicMock() + type(mock_outline).control_points = PropertyMock(return_value=[]) + + with pytest.raises(ValueError, match="must have at least one control point"): + OutlineGrouper.get_outline_bounding_box(mock_outline) + + def test_should_detect_when_one_outline_is_inside_another(self, create_square_outline): + """Test outline containment detection.""" + outer = create_square_outline(center_x=0.0, center_y=0.0, size=10.0) + inner = create_square_outline(center_x=0.0, center_y=0.0, size=5.0) + separate = create_square_outline(center_x=20.0, center_y=20.0, size=5.0) + + # Inner is inside outer + assert OutlineGrouper.is_outline_inside_other(inner, outer) + + # Outer is not inside inner + assert not OutlineGrouper.is_outline_inside_other(outer, inner) + + # Separate is not inside outer + assert not OutlineGrouper.is_outline_inside_other(separate, outer) + + def test_should_correctly_identify_containment_hierarchy_for_nested_squares(self, create_square_outline): + """Test containment hierarchy detection.""" + # Create nested squares + outlines = [ + create_square_outline(center_x=0.0, center_y=0.0, size=10.0, color=Color.BLACK), # 0 + create_square_outline(center_x=0.0, center_y=0.0, size=6.0, color=Color.BLUE), # 1 + create_square_outline(center_x=0.0, center_y=0.0, size=3.0, color=Color.GREEN), # 2 + create_square_outline(center_x=0.0, center_y=0.0, size=1.0, color=Color.BLACK), # 3 + create_square_outline(center_x=20.0, center_y=20.0, size=5.0, color=Color.BLACK), # 4 + ] + + hierarchy = OutlineGrouper.get_containment_hierarchy(outlines) + + # Expected hierarchy (only immediate children): + # Outline 0 contains 1 (outline 1 is inside outline 0) + # Outline 1 contains 2 (outline 2 is inside outline 1) + # Outline 2 contains 3 (outline 3 is inside outline 2) + + assert hierarchy[0] == [1] + assert hierarchy[1] == [2] + assert hierarchy[2] == [3] + assert hierarchy[3] == [] + + def test_should_classify_outline_colors_correctly(self, create_square_outline): + """Test outline color classification.""" + black_outline = create_square_outline(color=Color.BLACK) + blue_outline = create_square_outline(color=Color.BLUE) + green_outline = create_square_outline(color=Color.GREEN) + + assert OutlineGrouper.classify_outline_color(black_outline) == "va" + assert OutlineGrouper.classify_outline_color(blue_outline) == "vi_iron" + assert OutlineGrouper.classify_outline_color(green_outline) == "vi_air" + + def test_should_raise_value_error_when_classifying_outline_with_invalid_color(self): + """Test outline color classification with invalid color.""" + red_outline = Outline( + bezier_segments=[BezierSegment([Point(0,0), Point(1,0)], degree=1)], + corners=[Point(0,0), Point(1,0)], + color=Color.RED, + is_closed=True + ) + + with pytest.raises(ValueError, match="Unknown outline color"): + OutlineGrouper.classify_outline_color(red_outline) + + def test_should_assign_correct_physical_groups_based_on_outline_classification(self, create_square_outline): + """Test physical group assignment for outlines.""" + # Test Va outline + groups = OutlineGrouper.get_physical_groups_for_outline( + classification="va", + is_outermost=False, + is_va_in_vi=False + ) + assert len(groups) == 1 + assert groups[0] == DOMAIN_VA + + # Test Va outline inside Vi (should get BOUNDARY_GAMMA too) + groups = OutlineGrouper.get_physical_groups_for_outline( + classification="va", + is_outermost=False, + is_va_in_vi=True + ) + assert len(groups) == 2 + assert BOUNDARY_GAMMA in groups + assert DOMAIN_VA in groups + + # Test outermost outline (should get BOUNDARY_OUT) + groups = OutlineGrouper.get_physical_groups_for_outline( + classification="va", + is_outermost=True, + is_va_in_vi=False + ) + assert len(groups) == 2 + assert DOMAIN_VA in groups + assert BOUNDARY_OUT in groups + + def test_should_group_single_outline_as_outermost(self, create_square_outline): + """Test basic grouping of outlines.""" + # Simple case: one outer Va outline + outlines = [create_square_outline(color=Color.BLACK)] + + result = OutlineGrouper.group_outlines(outlines) + + assert len(result) == 1 + assert result[0]["holes"] == [] + assert len(result[0]["physical_groups"]) == 2 # DOMAIN_VA + BOUNDARY_OUT + assert DOMAIN_VA in result[0]["physical_groups"] + assert BOUNDARY_OUT in result[0]["physical_groups"] + + def test_should_correctly_group_nested_outlines_with_varying_colors(self, sample_outlines): + """Test grouping of nested outlines.""" + result = OutlineGrouper.group_outlines(sample_outlines) + + assert len(result) == 4 + + # Check outline 0 (outermost green - Vi air) + assert result[0]["holes"] == [1] # Contains only the immediate child (inner1 - blue) + assert DOMAIN_VI_AIR in result[0]["physical_groups"] + assert BOUNDARY_OUT in result[0]["physical_groups"] + + # Check outline 1 (blue inner1 - Vi iron) + assert result[1]["holes"] == [2] # Contains only the immediate child (inner2 - green) + assert DOMAIN_VI_IRON in result[1]["physical_groups"] + + # Check outline 2 (green inner2 - Vi air) + assert result[2]["holes"] == [3] # Contains only the immediate child (inner3 - black) + assert DOMAIN_VI_AIR in result[2]["physical_groups"] + + # Check outline 3 (innermost black - Va) + assert result[3]["holes"] == [] # Contains nothing + assert DOMAIN_VA in result[3]["physical_groups"] + assert BOUNDARY_GAMMA in result[3]["physical_groups"] # Inside Vi + + def test_should_return_empty_list_when_grouping_empty_outlines(self): + """Test grouping with empty input.""" + result = OutlineGrouper.group_outlines([]) + assert result == [] + + @patch('svg_to_getdp.infrastructure.outline_grouper.OutlineGrouper.classify_outline_color') + @patch('svg_to_getdp.infrastructure.outline_grouper.OutlineGrouper.is_outline_inside_other') + def test_should_detect_va_outlines_inside_vi_outlines_and_assign_boundary_gamma( + self, mock_is_inside, mock_classify, create_square_outline + ): + """Test detection of Va outlines inside Vi outlines.""" + # Setup mock to simulate Va inside Vi + def side_effect(outline, other): + # Simple mock: return True if outline is black and other is blue or green + if outline.color == Color.BLACK and other.color in [Color.BLUE, Color.GREEN]: + return True + return False + + mock_is_inside.side_effect = side_effect + + # Mock color classification + def classify_side_effect(outline): + if outline.color == Color.BLACK: + return "va" + elif outline.color == Color.BLUE: + return "vi_iron" + return "va" # default + + mock_classify.side_effect = classify_side_effect + + # Create outlines + vi_outline = create_square_outline(color=Color.BLUE) + va_outline = create_square_outline(color=Color.BLACK) + outlines = [vi_outline, va_outline] + + # Mock the containment hierarchy to show Va is inside Vi + with patch('svg_to_getdp.infrastructure.outline_grouper.OutlineGrouper.get_containment_hierarchy') as mock_hierarchy: + mock_hierarchy.return_value = {0: [1], 1: []} # Vi contains Va + + result = OutlineGrouper.group_outlines(outlines) + + # Check that Va outline got BOUNDARY_GAMMA + assert BOUNDARY_GAMMA in result[1]["physical_groups"] + + def test_should_raise_error_when_no_outermost_candidate_can_be_determined(self, create_square_outline): + """Test error when no outermost candidate is found.""" + # Create a circular dependency scenario + outline1 = create_square_outline(color=Color.BLACK) + outline2 = create_square_outline(color=Color.BLUE) + + # Mock containment hierarchy to create circular reference + # Use the correct module path based on import + with patch('svg_to_getdp.infrastructure.outline_grouper.OutlineGrouper.get_containment_hierarchy') as mock_hierarchy: + mock_hierarchy.return_value = {0: [1], 1: [0]} # Each contains the other + + with pytest.raises(ValueError, match="No outermost candidates found"): + OutlineGrouper.group_outlines([outline1, outline2]) + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestOutlineGrouperIntegration: + """Integration tests for OutlineGrouper with real outline data.""" + + def test_should_process_triangle_outline_and_correctly_determine_containment(self): + """Test complete workflow with actual Bézier segments.""" + # Create a simple triangle using linear Bézier segments + p1 = Point(0, 0) + p2 = Point(4, 0) + p3 = Point(2, 3) + + segment1 = BezierSegment([p1, p2], degree=1) + segment2 = BezierSegment([p2, p3], degree=1) + segment3 = BezierSegment([p3, p1], degree=1) + + triangle = Outline( + bezier_segments=[segment1, segment2, segment3], + corners=[p1, p2, p3], + color=Color.BLACK, + is_closed=True + ) + + # Test point inside triangle + point_inside = Point(2, 1) + point_outside = Point(2, -1) + + assert OutlineGrouper.is_point_inside_outline(point_inside, triangle) + assert not OutlineGrouper.is_point_inside_outline(point_outside, triangle) + + # Test bounding box + min_x, max_x, min_y, max_y = OutlineGrouper.get_outline_bounding_box(triangle) + assert math.isclose(min_x, 0.0) + assert math.isclose(max_x, 4.0) + assert math.isclose(min_y, 0.0) + assert math.isclose(max_y, 3.0) + + # Test grouping (just this one outline) + result = OutlineGrouper.group_outlines([triangle]) + assert len(result) == 1 + assert result[0]["holes"] == [] + assert len(result[0]["physical_groups"]) == 2 # DOMAIN_VA + BOUNDARY_OUT diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_preprocessor.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_preprocessor.py new file mode 100644 index 0000000..9404d54 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_outline_preprocessor.py @@ -0,0 +1,386 @@ +""" +Unit tests for OutlinePreprocessor class. + +Tests the functionality of converting outlines to Gmsh geometry, +handling holes, physical groups, and topological relationships. +""" +import pytest +from unittest.mock import Mock, patch + +from svg_to_getdp.core.entities.outline import Outline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.bezier_segment import BezierSegment +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.physical_group import ( + DOMAIN_VI_IRON, + DOMAIN_VI_AIR, + DOMAIN_VA, + BOUNDARY_GAMMA, + BOUNDARY_OUT +) +from svg_to_getdp.infrastructure.outline_preprocessor import OutlinePreprocessor + + +class TestOutlinePreprocessor: + """Test suite for OutlinePreprocessor class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def mock_gmsh_factory(self): + """Create a mock Gmsh factory with basic geometry operations.""" + factory = Mock() + + # Mock geometry creation methods with distinct return values + # Track counters as instance variables + self._point_counter = 0 + self._line_counter = 0 + self._bezier_counter = 0 + self._curve_loop_counter = 0 + self._surface_counter = 0 + + def mock_add_point(x, y, z): + self._point_counter += 1 + return 100 + self._point_counter + + def mock_add_line(start, end): + self._line_counter += 1 + return 200 + self._line_counter + + def mock_add_bezier(points): + self._bezier_counter += 1 + return 300 + self._bezier_counter + + def mock_add_curve_loop(curves): + self._curve_loop_counter += 1 + return 400 + self._curve_loop_counter + + def mock_add_plane_surface(curve_loops): + self._surface_counter += 1 + return 500 + self._surface_counter + + factory.addPoint = Mock(side_effect=mock_add_point) + factory.addLine = Mock(side_effect=mock_add_line) + factory.addBezier = Mock(side_effect=mock_add_bezier) + factory.addCurveLoop = Mock(side_effect=mock_add_curve_loop) + factory.addPlaneSurface = Mock(side_effect=mock_add_plane_surface) + factory.addPhysicalGroup = Mock() + + return factory + + @pytest.fixture + def basic_points(self): + """Create basic test points for constructing outlines.""" + return [ + Point(0.0, 0.0), # Bottom-left + Point(1.0, 0.0), # Bottom-right + Point(1.0, 1.0), # Top-right + Point(0.0, 1.0), # Top-left + Point(0.5, 0.5), # Center + Point(0.0, 0.5), # Left-center + Point(0.5, 0.0) # Bottom-center + ] + + @pytest.fixture + def square_outline(self, basic_points): + """Create a square outline with straight edges.""" + segments = [ + BezierSegment([basic_points[0], basic_points[1]], degree=1), # Bottom edge + BezierSegment([basic_points[1], basic_points[2]], degree=1), # Right edge + BezierSegment([basic_points[2], basic_points[3]], degree=1), # Top edge + BezierSegment([basic_points[3], basic_points[0]], degree=1), # Left edge + ] + corners = [basic_points[0], basic_points[1], basic_points[2], basic_points[3]] + return Outline(segments, corners, Color.BLUE) + + @pytest.fixture + def outline_with_bezier_curves(self, basic_points): + """Create an outline with both straight edges and Bézier curves.""" + segments = [ + # Curved bottom edge (quadratic Bézier) + BezierSegment([basic_points[0], basic_points[6], basic_points[1]], degree=2), + # Straight right edge + BezierSegment([basic_points[1], basic_points[2]], degree=1), + # Straight top edge + BezierSegment([basic_points[2], basic_points[3]], degree=1), + # Curved left edge (quadratic Bézier) + BezierSegment([basic_points[3], basic_points[5], basic_points[0]], degree=2), + ] + corners = [basic_points[0], basic_points[1], basic_points[2], basic_points[3]] + return Outline(segments, corners, Color.BLACK) + + # ==================== Initialization Tests ==================== + + def test_initializes_with_empty_state(self): + """OutlinePreprocessor should initialize with all internal collections empty.""" + mesher = OutlinePreprocessor() + + assert mesher._point_tags == {} + assert mesher._curve_loops == {} + assert mesher._surface_tags == {} + assert mesher._created_points == {} + assert mesher._curve_tags_per_outline == {} + assert mesher._processing_order == [] + assert mesher._physical_groups_by_type['boundary'] == {} + assert mesher._physical_groups_by_type['domain'] == {} + + # ==================== Basic Functionality Tests ==================== + + def test_raises_error_when_outline_and_property_counts_mismatch( + self, mock_gmsh_factory, square_outline + ): + """Should raise ValueError when outlines and properties counts don't match.""" + mesher = OutlinePreprocessor() + + outlines = [square_outline] + properties = [ + {"physical_groups": [DOMAIN_VA]}, + {"physical_groups": [BOUNDARY_OUT]} # Extra property dict + ] + + with pytest.raises(ValueError, match="must match"): + mesher.preprocess_outlines(mock_gmsh_factory, outlines, properties) + + def test_meshes_square_outline_with_straight_edges( + self, mock_gmsh_factory, square_outline + ): + """Should create geometry for a square outline with only straight edges.""" + mesher = OutlinePreprocessor() + + mesher.preprocess_outlines(mock_gmsh_factory, [square_outline], [{"physical_groups": [DOMAIN_VI_IRON]}]) + + # Verify geometry creation calls + assert mock_gmsh_factory.addPoint.call_count == 4 # Four corner points + assert mock_gmsh_factory.addLine.call_count == 4 # Four straight edges + assert mock_gmsh_factory.addBezier.call_count == 0 # No Bézier curves + + # Verify surface and physical group creation + assert mock_gmsh_factory.addCurveLoop.call_count == 1 + assert mock_gmsh_factory.addPlaneSurface.call_count == 1 + assert mock_gmsh_factory.addPhysicalGroup.call_count == 1 + + # Get the actual surface tag that was created (should be 501) + # Since addPlaneSurface returns 500 + counter, and counter starts at 1 + surface_tag = 501 + + # Verify the physical group was created with the correct surface tag + mock_gmsh_factory.addPhysicalGroup.assert_called_with( + 2, [surface_tag], DOMAIN_VI_IRON.value + ) + + def test_preprocesses_outline_with_bezier_curves( + self, mock_gmsh_factory, outline_with_bezier_curves + ): + """Should create geometry for outline containing both straight and Bézier edges.""" + mesher = OutlinePreprocessor() + mesher.preprocess_outlines( + mock_gmsh_factory, + [outline_with_bezier_curves], + [{"physical_groups": [DOMAIN_VA]}] + ) + + # Verify geometry creation calls + assert mock_gmsh_factory.addPoint.call_count == 6 # All unique control points + assert mock_gmsh_factory.addLine.call_count == 2 # Two straight segments + assert mock_gmsh_factory.addBezier.call_count == 2 # Two Bézier segments + + # Verify surface and physical group creation + assert mock_gmsh_factory.addCurveLoop.call_count == 1 + assert mock_gmsh_factory.addPlaneSurface.call_count == 1 + + # Surface tag should be 501 (first call to addPlaneSurface) + surface_tag = 501 + + mock_gmsh_factory.addPhysicalGroup.assert_called_with( + 2, [surface_tag], DOMAIN_VA.value + ) + + # ==================== Hole Handling Tests ==================== + + def test_preprocesses_outer_outline_with_inner_hole( + self, mock_gmsh_factory, square_outline + ): + """Should create outer surface containing an inner hole.""" + # Create inner square outline (hole) + inner_square_points = [ + Point(0.25, 0.25), + Point(0.75, 0.25), + Point(0.75, 0.75), + Point(0.25, 0.75) + ] + inner_segments = [ + BezierSegment([inner_square_points[0], inner_square_points[1]], degree=1), + BezierSegment([inner_square_points[1], inner_square_points[2]], degree=1), + BezierSegment([inner_square_points[2], inner_square_points[3]], degree=1), + BezierSegment([inner_square_points[3], inner_square_points[0]], degree=1), + ] + inner_outline = Outline(inner_segments, inner_square_points, Color.GREEN) + + outlines = [square_outline, inner_outline] + properties = [ + {"holes": [1], "physical_groups": [DOMAIN_VI_IRON]}, # Outer contains hole + {"holes": [], "physical_groups": [DOMAIN_VI_AIR]} # Inner is hole + ] + + mesher = OutlinePreprocessor() + mesher.preprocess_outlines(mock_gmsh_factory, outlines, properties) + + # Verify holes are processed first (topological ordering) + assert mesher.get_processing_order() == [1, 0] # Inner first, outer second + + # Verify both surfaces were created + assert mock_gmsh_factory.addPlaneSurface.call_count == 2 + + # Outer surface should be created with hole references + surface_calls = mock_gmsh_factory.addPlaneSurface.call_args_list + # Find the call that has 2 curve loops (main loop + hole) + for call_obj in surface_calls: + if len(call_obj[0][0]) == 2: # Outer has 2 curve loops + assert len(call_obj[0][0]) == 2 # Main loop + hole loop + break + + def test_preprocesses_outline_with_multiple_holes( + self, mock_gmsh_factory, square_outline + ): + """Should create surface containing multiple holes.""" + # Create two hole outlines + hole_one_points = [Point(0.2, 0.2), Point(0.4, 0.2), Point(0.4, 0.4), Point(0.2, 0.4)] + hole_two_points = [Point(0.6, 0.6), Point(0.8, 0.6), Point(0.8, 0.8), Point(0.6, 0.8)] + + def create_square_segments(points): + return [BezierSegment([points[i], points[(i+1)%4]], degree=1) for i in range(4)] + + hole_one = Outline(create_square_segments(hole_one_points), hole_one_points, Color.GREEN) + hole_two = Outline(create_square_segments(hole_two_points), hole_two_points, Color.BLUE) + + outlines = [square_outline, hole_one, hole_two] + properties = [ + {"holes": [1, 2], "physical_groups": [DOMAIN_VI_IRON]}, # Outer with two holes + {"holes": [], "physical_groups": [DOMAIN_VI_AIR]}, # First hole + {"holes": [], "physical_groups": [DOMAIN_VI_AIR]} # Second hole + ] + + mesher = OutlinePreprocessor() + mesher.preprocess_outlines(mock_gmsh_factory, outlines, properties) + + # Verify topological order: holes first, then outer + processing_order = mesher.get_processing_order() + assert set(processing_order[:2]) == {1, 2} # Holes processed first + assert processing_order[2] == 0 # Outer processed last + + # Verify all surfaces were created + assert mock_gmsh_factory.addPlaneSurface.call_count == 3 + + # ==================== Physical Group Tests ==================== + + def test_assigns_boundary_physical_groups_to_outlines( + self, mock_gmsh_factory, square_outline + ): + """Should assign boundary physical groups to 1D outline entities.""" + preprocessor = OutlinePreprocessor() + + preprocessor = OutlinePreprocessor() + preprocessor.preprocess_outlines(mock_gmsh_factory, [square_outline], [{"physical_groups": [BOUNDARY_OUT]}]) + + # The line tags should be 201, 202, 203, 204 (incrementing from 200) + expected_curve_tags = [201, 202, 203, 204] + + # Check that addPhysicalGroup was called with expected curve tags + mock_gmsh_factory.addPhysicalGroup.assert_called_with( + 1, expected_curve_tags, BOUNDARY_OUT.value + ) + + def test_assigns_multiple_physical_groups_to_single_outline( + self, mock_gmsh_factory, square_outline + ): + """Should assign both domain and boundary physical groups when specified.""" + preprocessor = OutlinePreprocessor() + + preprocessor.preprocess_outlines( + mock_gmsh_factory, + [square_outline], + [{"physical_groups": [DOMAIN_VA, BOUNDARY_GAMMA]}] + ) + + # Should have two physical group assignments + assert mock_gmsh_factory.addPhysicalGroup.call_count == 2 + + calls = mock_gmsh_factory.addPhysicalGroup.call_args_list + domain_call = next(c for c in calls if c[0][0] == 2) # Dimension 2 + boundary_call = next(c for c in calls if c[0][0] == 1) # Dimension 1 + + # Verify domain assignment + assert domain_call[0][2] == DOMAIN_VA.value + assert domain_call[0][1] == [501] # Surface tag (first call returns 501) + + # Verify boundary assignment + assert boundary_call[0][2] == BOUNDARY_GAMMA.value + # Line tags should be 201, 202, 203, 204 + assert boundary_call[0][1] == [201, 202, 203, 204] + + # ==================== Edge Case Tests ==================== + + def test_returns_processing_order_copy_not_reference( + self, mock_gmsh_factory, square_outline + ): + """Should return a copy of processing order to prevent external modification.""" + preprocessor = OutlinePreprocessor() + + preprocessor.preprocess_outlines(mock_gmsh_factory, [square_outline], [{"physical_groups": [DOMAIN_VA]}]) + + order = preprocessor.get_processing_order() + assert order == [0] + + # Modifying returned list shouldn't affect internal state + order.append(999) + assert preprocessor.get_processing_order() == [0] + + def test_raises_error_for_non_existent_hole_reference( + self, mock_gmsh_factory, square_outline + ): + """Should raise error when hole index references non-existent outline.""" + preprocessor = OutlinePreprocessor() + + outlines = [square_outline] + properties = [{"holes": [999], "physical_groups": [DOMAIN_VI_IRON]}] # Invalid hole index + + with pytest.raises(ValueError, match="has not been created yet"): + preprocessor.preprocess_outlines(mock_gmsh_factory, outlines, properties) + def test_raises_error_for_non_physical_group_in_list( + self, mock_gmsh_factory, square_outline + ): + """Should raise TypeError when physical_groups contains non-PhysicalGroup objects.""" + preprocessor = OutlinePreprocessor() + + outlines = [square_outline] + properties = [{"physical_groups": ["invalid_type"]}] + + with pytest.raises(TypeError, match="must be PhysicalGroup instance"): + preprocessor.preprocess_outlines(mock_gmsh_factory, outlines, properties) + # ==================== Internal Method Tests ==================== + + def test_falls_back_to_input_order_when_topological_sort_fails( + self, square_outline + ): + """Should use input order when cyclic dependencies prevent topological sort.""" + preprocessor = OutlinePreprocessor() + + # Create outlines with circular dependency + outlines = [square_outline, square_outline, square_outline] + properties = [ + {"holes": [1], "physical_groups": [DOMAIN_VA]}, # Depends on outline 1 + {"holes": [0], "physical_groups": [DOMAIN_VI_IRON]}, # Depends on outline 0 (cycle) + {"physical_groups": [DOMAIN_VI_AIR]} + ] + + with patch('builtins.print') as mock_print: + order = preprocessor._get_processing_order(outlines, properties) + + # Verify warning was logged + mock_print.assert_called_with( + "Warning: Could not determine topological order. Using input order." + ) + + # Should use original order as fallback + assert order == [0, 1, 2] + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_svg_parser.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_svg_parser.py new file mode 100644 index 0000000..5c0aa24 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_svg_parser.py @@ -0,0 +1,564 @@ +""" +Test suite for the SVG Parser infrastructure component. +""" +import pytest +import tempfile +import os + +from svg_to_getdp.infrastructure.svg_processing.svg_parser import SvgParser +from svg_to_getdp.core.entities.raw_outline import RawOutline +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color + + +class TestSVGParser: + """Test suite for the SvgParser class""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def parser(self): + """Set up a fresh parser instance for each test""" + return SvgParser() + + @pytest.fixture + def temp_svg_file(self): + """Create a temporary SVG file for testing""" + def _create_temp_file(content): + with tempfile.NamedTemporaryFile(mode='w', suffix='.svg', delete=False) as f: + f.write(content) + return f.name + return _create_temp_file + + @pytest.fixture + def cleanup_temp_file(self): + """Clean up temporary file""" + def _cleanup(filepath): + if os.path.exists(filepath): + os.unlink(filepath) + return _cleanup + + # ==================== Basic Tests ==================== + + def test_parser_initialization(self, parser): + """Test that parser initializes with correct namespace""" + assert parser.namespace == '{http://www.w3.org/2000/svg}' + + def test_parse_nonexistent_file(self, parser): + """Test that parser raises error for nonexistent file""" + with pytest.raises(ValueError, match="Invalid SVG file"): + parser.extract_raw_outlines_by_color("nonexistent.svg") + + def test_parse_invalid_xml(self, parser, temp_svg_file, cleanup_temp_file): + """Test that parser raises error for invalid XML""" + temp_path = temp_svg_file("invalid xml content") + + try: + with pytest.raises(ValueError, match="Invalid SVG file"): + parser.extract_raw_outlines_by_color(temp_path) + finally: + cleanup_temp_file(temp_path) + + # ==================== SVG Parsing Tests ==================== + + def test_parse_minimal_svg(self, parser, temp_svg_file, cleanup_temp_file): + """Test parsing of minimal valid SVG""" + svg_content = ''' + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + assert result == {} # No elements, empty result + finally: + cleanup_temp_file(temp_path) + + def test_parse_svg_with_single_red_dot(self, parser, temp_svg_file, cleanup_temp_file): + """Test parsing SVG with a single red dot (circle)""" + svg_content = ''' + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check it has one color key + keys = list(result.keys()) + assert len(keys) == 1 + + red_color_key = keys[0] + red_raw_outlines = result[red_color_key] + + # Check the color key is red + assert red_color_key.name == "red" + assert red_color_key.rgb == (255, 0, 0) + + # Check there is one raw_outline consisting of one point + assert len(red_raw_outlines) == 1 + raw_outline = red_raw_outlines[0] + assert isinstance(raw_outline, RawOutline) + assert len(raw_outline.points) == 1 + + # Check the point is in valid range (scaled to unit coordinates) + point = raw_outline.points[0] + assert 0 <= point.x <= 1, f"x={point.x} not in [0,1]" + assert 0 <= point.y <= 1, f"y={point.y} not in [0,1]" + + finally: + cleanup_temp_file(temp_path) + + def test_parse_svg_with_multiple_colors(self, parser, temp_svg_file, cleanup_temp_file): + """Test parsing SVG with one shape per color - red as single-point raw_outline from ellipse""" + svg_content = ''' + + + + + + + + + + + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check we have exactly 4 color keys (red, green, blue, black) + color_keys = list(result.keys()) + assert len(color_keys) == 4, f"Expected 4 colors, got {len(color_keys)}: {[c.name for c in color_keys]}" + + # Test RED structure (ellipse → single point) + red_color_key = None + for key in color_keys: + if key.name == "red": + red_color_key = key + break + + assert red_color_key is not None, "Red color not found in results" + assert red_color_key.name == "red" + assert red_color_key.rgb == (255, 0, 0) + + red_raw_outlines = result[red_color_key] + assert len(red_raw_outlines) == 1, f"Expected 1 red raw_outline, got {len(red_raw_outlines)}" + + red_raw_outline = red_raw_outlines[0] + assert isinstance(red_raw_outline, RawOutline) + assert red_raw_outline.color.name == "red" + + # Red structure should have exactly 1 point (center of ellipse) + assert len(red_raw_outline.points) == 1, f"Red ellipse should have 1 point, got {len(red_raw_outline.points)}" + + red_point = red_raw_outline.points[0] + assert 0 <= red_point.x <= 1, f"Red point x={red_point.x} not in [0,1]" + assert 0 <= red_point.y <= 1, f"Red point y={red_point.y} not in [0,1]" + + # Test GREEN structure (closed square path) + green_color_key = None + for key in color_keys: + if key.name == "green": + green_color_key = key + break + + assert green_color_key is not None, "Green color not found in results" + assert green_color_key.name == "green" + assert green_color_key.rgb == (0, 255, 0) + + green_raw_outlines = result[green_color_key] + assert len(green_raw_outlines) == 1, f"Expected 1 green raw_outline, got {len(green_raw_outlines)}" + + green_raw_outline = green_raw_outlines[0] + assert isinstance(green_raw_outline, RawOutline) + assert green_raw_outline.color.name == "green" + + # Green structure should have multiple points (at least 4 for a square) + assert len(green_raw_outline.points) >= 4, f"Green square should have >=4 points, got {len(green_raw_outline.points)}" + assert green_raw_outline.is_closed, "Green square should be closed" + + for green_point in green_raw_outline.points: + assert 0 <= green_point.x <= 1, f"Green point x={green_point.x} not in [0,1]" + assert 0 <= green_point.y <= 1, f"Green point y={green_point.y} not in [0,1]" + + # Test BLUE structure (open line path) + blue_color_key = None + for key in color_keys: + if key.name == "blue": + blue_color_key = key + break + + assert blue_color_key is not None, "Blue color not found in results" + assert blue_color_key.name == "blue" + assert blue_color_key.rgb == (0, 0, 255) + + blue_raw_outlines = result[blue_color_key] + assert len(blue_raw_outlines) == 1, f"Expected 1 blue raw_outline, got {len(blue_raw_outlines)}" + + blue_raw_outline = blue_raw_outlines[0] + assert isinstance(blue_raw_outline, RawOutline) + assert blue_raw_outline.color.name == "blue" + + # Blue structure should have multiple points (at least 2 for a line) + assert len(blue_raw_outline.points) >= 2, f"Blue line should have >=2 points, got {len(blue_raw_outline.points)}" + assert not blue_raw_outline.is_closed, "Blue line should be open" + + for blue_point in blue_raw_outline.points: + assert 0 <= blue_point.x <= 1, f"Blue point x={blue_point.x} not in [0,1]" + assert 0 <= blue_point.y <= 1, f"Blue point y={blue_point.y} not in [0,1]" + + # Test BLACK structure (closed triangle path) + black_color_key = None + for key in color_keys: + if key.name == "black": + black_color_key = key + break + + assert black_color_key is not None, "Black color not found in results" + assert black_color_key.name == "black" + assert black_color_key.rgb == (0, 0, 0) + + black_raw_outlines = result[black_color_key] + assert len(black_raw_outlines) == 1, f"Expected 1 black raw_outline, got {len(black_raw_outlines)}" + + black_raw_outline = black_raw_outlines[0] + assert isinstance(black_raw_outline, RawOutline) + assert black_raw_outline.color.name == "black" + + # Black structure should have multiple points (at least 3 for a triangle) + assert len(black_raw_outline.points) >= 3, f"Black triangle should have >=3 points, got {len(black_raw_outline.points)}" + assert black_raw_outline.is_closed, "Black triangle should be closed" + + for black_point in black_raw_outline.points: + assert 0 <= black_point.x <= 1, f"Black point x={black_point.x} not in [0,1]" + assert 0 <= black_point.y <= 1, f"Black point y={black_point.y} not in [0,1]" + + # Verify no duplicate points in multi-point raw_outlines + for color, raw_outlines in result.items(): + if color.name != "red": # Skip red (single point) + for raw_outline in raw_outlines: + if len(raw_outline.points) > 1: + # Check for consecutive duplicates + for i in range(len(raw_outline.points) - 1): + assert raw_outline.points[i] != raw_outline.points[i + 1], \ + f"Consecutive duplicate points found in {color.name} raw_outline at index {i}" + + finally: + cleanup_temp_file(temp_path) + + # ==================== ViewBox and Scaling Tests ==================== + + def test_parse_viewbox_scaling(self, parser, temp_svg_file, cleanup_temp_file): + """Test that coordinates are properly scaled to unit square""" + svg_content = ''' + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check any raw_outlines we get + for color, raw_outlines in result.items(): + for raw_outline in raw_outlines: + # Check that points are scaled to [0,1] range + for point in raw_outline.points: + assert 0 <= point.x <= 1 + assert 0 <= point.y <= 1 + + finally: + cleanup_temp_file(temp_path) + + def test_parse_no_viewbox(self, parser, temp_svg_file, cleanup_temp_file): + """Test parsing SVG without viewBox attribute""" + svg_content = ''' + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check any raw_outlines we get + for color, raw_outlines in result.items(): + for raw_outline in raw_outlines: + # Should still work with default scaling + for point in raw_outline.points: + assert 0 <= point.x <= 1 + assert 0 <= point.y <= 1 + + finally: + cleanup_temp_file(temp_path) + + def test_parse_invalid_viewbox(self, parser, temp_svg_file, cleanup_temp_file): + """Test parsing SVG with invalid viewBox""" + svg_content = ''' + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check any raw_outlines we get + for color, raw_outlines in result.items(): + for raw_outline in raw_outlines: + # Should use default scaling + for point in raw_outline.points: + assert 0 <= point.x <= 1 + assert 0 <= point.y <= 1 + + finally: + cleanup_temp_file(temp_path) + + # ==================== Color Extraction Tests ==================== + + def test_color_extraction_hex(self, parser, temp_svg_file, cleanup_temp_file): + """Test color extraction from hex values""" + svg_content = ''' + + + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check that colors are extracted + for color in result.keys(): + assert color.name in ["red", "green", "blue"] + + finally: + cleanup_temp_file(temp_path) + + def test_color_extraction_rgb(self, parser, temp_svg_file, cleanup_temp_file): + """Test color extraction from rgb values""" + svg_content = ''' + + + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Check for expected colors + for color in result.keys(): + assert color.name in ["red", "green", "blue"] + + finally: + cleanup_temp_file(temp_path) + + # ==================== Parameterized Color Mapping Tests ==================== + + @pytest.mark.parametrize("hex_color,expected_primary_name", [ + ("#ff8080", "red"), # Light red -> red + ("#80ff80", "green"), # Light green -> green + ("#8080ff", "blue"), # Light blue -> blue + ("#ff4000", "red"), # Orange-red -> red + ("#ffff00", "red"), # Yellow -> red (closest to red+green) + ]) + def test_hex_color_mapping(self, parser, hex_color, expected_primary_name): + """Test mapping of various hex colors to primary colors""" + try: + # Try to access the color classifier if it's exposed + if hasattr(parser, 'color_classifier'): + result = parser.color_classifier.parse_color_string(hex_color) + assert result.name == expected_primary_name + else: + # Fallback: test through the parser's color extraction + import tempfile + import os + + svg_content = f''' + + + ''' + + with tempfile.NamedTemporaryFile(mode='w', suffix='.svg', delete=False) as f: + f.write(svg_content) + temp_path = f.name + + try: + result_dict = parser.extract_raw_outlines_by_color(temp_path) + colors = list(result_dict.keys()) + if colors: + result = colors[0] + assert result.name == expected_primary_name + else: + pytest.skip("No color extracted from test SVG") + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + except AttributeError: + pytest.skip("Color classification method not accessible in current architecture") + + # ==================== Error Handling Tests ==================== + + def test_error_handling_malformed_elements(self, parser, temp_svg_file, cleanup_temp_file): + """Test error handling for malformed SVG elements""" + svg_content = ''' + + + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + # This should raise an error due to malformed elements + with pytest.raises(ValueError, match="Invalid SVG file"): + parser.extract_raw_outlines_by_color(temp_path) + + finally: + cleanup_temp_file(temp_path) + + # ==================== RawOutline Tests ==================== + + def test_raw_outline_validation(self): + """Test that RawOutline validates point count""" + # Test works with 3+ points for any color + points_3 = [Point(0, 0), Point(1, 0), Point(1, 1)] + + # Green, blue and black should work with 3+ points + for color in [Color.GREEN, Color.BLUE, Color.BLACK]: + raw_outline = RawOutline(points=points_3, color=color) + assert raw_outline.points == points_3 + + # Test with more than 3 points + points_4 = [Point(0, 0), Point(1, 0), Point(1, 1), Point(0, 1)] + raw_outline_4 = RawOutline(points=points_4, color=Color.BLACK) + assert raw_outline_4.points == points_4 + + # Should fail with less than 3 points for black, green and blue. + points_2 = [Point(0, 0), Point(1, 1)] + for color in [Color.BLACK, Color.GREEN, Color.BLUE]: + with pytest.raises(ValueError, match="Raw outline must have at least 3 points"): + RawOutline(points=points_2, color=color) + + # Should fail with 0 points + with pytest.raises(ValueError): + RawOutline(points=[], color=Color.GREEN) + + # Red should work with 1 point + red_raw_outline_1 = RawOutline(points=[Point(0, 0)], color=Color.RED) + assert len(red_raw_outline_1.points) == 1 + + def test_raw_outline_structure(self, parser, temp_svg_file, cleanup_temp_file): + """Simple test that validates RawOutline objects for all four colors""" + svg_content = ''' + + + + + + + + + + + + + ''' + + temp_path = temp_svg_file(svg_content) + + try: + result = parser.extract_raw_outlines_by_color(temp_path) + + # Verify we have a dictionary + assert isinstance(result, dict) + + # Get the keys as a list + keys = list(result.keys()) + + # Check we have some colors + assert len(keys) > 0 + + # Find raw_outlines for each color by checking each key + red_raw_outlines = None + green_raw_outlines = None + blue_raw_outlines = None + black_raw_outlines = None + + for key in keys: + if hasattr(key, 'name'): + if key.name == 'red': + red_raw_outlines = result[key] + elif key.name == 'green': + green_raw_outlines = result[key] + elif key.name == 'blue': + blue_raw_outlines = result[key] + elif key.name == 'black': + black_raw_outlines = result[key] + + # Debug output + print(f"\nFound raw_outlines:") + if red_raw_outlines: + print(f" Red: {len(red_raw_outlines)} raw_outline(s)") + if green_raw_outlines: + print(f" Green: {len(green_raw_outlines)} raw_outline(s)") + if blue_raw_outlines: + print(f" Blue: {len(blue_raw_outlines)} raw_outline(s)") + if black_raw_outlines: + print(f" Black: {len(black_raw_outlines)} raw_outline(s)") + + # Validate red raw_outline (from circle) + assert red_raw_outlines is not None, "No red raw_outline found" + assert isinstance(red_raw_outlines, list) + assert len(red_raw_outlines) >= 1 + + red_raw_outline = red_raw_outlines[0] + assert isinstance(red_raw_outline, RawOutline) + assert isinstance(red_raw_outline.points, list) + + # Validate green raw_outline (from triangle path) + assert green_raw_outlines is not None, "No green raw_outline found" + assert isinstance(green_raw_outlines, list) + assert len(green_raw_outlines) >= 1 + + green_raw_outline = green_raw_outlines[0] + assert isinstance(green_raw_outline, RawOutline) + assert isinstance(green_raw_outline.points, list) + + # Validate blue raw_outline (from rectangle path) + assert blue_raw_outlines is not None, "No blue raw_outline found" + assert isinstance(blue_raw_outlines, list) + assert len(blue_raw_outlines) >= 1 + + blue_raw_outline = blue_raw_outlines[0] + assert isinstance(blue_raw_outline, RawOutline) + assert isinstance(blue_raw_outline.points, list) + + # Validate black raw_outline (from polygon) + assert black_raw_outlines is not None, "No black raw_outline found" + assert isinstance(black_raw_outlines, list) + assert len(black_raw_outlines) >= 1 + + black_raw_outline = black_raw_outlines[0] + assert isinstance(black_raw_outline, RawOutline) + assert isinstance(black_raw_outline.points, list) + finally: + cleanup_temp_file(temp_path) + \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/infrastructure/test_wire_preprocessor.py b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_wire_preprocessor.py new file mode 100644 index 0000000..53de7a4 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/infrastructure/test_wire_preprocessor.py @@ -0,0 +1,593 @@ +""" +Unit tests for WirePreprocessor class. + +Tests wire preprocessing functionality including clustering, +configuration loading, and physical group assignment. +""" +import pytest +import yaml +import tempfile +import os +import math +from unittest.mock import Mock, patch + +from svg_to_getdp.infrastructure.wire_preprocessor import WirePreprocessor, Wire, WireCluster +from svg_to_getdp.core.entities.point import Point +from svg_to_getdp.core.entities.color import Color +from svg_to_getdp.core.entities.physical_group import DOMAIN_COIL_POSITIVE, DOMAIN_COIL_NEGATIVE + + +class TestWirePreprocessor: + """Test suite for WirePreprocessor class.""" + + # ==================== Fixtures ==================== + + @pytest.fixture + def preprocessor(self): + """Create a wire preprocessor instance for testing.""" + return WirePreprocessor() + + @pytest.fixture + def mock_factory(self): + """Create a mock factory for testing.""" + return Mock() + + @pytest.fixture + def basic_wires(self): + """Create basic wire test data.""" + return [ + (Point(0.0, 0.0), Color.RED), + (Point(1.0, 0.0), Color.RED), + (Point(0.0, 1.0), Color.RED), + (Point(1.0, 1.0), Color.RED), + (Point(0.5, 0.5), Color.RED), + (Point(1.5, 0.5), Color.RED) + ] + + @pytest.fixture + def spatially_distributed_wires(self): + """Create spatially distributed wires for clustering tests.""" + return [ + Wire(Point(0.0, 0.0), Color.RED, 0), + Wire(Point(1.0, 0.0), Color.RED, 1), + Wire(Point(2.0, 0.0), Color.RED, 2), + Wire(Point(10.0, 0.0), Color.RED, 3), + Wire(Point(11.0, 0.0), Color.RED, 4), + ] + + @pytest.fixture + def sorted_wire_test_data(self): + """Create unsorted wires for sorting tests.""" + return [ + Wire(Point(2.0, 1.0), Color.RED, 0), + Wire(Point(1.0, 2.0), Color.RED, 1), + Wire(Point(2.0, 2.0), Color.RED, 2), + Wire(Point(0.0, 0.0), Color.RED, 3), + ] + + @pytest.fixture + def distance_calculation_wires(self): + """Create wires for distance calculation tests.""" + return ( + Wire(Point(0.0, 0.0), Color.RED, 0), + Wire(Point(3.0, 4.0), Color.RED, 1) + ) + + # ==================== Helper Methods ==================== + + def create_temporary_configuration_file(self, configuration_content: dict) -> str: + """Creates a temporary YAML configuration file for testing.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as configuration_file: + yaml.dump(configuration_content, configuration_file) + return configuration_file.name + + # ==================== Initialization Tests ==================== + + def test_initial_state_is_empty(self, preprocessor): + """Verifies WirePreprocessor starts with empty collections and no factory.""" + assert preprocessor.factory is None + assert preprocessor.wire_clusters == [] + assert preprocessor.all_wires == [] + + # ==================== Basic Functionality Tests ==================== + + def test_sorts_wires_from_top_to_bottom_left_to_right(self, preprocessor, sorted_wire_test_data): + """Ensures wires are sorted by descending y, then ascending x coordinates.""" + unsorted_wires = sorted_wire_test_data + + sorted_wires = preprocessor._sort_wires(unsorted_wires) + + assert sorted_wires[0].original_index == 1 # (1.0, 2.0) - highest y + assert sorted_wires[1].original_index == 2 # (2.0, 2.0) - same y, x > 1.0 + assert sorted_wires[2].original_index == 0 # (2.0, 1.0) - lower y + assert sorted_wires[3].original_index == 3 # (0.0, 0.0) - lowest y + + def test_calculates_euclidean_distance_between_wires(self, preprocessor, distance_calculation_wires): + """Validates distance calculation between two wire positions.""" + first_wire, second_wire = distance_calculation_wires + + calculated_distance = preprocessor._calculate_distance(first_wire, second_wire) + + assert math.isclose(calculated_distance, 5.0) + + def test_maps_cluster_current_sign_to_physical_group(self, preprocessor): + """Tests that cluster current signs correctly map to physical groups.""" + positive_current_cluster = WireCluster(name="positive_cluster", wire_count=1, current_sign=1) + negative_current_cluster = WireCluster(name="negative_cluster", wire_count=1, current_sign=-1) + + positive_physical_group = preprocessor._get_physical_group_for_cluster(positive_current_cluster) + assert positive_physical_group.value == DOMAIN_COIL_POSITIVE.value + assert positive_physical_group.name == DOMAIN_COIL_POSITIVE.name + + negative_physical_group = preprocessor._get_physical_group_for_cluster(negative_current_cluster) + assert negative_physical_group.value == DOMAIN_COIL_NEGATIVE.value + assert negative_physical_group.name == DOMAIN_COIL_NEGATIVE.name + + invalid_current_cluster = WireCluster(name="invalid_cluster", wire_count=1, current_sign=0) + with pytest.raises(ValueError, match="Invalid current sign"): + preprocessor._get_physical_group_for_cluster(invalid_current_cluster) + + # ==================== Configuration Loading Tests ==================== + + def test_loads_wire_clusters_from_valid_configuration(self, preprocessor): + """Validates loading wire clusters from properly formatted YAML configuration.""" + valid_configuration = { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 3, + 'current_sign': 1 + }, + 'cluster_2': { + 'wire_count': 2, + 'current_sign': -1 + } + } + } + + configuration_file_path = self.create_temporary_configuration_file(valid_configuration) + + try: + loaded_clusters = preprocessor._load_wire_clusters(configuration_file_path) + + assert len(loaded_clusters) == 2 + + # Clusters are sorted alphabetically by name + first_cluster = loaded_clusters[0] + assert first_cluster.name == 'cluster_1' + assert first_cluster.wire_count == 3 + assert first_cluster.current_sign == 1 + assert first_cluster.wires == [] + + second_cluster = loaded_clusters[1] + assert second_cluster.name == 'cluster_2' + assert second_cluster.wire_count == 2 + assert second_cluster.current_sign == -1 + assert second_cluster.wires == [] + + finally: + os.unlink(configuration_file_path) + + @pytest.mark.parametrize("configuration_content, expected_error_message", [ + ( + {'other_section': {'foo': 'bar'}}, + "Config file must contain 'wire_clusters' section" + ), + ( + { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': -5, + 'current_sign': 1 + } + } + }, + "wire_count must be a positive integer" + ), + ( + { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 0, + 'current_sign': 1 + } + } + }, + "wire_count must be a positive integer" + ), + ( + { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 3, + 'current_sign': 0 + } + } + }, + "current_sign must be 1 or -1" + ), + ( + { + 'wire_clusters': { + 'cluster_1': 'not_a_dict' + } + }, + "Cluster 'cluster_1' configuration must be a dictionary" + ), + ( + { + 'wire_clusters': { + 'cluster_1': { + 'current_sign': 1 + } + } + }, + "Cluster 'cluster_1' must have 'wire_count'" + ), + ( + { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 3 + } + } + }, + "Cluster 'cluster_1' must have 'current_sign'" + ), + ]) + def test_raises_error_for_invalid_configuration(self, preprocessor, configuration_content, expected_error_message): + """Verifies appropriate errors are raised for various invalid configuration scenarios.""" + configuration_file_path = self.create_temporary_configuration_file(configuration_content) + + try: + with pytest.raises(ValueError, match=expected_error_message): + preprocessor._load_wire_clusters(configuration_file_path) + finally: + os.unlink(configuration_file_path) + + def test_raises_error_when_configuration_file_not_found(self, preprocessor): + """Ensures FileNotFoundError is raised when configuration file doesn't exist.""" + with pytest.raises(FileNotFoundError): + preprocessor._load_wire_clusters("/nonexistent/path/config.yaml") + + def test_raises_error_for_invalid_yaml_syntax(self, preprocessor): + """Verifies invalid YAML syntax triggers appropriate error.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as temporary_file: + temporary_file.write("invalid: yaml: [") + configuration_file_path = temporary_file.name + + try: + with pytest.raises(ValueError, match="Invalid YAML"): + preprocessor._load_wire_clusters(configuration_file_path) + finally: + os.unlink(configuration_file_path) + + # ==================== Clustering Tests ==================== + + def test_clusters_wires_based_on_spatial_proximity(self, preprocessor, spatially_distributed_wires): + """Tests that spatially close wires are grouped into the same cluster.""" + wires = spatially_distributed_wires + + preprocessor.wire_clusters = [ + WireCluster(name="first_cluster", wire_count=3, current_sign=1), + WireCluster(name="second_cluster", wire_count=2, current_sign=-1) + ] + + preprocessor._perform_clustering(wires) + + # First three close wires should be in first cluster + assert len(preprocessor.wire_clusters[0].wires) == 3 + first_cluster_indices = {wire.original_index for wire in preprocessor.wire_clusters[0].wires} + assert first_cluster_indices == {0, 1, 2} + + # Last two close wires should be in second cluster + assert len(preprocessor.wire_clusters[1].wires) == 2 + second_cluster_indices = {wire.original_index for wire in preprocessor.wire_clusters[1].wires} + assert second_cluster_indices == {3, 4} + + def test_raises_error_when_insufficient_wires_for_cluster(self, preprocessor): + """Ensures error is raised when cluster requires more wires than available.""" + available_wires = [ + Wire(Point(0.0, 0.0), Color.RED, 0), + Wire(Point(1.0, 0.0), Color.RED, 1), + ] + + preprocessor.wire_clusters = [ + WireCluster(name="large_cluster", wire_count=3, current_sign=1), # Needs 3 wires + ] + + with pytest.raises(ValueError, match="Not enough wires for cluster"): + preprocessor._perform_clustering(available_wires) + + # ==================== Integration Tests ==================== + + def test_returns_empty_dict_when_no_wires_but_configuration_expected(self, preprocessor, mock_factory): + """Handles case where configuration expects wires but none are provided.""" + configuration_expecting_wires = { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 1, + 'current_sign': 1 + } + } + } + + configuration_file_path = self.create_temporary_configuration_file(configuration_expecting_wires) + + try: + result = preprocessor.prepare_wires( + factory=mock_factory, + config_path=configuration_file_path, + wires=[] + ) + + assert result == {} + mock_factory.addPoint.assert_not_called() + mock_factory.addPhysicalGroup.assert_not_called() + + finally: + os.unlink(configuration_file_path) + + def test_raises_error_when_wire_count_mismatches_configuration(self, preprocessor, mock_factory): + """Verifies error when total wires don't match cluster configuration requirements.""" + configuration_content = { + 'wire_clusters': { + 'cluster_1': { + 'wire_count': 5, # Expects 5 wires + 'current_sign': 1 + } + } + } + + configuration_file_path = self.create_temporary_configuration_file(configuration_content) + + available_wires = [ + (Point(0.0, 0.0), Color.RED), + (Point(1.0, 0.0), Color.RED), + (Point(2.0, 0.0), Color.RED) # Only 3 wires + ] + + try: + with pytest.raises(ValueError, match="Number of wires.*doesn't match cluster configuration"): + preprocessor.prepare_wires( + factory=mock_factory, + config_path=configuration_file_path, + wires=available_wires + ) + finally: + os.unlink(configuration_file_path) + + @patch('svg_to_getdp.infrastructure.wire_preprocessor.WirePreprocessor._load_wire_clusters') + def test_prepares_wires_and_assigns_to_clusters(self, mock_load_clusters, preprocessor, mock_factory): + """Integration test verifying complete wire preparation with factory interaction.""" + mock_clusters = [ + WireCluster(name="positive_cluster", wire_count=3, current_sign=1), + WireCluster(name="negative_cluster", wire_count=3, current_sign=-1) + ] + mock_load_clusters.return_value = mock_clusters + + spatially_separated_wires = [ + # First spatial group - should form positive cluster + (Point(0.0, 10.0), Color.RED), + (Point(1.0, 10.0), Color.RED), + (Point(0.0, 9.0), Color.RED), + + # Second spatial group - should form negative cluster + (Point(100.0, 0.0), Color.RED), + (Point(101.0, 0.0), Color.RED), + (Point(100.0, 1.0), Color.RED), + ] + + mock_factory.addPoint.side_effect = list(range(1, 7)) + + wire_results = preprocessor.prepare_wires( + factory=mock_factory, + config_path="dummy_path.yaml", + wires=spatially_separated_wires + ) + + assert mock_factory.addPoint.call_count == 6 + assert mock_factory.addPhysicalGroup.call_count == 2 + + # Verify physical group assignments + positive_physical_group_call = mock_factory.addPhysicalGroup.call_args_list[0] + assert positive_physical_group_call[0][2] == DOMAIN_COIL_POSITIVE.value + + negative_physical_group_call = mock_factory.addPhysicalGroup.call_args_list[1] + assert negative_physical_group_call[0][2] == DOMAIN_COIL_NEGATIVE.value + + # Verify result structure + assert len(wire_results) == 6 + + for wire_index in range(6): + wire_data = wire_results[wire_index] + assert 'point' in wire_data + assert 'color' in wire_data + assert 'gmsh_point_tag' in wire_data + assert 'physical_group' in wire_data + assert 'wire_index' in wire_data + assert 'wire_name' in wire_data + assert 'cluster_name' in wire_data + assert 'wire_in_cluster_index' in wire_data + assert 'cluster_index' in wire_data + + # Verify clustering logic + first_group_cluster = wire_results[0]['cluster_name'] + assert wire_results[1]['cluster_name'] == first_group_cluster + assert wire_results[2]['cluster_name'] == first_group_cluster + + second_group_cluster = wire_results[3]['cluster_name'] + assert wire_results[4]['cluster_name'] == second_group_cluster + assert wire_results[5]['cluster_name'] == second_group_cluster + + assert first_group_cluster != second_group_cluster + + positive_wire_count = sum( + 1 for index in range(6) + if wire_results[index]['physical_group'].value == DOMAIN_COIL_POSITIVE.value + ) + negative_wire_count = sum( + 1 for index in range(6) + if wire_results[index]['physical_group'].value == DOMAIN_COIL_NEGATIVE.value + ) + + assert positive_wire_count == 3 + assert negative_wire_count == 3 + + # ==================== Edge Case Tests ==================== + + @pytest.mark.parametrize("configuration_content, wire_positions, expected_cluster_characteristics", [ + ( + { + 'wire_clusters': { + 'single_cluster': { + 'wire_count': 4, + 'current_sign': 1 + } + } + }, + [(float(i), float(i)) for i in range(4)], + [('single_cluster', 4, DOMAIN_COIL_POSITIVE.value)] + ), + ( + { + 'wire_clusters': { + 'cluster_A': { + 'wire_count': 2, + 'current_sign': 1 + }, + 'cluster_B': { + 'wire_count': 2, + 'current_sign': -1 + } + } + }, + [ + (0.0, 0.0), + (0.1, 0.0), + (100.0, 100.0), + (100.1, 100.0), + ], + [('cluster_A', 2, DOMAIN_COIL_POSITIVE.value), + ('cluster_B', 2, DOMAIN_COIL_NEGATIVE.value)] + ), + ]) + def test_handles_edge_cases_in_wire_preparation(self, preprocessor, mock_factory, + configuration_content, wire_positions, + expected_cluster_characteristics): + """Tests various edge cases in wire preparation and clustering.""" + configuration_file_path = self.create_temporary_configuration_file(configuration_content) + + test_wires = [(Point(x, y), Color.RED) for x, y in wire_positions] + + mock_factory.addPoint.side_effect = list(range(1, len(test_wires) + 1)) + + try: + wire_results = preprocessor.prepare_wires( + factory=mock_factory, + config_path=configuration_file_path, + wires=test_wires + ) + + assert len(wire_results) == len(test_wires) + assert mock_factory.addPoint.call_count == len(test_wires) + + # Analyze cluster distribution + cluster_analysis = {} + for wire_data in wire_results.values(): + cluster_name = wire_data['cluster_name'] + if cluster_name not in cluster_analysis: + cluster_analysis[cluster_name] = { + 'wire_count': 0, + 'physical_group_value': wire_data['physical_group'].value, + 'wire_names': set() + } + cluster_analysis[cluster_name]['wire_count'] += 1 + cluster_analysis[cluster_name]['wire_names'].add(wire_data['wire_name']) + + # Verify cluster characteristics match expectations + assert len(cluster_analysis) == len(expected_cluster_characteristics) + + sorted_cluster_names = sorted(cluster_analysis.keys()) + for cluster_index, (expected_cluster_name, expected_wire_count, expected_physical_group_value) \ + in enumerate(expected_cluster_characteristics): + + actual_cluster_name = sorted_cluster_names[cluster_index] + cluster_data = cluster_analysis[actual_cluster_name] + + assert cluster_data['wire_count'] == expected_wire_count + assert cluster_data['physical_group_value'] == expected_physical_group_value + + finally: + os.unlink(configuration_file_path) + + # ==================== Performance Tests ==================== + + def test_performance_with_many_wires(self, preprocessor, mock_factory): + """Test that preprocessor handles large number of wires efficiently.""" + # Create many wires + many_wires = [(Point(i * 1.0, i * 1.0), Color.RED) for i in range(100)] + + # Configuration for 100 wires in a single cluster + configuration_content = { + 'wire_clusters': { + 'large_cluster': { + 'wire_count': 100, + 'current_sign': 1 + } + } + } + + configuration_file_path = self.create_temporary_configuration_file(configuration_content) + + mock_factory.addPoint.side_effect = list(range(1, 101)) + + import time + start_time = time.time() + try: + wire_results = preprocessor.prepare_wires( + factory=mock_factory, + config_path=configuration_file_path, + wires=many_wires + ) + end_time = time.time() + + # Should complete in reasonable time + assert end_time - start_time < 5.0 + + # Should have 100 wires + assert len(wire_results) == 100 + assert mock_factory.addPoint.call_count == 100 + + finally: + os.unlink(configuration_file_path) + + # ==================== Error Handling Tests ==================== + + def test_error_handling_with_invalid_wire_data(self, preprocessor, mock_factory): + """Test handling of invalid wire data.""" + invalid_wires = [ + (Point(0.0, 0.0), Color.RED), + "not a wire", + (Point(1.0, 1.0), Color.RED) + ] + + configuration_content = { + 'wire_clusters': { + 'test_cluster': { + 'wire_count': 3, + 'current_sign': 1 + } + } + } + + configuration_file_path = self.create_temporary_configuration_file(configuration_content) + + try: + with pytest.raises((TypeError, AttributeError, ValueError)): + preprocessor.prepare_wires( + factory=mock_factory, + config_path=configuration_file_path, + wires=invalid_wires + ) + finally: + os.unlink(configuration_file_path) diff --git a/sketchgetdp/svg_to_getdp/tests/test_configs/config_dipole_magnet.yaml b/sketchgetdp/svg_to_getdp/tests/test_configs/config_dipole_magnet.yaml new file mode 100644 index 0000000..d303eab --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/test_configs/config_dipole_magnet.yaml @@ -0,0 +1,24 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 25 + current_sign: 1 + cluster_2: + wire_count: 25 + current_sign: -1 + + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 15000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/test_configs/config_first_sketch.yaml b/sketchgetdp/svg_to_getdp/tests/test_configs/config_first_sketch.yaml new file mode 100644 index 0000000..8e73ff4 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/test_configs/config_first_sketch.yaml @@ -0,0 +1,23 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 6 + current_sign: 1 + cluster_2: + wire_count: 6 + current_sign: -1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 9000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/test_configs/config_h-type_magnet.yaml b/sketchgetdp/svg_to_getdp/tests/test_configs/config_h-type_magnet.yaml new file mode 100644 index 0000000..b00cd89 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/test_configs/config_h-type_magnet.yaml @@ -0,0 +1,23 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 16 + current_sign: 1 + cluster_2: + wire_count: 16 + current_sign: -1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 10000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/test_configs/config_quadrupole_magnet.yaml b/sketchgetdp/svg_to_getdp/tests/test_configs/config_quadrupole_magnet.yaml new file mode 100644 index 0000000..08ce29e --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/test_configs/config_quadrupole_magnet.yaml @@ -0,0 +1,29 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 16 + current_sign: 1 + cluster_2: + wire_count: 16 + current_sign: -1 + cluster_3: + wire_count: 16 + current_sign: -1 + cluster_4: + wire_count: 16 + current_sign: 1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.075 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 5000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/sketchgetdp/svg_to_getdp/tests/test_configs/config_racetrack_coil.yaml b/sketchgetdp/svg_to_getdp/tests/test_configs/config_racetrack_coil.yaml new file mode 100644 index 0000000..96fc5b1 --- /dev/null +++ b/sketchgetdp/svg_to_getdp/tests/test_configs/config_racetrack_coil.yaml @@ -0,0 +1,23 @@ +# SVG To Getdp Configuration + +## Wire cluster configuration +# Clusters are identified from top to bottom, left to right +# Each cluster has: number of wires and current direction (1 for positive, -1 for negative) +# Positive current flows out of the page. +wire_clusters: + cluster_1: + wire_count: 9 + current_sign: 1 + cluster_2: + wire_count: 9 + current_sign: -1 + +## mesh settings +# Set the mesh size for Gmsh +mesh_size: 0.1 + +## GetDP simulation settings +# Physical values for the simulation +physical_values: + Isource: 150000 # Current source in Amperes [A] + nu_iron_linear: 1/(1000 * 4e-7 * pi) # Iron reluctivity \ No newline at end of file diff --git a/tests/inputs/colors.jpg b/tests/inputs/colors.jpg new file mode 100644 index 0000000..fa0887c Binary files /dev/null and b/tests/inputs/colors.jpg differ diff --git a/tests/inputs/full_structures/dipole_magnet.svg b/tests/inputs/full_structures/dipole_magnet.svg new file mode 100644 index 0000000..146ae68 --- /dev/null +++ b/tests/inputs/full_structures/dipole_magnet.svg @@ -0,0 +1,465 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/inputs/full_structures/first_sketch.svg b/tests/inputs/full_structures/first_sketch.svg new file mode 100644 index 0000000..eec1463 --- /dev/null +++ b/tests/inputs/full_structures/first_sketch.svg @@ -0,0 +1,186 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/inputs/full_structures/h-type_magnet.svg b/tests/inputs/full_structures/h-type_magnet.svg new file mode 100644 index 0000000..005976a --- /dev/null +++ b/tests/inputs/full_structures/h-type_magnet.svg @@ -0,0 +1,310 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/inputs/full_structures/quadrupole_magnet.svg b/tests/inputs/full_structures/quadrupole_magnet.svg new file mode 100644 index 0000000..e064d86 --- /dev/null +++ b/tests/inputs/full_structures/quadrupole_magnet.svg @@ -0,0 +1,539 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/inputs/full_structures/racetrack_coil.svg b/tests/inputs/full_structures/racetrack_coil.svg new file mode 100644 index 0000000..8c88151 --- /dev/null +++ b/tests/inputs/full_structures/racetrack_coil.svg @@ -0,0 +1,176 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/test_BezierCurve.py b/tests/test_BezierCurve.py deleted file mode 100644 index 234082f..0000000 --- a/tests/test_BezierCurve.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest -import numpy as np -from sketchgetdp.bezier import BezierCurve - - -class TestBezierCurve(unittest.TestCase): - def setUp(self): - """Set up a BezierCurve instance for testing.""" - self.control_points = np.array([[0, 0], [1, 2], [2, 2], [3, 0]]) - self.bezier_curve = BezierCurve(self.control_points) - - def test_init(self): - """Test the initialization of the BezierCurve class.""" - self.assertIsInstance(self.bezier_curve, BezierCurve) - self.assertEqual(self.bezier_curve.degree, 3) - - def test_evaluate(self): - """Test the evaluate method of the BezierCurve class.""" - t = np.array([0, 0.5, 1]) - expected_result = np.array([[0, 0], [1.5, 1.5], [3, 0]]) - result = self.bezier_curve.evaluate(t) - self.assertTrue(np.allclose(result, expected_result)) - - def test_evaluate_derivative(self): - """Test the evaluate_derivative method of the BezierCurve class.""" - t = np.array([0, 0.5, 1]) - expected_result = np.array([[1, 2], [1, 0], [1, -2]]) - result = self.bezier_curve.evaluate_derivative(t) - self.assertTrue(np.allclose(result, expected_result))