library.dataset_visualization

Visualization and testing functions for patch datasets.

  1# library/dataset_visualization.py
  2"""Visualization and testing functions for patch datasets."""
  3
  4import random
  5import cv2
  6import numpy as np
  7import matplotlib.pyplot as plt
  8from pathlib import Path
  9from library.roi import detect_roi, crop_to_roi
 10from library.patch_dataset import create_patches_from_image
 11
 12
 13def verify_patches(pairs, patch_size=128, scaling_factor=1.0, step=None,
 14                   mask_type='root', filter_roi=True, preprocess_fns=None):
 15    """Display a random patch with original image (gridded), mask, and overlay for verification.
 16    
 17    Args:
 18        pairs: List of image-mask pairs from get_image_mask_pairs().
 19        patch_size: Size of patches. Default is 128.
 20        scaling_factor: Scaling factor. Default is 1.0.
 21        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
 22        mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'.
 23        filter_roi: If True, crop to ROI before patching. Default is True.
 24        preprocess_fns: Optional list of preprocessing functions to apply to image.
 25    
 26    Returns:
 27        None
 28    """
 29    if step is None:
 30        step = patch_size
 31    
 32    # Color mapping for mask types
 33    color_map = {
 34        'shoot': {'rgb': [0, 255, 0], 'name': 'green'},    # Green
 35        'seed': {'rgb': [0, 0, 255], 'name': 'blue'},      # Blue
 36        'root': {'rgb': [255, 0, 0], 'name': 'red'},       # Red
 37    }
 38    mask_colors = color_map.get(mask_type, {'rgb': [255, 255, 0], 'name': 'yellow'})
 39
 40    # Pick random image
 41    pair = random.choice(pairs)
 42    img_path = pair['image']
 43    
 44    # Load image and mask
 45    image = cv2.imread(str(img_path))
 46    mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE)
 47    
 48    # Detect ROI
 49    roi_bbox = detect_roi(image) if filter_roi else None
 50    
 51    # Create patches
 52    masks_dict = {mask_type: mask}
 53    result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 
 54                                      step, roi_bbox, preprocess_fns)
 55    
 56    # Pick random patch that has mask content
 57    n_rows, n_cols = result['image'].shape[0], result['image'].shape[1]
 58    
 59    valid_patches = []
 60    for row_idx in range(n_rows):
 61        for col_idx in range(n_cols):
 62            patch_mask = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0]
 63            if patch_mask.sum() > 0:
 64                valid_patches.append((row_idx, col_idx))
 65    
 66    if not valid_patches:
 67        print(f"No valid patches found for {img_path.name}")
 68        return
 69    
 70    # Pick random valid patch
 71    row_idx, col_idx = random.choice(valid_patches)
 72    
 73    # Extract patches
 74    img_patch = result['image'][row_idx, col_idx, 0]
 75    mask_patch = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0]
 76    
 77    # Calculate patch center (in cropped coordinates if ROI used)
 78    center_y = row_idx * step + patch_size // 2
 79    center_x = col_idx * step + patch_size // 2
 80    
 81    # Create overlay
 82    overlay = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB).copy()
 83    overlay[mask_patch > 0] = mask_colors['rgb']
 84    
 85    # Prepare display image (cropped if ROI used)
 86    display_image = crop_to_roi(image, roi_bbox) if roi_bbox else image
 87    # Apply preprocessing for display
 88    if preprocess_fns:
 89        from library.patch_dataset import apply_preprocessing_pipeline
 90        display_image = apply_preprocessing_pipeline(display_image, preprocess_fns)
 91    
 92    # Plot
 93    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
 94    
 95    # Original/cropped image with grid
 96    axes[0, 0].imshow(cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB))
 97    for i in range(0, display_image.shape[0], step):
 98        axes[0, 0].axhline(y=i, color=mask_colors['name'], linewidth=0.5, alpha=0.3)
 99    for j in range(0, display_image.shape[1], step):
100        axes[0, 0].axvline(x=j, color=mask_colors['name'], linewidth=0.5, alpha=0.3)
101    
102    axes[0, 0].plot(center_x, center_y, 'r*', markersize=15)
103    title = 'Cropped to ROI' if roi_bbox else 'Original Image'
104    if preprocess_fns:
105        title += ' (preprocessed)'
106    axes[0, 0].set_title(title)
107    axes[0, 0].axis('off')
108    
109    # Patch
110    axes[0, 1].imshow(cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB))
111    axes[0, 1].set_title(f'Patch [{row_idx},{col_idx}]')
112    axes[0, 1].axis('off')
113    
114    # Mask
115    axes[1, 0].imshow(mask_patch, cmap='gray', vmin=0, vmax=1)
116    axes[1, 0].set_title('Mask')
117    axes[1, 0].axis('off')
118    
119    # Overlay
120    axes[1, 1].imshow(overlay)
121    axes[1, 1].set_title('Overlay')
122    axes[1, 1].axis('off')
123    
124    overlap_pct = (1 - step/patch_size) * 100
125    title = f'{img_path.name} - {mask_type} (overlap {overlap_pct:.0f}%)'
126    title += f' - {len(valid_patches)} valid patches'
127    plt.suptitle(title)
128    plt.tight_layout()
129    plt.show()
130
131
132
133def test_unpatchify(pairs, img_idx=None, mask_type='root', patch_size=128, 
134                    scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None):
135    """Test that patches can be reconstructed back into the original image.
136    
137    Args:
138        pairs: List of image-mask pairs.
139        img_idx: Index of image to test. If None, picks random. Default is None.
140        mask_type: Which mask to test. Default is 'root'.
141        patch_size: Size of patches. Default is 128.
142        scaling_factor: Scaling factor. Default is 1.0.
143        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
144        filter_roi: If True, crop to ROI before patching. Default is True.
145        preprocess_fns: Optional list of preprocessing functions to apply to image.
146    
147    Returns:
148        Boolean indicating if reconstruction matches original.
149    """
150    from library.patch_dataset import process_image, reconstruct_from_patches
151    
152    if step is None:
153        step = patch_size
154    
155    # Pick image
156    if img_idx is None:
157        pair = random.choice(pairs)
158    else:
159        pair = pairs[img_idx]
160    
161    # Load image and mask
162    img_path = pair['image']
163    image = cv2.imread(str(img_path))
164    mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE)
165    
166    # Detect ROI
167    roi_bbox = detect_roi(image) if filter_roi else None
168    
169    print(f"Testing: {img_path.name}")
170    print(f"Original image shape: {image.shape}")
171    print(f"Original mask shape: {mask.shape}")
172    if roi_bbox:
173        print(f"ROI bbox: {roi_bbox}")
174    if preprocess_fns:
175        print(f"Preprocessing: {[fn.__name__ for fn in preprocess_fns]}")
176    print(f"Patch size: {patch_size}, Step: {step}, Overlap: {(1 - step/patch_size)*100:.1f}%")
177    
178    # Create patches
179    masks_dict = {mask_type: mask}
180    result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 
181                                      step, roi_bbox, preprocess_fns)
182    
183    # Get patched shapes
184    img_patches = result['image']
185    mask_patches = result['masks'][mask_type]
186    
187    print(f"Patches shape: {img_patches.shape}")
188    
189    # Get padded dimensions
190    n_rows, n_cols = img_patches.shape[0], img_patches.shape[1]
191    padded_h = n_rows * step + (patch_size - step)
192    padded_w = n_cols * step + (patch_size - step)
193    
194    print(f"Padded dimensions: {padded_h} x {padded_w}")
195    
196    # Get the image we should be reconstructing (cropped if ROI used)
197    target_image = crop_to_roi(image, roi_bbox) if roi_bbox else image
198    target_mask = crop_to_roi(mask, roi_bbox) if roi_bbox else mask
199    
200    # Apply preprocessing to target for comparison
201    if preprocess_fns:
202        from library.patch_dataset import apply_preprocessing_pipeline
203        target_image = apply_preprocessing_pipeline(target_image, preprocess_fns)
204    
205    # Process target for comparison (same padding applied during patching)
206    processed_img = process_image(target_image, patch_size, scaling_factor, is_mask=False)
207    processed_mask = process_image(target_mask, patch_size, scaling_factor, is_mask=True)
208    
209    # Reconstruct
210    reconstructed_img = reconstruct_from_patches(
211        img_patches, (padded_h, padded_w, 3), patch_size, step
212    )
213    reconstructed_mask = reconstruct_from_patches(
214        mask_patches, (padded_h, padded_w, 1), patch_size, step
215    )
216    
217    print(f"Reconstructed image shape: {reconstructed_img.shape}")
218    print(f"Reconstructed mask shape: {reconstructed_mask.shape}")
219    
220    # Check if reconstruction matches processed original
221    img_match = np.array_equal(reconstructed_img, processed_img)
222    mask_match = np.array_equal(reconstructed_mask[:, :, 0], processed_mask)
223    
224    print(f"Image reconstruction matches: {img_match}")
225    print(f"Mask reconstruction matches: {mask_match}")
226    
227    # Visualize
228    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
229    
230    # Original with ROI
231    axes[0, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
232    if roi_bbox:
233        (x1, y1), (x2, y2) = roi_bbox
234        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
235                              linewidth=2, edgecolor='green', facecolor='none')
236        axes[0, 0].add_patch(rect)
237    axes[0, 0].set_title('Original Image' + (' + ROI' if roi_bbox else ''))
238    axes[0, 0].axis('off')
239    
240    axes[1, 0].imshow(mask, cmap='gray')
241    axes[1, 0].set_title('Original Mask')
242    axes[1, 0].axis('off')
243    
244    # Padded/Processed (cropped + preprocessed if ROI)
245    axes[0, 1].imshow(cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB))
246    title = 'Cropped + Preprocessed + Padded' if roi_bbox and preprocess_fns else \
247            'Cropped + Padded' if roi_bbox else \
248            'Preprocessed + Padded' if preprocess_fns else 'Padded Image'
249    axes[0, 1].set_title(title)
250    axes[0, 1].axis('off')
251    
252    axes[1, 1].imshow(processed_mask, cmap='gray')
253    axes[1, 1].set_title('Padded Mask')
254    axes[1, 1].axis('off')
255    
256    # Reconstructed
257    axes[0, 2].imshow(cv2.cvtColor(reconstructed_img, cv2.COLOR_BGR2RGB))
258    axes[0, 2].set_title(f'Reconstructed (overlap {(1-step/patch_size)*100:.0f}%)')
259    axes[0, 2].axis('off')
260    
261    axes[1, 2].imshow(reconstructed_mask[:, :, 0], cmap='gray')
262    axes[1, 2].set_title('Reconstructed Mask')
263    axes[1, 2].axis('off')
264    
265    plt.tight_layout()
266    plt.show()
267    
268    return img_match and mask_match
269
270def visualize_patch_reassembly(patch_dir, raw_dir, dataset_type, n_images=3, 
271                               mask_types=['root', 'shoot', 'seed']):
272    """Visualize original images alongside their reassembled patches and mask overlays.
273    
274    Loads random original images and their corresponding patches, reassembles
275    the patches, and displays them side-by-side with yellow borders showing
276    patch boundaries. Also shows mask overlays in different colors.
277    
278    Args:
279        patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test').
280        raw_dir: Directory containing original images (e.g., '../../data/dataset').
281        dataset_type: Either 'train' or 'val'.
282        n_images: Number of random images to visualize. Default is 3.
283        mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed'].
284    
285    Example:
286        >>> visualize_patch_reassembly('../../data/test_output/patched_test', 
287        ...                            '../../data/dataset', 'train', n_images=3)
288    """
289    from library.patch_dataset import load_patch_metadata, reconstruct_from_patches
290    
291    # Define colors for each mask type (RGB)
292    mask_color_map = {
293        'root': [255, 0, 0],      # Red
294        'shoot': [0, 255, 0],     # Green
295        'seed': [0, 0, 255]       # Blue
296    }
297    
298    # Load metadata
299    metadata = load_patch_metadata(patch_dir, dataset_type)
300    patch_size = metadata['dataset_info']['patch_size']
301    step = metadata['dataset_info']['step']
302    filter_roi = metadata['dataset_info'].get('filter_roi', False)
303    preprocessing = metadata['dataset_info'].get('preprocessing', None)
304    
305    # Get unique source images
306    source_images = list(set([p['source_image'] for p in metadata['patches']]))
307    
308    # Select n random images
309    selected_images = random.sample(source_images, min(n_images, len(source_images)))
310    
311    for img_name in selected_images:
312        print(f"Processing {img_name}...")
313        
314        # Load original image
315        img_path = Path(raw_dir) / f'{dataset_type}_images' / img_name
316        original_image = cv2.imread(str(img_path))
317        
318        if original_image is None:
319            print(f"Warning: Could not load {img_path}")
320            continue
321        
322        # Get all patches for this image
323        image_patches = [p for p in metadata['patches'] if p['source_image'] == img_name]
324        
325        if not image_patches:
326            print(f"Warning: No patches found for {img_name}")
327            continue
328        
329        # Get ROI bbox from first patch (all patches from same image have same ROI)
330        roi_bbox = None
331        if 'roi_bbox' in image_patches[0]:
332            roi_bbox = tuple(tuple(coord) for coord in image_patches[0]['roi_bbox'])
333        
334        # Determine grid size from metadata
335        grid_size = image_patches[0]['grid_size']
336        n_rows, n_cols = grid_size
337        
338        # Calculate padded dimensions
339        padded_h = n_rows * step + (patch_size - step)
340        padded_w = n_cols * step + (patch_size - step)
341        
342        # Initialize arrays for reassembled image and masks
343        if step < patch_size:
344            # Use averaging reconstruction for overlapping patches
345            reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.float32)
346            counts = np.zeros((padded_h, padded_w), dtype=np.float32)
347            
348            reassembled_masks = {}
349            mask_counts = {}
350            for mask_type in mask_types:
351                reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32)
352                mask_counts[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32)
353        else:
354            # Simple unpatchify for non-overlapping
355            reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.uint8)
356            reassembled_masks = {}
357            for mask_type in mask_types:
358                reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.uint8)
359        
360        # Load and place each patch
361        patch_coords = []
362        for patch_info in image_patches:
363            row_idx = patch_info['row_idx']
364            col_idx = patch_info['col_idx']
365            patch_filename = patch_info['patch_filename']
366            
367            # Load image patch
368            patch_path = Path(patch_dir) / f'{dataset_type}_images' / dataset_type / patch_filename
369            patch = cv2.imread(str(patch_path))
370            
371            if patch is None:
372                continue
373            
374            # Calculate position
375            y_start = row_idx * step
376            x_start = col_idx * step
377            y_end = min(y_start + patch_size, padded_h)
378            x_end = min(x_start + patch_size, padded_w)
379            
380            patch_h = y_end - y_start
381            patch_w = x_end - x_start
382            
383            # Place image patch
384            if step < patch_size:
385                # Averaging for overlapping patches
386                reassembled[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w]
387                counts[y_start:y_end, x_start:x_end] += 1
388            else:
389                # Direct placement for non-overlapping
390                reassembled[y_start:y_end, x_start:x_end] = patch[:patch_h, :patch_w]
391            
392            # Load and place mask patches
393            for mask_type in mask_types:
394                mask_patch_path = Path(patch_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type / patch_filename
395                if mask_patch_path.exists():
396                    mask_patch = cv2.imread(str(mask_patch_path), cv2.IMREAD_GRAYSCALE)
397                    if mask_patch is not None:
398                        if step < patch_size:
399                            reassembled_masks[mask_type][y_start:y_end, x_start:x_end] += mask_patch[:patch_h, :patch_w]
400                            mask_counts[mask_type][y_start:y_end, x_start:x_end] += 1
401                        else:
402                            reassembled_masks[mask_type][y_start:y_end, x_start:x_end] = mask_patch[:patch_h, :patch_w]
403            
404            # Store coordinates for drawing borders
405            patch_coords.append((x_start, y_start, x_end, y_end))
406        
407        # Average overlapping regions if needed
408        if step < patch_size:
409            counts = np.maximum(counts, 1)
410            reassembled = reassembled / counts[:, :, np.newaxis]
411            reassembled = reassembled.astype(np.uint8)
412            
413            for mask_type in mask_types:
414                mask_counts[mask_type] = np.maximum(mask_counts[mask_type], 1)
415                reassembled_masks[mask_type] = reassembled_masks[mask_type] / mask_counts[mask_type]
416                reassembled_masks[mask_type] = reassembled_masks[mask_type].astype(np.uint8)
417        
418        # Create mask overlay on reassembled image
419        overlay = cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB).copy()
420        
421        # Apply each mask with its color
422        for mask_type in mask_types:
423            if mask_type in reassembled_masks:
424                mask = reassembled_masks[mask_type]
425                color = mask_color_map.get(mask_type, [255, 255, 255])
426                overlay[mask > 0] = color
427        
428        # Create visualization
429        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
430        
431        # Original image (with ROI if available)
432        axes[0].imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
433        if roi_bbox:
434            (x1, y1), (x2, y2) = roi_bbox
435            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
436                                linewidth=2, edgecolor='green', facecolor='none')
437            axes[0].add_patch(rect)
438        title = f'Original Image\n{img_name}'
439        if roi_bbox:
440            title += '\n(green = ROI used for patching)'
441        axes[0].set_title(title)
442        axes[0].axis('off')
443        
444        # Reassembled image with patch borders
445        axes[1].imshow(cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB))
446        
447        # Draw yellow borders for each patch
448        for x_start, y_start, x_end, y_end in patch_coords:
449            rect = plt.Rectangle(
450                (x_start, y_start), 
451                x_end - x_start, 
452                y_end - y_start,
453                linewidth=1, 
454                edgecolor='yellow', 
455                facecolor='none'
456            )
457            axes[1].add_patch(rect)
458        
459        title = f'Reassembled from {len(patch_coords)} Patches\n'
460        title += f'Patch size: {patch_size}, Step: {step} '
461        title += f'(overlap: {(1-step/patch_size)*100:.0f}%)'
462        if filter_roi:
463            title += '\n(cropped to ROI before patching)'
464        if preprocessing:
465            title += f'\nPreprocessing: {", ".join(preprocessing)}'
466        axes[1].set_title(title)
467        axes[1].axis('off')
468        
469        # Mask overlay
470        axes[2].imshow(overlay)
471        
472        # Create legend for mask colors
473        legend_text = []
474        for mask_type in mask_types:
475            if mask_type in reassembled_masks:
476                color_name = mask_type.capitalize()
477                legend_text.append(f'{color_name}: {mask_color_map[mask_type]}')
478        
479        axes[2].set_title(f'Mask Overlay\n' + ', '.join(legend_text))
480        axes[2].axis('off')
481        
482        plt.tight_layout()
483        plt.show()
484        
485        print(f"  Reassembled {len(patch_coords)} patches into {padded_h}x{padded_w} image")
486        if roi_bbox:
487            print(f"  Original: {original_image.shape[1]}x{original_image.shape[0]}, Cropped to ROI before patching")
488        else:
489            print(f"  Original: {original_image.shape[1]}x{original_image.shape[0]}")
490        if preprocessing:
491            print(f"  Preprocessing applied: {', '.join(preprocessing)}")
492        print()
def verify_patches( pairs, patch_size=128, scaling_factor=1.0, step=None, mask_type='root', filter_roi=True, preprocess_fns=None):
 14def verify_patches(pairs, patch_size=128, scaling_factor=1.0, step=None,
 15                   mask_type='root', filter_roi=True, preprocess_fns=None):
 16    """Display a random patch with original image (gridded), mask, and overlay for verification.
 17    
 18    Args:
 19        pairs: List of image-mask pairs from get_image_mask_pairs().
 20        patch_size: Size of patches. Default is 128.
 21        scaling_factor: Scaling factor. Default is 1.0.
 22        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
 23        mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'.
 24        filter_roi: If True, crop to ROI before patching. Default is True.
 25        preprocess_fns: Optional list of preprocessing functions to apply to image.
 26    
 27    Returns:
 28        None
 29    """
 30    if step is None:
 31        step = patch_size
 32    
 33    # Color mapping for mask types
 34    color_map = {
 35        'shoot': {'rgb': [0, 255, 0], 'name': 'green'},    # Green
 36        'seed': {'rgb': [0, 0, 255], 'name': 'blue'},      # Blue
 37        'root': {'rgb': [255, 0, 0], 'name': 'red'},       # Red
 38    }
 39    mask_colors = color_map.get(mask_type, {'rgb': [255, 255, 0], 'name': 'yellow'})
 40
 41    # Pick random image
 42    pair = random.choice(pairs)
 43    img_path = pair['image']
 44    
 45    # Load image and mask
 46    image = cv2.imread(str(img_path))
 47    mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE)
 48    
 49    # Detect ROI
 50    roi_bbox = detect_roi(image) if filter_roi else None
 51    
 52    # Create patches
 53    masks_dict = {mask_type: mask}
 54    result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 
 55                                      step, roi_bbox, preprocess_fns)
 56    
 57    # Pick random patch that has mask content
 58    n_rows, n_cols = result['image'].shape[0], result['image'].shape[1]
 59    
 60    valid_patches = []
 61    for row_idx in range(n_rows):
 62        for col_idx in range(n_cols):
 63            patch_mask = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0]
 64            if patch_mask.sum() > 0:
 65                valid_patches.append((row_idx, col_idx))
 66    
 67    if not valid_patches:
 68        print(f"No valid patches found for {img_path.name}")
 69        return
 70    
 71    # Pick random valid patch
 72    row_idx, col_idx = random.choice(valid_patches)
 73    
 74    # Extract patches
 75    img_patch = result['image'][row_idx, col_idx, 0]
 76    mask_patch = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0]
 77    
 78    # Calculate patch center (in cropped coordinates if ROI used)
 79    center_y = row_idx * step + patch_size // 2
 80    center_x = col_idx * step + patch_size // 2
 81    
 82    # Create overlay
 83    overlay = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB).copy()
 84    overlay[mask_patch > 0] = mask_colors['rgb']
 85    
 86    # Prepare display image (cropped if ROI used)
 87    display_image = crop_to_roi(image, roi_bbox) if roi_bbox else image
 88    # Apply preprocessing for display
 89    if preprocess_fns:
 90        from library.patch_dataset import apply_preprocessing_pipeline
 91        display_image = apply_preprocessing_pipeline(display_image, preprocess_fns)
 92    
 93    # Plot
 94    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
 95    
 96    # Original/cropped image with grid
 97    axes[0, 0].imshow(cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB))
 98    for i in range(0, display_image.shape[0], step):
 99        axes[0, 0].axhline(y=i, color=mask_colors['name'], linewidth=0.5, alpha=0.3)
100    for j in range(0, display_image.shape[1], step):
101        axes[0, 0].axvline(x=j, color=mask_colors['name'], linewidth=0.5, alpha=0.3)
102    
103    axes[0, 0].plot(center_x, center_y, 'r*', markersize=15)
104    title = 'Cropped to ROI' if roi_bbox else 'Original Image'
105    if preprocess_fns:
106        title += ' (preprocessed)'
107    axes[0, 0].set_title(title)
108    axes[0, 0].axis('off')
109    
110    # Patch
111    axes[0, 1].imshow(cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB))
112    axes[0, 1].set_title(f'Patch [{row_idx},{col_idx}]')
113    axes[0, 1].axis('off')
114    
115    # Mask
116    axes[1, 0].imshow(mask_patch, cmap='gray', vmin=0, vmax=1)
117    axes[1, 0].set_title('Mask')
118    axes[1, 0].axis('off')
119    
120    # Overlay
121    axes[1, 1].imshow(overlay)
122    axes[1, 1].set_title('Overlay')
123    axes[1, 1].axis('off')
124    
125    overlap_pct = (1 - step/patch_size) * 100
126    title = f'{img_path.name} - {mask_type} (overlap {overlap_pct:.0f}%)'
127    title += f' - {len(valid_patches)} valid patches'
128    plt.suptitle(title)
129    plt.tight_layout()
130    plt.show()

Display a random patch with original image (gridded), mask, and overlay for verification.

Arguments:
  • pairs: List of image-mask pairs from get_image_mask_pairs().
  • patch_size: Size of patches. Default is 128.
  • scaling_factor: Scaling factor. Default is 1.0.
  • step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
  • mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'.
  • filter_roi: If True, crop to ROI before patching. Default is True.
  • preprocess_fns: Optional list of preprocessing functions to apply to image.
Returns:

None

def test_unpatchify( pairs, img_idx=None, mask_type='root', patch_size=128, scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None):
134def test_unpatchify(pairs, img_idx=None, mask_type='root', patch_size=128, 
135                    scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None):
136    """Test that patches can be reconstructed back into the original image.
137    
138    Args:
139        pairs: List of image-mask pairs.
140        img_idx: Index of image to test. If None, picks random. Default is None.
141        mask_type: Which mask to test. Default is 'root'.
142        patch_size: Size of patches. Default is 128.
143        scaling_factor: Scaling factor. Default is 1.0.
144        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
145        filter_roi: If True, crop to ROI before patching. Default is True.
146        preprocess_fns: Optional list of preprocessing functions to apply to image.
147    
148    Returns:
149        Boolean indicating if reconstruction matches original.
150    """
151    from library.patch_dataset import process_image, reconstruct_from_patches
152    
153    if step is None:
154        step = patch_size
155    
156    # Pick image
157    if img_idx is None:
158        pair = random.choice(pairs)
159    else:
160        pair = pairs[img_idx]
161    
162    # Load image and mask
163    img_path = pair['image']
164    image = cv2.imread(str(img_path))
165    mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE)
166    
167    # Detect ROI
168    roi_bbox = detect_roi(image) if filter_roi else None
169    
170    print(f"Testing: {img_path.name}")
171    print(f"Original image shape: {image.shape}")
172    print(f"Original mask shape: {mask.shape}")
173    if roi_bbox:
174        print(f"ROI bbox: {roi_bbox}")
175    if preprocess_fns:
176        print(f"Preprocessing: {[fn.__name__ for fn in preprocess_fns]}")
177    print(f"Patch size: {patch_size}, Step: {step}, Overlap: {(1 - step/patch_size)*100:.1f}%")
178    
179    # Create patches
180    masks_dict = {mask_type: mask}
181    result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 
182                                      step, roi_bbox, preprocess_fns)
183    
184    # Get patched shapes
185    img_patches = result['image']
186    mask_patches = result['masks'][mask_type]
187    
188    print(f"Patches shape: {img_patches.shape}")
189    
190    # Get padded dimensions
191    n_rows, n_cols = img_patches.shape[0], img_patches.shape[1]
192    padded_h = n_rows * step + (patch_size - step)
193    padded_w = n_cols * step + (patch_size - step)
194    
195    print(f"Padded dimensions: {padded_h} x {padded_w}")
196    
197    # Get the image we should be reconstructing (cropped if ROI used)
198    target_image = crop_to_roi(image, roi_bbox) if roi_bbox else image
199    target_mask = crop_to_roi(mask, roi_bbox) if roi_bbox else mask
200    
201    # Apply preprocessing to target for comparison
202    if preprocess_fns:
203        from library.patch_dataset import apply_preprocessing_pipeline
204        target_image = apply_preprocessing_pipeline(target_image, preprocess_fns)
205    
206    # Process target for comparison (same padding applied during patching)
207    processed_img = process_image(target_image, patch_size, scaling_factor, is_mask=False)
208    processed_mask = process_image(target_mask, patch_size, scaling_factor, is_mask=True)
209    
210    # Reconstruct
211    reconstructed_img = reconstruct_from_patches(
212        img_patches, (padded_h, padded_w, 3), patch_size, step
213    )
214    reconstructed_mask = reconstruct_from_patches(
215        mask_patches, (padded_h, padded_w, 1), patch_size, step
216    )
217    
218    print(f"Reconstructed image shape: {reconstructed_img.shape}")
219    print(f"Reconstructed mask shape: {reconstructed_mask.shape}")
220    
221    # Check if reconstruction matches processed original
222    img_match = np.array_equal(reconstructed_img, processed_img)
223    mask_match = np.array_equal(reconstructed_mask[:, :, 0], processed_mask)
224    
225    print(f"Image reconstruction matches: {img_match}")
226    print(f"Mask reconstruction matches: {mask_match}")
227    
228    # Visualize
229    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
230    
231    # Original with ROI
232    axes[0, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
233    if roi_bbox:
234        (x1, y1), (x2, y2) = roi_bbox
235        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
236                              linewidth=2, edgecolor='green', facecolor='none')
237        axes[0, 0].add_patch(rect)
238    axes[0, 0].set_title('Original Image' + (' + ROI' if roi_bbox else ''))
239    axes[0, 0].axis('off')
240    
241    axes[1, 0].imshow(mask, cmap='gray')
242    axes[1, 0].set_title('Original Mask')
243    axes[1, 0].axis('off')
244    
245    # Padded/Processed (cropped + preprocessed if ROI)
246    axes[0, 1].imshow(cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB))
247    title = 'Cropped + Preprocessed + Padded' if roi_bbox and preprocess_fns else \
248            'Cropped + Padded' if roi_bbox else \
249            'Preprocessed + Padded' if preprocess_fns else 'Padded Image'
250    axes[0, 1].set_title(title)
251    axes[0, 1].axis('off')
252    
253    axes[1, 1].imshow(processed_mask, cmap='gray')
254    axes[1, 1].set_title('Padded Mask')
255    axes[1, 1].axis('off')
256    
257    # Reconstructed
258    axes[0, 2].imshow(cv2.cvtColor(reconstructed_img, cv2.COLOR_BGR2RGB))
259    axes[0, 2].set_title(f'Reconstructed (overlap {(1-step/patch_size)*100:.0f}%)')
260    axes[0, 2].axis('off')
261    
262    axes[1, 2].imshow(reconstructed_mask[:, :, 0], cmap='gray')
263    axes[1, 2].set_title('Reconstructed Mask')
264    axes[1, 2].axis('off')
265    
266    plt.tight_layout()
267    plt.show()
268    
269    return img_match and mask_match

Test that patches can be reconstructed back into the original image.

Arguments:
  • pairs: List of image-mask pairs.
  • img_idx: Index of image to test. If None, picks random. Default is None.
  • mask_type: Which mask to test. Default is 'root'.
  • patch_size: Size of patches. Default is 128.
  • scaling_factor: Scaling factor. Default is 1.0.
  • step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
  • filter_roi: If True, crop to ROI before patching. Default is True.
  • preprocess_fns: Optional list of preprocessing functions to apply to image.
Returns:

Boolean indicating if reconstruction matches original.

def visualize_patch_reassembly( patch_dir, raw_dir, dataset_type, n_images=3, mask_types=['root', 'shoot', 'seed']):
271def visualize_patch_reassembly(patch_dir, raw_dir, dataset_type, n_images=3, 
272                               mask_types=['root', 'shoot', 'seed']):
273    """Visualize original images alongside their reassembled patches and mask overlays.
274    
275    Loads random original images and their corresponding patches, reassembles
276    the patches, and displays them side-by-side with yellow borders showing
277    patch boundaries. Also shows mask overlays in different colors.
278    
279    Args:
280        patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test').
281        raw_dir: Directory containing original images (e.g., '../../data/dataset').
282        dataset_type: Either 'train' or 'val'.
283        n_images: Number of random images to visualize. Default is 3.
284        mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed'].
285    
286    Example:
287        >>> visualize_patch_reassembly('../../data/test_output/patched_test', 
288        ...                            '../../data/dataset', 'train', n_images=3)
289    """
290    from library.patch_dataset import load_patch_metadata, reconstruct_from_patches
291    
292    # Define colors for each mask type (RGB)
293    mask_color_map = {
294        'root': [255, 0, 0],      # Red
295        'shoot': [0, 255, 0],     # Green
296        'seed': [0, 0, 255]       # Blue
297    }
298    
299    # Load metadata
300    metadata = load_patch_metadata(patch_dir, dataset_type)
301    patch_size = metadata['dataset_info']['patch_size']
302    step = metadata['dataset_info']['step']
303    filter_roi = metadata['dataset_info'].get('filter_roi', False)
304    preprocessing = metadata['dataset_info'].get('preprocessing', None)
305    
306    # Get unique source images
307    source_images = list(set([p['source_image'] for p in metadata['patches']]))
308    
309    # Select n random images
310    selected_images = random.sample(source_images, min(n_images, len(source_images)))
311    
312    for img_name in selected_images:
313        print(f"Processing {img_name}...")
314        
315        # Load original image
316        img_path = Path(raw_dir) / f'{dataset_type}_images' / img_name
317        original_image = cv2.imread(str(img_path))
318        
319        if original_image is None:
320            print(f"Warning: Could not load {img_path}")
321            continue
322        
323        # Get all patches for this image
324        image_patches = [p for p in metadata['patches'] if p['source_image'] == img_name]
325        
326        if not image_patches:
327            print(f"Warning: No patches found for {img_name}")
328            continue
329        
330        # Get ROI bbox from first patch (all patches from same image have same ROI)
331        roi_bbox = None
332        if 'roi_bbox' in image_patches[0]:
333            roi_bbox = tuple(tuple(coord) for coord in image_patches[0]['roi_bbox'])
334        
335        # Determine grid size from metadata
336        grid_size = image_patches[0]['grid_size']
337        n_rows, n_cols = grid_size
338        
339        # Calculate padded dimensions
340        padded_h = n_rows * step + (patch_size - step)
341        padded_w = n_cols * step + (patch_size - step)
342        
343        # Initialize arrays for reassembled image and masks
344        if step < patch_size:
345            # Use averaging reconstruction for overlapping patches
346            reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.float32)
347            counts = np.zeros((padded_h, padded_w), dtype=np.float32)
348            
349            reassembled_masks = {}
350            mask_counts = {}
351            for mask_type in mask_types:
352                reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32)
353                mask_counts[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32)
354        else:
355            # Simple unpatchify for non-overlapping
356            reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.uint8)
357            reassembled_masks = {}
358            for mask_type in mask_types:
359                reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.uint8)
360        
361        # Load and place each patch
362        patch_coords = []
363        for patch_info in image_patches:
364            row_idx = patch_info['row_idx']
365            col_idx = patch_info['col_idx']
366            patch_filename = patch_info['patch_filename']
367            
368            # Load image patch
369            patch_path = Path(patch_dir) / f'{dataset_type}_images' / dataset_type / patch_filename
370            patch = cv2.imread(str(patch_path))
371            
372            if patch is None:
373                continue
374            
375            # Calculate position
376            y_start = row_idx * step
377            x_start = col_idx * step
378            y_end = min(y_start + patch_size, padded_h)
379            x_end = min(x_start + patch_size, padded_w)
380            
381            patch_h = y_end - y_start
382            patch_w = x_end - x_start
383            
384            # Place image patch
385            if step < patch_size:
386                # Averaging for overlapping patches
387                reassembled[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w]
388                counts[y_start:y_end, x_start:x_end] += 1
389            else:
390                # Direct placement for non-overlapping
391                reassembled[y_start:y_end, x_start:x_end] = patch[:patch_h, :patch_w]
392            
393            # Load and place mask patches
394            for mask_type in mask_types:
395                mask_patch_path = Path(patch_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type / patch_filename
396                if mask_patch_path.exists():
397                    mask_patch = cv2.imread(str(mask_patch_path), cv2.IMREAD_GRAYSCALE)
398                    if mask_patch is not None:
399                        if step < patch_size:
400                            reassembled_masks[mask_type][y_start:y_end, x_start:x_end] += mask_patch[:patch_h, :patch_w]
401                            mask_counts[mask_type][y_start:y_end, x_start:x_end] += 1
402                        else:
403                            reassembled_masks[mask_type][y_start:y_end, x_start:x_end] = mask_patch[:patch_h, :patch_w]
404            
405            # Store coordinates for drawing borders
406            patch_coords.append((x_start, y_start, x_end, y_end))
407        
408        # Average overlapping regions if needed
409        if step < patch_size:
410            counts = np.maximum(counts, 1)
411            reassembled = reassembled / counts[:, :, np.newaxis]
412            reassembled = reassembled.astype(np.uint8)
413            
414            for mask_type in mask_types:
415                mask_counts[mask_type] = np.maximum(mask_counts[mask_type], 1)
416                reassembled_masks[mask_type] = reassembled_masks[mask_type] / mask_counts[mask_type]
417                reassembled_masks[mask_type] = reassembled_masks[mask_type].astype(np.uint8)
418        
419        # Create mask overlay on reassembled image
420        overlay = cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB).copy()
421        
422        # Apply each mask with its color
423        for mask_type in mask_types:
424            if mask_type in reassembled_masks:
425                mask = reassembled_masks[mask_type]
426                color = mask_color_map.get(mask_type, [255, 255, 255])
427                overlay[mask > 0] = color
428        
429        # Create visualization
430        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
431        
432        # Original image (with ROI if available)
433        axes[0].imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB))
434        if roi_bbox:
435            (x1, y1), (x2, y2) = roi_bbox
436            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
437                                linewidth=2, edgecolor='green', facecolor='none')
438            axes[0].add_patch(rect)
439        title = f'Original Image\n{img_name}'
440        if roi_bbox:
441            title += '\n(green = ROI used for patching)'
442        axes[0].set_title(title)
443        axes[0].axis('off')
444        
445        # Reassembled image with patch borders
446        axes[1].imshow(cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB))
447        
448        # Draw yellow borders for each patch
449        for x_start, y_start, x_end, y_end in patch_coords:
450            rect = plt.Rectangle(
451                (x_start, y_start), 
452                x_end - x_start, 
453                y_end - y_start,
454                linewidth=1, 
455                edgecolor='yellow', 
456                facecolor='none'
457            )
458            axes[1].add_patch(rect)
459        
460        title = f'Reassembled from {len(patch_coords)} Patches\n'
461        title += f'Patch size: {patch_size}, Step: {step} '
462        title += f'(overlap: {(1-step/patch_size)*100:.0f}%)'
463        if filter_roi:
464            title += '\n(cropped to ROI before patching)'
465        if preprocessing:
466            title += f'\nPreprocessing: {", ".join(preprocessing)}'
467        axes[1].set_title(title)
468        axes[1].axis('off')
469        
470        # Mask overlay
471        axes[2].imshow(overlay)
472        
473        # Create legend for mask colors
474        legend_text = []
475        for mask_type in mask_types:
476            if mask_type in reassembled_masks:
477                color_name = mask_type.capitalize()
478                legend_text.append(f'{color_name}: {mask_color_map[mask_type]}')
479        
480        axes[2].set_title(f'Mask Overlay\n' + ', '.join(legend_text))
481        axes[2].axis('off')
482        
483        plt.tight_layout()
484        plt.show()
485        
486        print(f"  Reassembled {len(patch_coords)} patches into {padded_h}x{padded_w} image")
487        if roi_bbox:
488            print(f"  Original: {original_image.shape[1]}x{original_image.shape[0]}, Cropped to ROI before patching")
489        else:
490            print(f"  Original: {original_image.shape[1]}x{original_image.shape[0]}")
491        if preprocessing:
492            print(f"  Preprocessing applied: {', '.join(preprocessing)}")
493        print()

Visualize original images alongside their reassembled patches and mask overlays.

Loads random original images and their corresponding patches, reassembles the patches, and displays them side-by-side with yellow borders showing patch boundaries. Also shows mask overlays in different colors.

Arguments:
  • patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test').
  • raw_dir: Directory containing original images (e.g., '../../data/dataset').
  • dataset_type: Either 'train' or 'val'.
  • n_images: Number of random images to visualize. Default is 3.
  • mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed'].
Example:
>>> visualize_patch_reassembly('../../data/test_output/patched_test', 
...                            '../../data/dataset', 'train', n_images=3)