library.shoot_mask_visualization

Visualization and debugging tools for shoot mask cleaning pipeline.

This module provides visualization functions for analyzing shoot mask cleaning results, debugging pipeline failures, and validating parameter choices. Works in conjunction with shoot_mask_cleaning.py.

Typical usage:

Visualize Y-density statistics:

>>> import matplotlib.pyplot as plt
>>> from shoot_mask_cleaning import (
...     calculate_y_density, 
...     calculate_weighted_y_stats,
...     calculate_global_y_stats,
...     join_shoot_fragments
... )
>>> from shoot_mask_visualization import visualize_y_density_with_global_stats
>>> 
>>> # Calculate global stats
>>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
>>> global_stats = calculate_global_y_stats(mask_files)
>>> 
>>> # Visualize individual image
>>> mask = cv2.imread('data/shoot/image_01.png', cv2.IMREAD_GRAYSCALE)
>>> joined = join_shoot_fragments(mask)
>>> density = calculate_y_density(joined)
>>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
>>> 
>>> visualize_y_density_with_global_stats(joined, density, local_stats, 
...                                       global_stats, std_multiplier=2.0)
>>> plt.show()

Debug complete pipeline:

>>> from shoot_mask_visualization import debug_peak_detection
>>> 
>>> # Visualize all processing steps for troubleshooting
>>> debug_peak_detection('data/shoot/problematic_image.png', global_stats)
>>> plt.show()

Visualize final results:

>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
>>> from shoot_mask_visualization import visualize_final_components
>>> 
>>> result = clean_shoot_mask_pipeline('mask.png', global_stats)
>>> visualize_final_components(result['cleaned_mask'], 
...                            title=f"Final: {result['filename']}")
>>> plt.show()
Dependencies:
  • matplotlib
  • numpy
  • opencv-python (cv2)
  • scipy
  • shoot_mask_cleaning (companion module)

Author: Aaron Ciuffo Date: December 2024

  1"""Visualization and debugging tools for shoot mask cleaning pipeline.
  2
  3This module provides visualization functions for analyzing shoot mask cleaning
  4results, debugging pipeline failures, and validating parameter choices. Works in
  5conjunction with shoot_mask_cleaning.py.
  6
  7Typical usage:
  8
  9    Visualize Y-density statistics:
 10    
 11        >>> import matplotlib.pyplot as plt
 12        >>> from shoot_mask_cleaning import (
 13        ...     calculate_y_density, 
 14        ...     calculate_weighted_y_stats,
 15        ...     calculate_global_y_stats,
 16        ...     join_shoot_fragments
 17        ... )
 18        >>> from shoot_mask_visualization import visualize_y_density_with_global_stats
 19        >>> 
 20        >>> # Calculate global stats
 21        >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
 22        >>> global_stats = calculate_global_y_stats(mask_files)
 23        >>> 
 24        >>> # Visualize individual image
 25        >>> mask = cv2.imread('data/shoot/image_01.png', cv2.IMREAD_GRAYSCALE)
 26        >>> joined = join_shoot_fragments(mask)
 27        >>> density = calculate_y_density(joined)
 28        >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
 29        >>> 
 30        >>> visualize_y_density_with_global_stats(joined, density, local_stats, 
 31        ...                                       global_stats, std_multiplier=2.0)
 32        >>> plt.show()
 33    
 34    Debug complete pipeline:
 35    
 36        >>> from shoot_mask_visualization import debug_peak_detection
 37        >>> 
 38        >>> # Visualize all processing steps for troubleshooting
 39        >>> debug_peak_detection('data/shoot/problematic_image.png', global_stats)
 40        >>> plt.show()
 41    
 42    Visualize final results:
 43    
 44        >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
 45        >>> from shoot_mask_visualization import visualize_final_components
 46        >>> 
 47        >>> result = clean_shoot_mask_pipeline('mask.png', global_stats)
 48        >>> visualize_final_components(result['cleaned_mask'], 
 49        ...                            title=f"Final: {result['filename']}")
 50        >>> plt.show()
 51
 52Dependencies:
 53    - matplotlib
 54    - numpy
 55    - opencv-python (cv2)
 56    - scipy
 57    - shoot_mask_cleaning (companion module)
 58
 59Author: Aaron Ciuffo
 60Date: December 2024
 61"""
 62
 63import numpy as np
 64import cv2
 65import matplotlib.pyplot as plt
 66from pathlib import Path
 67
 68
 69def visualize_y_density_with_global_stats(mask, density, local_stats, global_stats,
 70                                          std_multiplier=2.0, figsize=(10, 8)):
 71    """Visualize mask with both local and global Y-density statistics.
 72    
 73    Creates a two-panel figure showing:
 74    1. Y-density histogram with local and global zones highlighted
 75    2. Mask with corresponding zone overlays and mean position markers
 76    
 77    Args:
 78        mask: Binary mask array (0 and 255).
 79        density: 1D array of pixel counts per row from calculate_y_density.
 80        local_stats: Local statistics dict with 'mean' and 'std' keys.
 81        global_stats: Global statistics from calculate_global_y_stats.
 82        std_multiplier: Standard deviation multiplier for zone width (default: 2.0).
 83        figsize: Figure size as (width, height) tuple (default: (10, 8)).
 84        
 85    Example:
 86        >>> from shoot_mask_cleaning import (
 87        ...     calculate_y_density,
 88        ...     calculate_weighted_y_stats,
 89        ...     join_shoot_fragments
 90        ... )
 91        >>> 
 92        >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
 93        >>> joined = join_shoot_fragments(mask)
 94        >>> density = calculate_y_density(joined)
 95        >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
 96        >>> 
 97        >>> visualize_y_density_with_global_stats(joined, density, local_stats,
 98        ...                                       global_stats, std_multiplier=2.0)
 99        >>> plt.show()
100    
101    Note:
102        Red = local image statistics
103        Yellow = global dataset statistics
104        Orange markers = mean positions
105    """
106    fig, (ax_hist, ax_mask) = plt.subplots(1, 2, figsize=figsize,
107                                            gridspec_kw={'width_ratios': [1, 3]})
108    
109    # Plot density histogram
110    ax_hist.barh(range(len(density)), density, height=1, color='blue', alpha=0.6)
111    ax_hist.set_ylim(len(density), 0)
112    ax_hist.set_xlabel('Pixel count')
113    ax_hist.set_ylabel('Y coordinate')
114    ax_hist.invert_xaxis()
115    
116    # Calculate bounds
117    local_upper = local_stats['mean'] - std_multiplier * local_stats['std']
118    local_lower = local_stats['mean'] + std_multiplier * local_stats['std']
119    global_upper = global_stats['global_mean'] - std_multiplier * global_stats['global_std']
120    global_lower = global_stats['global_mean'] + std_multiplier * global_stats['global_std']
121    
122    # Draw ROI bounds (gray)
123    if 'y_min' in global_stats and 'y_max' in global_stats:
124        ax_hist.axhline(global_stats['y_min'], color='gray', linestyle='-',
125                       linewidth=1, alpha=0.5)
126        ax_hist.axhline(global_stats['y_max'], color='gray', linestyle='-',
127                       linewidth=1, alpha=0.5)
128    
129    # Draw shaded zones on histogram
130    ax_hist.axhspan(local_upper, local_lower, alpha=0.2, color='red', label='Local zone')
131    ax_hist.axhspan(global_upper, global_lower, alpha=0.2, color='yellow', label='Global zone')
132    
133    # Draw mean lines
134    ax_hist.axhline(local_stats['mean'], color='red', linestyle='-',
135                   linewidth=2, label='Local mean')
136    ax_hist.axhline(global_stats['global_mean'], color='orange', linestyle='-',
137                   linewidth=2, label='Global mean')
138    
139    ax_hist.legend(loc='upper right', fontsize=8)
140    ax_hist.grid(True, alpha=0.3)
141    
142    # Show mask with overlays
143    ax_mask.imshow(mask, cmap='gray', aspect='equal')
144    
145    # Draw ROI bounds
146    if 'y_min' in global_stats and 'y_max' in global_stats:
147        ax_mask.axhline(global_stats['y_min'], color='gray', linestyle='-',
148                       linewidth=1, alpha=0.5)
149        ax_mask.axhline(global_stats['y_max'], color='gray', linestyle='-',
150                       linewidth=1, alpha=0.5)
151    
152    # Draw shaded zones on mask
153    ax_mask.axhspan(local_upper, local_lower, alpha=0.15, color='red')
154    ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
155    
156    # Add left edge markers for means (points only)
157    ax_mask.plot(0, local_stats['mean'], 'o', color='red', markersize=10)
158    ax_mask.plot(0, global_stats['global_mean'], 'o', color='orange', markersize=10)
159    
160    ax_mask.axis('off')
161    
162    plt.tight_layout()
163
164
165def visualize_x_projection_with_peaks_adaptive(mask, x_projection, peaks, global_stats,
166                                                std_multiplier, figsize=(14, 8)):
167    """Visualize mask and X-projection with detected peak locations.
168    
169    Creates a two-panel figure showing:
170    1. Mask with Y-zone overlay and vertical lines at peak positions
171    2. X-projection histogram with peaks marked
172    
173    Args:
174        mask: Binary mask array (0 and 255).
175        x_projection: 1D array of pixel counts from calculate_x_projection.
176        peaks: Array of peak X-coordinates.
177        global_stats: Global statistics from calculate_global_y_stats.
178        std_multiplier: The std multiplier actually used for this image.
179        figsize: Figure size as (width, height) tuple (default: (14, 8)).
180        
181    Example:
182        >>> from shoot_mask_cleaning import (
183        ...     join_shoot_fragments,
184        ...     find_shoot_peaks_with_size_fallback
185        ... )
186        >>> 
187        >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
188        >>> joined = join_shoot_fragments(mask)
189        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
190        ...     joined, global_stats
191        ... )
192        >>> 
193        >>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks,
194        ...                                            global_stats, std)
195        >>> plt.show()
196    """
197    fig, (ax_mask, ax_proj) = plt.subplots(2, 1, figsize=figsize,
198                                            gridspec_kw={'height_ratios': [3, 1]})
199    
200    # Calculate Y bounds
201    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
202    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
203    
204    # Show mask with peaks
205    ax_mask.imshow(mask, cmap='gray', aspect='equal')
206    ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
207    
208    for i, peak_x in enumerate(peaks):
209        ax_mask.axvline(peak_x, color='red', linestyle='--', linewidth=2)
210        ax_mask.text(peak_x, 100, f'{i+1}', color='red', fontsize=12,
211                    ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
212    
213    ax_mask.set_title(f'Mask with {len(peaks)} Detected Peaks (std={std_multiplier:.2f})')
214    ax_mask.axis('off')
215    
216    # Show X-projection with peaks
217    ax_proj.fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue')
218    ax_proj.set_xlim(0, len(x_projection))
219    ax_proj.set_xlabel('X coordinate (pixels)')
220    ax_proj.set_ylabel('Pixel count')
221    ax_proj.set_title('X-axis Projection')
222    ax_proj.grid(True, alpha=0.3)
223    
224    # Mark peaks
225    for peak_x in peaks:
226        ax_proj.axvline(peak_x, color='red', linestyle='--', linewidth=2)
227        ax_proj.plot(peak_x, x_projection[peak_x], 'ro', markersize=10)
228    
229    plt.tight_layout()
230
231
232def visualize_final_components(mask, global_stats, std_used, title="Final Mask", 
233                               figsize=(16, 6)):
234    """Visualize final mask with component count, labels, and statistics.
235    
236    Creates a two-panel figure showing:
237    1. Binary mask with component count
238    2. Colored label map with component numbers positioned below each shoot
239    
240    Args:
241        mask: Binary mask array (0 and 255).
242        global_stats: Global statistics from calculate_global_y_stats.
243        std_used: Std multiplier used for this image.
244        title: Title prefix for the plot (default: "Final Mask").
245        figsize: Figure size as (width, height) tuple (default: (16, 6)).
246        
247    Returns:
248        Number of components found in the mask.
249        
250    Example:
251        >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
252        >>> 
253        >>> result = clean_shoot_mask_pipeline('mask.png', global_stats)
254        >>> num_components = visualize_final_components(
255        ...     result['cleaned_mask'],
256        ...     global_stats,
257        ...     result['std_used'],
258        ...     title=f"Final: {result['filename']}"
259        ... )
260        >>> 
261        >>> if num_components != 5:
262        ...     print(f"WARNING: Expected 5, found {num_components}")
263        >>> plt.show()
264    
265    Note:
266        Component labels are positioned below each shoot to avoid obscuring
267        the actual shoot structures. Labeled view is zoomed to the Y-zone area
268        for easier inspection.
269    """
270    # Find connected components
271    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
272    num_components = num_labels - 1  # Exclude background
273    
274    # Create a colored label image for visualization
275    label_colors = np.zeros((*mask.shape, 3), dtype=np.uint8)
276    
277    # Assign random colors to each component
278    np.random.seed(42)
279    for i in range(1, num_labels):
280        color = np.random.randint(50, 255, size=3)
281        label_colors[labels == i] = color
282    
283    # Plot
284    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
285    
286    # Original mask
287    ax1.imshow(mask, cmap='gray', aspect='equal')
288    ax1.set_title(f'{title}\nComponents: {num_components}')
289    ax1.axis('off')
290    
291    # Colored labels with centroids - zoomed to area of interest
292    ax2.imshow(label_colors, aspect='equal')
293    
294    # Calculate Y-zone for zoom
295    global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std'])
296    global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std'])
297    
298    # Add padding
299    y_padding = 100
300    zoom_y_min = max(0, global_upper - y_padding)
301    zoom_y_max = min(mask.shape[0], global_lower + y_padding)
302    
303    # Set axis limits to zoom
304    ax2.set_ylim(zoom_y_max, zoom_y_min)  # Inverted for image coordinates
305    ax2.set_xlim(0, mask.shape[1])
306    
307    # Mark centroids and label them - place labels below objects
308    for i in range(1, num_labels):
309        cx, cy = centroids[i]
310        # Get component bottom Y coordinate
311        y_bottom = stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT]
312        
313        # Place label below the object
314        ax2.text(cx, y_bottom + 40, f'{i}', color='white', fontsize=14,
315                ha='center', weight='bold',
316                bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
317    
318    ax2.set_title(f'Labeled Components: {num_components} (Zoomed)')
319    ax2.axis('off')
320    
321    plt.tight_layout()
322    
323    # Print component stats
324    print(f"\nComponent Statistics:")
325    print(f"{'Label':<8} {'Area':<10} {'X-Center':<10} {'Y-Center':<10}")
326    print("-" * 40)
327    for i in range(1, num_labels):
328        area = stats[i, cv2.CC_STAT_AREA]
329        cx, cy = centroids[i]
330        print(f"{i:<8} {area:<10} {cx:<10.1f} {cy:<10.1f}")
331    
332    return num_components
333
334
335def debug_peak_detection(mask_path, global_stats, kernel_size=7, iterations=5,
336                        quality_check_threshold=2.5):
337    """Debug visualization showing all pipeline steps with peak locations.
338    
339    Creates a comprehensive four-panel figure for troubleshooting:
340    1. Original mask with detected peak locations and Y-zone
341    2. Mask after morphological closing with peaks
342    3. X-projection histogram with marked peaks
343    4. Final cleaned result with component labels
344    
345    Args:
346        mask_path: Path to shoot mask file (string or Path object).
347        global_stats: Global statistics from calculate_global_y_stats.
348        kernel_size: Kernel size for initial closing (default: 7).
349        iterations: Number of closing iterations (default: 5).
350        quality_check_threshold: Std threshold for quality checks (default: 2.5).
351        
352    Example:
353        >>> # Debug a problematic image
354        >>> debug_peak_detection('data/shoot/image_11.png', global_stats)
355        >>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight')
356        >>> plt.show()
357        >>> 
358        >>> # Batch debug multiple images
359        >>> for mask_file in problematic_files:
360        ...     debug_peak_detection(mask_file, global_stats)
361        ...     plt.show()
362    
363    Note:
364        This function prints detailed information about peak detection results,
365        component assignments, and filtering decisions to help diagnose issues.
366    """
367    from library.shoot_mask_cleaning import (
368        join_shoot_fragments,
369        find_shoot_peaks_with_size_fallback,
370        merge_with_narrow_bands,
371        filter_one_per_peak
372    )
373    
374    print(f"Debugging: {Path(mask_path).name}")
375    
376    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
377    joined_mask = join_shoot_fragments(mask, kernel_size=kernel_size, iterations=iterations)
378    
379    # Find peaks
380    peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback(
381        joined_mask, global_stats, quality_check_threshold=quality_check_threshold
382    )
383    
384    print(f"  Peaks detected at X positions: {peaks}")
385    print(f"  Method: {method}, Std: {std_used:.2f}")
386    
387    # Process
388    narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 
389                                           std_used, band_width=100)
390    
391    require_y_zone = (method != "size_fallback")
392    final_result = filter_one_per_peak(narrow_result, peaks, global_stats, std_used,
393                                      band_width=100, require_y_zone=require_y_zone)
394    
395    # Calculate Y-zone
396    global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std'])
397    global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std'])
398    
399    # Visualize
400    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
401    
402    # Original with peaks
403    axes[0, 0].imshow(mask, cmap='gray', aspect='equal')
404    axes[0, 0].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
405    for i, peak_x in enumerate(peaks):
406        axes[0, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7)
407        axes[0, 0].text(peak_x, 50, f'P{i+1}', color='red', fontsize=12,
408                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
409    axes[0, 0].set_title('Original with Peak Locations')
410    axes[0, 0].axis('off')
411    
412    # After joining with peaks
413    axes[0, 1].imshow(joined_mask, cmap='gray', aspect='equal')
414    axes[0, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
415    for i, peak_x in enumerate(peaks):
416        axes[0, 1].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7)
417    axes[0, 1].set_title('After Closing with Peaks')
418    axes[0, 1].axis('off')
419    
420    # X-projection
421    axes[1, 0].fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue')
422    axes[1, 0].set_xlim(0, len(x_projection))
423    for i, peak_x in enumerate(peaks):
424        axes[1, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2)
425        axes[1, 0].plot(peak_x, x_projection[peak_x], 'ro', markersize=10)
426        axes[1, 0].text(peak_x, x_projection[peak_x] + 5, f'P{i+1}',
427                       ha='center', color='red', weight='bold')
428    axes[1, 0].set_xlabel('X coordinate')
429    axes[1, 0].set_ylabel('Pixel count')
430    axes[1, 0].set_title('X-Projection with Detected Peaks')
431    axes[1, 0].grid(True, alpha=0.3)
432    
433    # Final with component labels
434    num_labels, labels_map, comp_stats, comp_centroids = cv2.connectedComponentsWithStats(
435        final_result, connectivity=8
436    )
437    axes[1, 1].imshow(final_result, cmap='gray', aspect='equal')
438    axes[1, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
439    for i in range(1, num_labels):
440        cx, cy = comp_centroids[i]
441        y_bottom = comp_stats[i, cv2.CC_STAT_TOP] + comp_stats[i, cv2.CC_STAT_HEIGHT]
442        axes[1, 1].text(cx, y_bottom + 40, f'{i}', color='red', fontsize=14,
443                       ha='center', weight='bold',
444                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
445    axes[1, 1].set_title(f'Final: {num_labels-1} components')
446    axes[1, 1].axis('off')
447    
448    plt.tight_layout()
449
450
451def batch_process_and_visualize(mask_paths, global_stats, output_dir=None,
452                                save_masks=True, show_plots=True):
453    """Process all masks and generate summary visualizations.
454    
455    Runs the complete cleaning pipeline on multiple masks and creates:
456    1. Individual processed masks (saved to output_dir if specified)
457    2. Summary statistics printed to console
458    3. Optional visualization of each result
459    
460    Args:
461        mask_paths: List of paths to mask files.
462        global_stats: Global statistics from calculate_global_y_stats.
463        output_dir: Directory to save cleaned masks (default: None = don't save).
464        save_masks: Whether to save cleaned masks (default: True).
465        show_plots: Whether to display plots for each image (default: True).
466        
467    Returns:
468        List of result dictionaries from clean_shoot_mask_pipeline.
469        
470    Example:
471        >>> from pathlib import Path
472        >>> 
473        >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
474        >>> results = batch_process_and_visualize(
475        ...     mask_files,
476        ...     global_stats,
477        ...     output_dir='data/shoot_cleaned',
478        ...     save_masks=True,
479        ...     show_plots=False
480        ... )
481        >>> 
482        >>> # Summary
483        >>> success_count = sum(1 for r in results if r['num_components'] == 5)
484        >>> print(f"Success rate: {success_count}/{len(results)}")
485    """
486    from shoot_mask_cleaning import clean_shoot_mask_pipeline
487    
488    if output_dir is not None and save_masks:
489        output_path = Path(output_dir)
490        output_path.mkdir(parents=True, exist_ok=True)
491    
492    results = []
493    
494    for mask_path in mask_paths:
495        # Process mask
496        result = clean_shoot_mask_pipeline(mask_path, global_stats)
497        results.append(result)
498        
499        # Save if requested
500        if output_dir is not None and save_masks:
501            save_path = output_path / result['filename']
502            cv2.imwrite(str(save_path), result['cleaned_mask'])
503        
504        # Visualize if requested
505        if show_plots:
506            visualize_final_components(result['cleaned_mask'], global_stats,
507                                      result['std_used'],
508                                      title=f"Final: {result['filename']}")
509            plt.show()
510    
511    # Print summary
512    print("\n" + "="*70)
513    print("BATCH PROCESSING SUMMARY")
514    print("="*70)
515    print(f"{'Filename':<30} {'Method':<15} {'Std':>6} {'Components':>12}")
516    print("-"*70)
517    
518    for result in results:
519        print(f"{result['filename']:<30} {result['method']:<15} "
520              f"{result['std_used']:>6.2f} {result['num_components']:>12}")
521    
522    # Summary statistics
523    total = len(results)
524    correct_count = sum(1 for r in results if r['num_components'] == 5)
525    projection_count = sum(1 for r in results if r['method'] == 'projection')
526    
527    print("-"*70)
528    print(f"Total processed: {total}")
529    print(f"Correct component count (5): {correct_count} ({100*correct_count/total:.1f}%)")
530    print(f"Projection method: {projection_count} ({100*projection_count/total:.1f}%)")
531    print(f"Size fallback method: {total-projection_count} ({100*(total-projection_count)/total:.1f}%)")
532    print("="*70 + "\n")
533    
534    return results
def visualize_y_density_with_global_stats( mask, density, local_stats, global_stats, std_multiplier=2.0, figsize=(10, 8)):
 70def visualize_y_density_with_global_stats(mask, density, local_stats, global_stats,
 71                                          std_multiplier=2.0, figsize=(10, 8)):
 72    """Visualize mask with both local and global Y-density statistics.
 73    
 74    Creates a two-panel figure showing:
 75    1. Y-density histogram with local and global zones highlighted
 76    2. Mask with corresponding zone overlays and mean position markers
 77    
 78    Args:
 79        mask: Binary mask array (0 and 255).
 80        density: 1D array of pixel counts per row from calculate_y_density.
 81        local_stats: Local statistics dict with 'mean' and 'std' keys.
 82        global_stats: Global statistics from calculate_global_y_stats.
 83        std_multiplier: Standard deviation multiplier for zone width (default: 2.0).
 84        figsize: Figure size as (width, height) tuple (default: (10, 8)).
 85        
 86    Example:
 87        >>> from shoot_mask_cleaning import (
 88        ...     calculate_y_density,
 89        ...     calculate_weighted_y_stats,
 90        ...     join_shoot_fragments
 91        ... )
 92        >>> 
 93        >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
 94        >>> joined = join_shoot_fragments(mask)
 95        >>> density = calculate_y_density(joined)
 96        >>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
 97        >>> 
 98        >>> visualize_y_density_with_global_stats(joined, density, local_stats,
 99        ...                                       global_stats, std_multiplier=2.0)
100        >>> plt.show()
101    
102    Note:
103        Red = local image statistics
104        Yellow = global dataset statistics
105        Orange markers = mean positions
106    """
107    fig, (ax_hist, ax_mask) = plt.subplots(1, 2, figsize=figsize,
108                                            gridspec_kw={'width_ratios': [1, 3]})
109    
110    # Plot density histogram
111    ax_hist.barh(range(len(density)), density, height=1, color='blue', alpha=0.6)
112    ax_hist.set_ylim(len(density), 0)
113    ax_hist.set_xlabel('Pixel count')
114    ax_hist.set_ylabel('Y coordinate')
115    ax_hist.invert_xaxis()
116    
117    # Calculate bounds
118    local_upper = local_stats['mean'] - std_multiplier * local_stats['std']
119    local_lower = local_stats['mean'] + std_multiplier * local_stats['std']
120    global_upper = global_stats['global_mean'] - std_multiplier * global_stats['global_std']
121    global_lower = global_stats['global_mean'] + std_multiplier * global_stats['global_std']
122    
123    # Draw ROI bounds (gray)
124    if 'y_min' in global_stats and 'y_max' in global_stats:
125        ax_hist.axhline(global_stats['y_min'], color='gray', linestyle='-',
126                       linewidth=1, alpha=0.5)
127        ax_hist.axhline(global_stats['y_max'], color='gray', linestyle='-',
128                       linewidth=1, alpha=0.5)
129    
130    # Draw shaded zones on histogram
131    ax_hist.axhspan(local_upper, local_lower, alpha=0.2, color='red', label='Local zone')
132    ax_hist.axhspan(global_upper, global_lower, alpha=0.2, color='yellow', label='Global zone')
133    
134    # Draw mean lines
135    ax_hist.axhline(local_stats['mean'], color='red', linestyle='-',
136                   linewidth=2, label='Local mean')
137    ax_hist.axhline(global_stats['global_mean'], color='orange', linestyle='-',
138                   linewidth=2, label='Global mean')
139    
140    ax_hist.legend(loc='upper right', fontsize=8)
141    ax_hist.grid(True, alpha=0.3)
142    
143    # Show mask with overlays
144    ax_mask.imshow(mask, cmap='gray', aspect='equal')
145    
146    # Draw ROI bounds
147    if 'y_min' in global_stats and 'y_max' in global_stats:
148        ax_mask.axhline(global_stats['y_min'], color='gray', linestyle='-',
149                       linewidth=1, alpha=0.5)
150        ax_mask.axhline(global_stats['y_max'], color='gray', linestyle='-',
151                       linewidth=1, alpha=0.5)
152    
153    # Draw shaded zones on mask
154    ax_mask.axhspan(local_upper, local_lower, alpha=0.15, color='red')
155    ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
156    
157    # Add left edge markers for means (points only)
158    ax_mask.plot(0, local_stats['mean'], 'o', color='red', markersize=10)
159    ax_mask.plot(0, global_stats['global_mean'], 'o', color='orange', markersize=10)
160    
161    ax_mask.axis('off')
162    
163    plt.tight_layout()

Visualize mask with both local and global Y-density statistics.

Creates a two-panel figure showing:

  1. Y-density histogram with local and global zones highlighted
  2. Mask with corresponding zone overlays and mean position markers
Arguments:
  • mask: Binary mask array (0 and 255).
  • density: 1D array of pixel counts per row from calculate_y_density.
  • local_stats: Local statistics dict with 'mean' and 'std' keys.
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_multiplier: Standard deviation multiplier for zone width (default: 2.0).
  • figsize: Figure size as (width, height) tuple (default: (10, 8)).
Example:
>>> from shoot_mask_cleaning import (
...     calculate_y_density,
...     calculate_weighted_y_stats,
...     join_shoot_fragments
... )
>>> 
>>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
>>> joined = join_shoot_fragments(mask)
>>> density = calculate_y_density(joined)
>>> local_stats = calculate_weighted_y_stats(density, y_min=200, y_max=750)
>>> 
>>> visualize_y_density_with_global_stats(joined, density, local_stats,
...                                       global_stats, std_multiplier=2.0)
>>> plt.show()
Note:

Red = local image statistics Yellow = global dataset statistics Orange markers = mean positions

def visualize_x_projection_with_peaks_adaptive( mask, x_projection, peaks, global_stats, std_multiplier, figsize=(14, 8)):
166def visualize_x_projection_with_peaks_adaptive(mask, x_projection, peaks, global_stats,
167                                                std_multiplier, figsize=(14, 8)):
168    """Visualize mask and X-projection with detected peak locations.
169    
170    Creates a two-panel figure showing:
171    1. Mask with Y-zone overlay and vertical lines at peak positions
172    2. X-projection histogram with peaks marked
173    
174    Args:
175        mask: Binary mask array (0 and 255).
176        x_projection: 1D array of pixel counts from calculate_x_projection.
177        peaks: Array of peak X-coordinates.
178        global_stats: Global statistics from calculate_global_y_stats.
179        std_multiplier: The std multiplier actually used for this image.
180        figsize: Figure size as (width, height) tuple (default: (14, 8)).
181        
182    Example:
183        >>> from shoot_mask_cleaning import (
184        ...     join_shoot_fragments,
185        ...     find_shoot_peaks_with_size_fallback
186        ... )
187        >>> 
188        >>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
189        >>> joined = join_shoot_fragments(mask)
190        >>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
191        ...     joined, global_stats
192        ... )
193        >>> 
194        >>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks,
195        ...                                            global_stats, std)
196        >>> plt.show()
197    """
198    fig, (ax_mask, ax_proj) = plt.subplots(2, 1, figsize=figsize,
199                                            gridspec_kw={'height_ratios': [3, 1]})
200    
201    # Calculate Y bounds
202    global_upper = int(global_stats['global_mean'] - std_multiplier * global_stats['global_std'])
203    global_lower = int(global_stats['global_mean'] + std_multiplier * global_stats['global_std'])
204    
205    # Show mask with peaks
206    ax_mask.imshow(mask, cmap='gray', aspect='equal')
207    ax_mask.axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
208    
209    for i, peak_x in enumerate(peaks):
210        ax_mask.axvline(peak_x, color='red', linestyle='--', linewidth=2)
211        ax_mask.text(peak_x, 100, f'{i+1}', color='red', fontsize=12,
212                    ha='center', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
213    
214    ax_mask.set_title(f'Mask with {len(peaks)} Detected Peaks (std={std_multiplier:.2f})')
215    ax_mask.axis('off')
216    
217    # Show X-projection with peaks
218    ax_proj.fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue')
219    ax_proj.set_xlim(0, len(x_projection))
220    ax_proj.set_xlabel('X coordinate (pixels)')
221    ax_proj.set_ylabel('Pixel count')
222    ax_proj.set_title('X-axis Projection')
223    ax_proj.grid(True, alpha=0.3)
224    
225    # Mark peaks
226    for peak_x in peaks:
227        ax_proj.axvline(peak_x, color='red', linestyle='--', linewidth=2)
228        ax_proj.plot(peak_x, x_projection[peak_x], 'ro', markersize=10)
229    
230    plt.tight_layout()

Visualize mask and X-projection with detected peak locations.

Creates a two-panel figure showing:

  1. Mask with Y-zone overlay and vertical lines at peak positions
  2. X-projection histogram with peaks marked
Arguments:
  • mask: Binary mask array (0 and 255).
  • x_projection: 1D array of pixel counts from calculate_x_projection.
  • peaks: Array of peak X-coordinates.
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_multiplier: The std multiplier actually used for this image.
  • figsize: Figure size as (width, height) tuple (default: (14, 8)).
Example:
>>> from shoot_mask_cleaning import (
...     join_shoot_fragments,
...     find_shoot_peaks_with_size_fallback
... )
>>> 
>>> mask = cv2.imread('shoot.png', cv2.IMREAD_GRAYSCALE)
>>> joined = join_shoot_fragments(mask)
>>> peaks, x_proj, std, method = find_shoot_peaks_with_size_fallback(
...     joined, global_stats
... )
>>> 
>>> visualize_x_projection_with_peaks_adaptive(joined, x_proj, peaks,
...                                            global_stats, std)
>>> plt.show()
def visualize_final_components(mask, global_stats, std_used, title='Final Mask', figsize=(16, 6)):
233def visualize_final_components(mask, global_stats, std_used, title="Final Mask", 
234                               figsize=(16, 6)):
235    """Visualize final mask with component count, labels, and statistics.
236    
237    Creates a two-panel figure showing:
238    1. Binary mask with component count
239    2. Colored label map with component numbers positioned below each shoot
240    
241    Args:
242        mask: Binary mask array (0 and 255).
243        global_stats: Global statistics from calculate_global_y_stats.
244        std_used: Std multiplier used for this image.
245        title: Title prefix for the plot (default: "Final Mask").
246        figsize: Figure size as (width, height) tuple (default: (16, 6)).
247        
248    Returns:
249        Number of components found in the mask.
250        
251    Example:
252        >>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
253        >>> 
254        >>> result = clean_shoot_mask_pipeline('mask.png', global_stats)
255        >>> num_components = visualize_final_components(
256        ...     result['cleaned_mask'],
257        ...     global_stats,
258        ...     result['std_used'],
259        ...     title=f"Final: {result['filename']}"
260        ... )
261        >>> 
262        >>> if num_components != 5:
263        ...     print(f"WARNING: Expected 5, found {num_components}")
264        >>> plt.show()
265    
266    Note:
267        Component labels are positioned below each shoot to avoid obscuring
268        the actual shoot structures. Labeled view is zoomed to the Y-zone area
269        for easier inspection.
270    """
271    # Find connected components
272    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
273    num_components = num_labels - 1  # Exclude background
274    
275    # Create a colored label image for visualization
276    label_colors = np.zeros((*mask.shape, 3), dtype=np.uint8)
277    
278    # Assign random colors to each component
279    np.random.seed(42)
280    for i in range(1, num_labels):
281        color = np.random.randint(50, 255, size=3)
282        label_colors[labels == i] = color
283    
284    # Plot
285    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
286    
287    # Original mask
288    ax1.imshow(mask, cmap='gray', aspect='equal')
289    ax1.set_title(f'{title}\nComponents: {num_components}')
290    ax1.axis('off')
291    
292    # Colored labels with centroids - zoomed to area of interest
293    ax2.imshow(label_colors, aspect='equal')
294    
295    # Calculate Y-zone for zoom
296    global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std'])
297    global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std'])
298    
299    # Add padding
300    y_padding = 100
301    zoom_y_min = max(0, global_upper - y_padding)
302    zoom_y_max = min(mask.shape[0], global_lower + y_padding)
303    
304    # Set axis limits to zoom
305    ax2.set_ylim(zoom_y_max, zoom_y_min)  # Inverted for image coordinates
306    ax2.set_xlim(0, mask.shape[1])
307    
308    # Mark centroids and label them - place labels below objects
309    for i in range(1, num_labels):
310        cx, cy = centroids[i]
311        # Get component bottom Y coordinate
312        y_bottom = stats[i, cv2.CC_STAT_TOP] + stats[i, cv2.CC_STAT_HEIGHT]
313        
314        # Place label below the object
315        ax2.text(cx, y_bottom + 40, f'{i}', color='white', fontsize=14,
316                ha='center', weight='bold',
317                bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
318    
319    ax2.set_title(f'Labeled Components: {num_components} (Zoomed)')
320    ax2.axis('off')
321    
322    plt.tight_layout()
323    
324    # Print component stats
325    print(f"\nComponent Statistics:")
326    print(f"{'Label':<8} {'Area':<10} {'X-Center':<10} {'Y-Center':<10}")
327    print("-" * 40)
328    for i in range(1, num_labels):
329        area = stats[i, cv2.CC_STAT_AREA]
330        cx, cy = centroids[i]
331        print(f"{i:<8} {area:<10} {cx:<10.1f} {cy:<10.1f}")
332    
333    return num_components

Visualize final mask with component count, labels, and statistics.

Creates a two-panel figure showing:

  1. Binary mask with component count
  2. Colored label map with component numbers positioned below each shoot
Arguments:
  • mask: Binary mask array (0 and 255).
  • global_stats: Global statistics from calculate_global_y_stats.
  • std_used: Std multiplier used for this image.
  • title: Title prefix for the plot (default: "Final Mask").
  • figsize: Figure size as (width, height) tuple (default: (16, 6)).
Returns:

Number of components found in the mask.

Example:
>>> from shoot_mask_cleaning import clean_shoot_mask_pipeline
>>> 
>>> result = clean_shoot_mask_pipeline('mask.png', global_stats)
>>> num_components = visualize_final_components(
...     result['cleaned_mask'],
...     global_stats,
...     result['std_used'],
...     title=f"Final: {result['filename']}"
... )
>>> 
>>> if num_components != 5:
...     print(f"WARNING: Expected 5, found {num_components}")
>>> plt.show()
Note:

Component labels are positioned below each shoot to avoid obscuring the actual shoot structures. Labeled view is zoomed to the Y-zone area for easier inspection.

def debug_peak_detection( mask_path, global_stats, kernel_size=7, iterations=5, quality_check_threshold=2.5):
336def debug_peak_detection(mask_path, global_stats, kernel_size=7, iterations=5,
337                        quality_check_threshold=2.5):
338    """Debug visualization showing all pipeline steps with peak locations.
339    
340    Creates a comprehensive four-panel figure for troubleshooting:
341    1. Original mask with detected peak locations and Y-zone
342    2. Mask after morphological closing with peaks
343    3. X-projection histogram with marked peaks
344    4. Final cleaned result with component labels
345    
346    Args:
347        mask_path: Path to shoot mask file (string or Path object).
348        global_stats: Global statistics from calculate_global_y_stats.
349        kernel_size: Kernel size for initial closing (default: 7).
350        iterations: Number of closing iterations (default: 5).
351        quality_check_threshold: Std threshold for quality checks (default: 2.5).
352        
353    Example:
354        >>> # Debug a problematic image
355        >>> debug_peak_detection('data/shoot/image_11.png', global_stats)
356        >>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight')
357        >>> plt.show()
358        >>> 
359        >>> # Batch debug multiple images
360        >>> for mask_file in problematic_files:
361        ...     debug_peak_detection(mask_file, global_stats)
362        ...     plt.show()
363    
364    Note:
365        This function prints detailed information about peak detection results,
366        component assignments, and filtering decisions to help diagnose issues.
367    """
368    from library.shoot_mask_cleaning import (
369        join_shoot_fragments,
370        find_shoot_peaks_with_size_fallback,
371        merge_with_narrow_bands,
372        filter_one_per_peak
373    )
374    
375    print(f"Debugging: {Path(mask_path).name}")
376    
377    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
378    joined_mask = join_shoot_fragments(mask, kernel_size=kernel_size, iterations=iterations)
379    
380    # Find peaks
381    peaks, x_projection, std_used, method = find_shoot_peaks_with_size_fallback(
382        joined_mask, global_stats, quality_check_threshold=quality_check_threshold
383    )
384    
385    print(f"  Peaks detected at X positions: {peaks}")
386    print(f"  Method: {method}, Std: {std_used:.2f}")
387    
388    # Process
389    narrow_result = merge_with_narrow_bands(joined_mask, peaks, global_stats, 
390                                           std_used, band_width=100)
391    
392    require_y_zone = (method != "size_fallback")
393    final_result = filter_one_per_peak(narrow_result, peaks, global_stats, std_used,
394                                      band_width=100, require_y_zone=require_y_zone)
395    
396    # Calculate Y-zone
397    global_upper = int(global_stats['global_mean'] - std_used * global_stats['global_std'])
398    global_lower = int(global_stats['global_mean'] + std_used * global_stats['global_std'])
399    
400    # Visualize
401    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
402    
403    # Original with peaks
404    axes[0, 0].imshow(mask, cmap='gray', aspect='equal')
405    axes[0, 0].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
406    for i, peak_x in enumerate(peaks):
407        axes[0, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7)
408        axes[0, 0].text(peak_x, 50, f'P{i+1}', color='red', fontsize=12,
409                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
410    axes[0, 0].set_title('Original with Peak Locations')
411    axes[0, 0].axis('off')
412    
413    # After joining with peaks
414    axes[0, 1].imshow(joined_mask, cmap='gray', aspect='equal')
415    axes[0, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
416    for i, peak_x in enumerate(peaks):
417        axes[0, 1].axvline(peak_x, color='red', linestyle='--', linewidth=2, alpha=0.7)
418    axes[0, 1].set_title('After Closing with Peaks')
419    axes[0, 1].axis('off')
420    
421    # X-projection
422    axes[1, 0].fill_between(range(len(x_projection)), x_projection, alpha=0.7, color='blue')
423    axes[1, 0].set_xlim(0, len(x_projection))
424    for i, peak_x in enumerate(peaks):
425        axes[1, 0].axvline(peak_x, color='red', linestyle='--', linewidth=2)
426        axes[1, 0].plot(peak_x, x_projection[peak_x], 'ro', markersize=10)
427        axes[1, 0].text(peak_x, x_projection[peak_x] + 5, f'P{i+1}',
428                       ha='center', color='red', weight='bold')
429    axes[1, 0].set_xlabel('X coordinate')
430    axes[1, 0].set_ylabel('Pixel count')
431    axes[1, 0].set_title('X-Projection with Detected Peaks')
432    axes[1, 0].grid(True, alpha=0.3)
433    
434    # Final with component labels
435    num_labels, labels_map, comp_stats, comp_centroids = cv2.connectedComponentsWithStats(
436        final_result, connectivity=8
437    )
438    axes[1, 1].imshow(final_result, cmap='gray', aspect='equal')
439    axes[1, 1].axhspan(global_upper, global_lower, alpha=0.2, color='yellow')
440    for i in range(1, num_labels):
441        cx, cy = comp_centroids[i]
442        y_bottom = comp_stats[i, cv2.CC_STAT_TOP] + comp_stats[i, cv2.CC_STAT_HEIGHT]
443        axes[1, 1].text(cx, y_bottom + 40, f'{i}', color='red', fontsize=14,
444                       ha='center', weight='bold',
445                       bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))
446    axes[1, 1].set_title(f'Final: {num_labels-1} components')
447    axes[1, 1].axis('off')
448    
449    plt.tight_layout()

Debug visualization showing all pipeline steps with peak locations.

Creates a comprehensive four-panel figure for troubleshooting:

  1. Original mask with detected peak locations and Y-zone
  2. Mask after morphological closing with peaks
  3. X-projection histogram with marked peaks
  4. Final cleaned result with component labels
Arguments:
  • mask_path: Path to shoot mask file (string or Path object).
  • global_stats: Global statistics from calculate_global_y_stats.
  • kernel_size: Kernel size for initial closing (default: 7).
  • iterations: Number of closing iterations (default: 5).
  • quality_check_threshold: Std threshold for quality checks (default: 2.5).
Example:
>>> # Debug a problematic image
>>> debug_peak_detection('data/shoot/image_11.png', global_stats)
>>> plt.savefig('debug_output.png', dpi=150, bbox_inches='tight')
>>> plt.show()
>>> 
>>> # Batch debug multiple images
>>> for mask_file in problematic_files:
...     debug_peak_detection(mask_file, global_stats)
...     plt.show()
Note:

This function prints detailed information about peak detection results, component assignments, and filtering decisions to help diagnose issues.

def batch_process_and_visualize( mask_paths, global_stats, output_dir=None, save_masks=True, show_plots=True):
452def batch_process_and_visualize(mask_paths, global_stats, output_dir=None,
453                                save_masks=True, show_plots=True):
454    """Process all masks and generate summary visualizations.
455    
456    Runs the complete cleaning pipeline on multiple masks and creates:
457    1. Individual processed masks (saved to output_dir if specified)
458    2. Summary statistics printed to console
459    3. Optional visualization of each result
460    
461    Args:
462        mask_paths: List of paths to mask files.
463        global_stats: Global statistics from calculate_global_y_stats.
464        output_dir: Directory to save cleaned masks (default: None = don't save).
465        save_masks: Whether to save cleaned masks (default: True).
466        show_plots: Whether to display plots for each image (default: True).
467        
468    Returns:
469        List of result dictionaries from clean_shoot_mask_pipeline.
470        
471    Example:
472        >>> from pathlib import Path
473        >>> 
474        >>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
475        >>> results = batch_process_and_visualize(
476        ...     mask_files,
477        ...     global_stats,
478        ...     output_dir='data/shoot_cleaned',
479        ...     save_masks=True,
480        ...     show_plots=False
481        ... )
482        >>> 
483        >>> # Summary
484        >>> success_count = sum(1 for r in results if r['num_components'] == 5)
485        >>> print(f"Success rate: {success_count}/{len(results)}")
486    """
487    from shoot_mask_cleaning import clean_shoot_mask_pipeline
488    
489    if output_dir is not None and save_masks:
490        output_path = Path(output_dir)
491        output_path.mkdir(parents=True, exist_ok=True)
492    
493    results = []
494    
495    for mask_path in mask_paths:
496        # Process mask
497        result = clean_shoot_mask_pipeline(mask_path, global_stats)
498        results.append(result)
499        
500        # Save if requested
501        if output_dir is not None and save_masks:
502            save_path = output_path / result['filename']
503            cv2.imwrite(str(save_path), result['cleaned_mask'])
504        
505        # Visualize if requested
506        if show_plots:
507            visualize_final_components(result['cleaned_mask'], global_stats,
508                                      result['std_used'],
509                                      title=f"Final: {result['filename']}")
510            plt.show()
511    
512    # Print summary
513    print("\n" + "="*70)
514    print("BATCH PROCESSING SUMMARY")
515    print("="*70)
516    print(f"{'Filename':<30} {'Method':<15} {'Std':>6} {'Components':>12}")
517    print("-"*70)
518    
519    for result in results:
520        print(f"{result['filename']:<30} {result['method']:<15} "
521              f"{result['std_used']:>6.2f} {result['num_components']:>12}")
522    
523    # Summary statistics
524    total = len(results)
525    correct_count = sum(1 for r in results if r['num_components'] == 5)
526    projection_count = sum(1 for r in results if r['method'] == 'projection')
527    
528    print("-"*70)
529    print(f"Total processed: {total}")
530    print(f"Correct component count (5): {correct_count} ({100*correct_count/total:.1f}%)")
531    print(f"Projection method: {projection_count} ({100*projection_count/total:.1f}%)")
532    print(f"Size fallback method: {total-projection_count} ({100*(total-projection_count)/total:.1f}%)")
533    print("="*70 + "\n")
534    
535    return results

Process all masks and generate summary visualizations.

Runs the complete cleaning pipeline on multiple masks and creates:

  1. Individual processed masks (saved to output_dir if specified)
  2. Summary statistics printed to console
  3. Optional visualization of each result
Arguments:
  • mask_paths: List of paths to mask files.
  • global_stats: Global statistics from calculate_global_y_stats.
  • output_dir: Directory to save cleaned masks (default: None = don't save).
  • save_masks: Whether to save cleaned masks (default: True).
  • show_plots: Whether to display plots for each image (default: True).
Returns:

List of result dictionaries from clean_shoot_mask_pipeline.

Example:
>>> from pathlib import Path
>>> 
>>> mask_files = [str(f) for f in Path('data/shoot').glob('*.png')]
>>> results = batch_process_and_visualize(
...     mask_files,
...     global_stats,
...     output_dir='data/shoot_cleaned',
...     save_masks=True,
...     show_plots=False
... )
>>> 
>>> # Summary
>>> success_count = sum(1 for r in results if r['num_components'] == 5)
>>> print(f"Success rate: {success_count}/{len(results)}")