library.root_shoot_matching

Root-shoot matching module for plant root analysis.

This module provides a complete pipeline for matching root skeleton structures to shoot regions in plant microscopy images. It handles filtering, scoring, assignment of roots to shoots, and produces length measurements with robot coordinates for each plant.

The pipeline processes binary masks of shoots and roots through multiple stages:

  1. Shoot reference point extraction
  2. Root mask filtering by spatial sampling
  3. Root structure extraction and skeletonization
  4. Candidate scoring based on proximity, size, and verticality
  5. Greedy left-to-right assignment ensuring one root per shoot
  6. Top node identification for path tracing
  7. Length calculation and coordinate conversion

Complete End-to-End Example:

# Setup paths and create data loader
from pathlib import Path
from root_shoot_matching import GetIm, RootMatchingConfig, match_roots_to_shoots_complete
import pandas as pd

# Define paths to your masks and images
images_path = Path('data/Kaggle/')
roots_path = Path('data/repaired/roots')
shoots_path = Path('data/repaired/shoots')

# Get sorted file lists
images = sorted(images_path.glob('*.png'))
roots = sorted(roots_path.glob('*.png'))
shoots = sorted(shoots_path.glob('*.png'))

# Create getter for loading masks
getter = GetIm(shoots=shoots, roots=roots)

# Configure matching parameters
config = RootMatchingConfig(
    max_horizontal_offset=200,      # Horizontal search radius (pixels)
    distance_weight=0.35,            # Proximity importance (0-1)
    size_weight=0.35,                # Root length importance (0-1)
    verticality_weight=0.3,          # Vertical orientation importance (0-1)
    max_below_shoot=100,             # Max distance below shoot to search (pixels)
    min_skeleton_pixels=6            # Minimum pixels to be valid root
)

# Process single sample
sample_idx = 0
shoot_mask, root_mask = getter(sample_idx, is_idx=True)

# Get results as numpy array (5 lengths in pixels, left to right)
lengths = match_roots_to_shoots_complete(
    shoot_mask, root_mask, images[sample_idx], config
)
print(f"Lengths: {lengths}")  # Array of 5 floats

# Get results as DataFrame with all coordinates
df = match_roots_to_shoots_complete(
    shoot_mask, root_mask, images[sample_idx], config,
    return_dataframe=True,
    sample_idx=sample_idx,
    verbose=True
)
print(df)
# Output columns:
# - plant_order: 1-5 (left to right)
# - Plant ID: test_image_1
# - Length (px): root length in pixels
# - length_px: duplicate for compatibility
# - top_node_x, top_node_y: pixel coordinates of root start
# - endpoint_x, endpoint_y: pixel coordinates of root end
# - top_node_robot_x/y/z: robot coordinates in meters
# - endpoint_robot_x/y/z: robot coordinates in meters

# Process all samples into single DataFrame
all_results = []
for idx in range(len(roots)):
    shoot_mask, root_mask = getter(idx, is_idx=True)
    df = match_roots_to_shoots_complete(
        shoot_mask, root_mask, images[idx], config,
        return_dataframe=True,
        sample_idx=idx
    )
    all_results.append(df)

# Combine and save full dataset
results_df = pd.concat(all_results, ignore_index=True)
results_df.to_csv('root_measurements_full.csv', index=False)
print(f"Processed {len(results_df)} plants from {len(roots)} images")

# Create Kaggle submission (Plant ID and Length only)
submission_df = results_df[['Plant ID', 'Length (px)']]
submission_df.to_csv('kaggle_submission.csv', index=False)
print(f"Saved Kaggle submission with {len(submission_df)} entries")

Key Classes and Functions:

GetIm: Helper class for loading and visualizing shoot/root mask pairs RootMatchingConfig: Configuration dataclass for tuning matching parameters match_roots_to_shoots_complete: Main pipeline function (masks in, measurements out) pixel_to_robot_coords: Convert pixel coordinates to robot coordinates in meters

Configuration Tuning Guide:

max_horizontal_offset: Increase if roots drift far from shoots (default: 200-400px) max_below_shoot: Increase if roots start far below shoots (default: 100px) distance_weight: Increase to favor closer roots (default: 0.35) size_weight: Increase to favor longer roots (default: 0.35) min_skeleton_pixels: Increase to reject more noise (default: 6px) min_score_threshold: Increase in assign_roots_to_shoots_greedy to reject weak matches (default: 0.3)

   1"""Root-shoot matching module for plant root analysis.
   2
   3This module provides a complete pipeline for matching root skeleton structures to shoot
   4regions in plant microscopy images. It handles filtering, scoring, assignment of roots
   5to shoots, and produces length measurements with robot coordinates for each plant.
   6
   7The pipeline processes binary masks of shoots and roots through multiple stages:
   81. Shoot reference point extraction
   92. Root mask filtering by spatial sampling
  103. Root structure extraction and skeletonization
  114. Candidate scoring based on proximity, size, and verticality
  125. Greedy left-to-right assignment ensuring one root per shoot
  136. Top node identification for path tracing
  147. Length calculation and coordinate conversion
  15
  16Complete End-to-End Example:
  17-----------------------------
  18```
  19# Setup paths and create data loader
  20from pathlib import Path
  21from root_shoot_matching import GetIm, RootMatchingConfig, match_roots_to_shoots_complete
  22import pandas as pd
  23
  24# Define paths to your masks and images
  25images_path = Path('data/Kaggle/')
  26roots_path = Path('data/repaired/roots')
  27shoots_path = Path('data/repaired/shoots')
  28
  29# Get sorted file lists
  30images = sorted(images_path.glob('*.png'))
  31roots = sorted(roots_path.glob('*.png'))
  32shoots = sorted(shoots_path.glob('*.png'))
  33
  34# Create getter for loading masks
  35getter = GetIm(shoots=shoots, roots=roots)
  36
  37# Configure matching parameters
  38config = RootMatchingConfig(
  39    max_horizontal_offset=200,      # Horizontal search radius (pixels)
  40    distance_weight=0.35,            # Proximity importance (0-1)
  41    size_weight=0.35,                # Root length importance (0-1)
  42    verticality_weight=0.3,          # Vertical orientation importance (0-1)
  43    max_below_shoot=100,             # Max distance below shoot to search (pixels)
  44    min_skeleton_pixels=6            # Minimum pixels to be valid root
  45)
  46
  47# Process single sample
  48sample_idx = 0
  49shoot_mask, root_mask = getter(sample_idx, is_idx=True)
  50
  51# Get results as numpy array (5 lengths in pixels, left to right)
  52lengths = match_roots_to_shoots_complete(
  53    shoot_mask, root_mask, images[sample_idx], config
  54)
  55print(f"Lengths: {lengths}")  # Array of 5 floats
  56
  57# Get results as DataFrame with all coordinates
  58df = match_roots_to_shoots_complete(
  59    shoot_mask, root_mask, images[sample_idx], config,
  60    return_dataframe=True,
  61    sample_idx=sample_idx,
  62    verbose=True
  63)
  64print(df)
  65# Output columns:
  66# - plant_order: 1-5 (left to right)
  67# - Plant ID: test_image_1
  68# - Length (px): root length in pixels
  69# - length_px: duplicate for compatibility
  70# - top_node_x, top_node_y: pixel coordinates of root start
  71# - endpoint_x, endpoint_y: pixel coordinates of root end
  72# - top_node_robot_x/y/z: robot coordinates in meters
  73# - endpoint_robot_x/y/z: robot coordinates in meters
  74
  75# Process all samples into single DataFrame
  76all_results = []
  77for idx in range(len(roots)):
  78    shoot_mask, root_mask = getter(idx, is_idx=True)
  79    df = match_roots_to_shoots_complete(
  80        shoot_mask, root_mask, images[idx], config,
  81        return_dataframe=True,
  82        sample_idx=idx
  83    )
  84    all_results.append(df)
  85
  86# Combine and save full dataset
  87results_df = pd.concat(all_results, ignore_index=True)
  88results_df.to_csv('root_measurements_full.csv', index=False)
  89print(f"Processed {len(results_df)} plants from {len(roots)} images")
  90
  91# Create Kaggle submission (Plant ID and Length only)
  92submission_df = results_df[['Plant ID', 'Length (px)']]
  93submission_df.to_csv('kaggle_submission.csv', index=False)
  94print(f"Saved Kaggle submission with {len(submission_df)} entries")
  95```
  96
  97Key Classes and Functions:
  98--------------------------
  99GetIm: Helper class for loading and visualizing shoot/root mask pairs
 100RootMatchingConfig: Configuration dataclass for tuning matching parameters
 101match_roots_to_shoots_complete: Main pipeline function (masks in, measurements out)
 102pixel_to_robot_coords: Convert pixel coordinates to robot coordinates in meters
 103
 104Configuration Tuning Guide:
 105---------------------------
 106max_horizontal_offset: Increase if roots drift far from shoots (default: 200-400px)
 107max_below_shoot: Increase if roots start far below shoots (default: 100px)
 108distance_weight: Increase to favor closer roots (default: 0.35)
 109size_weight: Increase to favor longer roots (default: 0.35)
 110min_skeleton_pixels: Increase to reject more noise (default: 6px)
 111min_score_threshold: Increase in assign_roots_to_shoots_greedy to reject weak matches (default: 0.3)
 112"""
 113
 114import numpy as np
 115import pandas as pd
 116from pathlib import Path  # ADDED: Missing import
 117import matplotlib.pyplot as plt  # ADDED: Missing import
 118import warnings
 119from dataclasses import dataclass, field
 120from scipy import ndimage
 121from skimage.measure import label as label_components
 122from scipy.spatial.distance import cdist
 123
 124from library.mask_processing import load_mask
 125from library.root_analysis import (
 126    extract_root_structures, 
 127    find_farthest_endpoint_path,
 128    calculate_skeleton_length_px
 129)
 130
 131
 132class GetIm():
 133    """Helper class for loading and displaying shoot and root mask pairs.
 134    
 135    Provides convenient access to paired shoot and root masks from sorted file lists,
 136    with methods for visualization. Handles both 1-indexed file numbers and 0-indexed
 137    array indices.
 138    
 139    Args:
 140        shoots (list): Sorted list of shoot mask file paths
 141        roots (list): Sorted list of root mask file paths (same length as shoots)
 142    
 143    Attributes:
 144        shoots (list): Stored shoot mask file paths
 145        roots (list): Stored root mask file paths
 146    
 147    Methods:
 148        show: Display shoot and root masks side by side
 149        show_overlay: Display masks overlaid in single image with grid
 150        __call__: Load masks directly (same as _load_masks)
 151    
 152    Examples:
 153        >>> from pathlib import Path
 154        >>> shoots = sorted(Path('data/shoots').glob('*.png'))
 155        >>> roots = sorted(Path('data/roots').glob('*.png'))
 156        >>> getter = GetIm(shoots=shoots, roots=roots)
 157        >>> 
 158        >>> # Display by file number (1-indexed)
 159        >>> getter.show(1)
 160        >>> 
 161        >>> # Display by array index (0-indexed)
 162        >>> getter.show(0, is_idx=True)
 163        >>> 
 164        >>> # Load masks directly
 165        >>> shoot_mask, root_mask = getter(0, is_idx=True)
 166    
 167    Notes:
 168        - File numbers are 1-indexed by default (file_num=1 loads first file)
 169        - Array indices are 0-indexed when is_idx=True (is_idx=True, file_num=0 loads first file)
 170        - Shoot and root lists must be the same length and correspond to matching pairs
 171    """    
 172    def __init__(self, shoots, roots):
 173        self.shoots = shoots
 174        self.roots = roots
 175    
 176    def _load_masks(self, file_num, is_idx=False, print_f=False):
 177        if not is_idx:
 178            file_num = file_num - 1
 179        
 180
 181        shoot_f = str(self.shoots[file_num])
 182        root_f = str(self.roots[file_num])
 183        if print_f:
 184            print(shoot_f, '\n', root_f)
 185        return (load_mask(shoot_f), load_mask(root_f))
 186
 187    def show(self, file_num, is_idx=False, size=(20, 8)):
 188        """Display shoot and root masks side by side with filenames as titles."""
 189        s, r = self._load_masks(file_num, is_idx, print_f=False)
 190        
 191        # Get filenames
 192        if not is_idx:
 193            file_num = file_num - 1
 194        
 195        shoot_name = Path(self.shoots[file_num]).name
 196        root_name = Path(self.roots[file_num]).name
 197        
 198        # Create side-by-side plot
 199        fig, axes = plt.subplots(1, 2, figsize=size)
 200        
 201        axes[0].imshow(s, cmap='gray')
 202        axes[0].set_title(shoot_name, fontsize=12)
 203        axes[0].axis('off')
 204        
 205        axes[1].imshow(r, cmap='gray')
 206        axes[1].set_title(root_name, fontsize=12)
 207        axes[1].axis('off')
 208        
 209        plt.tight_layout()
 210        plt.show()
 211
 212    def show_overlay(self, file_num, is_idx=False, size=(12, 8)):
 213        """Display shoot (green) and root (red) masks overlaid in single image with grid."""
 214        s, r = self._load_masks(file_num, is_idx, print_f=False)
 215        
 216        # Get filename
 217        if not is_idx:
 218            file_num = file_num - 1
 219        
 220        shoot_name = Path(self.shoots[file_num]).name
 221        root_name = Path(self.roots[file_num]).name
 222        combined_title = f"{shoot_name} + {root_name}"
 223        
 224        # Create RGB image
 225        h, w = s.shape
 226        rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
 227        
 228        # Apply shoot mask as green
 229        rgb_img[s > 0] = [0, 255, 0]
 230        
 231        # Apply root mask as red (on top)
 232        rgb_img[r > 0] = [255, 0, 0]
 233        
 234        # Display
 235        fig, ax = plt.subplots(figsize=size)
 236        ax.imshow(rgb_img)
 237        ax.set_title(combined_title, fontsize=12)
 238        
 239        # Add grid lines every 200 pixels
 240        ax.set_xticks(np.arange(0, w, 200))
 241        ax.set_yticks(np.arange(0, h, 200))
 242        ax.grid(True, color='blue', linewidth=0.5, alpha=0.5)
 243        
 244        plt.tight_layout()
 245        plt.show()
 246
 247    def __call__(self, file_num, is_idx=False, print_f=False):
 248        """
 249        Args:
 250            file_num(int): file number
 251            is_idx(bool): treat file_num as array index when true
 252            print_f(bool): print filenames when true
 253
 254        Returns:
 255            tuple(shoot_mask, root_mask)
 256        """
 257        return self._load_masks(file_num, is_idx, print_f=print_f)
 258
 259    
 260
 261@dataclass
 262class RootMatchingConfig:
 263    """Configuration parameters for root-shoot matching.
 264    
 265    Attributes:
 266        sampling_buffer_above: Pixels above top shoot to include in sampling box
 267        sampling_buffer_below: Pixels below bottom shoot to include in sampling box
 268        distance_weight: Weight for proximity score (0-1)
 269        size_weight: Weight for skeleton length score (0-1)
 270        verticality_weight: Weight for vertical orientation score (0-1)
 271        max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
 272        max_above_shoot: Allow roots to extend this many pixels above shoot bottom
 273        max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
 274        edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise regions
 275        max_distance: Distance normalization constant (pixels)
 276        max_size: Size normalization constant (pixels)
 277        ideal_aspect: Ideal vertical:horizontal ratio for roots
 278    """
 279    sampling_buffer_above: int = 200
 280    sampling_buffer_below: int = 300
 281    
 282    distance_weight: float = 0.35
 283    size_weight: float = 0.35
 284    verticality_weight: float = 0.3
 285    
 286    max_horizontal_offset: int = 200
 287    max_above_shoot: int = 200
 288    max_below_shoot: int = 100
 289    min_skeleton_pixels: int = 6 
 290    edge_exclusion_zones: list = field(default_factory=lambda: [(0, 900), (3300, 4000)])
 291    
 292    max_distance: int = 100
 293    max_size: int = 1000
 294    ideal_aspect: float = 5.0
 295
 296
 297
 298def get_shoot_reference_points_multi(shoot_mask):
 299    """Calculate multiple reference points for each shoot region.
 300    
 301    Args:
 302        shoot_mask: Binary mask array with shoot regions (H, W)
 303        
 304    Returns:
 305        tuple: (labeled_shoots, num_shoots, ref_points) where:
 306            - labeled_shoots: Array with unique label per shoot region
 307            - num_shoots: Number of distinct shoot regions found
 308            - ref_points: Dict with structure {shoot_label: {
 309                'centroid': (y, x),
 310                'bottom_center': (y, x),
 311                'bottom_most': (y, x),
 312                'bbox': (min_y, max_y, min_x, max_x)
 313              }}
 314    
 315    Examples:
 316        >>> labeled_shoots, num_shoots, ref_points = get_shoot_reference_points_multi(shoot_mask)
 317        >>> print(f"Found {num_shoots} shoots")
 318        >>> print(ref_points[1]['bottom_center'])
 319    """
 320    # Label shoots
 321    # labeled_shoots, num_shoots = ndimage.label(shoot_mask)
 322    structure = ndimage.generate_binary_structure(2, 2)  # 2D, 8-connectivity
 323    labeled_shoots, num_shoots = ndimage.label(shoot_mask, structure=structure)
 324    
 325    ref_points = {}
 326    
 327    for label in range(1, num_shoots + 1):
 328        shoot_region = labeled_shoots == label
 329        y_coords, x_coords = np.where(shoot_region)
 330        
 331        if len(y_coords) == 0:
 332            continue
 333            
 334        # Centroid
 335        centroid_y = np.mean(y_coords)
 336        centroid_x = np.mean(x_coords)
 337        
 338        # Bounding box
 339        min_y, max_y = y_coords.min(), y_coords.max()
 340        min_x, max_x = x_coords.min(), x_coords.max()
 341        
 342        # Bottom center: centroid x, maximum y (lowest point)
 343        bottom_center = (max_y, centroid_x)
 344        
 345        # Bottom most: actual bottom-most pixel
 346        bottom_idx = np.argmax(y_coords)
 347        bottom_most = (y_coords[bottom_idx], x_coords[bottom_idx])
 348        
 349        ref_points[label] = {
 350            'centroid': (centroid_y, centroid_x),
 351            'bottom_center': bottom_center,
 352            'bottom_most': bottom_most,
 353            'bbox': (min_y, max_y, min_x, max_x)
 354        }
 355    
 356    return labeled_shoots, num_shoots, ref_points
 357
 358
 359def filter_root_mask_by_sampling_box(root_mask, shoot_mask, ref_points, num_shoots,
 360                                     sampling_buffer_above=200,
 361                                     sampling_buffer_below=300):
 362    """Filter root mask to keep only components starting within sampling box.
 363    
 364    Creates a sampling box around all shoots and keeps only skeleton components
 365    whose top (min_y) falls within this box. Preserves full component length even
 366    if it extends beyond the box boundaries.
 367    
 368    Args:
 369        root_mask: Binary mask array with root regions (H, W)
 370        shoot_mask: Binary mask array with shoot regions (H, W)
 371        ref_points: Dict from get_shoot_reference_points_multi
 372        num_shoots: Number of shoot regions
 373        sampling_buffer_above: Pixels above top shoot to include
 374        sampling_buffer_below: Pixels below bottom shoot to include
 375        
 376    Returns:
 377        np.ndarray: Binary mask with filtered root components, preserving full lengths
 378        
 379    Examples:
 380        >>> filtered_mask = filter_root_mask_by_sampling_box(
 381        ...     root_mask, shoot_mask, ref_points, 5
 382        ... )
 383        >>> structures = extract_root_structures(filtered_mask)
 384    """
 385    # Calculate sampling box
 386    shoot_y_min = min(ref_points[s]['bbox'][0] for s in range(1, num_shoots + 1))
 387    shoot_y_max = max(ref_points[s]['bbox'][1] for s in range(1, num_shoots + 1))
 388    
 389    sampling_box_top = max(0, shoot_y_min - sampling_buffer_above)
 390    sampling_box_bottom = min(root_mask.shape[0], shoot_y_max + sampling_buffer_below)
 391    
 392    # Label all connected components in the FULL mask
 393    labeled_mask = label_components(root_mask)
 394    num_components = labeled_mask.max()
 395    
 396    print(f"Sampling box: rows {sampling_box_top:.0f} to {sampling_box_bottom:.0f}")
 397    print(f"Found {num_components} components in full mask")
 398    
 399    # Check which components have their TOP (min_y) in the sampling box
 400    valid_labels = set()
 401    
 402    for component_label in range(1, num_components + 1):
 403        component_mask = labeled_mask == component_label
 404        coords = np.argwhere(component_mask)
 405        
 406        if len(coords) == 0:
 407            continue
 408        
 409        comp_top_y = coords[:, 0].min()
 410        
 411        # Keep component if its top is in the sampling box
 412        if sampling_box_top <= comp_top_y <= sampling_box_bottom:
 413            valid_labels.add(component_label)
 414    
 415    # Create filtered mask with FULL components (not cropped)
 416    filtered_mask = np.isin(labeled_mask, list(valid_labels))
 417    
 418    print(f"Kept {len(valid_labels)} components (full length preserved)")
 419    
 420    return filtered_mask.astype(bool)
 421
 422
 423def score_skeleton_for_shoot(component_props, shoot_ref_point, shoot_bbox,
 424                              max_distance=100, max_size=1000, ideal_aspect=5.0,
 425                              distance_weight=0.5, size_weight=0.2, verticality_weight=0.3,
 426                              max_horizontal_offset=400,
 427                              max_above_shoot=200,
 428                              max_below_shoot=100,
 429                              min_skeleton_pixels=6,  # NEW parameter
 430                              edge_exclusion_zones=[(0, 900), (3300, 4000)]):
 431    """Score how likely a skeleton component is to be a root for a given shoot.
 432    
 433    Applies spatial filters (horizontal distance, edge zones) and calculates a
 434    combined score based on proximity, size, and verticality. Higher scores
 435    indicate better matches.
 436    
 437    Args:
 438        component_props: Dict with keys:
 439            - 'label': Component identifier
 440            - 'centroid': (y, x) tuple
 441            - 'bbox': (min_y, max_y, min_x, max_x) tuple
 442            - 'num_pixels': Total skeleton pixels
 443            - 'vertical_extent': Height in pixels
 444            - 'horizontal_extent': Width in pixels
 445        shoot_ref_point: (y, x) tuple for shoot reference position
 446        shoot_bbox: (min_y, max_y, min_x, max_x) tuple for shoot bounding box
 447        max_distance: Distance normalization constant (pixels)
 448        max_size: Size normalization constant (pixels)
 449        ideal_aspect: Ideal vertical:horizontal ratio for roots
 450        distance_weight: Weight for proximity score (0-1)
 451        size_weight: Weight for skeleton length (0-1)
 452        verticality_weight: Weight for vertical orientation (0-1)
 453        max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
 454        max_above_shoot: Allow roots to extend this many pixels above shoot bottom
 455        max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
 456        min_skeleton_pixels: Minimum skeleton pixels to be valid candidate  # ADDED
 457        edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise
 458        
 459    Returns:
 460        dict: Scoring results with keys:
 461            - 'score': Combined score (higher is better), -1 if invalid
 462            - 'distance': Euclidean distance from shoot
 463            - 'distance_score': Normalized distance component
 464            - 'size_score': Normalized size component
 465            - 'verticality_score': Normalized verticality component
 466            - 'aspect_ratio': Vertical:horizontal ratio
 467            - 'reason': Rejection reason if score is -1
 468            
 469    Examples:
 470        >>> score_result = score_skeleton_for_shoot(comp_props, shoot_ref, shoot_bbox)
 471        >>> if score_result['score'] > 0:
 472        ...     print(f"Valid candidate with score {score_result['score']:.3f}")
 473    """
 474    shoot_y, shoot_x = shoot_ref_point
 475    shoot_bottom = shoot_bbox[1]
 476    
 477    comp_centroid_y, comp_centroid_x = component_props['centroid']
 478    comp_top_y = component_props['bbox'][0]
 479    comp_bottom_y = component_props['bbox'][1]
 480    comp_bbox = component_props['bbox']
 481    comp_pixels = component_props['num_pixels']
 482    vertical_extent = component_props['vertical_extent']
 483    horizontal_extent = component_props['horizontal_extent']
 484
 485    # FILTER 0: Minimum size filter (reject noise)
 486    if comp_pixels < min_skeleton_pixels:
 487        return {
 488            'score': -1,
 489            'distance': float('inf'),
 490            'distance_score': 0,
 491            'size_score': 0,
 492            'verticality_score': 0,
 493            'reason': 'too_small'
 494        }
 495
 496    # FILTER 1: Root must extend reasonably below shoot
 497    if comp_bottom_y < (shoot_bottom - max_above_shoot):
 498        return {
 499            'score': -1,
 500            'distance': float('inf'),
 501            'distance_score': 0,
 502            'size_score': 0,
 503            'verticality_score': 0,
 504            'reason': 'above_shoot'
 505        }
 506    
 507    # FILTER 2: Root must start within reasonable distance below shoot
 508    if comp_top_y > (shoot_bottom + max_below_shoot):
 509        return {
 510            'score': -1,
 511            'distance': float('inf'),
 512            'distance_score': 0,
 513            'size_score': 0,
 514            'verticality_score': 0,
 515            'reason': 'too_far_below_shoot'
 516        }
 517    
 518    # FILTER 3: Horizontal overlap check
 519    comp_min_x = comp_bbox[2]
 520    comp_max_x = comp_bbox[3]
 521    search_zone_min_x = shoot_x - max_horizontal_offset
 522    search_zone_max_x = shoot_x + max_horizontal_offset
 523    
 524    if comp_max_x < search_zone_min_x or comp_min_x > search_zone_max_x:
 525        return {
 526            'score': -1,
 527            'distance': float('inf'),
 528            'distance_score': 0,
 529            'size_score': 0,
 530            'verticality_score': 0,
 531            'reason': f'too_far_horizontally'
 532        }
 533    
 534    # FILTER 4: Edge exclusion zones
 535    for min_edge_x, max_edge_x in edge_exclusion_zones:
 536        if min_edge_x <= comp_centroid_x <= max_edge_x:
 537            return {
 538                'score': -1,
 539                'distance': float('inf'),
 540                'distance_score': 0,
 541                'size_score': 0,
 542                'verticality_score': 0,
 543                'reason': f'in_edge_zone'
 544            }
 545    
 546    # Calculate scores
 547    distance = np.sqrt((comp_top_y - shoot_y)**2 + (comp_centroid_x - shoot_x)**2)
 548    distance_score = 1.0 / (1.0 + distance / max_distance)
 549    
 550    size_score = min(comp_pixels / max_size, 1.0)
 551    
 552    if horizontal_extent > 0:
 553        aspect_ratio = vertical_extent / horizontal_extent
 554        verticality_score = min(aspect_ratio / ideal_aspect, 1.0)
 555    else:
 556        verticality_score = 1.0
 557    
 558    combined = (distance_weight * distance_score + 
 559                size_weight * size_score + 
 560                verticality_weight * verticality_score)
 561    
 562    return {
 563        'score': combined,
 564        'distance': distance,
 565        'distance_score': distance_score,
 566        'size_score': size_score,
 567        'verticality_score': verticality_score,
 568        'aspect_ratio': vertical_extent / max(horizontal_extent, 1)
 569    }
 570
 571
 572def find_valid_candidates_for_shoots(structures, ref_points, num_shoots, config=None):
 573    """Find and score valid root candidates for each shoot.
 574    
 575    Uses sampling box pre-filtering for efficiency, then scores each component
 576    for each shoot. Returns organized candidates sorted by score.
 577    
 578    Args:
 579        structures: Dict from extract_root_structures with keys:
 580            - 'skeleton': Full skeleton array
 581            - 'labeled_skeleton': Labeled skeleton array
 582            - 'unique_labels': Array of label IDs
 583            - 'roots': Dict of root data by label
 584        ref_points: Dict from get_shoot_reference_points_multi
 585        num_shoots: Number of shoots (typically 5)
 586        config: RootMatchingConfig instance (uses defaults if None)
 587        
 588    Returns:
 589        dict: {shoot_label: [(root_label, score_dict), ...]} where each shoot's
 590              list is sorted by score descending
 591              
 592    Examples:
 593        >>> candidates = find_valid_candidates_for_shoots(structures, ref_points, 5)
 594        >>> for shoot_label in range(1, 6):
 595        ...     print(f"Shoot {shoot_label}: {len(candidates[shoot_label])} candidates")
 596    """
 597    if config is None:
 598        config = RootMatchingConfig()
 599    
 600    # Storage for each shoot's candidates
 601    shoot_candidates = {label: [] for label in range(1, num_shoots + 1)}
 602    
 603    total_components = len(structures['roots'])
 604    
 605    # Process each skeleton component
 606    for root_label, root_data in structures['roots'].items():
 607        branch_data = root_data.get('branch_data')
 608        
 609        # Skip if no branch data
 610        if branch_data is None or len(branch_data) == 0:
 611            continue
 612        
 613        # Extract coordinates for this component
 614        root_mask_region = root_data['mask']
 615        coords = np.argwhere(root_mask_region)
 616        
 617        if len(coords) == 0:
 618            continue
 619        
 620        # Build component properties
 621        comp_props = {
 622            'label': root_label,
 623            'centroid': (coords[:, 0].mean(), coords[:, 1].mean()),
 624            'bbox': (coords[:, 0].min(), coords[:, 0].max(), 
 625                    coords[:, 1].min(), coords[:, 1].max()),
 626            'num_pixels': root_data['total_pixels'],
 627            'vertical_extent': coords[:, 0].max() - coords[:, 0].min(),
 628            'horizontal_extent': coords[:, 1].max() - coords[:, 1].min()
 629        }
 630        
 631        # Score this component for each shoot
 632        for shoot_label in range(1, num_shoots + 1):
 633            shoot_ref = ref_points[shoot_label]['bottom_center']
 634            shoot_bbox = ref_points[shoot_label]['bbox']
 635            
 636            score_result = score_skeleton_for_shoot(
 637                comp_props, shoot_ref, shoot_bbox,
 638                max_distance=config.max_distance,
 639                max_size=config.max_size,
 640                ideal_aspect=config.ideal_aspect,
 641                distance_weight=config.distance_weight,
 642                size_weight=config.size_weight,
 643                verticality_weight=config.verticality_weight,
 644                max_horizontal_offset=config.max_horizontal_offset,
 645                max_above_shoot=config.max_above_shoot,
 646                max_below_shoot=config.max_below_shoot,
 647                min_skeleton_pixels=config.min_skeleton_pixels,  # ADDED
 648                edge_exclusion_zones=config.edge_exclusion_zones
 649            )
 650            
 651            if score_result['score'] > 0:
 652                shoot_candidates[shoot_label].append((root_label, score_result))
 653    
 654    # Sort each shoot's candidates by score (descending)
 655    for shoot_label in shoot_candidates:
 656        shoot_candidates[shoot_label].sort(key=lambda x: x[1]['score'], reverse=True)
 657    
 658    print(f"Valid candidates per shoot:")
 659    for shoot_label in range(1, num_shoots + 1):
 660        print(f"  Shoot {shoot_label}: {len(shoot_candidates[shoot_label])} candidates")
 661    
 662    return shoot_candidates
 663
 664
 665
 666def assign_roots_to_shoots_greedy(shoot_candidates, ref_points, num_shoots, config=None, min_score_threshold=0.3):
 667    """Assign roots to shoots using greedy left-to-right algorithm.
 668    
 669    Processes shoots in left-to-right order, assigning the best available
 670    (unassigned) root to each shoot. Guarantees exactly num_shoots outputs.
 671    
 672    Args:
 673        shoot_candidates: Dict from find_valid_candidates_for_shoots with structure
 674            {shoot_label: [(root_label, score_dict), ...]}
 675        ref_points: Dict from get_shoot_reference_points_multi
 676        num_shoots: Number of shoots (always 5)
 677        config: RootMatchingConfig instance (optional)
 678        min_score_threshold: Minimum score required for assignment (default 0.3)
 679        
 680    Returns:
 681        dict: {shoot_label: {
 682            'root_label': int or None,
 683            'score_dict': dict or None,
 684            'order': int  # 0=leftmost, 4=rightmost
 685        }}
 686    """
 687    # Sort shoots by x-position (left to right)
 688    shoot_positions = []
 689    for shoot_label in range(1, num_shoots + 1):
 690        centroid_x = ref_points[shoot_label]['centroid'][1]
 691        shoot_positions.append((shoot_label, centroid_x))
 692    
 693    shoot_positions.sort(key=lambda x: x[1])  # Sort by x coordinate
 694    
 695    # Track assigned roots
 696    assigned_roots = set()
 697    
 698    # Build assignments
 699    assignments = {}
 700    
 701    for order, (shoot_label, _) in enumerate(shoot_positions):
 702        candidates = shoot_candidates[shoot_label]
 703        
 704        # Find best unassigned root
 705        best_root = None
 706        best_score_dict = None
 707        
 708        for root_label, score_dict in candidates:
 709            if root_label not in assigned_roots:
 710                # Check if score meets minimum threshold
 711                if score_dict['score'] >= min_score_threshold:
 712                    best_root = root_label
 713                    best_score_dict = score_dict
 714                    break  # Candidates are already sorted by score
 715        
 716        # Store assignment
 717        if best_root is not None:
 718            assignments[shoot_label] = {
 719                'root_label': best_root,
 720                'score_dict': best_score_dict,
 721                'order': order
 722            }
 723            assigned_roots.add(best_root)
 724        else:
 725            # No valid candidates (ungerminated seed or below threshold)
 726            assignments[shoot_label] = {
 727                'root_label': None,
 728                'score_dict': None,
 729                'order': order
 730            }
 731    
 732    return assignments
 733
 734
 735def find_best_top_node(root_skeleton, shoot_ref_point, structures, root_label, 
 736                       distance_threshold=150, prefer_topmost=True):
 737    """Find the skeleton node closest to the shoot reference point.
 738    
 739    Uses a hybrid approach: filters nodes within distance_threshold of shoot,
 740    then selects based on preference (topmost or closest).
 741    
 742    Args:
 743        root_skeleton: Binary mask for this specific root (not used, kept for signature)
 744        shoot_ref_point: (y, x) tuple for shoot reference position
 745        structures: Dict from extract_root_structures
 746        root_label: Label ID for this root
 747        distance_threshold: Maximum distance from shoot to consider nodes (pixels)
 748        prefer_topmost: If True, pick topmost among candidates; if False, pick closest
 749        
 750    Returns:
 751        int: Node ID of the selected top node, or None if no valid nodes found
 752    """
 753    
 754    # Get branch data for this root
 755    branch_data = structures['roots'][root_label].get('branch_data')
 756    if branch_data is None or len(branch_data) == 0:
 757        return None
 758    
 759    # Extract all unique nodes from branch_data
 760    src_nodes = branch_data[['node-id-src', 'image-coord-src-0', 'image-coord-src-1']].drop_duplicates()
 761    dst_nodes = branch_data[['node-id-dst', 'image-coord-dst-0', 'image-coord-dst-1']].drop_duplicates()
 762    
 763    src_nodes.columns = ['node_id', 'y', 'x']
 764    dst_nodes.columns = ['node_id', 'y', 'x']
 765    
 766    all_nodes = pd.concat([src_nodes, dst_nodes]).drop_duplicates(subset='node_id')
 767    node_coords = all_nodes[['y', 'x']].values
 768    node_ids = all_nodes['node_id'].values
 769    
 770    if len(node_ids) == 0:
 771        return None
 772    
 773    # Calculate distances to shoot reference point
 774    distances = cdist(node_coords, [shoot_ref_point], metric='euclidean').flatten()
 775    
 776    # Filter nodes within distance threshold
 777    within_threshold = distances <= distance_threshold
 778    
 779    if not np.any(within_threshold):
 780        # No nodes within threshold, fall back to closest node
 781        return int(node_ids[np.argmin(distances)])
 782    
 783    # Get candidate nodes within threshold
 784    candidate_indices = np.where(within_threshold)[0]
 785    candidate_coords = node_coords[candidate_indices]
 786    candidate_ids = node_ids[candidate_indices]
 787    candidate_distances = distances[candidate_indices]
 788    
 789    if prefer_topmost:
 790        # Pick node with lowest y-coordinate (topmost)
 791        best_idx = np.argmin(candidate_coords[:, 0])
 792        return int(candidate_ids[best_idx])
 793    else:
 794        # Pick node with minimum distance
 795        best_idx = np.argmin(candidate_distances)
 796        return int(candidate_ids[best_idx])
 797    
 798
 799def pixel_to_robot_coords(pixel_x, pixel_y, image_shape, 
 800                         dish_size_m=0.15,
 801                         dish_offset_m=[0.10775, 0.062, 0.175]):
 802    """
 803    Convert pixel coordinates to robot coordinates in meters.
 804    
 805    Converts ROI-relative pixel coordinates to robot world coordinates.
 806    Assumes square petri dish with gantry axes aligned to image axes:
 807    - Robot X-axis aligns with image Y-axis (rows, downward)
 808    - Robot Y-axis aligns with image X-axis (columns, rightward)
 809    
 810    Args:
 811        pixel_x: X coordinate in pixels (column, increases right)
 812        pixel_y: Y coordinate in pixels (row, increases down)
 813        image_shape: Tuple of (height, width) of the ROI
 814        dish_size_m: Size of square petri dish in meters (default 0.15m)
 815        dish_offset_m: Robot coordinates [x, y, z] of plate top-left corner in meters
 816                      Default [0.10775, 0.062, 0.175] from simulation specification
 817        
 818    Returns:
 819        tuple: (robot_x_m, robot_y_m, robot_z_m) in meters
 820    """
 821    height, width = image_shape[:2]
 822    
 823    # Calculate scale (meters per pixel)
 824    scale = dish_size_m / width
 825    
 826    # Convert to plate-relative meters
 827    plate_x = pixel_x * scale
 828    plate_y = pixel_y * scale
 829    
 830    # Map to robot coordinates (gantry alignment, no rotation)
 831    robot_x = plate_y + dish_offset_m[0]  # Image Y-axis → Robot X-axis
 832    robot_y = plate_x + dish_offset_m[1]  # Image X-axis → Robot Y-axis
 833    robot_z = dish_offset_m[2]
 834    
 835    return robot_x, robot_y, robot_z
 836
 837def match_roots_to_shoots_complete(shoot_mask, root_mask, image_path, config=None, 
 838                                   distance_threshold=150, prefer_topmost=True,
 839                                   min_score_threshold=0.3, verbose=False,
 840                                   return_dataframe=False, sample_idx=None,
 841                                   visual_debugging=False, output_path=None):
 842    """Complete pipeline: shoot and root masks in, 5 length measurements out.
 843    
 844    High-level wrapper that runs the entire matching pipeline:
 845    1. Extract shoot reference points
 846    2. Filter root mask by sampling box
 847    3. Extract root structures
 848    4. Find valid candidates for each shoot
 849    5. Assign roots to shoots greedily (left to right)
 850    6. Find best top nodes for assignments
 851    7. Calculate root lengths from top nodes
 852    
 853    Args:
 854        shoot_mask: Binary mask array with shoot regions (H, W)
 855        root_mask: Binary mask array with root regions (H, W)
 856        image_path: Path to original image for ROI detection
 857        config: RootMatchingConfig instance (uses defaults if None)
 858        distance_threshold: Max distance from shoot for top node selection (pixels)
 859        prefer_topmost: If True, prefer topmost node; if False, prefer closest
 860        min_score_threshold: Minimum score for valid assignment
 861        verbose: Print progress messages
 862        return_dataframe: If True, return pandas DataFrame; if False, return numpy array
 863        sample_idx: Sample index to include in DataFrame (only used if return_dataframe=True)
 864        visual_debugging: If True, display visualizations of assignments and root lengths
 865        output_path: Optional path to save visualization outputs (not yet implemented)
 866        
 867    Returns:
 868        If return_dataframe=False (default):
 869            np.ndarray: Array of 5 root lengths in pixels, ordered left to right.
 870        If return_dataframe=True:
 871            pd.DataFrame: DataFrame with columns:
 872                - 'plant_order': 1-5 (left to right)
 873                - 'Plant ID': test_image_{sample_idx + 1}
 874                - 'Length (px)': root length in pixels
 875                - 'length_px': root length in pixels (duplicate)
 876                - 'top_node_x': x-coordinate of top node (pixels)
 877                - 'top_node_y': y-coordinate of top node (pixels)
 878                - 'endpoint_x': x-coordinate of endpoint node (pixels)
 879                - 'endpoint_y': y-coordinate of endpoint node (pixels)
 880                - 'top_node_robot_x': x-coordinate of top node (meters)
 881                - 'top_node_robot_y': y-coordinate of top node (meters)
 882                - 'top_node_robot_z': z-coordinate of top node (meters)
 883                - 'endpoint_robot_x': x-coordinate of endpoint (meters)
 884                - 'endpoint_robot_y': y-coordinate of endpoint (meters)
 885                - 'endpoint_robot_z': z-coordinate of endpoint (meters)
 886            
 887    Raises:
 888        ValueError: If shoot_mask does not contain exactly 5 shoots (warning only)
 889        
 890    Examples:
 891        >>> # Get numpy array
 892        >>> lengths = match_roots_to_shoots_complete(shoot_mask, root_mask)
 893        >>> 
 894        >>> # Get DataFrame
 895        >>> df = match_roots_to_shoots_complete(shoot_mask, root_mask, 
 896        ...                                      return_dataframe=True, sample_idx=0)
 897        
 898    Notes:
 899        - Always returns exactly 5 measurements
 900        - Left-to-right order is determined by shoot x-position (centroid)
 901        - Robust to missing roots (returns 0.0) and noisy masks
 902    """
 903
 904    
 905    if config is None:
 906        config = RootMatchingConfig()
 907    
 908    if verbose:
 909        print("Starting complete root-shoot matching pipeline...")
 910    
 911    # Step 0: Detect ROI from original image
 912    if verbose:
 913        print("  Step 0: Detecting ROI from original image...")
 914    
 915    from library.roi import detect_roi
 916    from library.mask_processing import load_mask
 917    
 918    roi_bbox = detect_roi(load_mask(str(image_path)))
 919    (x1, y1), (x2, y2) = roi_bbox
 920    roi_width = x2 - x1
 921    roi_height = y2 - y1
 922    
 923    if verbose:
 924        print(f"    ROI bbox: ({x1}, {y1}) to ({x2}, {y2}), size={roi_width}x{roi_height}px")
 925    
 926    # Step 1: Get shoot reference points
 927    if verbose:
 928        print("  Step 1: Extracting shoot reference points...")
 929    labeled_shoots, num_shoots, ref_points = get_shoot_reference_points_multi(shoot_mask)
 930    
 931    if num_shoots != 5:
 932        warnings.warn(f"Expected 5 shoots, found {num_shoots}. Continuing anyway.", UserWarning)
 933    
 934    if verbose:
 935        print(f"    Found {num_shoots} shoots")
 936    
 937    # Step 2: Filter root mask by sampling box
 938    if verbose:
 939        print("  Step 2: Filtering root mask by sampling box...")
 940    filtered_root_mask = filter_root_mask_by_sampling_box(
 941        root_mask, shoot_mask, ref_points, num_shoots,
 942        sampling_buffer_above=config.sampling_buffer_above,
 943        sampling_buffer_below=config.sampling_buffer_below
 944    )
 945    
 946    # Step 3: Extract root structures
 947    if verbose:
 948        print("  Step 3: Extracting root structures...")
 949    structures = extract_root_structures(filtered_root_mask, verbose=verbose)
 950    
 951    if verbose:
 952        print(f"    Extracted {len(structures['roots'])} root structures")
 953    
 954    # Step 4: Find valid candidates
 955    if verbose:
 956        print("  Step 4: Finding valid candidates for each shoot...")
 957    shoot_candidates = find_valid_candidates_for_shoots(structures, ref_points, num_shoots, config)
 958    
 959    # Step 5: Assign roots to shoots
 960    if verbose:
 961        print("  Step 5: Assigning roots to shoots (greedy left-to-right)...")
 962    assignments = assign_roots_to_shoots_greedy(shoot_candidates, ref_points, num_shoots, 
 963                                                min_score_threshold=min_score_threshold)
 964    
 965    # Step 6 & 7: Find top nodes and calculate lengths
 966    if verbose:
 967        print("  Step 6-7: Finding top nodes and calculating lengths...")
 968    
 969    # Helper function to get node coordinates
 970    def get_node_coords(node_id, branch_data):
 971        """Extract (y, x) coordinates for a given node_id from branch_data."""
 972        for idx, row in branch_data.iterrows():
 973            if row['node-id-src'] == node_id:
 974                return (row['image-coord-src-0'], row['image-coord-src-1'])
 975            elif row['node-id-dst'] == node_id:
 976                return (row['image-coord-dst-0'], row['image-coord-dst-1'])
 977        return None
 978    
 979    # Build structure for process_matched_roots_to_lengths
 980    top_node_results = {}
 981    
 982    for shoot_label, info in assignments.items():
 983        root_label = info['root_label']
 984        
 985        if root_label is not None:
 986            # Find top node
 987            shoot_ref = ref_points[shoot_label]['bottom_center']
 988            top_node = find_best_top_node(None, shoot_ref, structures, root_label,
 989                                         distance_threshold=distance_threshold,
 990                                         prefer_topmost=prefer_topmost)
 991            
 992            if top_node is not None:
 993                branch_data = structures['roots'][root_label]['branch_data']
 994                
 995                top_node_results[root_label] = {
 996                    'shoot_label': shoot_label,
 997                    'top_nodes': [(top_node, 0)],
 998                    'branch_data': branch_data
 999                }
1000    
1001    # Calculate lengths and coordinates for all matched roots
1002    lengths_dict = {}
1003    top_coords_dict = {}
1004    endpoint_coords_dict = {}
1005    
1006    for root_label, result in top_node_results.items():
1007        branch_data = result['branch_data']
1008        shoot_label = result['shoot_label']
1009        top_node = result['top_nodes'][0][0]
1010        
1011        # Get top node coordinates
1012        top_coords = get_node_coords(top_node, branch_data)
1013        if top_coords:
1014            top_coords_dict[shoot_label] = top_coords
1015        
1016        try:
1017            # Find longest path from top node
1018            path = find_farthest_endpoint_path(
1019                branch_data,
1020                top_node,
1021                direction='down',
1022                use_smart_scoring=True,
1023                verbose=False
1024            )
1025            
1026            # Get endpoint node (last node in path)
1027            endpoint_node = path[-1][0]
1028            endpoint_coords = get_node_coords(endpoint_node, branch_data)
1029            if endpoint_coords:
1030                endpoint_coords_dict[shoot_label] = endpoint_coords
1031            
1032            # Calculate length
1033            root_length = calculate_skeleton_length_px(path)
1034            lengths_dict[shoot_label] = root_length
1035            
1036            if verbose:
1037                print(f"    Shoot {shoot_label}: Root {root_label}, length={root_length:.1f}px, "
1038                      f"top={top_coords}, endpoint={endpoint_coords}")
1039            
1040        except Exception as e:
1041            warnings.warn(f"Failed to calculate length for root {root_label} (shoot {shoot_label}): {e}", UserWarning)
1042            lengths_dict[shoot_label] = 0.0
1043            
1044            if verbose:
1045                print(f"    Shoot {shoot_label}: Root {root_label}, ERROR - returning 0.0")
1046    
1047    # Build output array ordered by shoot position (left to right)
1048    shoot_positions = []
1049    for shoot_label in range(1, num_shoots + 1):
1050        shoot_x = ref_points[shoot_label]['centroid'][1]
1051        shoot_positions.append((shoot_label, shoot_x))
1052    
1053    shoot_positions.sort(key=lambda x: x[1])
1054    
1055    # Create final arrays
1056    lengths_array = np.array([
1057        lengths_dict.get(shoot_label, 0.0)
1058        for shoot_label, _ in shoot_positions
1059    ])
1060    
1061    top_x_array = np.array([
1062        top_coords_dict.get(shoot_label, (np.nan, np.nan))[1]
1063        for shoot_label, _ in shoot_positions
1064    ])
1065    
1066    top_y_array = np.array([
1067        top_coords_dict.get(shoot_label, (np.nan, np.nan))[0]
1068        for shoot_label, _ in shoot_positions
1069    ])
1070    
1071    endpoint_x_array = np.array([
1072        endpoint_coords_dict.get(shoot_label, (np.nan, np.nan))[1]
1073        for shoot_label, _ in shoot_positions
1074    ])
1075    
1076    endpoint_y_array = np.array([
1077        endpoint_coords_dict.get(shoot_label, (np.nan, np.nan))[0]
1078        for shoot_label, _ in shoot_positions
1079    ])
1080    
1081    if verbose:
1082        print(f"\n  Final lengths (left to right): {lengths_array}")
1083        print("  Pipeline complete!")
1084    
1085    # Visual debugging
1086    if visual_debugging:
1087
1088        im_path = Path(image_path)
1089        print("="*50)
1090        print(f"\n  Generating visualizations for {im_path.name}")
1091
1092        print('   Original image')
1093        image = load_mask(im_path)
1094
1095        fig, ax = plt.subplots(figsize=(12, 8))
1096        ax.imshow(image, cmap='gray')
1097        ax.set_title(im_path.name, fontsize=14)
1098        ax.axis('off')
1099        plt.tight_layout()
1100        plt.show()
1101
1102        
1103        from library.root_analysis_visualization import visualize_assignments, visualize_root_lengths
1104        
1105        # Visualize assignments
1106        print('    Visualize shoot & soot assignments')
1107        visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points)
1108        
1109        # Visualize root lengths with detailed views
1110        print('    Visualize root skeletons with node and edge network')
1111        visualize_root_lengths(structures, top_node_results, labeled_shoots, 
1112                              show_detailed_roots=True)
1113    
1114    # Return DataFrame or array
1115    if return_dataframe:
1116        import pandas as pd
1117        
1118        # Convert pixel coordinates to robot coordinates using ROI-relative system
1119        roi_shape = (roi_height, roi_width)
1120        
1121        top_robot_coords = [pixel_to_robot_coords(x - x1, y - y1, roi_shape) 
1122                           if not (np.isnan(x) or np.isnan(y)) else (np.nan, np.nan, np.nan)
1123                           for x, y in zip(top_x_array, top_y_array)]
1124        endpoint_robot_coords = [pixel_to_robot_coords(x - x1, y - y1, roi_shape) 
1125                                if not (np.isnan(x) or np.isnan(y)) else (np.nan, np.nan, np.nan)
1126                                for x, y in zip(endpoint_x_array, endpoint_y_array)]
1127        
1128        # Unpack into separate arrays
1129        top_robot_x = np.array([c[0] for c in top_robot_coords])
1130        top_robot_y = np.array([c[1] for c in top_robot_coords])
1131        top_robot_z = np.array([c[2] for c in top_robot_coords])
1132        
1133        endpoint_robot_x = np.array([c[0] for c in endpoint_robot_coords])
1134        endpoint_robot_y = np.array([c[1] for c in endpoint_robot_coords])
1135        endpoint_robot_z = np.array([c[2] for c in endpoint_robot_coords])
1136        
1137        df_data = {
1138            'plant_order': list(range(1, 6)),
1139            'Plant ID': [f'test_image_{sample_idx + 1:02d}_plant_{i}' if sample_idx is not None else f'unknown_plant_{i}' for i in range(1, 6)],
1140            'Length (px)': lengths_array,
1141            'length_px': lengths_array,
1142            'top_node_x': top_x_array,
1143            'top_node_y': top_y_array,
1144            'endpoint_x': endpoint_x_array,
1145            'endpoint_y': endpoint_y_array,
1146            'top_node_robot_x': top_robot_x,
1147            'top_node_robot_y': top_robot_y,
1148            'top_node_robot_z': top_robot_z,
1149            'endpoint_robot_x': endpoint_robot_x,
1150            'endpoint_robot_y': endpoint_robot_y,
1151            'endpoint_robot_z': endpoint_robot_z
1152        }
1153        
1154        return pd.DataFrame(df_data)
1155    else:
1156        return lengths_array
class GetIm:
133class GetIm():
134    """Helper class for loading and displaying shoot and root mask pairs.
135    
136    Provides convenient access to paired shoot and root masks from sorted file lists,
137    with methods for visualization. Handles both 1-indexed file numbers and 0-indexed
138    array indices.
139    
140    Args:
141        shoots (list): Sorted list of shoot mask file paths
142        roots (list): Sorted list of root mask file paths (same length as shoots)
143    
144    Attributes:
145        shoots (list): Stored shoot mask file paths
146        roots (list): Stored root mask file paths
147    
148    Methods:
149        show: Display shoot and root masks side by side
150        show_overlay: Display masks overlaid in single image with grid
151        __call__: Load masks directly (same as _load_masks)
152    
153    Examples:
154        >>> from pathlib import Path
155        >>> shoots = sorted(Path('data/shoots').glob('*.png'))
156        >>> roots = sorted(Path('data/roots').glob('*.png'))
157        >>> getter = GetIm(shoots=shoots, roots=roots)
158        >>> 
159        >>> # Display by file number (1-indexed)
160        >>> getter.show(1)
161        >>> 
162        >>> # Display by array index (0-indexed)
163        >>> getter.show(0, is_idx=True)
164        >>> 
165        >>> # Load masks directly
166        >>> shoot_mask, root_mask = getter(0, is_idx=True)
167    
168    Notes:
169        - File numbers are 1-indexed by default (file_num=1 loads first file)
170        - Array indices are 0-indexed when is_idx=True (is_idx=True, file_num=0 loads first file)
171        - Shoot and root lists must be the same length and correspond to matching pairs
172    """    
173    def __init__(self, shoots, roots):
174        self.shoots = shoots
175        self.roots = roots
176    
177    def _load_masks(self, file_num, is_idx=False, print_f=False):
178        if not is_idx:
179            file_num = file_num - 1
180        
181
182        shoot_f = str(self.shoots[file_num])
183        root_f = str(self.roots[file_num])
184        if print_f:
185            print(shoot_f, '\n', root_f)
186        return (load_mask(shoot_f), load_mask(root_f))
187
188    def show(self, file_num, is_idx=False, size=(20, 8)):
189        """Display shoot and root masks side by side with filenames as titles."""
190        s, r = self._load_masks(file_num, is_idx, print_f=False)
191        
192        # Get filenames
193        if not is_idx:
194            file_num = file_num - 1
195        
196        shoot_name = Path(self.shoots[file_num]).name
197        root_name = Path(self.roots[file_num]).name
198        
199        # Create side-by-side plot
200        fig, axes = plt.subplots(1, 2, figsize=size)
201        
202        axes[0].imshow(s, cmap='gray')
203        axes[0].set_title(shoot_name, fontsize=12)
204        axes[0].axis('off')
205        
206        axes[1].imshow(r, cmap='gray')
207        axes[1].set_title(root_name, fontsize=12)
208        axes[1].axis('off')
209        
210        plt.tight_layout()
211        plt.show()
212
213    def show_overlay(self, file_num, is_idx=False, size=(12, 8)):
214        """Display shoot (green) and root (red) masks overlaid in single image with grid."""
215        s, r = self._load_masks(file_num, is_idx, print_f=False)
216        
217        # Get filename
218        if not is_idx:
219            file_num = file_num - 1
220        
221        shoot_name = Path(self.shoots[file_num]).name
222        root_name = Path(self.roots[file_num]).name
223        combined_title = f"{shoot_name} + {root_name}"
224        
225        # Create RGB image
226        h, w = s.shape
227        rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
228        
229        # Apply shoot mask as green
230        rgb_img[s > 0] = [0, 255, 0]
231        
232        # Apply root mask as red (on top)
233        rgb_img[r > 0] = [255, 0, 0]
234        
235        # Display
236        fig, ax = plt.subplots(figsize=size)
237        ax.imshow(rgb_img)
238        ax.set_title(combined_title, fontsize=12)
239        
240        # Add grid lines every 200 pixels
241        ax.set_xticks(np.arange(0, w, 200))
242        ax.set_yticks(np.arange(0, h, 200))
243        ax.grid(True, color='blue', linewidth=0.5, alpha=0.5)
244        
245        plt.tight_layout()
246        plt.show()
247
248    def __call__(self, file_num, is_idx=False, print_f=False):
249        """
250        Args:
251            file_num(int): file number
252            is_idx(bool): treat file_num as array index when true
253            print_f(bool): print filenames when true
254
255        Returns:
256            tuple(shoot_mask, root_mask)
257        """
258        return self._load_masks(file_num, is_idx, print_f=print_f)

Helper class for loading and displaying shoot and root mask pairs.

Provides convenient access to paired shoot and root masks from sorted file lists, with methods for visualization. Handles both 1-indexed file numbers and 0-indexed array indices.

Arguments:
  • shoots (list): Sorted list of shoot mask file paths
  • roots (list): Sorted list of root mask file paths (same length as shoots)
Attributes:
  • shoots (list): Stored shoot mask file paths
  • roots (list): Stored root mask file paths
Methods:

show: Display shoot and root masks side by side show_overlay: Display masks overlaid in single image with grid __call__: Load masks directly (same as _load_masks)

Examples:
>>> from pathlib import Path
>>> shoots = sorted(Path('data/shoots').glob('*.png'))
>>> roots = sorted(Path('data/roots').glob('*.png'))
>>> getter = GetIm(shoots=shoots, roots=roots)
>>> 
>>> # Display by file number (1-indexed)
>>> getter.show(1)
>>> 
>>> # Display by array index (0-indexed)
>>> getter.show(0, is_idx=True)
>>> 
>>> # Load masks directly
>>> shoot_mask, root_mask = getter(0, is_idx=True)
Notes:
  • File numbers are 1-indexed by default (file_num=1 loads first file)
  • Array indices are 0-indexed when is_idx=True (is_idx=True, file_num=0 loads first file)
  • Shoot and root lists must be the same length and correspond to matching pairs
GetIm(shoots, roots)
173    def __init__(self, shoots, roots):
174        self.shoots = shoots
175        self.roots = roots
shoots
roots
def show(self, file_num, is_idx=False, size=(20, 8)):
188    def show(self, file_num, is_idx=False, size=(20, 8)):
189        """Display shoot and root masks side by side with filenames as titles."""
190        s, r = self._load_masks(file_num, is_idx, print_f=False)
191        
192        # Get filenames
193        if not is_idx:
194            file_num = file_num - 1
195        
196        shoot_name = Path(self.shoots[file_num]).name
197        root_name = Path(self.roots[file_num]).name
198        
199        # Create side-by-side plot
200        fig, axes = plt.subplots(1, 2, figsize=size)
201        
202        axes[0].imshow(s, cmap='gray')
203        axes[0].set_title(shoot_name, fontsize=12)
204        axes[0].axis('off')
205        
206        axes[1].imshow(r, cmap='gray')
207        axes[1].set_title(root_name, fontsize=12)
208        axes[1].axis('off')
209        
210        plt.tight_layout()
211        plt.show()

Display shoot and root masks side by side with filenames as titles.

def show_overlay(self, file_num, is_idx=False, size=(12, 8)):
213    def show_overlay(self, file_num, is_idx=False, size=(12, 8)):
214        """Display shoot (green) and root (red) masks overlaid in single image with grid."""
215        s, r = self._load_masks(file_num, is_idx, print_f=False)
216        
217        # Get filename
218        if not is_idx:
219            file_num = file_num - 1
220        
221        shoot_name = Path(self.shoots[file_num]).name
222        root_name = Path(self.roots[file_num]).name
223        combined_title = f"{shoot_name} + {root_name}"
224        
225        # Create RGB image
226        h, w = s.shape
227        rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
228        
229        # Apply shoot mask as green
230        rgb_img[s > 0] = [0, 255, 0]
231        
232        # Apply root mask as red (on top)
233        rgb_img[r > 0] = [255, 0, 0]
234        
235        # Display
236        fig, ax = plt.subplots(figsize=size)
237        ax.imshow(rgb_img)
238        ax.set_title(combined_title, fontsize=12)
239        
240        # Add grid lines every 200 pixels
241        ax.set_xticks(np.arange(0, w, 200))
242        ax.set_yticks(np.arange(0, h, 200))
243        ax.grid(True, color='blue', linewidth=0.5, alpha=0.5)
244        
245        plt.tight_layout()
246        plt.show()

Display shoot (green) and root (red) masks overlaid in single image with grid.

@dataclass
class RootMatchingConfig:
262@dataclass
263class RootMatchingConfig:
264    """Configuration parameters for root-shoot matching.
265    
266    Attributes:
267        sampling_buffer_above: Pixels above top shoot to include in sampling box
268        sampling_buffer_below: Pixels below bottom shoot to include in sampling box
269        distance_weight: Weight for proximity score (0-1)
270        size_weight: Weight for skeleton length score (0-1)
271        verticality_weight: Weight for vertical orientation score (0-1)
272        max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
273        max_above_shoot: Allow roots to extend this many pixels above shoot bottom
274        max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
275        edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise regions
276        max_distance: Distance normalization constant (pixels)
277        max_size: Size normalization constant (pixels)
278        ideal_aspect: Ideal vertical:horizontal ratio for roots
279    """
280    sampling_buffer_above: int = 200
281    sampling_buffer_below: int = 300
282    
283    distance_weight: float = 0.35
284    size_weight: float = 0.35
285    verticality_weight: float = 0.3
286    
287    max_horizontal_offset: int = 200
288    max_above_shoot: int = 200
289    max_below_shoot: int = 100
290    min_skeleton_pixels: int = 6 
291    edge_exclusion_zones: list = field(default_factory=lambda: [(0, 900), (3300, 4000)])
292    
293    max_distance: int = 100
294    max_size: int = 1000
295    ideal_aspect: float = 5.0

Configuration parameters for root-shoot matching.

Attributes:
  • sampling_buffer_above: Pixels above top shoot to include in sampling box
  • sampling_buffer_below: Pixels below bottom shoot to include in sampling box
  • distance_weight: Weight for proximity score (0-1)
  • size_weight: Weight for skeleton length score (0-1)
  • verticality_weight: Weight for vertical orientation score (0-1)
  • max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
  • max_above_shoot: Allow roots to extend this many pixels above shoot bottom
  • max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
  • edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise regions
  • max_distance: Distance normalization constant (pixels)
  • max_size: Size normalization constant (pixels)
  • ideal_aspect: Ideal vertical:horizontal ratio for roots
RootMatchingConfig( sampling_buffer_above: int = 200, sampling_buffer_below: int = 300, distance_weight: float = 0.35, size_weight: float = 0.35, verticality_weight: float = 0.3, max_horizontal_offset: int = 200, max_above_shoot: int = 200, max_below_shoot: int = 100, min_skeleton_pixels: int = 6, edge_exclusion_zones: list = <factory>, max_distance: int = 100, max_size: int = 1000, ideal_aspect: float = 5.0)
sampling_buffer_above: int = 200
sampling_buffer_below: int = 300
distance_weight: float = 0.35
size_weight: float = 0.35
verticality_weight: float = 0.3
max_horizontal_offset: int = 200
max_above_shoot: int = 200
max_below_shoot: int = 100
min_skeleton_pixels: int = 6
edge_exclusion_zones: list
max_distance: int = 100
max_size: int = 1000
ideal_aspect: float = 5.0
def get_shoot_reference_points_multi(shoot_mask):
299def get_shoot_reference_points_multi(shoot_mask):
300    """Calculate multiple reference points for each shoot region.
301    
302    Args:
303        shoot_mask: Binary mask array with shoot regions (H, W)
304        
305    Returns:
306        tuple: (labeled_shoots, num_shoots, ref_points) where:
307            - labeled_shoots: Array with unique label per shoot region
308            - num_shoots: Number of distinct shoot regions found
309            - ref_points: Dict with structure {shoot_label: {
310                'centroid': (y, x),
311                'bottom_center': (y, x),
312                'bottom_most': (y, x),
313                'bbox': (min_y, max_y, min_x, max_x)
314              }}
315    
316    Examples:
317        >>> labeled_shoots, num_shoots, ref_points = get_shoot_reference_points_multi(shoot_mask)
318        >>> print(f"Found {num_shoots} shoots")
319        >>> print(ref_points[1]['bottom_center'])
320    """
321    # Label shoots
322    # labeled_shoots, num_shoots = ndimage.label(shoot_mask)
323    structure = ndimage.generate_binary_structure(2, 2)  # 2D, 8-connectivity
324    labeled_shoots, num_shoots = ndimage.label(shoot_mask, structure=structure)
325    
326    ref_points = {}
327    
328    for label in range(1, num_shoots + 1):
329        shoot_region = labeled_shoots == label
330        y_coords, x_coords = np.where(shoot_region)
331        
332        if len(y_coords) == 0:
333            continue
334            
335        # Centroid
336        centroid_y = np.mean(y_coords)
337        centroid_x = np.mean(x_coords)
338        
339        # Bounding box
340        min_y, max_y = y_coords.min(), y_coords.max()
341        min_x, max_x = x_coords.min(), x_coords.max()
342        
343        # Bottom center: centroid x, maximum y (lowest point)
344        bottom_center = (max_y, centroid_x)
345        
346        # Bottom most: actual bottom-most pixel
347        bottom_idx = np.argmax(y_coords)
348        bottom_most = (y_coords[bottom_idx], x_coords[bottom_idx])
349        
350        ref_points[label] = {
351            'centroid': (centroid_y, centroid_x),
352            'bottom_center': bottom_center,
353            'bottom_most': bottom_most,
354            'bbox': (min_y, max_y, min_x, max_x)
355        }
356    
357    return labeled_shoots, num_shoots, ref_points

Calculate multiple reference points for each shoot region.

Arguments:
  • shoot_mask: Binary mask array with shoot regions (H, W)
Returns:

tuple: (labeled_shoots, num_shoots, ref_points) where: - labeled_shoots: Array with unique label per shoot region - num_shoots: Number of distinct shoot regions found - ref_points: Dict with structure {shoot_label: { 'centroid': (y, x), 'bottom_center': (y, x), 'bottom_most': (y, x), 'bbox': (min_y, max_y, min_x, max_x) }}

Examples:
>>> labeled_shoots, num_shoots, ref_points = get_shoot_reference_points_multi(shoot_mask)
>>> print(f"Found {num_shoots} shoots")
>>> print(ref_points[1]['bottom_center'])
def filter_root_mask_by_sampling_box( root_mask, shoot_mask, ref_points, num_shoots, sampling_buffer_above=200, sampling_buffer_below=300):
360def filter_root_mask_by_sampling_box(root_mask, shoot_mask, ref_points, num_shoots,
361                                     sampling_buffer_above=200,
362                                     sampling_buffer_below=300):
363    """Filter root mask to keep only components starting within sampling box.
364    
365    Creates a sampling box around all shoots and keeps only skeleton components
366    whose top (min_y) falls within this box. Preserves full component length even
367    if it extends beyond the box boundaries.
368    
369    Args:
370        root_mask: Binary mask array with root regions (H, W)
371        shoot_mask: Binary mask array with shoot regions (H, W)
372        ref_points: Dict from get_shoot_reference_points_multi
373        num_shoots: Number of shoot regions
374        sampling_buffer_above: Pixels above top shoot to include
375        sampling_buffer_below: Pixels below bottom shoot to include
376        
377    Returns:
378        np.ndarray: Binary mask with filtered root components, preserving full lengths
379        
380    Examples:
381        >>> filtered_mask = filter_root_mask_by_sampling_box(
382        ...     root_mask, shoot_mask, ref_points, 5
383        ... )
384        >>> structures = extract_root_structures(filtered_mask)
385    """
386    # Calculate sampling box
387    shoot_y_min = min(ref_points[s]['bbox'][0] for s in range(1, num_shoots + 1))
388    shoot_y_max = max(ref_points[s]['bbox'][1] for s in range(1, num_shoots + 1))
389    
390    sampling_box_top = max(0, shoot_y_min - sampling_buffer_above)
391    sampling_box_bottom = min(root_mask.shape[0], shoot_y_max + sampling_buffer_below)
392    
393    # Label all connected components in the FULL mask
394    labeled_mask = label_components(root_mask)
395    num_components = labeled_mask.max()
396    
397    print(f"Sampling box: rows {sampling_box_top:.0f} to {sampling_box_bottom:.0f}")
398    print(f"Found {num_components} components in full mask")
399    
400    # Check which components have their TOP (min_y) in the sampling box
401    valid_labels = set()
402    
403    for component_label in range(1, num_components + 1):
404        component_mask = labeled_mask == component_label
405        coords = np.argwhere(component_mask)
406        
407        if len(coords) == 0:
408            continue
409        
410        comp_top_y = coords[:, 0].min()
411        
412        # Keep component if its top is in the sampling box
413        if sampling_box_top <= comp_top_y <= sampling_box_bottom:
414            valid_labels.add(component_label)
415    
416    # Create filtered mask with FULL components (not cropped)
417    filtered_mask = np.isin(labeled_mask, list(valid_labels))
418    
419    print(f"Kept {len(valid_labels)} components (full length preserved)")
420    
421    return filtered_mask.astype(bool)

Filter root mask to keep only components starting within sampling box.

Creates a sampling box around all shoots and keeps only skeleton components whose top (min_y) falls within this box. Preserves full component length even if it extends beyond the box boundaries.

Arguments:
  • root_mask: Binary mask array with root regions (H, W)
  • shoot_mask: Binary mask array with shoot regions (H, W)
  • ref_points: Dict from get_shoot_reference_points_multi
  • num_shoots: Number of shoot regions
  • sampling_buffer_above: Pixels above top shoot to include
  • sampling_buffer_below: Pixels below bottom shoot to include
Returns:

np.ndarray: Binary mask with filtered root components, preserving full lengths

Examples:
>>> filtered_mask = filter_root_mask_by_sampling_box(
...     root_mask, shoot_mask, ref_points, 5
... )
>>> structures = extract_root_structures(filtered_mask)
def score_skeleton_for_shoot( component_props, shoot_ref_point, shoot_bbox, max_distance=100, max_size=1000, ideal_aspect=5.0, distance_weight=0.5, size_weight=0.2, verticality_weight=0.3, max_horizontal_offset=400, max_above_shoot=200, max_below_shoot=100, min_skeleton_pixels=6, edge_exclusion_zones=[(0, 900), (3300, 4000)]):
424def score_skeleton_for_shoot(component_props, shoot_ref_point, shoot_bbox,
425                              max_distance=100, max_size=1000, ideal_aspect=5.0,
426                              distance_weight=0.5, size_weight=0.2, verticality_weight=0.3,
427                              max_horizontal_offset=400,
428                              max_above_shoot=200,
429                              max_below_shoot=100,
430                              min_skeleton_pixels=6,  # NEW parameter
431                              edge_exclusion_zones=[(0, 900), (3300, 4000)]):
432    """Score how likely a skeleton component is to be a root for a given shoot.
433    
434    Applies spatial filters (horizontal distance, edge zones) and calculates a
435    combined score based on proximity, size, and verticality. Higher scores
436    indicate better matches.
437    
438    Args:
439        component_props: Dict with keys:
440            - 'label': Component identifier
441            - 'centroid': (y, x) tuple
442            - 'bbox': (min_y, max_y, min_x, max_x) tuple
443            - 'num_pixels': Total skeleton pixels
444            - 'vertical_extent': Height in pixels
445            - 'horizontal_extent': Width in pixels
446        shoot_ref_point: (y, x) tuple for shoot reference position
447        shoot_bbox: (min_y, max_y, min_x, max_x) tuple for shoot bounding box
448        max_distance: Distance normalization constant (pixels)
449        max_size: Size normalization constant (pixels)
450        ideal_aspect: Ideal vertical:horizontal ratio for roots
451        distance_weight: Weight for proximity score (0-1)
452        size_weight: Weight for skeleton length (0-1)
453        verticality_weight: Weight for vertical orientation (0-1)
454        max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
455        max_above_shoot: Allow roots to extend this many pixels above shoot bottom
456        max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
457        min_skeleton_pixels: Minimum skeleton pixels to be valid candidate  # ADDED
458        edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise
459        
460    Returns:
461        dict: Scoring results with keys:
462            - 'score': Combined score (higher is better), -1 if invalid
463            - 'distance': Euclidean distance from shoot
464            - 'distance_score': Normalized distance component
465            - 'size_score': Normalized size component
466            - 'verticality_score': Normalized verticality component
467            - 'aspect_ratio': Vertical:horizontal ratio
468            - 'reason': Rejection reason if score is -1
469            
470    Examples:
471        >>> score_result = score_skeleton_for_shoot(comp_props, shoot_ref, shoot_bbox)
472        >>> if score_result['score'] > 0:
473        ...     print(f"Valid candidate with score {score_result['score']:.3f}")
474    """
475    shoot_y, shoot_x = shoot_ref_point
476    shoot_bottom = shoot_bbox[1]
477    
478    comp_centroid_y, comp_centroid_x = component_props['centroid']
479    comp_top_y = component_props['bbox'][0]
480    comp_bottom_y = component_props['bbox'][1]
481    comp_bbox = component_props['bbox']
482    comp_pixels = component_props['num_pixels']
483    vertical_extent = component_props['vertical_extent']
484    horizontal_extent = component_props['horizontal_extent']
485
486    # FILTER 0: Minimum size filter (reject noise)
487    if comp_pixels < min_skeleton_pixels:
488        return {
489            'score': -1,
490            'distance': float('inf'),
491            'distance_score': 0,
492            'size_score': 0,
493            'verticality_score': 0,
494            'reason': 'too_small'
495        }
496
497    # FILTER 1: Root must extend reasonably below shoot
498    if comp_bottom_y < (shoot_bottom - max_above_shoot):
499        return {
500            'score': -1,
501            'distance': float('inf'),
502            'distance_score': 0,
503            'size_score': 0,
504            'verticality_score': 0,
505            'reason': 'above_shoot'
506        }
507    
508    # FILTER 2: Root must start within reasonable distance below shoot
509    if comp_top_y > (shoot_bottom + max_below_shoot):
510        return {
511            'score': -1,
512            'distance': float('inf'),
513            'distance_score': 0,
514            'size_score': 0,
515            'verticality_score': 0,
516            'reason': 'too_far_below_shoot'
517        }
518    
519    # FILTER 3: Horizontal overlap check
520    comp_min_x = comp_bbox[2]
521    comp_max_x = comp_bbox[3]
522    search_zone_min_x = shoot_x - max_horizontal_offset
523    search_zone_max_x = shoot_x + max_horizontal_offset
524    
525    if comp_max_x < search_zone_min_x or comp_min_x > search_zone_max_x:
526        return {
527            'score': -1,
528            'distance': float('inf'),
529            'distance_score': 0,
530            'size_score': 0,
531            'verticality_score': 0,
532            'reason': f'too_far_horizontally'
533        }
534    
535    # FILTER 4: Edge exclusion zones
536    for min_edge_x, max_edge_x in edge_exclusion_zones:
537        if min_edge_x <= comp_centroid_x <= max_edge_x:
538            return {
539                'score': -1,
540                'distance': float('inf'),
541                'distance_score': 0,
542                'size_score': 0,
543                'verticality_score': 0,
544                'reason': f'in_edge_zone'
545            }
546    
547    # Calculate scores
548    distance = np.sqrt((comp_top_y - shoot_y)**2 + (comp_centroid_x - shoot_x)**2)
549    distance_score = 1.0 / (1.0 + distance / max_distance)
550    
551    size_score = min(comp_pixels / max_size, 1.0)
552    
553    if horizontal_extent > 0:
554        aspect_ratio = vertical_extent / horizontal_extent
555        verticality_score = min(aspect_ratio / ideal_aspect, 1.0)
556    else:
557        verticality_score = 1.0
558    
559    combined = (distance_weight * distance_score + 
560                size_weight * size_score + 
561                verticality_weight * verticality_score)
562    
563    return {
564        'score': combined,
565        'distance': distance,
566        'distance_score': distance_score,
567        'size_score': size_score,
568        'verticality_score': verticality_score,
569        'aspect_ratio': vertical_extent / max(horizontal_extent, 1)
570    }

Score how likely a skeleton component is to be a root for a given shoot.

Applies spatial filters (horizontal distance, edge zones) and calculates a combined score based on proximity, size, and verticality. Higher scores indicate better matches.

Arguments:
  • component_props: Dict with keys:
    • 'label': Component identifier
    • 'centroid': (y, x) tuple
    • 'bbox': (min_y, max_y, min_x, max_x) tuple
    • 'num_pixels': Total skeleton pixels
    • 'vertical_extent': Height in pixels
    • 'horizontal_extent': Width in pixels
  • shoot_ref_point: (y, x) tuple for shoot reference position
  • shoot_bbox: (min_y, max_y, min_x, max_x) tuple for shoot bounding box
  • max_distance: Distance normalization constant (pixels)
  • max_size: Size normalization constant (pixels)
  • ideal_aspect: Ideal vertical:horizontal ratio for roots
  • distance_weight: Weight for proximity score (0-1)
  • size_weight: Weight for skeleton length (0-1)
  • verticality_weight: Weight for vertical orientation (0-1)
  • max_horizontal_offset: Maximum horizontal distance from shoot (pixels)
  • max_above_shoot: Allow roots to extend this many pixels above shoot bottom
  • max_below_shoot: Maximum distance below shoot bottom where root can start (pixels)
  • min_skeleton_pixels: Minimum skeleton pixels to be valid candidate # ADDED
  • edge_exclusion_zones: List of (min_x, max_x) tuples for edge noise
Returns:

dict: Scoring results with keys: - 'score': Combined score (higher is better), -1 if invalid - 'distance': Euclidean distance from shoot - 'distance_score': Normalized distance component - 'size_score': Normalized size component - 'verticality_score': Normalized verticality component - 'aspect_ratio': Vertical:horizontal ratio - 'reason': Rejection reason if score is -1

Examples:
>>> score_result = score_skeleton_for_shoot(comp_props, shoot_ref, shoot_bbox)
>>> if score_result['score'] > 0:
...     print(f"Valid candidate with score {score_result['score']:.3f}")
def find_valid_candidates_for_shoots(structures, ref_points, num_shoots, config=None):
573def find_valid_candidates_for_shoots(structures, ref_points, num_shoots, config=None):
574    """Find and score valid root candidates for each shoot.
575    
576    Uses sampling box pre-filtering for efficiency, then scores each component
577    for each shoot. Returns organized candidates sorted by score.
578    
579    Args:
580        structures: Dict from extract_root_structures with keys:
581            - 'skeleton': Full skeleton array
582            - 'labeled_skeleton': Labeled skeleton array
583            - 'unique_labels': Array of label IDs
584            - 'roots': Dict of root data by label
585        ref_points: Dict from get_shoot_reference_points_multi
586        num_shoots: Number of shoots (typically 5)
587        config: RootMatchingConfig instance (uses defaults if None)
588        
589    Returns:
590        dict: {shoot_label: [(root_label, score_dict), ...]} where each shoot's
591              list is sorted by score descending
592              
593    Examples:
594        >>> candidates = find_valid_candidates_for_shoots(structures, ref_points, 5)
595        >>> for shoot_label in range(1, 6):
596        ...     print(f"Shoot {shoot_label}: {len(candidates[shoot_label])} candidates")
597    """
598    if config is None:
599        config = RootMatchingConfig()
600    
601    # Storage for each shoot's candidates
602    shoot_candidates = {label: [] for label in range(1, num_shoots + 1)}
603    
604    total_components = len(structures['roots'])
605    
606    # Process each skeleton component
607    for root_label, root_data in structures['roots'].items():
608        branch_data = root_data.get('branch_data')
609        
610        # Skip if no branch data
611        if branch_data is None or len(branch_data) == 0:
612            continue
613        
614        # Extract coordinates for this component
615        root_mask_region = root_data['mask']
616        coords = np.argwhere(root_mask_region)
617        
618        if len(coords) == 0:
619            continue
620        
621        # Build component properties
622        comp_props = {
623            'label': root_label,
624            'centroid': (coords[:, 0].mean(), coords[:, 1].mean()),
625            'bbox': (coords[:, 0].min(), coords[:, 0].max(), 
626                    coords[:, 1].min(), coords[:, 1].max()),
627            'num_pixels': root_data['total_pixels'],
628            'vertical_extent': coords[:, 0].max() - coords[:, 0].min(),
629            'horizontal_extent': coords[:, 1].max() - coords[:, 1].min()
630        }
631        
632        # Score this component for each shoot
633        for shoot_label in range(1, num_shoots + 1):
634            shoot_ref = ref_points[shoot_label]['bottom_center']
635            shoot_bbox = ref_points[shoot_label]['bbox']
636            
637            score_result = score_skeleton_for_shoot(
638                comp_props, shoot_ref, shoot_bbox,
639                max_distance=config.max_distance,
640                max_size=config.max_size,
641                ideal_aspect=config.ideal_aspect,
642                distance_weight=config.distance_weight,
643                size_weight=config.size_weight,
644                verticality_weight=config.verticality_weight,
645                max_horizontal_offset=config.max_horizontal_offset,
646                max_above_shoot=config.max_above_shoot,
647                max_below_shoot=config.max_below_shoot,
648                min_skeleton_pixels=config.min_skeleton_pixels,  # ADDED
649                edge_exclusion_zones=config.edge_exclusion_zones
650            )
651            
652            if score_result['score'] > 0:
653                shoot_candidates[shoot_label].append((root_label, score_result))
654    
655    # Sort each shoot's candidates by score (descending)
656    for shoot_label in shoot_candidates:
657        shoot_candidates[shoot_label].sort(key=lambda x: x[1]['score'], reverse=True)
658    
659    print(f"Valid candidates per shoot:")
660    for shoot_label in range(1, num_shoots + 1):
661        print(f"  Shoot {shoot_label}: {len(shoot_candidates[shoot_label])} candidates")
662    
663    return shoot_candidates

Find and score valid root candidates for each shoot.

Uses sampling box pre-filtering for efficiency, then scores each component for each shoot. Returns organized candidates sorted by score.

Arguments:
  • structures: Dict from extract_root_structures with keys:
    • 'skeleton': Full skeleton array
    • 'labeled_skeleton': Labeled skeleton array
    • 'unique_labels': Array of label IDs
    • 'roots': Dict of root data by label
  • ref_points: Dict from get_shoot_reference_points_multi
  • num_shoots: Number of shoots (typically 5)
  • config: RootMatchingConfig instance (uses defaults if None)
Returns:

dict: {shoot_label: [(root_label, score_dict), ...]} where each shoot's list is sorted by score descending

Examples:
>>> candidates = find_valid_candidates_for_shoots(structures, ref_points, 5)
>>> for shoot_label in range(1, 6):
...     print(f"Shoot {shoot_label}: {len(candidates[shoot_label])} candidates")
def assign_roots_to_shoots_greedy( shoot_candidates, ref_points, num_shoots, config=None, min_score_threshold=0.3):
667def assign_roots_to_shoots_greedy(shoot_candidates, ref_points, num_shoots, config=None, min_score_threshold=0.3):
668    """Assign roots to shoots using greedy left-to-right algorithm.
669    
670    Processes shoots in left-to-right order, assigning the best available
671    (unassigned) root to each shoot. Guarantees exactly num_shoots outputs.
672    
673    Args:
674        shoot_candidates: Dict from find_valid_candidates_for_shoots with structure
675            {shoot_label: [(root_label, score_dict), ...]}
676        ref_points: Dict from get_shoot_reference_points_multi
677        num_shoots: Number of shoots (always 5)
678        config: RootMatchingConfig instance (optional)
679        min_score_threshold: Minimum score required for assignment (default 0.3)
680        
681    Returns:
682        dict: {shoot_label: {
683            'root_label': int or None,
684            'score_dict': dict or None,
685            'order': int  # 0=leftmost, 4=rightmost
686        }}
687    """
688    # Sort shoots by x-position (left to right)
689    shoot_positions = []
690    for shoot_label in range(1, num_shoots + 1):
691        centroid_x = ref_points[shoot_label]['centroid'][1]
692        shoot_positions.append((shoot_label, centroid_x))
693    
694    shoot_positions.sort(key=lambda x: x[1])  # Sort by x coordinate
695    
696    # Track assigned roots
697    assigned_roots = set()
698    
699    # Build assignments
700    assignments = {}
701    
702    for order, (shoot_label, _) in enumerate(shoot_positions):
703        candidates = shoot_candidates[shoot_label]
704        
705        # Find best unassigned root
706        best_root = None
707        best_score_dict = None
708        
709        for root_label, score_dict in candidates:
710            if root_label not in assigned_roots:
711                # Check if score meets minimum threshold
712                if score_dict['score'] >= min_score_threshold:
713                    best_root = root_label
714                    best_score_dict = score_dict
715                    break  # Candidates are already sorted by score
716        
717        # Store assignment
718        if best_root is not None:
719            assignments[shoot_label] = {
720                'root_label': best_root,
721                'score_dict': best_score_dict,
722                'order': order
723            }
724            assigned_roots.add(best_root)
725        else:
726            # No valid candidates (ungerminated seed or below threshold)
727            assignments[shoot_label] = {
728                'root_label': None,
729                'score_dict': None,
730                'order': order
731            }
732    
733    return assignments

Assign roots to shoots using greedy left-to-right algorithm.

Processes shoots in left-to-right order, assigning the best available (unassigned) root to each shoot. Guarantees exactly num_shoots outputs.

Arguments:
  • shoot_candidates: Dict from find_valid_candidates_for_shoots with structure {shoot_label: [(root_label, score_dict), ...]}
  • ref_points: Dict from get_shoot_reference_points_multi
  • num_shoots: Number of shoots (always 5)
  • config: RootMatchingConfig instance (optional)
  • min_score_threshold: Minimum score required for assignment (default 0.3)
Returns:

dict: {shoot_label: { 'root_label': int or None, 'score_dict': dict or None, 'order': int # 0=leftmost, 4=rightmost }}

def find_best_top_node( root_skeleton, shoot_ref_point, structures, root_label, distance_threshold=150, prefer_topmost=True):
736def find_best_top_node(root_skeleton, shoot_ref_point, structures, root_label, 
737                       distance_threshold=150, prefer_topmost=True):
738    """Find the skeleton node closest to the shoot reference point.
739    
740    Uses a hybrid approach: filters nodes within distance_threshold of shoot,
741    then selects based on preference (topmost or closest).
742    
743    Args:
744        root_skeleton: Binary mask for this specific root (not used, kept for signature)
745        shoot_ref_point: (y, x) tuple for shoot reference position
746        structures: Dict from extract_root_structures
747        root_label: Label ID for this root
748        distance_threshold: Maximum distance from shoot to consider nodes (pixels)
749        prefer_topmost: If True, pick topmost among candidates; if False, pick closest
750        
751    Returns:
752        int: Node ID of the selected top node, or None if no valid nodes found
753    """
754    
755    # Get branch data for this root
756    branch_data = structures['roots'][root_label].get('branch_data')
757    if branch_data is None or len(branch_data) == 0:
758        return None
759    
760    # Extract all unique nodes from branch_data
761    src_nodes = branch_data[['node-id-src', 'image-coord-src-0', 'image-coord-src-1']].drop_duplicates()
762    dst_nodes = branch_data[['node-id-dst', 'image-coord-dst-0', 'image-coord-dst-1']].drop_duplicates()
763    
764    src_nodes.columns = ['node_id', 'y', 'x']
765    dst_nodes.columns = ['node_id', 'y', 'x']
766    
767    all_nodes = pd.concat([src_nodes, dst_nodes]).drop_duplicates(subset='node_id')
768    node_coords = all_nodes[['y', 'x']].values
769    node_ids = all_nodes['node_id'].values
770    
771    if len(node_ids) == 0:
772        return None
773    
774    # Calculate distances to shoot reference point
775    distances = cdist(node_coords, [shoot_ref_point], metric='euclidean').flatten()
776    
777    # Filter nodes within distance threshold
778    within_threshold = distances <= distance_threshold
779    
780    if not np.any(within_threshold):
781        # No nodes within threshold, fall back to closest node
782        return int(node_ids[np.argmin(distances)])
783    
784    # Get candidate nodes within threshold
785    candidate_indices = np.where(within_threshold)[0]
786    candidate_coords = node_coords[candidate_indices]
787    candidate_ids = node_ids[candidate_indices]
788    candidate_distances = distances[candidate_indices]
789    
790    if prefer_topmost:
791        # Pick node with lowest y-coordinate (topmost)
792        best_idx = np.argmin(candidate_coords[:, 0])
793        return int(candidate_ids[best_idx])
794    else:
795        # Pick node with minimum distance
796        best_idx = np.argmin(candidate_distances)
797        return int(candidate_ids[best_idx])

Find the skeleton node closest to the shoot reference point.

Uses a hybrid approach: filters nodes within distance_threshold of shoot, then selects based on preference (topmost or closest).

Arguments:
  • root_skeleton: Binary mask for this specific root (not used, kept for signature)
  • shoot_ref_point: (y, x) tuple for shoot reference position
  • structures: Dict from extract_root_structures
  • root_label: Label ID for this root
  • distance_threshold: Maximum distance from shoot to consider nodes (pixels)
  • prefer_topmost: If True, pick topmost among candidates; if False, pick closest
Returns:

int: Node ID of the selected top node, or None if no valid nodes found

def pixel_to_robot_coords( pixel_x, pixel_y, image_shape, dish_size_m=0.15, dish_offset_m=[0.10775, 0.062, 0.175]):
800def pixel_to_robot_coords(pixel_x, pixel_y, image_shape, 
801                         dish_size_m=0.15,
802                         dish_offset_m=[0.10775, 0.062, 0.175]):
803    """
804    Convert pixel coordinates to robot coordinates in meters.
805    
806    Converts ROI-relative pixel coordinates to robot world coordinates.
807    Assumes square petri dish with gantry axes aligned to image axes:
808    - Robot X-axis aligns with image Y-axis (rows, downward)
809    - Robot Y-axis aligns with image X-axis (columns, rightward)
810    
811    Args:
812        pixel_x: X coordinate in pixels (column, increases right)
813        pixel_y: Y coordinate in pixels (row, increases down)
814        image_shape: Tuple of (height, width) of the ROI
815        dish_size_m: Size of square petri dish in meters (default 0.15m)
816        dish_offset_m: Robot coordinates [x, y, z] of plate top-left corner in meters
817                      Default [0.10775, 0.062, 0.175] from simulation specification
818        
819    Returns:
820        tuple: (robot_x_m, robot_y_m, robot_z_m) in meters
821    """
822    height, width = image_shape[:2]
823    
824    # Calculate scale (meters per pixel)
825    scale = dish_size_m / width
826    
827    # Convert to plate-relative meters
828    plate_x = pixel_x * scale
829    plate_y = pixel_y * scale
830    
831    # Map to robot coordinates (gantry alignment, no rotation)
832    robot_x = plate_y + dish_offset_m[0]  # Image Y-axis → Robot X-axis
833    robot_y = plate_x + dish_offset_m[1]  # Image X-axis → Robot Y-axis
834    robot_z = dish_offset_m[2]
835    
836    return robot_x, robot_y, robot_z

Convert pixel coordinates to robot coordinates in meters.

Converts ROI-relative pixel coordinates to robot world coordinates. Assumes square petri dish with gantry axes aligned to image axes:

  • Robot X-axis aligns with image Y-axis (rows, downward)
  • Robot Y-axis aligns with image X-axis (columns, rightward)
Arguments:
  • pixel_x: X coordinate in pixels (column, increases right)
  • pixel_y: Y coordinate in pixels (row, increases down)
  • image_shape: Tuple of (height, width) of the ROI
  • dish_size_m: Size of square petri dish in meters (default 0.15m)
  • dish_offset_m: Robot coordinates [x, y, z] of plate top-left corner in meters Default [0.10775, 0.062, 0.175] from simulation specification
Returns:

tuple: (robot_x_m, robot_y_m, robot_z_m) in meters

def match_roots_to_shoots_complete( shoot_mask, root_mask, image_path, config=None, distance_threshold=150, prefer_topmost=True, min_score_threshold=0.3, verbose=False, return_dataframe=False, sample_idx=None, visual_debugging=False, output_path=None):
 838def match_roots_to_shoots_complete(shoot_mask, root_mask, image_path, config=None, 
 839                                   distance_threshold=150, prefer_topmost=True,
 840                                   min_score_threshold=0.3, verbose=False,
 841                                   return_dataframe=False, sample_idx=None,
 842                                   visual_debugging=False, output_path=None):
 843    """Complete pipeline: shoot and root masks in, 5 length measurements out.
 844    
 845    High-level wrapper that runs the entire matching pipeline:
 846    1. Extract shoot reference points
 847    2. Filter root mask by sampling box
 848    3. Extract root structures
 849    4. Find valid candidates for each shoot
 850    5. Assign roots to shoots greedily (left to right)
 851    6. Find best top nodes for assignments
 852    7. Calculate root lengths from top nodes
 853    
 854    Args:
 855        shoot_mask: Binary mask array with shoot regions (H, W)
 856        root_mask: Binary mask array with root regions (H, W)
 857        image_path: Path to original image for ROI detection
 858        config: RootMatchingConfig instance (uses defaults if None)
 859        distance_threshold: Max distance from shoot for top node selection (pixels)
 860        prefer_topmost: If True, prefer topmost node; if False, prefer closest
 861        min_score_threshold: Minimum score for valid assignment
 862        verbose: Print progress messages
 863        return_dataframe: If True, return pandas DataFrame; if False, return numpy array
 864        sample_idx: Sample index to include in DataFrame (only used if return_dataframe=True)
 865        visual_debugging: If True, display visualizations of assignments and root lengths
 866        output_path: Optional path to save visualization outputs (not yet implemented)
 867        
 868    Returns:
 869        If return_dataframe=False (default):
 870            np.ndarray: Array of 5 root lengths in pixels, ordered left to right.
 871        If return_dataframe=True:
 872            pd.DataFrame: DataFrame with columns:
 873                - 'plant_order': 1-5 (left to right)
 874                - 'Plant ID': test_image_{sample_idx + 1}
 875                - 'Length (px)': root length in pixels
 876                - 'length_px': root length in pixels (duplicate)
 877                - 'top_node_x': x-coordinate of top node (pixels)
 878                - 'top_node_y': y-coordinate of top node (pixels)
 879                - 'endpoint_x': x-coordinate of endpoint node (pixels)
 880                - 'endpoint_y': y-coordinate of endpoint node (pixels)
 881                - 'top_node_robot_x': x-coordinate of top node (meters)
 882                - 'top_node_robot_y': y-coordinate of top node (meters)
 883                - 'top_node_robot_z': z-coordinate of top node (meters)
 884                - 'endpoint_robot_x': x-coordinate of endpoint (meters)
 885                - 'endpoint_robot_y': y-coordinate of endpoint (meters)
 886                - 'endpoint_robot_z': z-coordinate of endpoint (meters)
 887            
 888    Raises:
 889        ValueError: If shoot_mask does not contain exactly 5 shoots (warning only)
 890        
 891    Examples:
 892        >>> # Get numpy array
 893        >>> lengths = match_roots_to_shoots_complete(shoot_mask, root_mask)
 894        >>> 
 895        >>> # Get DataFrame
 896        >>> df = match_roots_to_shoots_complete(shoot_mask, root_mask, 
 897        ...                                      return_dataframe=True, sample_idx=0)
 898        
 899    Notes:
 900        - Always returns exactly 5 measurements
 901        - Left-to-right order is determined by shoot x-position (centroid)
 902        - Robust to missing roots (returns 0.0) and noisy masks
 903    """
 904
 905    
 906    if config is None:
 907        config = RootMatchingConfig()
 908    
 909    if verbose:
 910        print("Starting complete root-shoot matching pipeline...")
 911    
 912    # Step 0: Detect ROI from original image
 913    if verbose:
 914        print("  Step 0: Detecting ROI from original image...")
 915    
 916    from library.roi import detect_roi
 917    from library.mask_processing import load_mask
 918    
 919    roi_bbox = detect_roi(load_mask(str(image_path)))
 920    (x1, y1), (x2, y2) = roi_bbox
 921    roi_width = x2 - x1
 922    roi_height = y2 - y1
 923    
 924    if verbose:
 925        print(f"    ROI bbox: ({x1}, {y1}) to ({x2}, {y2}), size={roi_width}x{roi_height}px")
 926    
 927    # Step 1: Get shoot reference points
 928    if verbose:
 929        print("  Step 1: Extracting shoot reference points...")
 930    labeled_shoots, num_shoots, ref_points = get_shoot_reference_points_multi(shoot_mask)
 931    
 932    if num_shoots != 5:
 933        warnings.warn(f"Expected 5 shoots, found {num_shoots}. Continuing anyway.", UserWarning)
 934    
 935    if verbose:
 936        print(f"    Found {num_shoots} shoots")
 937    
 938    # Step 2: Filter root mask by sampling box
 939    if verbose:
 940        print("  Step 2: Filtering root mask by sampling box...")
 941    filtered_root_mask = filter_root_mask_by_sampling_box(
 942        root_mask, shoot_mask, ref_points, num_shoots,
 943        sampling_buffer_above=config.sampling_buffer_above,
 944        sampling_buffer_below=config.sampling_buffer_below
 945    )
 946    
 947    # Step 3: Extract root structures
 948    if verbose:
 949        print("  Step 3: Extracting root structures...")
 950    structures = extract_root_structures(filtered_root_mask, verbose=verbose)
 951    
 952    if verbose:
 953        print(f"    Extracted {len(structures['roots'])} root structures")
 954    
 955    # Step 4: Find valid candidates
 956    if verbose:
 957        print("  Step 4: Finding valid candidates for each shoot...")
 958    shoot_candidates = find_valid_candidates_for_shoots(structures, ref_points, num_shoots, config)
 959    
 960    # Step 5: Assign roots to shoots
 961    if verbose:
 962        print("  Step 5: Assigning roots to shoots (greedy left-to-right)...")
 963    assignments = assign_roots_to_shoots_greedy(shoot_candidates, ref_points, num_shoots, 
 964                                                min_score_threshold=min_score_threshold)
 965    
 966    # Step 6 & 7: Find top nodes and calculate lengths
 967    if verbose:
 968        print("  Step 6-7: Finding top nodes and calculating lengths...")
 969    
 970    # Helper function to get node coordinates
 971    def get_node_coords(node_id, branch_data):
 972        """Extract (y, x) coordinates for a given node_id from branch_data."""
 973        for idx, row in branch_data.iterrows():
 974            if row['node-id-src'] == node_id:
 975                return (row['image-coord-src-0'], row['image-coord-src-1'])
 976            elif row['node-id-dst'] == node_id:
 977                return (row['image-coord-dst-0'], row['image-coord-dst-1'])
 978        return None
 979    
 980    # Build structure for process_matched_roots_to_lengths
 981    top_node_results = {}
 982    
 983    for shoot_label, info in assignments.items():
 984        root_label = info['root_label']
 985        
 986        if root_label is not None:
 987            # Find top node
 988            shoot_ref = ref_points[shoot_label]['bottom_center']
 989            top_node = find_best_top_node(None, shoot_ref, structures, root_label,
 990                                         distance_threshold=distance_threshold,
 991                                         prefer_topmost=prefer_topmost)
 992            
 993            if top_node is not None:
 994                branch_data = structures['roots'][root_label]['branch_data']
 995                
 996                top_node_results[root_label] = {
 997                    'shoot_label': shoot_label,
 998                    'top_nodes': [(top_node, 0)],
 999                    'branch_data': branch_data
1000                }
1001    
1002    # Calculate lengths and coordinates for all matched roots
1003    lengths_dict = {}
1004    top_coords_dict = {}
1005    endpoint_coords_dict = {}
1006    
1007    for root_label, result in top_node_results.items():
1008        branch_data = result['branch_data']
1009        shoot_label = result['shoot_label']
1010        top_node = result['top_nodes'][0][0]
1011        
1012        # Get top node coordinates
1013        top_coords = get_node_coords(top_node, branch_data)
1014        if top_coords:
1015            top_coords_dict[shoot_label] = top_coords
1016        
1017        try:
1018            # Find longest path from top node
1019            path = find_farthest_endpoint_path(
1020                branch_data,
1021                top_node,
1022                direction='down',
1023                use_smart_scoring=True,
1024                verbose=False
1025            )
1026            
1027            # Get endpoint node (last node in path)
1028            endpoint_node = path[-1][0]
1029            endpoint_coords = get_node_coords(endpoint_node, branch_data)
1030            if endpoint_coords:
1031                endpoint_coords_dict[shoot_label] = endpoint_coords
1032            
1033            # Calculate length
1034            root_length = calculate_skeleton_length_px(path)
1035            lengths_dict[shoot_label] = root_length
1036            
1037            if verbose:
1038                print(f"    Shoot {shoot_label}: Root {root_label}, length={root_length:.1f}px, "
1039                      f"top={top_coords}, endpoint={endpoint_coords}")
1040            
1041        except Exception as e:
1042            warnings.warn(f"Failed to calculate length for root {root_label} (shoot {shoot_label}): {e}", UserWarning)
1043            lengths_dict[shoot_label] = 0.0
1044            
1045            if verbose:
1046                print(f"    Shoot {shoot_label}: Root {root_label}, ERROR - returning 0.0")
1047    
1048    # Build output array ordered by shoot position (left to right)
1049    shoot_positions = []
1050    for shoot_label in range(1, num_shoots + 1):
1051        shoot_x = ref_points[shoot_label]['centroid'][1]
1052        shoot_positions.append((shoot_label, shoot_x))
1053    
1054    shoot_positions.sort(key=lambda x: x[1])
1055    
1056    # Create final arrays
1057    lengths_array = np.array([
1058        lengths_dict.get(shoot_label, 0.0)
1059        for shoot_label, _ in shoot_positions
1060    ])
1061    
1062    top_x_array = np.array([
1063        top_coords_dict.get(shoot_label, (np.nan, np.nan))[1]
1064        for shoot_label, _ in shoot_positions
1065    ])
1066    
1067    top_y_array = np.array([
1068        top_coords_dict.get(shoot_label, (np.nan, np.nan))[0]
1069        for shoot_label, _ in shoot_positions
1070    ])
1071    
1072    endpoint_x_array = np.array([
1073        endpoint_coords_dict.get(shoot_label, (np.nan, np.nan))[1]
1074        for shoot_label, _ in shoot_positions
1075    ])
1076    
1077    endpoint_y_array = np.array([
1078        endpoint_coords_dict.get(shoot_label, (np.nan, np.nan))[0]
1079        for shoot_label, _ in shoot_positions
1080    ])
1081    
1082    if verbose:
1083        print(f"\n  Final lengths (left to right): {lengths_array}")
1084        print("  Pipeline complete!")
1085    
1086    # Visual debugging
1087    if visual_debugging:
1088
1089        im_path = Path(image_path)
1090        print("="*50)
1091        print(f"\n  Generating visualizations for {im_path.name}")
1092
1093        print('   Original image')
1094        image = load_mask(im_path)
1095
1096        fig, ax = plt.subplots(figsize=(12, 8))
1097        ax.imshow(image, cmap='gray')
1098        ax.set_title(im_path.name, fontsize=14)
1099        ax.axis('off')
1100        plt.tight_layout()
1101        plt.show()
1102
1103        
1104        from library.root_analysis_visualization import visualize_assignments, visualize_root_lengths
1105        
1106        # Visualize assignments
1107        print('    Visualize shoot & soot assignments')
1108        visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points)
1109        
1110        # Visualize root lengths with detailed views
1111        print('    Visualize root skeletons with node and edge network')
1112        visualize_root_lengths(structures, top_node_results, labeled_shoots, 
1113                              show_detailed_roots=True)
1114    
1115    # Return DataFrame or array
1116    if return_dataframe:
1117        import pandas as pd
1118        
1119        # Convert pixel coordinates to robot coordinates using ROI-relative system
1120        roi_shape = (roi_height, roi_width)
1121        
1122        top_robot_coords = [pixel_to_robot_coords(x - x1, y - y1, roi_shape) 
1123                           if not (np.isnan(x) or np.isnan(y)) else (np.nan, np.nan, np.nan)
1124                           for x, y in zip(top_x_array, top_y_array)]
1125        endpoint_robot_coords = [pixel_to_robot_coords(x - x1, y - y1, roi_shape) 
1126                                if not (np.isnan(x) or np.isnan(y)) else (np.nan, np.nan, np.nan)
1127                                for x, y in zip(endpoint_x_array, endpoint_y_array)]
1128        
1129        # Unpack into separate arrays
1130        top_robot_x = np.array([c[0] for c in top_robot_coords])
1131        top_robot_y = np.array([c[1] for c in top_robot_coords])
1132        top_robot_z = np.array([c[2] for c in top_robot_coords])
1133        
1134        endpoint_robot_x = np.array([c[0] for c in endpoint_robot_coords])
1135        endpoint_robot_y = np.array([c[1] for c in endpoint_robot_coords])
1136        endpoint_robot_z = np.array([c[2] for c in endpoint_robot_coords])
1137        
1138        df_data = {
1139            'plant_order': list(range(1, 6)),
1140            'Plant ID': [f'test_image_{sample_idx + 1:02d}_plant_{i}' if sample_idx is not None else f'unknown_plant_{i}' for i in range(1, 6)],
1141            'Length (px)': lengths_array,
1142            'length_px': lengths_array,
1143            'top_node_x': top_x_array,
1144            'top_node_y': top_y_array,
1145            'endpoint_x': endpoint_x_array,
1146            'endpoint_y': endpoint_y_array,
1147            'top_node_robot_x': top_robot_x,
1148            'top_node_robot_y': top_robot_y,
1149            'top_node_robot_z': top_robot_z,
1150            'endpoint_robot_x': endpoint_robot_x,
1151            'endpoint_robot_y': endpoint_robot_y,
1152            'endpoint_robot_z': endpoint_robot_z
1153        }
1154        
1155        return pd.DataFrame(df_data)
1156    else:
1157        return lengths_array

Complete pipeline: shoot and root masks in, 5 length measurements out.

High-level wrapper that runs the entire matching pipeline:

  1. Extract shoot reference points
  2. Filter root mask by sampling box
  3. Extract root structures
  4. Find valid candidates for each shoot
  5. Assign roots to shoots greedily (left to right)
  6. Find best top nodes for assignments
  7. Calculate root lengths from top nodes
Arguments:
  • shoot_mask: Binary mask array with shoot regions (H, W)
  • root_mask: Binary mask array with root regions (H, W)
  • image_path: Path to original image for ROI detection
  • config: RootMatchingConfig instance (uses defaults if None)
  • distance_threshold: Max distance from shoot for top node selection (pixels)
  • prefer_topmost: If True, prefer topmost node; if False, prefer closest
  • min_score_threshold: Minimum score for valid assignment
  • verbose: Print progress messages
  • return_dataframe: If True, return pandas DataFrame; if False, return numpy array
  • sample_idx: Sample index to include in DataFrame (only used if return_dataframe=True)
  • visual_debugging: If True, display visualizations of assignments and root lengths
  • output_path: Optional path to save visualization outputs (not yet implemented)
Returns:

If return_dataframe=False (default): np.ndarray: Array of 5 root lengths in pixels, ordered left to right. If return_dataframe=True: pd.DataFrame: DataFrame with columns: - 'plant_order': 1-5 (left to right) - 'Plant ID': test_image_{sample_idx + 1} - 'Length (px)': root length in pixels - 'length_px': root length in pixels (duplicate) - 'top_node_x': x-coordinate of top node (pixels) - 'top_node_y': y-coordinate of top node (pixels) - 'endpoint_x': x-coordinate of endpoint node (pixels) - 'endpoint_y': y-coordinate of endpoint node (pixels) - 'top_node_robot_x': x-coordinate of top node (meters) - 'top_node_robot_y': y-coordinate of top node (meters) - 'top_node_robot_z': z-coordinate of top node (meters) - 'endpoint_robot_x': x-coordinate of endpoint (meters) - 'endpoint_robot_y': y-coordinate of endpoint (meters) - 'endpoint_robot_z': z-coordinate of endpoint (meters)

Raises:
  • ValueError: If shoot_mask does not contain exactly 5 shoots (warning only)
Examples:
>>> # Get numpy array
>>> lengths = match_roots_to_shoots_complete(shoot_mask, root_mask)
>>> 
>>> # Get DataFrame
>>> df = match_roots_to_shoots_complete(shoot_mask, root_mask, 
...                                      return_dataframe=True, sample_idx=0)
Notes:
  • Always returns exactly 5 measurements
  • Left-to-right order is determined by shoot x-position (centroid)
  • Robust to missing roots (returns 0.0) and noisy masks