library.shoot_mask_cleaning

Shoot mask cleaning pipeline for plant root analysis.

This module provides functions for cleaning and filtering shoot segmentation masks from U-Net model predictions. The pipeline identifies exactly 5 shoot locations using adaptive Y-zone detection and X-axis peak finding, handling edge cases like failed germination and fallen shoots.

Typical workflow:

Step 1: Calculate global statistics across all masks (one-time setup)

>>> from pathlib import Path
>>> import cv2
>>> from shoot_mask_cleaning import calculate_global_y_stats
>>> 
>>> mask_dir = Path('data/shoot_masks')
>>> mask_paths = [str(f) for f in sorted(mask_dir.glob('*.png'))]
>>> 
>>> global_stats = calculate_global_y_stats(
...     mask_paths, 
...     kernel_size=5,
...     iterations=3,
...     y_min=200, 
...     y_max=750
... )
>>> print(f"Global mean Y: {global_stats['global_mean']:.1f}")
>>> print(f"Global std Y: {global_stats['global_std']:.1f}")

Step 2: Process individual masks

>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
>>> 
>>> for mask_path in mask_paths:
...     result = clean_shoot_mask_pipeline(mask_path, global_stats)
...     
...     # Save cleaned mask
...     output_path = f"cleaned/{result['filename']}"
...     cv2.imwrite(output_path, result['cleaned_mask'])
...     
...     # Check results
...     print(f"{result['filename']}: {result['method']}, "
...           f"{result['num_components']} components")

Step 3: Debug problematic images

>>> from shoot_mask_visualization import debug_peak_detection
>>> 
>>> # Visualize complete pipeline for troubleshooting
>>> debug_peak_detection('data/shoot_masks/problem_image.png', global_stats)

Recommended parameters (validated on 19 test images):

Initial closing:
    - kernel_size: 7
    - iterations: 5

Y-zone boundaries:
    - y_min: 200 (shoots never above this)
    - y_max: 750 (shoots never below this)

Peak detection:
    - n_peaks: 5 (expected number of plants)
    - min_distance: 300 (minimum spacing between shoots)
    - x_min: 1000 (shoots never left of this)
    - initial_std: 2.0 (starting Y-zone width)
    - quality_check_threshold: 2.5 (when to validate peak quality)
    - min_peak_width: 20 (minimum width for valid peaks)
    - min_peak_height: 10 (minimum height for valid peaks)
    - min_area: 500 (for size-based fallback)

Filtering:
    - band_width: 100 (X-distance around peaks)
Dependencies:
  • numpy
  • opencv-python (cv2)
  • scipy (for peak detection)

Author: Aaron Ciuffo Date: December 2024

  1"""Shoot mask cleaning pipeline for plant root analysis.
  2
  3This module provides functions for cleaning and filtering shoot segmentation masks
  4from U-Net model predictions. The pipeline identifies exactly 5 shoot locations
  5using adaptive Y-zone detection and X-axis peak finding, handling edge cases like
  6failed germination and fallen shoots.
  7
  8Typical workflow:
  9    
 10    Step 1: Calculate global statistics across all masks (one-time setup)
 11    
 12        >>> from pathlib import Path
 13        >>> import cv2
 14        >>> from shoot_mask_cleaning import calculate_global_y_stats
 15        >>> 
 16        >>> mask_dir = Path('data/shoot_masks')
 17        >>> mask_paths = [str(f) for f in sorted(mask_dir.glob('*.png'))]
 18        >>> 
 19        >>> global_stats = calculate_global_y_stats(
 20        ...     mask_paths, 
 21        ...     kernel_size=5,
 22        ...     iterations=3,
 23        ...     y_min=200, 
 24        ...     y_max=750
 25        ... )
 26        >>> print(f"Global mean Y: {global_stats['global_mean']:.1f}")
 27        >>> print(f"Global std Y: {global_stats['global_std']:.1f}")
 28    
 29    Step 2: Process individual masks
 30    
 31        >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
 32        >>> 
 33        >>> for mask_path in mask_paths:
 34        ...     result = clean_shoot_mask_pipeline(mask_path, global_stats)
 35        ...     
 36        ...     # Save cleaned mask
 37        ...     output_path = f"cleaned/{result['filename']}"
 38        ...     cv2.imwrite(output_path, result['cleaned_mask'])
 39        ...     
 40        ...     # Check results
 41        ...     print(f"{result['filename']}: {result['method']}, "
 42        ...           f"{result['num_components']} components")
 43    
 44    Step 3: Debug problematic images
 45    
 46        >>> from shoot_mask_visualization import debug_peak_detection
 47        >>> 
 48        >>> # Visualize complete pipeline for troubleshooting
 49        >>> debug_peak_detection('data/shoot_masks/problem_image.png', global_stats)
 50
 51Recommended parameters (validated on 19 test images):
 52    
 53    Initial closing:
 54        - kernel_size: 7
 55        - iterations: 5
 56    
 57    Y-zone boundaries:
 58        - y_min: 200 (shoots never above this)
 59        - y_max: 750 (shoots never below this)
 60    
 61    Peak detection:
 62        - n_peaks: 5 (expected number of plants)
 63        - min_distance: 300 (minimum spacing between shoots)
 64        - x_min: 1000 (shoots never left of this)
 65        - initial_std: 2.0 (starting Y-zone width)
 66        - quality_check_threshold: 2.5 (when to validate peak quality)
 67        - min_peak_width: 20 (minimum width for valid peaks)
 68        - min_peak_height: 10 (minimum height for valid peaks)
 69        - min_area: 500 (for size-based fallback)
 70    
 71    Filtering:
 72        - band_width: 100 (X-distance around peaks)
 73
 74Dependencies:
 75    - numpy
 76    - opencv-python (cv2)
 77    - scipy (for peak detection)
 78
 79Author: Aaron Ciuffo
 80Date: December 2024
 81"""
 82
 83import numpy as np
 84import cv2
 85from pathlib import Path
 86from scipy.signal import find_peaks, peak_widths
 87
 88
 89def join_shoot_fragments(mask, kernel_size=7, iterations=5):
 90    """Join small shoot fragments using morphological closing.
 91    
 92    Applies morphological closing to connect nearby fragments and fill small gaps
 93    in the segmentation mask. This preprocessing step improves peak detection by
 94    creating more continuous shoot structures.
 95    
 96    Args:
 97        mask: Binary mask array with values 0 and 255.
 98        kernel_size: Size of square structuring element (default: 7).
 99        iterations: Number of closing iterations (default: 5).
100        
101    Returns:
102        Binary mask array with joined fragments (values 0 and 255).
103        
104    Example:
105        >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
106        >>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3)
107        >>> cv2.imwrite('joined_mask.png', joined)
108    """
109    kernel = np.ones((kernel_size, kernel_size), np.uint8)
110    joined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=iterations)
111    return joined
112
113
114def calculate_y_density(mask):
115    """Calculate pixel density per Y coordinate.
116    
117    Sums the number of foreground pixels along each row to create a 1D density
118    profile. Used for finding the vertical position of shoots.
119    
120    Args:
121        mask: Binary mask with values 0 and 255.
122        
123    Returns:
124        1D numpy array of pixel counts per row (length = mask height).
125        
126    Example:
127        >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
128        >>> density = calculate_y_density(mask)
129        >>> print(f"Peak density at Y={np.argmax(density)}")
130    """
131    return np.sum(mask > 0, axis=1)
132
133
134def calculate_weighted_y_stats(density, y_min=200, y_max=750):
135    """Calculate weighted mean and standard deviation of Y positions within ROI.
136    
137    Computes the center of mass and spread of shoot pixels along the Y-axis,
138    restricted to the valid Y-range. Pixels outside the ROI are ignored to
139    prevent noise from biasing statistics.
140    
141    Args:
142        density: 1D array of pixel counts per row from calculate_y_density.
143        y_min: Minimum Y coordinate to consider (default: 200).
144        y_max: Maximum Y coordinate to consider (default: 750).
145        
146    Returns:
147        Dictionary with keys:
148            - 'mean': Weighted mean Y position (pixels).
149            - 'std': Weighted standard deviation (pixels).
150        
151    Example:
152        >>> density = calculate_y_density(mask)
153        >>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
154        >>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}")
155    """
156    # Clip density to ROI
157    roi_density = density.copy()
158    roi_density[:y_min] = 0
159    roi_density[y_max:] = 0
160    
161    y_positions = np.arange(len(roi_density))
162    total_pixels = np.sum(roi_density)
163    
164    if total_pixels == 0:
165        return {'mean': 0, 'std': 0}
166    
167    # Weighted mean (center of mass)
168    weighted_mean = np.sum(y_positions * roi_density) / total_pixels
169    
170    # Weighted standard deviation
171    weighted_var = np.sum(roi_density * (y_positions - weighted_mean)**2) / total_pixels
172    weighted_std = np.sqrt(weighted_var)
173    
174    return {
175        'mean': weighted_mean,
176        'std': weighted_std
177    }
178
179
180def calculate_global_y_stats(mask_paths, kernel_size=7, iterations=5, y_min=200, y_max=750):
181    """Calculate global weighted mean and standard deviation across all masks.
182    
183    Processes all masks to establish global shoot position statistics. These
184    statistics serve as priors for individual image processing, enabling adaptive
185    Y-zone detection that handles variation in shoot positions.
186    
187    Args:
188        mask_paths: List of paths (strings or Path objects) to mask files.
189        kernel_size: Kernel size for initial morphological closing (default: 7).
190        iterations: Number of closing iterations (default: 5).
191        y_min: Minimum Y coordinate to consider (default: 200).
192        y_max: Maximum Y coordinate to consider (default: 750).
193        
194    Returns:
195        Dictionary with keys:
196            - 'global_mean': Mean Y position across all images (pixels).
197            - 'global_std': Mean standard deviation across all images (pixels).
198            - 'all_stats': List of per-image statistics dictionaries.
199            - 'all_means': List of individual image means.
200            - 'all_stds': List of individual image standard deviations.
201            - 'y_min': Y minimum boundary used.
202            - 'y_max': Y maximum boundary used.
203        
204    Example:
205        >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
206        >>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750)
207        >>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± "
208        ...       f"{2 * global_stats['global_std']:.0f} pixels")
209    
210    Note:
211        This function should be called once per dataset to establish the global
212        statistics used for processing all individual images.
213    """
214    all_stats = []
215    
216    for path in mask_paths:
217        mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
218        joined = join_shoot_fragments(mask, kernel_size, iterations)
219        density = calculate_y_density(joined)
220        stats = calculate_weighted_y_stats(density, y_min, y_max)
221        all_stats.append(stats)
222    
223    # Calculate mean of means and mean of stds
224    means = [s['mean'] for s in all_stats]
225    stds = [s['std'] for s in all_stats]
226    
227    global_mean = np.mean(means)
228    global_std = np.mean(stds)
229    
230    return {
231        'global_mean': global_mean,
232        'global_std': global_std,
233        'all_stats': all_stats,
234        'all_means': means,
235        'all_stds': stds,
236        'y_min': y_min,
237        'y_max': y_max
238    }
239
240
241def calculate_x_projection(mask, global_stats, std_multiplier=2.0):
242    """Calculate X-axis projection within the global Y-zone.
243    
244    Sums pixels along columns (Y-axis) within the vertical zone defined by global
245    statistics. Creates a 1D signal showing horizontal distribution of shoot mass.
246    
247    Args:
248        mask: Binary mask with values 0 and 255.
249        global_stats: Global statistics from calculate_global_y_stats.
250        std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0).
251        
252    Returns:
253        1D numpy array of pixel counts per column (length = mask width).
254        
255    Example:
256        >>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0)
257        >>> import matplotlib.pyplot as plt
258        >>> plt.plot(x_proj)
259        >>> plt.xlabel('X coordinate')
260        >>> plt.ylabel('Pixel count')
261        >>> plt.show()
262    """
263    # Calculate Y bounds
264    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
265    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
266    
267    # Extract Y-zone region and sum along Y axis
268    zone_region = mask[global_upper:global_lower, :]
269    x_projection = np.sum(zone_region > 0, axis=0)
270    
271    return x_projection
272
273
274def find_shoot_peaks(x_projection, n_peaks=5, min_distance=300, x_min=1000):
275    """Find peak locations in X-projection using scipy peak detection.
276    
277    Identifies the N strongest peaks in the horizontal projection that satisfy
278    spacing and position constraints. Returns peaks sorted left to right.
279    
280    Args:
281        x_projection: 1D array of pixel counts from calculate_x_projection.
282        n_peaks: Number of peaks to find (default: 5).
283        min_distance: Minimum distance between peaks in pixels (default: 300).
284        x_min: Minimum X coordinate to consider (default: 1000).
285        
286    Returns:
287        Numpy array of peak X-coordinates, sorted left to right.
288        
289    Example:
290        >>> x_proj = calculate_x_projection(mask, global_stats)
291        >>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300)
292        >>> print(f"Found peaks at X positions: {peaks}")
293    """
294    # Mask out x < x_min
295    masked_projection = x_projection.copy()
296    masked_projection[:x_min] = 0
297    
298    # Find peaks with minimum distance constraint
299    peaks, properties = find_peaks(masked_projection, distance=min_distance)
300    
301    # Get peak heights and select top n_peaks
302    peak_heights = masked_projection[peaks]
303    top_indices = np.argsort(peak_heights)[-n_peaks:]
304    top_peaks = peaks[top_indices]
305    
306    # Sort by position (left to right)
307    top_peaks = np.sort(top_peaks)
308    
309    return top_peaks
310
311
312def validate_peak_quality(x_projection, peaks, min_peak_height=10, min_peak_width=20):
313    """Check if detected peaks are likely to be real shoots versus noise.
314    
315    Validates peak quality by checking both height (signal strength) and width
316    (spatial extent). Narrow spikes indicate noise, while wide peaks indicate
317    actual shoots. Requires at least 3 peaks to pass validation.
318    
319    Args:
320        x_projection: 1D array of pixel counts.
321        peaks: Array of peak X-coordinates from find_shoot_peaks.
322        min_peak_height: Minimum height for a valid peak (default: 10).
323        min_peak_width: Minimum width at half-height for valid peak (default: 20).
324        
325    Returns:
326        Boolean indicating if peaks are high quality (True) or likely noise (False).
327        
328    Example:
329        >>> peaks = find_shoot_peaks(x_proj, n_peaks=5)
330        >>> if validate_peak_quality(x_proj, peaks):
331        ...     print("High quality peaks detected")
332        ... else:
333        ...     print("Peaks may be noise, consider widening Y-zone")
334    """
335    peak_heights = x_projection[peaks]
336    
337    # Calculate peak widths at half prominence
338    widths, _, _, _ = peak_widths(x_projection, peaks, rel_height=0.5)
339    
340    # Check if at least 3 peaks are both tall enough AND wide enough
341    valid_peaks = np.sum((peak_heights >= min_peak_height) & (widths >= min_peak_width))
342    
343    return valid_peaks >= 3
344
345
346def find_shoot_peaks_with_size_fallback(mask, global_stats, n_peaks=5, min_distance=300,
347                                        x_min=1000, initial_std=2.0, max_std=3.5,
348                                        std_step=0.25, min_area=500, min_peak_width=20,
349                                        quality_check_threshold=2.5):
350    """Adaptively find shoot peaks with fallback to size-based detection.
351    
352    Three-stage detection strategy:
353    1. Normal cases (std 2.0-2.25): X-projection peaks without quality checks
354    2. Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation
355    3. Size fallback: Largest components by area when projection methods fail
356    
357    This handles both typical shoots and edge cases like failed germination or
358    shoots that have fallen outside the typical vertical zone.
359    
360    Args:
361        mask: Binary mask with values 0 and 255.
362        global_stats: Global statistics from calculate_global_y_stats.
363        n_peaks: Target number of peaks (default: 5).
364        min_distance: Minimum distance between peaks in pixels (default: 300).
365        x_min: Minimum X coordinate to consider (default: 1000).
366        initial_std: Starting std multiplier (default: 2.0).
367        max_std: Maximum std multiplier to try (default: 3.5).
368        std_step: Step size for widening (default: 0.25).
369        min_area: Minimum area for size fallback in pixels (default: 500).
370        min_peak_width: Minimum peak width for quality validation (default: 20).
371        quality_check_threshold: Std threshold to activate quality checks (default: 2.5).
372        
373    Returns:
374        Tuple of (peaks, x_projection, std_multiplier_used, method_used) where:
375            - peaks: Array of peak X-coordinates
376            - x_projection: The X-projection array used
377            - std_multiplier_used: The std multiplier that succeeded
378            - method_used: Either "projection" or "size_fallback"
379        
380    Example:
381        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
382        ...     mask, global_stats, quality_check_threshold=2.5
383        ... )
384        >>> print(f"Method: {method}, Std: {std:.2f}")
385        >>> print(f"Peaks at: {peaks}")
386    """
387    std_multiplier = initial_std
388    
389    # Try standard X-projection approach with progressive widening
390    while std_multiplier <= max_std:
391        x_projection = calculate_x_projection(mask, global_stats, std_multiplier)
392        peaks = find_shoot_peaks(x_projection, n_peaks, min_distance, x_min)
393        
394        if len(peaks) >= n_peaks:
395            # Only check quality if we're in desperate territory (high std)
396            if std_multiplier >= quality_check_threshold:
397                if validate_peak_quality(x_projection, peaks, min_peak_height=10, 
398                                       min_peak_width=min_peak_width):
399                    print(f"Found {len(peaks)} quality peaks with std_multiplier={std_multiplier:.2f}")
400                    return peaks, x_projection, std_multiplier, "projection"
401                else:
402                    print(f"Found {len(peaks)} peaks but quality too low (std={std_multiplier:.2f})")
403            else:
404                # Low std values - just trust the peaks
405                print(f"Found {len(peaks)} peaks with std_multiplier={std_multiplier:.2f}")
406                return peaks, x_projection, std_multiplier, "projection"
407        
408        std_multiplier += std_step
409    
410    # Fallback: Use connected components filtered by size and X position
411    print(f"Projection method insufficient, trying size-based fallback...")
412    
413    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
414    
415    valid_components = []
416    for i in range(1, num_labels):
417        area = stats[i, cv2.CC_STAT_AREA]
418        x_left = stats[i, cv2.CC_STAT_LEFT]
419        width = stats[i, cv2.CC_STAT_WIDTH]
420        bbox_center = x_left + width // 2
421        
422        if area >= min_area and bbox_center >= x_min:
423            valid_components.append({
424                'label': i,
425                'bbox_center': bbox_center,
426                'area': area
427            })
428    
429    # Sort by area (largest first) and take top n_peaks
430    valid_components.sort(key=lambda x: x['area'], reverse=True)
431    selected = valid_components[:n_peaks]
432    
433    # Extract X positions and sort left to right
434    peaks = np.array([c['bbox_center'] for c in selected])
435    peaks = np.sort(peaks)
436    
437    # Create full projection for visualization
438    x_projection = np.sum(mask > 0, axis=0)
439    
440    print(f"Found {len(peaks)} peaks using size-based fallback (min_area={min_area})")
441    return peaks, x_projection, max_std, "size_fallback"
442
443
444def merge_with_narrow_bands(mask, peaks, global_stats, std_multiplier, band_width=100):
445    """Find components near peaks and keep entire components (not clipped to bands).
446    
447    Uses narrow X-bands around detected peaks to identify which components belong
448    to each shoot location. Keeps complete components rather than clipping them
449    to band boundaries, preserving shoot morphology.
450    
451    Args:
452        mask: Binary mask with values 0 and 255.
453        peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback.
454        global_stats: Global statistics from calculate_global_y_stats.
455        std_multiplier: Std multiplier used for peak detection.
456        band_width: Half-width of X-band for identifying components (default: 100).
457        
458    Returns:
459        Binary mask with complete components near peaks (values 0 and 255).
460        
461    Example:
462        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
463        ...     mask, global_stats
464        ... )
465        >>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, 
466        ...                                    band_width=100)
467    """
468    # Find all connected components in the full mask
469    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
470    
471    # Create empty output
472    output_mask = np.zeros_like(mask)
473    
474    # For each peak, find components whose center falls in the band
475    kept_labels = set()
476    
477    for i, peak_x in enumerate(peaks):
478        x_left = peak_x - band_width
479        x_right = peak_x + band_width
480        
481        # Check each component
482        for label in range(1, num_labels):
483            if label in kept_labels:
484                continue  # Already assigned to another peak
485            
486            # Get component's X-center
487            comp_x_left = stats[label, cv2.CC_STAT_LEFT]
488            comp_width = stats[label, cv2.CC_STAT_WIDTH]
489            comp_x_center = comp_x_left + comp_width // 2
490            
491            # If center is in this peak's band, keep the ENTIRE component
492            if x_left <= comp_x_center <= x_right:
493                output_mask[labels == label] = 255
494                kept_labels.add(label)
495                print(f"  Peak {i+1} at X={peak_x}: keeping component with center at X={comp_x_center}")
496    
497    return output_mask
498
499
500def filter_one_per_peak(mask, peaks, global_stats, std_multiplier, band_width=100,
501                       require_y_zone=True):
502    """Keep the largest component near each peak, optionally validating Y-zone.
503    
504    For each detected peak, identifies all components within the X-band and keeps
505    only the largest one. Optionally filters out components that don't touch the
506    Y-zone (disabled for size-fallback method to handle fallen shoots).
507    
508    Args:
509        mask: Binary mask with values 0 and 255.
510        peaks: Array of peak X-coordinates.
511        global_stats: Global statistics from calculate_global_y_stats.
512        std_multiplier: Std multiplier used for peak detection.
513        band_width: Half-width for assigning components to peaks (default: 100).
514        require_y_zone: If True, only keep components touching Y-zone (default: True).
515        
516    Returns:
517        Binary mask with exactly one component per peak (values 0 and 255).
518        
519    Example:
520        >>> # For projection method (require Y-zone)
521        >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
522        ...                                  band_width=100, require_y_zone=True)
523        >>> 
524        >>> # For size fallback (allow fallen shoots outside Y-zone)
525        >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
526        ...                                  band_width=100, require_y_zone=False)
527    
528    Note:
529        Set require_y_zone=False when using size-based fallback method to preserve
530        large shoots that have fallen outside the typical vertical zone.
531    """
532    # Calculate Y-zone bounds
533    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
534    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
535    
536    # Find all components
537    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
538    
539    output_mask = np.zeros_like(mask)
540    
541    for peak_x in peaks:
542        x_left = peak_x - band_width
543        x_right = peak_x + band_width
544        
545        # Find all components in this peak's band
546        candidates = []
547        for i in range(1, num_labels):
548            x_center = stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] // 2
549            
550            if x_left <= x_center <= x_right:
551                area = stats[i, cv2.CC_STAT_AREA]
552                
553                if require_y_zone:
554                    # Check if component touches Y-zone
555                    component_mask = (labels == i).astype(np.uint8)
556                    y_zone_pixels = np.sum(component_mask[global_upper:global_lower, :])
557                    
558                    if y_zone_pixels > 0:
559                        candidates.append((i, area))
560                else:
561                    # No Y-zone requirement - keep all
562                    candidates.append((i, area))
563        
564        # Keep the largest one from this peak
565        if candidates:
566            best_label = max(candidates, key=lambda x: x[1])[0]
567            output_mask[labels == best_label] = 255
568            print(f"  Peak at X={peak_x}: keeping component {best_label} "
569                  f"(area={stats[best_label, cv2.CC_STAT_AREA]})")
570        else:
571            print(f"  Peak at X={peak_x}: WARNING - no valid components found!")
572    
573    return output_mask
574
575
576def clean_shoot_mask_pipeline(mask_path, global_stats, closing_kernel_size=7,
577                              closing_iterations=5, quality_check_threshold=2.5,
578                              band_width=100):
579    """Complete pipeline to clean a shoot mask from raw predictions to final output.
580    
581    Runs the full processing workflow:
582    1. Load mask and apply initial morphological closing
583    2. Detect 5 shoot locations using adaptive peak finding
584    3. Filter to components near detected peaks
585    4. Keep exactly one component per peak (largest in each band)
586    
587    Args:
588        mask_path: Path to shoot mask file (string or Path object).
589        global_stats: Global statistics from calculate_global_y_stats.
590        closing_kernel_size: Initial closing kernel size (default: 7).
591        closing_iterations: Initial closing iterations (default: 5).
592        quality_check_threshold: Std threshold for quality checks (default: 2.5).
593        band_width: X-band width around peaks in pixels (default: 100).
594        
595    Returns:
596        Dictionary with keys:
597            - 'cleaned_mask': Final cleaned binary mask (0 and 255)
598            - 'peaks': Detected peak X-coordinates
599            - 'std_used': Std multiplier that succeeded
600            - 'method': Detection method ("projection" or "size_fallback")
601            - 'num_components': Number of components in final mask
602            - 'filename': Input filename
603        
604    Example:
605        >>> from pathlib import Path
606        >>> 
607        >>> # Process all masks
608        >>> mask_dir = Path('data/shoot_masks')
609        >>> for mask_file in mask_dir.glob('*.png'):
610        ...     result = clean_shoot_mask_pipeline(str(mask_file), global_stats)
611        ...     
612        ...     # Save cleaned mask
613        ...     output_path = f"cleaned/{result['filename']}"
614        ...     cv2.imwrite(output_path, result['cleaned_mask'])
615        ...     
616        ...     # Log results
617        ...     if result['num_components'] != 5:
618        ...         print(f"WARNING: {result['filename']} has "
619        ...               f"{result['num_components']} components")
620    
621    Note:
622        Parameters are optimized for typical plant imaging conditions:
623        - Shoots at X > 1000 pixels
624        - Shoots between Y = 200-750 pixels (adaptive per image)
625        - ~400-500 pixel spacing between plants
626    """
627    print(f"Processing: {Path(mask_path).name}")
628    
629    # Step 1: Load and initial joining
630    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
631    joined_mask = join_shoot_fragments(mask, closing_kernel_size, closing_iterations)
632    
633    # Step 2: Find shoot locations
634    peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback(
635        joined_mask, global_stats,
636        quality_check_threshold=quality_check_threshold
637    )
638    print(f"  Found {len(peaks)} peaks using {method} (std={std_used:.2f})")
639    
640    # Step 3: Filter to narrow bands around peaks
641    narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 
642                                           std_used, band_width=band_width)
643    
644    # Step 4: Keep one component per peak
645    require_y_zone = (method != "size_fallback")
646    final_mask = filter_one_per_peak(narrow_result, peaks, global_stats, std_used,
647                                    band_width=band_width, require_y_zone=require_y_zone)
648    
649    # Count final components
650    num_labels, _, _, _ = cv2.connectedComponentsWithStats(final_mask, connectivity=8)
651    num_components = num_labels - 1
652    
653    return {
654        'cleaned_mask': final_mask,
655        'peaks': peaks,
656        'std_used': std_used,
657        'method': method,
658        'num_components': num_components,
659        'filename': Path(mask_path).name
660    }
661
662
663# Recommended parameters as a reference
664RECOMMENDED_PARAMS = {
665    'initial_closing': {
666        'kernel_size': 7,
667        'iterations': 5
668    },
669    'y_zone_boundaries': {
670        'y_min': 200,  # Shoots never above this
671        'y_max': 750   # Shoots never below this
672    },
673    'peak_detection': {
674        'n_peaks': 5,
675        'min_distance': 300,          # Minimum spacing between shoots
676        'x_min': 1000,                # Shoots never left of this
677        'initial_std': 2.0,           # Starting Y-zone width
678        'max_std': 3.5,               # Maximum Y-zone width
679        'quality_check_threshold': 2.5,  # When to validate peaks
680        'min_peak_width': 20,         # Minimum width for valid peaks
681        'min_peak_height': 10,        # Minimum height for valid peaks
682        'min_area': 500               # For size-based fallback
683    },
684    'filtering': {
685        'band_width': 100  # X-distance around peaks
686    }
687}
def join_shoot_fragments(mask, kernel_size=7, iterations=5):
 90def join_shoot_fragments(mask, kernel_size=7, iterations=5):
 91    """Join small shoot fragments using morphological closing.
 92    
 93    Applies morphological closing to connect nearby fragments and fill small gaps
 94    in the segmentation mask. This preprocessing step improves peak detection by
 95    creating more continuous shoot structures.
 96    
 97    Args:
 98        mask: Binary mask array with values 0 and 255.
 99        kernel_size: Size of square structuring element (default: 7).
100        iterations: Number of closing iterations (default: 5).
101        
102    Returns:
103        Binary mask array with joined fragments (values 0 and 255).
104        
105    Example:
106        >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
107        >>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3)
108        >>> cv2.imwrite('joined_mask.png', joined)
109    """
110    kernel = np.ones((kernel_size, kernel_size), np.uint8)
111    joined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=iterations)
112    return joined

Join small shoot fragments using morphological closing.

Applies morphological closing to connect nearby fragments and fill small gaps in the segmentation mask. This preprocessing step improves peak detection by creating more continuous shoot structures.

Arguments:
  • mask: Binary mask array with values 0 and 255.
  • kernel_size: Size of square structuring element (default: 7).
  • iterations: Number of closing iterations (default: 5).
Returns:

Binary mask array with joined fragments (values 0 and 255).

Example:
>>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
>>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3)
>>> cv2.imwrite('joined_mask.png', joined)
def calculate_y_density(mask):
115def calculate_y_density(mask):
116    """Calculate pixel density per Y coordinate.
117    
118    Sums the number of foreground pixels along each row to create a 1D density
119    profile. Used for finding the vertical position of shoots.
120    
121    Args:
122        mask: Binary mask with values 0 and 255.
123        
124    Returns:
125        1D numpy array of pixel counts per row (length = mask height).
126        
127    Example:
128        >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
129        >>> density = calculate_y_density(mask)
130        >>> print(f"Peak density at Y={np.argmax(density)}")
131    """
132    return np.sum(mask > 0, axis=1)

Calculate pixel density per Y coordinate.

Sums the number of foreground pixels along each row to create a 1D density profile. Used for finding the vertical position of shoots.

Arguments:
  • mask: Binary mask with values 0 and 255.
Returns:

1D numpy array of pixel counts per row (length = mask height).

Example:
>>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE)
>>> density = calculate_y_density(mask)
>>> print(f"Peak density at Y={np.argmax(density)}")
def calculate_weighted_y_stats(density, y_min=200, y_max=750):
135def calculate_weighted_y_stats(density, y_min=200, y_max=750):
136    """Calculate weighted mean and standard deviation of Y positions within ROI.
137    
138    Computes the center of mass and spread of shoot pixels along the Y-axis,
139    restricted to the valid Y-range. Pixels outside the ROI are ignored to
140    prevent noise from biasing statistics.
141    
142    Args:
143        density: 1D array of pixel counts per row from calculate_y_density.
144        y_min: Minimum Y coordinate to consider (default: 200).
145        y_max: Maximum Y coordinate to consider (default: 750).
146        
147    Returns:
148        Dictionary with keys:
149            - 'mean': Weighted mean Y position (pixels).
150            - 'std': Weighted standard deviation (pixels).
151        
152    Example:
153        >>> density = calculate_y_density(mask)
154        >>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
155        >>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}")
156    """
157    # Clip density to ROI
158    roi_density = density.copy()
159    roi_density[:y_min] = 0
160    roi_density[y_max:] = 0
161    
162    y_positions = np.arange(len(roi_density))
163    total_pixels = np.sum(roi_density)
164    
165    if total_pixels == 0:
166        return {'mean': 0, 'std': 0}
167    
168    # Weighted mean (center of mass)
169    weighted_mean = np.sum(y_positions * roi_density) / total_pixels
170    
171    # Weighted standard deviation
172    weighted_var = np.sum(roi_density * (y_positions - weighted_mean)**2) / total_pixels
173    weighted_std = np.sqrt(weighted_var)
174    
175    return {
176        'mean': weighted_mean,
177        'std': weighted_std
178    }

Calculate weighted mean and standard deviation of Y positions within ROI.

Computes the center of mass and spread of shoot pixels along the Y-axis, restricted to the valid Y-range. Pixels outside the ROI are ignored to prevent noise from biasing statistics.

Arguments:
  • density: 1D array of pixel counts per row from calculate_y_density.
  • y_min: Minimum Y coordinate to consider (default: 200).
  • y_max: Maximum Y coordinate to consider (default: 750).
Returns:

Dictionary with keys: - 'mean': Weighted mean Y position (pixels). - 'std': Weighted standard deviation (pixels).

Example:
>>> density = calculate_y_density(mask)
>>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
>>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}")
def calculate_global_y_stats(mask_paths, kernel_size=7, iterations=5, y_min=200, y_max=750):
181def calculate_global_y_stats(mask_paths, kernel_size=7, iterations=5, y_min=200, y_max=750):
182    """Calculate global weighted mean and standard deviation across all masks.
183    
184    Processes all masks to establish global shoot position statistics. These
185    statistics serve as priors for individual image processing, enabling adaptive
186    Y-zone detection that handles variation in shoot positions.
187    
188    Args:
189        mask_paths: List of paths (strings or Path objects) to mask files.
190        kernel_size: Kernel size for initial morphological closing (default: 7).
191        iterations: Number of closing iterations (default: 5).
192        y_min: Minimum Y coordinate to consider (default: 200).
193        y_max: Maximum Y coordinate to consider (default: 750).
194        
195    Returns:
196        Dictionary with keys:
197            - 'global_mean': Mean Y position across all images (pixels).
198            - 'global_std': Mean standard deviation across all images (pixels).
199            - 'all_stats': List of per-image statistics dictionaries.
200            - 'all_means': List of individual image means.
201            - 'all_stds': List of individual image standard deviations.
202            - 'y_min': Y minimum boundary used.
203            - 'y_max': Y maximum boundary used.
204        
205    Example:
206        >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
207        >>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750)
208        >>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± "
209        ...       f"{2 * global_stats['global_std']:.0f} pixels")
210    
211    Note:
212        This function should be called once per dataset to establish the global
213        statistics used for processing all individual images.
214    """
215    all_stats = []
216    
217    for path in mask_paths:
218        mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
219        joined = join_shoot_fragments(mask, kernel_size, iterations)
220        density = calculate_y_density(joined)
221        stats = calculate_weighted_y_stats(density, y_min, y_max)
222        all_stats.append(stats)
223    
224    # Calculate mean of means and mean of stds
225    means = [s['mean'] for s in all_stats]
226    stds = [s['std'] for s in all_stats]
227    
228    global_mean = np.mean(means)
229    global_std = np.mean(stds)
230    
231    return {
232        'global_mean': global_mean,
233        'global_std': global_std,
234        'all_stats': all_stats,
235        'all_means': means,
236        'all_stds': stds,
237        'y_min': y_min,
238        'y_max': y_max
239    }

Calculate global weighted mean and standard deviation across all masks.

Processes all masks to establish global shoot position statistics. These statistics serve as priors for individual image processing, enabling adaptive Y-zone detection that handles variation in shoot positions.

Arguments:
  • mask_paths: List of paths (strings or Path objects) to mask files.
  • kernel_size: Kernel size for initial morphological closing (default: 7).
  • iterations: Number of closing iterations (default: 5).
  • y_min: Minimum Y coordinate to consider (default: 200).
  • y_max: Maximum Y coordinate to consider (default: 750).
Returns:

Dictionary with keys: - 'global_mean': Mean Y position across all images (pixels). - 'global_std': Mean standard deviation across all images (pixels). - 'all_stats': List of per-image statistics dictionaries. - 'all_means': List of individual image means. - 'all_stds': List of individual image standard deviations. - 'y_min': Y minimum boundary used. - 'y_max': Y maximum boundary used.

Example:
>>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
>>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750)
>>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± "
...       f"{2 * global_stats['global_std']:.0f} pixels")
Note:

This function should be called once per dataset to establish the global statistics used for processing all individual images.

def calculate_x_projection(mask, global_stats, std_multiplier=2.0):
242def calculate_x_projection(mask, global_stats, std_multiplier=2.0):
243    """Calculate X-axis projection within the global Y-zone.
244    
245    Sums pixels along columns (Y-axis) within the vertical zone defined by global
246    statistics. Creates a 1D signal showing horizontal distribution of shoot mass.
247    
248    Args:
249        mask: Binary mask with values 0 and 255.
250        global_stats: Global statistics from calculate_global_y_stats.
251        std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0).
252        
253    Returns:
254        1D numpy array of pixel counts per column (length = mask width).
255        
256    Example:
257        >>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0)
258        >>> import matplotlib.pyplot as plt
259        >>> plt.plot(x_proj)
260        >>> plt.xlabel('X coordinate')
261        >>> plt.ylabel('Pixel count')
262        >>> plt.show()
263    """
264    # Calculate Y bounds
265    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
266    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
267    
268    # Extract Y-zone region and sum along Y axis
269    zone_region = mask[global_upper:global_lower, :]
270    x_projection = np.sum(zone_region > 0, axis=0)
271    
272    return x_projection

Calculate X-axis projection within the global Y-zone.

Sums pixels along columns (Y-axis) within the vertical zone defined by global statistics. Creates a 1D signal showing horizontal distribution of shoot mass.

Arguments:
  • mask: Binary mask with values 0 and 255.
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0).
Returns:

1D numpy array of pixel counts per column (length = mask width).

Example:
>>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0)
>>> import matplotlib.pyplot as plt
>>> plt.plot(x_proj)
>>> plt.xlabel('X coordinate')
>>> plt.ylabel('Pixel count')
>>> plt.show()
def find_shoot_peaks(x_projection, n_peaks=5, min_distance=300, x_min=1000):
275def find_shoot_peaks(x_projection, n_peaks=5, min_distance=300, x_min=1000):
276    """Find peak locations in X-projection using scipy peak detection.
277    
278    Identifies the N strongest peaks in the horizontal projection that satisfy
279    spacing and position constraints. Returns peaks sorted left to right.
280    
281    Args:
282        x_projection: 1D array of pixel counts from calculate_x_projection.
283        n_peaks: Number of peaks to find (default: 5).
284        min_distance: Minimum distance between peaks in pixels (default: 300).
285        x_min: Minimum X coordinate to consider (default: 1000).
286        
287    Returns:
288        Numpy array of peak X-coordinates, sorted left to right.
289        
290    Example:
291        >>> x_proj = calculate_x_projection(mask, global_stats)
292        >>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300)
293        >>> print(f"Found peaks at X positions: {peaks}")
294    """
295    # Mask out x < x_min
296    masked_projection = x_projection.copy()
297    masked_projection[:x_min] = 0
298    
299    # Find peaks with minimum distance constraint
300    peaks, properties = find_peaks(masked_projection, distance=min_distance)
301    
302    # Get peak heights and select top n_peaks
303    peak_heights = masked_projection[peaks]
304    top_indices = np.argsort(peak_heights)[-n_peaks:]
305    top_peaks = peaks[top_indices]
306    
307    # Sort by position (left to right)
308    top_peaks = np.sort(top_peaks)
309    
310    return top_peaks

Find peak locations in X-projection using scipy peak detection.

Identifies the N strongest peaks in the horizontal projection that satisfy spacing and position constraints. Returns peaks sorted left to right.

Arguments:
  • x_projection: 1D array of pixel counts from calculate_x_projection.
  • n_peaks: Number of peaks to find (default: 5).
  • min_distance: Minimum distance between peaks in pixels (default: 300).
  • x_min: Minimum X coordinate to consider (default: 1000).
Returns:

Numpy array of peak X-coordinates, sorted left to right.

Example:
>>> x_proj = calculate_x_projection(mask, global_stats)
>>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300)
>>> print(f"Found peaks at X positions: {peaks}")
def validate_peak_quality(x_projection, peaks, min_peak_height=10, min_peak_width=20):
313def validate_peak_quality(x_projection, peaks, min_peak_height=10, min_peak_width=20):
314    """Check if detected peaks are likely to be real shoots versus noise.
315    
316    Validates peak quality by checking both height (signal strength) and width
317    (spatial extent). Narrow spikes indicate noise, while wide peaks indicate
318    actual shoots. Requires at least 3 peaks to pass validation.
319    
320    Args:
321        x_projection: 1D array of pixel counts.
322        peaks: Array of peak X-coordinates from find_shoot_peaks.
323        min_peak_height: Minimum height for a valid peak (default: 10).
324        min_peak_width: Minimum width at half-height for valid peak (default: 20).
325        
326    Returns:
327        Boolean indicating if peaks are high quality (True) or likely noise (False).
328        
329    Example:
330        >>> peaks = find_shoot_peaks(x_proj, n_peaks=5)
331        >>> if validate_peak_quality(x_proj, peaks):
332        ...     print("High quality peaks detected")
333        ... else:
334        ...     print("Peaks may be noise, consider widening Y-zone")
335    """
336    peak_heights = x_projection[peaks]
337    
338    # Calculate peak widths at half prominence
339    widths, _, _, _ = peak_widths(x_projection, peaks, rel_height=0.5)
340    
341    # Check if at least 3 peaks are both tall enough AND wide enough
342    valid_peaks = np.sum((peak_heights >= min_peak_height) & (widths >= min_peak_width))
343    
344    return valid_peaks >= 3

Check if detected peaks are likely to be real shoots versus noise.

Validates peak quality by checking both height (signal strength) and width (spatial extent). Narrow spikes indicate noise, while wide peaks indicate actual shoots. Requires at least 3 peaks to pass validation.

Arguments:
  • x_projection: 1D array of pixel counts.
  • peaks: Array of peak X-coordinates from find_shoot_peaks.
  • min_peak_height: Minimum height for a valid peak (default: 10).
  • min_peak_width: Minimum width at half-height for valid peak (default: 20).
Returns:

Boolean indicating if peaks are high quality (True) or likely noise (False).

Example:
>>> peaks = find_shoot_peaks(x_proj, n_peaks=5)
>>> if validate_peak_quality(x_proj, peaks):
...     print("High quality peaks detected")
... else:
...     print("Peaks may be noise, consider widening Y-zone")
def find_shoot_peaks_with_size_fallback( mask, global_stats, n_peaks=5, min_distance=300, x_min=1000, initial_std=2.0, max_std=3.5, std_step=0.25, min_area=500, min_peak_width=20, quality_check_threshold=2.5):
347def find_shoot_peaks_with_size_fallback(mask, global_stats, n_peaks=5, min_distance=300,
348                                        x_min=1000, initial_std=2.0, max_std=3.5,
349                                        std_step=0.25, min_area=500, min_peak_width=20,
350                                        quality_check_threshold=2.5):
351    """Adaptively find shoot peaks with fallback to size-based detection.
352    
353    Three-stage detection strategy:
354    1. Normal cases (std 2.0-2.25): X-projection peaks without quality checks
355    2. Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation
356    3. Size fallback: Largest components by area when projection methods fail
357    
358    This handles both typical shoots and edge cases like failed germination or
359    shoots that have fallen outside the typical vertical zone.
360    
361    Args:
362        mask: Binary mask with values 0 and 255.
363        global_stats: Global statistics from calculate_global_y_stats.
364        n_peaks: Target number of peaks (default: 5).
365        min_distance: Minimum distance between peaks in pixels (default: 300).
366        x_min: Minimum X coordinate to consider (default: 1000).
367        initial_std: Starting std multiplier (default: 2.0).
368        max_std: Maximum std multiplier to try (default: 3.5).
369        std_step: Step size for widening (default: 0.25).
370        min_area: Minimum area for size fallback in pixels (default: 500).
371        min_peak_width: Minimum peak width for quality validation (default: 20).
372        quality_check_threshold: Std threshold to activate quality checks (default: 2.5).
373        
374    Returns:
375        Tuple of (peaks, x_projection, std_multiplier_used, method_used) where:
376            - peaks: Array of peak X-coordinates
377            - x_projection: The X-projection array used
378            - std_multiplier_used: The std multiplier that succeeded
379            - method_used: Either "projection" or "size_fallback"
380        
381    Example:
382        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
383        ...     mask, global_stats, quality_check_threshold=2.5
384        ... )
385        >>> print(f"Method: {method}, Std: {std:.2f}")
386        >>> print(f"Peaks at: {peaks}")
387    """
388    std_multiplier = initial_std
389    
390    # Try standard X-projection approach with progressive widening
391    while std_multiplier <= max_std:
392        x_projection = calculate_x_projection(mask, global_stats, std_multiplier)
393        peaks = find_shoot_peaks(x_projection, n_peaks, min_distance, x_min)
394        
395        if len(peaks) >= n_peaks:
396            # Only check quality if we're in desperate territory (high std)
397            if std_multiplier >= quality_check_threshold:
398                if validate_peak_quality(x_projection, peaks, min_peak_height=10, 
399                                       min_peak_width=min_peak_width):
400                    print(f"Found {len(peaks)} quality peaks with std_multiplier={std_multiplier:.2f}")
401                    return peaks, x_projection, std_multiplier, "projection"
402                else:
403                    print(f"Found {len(peaks)} peaks but quality too low (std={std_multiplier:.2f})")
404            else:
405                # Low std values - just trust the peaks
406                print(f"Found {len(peaks)} peaks with std_multiplier={std_multiplier:.2f}")
407                return peaks, x_projection, std_multiplier, "projection"
408        
409        std_multiplier += std_step
410    
411    # Fallback: Use connected components filtered by size and X position
412    print(f"Projection method insufficient, trying size-based fallback...")
413    
414    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
415    
416    valid_components = []
417    for i in range(1, num_labels):
418        area = stats[i, cv2.CC_STAT_AREA]
419        x_left = stats[i, cv2.CC_STAT_LEFT]
420        width = stats[i, cv2.CC_STAT_WIDTH]
421        bbox_center = x_left + width // 2
422        
423        if area >= min_area and bbox_center >= x_min:
424            valid_components.append({
425                'label': i,
426                'bbox_center': bbox_center,
427                'area': area
428            })
429    
430    # Sort by area (largest first) and take top n_peaks
431    valid_components.sort(key=lambda x: x['area'], reverse=True)
432    selected = valid_components[:n_peaks]
433    
434    # Extract X positions and sort left to right
435    peaks = np.array([c['bbox_center'] for c in selected])
436    peaks = np.sort(peaks)
437    
438    # Create full projection for visualization
439    x_projection = np.sum(mask > 0, axis=0)
440    
441    print(f"Found {len(peaks)} peaks using size-based fallback (min_area={min_area})")
442    return peaks, x_projection, max_std, "size_fallback"

Adaptively find shoot peaks with fallback to size-based detection.

Three-stage detection strategy:

  1. Normal cases (std 2.0-2.25): X-projection peaks without quality checks
  2. Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation
  3. Size fallback: Largest components by area when projection methods fail

This handles both typical shoots and edge cases like failed germination or shoots that have fallen outside the typical vertical zone.

Arguments:
  • mask: Binary mask with values 0 and 255.
  • global_stats: Global statistics from calculate_global_y_stats.
  • n_peaks: Target number of peaks (default: 5).
  • min_distance: Minimum distance between peaks in pixels (default: 300).
  • x_min: Minimum X coordinate to consider (default: 1000).
  • initial_std: Starting std multiplier (default: 2.0).
  • max_std: Maximum std multiplier to try (default: 3.5).
  • std_step: Step size for widening (default: 0.25).
  • min_area: Minimum area for size fallback in pixels (default: 500).
  • min_peak_width: Minimum peak width for quality validation (default: 20).
  • quality_check_threshold: Std threshold to activate quality checks (default: 2.5).
Returns:

Tuple of (peaks, x_projection, std_multiplier_used, method_used) where: - peaks: Array of peak X-coordinates - x_projection: The X-projection array used - std_multiplier_used: The std multiplier that succeeded - method_used: Either "projection" or "size_fallback"

Example:
>>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
...     mask, global_stats, quality_check_threshold=2.5
... )
>>> print(f"Method: {method}, Std: {std:.2f}")
>>> print(f"Peaks at: {peaks}")
def merge_with_narrow_bands(mask, peaks, global_stats, std_multiplier, band_width=100):
445def merge_with_narrow_bands(mask, peaks, global_stats, std_multiplier, band_width=100):
446    """Find components near peaks and keep entire components (not clipped to bands).
447    
448    Uses narrow X-bands around detected peaks to identify which components belong
449    to each shoot location. Keeps complete components rather than clipping them
450    to band boundaries, preserving shoot morphology.
451    
452    Args:
453        mask: Binary mask with values 0 and 255.
454        peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback.
455        global_stats: Global statistics from calculate_global_y_stats.
456        std_multiplier: Std multiplier used for peak detection.
457        band_width: Half-width of X-band for identifying components (default: 100).
458        
459    Returns:
460        Binary mask with complete components near peaks (values 0 and 255).
461        
462    Example:
463        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
464        ...     mask, global_stats
465        ... )
466        >>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, 
467        ...                                    band_width=100)
468    """
469    # Find all connected components in the full mask
470    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
471    
472    # Create empty output
473    output_mask = np.zeros_like(mask)
474    
475    # For each peak, find components whose center falls in the band
476    kept_labels = set()
477    
478    for i, peak_x in enumerate(peaks):
479        x_left = peak_x - band_width
480        x_right = peak_x + band_width
481        
482        # Check each component
483        for label in range(1, num_labels):
484            if label in kept_labels:
485                continue  # Already assigned to another peak
486            
487            # Get component's X-center
488            comp_x_left = stats[label, cv2.CC_STAT_LEFT]
489            comp_width = stats[label, cv2.CC_STAT_WIDTH]
490            comp_x_center = comp_x_left + comp_width // 2
491            
492            # If center is in this peak's band, keep the ENTIRE component
493            if x_left <= comp_x_center <= x_right:
494                output_mask[labels == label] = 255
495                kept_labels.add(label)
496                print(f"  Peak {i+1} at X={peak_x}: keeping component with center at X={comp_x_center}")
497    
498    return output_mask

Find components near peaks and keep entire components (not clipped to bands).

Uses narrow X-bands around detected peaks to identify which components belong to each shoot location. Keeps complete components rather than clipping them to band boundaries, preserving shoot morphology.

Arguments:
  • mask: Binary mask with values 0 and 255.
  • peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback.
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_multiplier: Std multiplier used for peak detection.
  • band_width: Half-width of X-band for identifying components (default: 100).
Returns:

Binary mask with complete components near peaks (values 0 and 255).

Example:
>>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
...     mask, global_stats
... )
>>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, 
...                                    band_width=100)
def filter_one_per_peak( mask, peaks, global_stats, std_multiplier, band_width=100, require_y_zone=True):
501def filter_one_per_peak(mask, peaks, global_stats, std_multiplier, band_width=100,
502                       require_y_zone=True):
503    """Keep the largest component near each peak, optionally validating Y-zone.
504    
505    For each detected peak, identifies all components within the X-band and keeps
506    only the largest one. Optionally filters out components that don't touch the
507    Y-zone (disabled for size-fallback method to handle fallen shoots).
508    
509    Args:
510        mask: Binary mask with values 0 and 255.
511        peaks: Array of peak X-coordinates.
512        global_stats: Global statistics from calculate_global_y_stats.
513        std_multiplier: Std multiplier used for peak detection.
514        band_width: Half-width for assigning components to peaks (default: 100).
515        require_y_zone: If True, only keep components touching Y-zone (default: True).
516        
517    Returns:
518        Binary mask with exactly one component per peak (values 0 and 255).
519        
520    Example:
521        >>> # For projection method (require Y-zone)
522        >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
523        ...                                  band_width=100, require_y_zone=True)
524        >>> 
525        >>> # For size fallback (allow fallen shoots outside Y-zone)
526        >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
527        ...                                  band_width=100, require_y_zone=False)
528    
529    Note:
530        Set require_y_zone=False when using size-based fallback method to preserve
531        large shoots that have fallen outside the typical vertical zone.
532    """
533    # Calculate Y-zone bounds
534    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
535    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
536    
537    # Find all components
538    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
539    
540    output_mask = np.zeros_like(mask)
541    
542    for peak_x in peaks:
543        x_left = peak_x - band_width
544        x_right = peak_x + band_width
545        
546        # Find all components in this peak's band
547        candidates = []
548        for i in range(1, num_labels):
549            x_center = stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] // 2
550            
551            if x_left <= x_center <= x_right:
552                area = stats[i, cv2.CC_STAT_AREA]
553                
554                if require_y_zone:
555                    # Check if component touches Y-zone
556                    component_mask = (labels == i).astype(np.uint8)
557                    y_zone_pixels = np.sum(component_mask[global_upper:global_lower, :])
558                    
559                    if y_zone_pixels > 0:
560                        candidates.append((i, area))
561                else:
562                    # No Y-zone requirement - keep all
563                    candidates.append((i, area))
564        
565        # Keep the largest one from this peak
566        if candidates:
567            best_label = max(candidates, key=lambda x: x[1])[0]
568            output_mask[labels == best_label] = 255
569            print(f"  Peak at X={peak_x}: keeping component {best_label} "
570                  f"(area={stats[best_label, cv2.CC_STAT_AREA]})")
571        else:
572            print(f"  Peak at X={peak_x}: WARNING - no valid components found!")
573    
574    return output_mask

Keep the largest component near each peak, optionally validating Y-zone.

For each detected peak, identifies all components within the X-band and keeps only the largest one. Optionally filters out components that don't touch the Y-zone (disabled for size-fallback method to handle fallen shoots).

Arguments:
  • mask: Binary mask with values 0 and 255.
  • peaks: Array of peak X-coordinates.
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_multiplier: Std multiplier used for peak detection.
  • band_width: Half-width for assigning components to peaks (default: 100).
  • require_y_zone: If True, only keep components touching Y-zone (default: True).
Returns:

Binary mask with exactly one component per peak (values 0 and 255).

Example:
>>> # For projection method (require Y-zone)
>>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
...                                  band_width=100, require_y_zone=True)
>>> 
>>> # For size fallback (allow fallen shoots outside Y-zone)
>>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used,
...                                  band_width=100, require_y_zone=False)
Note:

Set require_y_zone=False when using size-based fallback method to preserve large shoots that have fallen outside the typical vertical zone.

def clean_shoot_mask_pipeline( mask_path, global_stats, closing_kernel_size=7, closing_iterations=5, quality_check_threshold=2.5, band_width=100):
577def clean_shoot_mask_pipeline(mask_path, global_stats, closing_kernel_size=7,
578                              closing_iterations=5, quality_check_threshold=2.5,
579                              band_width=100):
580    """Complete pipeline to clean a shoot mask from raw predictions to final output.
581    
582    Runs the full processing workflow:
583    1. Load mask and apply initial morphological closing
584    2. Detect 5 shoot locations using adaptive peak finding
585    3. Filter to components near detected peaks
586    4. Keep exactly one component per peak (largest in each band)
587    
588    Args:
589        mask_path: Path to shoot mask file (string or Path object).
590        global_stats: Global statistics from calculate_global_y_stats.
591        closing_kernel_size: Initial closing kernel size (default: 7).
592        closing_iterations: Initial closing iterations (default: 5).
593        quality_check_threshold: Std threshold for quality checks (default: 2.5).
594        band_width: X-band width around peaks in pixels (default: 100).
595        
596    Returns:
597        Dictionary with keys:
598            - 'cleaned_mask': Final cleaned binary mask (0 and 255)
599            - 'peaks': Detected peak X-coordinates
600            - 'std_used': Std multiplier that succeeded
601            - 'method': Detection method ("projection" or "size_fallback")
602            - 'num_components': Number of components in final mask
603            - 'filename': Input filename
604        
605    Example:
606        >>> from pathlib import Path
607        >>> 
608        >>> # Process all masks
609        >>> mask_dir = Path('data/shoot_masks')
610        >>> for mask_file in mask_dir.glob('*.png'):
611        ...     result = clean_shoot_mask_pipeline(str(mask_file), global_stats)
612        ...     
613        ...     # Save cleaned mask
614        ...     output_path = f"cleaned/{result['filename']}"
615        ...     cv2.imwrite(output_path, result['cleaned_mask'])
616        ...     
617        ...     # Log results
618        ...     if result['num_components'] != 5:
619        ...         print(f"WARNING: {result['filename']} has "
620        ...               f"{result['num_components']} components")
621    
622    Note:
623        Parameters are optimized for typical plant imaging conditions:
624        - Shoots at X > 1000 pixels
625        - Shoots between Y = 200-750 pixels (adaptive per image)
626        - ~400-500 pixel spacing between plants
627    """
628    print(f"Processing: {Path(mask_path).name}")
629    
630    # Step 1: Load and initial joining
631    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
632    joined_mask = join_shoot_fragments(mask, closing_kernel_size, closing_iterations)
633    
634    # Step 2: Find shoot locations
635    peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback(
636        joined_mask, global_stats,
637        quality_check_threshold=quality_check_threshold
638    )
639    print(f"  Found {len(peaks)} peaks using {method} (std={std_used:.2f})")
640    
641    # Step 3: Filter to narrow bands around peaks
642    narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 
643                                           std_used, band_width=band_width)
644    
645    # Step 4: Keep one component per peak
646    require_y_zone = (method != "size_fallback")
647    final_mask = filter_one_per_peak(narrow_result, peaks, global_stats, std_used,
648                                    band_width=band_width, require_y_zone=require_y_zone)
649    
650    # Count final components
651    num_labels, _, _, _ = cv2.connectedComponentsWithStats(final_mask, connectivity=8)
652    num_components = num_labels - 1
653    
654    return {
655        'cleaned_mask': final_mask,
656        'peaks': peaks,
657        'std_used': std_used,
658        'method': method,
659        'num_components': num_components,
660        'filename': Path(mask_path).name
661    }

Complete pipeline to clean a shoot mask from raw predictions to final output.

Runs the full processing workflow:

  1. Load mask and apply initial morphological closing
  2. Detect 5 shoot locations using adaptive peak finding
  3. Filter to components near detected peaks
  4. Keep exactly one component per peak (largest in each band)
Arguments:
  • mask_path: Path to shoot mask file (string or Path object).
  • global_stats: Global statistics from calculate_global_y_stats.
  • closing_kernel_size: Initial closing kernel size (default: 7).
  • closing_iterations: Initial closing iterations (default: 5).
  • quality_check_threshold: Std threshold for quality checks (default: 2.5).
  • band_width: X-band width around peaks in pixels (default: 100).
Returns:

Dictionary with keys: - 'cleaned_mask': Final cleaned binary mask (0 and 255) - 'peaks': Detected peak X-coordinates - 'std_used': Std multiplier that succeeded - 'method': Detection method ("projection" or "size_fallback") - 'num_components': Number of components in final mask - 'filename': Input filename

Example:
>>> from pathlib import Path
>>> 
>>> # Process all masks
>>> mask_dir = Path('data/shoot_masks')
>>> for mask_file in mask_dir.glob('*.png'):
...     result = clean_shoot_mask_pipeline(str(mask_file), global_stats)
...     
...     # Save cleaned mask
...     output_path = f"cleaned/{result['filename']}"
...     cv2.imwrite(output_path, result['cleaned_mask'])
...     
...     # Log results
...     if result['num_components'] != 5:
...         print(f"WARNING: {result['filename']} has "
...               f"{result['num_components']} components")
Note:

Parameters are optimized for typical plant imaging conditions:

  • Shoots at X > 1000 pixels
  • Shoots between Y = 200-750 pixels (adaptive per image)
  • ~400-500 pixel spacing between plants