library.shoot_mask_visualization
Visualization and debugging tools for shoot mask cleaning pipeline.
This module provides visualization functions for analyzing shoot mask cleaning results, debugging pipeline failures, and validating parameter choices. Works in conjunction with shoot_mask_cleaning.py.
Typical usage:
Visualize Y-density statistics:
>>> import matplotlib.pyplot as plt >>> from shoot_mask_cleaning import ( ... calculate_y_density, ... calculate_weighted_y_stats, ... calculate_global_y_stats, ... join_shoot_fragments ... ) >>> from shoot_mask_visualization import visualize_y_density_with_global_stats >>> >>> # Calculate global stats >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] >>> global_stats = calculate_global_y_stats(mask_files) >>> >>> # Visualize individual image >>> mask = cv2.imread('data/shoot/image_01.png', cv2.IMREAD_GRAYSCALE) >>> joined = join_shoot_fragments(mask) >>> density = calculate_y_density(joined) >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) >>> >>> visualize_y_density_with_global_stats(joined, density, local_stats, ... global_stats, std_multiplier=2.0) >>> plt.show()Debug complete pipeline:
>>> from shoot_mask_visualization import debug_peak_detection >>> >>> # Visualize all processing steps for troubleshooting >>> debug_peak_detection('data/shoot/problematic_image.png', global_stats) >>> plt.show()Visualize final results:
>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline >>> from shoot_mask_visualization import visualize_final_components >>> >>> result = clean_shoot_mask_pipeline('mask.png', global_stats) >>> visualize_final_components(result['cleaned_mask'], ... title=f"Final: {result['filename']}") >>> plt.show()
Dependencies:
- matplotlib
- numpy
- opencv-python (cv2)
- scipy
- shoot_mask_cleaning (companion module)
Author: Aaron Ciuffo Date: December 2024
1"""Visualization and debugging tools for shoot mask cleaning pipeline. 2 3This module provides visualization functions for analyzing shoot mask cleaning 4results, debugging pipeline failures, and validating parameter choices. Works in 5conjunction with shoot_mask_cleaning.py. 6 7Typical usage: 8 9 Visualize Y-density statistics: 10 11 >>> import matplotlib.pyplot as plt 12 >>> from shoot_mask_cleaning import ( 13 ... calculate_y_density, 14 ... calculate_weighted_y_stats, 15 ... calculate_global_y_stats, 16 ... join_shoot_fragments 17 ... ) 18 >>> from shoot_mask_visualization import visualize_y_density_with_global_stats 19 >>> 20 >>> # Calculate global stats 21 >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] 22 >>> global_stats = calculate_global_y_stats(mask_files) 23 >>> 24 >>> # Visualize individual image 25 >>> mask = cv2.imread('data/shoot/image_01.png', cv2.IMREAD_GRAYSCALE) 26 >>> joined = join_shoot_fragments(mask) 27 >>> density = calculate_y_density(joined) 28 >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) 29 >>> 30 >>> visualize_y_density_with_global_stats(joined, density, local_stats, 31 ... global_stats, std_multiplier=2.0) 32 >>> plt.show() 33 34 Debug complete pipeline: 35 36 >>> from shoot_mask_visualization import debug_peak_detection 37 >>> 38 >>> # Visualize all processing steps for troubleshooting 39 >>> debug_peak_detection('data/shoot/problematic_image.png', global_stats) 40 >>> plt.show() 41 42 Visualize final results: 43 44 >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline 45 >>> from shoot_mask_visualization import visualize_final_components 46 >>> 47 >>> result = clean_shoot_mask_pipeline('mask.png', global_stats) 48 >>> visualize_final_components(result['cleaned_mask'], 49 ... title=f"Final: {result['filename']}") 50 >>> plt.show() 51 52Dependencies: 53 - matplotlib 54 - numpy 55 - opencv-python (cv2) 56 - scipy 57 - shoot_mask_cleaning (companion module) 58 59Author: Aaron Ciuffo 60Date: December 2024 61""" 62 63import numpy as np 64import cv2 65import matplotlib.pyplot as plt 66from pathlib import Path 67 68 69def visualize_y_density_with_global_stats(mask, density, local_stats, global_stats, 70 std_multiplier=2.0, figsize=(10, 8)): 71 """Visualize mask with both local and global Y-density statistics. 72 73 Creates a two-panel figure showing: 74 1. Y-density histogram with local and global zones highlighted 75 2. Mask with corresponding zone overlays and mean position markers 76 77 Args: 78 mask: Binary mask array (0 and 255). 79 density: 1D array of pixel counts per row from calculate_y_density. 80 local_stats: Local statistics dict with 'mean' and 'std' keys. 81 global_stats: Global statistics from calculate_global_y_stats. 82 std_multiplier: Standard deviation multiplier for zone width (default: 2.0). 83 figsize: Figure size as (width, height) tuple (default: (10, 8)). 84 85 Example: 86 >>> from shoot_mask_cleaning import ( 87 ... calculate_y_density, 88 ... calculate_weighted_y_stats, 89 ... join_shoot_fragments 90 ... ) 91 >>> 92 >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) 93 >>> joined = join_shoot_fragments(mask) 94 >>> density = calculate_y_density(joined) 95 >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) 96 >>> 97 >>> visualize_y_density_with_global_stats(joined, density, local_stats, 98 ... global_stats, std_multiplier=2.0) 99 >>> plt.show() 100 101 Note: 102 Red = local image statistics 103 Yellow = global dataset statistics 104 Orange markers = mean positions 105 """ 106 fig, (ax_hist, ax_mask) = plt.subplots(1, 2, figsize=figsize, 107 gridspec_kw={'width_ratios': [1, 3]}) 108 109 # Plot density histogram 110 ax_hist.barh(range(len(density)), density, height=1, color='blue', alpha=0.6) 111 ax_hist.set_ylim(len(density), 0) 112 ax_hist.set_xlabel('Pixel count') 113 ax_hist.set_ylabel('Y coordinate') 114 ax_hist.invert_xaxis() 115 116 # Calculate bounds 117 local_upper = local_stats['mean'] - std_multiplier * local_stats['std'] 118 local_lower = local_stats['mean'] + std_multiplier * local_stats['std'] 119 global_upper = global_stats['global_mean'] - std_multiplier * global_stats['global_std'] 120 global_lower = global_stats['global_mean'] + std_multiplier * global_stats['global_std'] 121 122 # Draw ROI bounds (gray) 123 if 'y_min' in global_stats and 'y_max' in global_stats: 124 ax_hist.axhline(global_stats['y_min'], color='gray', linestyle='-', 125 linewidth=1, alpha=0.5) 126 ax_hist.axhline(global_stats['y_max'], color='gray', linestyle='-', 127 linewidth=1, alpha=0.5) 128 129 # Draw shaded zones on histogram 130 ax_hist.axhspan(local_upper, local_lower, alpha=0.2, color='red', label='Local zone') 131 ax_hist.axhspan(global_upper, global_lower, alpha=0.2, color='yellow', label='Global zone') 132 133 # Draw mean lines 134 ax_hist.axhline(local_stats['mean'], color='red', linestyle='-', 135 linewidth=2, label='Local mean') 136 ax_hist.axhline(global_stats['global_mean'], color='orange', linestyle='-', 137 linewidth=2, label='Global mean') 138 139 ax_hist.legend(loc='upper right', fontsize=8) 140 ax_hist.grid(True, alpha=0.3) 141 142 # Show mask with overlays 143 ax_mask.imshow(mask, cmap='gray', aspect='equal') 144 145 # Draw ROI bounds 146 if 'y_min' in global_stats and 'y_max' in global_stats: 147 ax_mask.axhline(global_stats['y_min'], color='gray', linestyle='-', 148 linewidth=1, alpha=0.5) 149 ax_mask.axhline(global_stats['y_max'], color='gray', linestyle='-', 150 linewidth=1, alpha=0.5) 151 152 # Draw shaded zones on mask 153 ax_mask.axhspan(local_upper, local_lower, alpha=0.15, color='red') 154 ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 155 156 # Add left edge markers for means (points only) 157 ax_mask.plot(0, local_stats['mean'], 'o', color='red', markersize=10) 158 ax_mask.plot(0, global_stats['global_mean'], 'o', color='orange', markersize=10) 159 160 ax_mask.axis('off') 161 162 plt.tight_layout() 163 164 165def visualize_x_projection_with_peaks_adaptive(mask, x_projection, peaks, global_stats, 166 std_multiplier, figsize=(14, 8)): 167 """Visualize mask and X-projection with detected peak locations. 168 169 Creates a two-panel figure showing: 170 1. Mask with Y-zone overlay and vertical lines at peak positions 171 2. X-projection histogram with peaks marked 172 173 Args: 174 mask: Binary mask array (0 and 255). 175 x_projection: 1D array of pixel counts from calculate_x_projection. 176 peaks: Array of peak X-coordinates. 177 global_stats: Global statistics from calculate_global_y_stats. 178 std_multiplier: The std multiplier actually used for this image. 179 figsize: Figure size as (width, height) tuple (default: (14, 8)). 180 181 Example: 182 >>> from shoot_mask_cleaning import ( 183 ... join_shoot_fragments, 184 ... find_shoot_peaks_with_size_fallback 185 ... ) 186 >>> 187 >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) 188 >>> joined = join_shoot_fragments(mask) 189 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 190 ... joined, global_stats 191 ... ) 192 >>> 193 >>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks, 194 ... global_stats, std) 195 >>> plt.show() 196 """ 197 fig, (ax_mask, ax_proj) = plt.subplots(2, 1, figsize=figsize, 198 gridspec_kw={'height_ratios': [3, 1]}) 199 200 # Calculate Y bounds 201 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 202 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 203 204 # Show mask with peaks 205 ax_mask.imshow(mask, cmap='gray', aspect='equal') 206 ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 207 208 for i, peak_x in enumerate(peaks): 209 ax_mask.axvline(peak_x, color='red', linestyle='--', linewidth=2) 210 ax_mask.text(peak_x, 100, f'{i+1}', color='red', fontsize=12, 211 ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) 212 213 ax_mask.set_title(f'Mask with {len(peaks)} Detected Peaks (std={std_multiplier:.2f})') 214 ax_mask.axis('off') 215 216 # Show X-projection with peaks 217 ax_proj.fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue') 218 ax_proj.set_xlim(0, len(x_projection)) 219 ax_proj.set_xlabel('X coordinate (pixels)') 220 ax_proj.set_ylabel('Pixel count') 221 ax_proj.set_title('X-axis Projection') 222 ax_proj.grid(True, alpha=0.3) 223 224 # Mark peaks 225 for peak_x in peaks: 226 ax_proj.axvline(peak_x, color='red', linestyle='--', linewidth=2) 227 ax_proj.plot(peak_x, x_projection[peak_x], 'ro', markersize=10) 228 229 plt.tight_layout() 230 231 232def visualize_final_components(mask, global_stats, std_used, title="Final Mask", 233 figsize=(16, 6)): 234 """Visualize final mask with component count, labels, and statistics. 235 236 Creates a two-panel figure showing: 237 1. Binary mask with component count 238 2. Colored label map with component numbers positioned below each shoot 239 240 Args: 241 mask: Binary mask array (0 and 255). 242 global_stats: Global statistics from calculate_global_y_stats. 243 std_used: Std multiplier used for this image. 244 title: Title prefix for the plot (default: "Final Mask"). 245 figsize: Figure size as (width, height) tuple (default: (16, 6)). 246 247 Returns: 248 Number of components found in the mask. 249 250 Example: 251 >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline 252 >>> 253 >>> result = clean_shoot_mask_pipeline('mask.png', global_stats) 254 >>> num_components = visualize_final_components( 255 ... result['cleaned_mask'], 256 ... global_stats, 257 ... result['std_used'], 258 ... title=f"Final: {result['filename']}" 259 ... ) 260 >>> 261 >>> if num_components != 5: 262 ... print(f"WARNING: Expected 5, found {num_components}") 263 >>> plt.show() 264 265 Note: 266 Component labels are positioned below each shoot to avoid obscuring 267 the actual shoot structures. Labeled view is zoomed to the Y-zone area 268 for easier inspection. 269 """ 270 # Find connected components 271 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) 272 num_components = num_labels - 1 # Exclude background 273 274 # Create a colored label image for visualization 275 label_colors = np.zeros((*mask.shape, 3), dtype=np.uint8) 276 277 # Assign random colors to each component 278 np.random.seed(42) 279 for i in range(1, num_labels): 280 color = np.random.randint(50, 255, size=3) 281 label_colors[labels == i] = color 282 283 # Plot 284 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) 285 286 # Original mask 287 ax1.imshow(mask, cmap='gray', aspect='equal') 288 ax1.set_title(f'{title}\nComponents: {num_components}') 289 ax1.axis('off') 290 291 # Colored labels with centroids - zoomed to area of interest 292 ax2.imshow(label_colors, aspect='equal') 293 294 # Calculate Y-zone for zoom 295 global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std']) 296 global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std']) 297 298 # Add padding 299 y_padding = 100 300 zoom_y_min = max(0, global_upper - y_padding) 301 zoom_y_max = min(mask.shape[0], global_lower + y_padding) 302 303 # Set axis limits to zoom 304 ax2.set_ylim(zoom_y_max, zoom_y_min) # Inverted for image coordinates 305 ax2.set_xlim(0, mask.shape[1]) 306 307 # Mark centroids and label them - place labels below objects 308 for i in range(1, num_labels): 309 cx, cy = centroids[i] 310 # Get component bottom Y coordinate 311 y_bottom = stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] 312 313 # Place label below the object 314 ax2.text(cx, y_bottom + 40, f'{i}', color='white', fontsize=14, 315 ha='center', weight='bold', 316 bbox=dict(boxstyle='round', facecolor='black', alpha=0.7)) 317 318 ax2.set_title(f'Labeled Components: {num_components} (Zoomed)') 319 ax2.axis('off') 320 321 plt.tight_layout() 322 323 # Print component stats 324 print(f"\nComponent Statistics:") 325 print(f"{'Label':<8} {'Area':<10} {'X-Center':<10} {'Y-Center':<10}") 326 print("-" * 40) 327 for i in range(1, num_labels): 328 area = stats[i, cv2.CC_STAT_AREA] 329 cx, cy = centroids[i] 330 print(f"{i:<8} {area:<10} {cx:<10.1f} {cy:<10.1f}") 331 332 return num_components 333 334 335def debug_peak_detection(mask_path, global_stats, kernel_size=7, iterations=5, 336 quality_check_threshold=2.5): 337 """Debug visualization showing all pipeline steps with peak locations. 338 339 Creates a comprehensive four-panel figure for troubleshooting: 340 1. Original mask with detected peak locations and Y-zone 341 2. Mask after morphological closing with peaks 342 3. X-projection histogram with marked peaks 343 4. Final cleaned result with component labels 344 345 Args: 346 mask_path: Path to shoot mask file (string or Path object). 347 global_stats: Global statistics from calculate_global_y_stats. 348 kernel_size: Kernel size for initial closing (default: 7). 349 iterations: Number of closing iterations (default: 5). 350 quality_check_threshold: Std threshold for quality checks (default: 2.5). 351 352 Example: 353 >>> # Debug a problematic image 354 >>> debug_peak_detection('data/shoot/image_11.png', global_stats) 355 >>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight') 356 >>> plt.show() 357 >>> 358 >>> # Batch debug multiple images 359 >>> for mask_file in problematic_files: 360 ... debug_peak_detection(mask_file, global_stats) 361 ... plt.show() 362 363 Note: 364 This function prints detailed information about peak detection results, 365 component assignments, and filtering decisions to help diagnose issues. 366 """ 367 from library.shoot_mask_cleaning import ( 368 join_shoot_fragments, 369 find_shoot_peaks_with_size_fallback, 370 merge_with_narrow_bands, 371 filter_one_per_peak 372 ) 373 374 print(f"Debugging: {Path(mask_path).name}") 375 376 mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) 377 joined_mask = join_shoot_fragments(mask, kernel_size=kernel_size, iterations=iterations) 378 379 # Find peaks 380 peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback( 381 joined_mask, global_stats, quality_check_threshold=quality_check_threshold 382 ) 383 384 print(f" Peaks detected at X positions: {peaks}") 385 print(f" Method: {method}, Std: {std_used:.2f}") 386 387 # Process 388 narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 389 std_used, band_width=100) 390 391 require_y_zone = (method != "size_fallback") 392 final_result = filter_one_per_peak(narrow_result, peaks, global_stats, std_used, 393 band_width=100, require_y_zone=require_y_zone) 394 395 # Calculate Y-zone 396 global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std']) 397 global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std']) 398 399 # Visualize 400 fig, axes = plt.subplots(2, 2, figsize=(18, 12)) 401 402 # Original with peaks 403 axes[0, 0].imshow(mask, cmap='gray', aspect='equal') 404 axes[0, 0].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 405 for i, peak_x in enumerate(peaks): 406 axes[0, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7) 407 axes[0, 0].text(peak_x, 50, f'P{i+1}', color='red', fontsize=12, 408 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) 409 axes[0, 0].set_title('Original with Peak Locations') 410 axes[0, 0].axis('off') 411 412 # After joining with peaks 413 axes[0, 1].imshow(joined_mask, cmap='gray', aspect='equal') 414 axes[0, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 415 for i, peak_x in enumerate(peaks): 416 axes[0, 1].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7) 417 axes[0, 1].set_title('After Closing with Peaks') 418 axes[0, 1].axis('off') 419 420 # X-projection 421 axes[1, 0].fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue') 422 axes[1, 0].set_xlim(0, len(x_projection)) 423 for i, peak_x in enumerate(peaks): 424 axes[1, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2) 425 axes[1, 0].plot(peak_x, x_projection[peak_x], 'ro', markersize=10) 426 axes[1, 0].text(peak_x, x_projection[peak_x] + 5, f'P{i+1}', 427 ha='center', color='red', weight='bold') 428 axes[1, 0].set_xlabel('X coordinate') 429 axes[1, 0].set_ylabel('Pixel count') 430 axes[1, 0].set_title('X-Projection with Detected Peaks') 431 axes[1, 0].grid(True, alpha=0.3) 432 433 # Final with component labels 434 num_labels, labels_map, comp_stats, comp_centroids = cv2.connectedComponentsWithStats( 435 final_result, connectivity=8 436 ) 437 axes[1, 1].imshow(final_result, cmap='gray', aspect='equal') 438 axes[1, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 439 for i in range(1, num_labels): 440 cx, cy = comp_centroids[i] 441 y_bottom = comp_stats[i, cv2.CC_STAT_TOP] + comp_stats[i, cv2.CC_STAT_HEIGHT] 442 axes[1, 1].text(cx, y_bottom + 40, f'{i}', color='red', fontsize=14, 443 ha='center', weight='bold', 444 bbox=dict(boxstyle='round', facecolor='white', alpha=0.9)) 445 axes[1, 1].set_title(f'Final: {num_labels-1} components') 446 axes[1, 1].axis('off') 447 448 plt.tight_layout() 449 450 451def batch_process_and_visualize(mask_paths, global_stats, output_dir=None, 452 save_masks=True, show_plots=True): 453 """Process all masks and generate summary visualizations. 454 455 Runs the complete cleaning pipeline on multiple masks and creates: 456 1. Individual processed masks (saved to output_dir if specified) 457 2. Summary statistics printed to console 458 3. Optional visualization of each result 459 460 Args: 461 mask_paths: List of paths to mask files. 462 global_stats: Global statistics from calculate_global_y_stats. 463 output_dir: Directory to save cleaned masks (default: None = don't save). 464 save_masks: Whether to save cleaned masks (default: True). 465 show_plots: Whether to display plots for each image (default: True). 466 467 Returns: 468 List of result dictionaries from clean_shoot_mask_pipeline. 469 470 Example: 471 >>> from pathlib import Path 472 >>> 473 >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] 474 >>> results = batch_process_and_visualize( 475 ... mask_files, 476 ... global_stats, 477 ... output_dir='data/shoot_cleaned', 478 ... save_masks=True, 479 ... show_plots=False 480 ... ) 481 >>> 482 >>> # Summary 483 >>> success_count = sum(1 for r in results if r['num_components'] == 5) 484 >>> print(f"Success rate: {success_count}/{len(results)}") 485 """ 486 from shoot_mask_cleaning import clean_shoot_mask_pipeline 487 488 if output_dir is not None and save_masks: 489 output_path = Path(output_dir) 490 output_path.mkdir(parents=True, exist_ok=True) 491 492 results = [] 493 494 for mask_path in mask_paths: 495 # Process mask 496 result = clean_shoot_mask_pipeline(mask_path, global_stats) 497 results.append(result) 498 499 # Save if requested 500 if output_dir is not None and save_masks: 501 save_path = output_path / result['filename'] 502 cv2.imwrite(str(save_path), result['cleaned_mask']) 503 504 # Visualize if requested 505 if show_plots: 506 visualize_final_components(result['cleaned_mask'], global_stats, 507 result['std_used'], 508 title=f"Final: {result['filename']}") 509 plt.show() 510 511 # Print summary 512 print("\n" + "="*70) 513 print("BATCH PROCESSING SUMMARY") 514 print("="*70) 515 print(f"{'Filename':<30} {'Method':<15} {'Std':>6} {'Components':>12}") 516 print("-"*70) 517 518 for result in results: 519 print(f"{result['filename']:<30} {result['method']:<15} " 520 f"{result['std_used']:>6.2f} {result['num_components']:>12}") 521 522 # Summary statistics 523 total = len(results) 524 correct_count = sum(1 for r in results if r['num_components'] == 5) 525 projection_count = sum(1 for r in results if r['method'] == 'projection') 526 527 print("-"*70) 528 print(f"Total processed: {total}") 529 print(f"Correct component count (5): {correct_count} ({100*correct_count/total:.1f}%)") 530 print(f"Projection method: {projection_count} ({100*projection_count/total:.1f}%)") 531 print(f"Size fallback method: {total-projection_count} ({100*(total-projection_count)/total:.1f}%)") 532 print("="*70 + "\n") 533 534 return results
70def visualize_y_density_with_global_stats(mask, density, local_stats, global_stats, 71 std_multiplier=2.0, figsize=(10, 8)): 72 """Visualize mask with both local and global Y-density statistics. 73 74 Creates a two-panel figure showing: 75 1. Y-density histogram with local and global zones highlighted 76 2. Mask with corresponding zone overlays and mean position markers 77 78 Args: 79 mask: Binary mask array (0 and 255). 80 density: 1D array of pixel counts per row from calculate_y_density. 81 local_stats: Local statistics dict with 'mean' and 'std' keys. 82 global_stats: Global statistics from calculate_global_y_stats. 83 std_multiplier: Standard deviation multiplier for zone width (default: 2.0). 84 figsize: Figure size as (width, height) tuple (default: (10, 8)). 85 86 Example: 87 >>> from shoot_mask_cleaning import ( 88 ... calculate_y_density, 89 ... calculate_weighted_y_stats, 90 ... join_shoot_fragments 91 ... ) 92 >>> 93 >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) 94 >>> joined = join_shoot_fragments(mask) 95 >>> density = calculate_y_density(joined) 96 >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) 97 >>> 98 >>> visualize_y_density_with_global_stats(joined, density, local_stats, 99 ... global_stats, std_multiplier=2.0) 100 >>> plt.show() 101 102 Note: 103 Red = local image statistics 104 Yellow = global dataset statistics 105 Orange markers = mean positions 106 """ 107 fig, (ax_hist, ax_mask) = plt.subplots(1, 2, figsize=figsize, 108 gridspec_kw={'width_ratios': [1, 3]}) 109 110 # Plot density histogram 111 ax_hist.barh(range(len(density)), density, height=1, color='blue', alpha=0.6) 112 ax_hist.set_ylim(len(density), 0) 113 ax_hist.set_xlabel('Pixel count') 114 ax_hist.set_ylabel('Y coordinate') 115 ax_hist.invert_xaxis() 116 117 # Calculate bounds 118 local_upper = local_stats['mean'] - std_multiplier * local_stats['std'] 119 local_lower = local_stats['mean'] + std_multiplier * local_stats['std'] 120 global_upper = global_stats['global_mean'] - std_multiplier * global_stats['global_std'] 121 global_lower = global_stats['global_mean'] + std_multiplier * global_stats['global_std'] 122 123 # Draw ROI bounds (gray) 124 if 'y_min' in global_stats and 'y_max' in global_stats: 125 ax_hist.axhline(global_stats['y_min'], color='gray', linestyle='-', 126 linewidth=1, alpha=0.5) 127 ax_hist.axhline(global_stats['y_max'], color='gray', linestyle='-', 128 linewidth=1, alpha=0.5) 129 130 # Draw shaded zones on histogram 131 ax_hist.axhspan(local_upper, local_lower, alpha=0.2, color='red', label='Local zone') 132 ax_hist.axhspan(global_upper, global_lower, alpha=0.2, color='yellow', label='Global zone') 133 134 # Draw mean lines 135 ax_hist.axhline(local_stats['mean'], color='red', linestyle='-', 136 linewidth=2, label='Local mean') 137 ax_hist.axhline(global_stats['global_mean'], color='orange', linestyle='-', 138 linewidth=2, label='Global mean') 139 140 ax_hist.legend(loc='upper right', fontsize=8) 141 ax_hist.grid(True, alpha=0.3) 142 143 # Show mask with overlays 144 ax_mask.imshow(mask, cmap='gray', aspect='equal') 145 146 # Draw ROI bounds 147 if 'y_min' in global_stats and 'y_max' in global_stats: 148 ax_mask.axhline(global_stats['y_min'], color='gray', linestyle='-', 149 linewidth=1, alpha=0.5) 150 ax_mask.axhline(global_stats['y_max'], color='gray', linestyle='-', 151 linewidth=1, alpha=0.5) 152 153 # Draw shaded zones on mask 154 ax_mask.axhspan(local_upper, local_lower, alpha=0.15, color='red') 155 ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 156 157 # Add left edge markers for means (points only) 158 ax_mask.plot(0, local_stats['mean'], 'o', color='red', markersize=10) 159 ax_mask.plot(0, global_stats['global_mean'], 'o', color='orange', markersize=10) 160 161 ax_mask.axis('off') 162 163 plt.tight_layout()
Visualize mask with both local and global Y-density statistics.
Creates a two-panel figure showing:
- Y-density histogram with local and global zones highlighted
- Mask with corresponding zone overlays and mean position markers
Arguments:
- mask: Binary mask array (0 and 255).
- density: 1D array of pixel counts per row from calculate_y_density.
- local_stats: Local statistics dict with 'mean' and 'std' keys.
- global_stats: Global statistics from calculate_global_y_stats.
- std_multiplier: Standard deviation multiplier for zone width (default: 2.0).
- figsize: Figure size as (width, height) tuple (default: (10, 8)).
Example:
>>> from shoot_mask_cleaning import ( ... calculate_y_density, ... calculate_weighted_y_stats, ... join_shoot_fragments ... ) >>> >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) >>> joined = join_shoot_fragments(mask) >>> density = calculate_y_density(joined) >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750) >>> >>> visualize_y_density_with_global_stats(joined, density, local_stats, ... global_stats, std_multiplier=2.0) >>> plt.show()
Note:
Red = local image statistics Yellow = global dataset statistics Orange markers = mean positions
166def visualize_x_projection_with_peaks_adaptive(mask, x_projection, peaks, global_stats, 167 std_multiplier, figsize=(14, 8)): 168 """Visualize mask and X-projection with detected peak locations. 169 170 Creates a two-panel figure showing: 171 1. Mask with Y-zone overlay and vertical lines at peak positions 172 2. X-projection histogram with peaks marked 173 174 Args: 175 mask: Binary mask array (0 and 255). 176 x_projection: 1D array of pixel counts from calculate_x_projection. 177 peaks: Array of peak X-coordinates. 178 global_stats: Global statistics from calculate_global_y_stats. 179 std_multiplier: The std multiplier actually used for this image. 180 figsize: Figure size as (width, height) tuple (default: (14, 8)). 181 182 Example: 183 >>> from shoot_mask_cleaning import ( 184 ... join_shoot_fragments, 185 ... find_shoot_peaks_with_size_fallback 186 ... ) 187 >>> 188 >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) 189 >>> joined = join_shoot_fragments(mask) 190 >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( 191 ... joined, global_stats 192 ... ) 193 >>> 194 >>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks, 195 ... global_stats, std) 196 >>> plt.show() 197 """ 198 fig, (ax_mask, ax_proj) = plt.subplots(2, 1, figsize=figsize, 199 gridspec_kw={'height_ratios': [3, 1]}) 200 201 # Calculate Y bounds 202 global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std']) 203 global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std']) 204 205 # Show mask with peaks 206 ax_mask.imshow(mask, cmap='gray', aspect='equal') 207 ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 208 209 for i, peak_x in enumerate(peaks): 210 ax_mask.axvline(peak_x, color='red', linestyle='--', linewidth=2) 211 ax_mask.text(peak_x, 100, f'{i+1}', color='red', fontsize=12, 212 ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) 213 214 ax_mask.set_title(f'Mask with {len(peaks)} Detected Peaks (std={std_multiplier:.2f})') 215 ax_mask.axis('off') 216 217 # Show X-projection with peaks 218 ax_proj.fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue') 219 ax_proj.set_xlim(0, len(x_projection)) 220 ax_proj.set_xlabel('X coordinate (pixels)') 221 ax_proj.set_ylabel('Pixel count') 222 ax_proj.set_title('X-axis Projection') 223 ax_proj.grid(True, alpha=0.3) 224 225 # Mark peaks 226 for peak_x in peaks: 227 ax_proj.axvline(peak_x, color='red', linestyle='--', linewidth=2) 228 ax_proj.plot(peak_x, x_projection[peak_x], 'ro', markersize=10) 229 230 plt.tight_layout()
Visualize mask and X-projection with detected peak locations.
Creates a two-panel figure showing:
- Mask with Y-zone overlay and vertical lines at peak positions
- X-projection histogram with peaks marked
Arguments:
- mask: Binary mask array (0 and 255).
- x_projection: 1D array of pixel counts from calculate_x_projection.
- peaks: Array of peak X-coordinates.
- global_stats: Global statistics from calculate_global_y_stats.
- std_multiplier: The std multiplier actually used for this image.
- figsize: Figure size as (width, height) tuple (default: (14, 8)).
Example:
>>> from shoot_mask_cleaning import ( ... join_shoot_fragments, ... find_shoot_peaks_with_size_fallback ... ) >>> >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE) >>> joined = join_shoot_fragments(mask) >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback( ... joined, global_stats ... ) >>> >>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks, ... global_stats, std) >>> plt.show()
233def visualize_final_components(mask, global_stats, std_used, title="Final Mask", 234 figsize=(16, 6)): 235 """Visualize final mask with component count, labels, and statistics. 236 237 Creates a two-panel figure showing: 238 1. Binary mask with component count 239 2. Colored label map with component numbers positioned below each shoot 240 241 Args: 242 mask: Binary mask array (0 and 255). 243 global_stats: Global statistics from calculate_global_y_stats. 244 std_used: Std multiplier used for this image. 245 title: Title prefix for the plot (default: "Final Mask"). 246 figsize: Figure size as (width, height) tuple (default: (16, 6)). 247 248 Returns: 249 Number of components found in the mask. 250 251 Example: 252 >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline 253 >>> 254 >>> result = clean_shoot_mask_pipeline('mask.png', global_stats) 255 >>> num_components = visualize_final_components( 256 ... result['cleaned_mask'], 257 ... global_stats, 258 ... result['std_used'], 259 ... title=f"Final: {result['filename']}" 260 ... ) 261 >>> 262 >>> if num_components != 5: 263 ... print(f"WARNING: Expected 5, found {num_components}") 264 >>> plt.show() 265 266 Note: 267 Component labels are positioned below each shoot to avoid obscuring 268 the actual shoot structures. Labeled view is zoomed to the Y-zone area 269 for easier inspection. 270 """ 271 # Find connected components 272 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) 273 num_components = num_labels - 1 # Exclude background 274 275 # Create a colored label image for visualization 276 label_colors = np.zeros((*mask.shape, 3), dtype=np.uint8) 277 278 # Assign random colors to each component 279 np.random.seed(42) 280 for i in range(1, num_labels): 281 color = np.random.randint(50, 255, size=3) 282 label_colors[labels == i] = color 283 284 # Plot 285 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) 286 287 # Original mask 288 ax1.imshow(mask, cmap='gray', aspect='equal') 289 ax1.set_title(f'{title}\nComponents: {num_components}') 290 ax1.axis('off') 291 292 # Colored labels with centroids - zoomed to area of interest 293 ax2.imshow(label_colors, aspect='equal') 294 295 # Calculate Y-zone for zoom 296 global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std']) 297 global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std']) 298 299 # Add padding 300 y_padding = 100 301 zoom_y_min = max(0, global_upper - y_padding) 302 zoom_y_max = min(mask.shape[0], global_lower + y_padding) 303 304 # Set axis limits to zoom 305 ax2.set_ylim(zoom_y_max, zoom_y_min) # Inverted for image coordinates 306 ax2.set_xlim(0, mask.shape[1]) 307 308 # Mark centroids and label them - place labels below objects 309 for i in range(1, num_labels): 310 cx, cy = centroids[i] 311 # Get component bottom Y coordinate 312 y_bottom = stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT] 313 314 # Place label below the object 315 ax2.text(cx, y_bottom + 40, f'{i}', color='white', fontsize=14, 316 ha='center', weight='bold', 317 bbox=dict(boxstyle='round', facecolor='black', alpha=0.7)) 318 319 ax2.set_title(f'Labeled Components: {num_components} (Zoomed)') 320 ax2.axis('off') 321 322 plt.tight_layout() 323 324 # Print component stats 325 print(f"\nComponent Statistics:") 326 print(f"{'Label':<8} {'Area':<10} {'X-Center':<10} {'Y-Center':<10}") 327 print("-" * 40) 328 for i in range(1, num_labels): 329 area = stats[i, cv2.CC_STAT_AREA] 330 cx, cy = centroids[i] 331 print(f"{i:<8} {area:<10} {cx:<10.1f} {cy:<10.1f}") 332 333 return num_components
Visualize final mask with component count, labels, and statistics.
Creates a two-panel figure showing:
- Binary mask with component count
- Colored label map with component numbers positioned below each shoot
Arguments:
- mask: Binary mask array (0 and 255).
- global_stats: Global statistics from calculate_global_y_stats.
- std_used: Std multiplier used for this image.
- title: Title prefix for the plot (default: "Final Mask").
- figsize: Figure size as (width, height) tuple (default: (16, 6)).
Returns:
Number of components found in the mask.
Example:
>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline >>> >>> result = clean_shoot_mask_pipeline('mask.png', global_stats) >>> num_components = visualize_final_components( ... result['cleaned_mask'], ... global_stats, ... result['std_used'], ... title=f"Final: {result['filename']}" ... ) >>> >>> if num_components != 5: ... print(f"WARNING: Expected 5, found {num_components}") >>> plt.show()
Note:
Component labels are positioned below each shoot to avoid obscuring the actual shoot structures. Labeled view is zoomed to the Y-zone area for easier inspection.
336def debug_peak_detection(mask_path, global_stats, kernel_size=7, iterations=5, 337 quality_check_threshold=2.5): 338 """Debug visualization showing all pipeline steps with peak locations. 339 340 Creates a comprehensive four-panel figure for troubleshooting: 341 1. Original mask with detected peak locations and Y-zone 342 2. Mask after morphological closing with peaks 343 3. X-projection histogram with marked peaks 344 4. Final cleaned result with component labels 345 346 Args: 347 mask_path: Path to shoot mask file (string or Path object). 348 global_stats: Global statistics from calculate_global_y_stats. 349 kernel_size: Kernel size for initial closing (default: 7). 350 iterations: Number of closing iterations (default: 5). 351 quality_check_threshold: Std threshold for quality checks (default: 2.5). 352 353 Example: 354 >>> # Debug a problematic image 355 >>> debug_peak_detection('data/shoot/image_11.png', global_stats) 356 >>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight') 357 >>> plt.show() 358 >>> 359 >>> # Batch debug multiple images 360 >>> for mask_file in problematic_files: 361 ... debug_peak_detection(mask_file, global_stats) 362 ... plt.show() 363 364 Note: 365 This function prints detailed information about peak detection results, 366 component assignments, and filtering decisions to help diagnose issues. 367 """ 368 from library.shoot_mask_cleaning import ( 369 join_shoot_fragments, 370 find_shoot_peaks_with_size_fallback, 371 merge_with_narrow_bands, 372 filter_one_per_peak 373 ) 374 375 print(f"Debugging: {Path(mask_path).name}") 376 377 mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) 378 joined_mask = join_shoot_fragments(mask, kernel_size=kernel_size, iterations=iterations) 379 380 # Find peaks 381 peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback( 382 joined_mask, global_stats, quality_check_threshold=quality_check_threshold 383 ) 384 385 print(f" Peaks detected at X positions: {peaks}") 386 print(f" Method: {method}, Std: {std_used:.2f}") 387 388 # Process 389 narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 390 std_used, band_width=100) 391 392 require_y_zone = (method != "size_fallback") 393 final_result = filter_one_per_peak(narrow_result, peaks, global_stats, std_used, 394 band_width=100, require_y_zone=require_y_zone) 395 396 # Calculate Y-zone 397 global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std']) 398 global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std']) 399 400 # Visualize 401 fig, axes = plt.subplots(2, 2, figsize=(18, 12)) 402 403 # Original with peaks 404 axes[0, 0].imshow(mask, cmap='gray', aspect='equal') 405 axes[0, 0].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 406 for i, peak_x in enumerate(peaks): 407 axes[0, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7) 408 axes[0, 0].text(peak_x, 50, f'P{i+1}', color='red', fontsize=12, 409 bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) 410 axes[0, 0].set_title('Original with Peak Locations') 411 axes[0, 0].axis('off') 412 413 # After joining with peaks 414 axes[0, 1].imshow(joined_mask, cmap='gray', aspect='equal') 415 axes[0, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 416 for i, peak_x in enumerate(peaks): 417 axes[0, 1].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7) 418 axes[0, 1].set_title('After Closing with Peaks') 419 axes[0, 1].axis('off') 420 421 # X-projection 422 axes[1, 0].fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue') 423 axes[1, 0].set_xlim(0, len(x_projection)) 424 for i, peak_x in enumerate(peaks): 425 axes[1, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2) 426 axes[1, 0].plot(peak_x, x_projection[peak_x], 'ro', markersize=10) 427 axes[1, 0].text(peak_x, x_projection[peak_x] + 5, f'P{i+1}', 428 ha='center', color='red', weight='bold') 429 axes[1, 0].set_xlabel('X coordinate') 430 axes[1, 0].set_ylabel('Pixel count') 431 axes[1, 0].set_title('X-Projection with Detected Peaks') 432 axes[1, 0].grid(True, alpha=0.3) 433 434 # Final with component labels 435 num_labels, labels_map, comp_stats, comp_centroids = cv2.connectedComponentsWithStats( 436 final_result, connectivity=8 437 ) 438 axes[1, 1].imshow(final_result, cmap='gray', aspect='equal') 439 axes[1, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow') 440 for i in range(1, num_labels): 441 cx, cy = comp_centroids[i] 442 y_bottom = comp_stats[i, cv2.CC_STAT_TOP] + comp_stats[i, cv2.CC_STAT_HEIGHT] 443 axes[1, 1].text(cx, y_bottom + 40, f'{i}', color='red', fontsize=14, 444 ha='center', weight='bold', 445 bbox=dict(boxstyle='round', facecolor='white', alpha=0.9)) 446 axes[1, 1].set_title(f'Final: {num_labels-1} components') 447 axes[1, 1].axis('off') 448 449 plt.tight_layout()
Debug visualization showing all pipeline steps with peak locations.
Creates a comprehensive four-panel figure for troubleshooting:
- Original mask with detected peak locations and Y-zone
- Mask after morphological closing with peaks
- X-projection histogram with marked peaks
- Final cleaned result with component labels
Arguments:
- mask_path: Path to shoot mask file (string or Path object).
- global_stats: Global statistics from calculate_global_y_stats.
- kernel_size: Kernel size for initial closing (default: 7).
- iterations: Number of closing iterations (default: 5).
- quality_check_threshold: Std threshold for quality checks (default: 2.5).
Example:
>>> # Debug a problematic image >>> debug_peak_detection('data/shoot/image_11.png', global_stats) >>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight') >>> plt.show() >>> >>> # Batch debug multiple images >>> for mask_file in problematic_files: ... debug_peak_detection(mask_file, global_stats) ... plt.show()
Note:
This function prints detailed information about peak detection results, component assignments, and filtering decisions to help diagnose issues.
452def batch_process_and_visualize(mask_paths, global_stats, output_dir=None, 453 save_masks=True, show_plots=True): 454 """Process all masks and generate summary visualizations. 455 456 Runs the complete cleaning pipeline on multiple masks and creates: 457 1. Individual processed masks (saved to output_dir if specified) 458 2. Summary statistics printed to console 459 3. Optional visualization of each result 460 461 Args: 462 mask_paths: List of paths to mask files. 463 global_stats: Global statistics from calculate_global_y_stats. 464 output_dir: Directory to save cleaned masks (default: None = don't save). 465 save_masks: Whether to save cleaned masks (default: True). 466 show_plots: Whether to display plots for each image (default: True). 467 468 Returns: 469 List of result dictionaries from clean_shoot_mask_pipeline. 470 471 Example: 472 >>> from pathlib import Path 473 >>> 474 >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] 475 >>> results = batch_process_and_visualize( 476 ... mask_files, 477 ... global_stats, 478 ... output_dir='data/shoot_cleaned', 479 ... save_masks=True, 480 ... show_plots=False 481 ... ) 482 >>> 483 >>> # Summary 484 >>> success_count = sum(1 for r in results if r['num_components'] == 5) 485 >>> print(f"Success rate: {success_count}/{len(results)}") 486 """ 487 from shoot_mask_cleaning import clean_shoot_mask_pipeline 488 489 if output_dir is not None and save_masks: 490 output_path = Path(output_dir) 491 output_path.mkdir(parents=True, exist_ok=True) 492 493 results = [] 494 495 for mask_path in mask_paths: 496 # Process mask 497 result = clean_shoot_mask_pipeline(mask_path, global_stats) 498 results.append(result) 499 500 # Save if requested 501 if output_dir is not None and save_masks: 502 save_path = output_path / result['filename'] 503 cv2.imwrite(str(save_path), result['cleaned_mask']) 504 505 # Visualize if requested 506 if show_plots: 507 visualize_final_components(result['cleaned_mask'], global_stats, 508 result['std_used'], 509 title=f"Final: {result['filename']}") 510 plt.show() 511 512 # Print summary 513 print("\n" + "="*70) 514 print("BATCH PROCESSING SUMMARY") 515 print("="*70) 516 print(f"{'Filename':<30} {'Method':<15} {'Std':>6} {'Components':>12}") 517 print("-"*70) 518 519 for result in results: 520 print(f"{result['filename']:<30} {result['method']:<15} " 521 f"{result['std_used']:>6.2f} {result['num_components']:>12}") 522 523 # Summary statistics 524 total = len(results) 525 correct_count = sum(1 for r in results if r['num_components'] == 5) 526 projection_count = sum(1 for r in results if r['method'] == 'projection') 527 528 print("-"*70) 529 print(f"Total processed: {total}") 530 print(f"Correct component count (5): {correct_count} ({100*correct_count/total:.1f}%)") 531 print(f"Projection method: {projection_count} ({100*projection_count/total:.1f}%)") 532 print(f"Size fallback method: {total-projection_count} ({100*(total-projection_count)/total:.1f}%)") 533 print("="*70 + "\n") 534 535 return results
Process all masks and generate summary visualizations.
Runs the complete cleaning pipeline on multiple masks and creates:
- Individual processed masks (saved to output_dir if specified)
- Summary statistics printed to console
- Optional visualization of each result
Arguments:
- mask_paths: List of paths to mask files.
- global_stats: Global statistics from calculate_global_y_stats.
- output_dir: Directory to save cleaned masks (default: None = don't save).
- save_masks: Whether to save cleaned masks (default: True).
- show_plots: Whether to display plots for each image (default: True).
Returns:
List of result dictionaries from clean_shoot_mask_pipeline.
Example:
>>> from pathlib import Path >>> >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')] >>> results = batch_process_and_visualize( ... mask_files, ... global_stats, ... output_dir='data/shoot_cleaned', ... save_masks=True, ... show_plots=False ... ) >>> >>> # Summary >>> success_count = sum(1 for r in results if r['num_components'] == 5) >>> print(f"Success rate: {success_count}/{len(results)}")