library.shoot_mask_cleaning
Shoot mask cleaning pipeline for plant root analysis.
This module provides functions for cleaning and filtering shoot segmentation masks from U-Net model predictions. The pipeline identifies exactly 5 shoot locations using adaptive Y-zone detection and X-axis peak finding, handling edge cases like failed germination and fallen shoots.
Typical workflow:
Step 1: Calculate global statistics across all masks (one-time setup)
>>> from pathlib import Path >>> import cv2 >>> from shoot_mask_cleaning import calculate_global_y_stats >>> >>> mask_dir = Path('data/shoot_masks') >>> mask_paths = [str(f) for f in sorted(mask_dir.glob('*.png'))] >>> >>> global_stats = calculate_global_y_stats( ... mask_paths, ... kernel_size=5, ... iterations=3, ... y_min=200, ... y_max=750 ... ) >>> print(f"Global mean Y: {global_stats['global_mean']:.1f}") >>> print(f"Global std Y: {global_stats['global_std']:.1f}")Step 2: Process individual masks
>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline >>> >>> for mask_path in mask_paths: ... result = clean_shoot_mask_pipeline(mask_path, global_stats) ... ... # Save cleaned mask ... output_path = f"cleaned/{result['filename']}" ... cv2.imwrite(output_path, result['cleaned_mask']) ... ... # Check results ... print(f"{result['filename']}: {result['method']}, " ... f"{result['num_components']} components")Step 3: Debug problematic images
>>> from shoot_mask_visualization import debug_peak_detection >>> >>> # Visualize complete pipeline for troubleshooting >>> debug_peak_detection('data/shoot_masks/problem_image.png', global_stats)
Recommended parameters (validated on 19 test images):
Initial closing:
- kernel_size: 7
- iterations: 5
Y-zone boundaries:
- y_min: 200 (shoots never above this)
- y_max: 750 (shoots never below this)
Peak detection:
- n_peaks: 5 (expected number of plants)
- min_distance: 300 (minimum spacing between shoots)
- x_min: 1000 (shoots never left of this)
- initial_std: 2.0 (starting Y-zone width)
- quality_check_threshold: 2.5 (when to validate peak quality)
- min_peak_width: 20 (minimum width for valid peaks)
- min_peak_height: 10 (minimum height for valid peaks)
- min_area: 500 (for size-based fallback)
Filtering:
- band_width: 100 (X-distance around peaks)
Dependencies:
- numpy
- opencv-python (cv2)
- scipy (for peak detection)
Author: Aaron Ciuffo Date: December 2024
1"""Shoot mask cleaning pipeline for plant root analysis. 2 3This module provides functions for cleaning and filtering shoot segmentation masks 4from U-Net model predictions. The pipeline identifies exactly 5 shoot locations 5using adaptive Y-zone detection and X-axis peak finding, handling edge cases like 6failed germination and fallen shoots. 7 8Typical workflow: 9 10 Step 1: Calculate global statistics across all masks (one-time setup) 11 12 >>> from pathlib import Path 13 >>> import cv2 14 >>> from shoot_mask_cleaning import calculate_global_y_stats 15 >>> 16 >>> mask_dir = Path('data/shoot_masks') 17 >>> mask_paths = [str(f) for f in sorted(mask_dir.glob('*.png'))] 18 >>> 19 >>> global_stats = calculate_global_y_stats( 20 ... mask_paths, 21 ... kernel_size=5, 22 ... iterations=3, 23 ... y_min=200, 24 ... y_max=750 25 ... ) 26 >>> print(f"Global mean Y: {global_stats['global_mean']:.1f}") 27 >>> print(f"Global std Y: {global_stats['global_std']:.1f}") 28 29 Step 2: Process individual masks 30 31 >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline 32 >>> 33 >>> for mask_path in mask_paths: 34 ... result = clean_shoot_mask_pipeline(mask_path, global_stats) 35 ... 36 ... # Save cleaned mask 37 ... output_path = f"cleaned/{result['filename']}" 38 ... cv2.imwrite(output_path, result['cleaned_mask']) 39 ... 40 ... # Check results 41 ... print(f"{result['filename']}: {result['method']}, " 42 ... f"{result['num_components']} components") 43 44 Step 3: Debug problematic images 45 46 >>> from shoot_mask_visualization import debug_peak_detection 47 >>> 48 >>> # Visualize complete pipeline for troubleshooting 49 >>> debug_peak_detection('data/shoot_masks/problem_image.png', global_stats) 50 51Recommended parameters (validated on 19 test images): 52 53 Initial closing: 54 - kernel_size: 7 55 - iterations: 5 56 57 Y-zone boundaries: 58 - y_min: 200 (shoots never above this) 59 - y_max: 750 (shoots never below this) 60 61 Peak detection: 62 - n_peaks: 5 (expected number of plants) 63 - min_distance: 300 (minimum spacing between shoots) 64 - x_min: 1000 (shoots never left of this) 65 - initial_std: 2.0 (starting Y-zone width) 66 - quality_check_threshold: 2.5 (when to validate peak quality) 67 - min_peak_width: 20 (minimum width for valid peaks) 68 - min_peak_height: 10 (minimum height for valid peaks) 69 - min_area: 500 (for size-based fallback) 70 71 Filtering: 72 - band_width: 100 (X-distance around peaks) 73 74Dependencies: 75 - numpy 76 - opencv-python (cv2) 77 - scipy (for peak detection) 78 79Author: Aaron Ciuffo 80Date: December 2024 81""" 82 83import numpy as np 84import cv2 85from pathlib import Path 86from scipy.signal import find_peaks, peak_widths 87 88 89def join_shoot_fragments(mask, kernel_size=7, iterations=5): 90 """Join small shoot fragments using morphological closing. 91 92 Applies morphological closing to connect nearby fragments and fill small gaps 93 in the segmentation mask. This preprocessing step improves peak detection by 94 creating more continuous shoot structures. 95 96 Args: 97 mask: Binary mask array with values 0 and 255. 98 kernel_size: Size of square structuring element (default: 7). 99 iterations: Number of closing iterations (default: 5). 100 101 Returns: 102 Binary mask array with joined fragments (values 0 and 255). 103 104 Example: 105 >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) 106 >>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3) 107 >>> cv2.imwrite('joined_mask.png', joined) 108 """ 109 kernel = np.ones((kernel_size, kernel_size), np.uint8) 110 joined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=iterations) 111 return joined 112 113 114def calculate_y_density(mask): 115 """Calculate pixel density per Y coordinate. 116 117 Sums the number of foreground pixels along each row to create a 1D density 118 profile. Used for finding the vertical position of shoots. 119 120 Args: 121 mask: Binary mask with values 0 and 255. 122 123 Returns: 124 1D numpy array of pixel counts per row (length = mask height). 125 126 Example: 127 >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) 128 >>> density = calculate_y_density(mask) 129 >>> print(f"Peak density at Y={np.argmax(density)}") 130 """ 131 return np.sum(mask > 0, axis=1) 132 133 134def calculate_weighted_y_stats(density, y_min=200, y_max=750): 135 """Calculate weighted mean and standard deviation of Y positions within ROI. 136 137 Computes the center of mass and spread of shoot pixels along the Y-axis, 138 restricted to the valid Y-range. Pixels outside the ROI are ignored to 139 prevent noise from biasing statistics. 140 141 Args: 142 density: 1D array of pixel counts per row from calculate_y_density. 143 y_min: Minimum Y coordinate to consider (default: 200). 144 y_max: Maximum Y coordinate to consider (default: 750). 145 146 Returns: 147 Dictionary with keys: 148 - 'mean': Weighted mean Y position (pixels). 149 - 'std': Weighted standard deviation (pixels). 150 151 Example: 152 >>> density = calculate_y_density(mask) 153 >>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) 154 >>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}") 155 """ 156 # Clip density to ROI 157 roi_density = density.copy() 158 roi_density[:y_min] = 0 159 roi_density[y_max:] = 0 160 161 y_positions = np.arange(len(roi_density)) 162 total_pixels = np.sum(roi_density) 163 164 if total_pixels == 0: 165 return {'mean': 0, 'std': 0} 166 167 # Weighted mean (center of mass) 168 weighted_mean = np.sum(y_positions * roi_density) / total_pixels 169 170 # Weighted standard deviation 171 weighted_var = np.sum(roi_density * (y_positions - weighted_mean)**2) / total_pixels 172 weighted_std = np.sqrt(weighted_var) 173 174 return { 175 'mean': weighted_mean, 176 'std': weighted_std 177 } 178 179 180def calculate_global_y_stats(mask_paths, kernel_size=7, iterations=5, y_min=200, y_max=750): 181 """Calculate global weighted mean and standard deviation across all masks. 182 183 Processes all masks to establish global shoot position statistics. These 184 statistics serve as priors for individual image processing, enabling adaptive 185 Y-zone detection that handles variation in shoot positions. 186 187 Args: 188 mask_paths: List of paths (strings or Path objects) to mask files. 189 kernel_size: Kernel size for initial morphological closing (default: 7). 190 iterations: Number of closing iterations (default: 5). 191 y_min: Minimum Y coordinate to consider (default: 200). 192 y_max: Maximum Y coordinate to consider (default: 750). 193 194 Returns: 195 Dictionary with keys: 196 - 'global_mean': Mean Y position across all images (pixels). 197 - 'global_std': Mean standard deviation across all images (pixels). 198 - 'all_stats': List of per-image statistics dictionaries. 199 - 'all_means': List of individual image means. 200 - 'all_stds': List of individual image standard deviations. 201 - 'y_min': Y minimum boundary used. 202 - 'y_max': Y maximum boundary used. 203 204 Example: 205 >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] 206 >>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750) 207 >>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± " 208 ... f"{2 * global_stats['global_std']:.0f} pixels") 209 210 Note: 211 This function should be called once per dataset to establish the global 212 statistics used for processing all individual images. 213 """ 214 all_stats = [] 215 216 for path in mask_paths: 217 mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) 218 joined = join_shoot_fragments(mask, kernel_size, iterations) 219 density = calculate_y_density(joined) 220 stats = calculate_weighted_y_stats(density, y_min, y_max) 221 all_stats.append(stats) 222 223 # Calculate mean of means and mean of stds 224 means = [s['mean'] for s in all_stats] 225 stds = [s['std'] for s in all_stats] 226 227 global_mean = np.mean(means) 228 global_std = np.mean(stds) 229 230 return { 231 'global_mean': global_mean, 232 'global_std': global_std, 233 'all_stats': all_stats, 234 'all_means': means, 235 'all_stds': stds, 236 'y_min': y_min, 237 'y_max': y_max 238 } 239 240 241def calculate_x_projection(mask, global_stats, std_multiplier=2.0): 242 """Calculate X-axis projection within the global Y-zone. 243 244 Sums pixels along columns (Y-axis) within the vertical zone defined by global 245 statistics. Creates a 1D signal showing horizontal distribution of shoot mass. 246 247 Args: 248 mask: Binary mask with values 0 and 255. 249 global_stats: Global statistics from calculate_global_y_stats. 250 std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0). 251 252 Returns: 253 1D numpy array of pixel counts per column (length = mask width). 254 255 Example: 256 >>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0) 257 >>> import matplotlib.pyplot as plt 258 >>> plt.plot(x_proj) 259 >>> plt.xlabel('X coordinate') 260 >>> plt.ylabel('Pixel count') 261 >>> plt.show() 262 """ 263 # Calculate Y bounds 264 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 265 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 266 267 # Extract Y-zone region and sum along Y axis 268 zone_region = mask[global_upper:global_lower, :] 269 x_projection = np.sum(zone_region > 0, axis=0) 270 271 return x_projection 272 273 274def find_shoot_peaks(x_projection, n_peaks=5, min_distance=300, x_min=1000): 275 """Find peak locations in X-projection using scipy peak detection. 276 277 Identifies the N strongest peaks in the horizontal projection that satisfy 278 spacing and position constraints. Returns peaks sorted left to right. 279 280 Args: 281 x_projection: 1D array of pixel counts from calculate_x_projection. 282 n_peaks: Number of peaks to find (default: 5). 283 min_distance: Minimum distance between peaks in pixels (default: 300). 284 x_min: Minimum X coordinate to consider (default: 1000). 285 286 Returns: 287 Numpy array of peak X-coordinates, sorted left to right. 288 289 Example: 290 >>> x_proj = calculate_x_projection(mask, global_stats) 291 >>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300) 292 >>> print(f"Found peaks at X positions: {peaks}") 293 """ 294 # Mask out x < x_min 295 masked_projection = x_projection.copy() 296 masked_projection[:x_min] = 0 297 298 # Find peaks with minimum distance constraint 299 peaks, properties = find_peaks(masked_projection, distance=min_distance) 300 301 # Get peak heights and select top n_peaks 302 peak_heights = masked_projection[peaks] 303 top_indices = np.argsort(peak_heights)[-n_peaks:] 304 top_peaks = peaks[top_indices] 305 306 # Sort by position (left to right) 307 top_peaks = np.sort(top_peaks) 308 309 return top_peaks 310 311 312def validate_peak_quality(x_projection, peaks, min_peak_height=10, min_peak_width=20): 313 """Check if detected peaks are likely to be real shoots versus noise. 314 315 Validates peak quality by checking both height (signal strength) and width 316 (spatial extent). Narrow spikes indicate noise, while wide peaks indicate 317 actual shoots. Requires at least 3 peaks to pass validation. 318 319 Args: 320 x_projection: 1D array of pixel counts. 321 peaks: Array of peak X-coordinates from find_shoot_peaks. 322 min_peak_height: Minimum height for a valid peak (default: 10). 323 min_peak_width: Minimum width at half-height for valid peak (default: 20). 324 325 Returns: 326 Boolean indicating if peaks are high quality (True) or likely noise (False). 327 328 Example: 329 >>> peaks = find_shoot_peaks(x_proj, n_peaks=5) 330 >>> if validate_peak_quality(x_proj, peaks): 331 ... print("High quality peaks detected") 332 ... else: 333 ... print("Peaks may be noise, consider widening Y-zone") 334 """ 335 peak_heights = x_projection[peaks] 336 337 # Calculate peak widths at half prominence 338 widths, _, _, _ = peak_widths(x_projection, peaks, rel_height=0.5) 339 340 # Check if at least 3 peaks are both tall enough AND wide enough 341 valid_peaks = np.sum((peak_heights >= min_peak_height) & (widths >= min_peak_width)) 342 343 return valid_peaks >= 3 344 345 346def find_shoot_peaks_with_size_fallback(mask, global_stats, n_peaks=5, min_distance=300, 347 x_min=1000, initial_std=2.0, max_std=3.5, 348 std_step=0.25, min_area=500, min_peak_width=20, 349 quality_check_threshold=2.5): 350 """Adaptively find shoot peaks with fallback to size-based detection. 351 352 Three-stage detection strategy: 353 1. Normal cases (std 2.0-2.25): X-projection peaks without quality checks 354 2. Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation 355 3. Size fallback: Largest components by area when projection methods fail 356 357 This handles both typical shoots and edge cases like failed germination or 358 shoots that have fallen outside the typical vertical zone. 359 360 Args: 361 mask: Binary mask with values 0 and 255. 362 global_stats: Global statistics from calculate_global_y_stats. 363 n_peaks: Target number of peaks (default: 5). 364 min_distance: Minimum distance between peaks in pixels (default: 300). 365 x_min: Minimum X coordinate to consider (default: 1000). 366 initial_std: Starting std multiplier (default: 2.0). 367 max_std: Maximum std multiplier to try (default: 3.5). 368 std_step: Step size for widening (default: 0.25). 369 min_area: Minimum area for size fallback in pixels (default: 500). 370 min_peak_width: Minimum peak width for quality validation (default: 20). 371 quality_check_threshold: Std threshold to activate quality checks (default: 2.5). 372 373 Returns: 374 Tuple of (peaks, x_projection, std_multiplier_used, method_used) where: 375 - peaks: Array of peak X-coordinates 376 - x_projection: The X-projection array used 377 - std_multiplier_used: The std multiplier that succeeded 378 - method_used: Either "projection" or "size_fallback" 379 380 Example: 381 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 382 ... mask, global_stats, quality_check_threshold=2.5 383 ... ) 384 >>> print(f"Method: {method}, Std: {std:.2f}") 385 >>> print(f"Peaks at: {peaks}") 386 """ 387 std_multiplier = initial_std 388 389 # Try standard X-projection approach with progressive widening 390 while std_multiplier <= max_std: 391 x_projection = calculate_x_projection(mask, global_stats, std_multiplier) 392 peaks = find_shoot_peaks(x_projection, n_peaks, min_distance, x_min) 393 394 if len(peaks) >= n_peaks: 395 # Only check quality if we're in desperate territory (high std) 396 if std_multiplier >= quality_check_threshold: 397 if validate_peak_quality(x_projection, peaks, min_peak_height=10, 398 min_peak_width=min_peak_width): 399 print(f"Found {len(peaks)} quality peaks with std_multiplier={std_multiplier:.2f}") 400 return peaks, x_projection, std_multiplier, "projection" 401 else: 402 print(f"Found {len(peaks)} peaks but quality too low (std={std_multiplier:.2f})") 403 else: 404 # Low std values - just trust the peaks 405 print(f"Found {len(peaks)} peaks with std_multiplier={std_multiplier:.2f}") 406 return peaks, x_projection, std_multiplier, "projection" 407 408 std_multiplier += std_step 409 410 # Fallback: Use connected components filtered by size and X position 411 print(f"Projection method insufficient, trying size-based fallback...") 412 413 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) 414 415 valid_components = [] 416 for i in range(1, num_labels): 417 area = stats[i, cv2.CC_STAT_AREA] 418 x_left = stats[i, cv2.CC_STAT_LEFT] 419 width = stats[i, cv2.CC_STAT_WIDTH] 420 bbox_center = x_left + width // 2 421 422 if area >= min_area and bbox_center >= x_min: 423 valid_components.append({ 424 'label': i, 425 'bbox_center': bbox_center, 426 'area': area 427 }) 428 429 # Sort by area (largest first) and take top n_peaks 430 valid_components.sort(key=lambda x: x['area'], reverse=True) 431 selected = valid_components[:n_peaks] 432 433 # Extract X positions and sort left to right 434 peaks = np.array([c['bbox_center'] for c in selected]) 435 peaks = np.sort(peaks) 436 437 # Create full projection for visualization 438 x_projection = np.sum(mask > 0, axis=0) 439 440 print(f"Found {len(peaks)} peaks using size-based fallback (min_area={min_area})") 441 return peaks, x_projection, max_std, "size_fallback" 442 443 444def merge_with_narrow_bands(mask, peaks, global_stats, std_multiplier, band_width=100): 445 """Find components near peaks and keep entire components (not clipped to bands). 446 447 Uses narrow X-bands around detected peaks to identify which components belong 448 to each shoot location. Keeps complete components rather than clipping them 449 to band boundaries, preserving shoot morphology. 450 451 Args: 452 mask: Binary mask with values 0 and 255. 453 peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback. 454 global_stats: Global statistics from calculate_global_y_stats. 455 std_multiplier: Std multiplier used for peak detection. 456 band_width: Half-width of X-band for identifying components (default: 100). 457 458 Returns: 459 Binary mask with complete components near peaks (values 0 and 255). 460 461 Example: 462 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 463 ... mask, global_stats 464 ... ) 465 >>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, 466 ... band_width=100) 467 """ 468 # Find all connected components in the full mask 469 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) 470 471 # Create empty output 472 output_mask = np.zeros_like(mask) 473 474 # For each peak, find components whose center falls in the band 475 kept_labels = set() 476 477 for i, peak_x in enumerate(peaks): 478 x_left = peak_x - band_width 479 x_right = peak_x + band_width 480 481 # Check each component 482 for label in range(1, num_labels): 483 if label in kept_labels: 484 continue # Already assigned to another peak 485 486 # Get component's X-center 487 comp_x_left = stats[label, cv2.CC_STAT_LEFT] 488 comp_width = stats[label, cv2.CC_STAT_WIDTH] 489 comp_x_center = comp_x_left + comp_width // 2 490 491 # If center is in this peak's band, keep the ENTIRE component 492 if x_left <= comp_x_center <= x_right: 493 output_mask[labels == label] = 255 494 kept_labels.add(label) 495 print(f" Peak {i+1} at X={peak_x}: keeping component with center at X={comp_x_center}") 496 497 return output_mask 498 499 500def filter_one_per_peak(mask, peaks, global_stats, std_multiplier, band_width=100, 501 require_y_zone=True): 502 """Keep the largest component near each peak, optionally validating Y-zone. 503 504 For each detected peak, identifies all components within the X-band and keeps 505 only the largest one. Optionally filters out components that don't touch the 506 Y-zone (disabled for size-fallback method to handle fallen shoots). 507 508 Args: 509 mask: Binary mask with values 0 and 255. 510 peaks: Array of peak X-coordinates. 511 global_stats: Global statistics from calculate_global_y_stats. 512 std_multiplier: Std multiplier used for peak detection. 513 band_width: Half-width for assigning components to peaks (default: 100). 514 require_y_zone: If True, only keep components touching Y-zone (default: True). 515 516 Returns: 517 Binary mask with exactly one component per peak (values 0 and 255). 518 519 Example: 520 >>> # For projection method (require Y-zone) 521 >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, 522 ... band_width=100, require_y_zone=True) 523 >>> 524 >>> # For size fallback (allow fallen shoots outside Y-zone) 525 >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, 526 ... band_width=100, require_y_zone=False) 527 528 Note: 529 Set require_y_zone=False when using size-based fallback method to preserve 530 large shoots that have fallen outside the typical vertical zone. 531 """ 532 # Calculate Y-zone bounds 533 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 534 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 535 536 # Find all components 537 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) 538 539 output_mask = np.zeros_like(mask) 540 541 for peak_x in peaks: 542 x_left = peak_x - band_width 543 x_right = peak_x + band_width 544 545 # Find all components in this peak's band 546 candidates = [] 547 for i in range(1, num_labels): 548 x_center = stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] // 2 549 550 if x_left <= x_center <= x_right: 551 area = stats[i, cv2.CC_STAT_AREA] 552 553 if require_y_zone: 554 # Check if component touches Y-zone 555 component_mask = (labels == i).astype(np.uint8) 556 y_zone_pixels = np.sum(component_mask[global_upper:global_lower, :]) 557 558 if y_zone_pixels > 0: 559 candidates.append((i, area)) 560 else: 561 # No Y-zone requirement - keep all 562 candidates.append((i, area)) 563 564 # Keep the largest one from this peak 565 if candidates: 566 best_label = max(candidates, key=lambda x: x[1])[0] 567 output_mask[labels == best_label] = 255 568 print(f" Peak at X={peak_x}: keeping component {best_label} " 569 f"(area={stats[best_label, cv2.CC_STAT_AREA]})") 570 else: 571 print(f" Peak at X={peak_x}: WARNING - no valid components found!") 572 573 return output_mask 574 575 576def clean_shoot_mask_pipeline(mask_path, global_stats, closing_kernel_size=7, 577 closing_iterations=5, quality_check_threshold=2.5, 578 band_width=100): 579 """Complete pipeline to clean a shoot mask from raw predictions to final output. 580 581 Runs the full processing workflow: 582 1. Load mask and apply initial morphological closing 583 2. Detect 5 shoot locations using adaptive peak finding 584 3. Filter to components near detected peaks 585 4. Keep exactly one component per peak (largest in each band) 586 587 Args: 588 mask_path: Path to shoot mask file (string or Path object). 589 global_stats: Global statistics from calculate_global_y_stats. 590 closing_kernel_size: Initial closing kernel size (default: 7). 591 closing_iterations: Initial closing iterations (default: 5). 592 quality_check_threshold: Std threshold for quality checks (default: 2.5). 593 band_width: X-band width around peaks in pixels (default: 100). 594 595 Returns: 596 Dictionary with keys: 597 - 'cleaned_mask': Final cleaned binary mask (0 and 255) 598 - 'peaks': Detected peak X-coordinates 599 - 'std_used': Std multiplier that succeeded 600 - 'method': Detection method ("projection" or "size_fallback") 601 - 'num_components': Number of components in final mask 602 - 'filename': Input filename 603 604 Example: 605 >>> from pathlib import Path 606 >>> 607 >>> # Process all masks 608 >>> mask_dir = Path('data/shoot_masks') 609 >>> for mask_file in mask_dir.glob('*.png'): 610 ... result = clean_shoot_mask_pipeline(str(mask_file), global_stats) 611 ... 612 ... # Save cleaned mask 613 ... output_path = f"cleaned/{result['filename']}" 614 ... cv2.imwrite(output_path, result['cleaned_mask']) 615 ... 616 ... # Log results 617 ... if result['num_components'] != 5: 618 ... print(f"WARNING: {result['filename']} has " 619 ... f"{result['num_components']} components") 620 621 Note: 622 Parameters are optimized for typical plant imaging conditions: 623 - Shoots at X > 1000 pixels 624 - Shoots between Y = 200-750 pixels (adaptive per image) 625 - ~400-500 pixel spacing between plants 626 """ 627 print(f"Processing: {Path(mask_path).name}") 628 629 # Step 1: Load and initial joining 630 mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) 631 joined_mask = join_shoot_fragments(mask, closing_kernel_size, closing_iterations) 632 633 # Step 2: Find shoot locations 634 peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback( 635 joined_mask, global_stats, 636 quality_check_threshold=quality_check_threshold 637 ) 638 print(f" Found {len(peaks)} peaks using {method} (std={std_used:.2f})") 639 640 # Step 3: Filter to narrow bands around peaks 641 narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 642 std_used, band_width=band_width) 643 644 # Step 4: Keep one component per peak 645 require_y_zone = (method != "size_fallback") 646 final_mask = filter_one_per_peak(narrow_result, peaks, global_stats, std_used, 647 band_width=band_width, require_y_zone=require_y_zone) 648 649 # Count final components 650 num_labels, _, _, _ = cv2.connectedComponentsWithStats(final_mask, connectivity=8) 651 num_components = num_labels - 1 652 653 return { 654 'cleaned_mask': final_mask, 655 'peaks': peaks, 656 'std_used': std_used, 657 'method': method, 658 'num_components': num_components, 659 'filename': Path(mask_path).name 660 } 661 662 663# Recommended parameters as a reference 664RECOMMENDED_PARAMS = { 665 'initial_closing': { 666 'kernel_size': 7, 667 'iterations': 5 668 }, 669 'y_zone_boundaries': { 670 'y_min': 200, # Shoots never above this 671 'y_max': 750 # Shoots never below this 672 }, 673 'peak_detection': { 674 'n_peaks': 5, 675 'min_distance': 300, # Minimum spacing between shoots 676 'x_min': 1000, # Shoots never left of this 677 'initial_std': 2.0, # Starting Y-zone width 678 'max_std': 3.5, # Maximum Y-zone width 679 'quality_check_threshold': 2.5, # When to validate peaks 680 'min_peak_width': 20, # Minimum width for valid peaks 681 'min_peak_height': 10, # Minimum height for valid peaks 682 'min_area': 500 # For size-based fallback 683 }, 684 'filtering': { 685 'band_width': 100 # X-distance around peaks 686 } 687}
90def join_shoot_fragments(mask, kernel_size=7, iterations=5): 91 """Join small shoot fragments using morphological closing. 92 93 Applies morphological closing to connect nearby fragments and fill small gaps 94 in the segmentation mask. This preprocessing step improves peak detection by 95 creating more continuous shoot structures. 96 97 Args: 98 mask: Binary mask array with values 0 and 255. 99 kernel_size: Size of square structuring element (default: 7). 100 iterations: Number of closing iterations (default: 5). 101 102 Returns: 103 Binary mask array with joined fragments (values 0 and 255). 104 105 Example: 106 >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) 107 >>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3) 108 >>> cv2.imwrite('joined_mask.png', joined) 109 """ 110 kernel = np.ones((kernel_size, kernel_size), np.uint8) 111 joined = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=iterations) 112 return joined
Join small shoot fragments using morphological closing.
Applies morphological closing to connect nearby fragments and fill small gaps in the segmentation mask. This preprocessing step improves peak detection by creating more continuous shoot structures.
Arguments:
- mask: Binary mask array with values 0 and 255.
- kernel_size: Size of square structuring element (default: 7).
- iterations: Number of closing iterations (default: 5).
Returns:
Binary mask array with joined fragments (values 0 and 255).
Example:
>>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) >>> joined = join_shoot_fragments(mask, kernel_size=5, iterations=3) >>> cv2.imwrite('joined_mask.png', joined)
115def calculate_y_density(mask): 116 """Calculate pixel density per Y coordinate. 117 118 Sums the number of foreground pixels along each row to create a 1D density 119 profile. Used for finding the vertical position of shoots. 120 121 Args: 122 mask: Binary mask with values 0 and 255. 123 124 Returns: 125 1D numpy array of pixel counts per row (length = mask height). 126 127 Example: 128 >>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) 129 >>> density = calculate_y_density(mask) 130 >>> print(f"Peak density at Y={np.argmax(density)}") 131 """ 132 return np.sum(mask > 0, axis=1)
Calculate pixel density per Y coordinate.
Sums the number of foreground pixels along each row to create a 1D density profile. Used for finding the vertical position of shoots.
Arguments:
- mask: Binary mask with values 0 and 255.
Returns:
1D numpy array of pixel counts per row (length = mask height).
Example:
>>> mask = cv2.imread('shoot_mask.png', cv2.IMREAD_GRAYSCALE) >>> density = calculate_y_density(mask) >>> print(f"Peak density at Y={np.argmax(density)}")
135def calculate_weighted_y_stats(density, y_min=200, y_max=750): 136 """Calculate weighted mean and standard deviation of Y positions within ROI. 137 138 Computes the center of mass and spread of shoot pixels along the Y-axis, 139 restricted to the valid Y-range. Pixels outside the ROI are ignored to 140 prevent noise from biasing statistics. 141 142 Args: 143 density: 1D array of pixel counts per row from calculate_y_density. 144 y_min: Minimum Y coordinate to consider (default: 200). 145 y_max: Maximum Y coordinate to consider (default: 750). 146 147 Returns: 148 Dictionary with keys: 149 - 'mean': Weighted mean Y position (pixels). 150 - 'std': Weighted standard deviation (pixels). 151 152 Example: 153 >>> density = calculate_y_density(mask) 154 >>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) 155 >>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}") 156 """ 157 # Clip density to ROI 158 roi_density = density.copy() 159 roi_density[:y_min] = 0 160 roi_density[y_max:] = 0 161 162 y_positions = np.arange(len(roi_density)) 163 total_pixels = np.sum(roi_density) 164 165 if total_pixels == 0: 166 return {'mean': 0, 'std': 0} 167 168 # Weighted mean (center of mass) 169 weighted_mean = np.sum(y_positions * roi_density) / total_pixels 170 171 # Weighted standard deviation 172 weighted_var = np.sum(roi_density * (y_positions - weighted_mean)**2) / total_pixels 173 weighted_std = np.sqrt(weighted_var) 174 175 return { 176 'mean': weighted_mean, 177 'std': weighted_std 178 }
Calculate weighted mean and standard deviation of Y positions within ROI.
Computes the center of mass and spread of shoot pixels along the Y-axis, restricted to the valid Y-range. Pixels outside the ROI are ignored to prevent noise from biasing statistics.
Arguments:
- density: 1D array of pixel counts per row from calculate_y_density.
- y_min: Minimum Y coordinate to consider (default: 200).
- y_max: Maximum Y coordinate to consider (default: 750).
Returns:
Dictionary with keys: - 'mean': Weighted mean Y position (pixels). - 'std': Weighted standard deviation (pixels).
Example:
>>> density = calculate_y_density(mask) >>> stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) >>> print(f"Center of mass: Y={stats['mean']:.1f} ± {stats['std']:.1f}")
181def calculate_global_y_stats(mask_paths, kernel_size=7, iterations=5, y_min=200, y_max=750): 182 """Calculate global weighted mean and standard deviation across all masks. 183 184 Processes all masks to establish global shoot position statistics. These 185 statistics serve as priors for individual image processing, enabling adaptive 186 Y-zone detection that handles variation in shoot positions. 187 188 Args: 189 mask_paths: List of paths (strings or Path objects) to mask files. 190 kernel_size: Kernel size for initial morphological closing (default: 7). 191 iterations: Number of closing iterations (default: 5). 192 y_min: Minimum Y coordinate to consider (default: 200). 193 y_max: Maximum Y coordinate to consider (default: 750). 194 195 Returns: 196 Dictionary with keys: 197 - 'global_mean': Mean Y position across all images (pixels). 198 - 'global_std': Mean standard deviation across all images (pixels). 199 - 'all_stats': List of per-image statistics dictionaries. 200 - 'all_means': List of individual image means. 201 - 'all_stds': List of individual image standard deviations. 202 - 'y_min': Y minimum boundary used. 203 - 'y_max': Y maximum boundary used. 204 205 Example: 206 >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] 207 >>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750) 208 >>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± " 209 ... f"{2 * global_stats['global_std']:.0f} pixels") 210 211 Note: 212 This function should be called once per dataset to establish the global 213 statistics used for processing all individual images. 214 """ 215 all_stats = [] 216 217 for path in mask_paths: 218 mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE) 219 joined = join_shoot_fragments(mask, kernel_size, iterations) 220 density = calculate_y_density(joined) 221 stats = calculate_weighted_y_stats(density, y_min, y_max) 222 all_stats.append(stats) 223 224 # Calculate mean of means and mean of stds 225 means = [s['mean'] for s in all_stats] 226 stds = [s['std'] for s in all_stats] 227 228 global_mean = np.mean(means) 229 global_std = np.mean(stds) 230 231 return { 232 'global_mean': global_mean, 233 'global_std': global_std, 234 'all_stats': all_stats, 235 'all_means': means, 236 'all_stds': stds, 237 'y_min': y_min, 238 'y_max': y_max 239 }
Calculate global weighted mean and standard deviation across all masks.
Processes all masks to establish global shoot position statistics. These statistics serve as priors for individual image processing, enabling adaptive Y-zone detection that handles variation in shoot positions.
Arguments:
- mask_paths: List of paths (strings or Path objects) to mask files.
- kernel_size: Kernel size for initial morphological closing (default: 7).
- iterations: Number of closing iterations (default: 5).
- y_min: Minimum Y coordinate to consider (default: 200).
- y_max: Maximum Y coordinate to consider (default: 750).
Returns:
Dictionary with keys: - 'global_mean': Mean Y position across all images (pixels). - 'global_std': Mean standard deviation across all images (pixels). - 'all_stats': List of per-image statistics dictionaries. - 'all_means': List of individual image means. - 'all_stds': List of individual image standard deviations. - 'y_min': Y minimum boundary used. - 'y_max': Y maximum boundary used.
Example:
>>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] >>> global_stats = calculate_global_y_stats(mask_files, y_min=200, y_max=750) >>> print(f"Expected shoot zone: {global_stats['global_mean']:.0f} ± " ... f"{2 * global_stats['global_std']:.0f} pixels")
Note:
This function should be called once per dataset to establish the global statistics used for processing all individual images.
242def calculate_x_projection(mask, global_stats, std_multiplier=2.0): 243 """Calculate X-axis projection within the global Y-zone. 244 245 Sums pixels along columns (Y-axis) within the vertical zone defined by global 246 statistics. Creates a 1D signal showing horizontal distribution of shoot mass. 247 248 Args: 249 mask: Binary mask with values 0 and 255. 250 global_stats: Global statistics from calculate_global_y_stats. 251 std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0). 252 253 Returns: 254 1D numpy array of pixel counts per column (length = mask width). 255 256 Example: 257 >>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0) 258 >>> import matplotlib.pyplot as plt 259 >>> plt.plot(x_proj) 260 >>> plt.xlabel('X coordinate') 261 >>> plt.ylabel('Pixel count') 262 >>> plt.show() 263 """ 264 # Calculate Y bounds 265 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 266 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 267 268 # Extract Y-zone region and sum along Y axis 269 zone_region = mask[global_upper:global_lower, :] 270 x_projection = np.sum(zone_region > 0, axis=0) 271 272 return x_projection
Calculate X-axis projection within the global Y-zone.
Sums pixels along columns (Y-axis) within the vertical zone defined by global statistics. Creates a 1D signal showing horizontal distribution of shoot mass.
Arguments:
- mask: Binary mask with values 0 and 255.
- global_stats: Global statistics from calculate_global_y_stats.
- std_multiplier: Standard deviation multiplier for Y-zone width (default: 2.0).
Returns:
1D numpy array of pixel counts per column (length = mask width).
Example:
>>> x_proj = calculate_x_projection(mask, global_stats, std_multiplier=2.0) >>> import matplotlib.pyplot as plt >>> plt.plot(x_proj) >>> plt.xlabel('X coordinate') >>> plt.ylabel('Pixel count') >>> plt.show()
275def find_shoot_peaks(x_projection, n_peaks=5, min_distance=300, x_min=1000): 276 """Find peak locations in X-projection using scipy peak detection. 277 278 Identifies the N strongest peaks in the horizontal projection that satisfy 279 spacing and position constraints. Returns peaks sorted left to right. 280 281 Args: 282 x_projection: 1D array of pixel counts from calculate_x_projection. 283 n_peaks: Number of peaks to find (default: 5). 284 min_distance: Minimum distance between peaks in pixels (default: 300). 285 x_min: Minimum X coordinate to consider (default: 1000). 286 287 Returns: 288 Numpy array of peak X-coordinates, sorted left to right. 289 290 Example: 291 >>> x_proj = calculate_x_projection(mask, global_stats) 292 >>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300) 293 >>> print(f"Found peaks at X positions: {peaks}") 294 """ 295 # Mask out x < x_min 296 masked_projection = x_projection.copy() 297 masked_projection[:x_min] = 0 298 299 # Find peaks with minimum distance constraint 300 peaks, properties = find_peaks(masked_projection, distance=min_distance) 301 302 # Get peak heights and select top n_peaks 303 peak_heights = masked_projection[peaks] 304 top_indices = np.argsort(peak_heights)[-n_peaks:] 305 top_peaks = peaks[top_indices] 306 307 # Sort by position (left to right) 308 top_peaks = np.sort(top_peaks) 309 310 return top_peaks
Find peak locations in X-projection using scipy peak detection.
Identifies the N strongest peaks in the horizontal projection that satisfy spacing and position constraints. Returns peaks sorted left to right.
Arguments:
- x_projection: 1D array of pixel counts from calculate_x_projection.
- n_peaks: Number of peaks to find (default: 5).
- min_distance: Minimum distance between peaks in pixels (default: 300).
- x_min: Minimum X coordinate to consider (default: 1000).
Returns:
Numpy array of peak X-coordinates, sorted left to right.
Example:
>>> x_proj = calculate_x_projection(mask, global_stats) >>> peaks = find_shoot_peaks(x_proj, n_peaks=5, min_distance=300) >>> print(f"Found peaks at X positions: {peaks}")
313def validate_peak_quality(x_projection, peaks, min_peak_height=10, min_peak_width=20): 314 """Check if detected peaks are likely to be real shoots versus noise. 315 316 Validates peak quality by checking both height (signal strength) and width 317 (spatial extent). Narrow spikes indicate noise, while wide peaks indicate 318 actual shoots. Requires at least 3 peaks to pass validation. 319 320 Args: 321 x_projection: 1D array of pixel counts. 322 peaks: Array of peak X-coordinates from find_shoot_peaks. 323 min_peak_height: Minimum height for a valid peak (default: 10). 324 min_peak_width: Minimum width at half-height for valid peak (default: 20). 325 326 Returns: 327 Boolean indicating if peaks are high quality (True) or likely noise (False). 328 329 Example: 330 >>> peaks = find_shoot_peaks(x_proj, n_peaks=5) 331 >>> if validate_peak_quality(x_proj, peaks): 332 ... print("High quality peaks detected") 333 ... else: 334 ... print("Peaks may be noise, consider widening Y-zone") 335 """ 336 peak_heights = x_projection[peaks] 337 338 # Calculate peak widths at half prominence 339 widths, _, _, _ = peak_widths(x_projection, peaks, rel_height=0.5) 340 341 # Check if at least 3 peaks are both tall enough AND wide enough 342 valid_peaks = np.sum((peak_heights >= min_peak_height) & (widths >= min_peak_width)) 343 344 return valid_peaks >= 3
Check if detected peaks are likely to be real shoots versus noise.
Validates peak quality by checking both height (signal strength) and width (spatial extent). Narrow spikes indicate noise, while wide peaks indicate actual shoots. Requires at least 3 peaks to pass validation.
Arguments:
- x_projection: 1D array of pixel counts.
- peaks: Array of peak X-coordinates from find_shoot_peaks.
- min_peak_height: Minimum height for a valid peak (default: 10).
- min_peak_width: Minimum width at half-height for valid peak (default: 20).
Returns:
Boolean indicating if peaks are high quality (True) or likely noise (False).
Example:
>>> peaks = find_shoot_peaks(x_proj, n_peaks=5) >>> if validate_peak_quality(x_proj, peaks): ... print("High quality peaks detected") ... else: ... print("Peaks may be noise, consider widening Y-zone")
347def find_shoot_peaks_with_size_fallback(mask, global_stats, n_peaks=5, min_distance=300, 348 x_min=1000, initial_std=2.0, max_std=3.5, 349 std_step=0.25, min_area=500, min_peak_width=20, 350 quality_check_threshold=2.5): 351 """Adaptively find shoot peaks with fallback to size-based detection. 352 353 Three-stage detection strategy: 354 1. Normal cases (std 2.0-2.25): X-projection peaks without quality checks 355 2. Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation 356 3. Size fallback: Largest components by area when projection methods fail 357 358 This handles both typical shoots and edge cases like failed germination or 359 shoots that have fallen outside the typical vertical zone. 360 361 Args: 362 mask: Binary mask with values 0 and 255. 363 global_stats: Global statistics from calculate_global_y_stats. 364 n_peaks: Target number of peaks (default: 5). 365 min_distance: Minimum distance between peaks in pixels (default: 300). 366 x_min: Minimum X coordinate to consider (default: 1000). 367 initial_std: Starting std multiplier (default: 2.0). 368 max_std: Maximum std multiplier to try (default: 3.5). 369 std_step: Step size for widening (default: 0.25). 370 min_area: Minimum area for size fallback in pixels (default: 500). 371 min_peak_width: Minimum peak width for quality validation (default: 20). 372 quality_check_threshold: Std threshold to activate quality checks (default: 2.5). 373 374 Returns: 375 Tuple of (peaks, x_projection, std_multiplier_used, method_used) where: 376 - peaks: Array of peak X-coordinates 377 - x_projection: The X-projection array used 378 - std_multiplier_used: The std multiplier that succeeded 379 - method_used: Either "projection" or "size_fallback" 380 381 Example: 382 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 383 ... mask, global_stats, quality_check_threshold=2.5 384 ... ) 385 >>> print(f"Method: {method}, Std: {std:.2f}") 386 >>> print(f"Peaks at: {peaks}") 387 """ 388 std_multiplier = initial_std 389 390 # Try standard X-projection approach with progressive widening 391 while std_multiplier <= max_std: 392 x_projection = calculate_x_projection(mask, global_stats, std_multiplier) 393 peaks = find_shoot_peaks(x_projection, n_peaks, min_distance, x_min) 394 395 if len(peaks) >= n_peaks: 396 # Only check quality if we're in desperate territory (high std) 397 if std_multiplier >= quality_check_threshold: 398 if validate_peak_quality(x_projection, peaks, min_peak_height=10, 399 min_peak_width=min_peak_width): 400 print(f"Found {len(peaks)} quality peaks with std_multiplier={std_multiplier:.2f}") 401 return peaks, x_projection, std_multiplier, "projection" 402 else: 403 print(f"Found {len(peaks)} peaks but quality too low (std={std_multiplier:.2f})") 404 else: 405 # Low std values - just trust the peaks 406 print(f"Found {len(peaks)} peaks with std_multiplier={std_multiplier:.2f}") 407 return peaks, x_projection, std_multiplier, "projection" 408 409 std_multiplier += std_step 410 411 # Fallback: Use connected components filtered by size and X position 412 print(f"Projection method insufficient, trying size-based fallback...") 413 414 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) 415 416 valid_components = [] 417 for i in range(1, num_labels): 418 area = stats[i, cv2.CC_STAT_AREA] 419 x_left = stats[i, cv2.CC_STAT_LEFT] 420 width = stats[i, cv2.CC_STAT_WIDTH] 421 bbox_center = x_left + width // 2 422 423 if area >= min_area and bbox_center >= x_min: 424 valid_components.append({ 425 'label': i, 426 'bbox_center': bbox_center, 427 'area': area 428 }) 429 430 # Sort by area (largest first) and take top n_peaks 431 valid_components.sort(key=lambda x: x['area'], reverse=True) 432 selected = valid_components[:n_peaks] 433 434 # Extract X positions and sort left to right 435 peaks = np.array([c['bbox_center'] for c in selected]) 436 peaks = np.sort(peaks) 437 438 # Create full projection for visualization 439 x_projection = np.sum(mask > 0, axis=0) 440 441 print(f"Found {len(peaks)} peaks using size-based fallback (min_area={min_area})") 442 return peaks, x_projection, max_std, "size_fallback"
Adaptively find shoot peaks with fallback to size-based detection.
Three-stage detection strategy:
- Normal cases (std 2.0-2.25): X-projection peaks without quality checks
- Widened search (std 2.5-3.5): Progressive Y-zone widening with quality validation
- Size fallback: Largest components by area when projection methods fail
This handles both typical shoots and edge cases like failed germination or shoots that have fallen outside the typical vertical zone.
Arguments:
- mask: Binary mask with values 0 and 255.
- global_stats: Global statistics from calculate_global_y_stats.
- n_peaks: Target number of peaks (default: 5).
- min_distance: Minimum distance between peaks in pixels (default: 300).
- x_min: Minimum X coordinate to consider (default: 1000).
- initial_std: Starting std multiplier (default: 2.0).
- max_std: Maximum std multiplier to try (default: 3.5).
- std_step: Step size for widening (default: 0.25).
- min_area: Minimum area for size fallback in pixels (default: 500).
- min_peak_width: Minimum peak width for quality validation (default: 20).
- quality_check_threshold: Std threshold to activate quality checks (default: 2.5).
Returns:
Tuple of (peaks, x_projection, std_multiplier_used, method_used) where: - peaks: Array of peak X-coordinates - x_projection: The X-projection array used - std_multiplier_used: The std multiplier that succeeded - method_used: Either "projection" or "size_fallback"
Example:
>>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( ... mask, global_stats, quality_check_threshold=2.5 ... ) >>> print(f"Method: {method}, Std: {std:.2f}") >>> print(f"Peaks at: {peaks}")
445def merge_with_narrow_bands(mask, peaks, global_stats, std_multiplier, band_width=100): 446 """Find components near peaks and keep entire components (not clipped to bands). 447 448 Uses narrow X-bands around detected peaks to identify which components belong 449 to each shoot location. Keeps complete components rather than clipping them 450 to band boundaries, preserving shoot morphology. 451 452 Args: 453 mask: Binary mask with values 0 and 255. 454 peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback. 455 global_stats: Global statistics from calculate_global_y_stats. 456 std_multiplier: Std multiplier used for peak detection. 457 band_width: Half-width of X-band for identifying components (default: 100). 458 459 Returns: 460 Binary mask with complete components near peaks (values 0 and 255). 461 462 Example: 463 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 464 ... mask, global_stats 465 ... ) 466 >>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, 467 ... band_width=100) 468 """ 469 # Find all connected components in the full mask 470 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) 471 472 # Create empty output 473 output_mask = np.zeros_like(mask) 474 475 # For each peak, find components whose center falls in the band 476 kept_labels = set() 477 478 for i, peak_x in enumerate(peaks): 479 x_left = peak_x - band_width 480 x_right = peak_x + band_width 481 482 # Check each component 483 for label in range(1, num_labels): 484 if label in kept_labels: 485 continue # Already assigned to another peak 486 487 # Get component's X-center 488 comp_x_left = stats[label, cv2.CC_STAT_LEFT] 489 comp_width = stats[label, cv2.CC_STAT_WIDTH] 490 comp_x_center = comp_x_left + comp_width // 2 491 492 # If center is in this peak's band, keep the ENTIRE component 493 if x_left <= comp_x_center <= x_right: 494 output_mask[labels == label] = 255 495 kept_labels.add(label) 496 print(f" Peak {i+1} at X={peak_x}: keeping component with center at X={comp_x_center}") 497 498 return output_mask
Find components near peaks and keep entire components (not clipped to bands).
Uses narrow X-bands around detected peaks to identify which components belong to each shoot location. Keeps complete components rather than clipping them to band boundaries, preserving shoot morphology.
Arguments:
- mask: Binary mask with values 0 and 255.
- peaks: Array of peak X-coordinates from find_shoot_peaks_with_size_fallback.
- global_stats: Global statistics from calculate_global_y_stats.
- std_multiplier: Std multiplier used for peak detection.
- band_width: Half-width of X-band for identifying components (default: 100).
Returns:
Binary mask with complete components near peaks (values 0 and 255).
Example:
>>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( ... mask, global_stats ... ) >>> filtered = merge_with_narrow_bands(mask, peaks, global_stats, std, ... band_width=100)
501def filter_one_per_peak(mask, peaks, global_stats, std_multiplier, band_width=100, 502 require_y_zone=True): 503 """Keep the largest component near each peak, optionally validating Y-zone. 504 505 For each detected peak, identifies all components within the X-band and keeps 506 only the largest one. Optionally filters out components that don't touch the 507 Y-zone (disabled for size-fallback method to handle fallen shoots). 508 509 Args: 510 mask: Binary mask with values 0 and 255. 511 peaks: Array of peak X-coordinates. 512 global_stats: Global statistics from calculate_global_y_stats. 513 std_multiplier: Std multiplier used for peak detection. 514 band_width: Half-width for assigning components to peaks (default: 100). 515 require_y_zone: If True, only keep components touching Y-zone (default: True). 516 517 Returns: 518 Binary mask with exactly one component per peak (values 0 and 255). 519 520 Example: 521 >>> # For projection method (require Y-zone) 522 >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, 523 ... band_width=100, require_y_zone=True) 524 >>> 525 >>> # For size fallback (allow fallen shoots outside Y-zone) 526 >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, 527 ... band_width=100, require_y_zone=False) 528 529 Note: 530 Set require_y_zone=False when using size-based fallback method to preserve 531 large shoots that have fallen outside the typical vertical zone. 532 """ 533 # Calculate Y-zone bounds 534 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 535 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 536 537 # Find all components 538 num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8) 539 540 output_mask = np.zeros_like(mask) 541 542 for peak_x in peaks: 543 x_left = peak_x - band_width 544 x_right = peak_x + band_width 545 546 # Find all components in this peak's band 547 candidates = [] 548 for i in range(1, num_labels): 549 x_center = stats[i, cv2.CC_STAT_LEFT] + stats[i, cv2.CC_STAT_WIDTH] // 2 550 551 if x_left <= x_center <= x_right: 552 area = stats[i, cv2.CC_STAT_AREA] 553 554 if require_y_zone: 555 # Check if component touches Y-zone 556 component_mask = (labels == i).astype(np.uint8) 557 y_zone_pixels = np.sum(component_mask[global_upper:global_lower, :]) 558 559 if y_zone_pixels > 0: 560 candidates.append((i, area)) 561 else: 562 # No Y-zone requirement - keep all 563 candidates.append((i, area)) 564 565 # Keep the largest one from this peak 566 if candidates: 567 best_label = max(candidates, key=lambda x: x[1])[0] 568 output_mask[labels == best_label] = 255 569 print(f" Peak at X={peak_x}: keeping component {best_label} " 570 f"(area={stats[best_label, cv2.CC_STAT_AREA]})") 571 else: 572 print(f" Peak at X={peak_x}: WARNING - no valid components found!") 573 574 return output_mask
Keep the largest component near each peak, optionally validating Y-zone.
For each detected peak, identifies all components within the X-band and keeps only the largest one. Optionally filters out components that don't touch the Y-zone (disabled for size-fallback method to handle fallen shoots).
Arguments:
- mask: Binary mask with values 0 and 255.
- peaks: Array of peak X-coordinates.
- global_stats: Global statistics from calculate_global_y_stats.
- std_multiplier: Std multiplier used for peak detection.
- band_width: Half-width for assigning components to peaks (default: 100).
- require_y_zone: If True, only keep components touching Y-zone (default: True).
Returns:
Binary mask with exactly one component per peak (values 0 and 255).
Example:
>>> # For projection method (require Y-zone) >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, ... band_width=100, require_y_zone=True) >>> >>> # For size fallback (allow fallen shoots outside Y-zone) >>> final_mask = filter_one_per_peak(mask, peaks, global_stats, std_used, ... band_width=100, require_y_zone=False)
Note:
Set require_y_zone=False when using size-based fallback method to preserve large shoots that have fallen outside the typical vertical zone.
577def clean_shoot_mask_pipeline(mask_path, global_stats, closing_kernel_size=7, 578 closing_iterations=5, quality_check_threshold=2.5, 579 band_width=100): 580 """Complete pipeline to clean a shoot mask from raw predictions to final output. 581 582 Runs the full processing workflow: 583 1. Load mask and apply initial morphological closing 584 2. Detect 5 shoot locations using adaptive peak finding 585 3. Filter to components near detected peaks 586 4. Keep exactly one component per peak (largest in each band) 587 588 Args: 589 mask_path: Path to shoot mask file (string or Path object). 590 global_stats: Global statistics from calculate_global_y_stats. 591 closing_kernel_size: Initial closing kernel size (default: 7). 592 closing_iterations: Initial closing iterations (default: 5). 593 quality_check_threshold: Std threshold for quality checks (default: 2.5). 594 band_width: X-band width around peaks in pixels (default: 100). 595 596 Returns: 597 Dictionary with keys: 598 - 'cleaned_mask': Final cleaned binary mask (0 and 255) 599 - 'peaks': Detected peak X-coordinates 600 - 'std_used': Std multiplier that succeeded 601 - 'method': Detection method ("projection" or "size_fallback") 602 - 'num_components': Number of components in final mask 603 - 'filename': Input filename 604 605 Example: 606 >>> from pathlib import Path 607 >>> 608 >>> # Process all masks 609 >>> mask_dir = Path('data/shoot_masks') 610 >>> for mask_file in mask_dir.glob('*.png'): 611 ... result = clean_shoot_mask_pipeline(str(mask_file), global_stats) 612 ... 613 ... # Save cleaned mask 614 ... output_path = f"cleaned/{result['filename']}" 615 ... cv2.imwrite(output_path, result['cleaned_mask']) 616 ... 617 ... # Log results 618 ... if result['num_components'] != 5: 619 ... print(f"WARNING: {result['filename']} has " 620 ... f"{result['num_components']} components") 621 622 Note: 623 Parameters are optimized for typical plant imaging conditions: 624 - Shoots at X > 1000 pixels 625 - Shoots between Y = 200-750 pixels (adaptive per image) 626 - ~400-500 pixel spacing between plants 627 """ 628 print(f"Processing: {Path(mask_path).name}") 629 630 # Step 1: Load and initial joining 631 mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) 632 joined_mask = join_shoot_fragments(mask, closing_kernel_size, closing_iterations) 633 634 # Step 2: Find shoot locations 635 peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback( 636 joined_mask, global_stats, 637 quality_check_threshold=quality_check_threshold 638 ) 639 print(f" Found {len(peaks)} peaks using {method} (std={std_used:.2f})") 640 641 # Step 3: Filter to narrow bands around peaks 642 narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 643 std_used, band_width=band_width) 644 645 # Step 4: Keep one component per peak 646 require_y_zone = (method != "size_fallback") 647 final_mask = filter_one_per_peak(narrow_result, peaks, global_stats, std_used, 648 band_width=band_width, require_y_zone=require_y_zone) 649 650 # Count final components 651 num_labels, _, _, _ = cv2.connectedComponentsWithStats(final_mask, connectivity=8) 652 num_components = num_labels - 1 653 654 return { 655 'cleaned_mask': final_mask, 656 'peaks': peaks, 657 'std_used': std_used, 658 'method': method, 659 'num_components': num_components, 660 'filename': Path(mask_path).name 661 }
Complete pipeline to clean a shoot mask from raw predictions to final output.
Runs the full processing workflow:
- Load mask and apply initial morphological closing
- Detect 5 shoot locations using adaptive peak finding
- Filter to components near detected peaks
- Keep exactly one component per peak (largest in each band)
Arguments:
- mask_path: Path to shoot mask file (string or Path object).
- global_stats: Global statistics from calculate_global_y_stats.
- closing_kernel_size: Initial closing kernel size (default: 7).
- closing_iterations: Initial closing iterations (default: 5).
- quality_check_threshold: Std threshold for quality checks (default: 2.5).
- band_width: X-band width around peaks in pixels (default: 100).
Returns:
Dictionary with keys: - 'cleaned_mask': Final cleaned binary mask (0 and 255) - 'peaks': Detected peak X-coordinates - 'std_used': Std multiplier that succeeded - 'method': Detection method ("projection" or "size_fallback") - 'num_components': Number of components in final mask - 'filename': Input filename
Example:
>>> from pathlib import Path >>> >>> # Process all masks >>> mask_dir = Path('data/shoot_masks') >>> for mask_file in mask_dir.glob('*.png'): ... result = clean_shoot_mask_pipeline(str(mask_file), global_stats) ... ... # Save cleaned mask ... output_path = f"cleaned/{result['filename']}" ... cv2.imwrite(output_path, result['cleaned_mask']) ... ... # Log results ... if result['num_components'] != 5: ... print(f"WARNING: {result['filename']} has " ... f"{result['num_components']} components")
Note:
Parameters are optimized for typical plant imaging conditions:
- Shoots at X > 1000 pixels
- Shoots between Y = 200-750 pixels (adaptive per image)
- ~400-500 pixel spacing between plants