library.dataset_visualization
Visualization and testing functions for patch datasets.
1# library/dataset_visualization.py 2"""Visualization and testing functions for patch datasets.""" 3 4import random 5import cv2 6import numpy as np 7import matplotlib.pyplot as plt 8from pathlib import Path 9from library.roi import detect_roi, crop_to_roi 10from library.patch_dataset import create_patches_from_image 11 12 13def verify_patches(pairs, patch_size=128, scaling_factor=1.0, step=None, 14 mask_type='root', filter_roi=True, preprocess_fns=None): 15 """Display a random patch with original image (gridded), mask, and overlay for verification. 16 17 Args: 18 pairs: List of image-mask pairs from get_image_mask_pairs(). 19 patch_size: Size of patches. Default is 128. 20 scaling_factor: Scaling factor. Default is 1.0. 21 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 22 mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'. 23 filter_roi: If True, crop to ROI before patching. Default is True. 24 preprocess_fns: Optional list of preprocessing functions to apply to image. 25 26 Returns: 27 None 28 """ 29 if step is None: 30 step = patch_size 31 32 # Color mapping for mask types 33 color_map = { 34 'shoot': {'rgb': [0, 255, 0], 'name': 'green'}, # Green 35 'seed': {'rgb': [0, 0, 255], 'name': 'blue'}, # Blue 36 'root': {'rgb': [255, 0, 0], 'name': 'red'}, # Red 37 } 38 mask_colors = color_map.get(mask_type, {'rgb': [255, 255, 0], 'name': 'yellow'}) 39 40 # Pick random image 41 pair = random.choice(pairs) 42 img_path = pair['image'] 43 44 # Load image and mask 45 image = cv2.imread(str(img_path)) 46 mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE) 47 48 # Detect ROI 49 roi_bbox = detect_roi(image) if filter_roi else None 50 51 # Create patches 52 masks_dict = {mask_type: mask} 53 result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 54 step, roi_bbox, preprocess_fns) 55 56 # Pick random patch that has mask content 57 n_rows, n_cols = result['image'].shape[0], result['image'].shape[1] 58 59 valid_patches = [] 60 for row_idx in range(n_rows): 61 for col_idx in range(n_cols): 62 patch_mask = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0] 63 if patch_mask.sum() > 0: 64 valid_patches.append((row_idx, col_idx)) 65 66 if not valid_patches: 67 print(f"No valid patches found for {img_path.name}") 68 return 69 70 # Pick random valid patch 71 row_idx, col_idx = random.choice(valid_patches) 72 73 # Extract patches 74 img_patch = result['image'][row_idx, col_idx, 0] 75 mask_patch = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0] 76 77 # Calculate patch center (in cropped coordinates if ROI used) 78 center_y = row_idx * step + patch_size // 2 79 center_x = col_idx * step + patch_size // 2 80 81 # Create overlay 82 overlay = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB).copy() 83 overlay[mask_patch > 0] = mask_colors['rgb'] 84 85 # Prepare display image (cropped if ROI used) 86 display_image = crop_to_roi(image, roi_bbox) if roi_bbox else image 87 # Apply preprocessing for display 88 if preprocess_fns: 89 from library.patch_dataset import apply_preprocessing_pipeline 90 display_image = apply_preprocessing_pipeline(display_image, preprocess_fns) 91 92 # Plot 93 fig, axes = plt.subplots(2, 2, figsize=(12, 12)) 94 95 # Original/cropped image with grid 96 axes[0, 0].imshow(cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)) 97 for i in range(0, display_image.shape[0], step): 98 axes[0, 0].axhline(y=i, color=mask_colors['name'], linewidth=0.5, alpha=0.3) 99 for j in range(0, display_image.shape[1], step): 100 axes[0, 0].axvline(x=j, color=mask_colors['name'], linewidth=0.5, alpha=0.3) 101 102 axes[0, 0].plot(center_x, center_y, 'r*', markersize=15) 103 title = 'Cropped to ROI' if roi_bbox else 'Original Image' 104 if preprocess_fns: 105 title += ' (preprocessed)' 106 axes[0, 0].set_title(title) 107 axes[0, 0].axis('off') 108 109 # Patch 110 axes[0, 1].imshow(cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB)) 111 axes[0, 1].set_title(f'Patch [{row_idx},{col_idx}]') 112 axes[0, 1].axis('off') 113 114 # Mask 115 axes[1, 0].imshow(mask_patch, cmap='gray', vmin=0, vmax=1) 116 axes[1, 0].set_title('Mask') 117 axes[1, 0].axis('off') 118 119 # Overlay 120 axes[1, 1].imshow(overlay) 121 axes[1, 1].set_title('Overlay') 122 axes[1, 1].axis('off') 123 124 overlap_pct = (1 - step/patch_size) * 100 125 title = f'{img_path.name} - {mask_type} (overlap {overlap_pct:.0f}%)' 126 title += f' - {len(valid_patches)} valid patches' 127 plt.suptitle(title) 128 plt.tight_layout() 129 plt.show() 130 131 132 133def test_unpatchify(pairs, img_idx=None, mask_type='root', patch_size=128, 134 scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None): 135 """Test that patches can be reconstructed back into the original image. 136 137 Args: 138 pairs: List of image-mask pairs. 139 img_idx: Index of image to test. If None, picks random. Default is None. 140 mask_type: Which mask to test. Default is 'root'. 141 patch_size: Size of patches. Default is 128. 142 scaling_factor: Scaling factor. Default is 1.0. 143 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 144 filter_roi: If True, crop to ROI before patching. Default is True. 145 preprocess_fns: Optional list of preprocessing functions to apply to image. 146 147 Returns: 148 Boolean indicating if reconstruction matches original. 149 """ 150 from library.patch_dataset import process_image, reconstruct_from_patches 151 152 if step is None: 153 step = patch_size 154 155 # Pick image 156 if img_idx is None: 157 pair = random.choice(pairs) 158 else: 159 pair = pairs[img_idx] 160 161 # Load image and mask 162 img_path = pair['image'] 163 image = cv2.imread(str(img_path)) 164 mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE) 165 166 # Detect ROI 167 roi_bbox = detect_roi(image) if filter_roi else None 168 169 print(f"Testing: {img_path.name}") 170 print(f"Original image shape: {image.shape}") 171 print(f"Original mask shape: {mask.shape}") 172 if roi_bbox: 173 print(f"ROI bbox: {roi_bbox}") 174 if preprocess_fns: 175 print(f"Preprocessing: {[fn.__name__ for fn in preprocess_fns]}") 176 print(f"Patch size: {patch_size}, Step: {step}, Overlap: {(1 - step/patch_size)*100:.1f}%") 177 178 # Create patches 179 masks_dict = {mask_type: mask} 180 result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 181 step, roi_bbox, preprocess_fns) 182 183 # Get patched shapes 184 img_patches = result['image'] 185 mask_patches = result['masks'][mask_type] 186 187 print(f"Patches shape: {img_patches.shape}") 188 189 # Get padded dimensions 190 n_rows, n_cols = img_patches.shape[0], img_patches.shape[1] 191 padded_h = n_rows * step + (patch_size - step) 192 padded_w = n_cols * step + (patch_size - step) 193 194 print(f"Padded dimensions: {padded_h} x {padded_w}") 195 196 # Get the image we should be reconstructing (cropped if ROI used) 197 target_image = crop_to_roi(image, roi_bbox) if roi_bbox else image 198 target_mask = crop_to_roi(mask, roi_bbox) if roi_bbox else mask 199 200 # Apply preprocessing to target for comparison 201 if preprocess_fns: 202 from library.patch_dataset import apply_preprocessing_pipeline 203 target_image = apply_preprocessing_pipeline(target_image, preprocess_fns) 204 205 # Process target for comparison (same padding applied during patching) 206 processed_img = process_image(target_image, patch_size, scaling_factor, is_mask=False) 207 processed_mask = process_image(target_mask, patch_size, scaling_factor, is_mask=True) 208 209 # Reconstruct 210 reconstructed_img = reconstruct_from_patches( 211 img_patches, (padded_h, padded_w, 3), patch_size, step 212 ) 213 reconstructed_mask = reconstruct_from_patches( 214 mask_patches, (padded_h, padded_w, 1), patch_size, step 215 ) 216 217 print(f"Reconstructed image shape: {reconstructed_img.shape}") 218 print(f"Reconstructed mask shape: {reconstructed_mask.shape}") 219 220 # Check if reconstruction matches processed original 221 img_match = np.array_equal(reconstructed_img, processed_img) 222 mask_match = np.array_equal(reconstructed_mask[:, :, 0], processed_mask) 223 224 print(f"Image reconstruction matches: {img_match}") 225 print(f"Mask reconstruction matches: {mask_match}") 226 227 # Visualize 228 fig, axes = plt.subplots(2, 3, figsize=(15, 10)) 229 230 # Original with ROI 231 axes[0, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 232 if roi_bbox: 233 (x1, y1), (x2, y2) = roi_bbox 234 rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 235 linewidth=2, edgecolor='green', facecolor='none') 236 axes[0, 0].add_patch(rect) 237 axes[0, 0].set_title('Original Image' + (' + ROI' if roi_bbox else '')) 238 axes[0, 0].axis('off') 239 240 axes[1, 0].imshow(mask, cmap='gray') 241 axes[1, 0].set_title('Original Mask') 242 axes[1, 0].axis('off') 243 244 # Padded/Processed (cropped + preprocessed if ROI) 245 axes[0, 1].imshow(cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)) 246 title = 'Cropped + Preprocessed + Padded' if roi_bbox and preprocess_fns else \ 247 'Cropped + Padded' if roi_bbox else \ 248 'Preprocessed + Padded' if preprocess_fns else 'Padded Image' 249 axes[0, 1].set_title(title) 250 axes[0, 1].axis('off') 251 252 axes[1, 1].imshow(processed_mask, cmap='gray') 253 axes[1, 1].set_title('Padded Mask') 254 axes[1, 1].axis('off') 255 256 # Reconstructed 257 axes[0, 2].imshow(cv2.cvtColor(reconstructed_img, cv2.COLOR_BGR2RGB)) 258 axes[0, 2].set_title(f'Reconstructed (overlap {(1-step/patch_size)*100:.0f}%)') 259 axes[0, 2].axis('off') 260 261 axes[1, 2].imshow(reconstructed_mask[:, :, 0], cmap='gray') 262 axes[1, 2].set_title('Reconstructed Mask') 263 axes[1, 2].axis('off') 264 265 plt.tight_layout() 266 plt.show() 267 268 return img_match and mask_match 269 270def visualize_patch_reassembly(patch_dir, raw_dir, dataset_type, n_images=3, 271 mask_types=['root', 'shoot', 'seed']): 272 """Visualize original images alongside their reassembled patches and mask overlays. 273 274 Loads random original images and their corresponding patches, reassembles 275 the patches, and displays them side-by-side with yellow borders showing 276 patch boundaries. Also shows mask overlays in different colors. 277 278 Args: 279 patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test'). 280 raw_dir: Directory containing original images (e.g., '../../data/dataset'). 281 dataset_type: Either 'train' or 'val'. 282 n_images: Number of random images to visualize. Default is 3. 283 mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed']. 284 285 Example: 286 >>> visualize_patch_reassembly('../../data/test_output/patched_test', 287 ... '../../data/dataset', 'train', n_images=3) 288 """ 289 from library.patch_dataset import load_patch_metadata, reconstruct_from_patches 290 291 # Define colors for each mask type (RGB) 292 mask_color_map = { 293 'root': [255, 0, 0], # Red 294 'shoot': [0, 255, 0], # Green 295 'seed': [0, 0, 255] # Blue 296 } 297 298 # Load metadata 299 metadata = load_patch_metadata(patch_dir, dataset_type) 300 patch_size = metadata['dataset_info']['patch_size'] 301 step = metadata['dataset_info']['step'] 302 filter_roi = metadata['dataset_info'].get('filter_roi', False) 303 preprocessing = metadata['dataset_info'].get('preprocessing', None) 304 305 # Get unique source images 306 source_images = list(set([p['source_image'] for p in metadata['patches']])) 307 308 # Select n random images 309 selected_images = random.sample(source_images, min(n_images, len(source_images))) 310 311 for img_name in selected_images: 312 print(f"Processing {img_name}...") 313 314 # Load original image 315 img_path = Path(raw_dir) / f'{dataset_type}_images' / img_name 316 original_image = cv2.imread(str(img_path)) 317 318 if original_image is None: 319 print(f"Warning: Could not load {img_path}") 320 continue 321 322 # Get all patches for this image 323 image_patches = [p for p in metadata['patches'] if p['source_image'] == img_name] 324 325 if not image_patches: 326 print(f"Warning: No patches found for {img_name}") 327 continue 328 329 # Get ROI bbox from first patch (all patches from same image have same ROI) 330 roi_bbox = None 331 if 'roi_bbox' in image_patches[0]: 332 roi_bbox = tuple(tuple(coord) for coord in image_patches[0]['roi_bbox']) 333 334 # Determine grid size from metadata 335 grid_size = image_patches[0]['grid_size'] 336 n_rows, n_cols = grid_size 337 338 # Calculate padded dimensions 339 padded_h = n_rows * step + (patch_size - step) 340 padded_w = n_cols * step + (patch_size - step) 341 342 # Initialize arrays for reassembled image and masks 343 if step < patch_size: 344 # Use averaging reconstruction for overlapping patches 345 reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.float32) 346 counts = np.zeros((padded_h, padded_w), dtype=np.float32) 347 348 reassembled_masks = {} 349 mask_counts = {} 350 for mask_type in mask_types: 351 reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32) 352 mask_counts[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32) 353 else: 354 # Simple unpatchify for non-overlapping 355 reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.uint8) 356 reassembled_masks = {} 357 for mask_type in mask_types: 358 reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.uint8) 359 360 # Load and place each patch 361 patch_coords = [] 362 for patch_info in image_patches: 363 row_idx = patch_info['row_idx'] 364 col_idx = patch_info['col_idx'] 365 patch_filename = patch_info['patch_filename'] 366 367 # Load image patch 368 patch_path = Path(patch_dir) / f'{dataset_type}_images' / dataset_type / patch_filename 369 patch = cv2.imread(str(patch_path)) 370 371 if patch is None: 372 continue 373 374 # Calculate position 375 y_start = row_idx * step 376 x_start = col_idx * step 377 y_end = min(y_start + patch_size, padded_h) 378 x_end = min(x_start + patch_size, padded_w) 379 380 patch_h = y_end - y_start 381 patch_w = x_end - x_start 382 383 # Place image patch 384 if step < patch_size: 385 # Averaging for overlapping patches 386 reassembled[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w] 387 counts[y_start:y_end, x_start:x_end] += 1 388 else: 389 # Direct placement for non-overlapping 390 reassembled[y_start:y_end, x_start:x_end] = patch[:patch_h, :patch_w] 391 392 # Load and place mask patches 393 for mask_type in mask_types: 394 mask_patch_path = Path(patch_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type / patch_filename 395 if mask_patch_path.exists(): 396 mask_patch = cv2.imread(str(mask_patch_path), cv2.IMREAD_GRAYSCALE) 397 if mask_patch is not None: 398 if step < patch_size: 399 reassembled_masks[mask_type][y_start:y_end, x_start:x_end] += mask_patch[:patch_h, :patch_w] 400 mask_counts[mask_type][y_start:y_end, x_start:x_end] += 1 401 else: 402 reassembled_masks[mask_type][y_start:y_end, x_start:x_end] = mask_patch[:patch_h, :patch_w] 403 404 # Store coordinates for drawing borders 405 patch_coords.append((x_start, y_start, x_end, y_end)) 406 407 # Average overlapping regions if needed 408 if step < patch_size: 409 counts = np.maximum(counts, 1) 410 reassembled = reassembled / counts[:, :, np.newaxis] 411 reassembled = reassembled.astype(np.uint8) 412 413 for mask_type in mask_types: 414 mask_counts[mask_type] = np.maximum(mask_counts[mask_type], 1) 415 reassembled_masks[mask_type] = reassembled_masks[mask_type] / mask_counts[mask_type] 416 reassembled_masks[mask_type] = reassembled_masks[mask_type].astype(np.uint8) 417 418 # Create mask overlay on reassembled image 419 overlay = cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB).copy() 420 421 # Apply each mask with its color 422 for mask_type in mask_types: 423 if mask_type in reassembled_masks: 424 mask = reassembled_masks[mask_type] 425 color = mask_color_map.get(mask_type, [255, 255, 255]) 426 overlay[mask > 0] = color 427 428 # Create visualization 429 fig, axes = plt.subplots(1, 3, figsize=(18, 6)) 430 431 # Original image (with ROI if available) 432 axes[0].imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) 433 if roi_bbox: 434 (x1, y1), (x2, y2) = roi_bbox 435 rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 436 linewidth=2, edgecolor='green', facecolor='none') 437 axes[0].add_patch(rect) 438 title = f'Original Image\n{img_name}' 439 if roi_bbox: 440 title += '\n(green = ROI used for patching)' 441 axes[0].set_title(title) 442 axes[0].axis('off') 443 444 # Reassembled image with patch borders 445 axes[1].imshow(cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB)) 446 447 # Draw yellow borders for each patch 448 for x_start, y_start, x_end, y_end in patch_coords: 449 rect = plt.Rectangle( 450 (x_start, y_start), 451 x_end - x_start, 452 y_end - y_start, 453 linewidth=1, 454 edgecolor='yellow', 455 facecolor='none' 456 ) 457 axes[1].add_patch(rect) 458 459 title = f'Reassembled from {len(patch_coords)} Patches\n' 460 title += f'Patch size: {patch_size}, Step: {step} ' 461 title += f'(overlap: {(1-step/patch_size)*100:.0f}%)' 462 if filter_roi: 463 title += '\n(cropped to ROI before patching)' 464 if preprocessing: 465 title += f'\nPreprocessing: {", ".join(preprocessing)}' 466 axes[1].set_title(title) 467 axes[1].axis('off') 468 469 # Mask overlay 470 axes[2].imshow(overlay) 471 472 # Create legend for mask colors 473 legend_text = [] 474 for mask_type in mask_types: 475 if mask_type in reassembled_masks: 476 color_name = mask_type.capitalize() 477 legend_text.append(f'{color_name}: {mask_color_map[mask_type]}') 478 479 axes[2].set_title(f'Mask Overlay\n' + ', '.join(legend_text)) 480 axes[2].axis('off') 481 482 plt.tight_layout() 483 plt.show() 484 485 print(f" Reassembled {len(patch_coords)} patches into {padded_h}x{padded_w} image") 486 if roi_bbox: 487 print(f" Original: {original_image.shape[1]}x{original_image.shape[0]}, Cropped to ROI before patching") 488 else: 489 print(f" Original: {original_image.shape[1]}x{original_image.shape[0]}") 490 if preprocessing: 491 print(f" Preprocessing applied: {', '.join(preprocessing)}") 492 print()
def
verify_patches( pairs, patch_size=128, scaling_factor=1.0, step=None, mask_type='root', filter_roi=True, preprocess_fns=None):
14def verify_patches(pairs, patch_size=128, scaling_factor=1.0, step=None, 15 mask_type='root', filter_roi=True, preprocess_fns=None): 16 """Display a random patch with original image (gridded), mask, and overlay for verification. 17 18 Args: 19 pairs: List of image-mask pairs from get_image_mask_pairs(). 20 patch_size: Size of patches. Default is 128. 21 scaling_factor: Scaling factor. Default is 1.0. 22 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 23 mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'. 24 filter_roi: If True, crop to ROI before patching. Default is True. 25 preprocess_fns: Optional list of preprocessing functions to apply to image. 26 27 Returns: 28 None 29 """ 30 if step is None: 31 step = patch_size 32 33 # Color mapping for mask types 34 color_map = { 35 'shoot': {'rgb': [0, 255, 0], 'name': 'green'}, # Green 36 'seed': {'rgb': [0, 0, 255], 'name': 'blue'}, # Blue 37 'root': {'rgb': [255, 0, 0], 'name': 'red'}, # Red 38 } 39 mask_colors = color_map.get(mask_type, {'rgb': [255, 255, 0], 'name': 'yellow'}) 40 41 # Pick random image 42 pair = random.choice(pairs) 43 img_path = pair['image'] 44 45 # Load image and mask 46 image = cv2.imread(str(img_path)) 47 mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE) 48 49 # Detect ROI 50 roi_bbox = detect_roi(image) if filter_roi else None 51 52 # Create patches 53 masks_dict = {mask_type: mask} 54 result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 55 step, roi_bbox, preprocess_fns) 56 57 # Pick random patch that has mask content 58 n_rows, n_cols = result['image'].shape[0], result['image'].shape[1] 59 60 valid_patches = [] 61 for row_idx in range(n_rows): 62 for col_idx in range(n_cols): 63 patch_mask = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0] 64 if patch_mask.sum() > 0: 65 valid_patches.append((row_idx, col_idx)) 66 67 if not valid_patches: 68 print(f"No valid patches found for {img_path.name}") 69 return 70 71 # Pick random valid patch 72 row_idx, col_idx = random.choice(valid_patches) 73 74 # Extract patches 75 img_patch = result['image'][row_idx, col_idx, 0] 76 mask_patch = result['masks'][mask_type][row_idx, col_idx, 0, :, :, 0] 77 78 # Calculate patch center (in cropped coordinates if ROI used) 79 center_y = row_idx * step + patch_size // 2 80 center_x = col_idx * step + patch_size // 2 81 82 # Create overlay 83 overlay = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB).copy() 84 overlay[mask_patch > 0] = mask_colors['rgb'] 85 86 # Prepare display image (cropped if ROI used) 87 display_image = crop_to_roi(image, roi_bbox) if roi_bbox else image 88 # Apply preprocessing for display 89 if preprocess_fns: 90 from library.patch_dataset import apply_preprocessing_pipeline 91 display_image = apply_preprocessing_pipeline(display_image, preprocess_fns) 92 93 # Plot 94 fig, axes = plt.subplots(2, 2, figsize=(12, 12)) 95 96 # Original/cropped image with grid 97 axes[0, 0].imshow(cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)) 98 for i in range(0, display_image.shape[0], step): 99 axes[0, 0].axhline(y=i, color=mask_colors['name'], linewidth=0.5, alpha=0.3) 100 for j in range(0, display_image.shape[1], step): 101 axes[0, 0].axvline(x=j, color=mask_colors['name'], linewidth=0.5, alpha=0.3) 102 103 axes[0, 0].plot(center_x, center_y, 'r*', markersize=15) 104 title = 'Cropped to ROI' if roi_bbox else 'Original Image' 105 if preprocess_fns: 106 title += ' (preprocessed)' 107 axes[0, 0].set_title(title) 108 axes[0, 0].axis('off') 109 110 # Patch 111 axes[0, 1].imshow(cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB)) 112 axes[0, 1].set_title(f'Patch [{row_idx},{col_idx}]') 113 axes[0, 1].axis('off') 114 115 # Mask 116 axes[1, 0].imshow(mask_patch, cmap='gray', vmin=0, vmax=1) 117 axes[1, 0].set_title('Mask') 118 axes[1, 0].axis('off') 119 120 # Overlay 121 axes[1, 1].imshow(overlay) 122 axes[1, 1].set_title('Overlay') 123 axes[1, 1].axis('off') 124 125 overlap_pct = (1 - step/patch_size) * 100 126 title = f'{img_path.name} - {mask_type} (overlap {overlap_pct:.0f}%)' 127 title += f' - {len(valid_patches)} valid patches' 128 plt.suptitle(title) 129 plt.tight_layout() 130 plt.show()
Display a random patch with original image (gridded), mask, and overlay for verification.
Arguments:
- pairs: List of image-mask pairs from get_image_mask_pairs().
- patch_size: Size of patches. Default is 128.
- scaling_factor: Scaling factor. Default is 1.0.
- step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
- mask_type: Which mask to display ('root', 'shoot', or 'seed'). Default is 'root'.
- filter_roi: If True, crop to ROI before patching. Default is True.
- preprocess_fns: Optional list of preprocessing functions to apply to image.
Returns:
None
def
test_unpatchify( pairs, img_idx=None, mask_type='root', patch_size=128, scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None):
134def test_unpatchify(pairs, img_idx=None, mask_type='root', patch_size=128, 135 scaling_factor=1.0, step=None, filter_roi=True, preprocess_fns=None): 136 """Test that patches can be reconstructed back into the original image. 137 138 Args: 139 pairs: List of image-mask pairs. 140 img_idx: Index of image to test. If None, picks random. Default is None. 141 mask_type: Which mask to test. Default is 'root'. 142 patch_size: Size of patches. Default is 128. 143 scaling_factor: Scaling factor. Default is 1.0. 144 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 145 filter_roi: If True, crop to ROI before patching. Default is True. 146 preprocess_fns: Optional list of preprocessing functions to apply to image. 147 148 Returns: 149 Boolean indicating if reconstruction matches original. 150 """ 151 from library.patch_dataset import process_image, reconstruct_from_patches 152 153 if step is None: 154 step = patch_size 155 156 # Pick image 157 if img_idx is None: 158 pair = random.choice(pairs) 159 else: 160 pair = pairs[img_idx] 161 162 # Load image and mask 163 img_path = pair['image'] 164 image = cv2.imread(str(img_path)) 165 mask = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE) 166 167 # Detect ROI 168 roi_bbox = detect_roi(image) if filter_roi else None 169 170 print(f"Testing: {img_path.name}") 171 print(f"Original image shape: {image.shape}") 172 print(f"Original mask shape: {mask.shape}") 173 if roi_bbox: 174 print(f"ROI bbox: {roi_bbox}") 175 if preprocess_fns: 176 print(f"Preprocessing: {[fn.__name__ for fn in preprocess_fns]}") 177 print(f"Patch size: {patch_size}, Step: {step}, Overlap: {(1 - step/patch_size)*100:.1f}%") 178 179 # Create patches 180 masks_dict = {mask_type: mask} 181 result = create_patches_from_image(image, masks_dict, patch_size, scaling_factor, 182 step, roi_bbox, preprocess_fns) 183 184 # Get patched shapes 185 img_patches = result['image'] 186 mask_patches = result['masks'][mask_type] 187 188 print(f"Patches shape: {img_patches.shape}") 189 190 # Get padded dimensions 191 n_rows, n_cols = img_patches.shape[0], img_patches.shape[1] 192 padded_h = n_rows * step + (patch_size - step) 193 padded_w = n_cols * step + (patch_size - step) 194 195 print(f"Padded dimensions: {padded_h} x {padded_w}") 196 197 # Get the image we should be reconstructing (cropped if ROI used) 198 target_image = crop_to_roi(image, roi_bbox) if roi_bbox else image 199 target_mask = crop_to_roi(mask, roi_bbox) if roi_bbox else mask 200 201 # Apply preprocessing to target for comparison 202 if preprocess_fns: 203 from library.patch_dataset import apply_preprocessing_pipeline 204 target_image = apply_preprocessing_pipeline(target_image, preprocess_fns) 205 206 # Process target for comparison (same padding applied during patching) 207 processed_img = process_image(target_image, patch_size, scaling_factor, is_mask=False) 208 processed_mask = process_image(target_mask, patch_size, scaling_factor, is_mask=True) 209 210 # Reconstruct 211 reconstructed_img = reconstruct_from_patches( 212 img_patches, (padded_h, padded_w, 3), patch_size, step 213 ) 214 reconstructed_mask = reconstruct_from_patches( 215 mask_patches, (padded_h, padded_w, 1), patch_size, step 216 ) 217 218 print(f"Reconstructed image shape: {reconstructed_img.shape}") 219 print(f"Reconstructed mask shape: {reconstructed_mask.shape}") 220 221 # Check if reconstruction matches processed original 222 img_match = np.array_equal(reconstructed_img, processed_img) 223 mask_match = np.array_equal(reconstructed_mask[:, :, 0], processed_mask) 224 225 print(f"Image reconstruction matches: {img_match}") 226 print(f"Mask reconstruction matches: {mask_match}") 227 228 # Visualize 229 fig, axes = plt.subplots(2, 3, figsize=(15, 10)) 230 231 # Original with ROI 232 axes[0, 0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 233 if roi_bbox: 234 (x1, y1), (x2, y2) = roi_bbox 235 rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 236 linewidth=2, edgecolor='green', facecolor='none') 237 axes[0, 0].add_patch(rect) 238 axes[0, 0].set_title('Original Image' + (' + ROI' if roi_bbox else '')) 239 axes[0, 0].axis('off') 240 241 axes[1, 0].imshow(mask, cmap='gray') 242 axes[1, 0].set_title('Original Mask') 243 axes[1, 0].axis('off') 244 245 # Padded/Processed (cropped + preprocessed if ROI) 246 axes[0, 1].imshow(cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)) 247 title = 'Cropped + Preprocessed + Padded' if roi_bbox and preprocess_fns else \ 248 'Cropped + Padded' if roi_bbox else \ 249 'Preprocessed + Padded' if preprocess_fns else 'Padded Image' 250 axes[0, 1].set_title(title) 251 axes[0, 1].axis('off') 252 253 axes[1, 1].imshow(processed_mask, cmap='gray') 254 axes[1, 1].set_title('Padded Mask') 255 axes[1, 1].axis('off') 256 257 # Reconstructed 258 axes[0, 2].imshow(cv2.cvtColor(reconstructed_img, cv2.COLOR_BGR2RGB)) 259 axes[0, 2].set_title(f'Reconstructed (overlap {(1-step/patch_size)*100:.0f}%)') 260 axes[0, 2].axis('off') 261 262 axes[1, 2].imshow(reconstructed_mask[:, :, 0], cmap='gray') 263 axes[1, 2].set_title('Reconstructed Mask') 264 axes[1, 2].axis('off') 265 266 plt.tight_layout() 267 plt.show() 268 269 return img_match and mask_match
Test that patches can be reconstructed back into the original image.
Arguments:
- pairs: List of image-mask pairs.
- img_idx: Index of image to test. If None, picks random. Default is None.
- mask_type: Which mask to test. Default is 'root'.
- patch_size: Size of patches. Default is 128.
- scaling_factor: Scaling factor. Default is 1.0.
- step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
- filter_roi: If True, crop to ROI before patching. Default is True.
- preprocess_fns: Optional list of preprocessing functions to apply to image.
Returns:
Boolean indicating if reconstruction matches original.
def
visualize_patch_reassembly( patch_dir, raw_dir, dataset_type, n_images=3, mask_types=['root', 'shoot', 'seed']):
271def visualize_patch_reassembly(patch_dir, raw_dir, dataset_type, n_images=3, 272 mask_types=['root', 'shoot', 'seed']): 273 """Visualize original images alongside their reassembled patches and mask overlays. 274 275 Loads random original images and their corresponding patches, reassembles 276 the patches, and displays them side-by-side with yellow borders showing 277 patch boundaries. Also shows mask overlays in different colors. 278 279 Args: 280 patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test'). 281 raw_dir: Directory containing original images (e.g., '../../data/dataset'). 282 dataset_type: Either 'train' or 'val'. 283 n_images: Number of random images to visualize. Default is 3. 284 mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed']. 285 286 Example: 287 >>> visualize_patch_reassembly('../../data/test_output/patched_test', 288 ... '../../data/dataset', 'train', n_images=3) 289 """ 290 from library.patch_dataset import load_patch_metadata, reconstruct_from_patches 291 292 # Define colors for each mask type (RGB) 293 mask_color_map = { 294 'root': [255, 0, 0], # Red 295 'shoot': [0, 255, 0], # Green 296 'seed': [0, 0, 255] # Blue 297 } 298 299 # Load metadata 300 metadata = load_patch_metadata(patch_dir, dataset_type) 301 patch_size = metadata['dataset_info']['patch_size'] 302 step = metadata['dataset_info']['step'] 303 filter_roi = metadata['dataset_info'].get('filter_roi', False) 304 preprocessing = metadata['dataset_info'].get('preprocessing', None) 305 306 # Get unique source images 307 source_images = list(set([p['source_image'] for p in metadata['patches']])) 308 309 # Select n random images 310 selected_images = random.sample(source_images, min(n_images, len(source_images))) 311 312 for img_name in selected_images: 313 print(f"Processing {img_name}...") 314 315 # Load original image 316 img_path = Path(raw_dir) / f'{dataset_type}_images' / img_name 317 original_image = cv2.imread(str(img_path)) 318 319 if original_image is None: 320 print(f"Warning: Could not load {img_path}") 321 continue 322 323 # Get all patches for this image 324 image_patches = [p for p in metadata['patches'] if p['source_image'] == img_name] 325 326 if not image_patches: 327 print(f"Warning: No patches found for {img_name}") 328 continue 329 330 # Get ROI bbox from first patch (all patches from same image have same ROI) 331 roi_bbox = None 332 if 'roi_bbox' in image_patches[0]: 333 roi_bbox = tuple(tuple(coord) for coord in image_patches[0]['roi_bbox']) 334 335 # Determine grid size from metadata 336 grid_size = image_patches[0]['grid_size'] 337 n_rows, n_cols = grid_size 338 339 # Calculate padded dimensions 340 padded_h = n_rows * step + (patch_size - step) 341 padded_w = n_cols * step + (patch_size - step) 342 343 # Initialize arrays for reassembled image and masks 344 if step < patch_size: 345 # Use averaging reconstruction for overlapping patches 346 reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.float32) 347 counts = np.zeros((padded_h, padded_w), dtype=np.float32) 348 349 reassembled_masks = {} 350 mask_counts = {} 351 for mask_type in mask_types: 352 reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32) 353 mask_counts[mask_type] = np.zeros((padded_h, padded_w), dtype=np.float32) 354 else: 355 # Simple unpatchify for non-overlapping 356 reassembled = np.zeros((padded_h, padded_w, 3), dtype=np.uint8) 357 reassembled_masks = {} 358 for mask_type in mask_types: 359 reassembled_masks[mask_type] = np.zeros((padded_h, padded_w), dtype=np.uint8) 360 361 # Load and place each patch 362 patch_coords = [] 363 for patch_info in image_patches: 364 row_idx = patch_info['row_idx'] 365 col_idx = patch_info['col_idx'] 366 patch_filename = patch_info['patch_filename'] 367 368 # Load image patch 369 patch_path = Path(patch_dir) / f'{dataset_type}_images' / dataset_type / patch_filename 370 patch = cv2.imread(str(patch_path)) 371 372 if patch is None: 373 continue 374 375 # Calculate position 376 y_start = row_idx * step 377 x_start = col_idx * step 378 y_end = min(y_start + patch_size, padded_h) 379 x_end = min(x_start + patch_size, padded_w) 380 381 patch_h = y_end - y_start 382 patch_w = x_end - x_start 383 384 # Place image patch 385 if step < patch_size: 386 # Averaging for overlapping patches 387 reassembled[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w] 388 counts[y_start:y_end, x_start:x_end] += 1 389 else: 390 # Direct placement for non-overlapping 391 reassembled[y_start:y_end, x_start:x_end] = patch[:patch_h, :patch_w] 392 393 # Load and place mask patches 394 for mask_type in mask_types: 395 mask_patch_path = Path(patch_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type / patch_filename 396 if mask_patch_path.exists(): 397 mask_patch = cv2.imread(str(mask_patch_path), cv2.IMREAD_GRAYSCALE) 398 if mask_patch is not None: 399 if step < patch_size: 400 reassembled_masks[mask_type][y_start:y_end, x_start:x_end] += mask_patch[:patch_h, :patch_w] 401 mask_counts[mask_type][y_start:y_end, x_start:x_end] += 1 402 else: 403 reassembled_masks[mask_type][y_start:y_end, x_start:x_end] = mask_patch[:patch_h, :patch_w] 404 405 # Store coordinates for drawing borders 406 patch_coords.append((x_start, y_start, x_end, y_end)) 407 408 # Average overlapping regions if needed 409 if step < patch_size: 410 counts = np.maximum(counts, 1) 411 reassembled = reassembled / counts[:, :, np.newaxis] 412 reassembled = reassembled.astype(np.uint8) 413 414 for mask_type in mask_types: 415 mask_counts[mask_type] = np.maximum(mask_counts[mask_type], 1) 416 reassembled_masks[mask_type] = reassembled_masks[mask_type] / mask_counts[mask_type] 417 reassembled_masks[mask_type] = reassembled_masks[mask_type].astype(np.uint8) 418 419 # Create mask overlay on reassembled image 420 overlay = cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB).copy() 421 422 # Apply each mask with its color 423 for mask_type in mask_types: 424 if mask_type in reassembled_masks: 425 mask = reassembled_masks[mask_type] 426 color = mask_color_map.get(mask_type, [255, 255, 255]) 427 overlay[mask > 0] = color 428 429 # Create visualization 430 fig, axes = plt.subplots(1, 3, figsize=(18, 6)) 431 432 # Original image (with ROI if available) 433 axes[0].imshow(cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)) 434 if roi_bbox: 435 (x1, y1), (x2, y2) = roi_bbox 436 rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 437 linewidth=2, edgecolor='green', facecolor='none') 438 axes[0].add_patch(rect) 439 title = f'Original Image\n{img_name}' 440 if roi_bbox: 441 title += '\n(green = ROI used for patching)' 442 axes[0].set_title(title) 443 axes[0].axis('off') 444 445 # Reassembled image with patch borders 446 axes[1].imshow(cv2.cvtColor(reassembled, cv2.COLOR_BGR2RGB)) 447 448 # Draw yellow borders for each patch 449 for x_start, y_start, x_end, y_end in patch_coords: 450 rect = plt.Rectangle( 451 (x_start, y_start), 452 x_end - x_start, 453 y_end - y_start, 454 linewidth=1, 455 edgecolor='yellow', 456 facecolor='none' 457 ) 458 axes[1].add_patch(rect) 459 460 title = f'Reassembled from {len(patch_coords)} Patches\n' 461 title += f'Patch size: {patch_size}, Step: {step} ' 462 title += f'(overlap: {(1-step/patch_size)*100:.0f}%)' 463 if filter_roi: 464 title += '\n(cropped to ROI before patching)' 465 if preprocessing: 466 title += f'\nPreprocessing: {", ".join(preprocessing)}' 467 axes[1].set_title(title) 468 axes[1].axis('off') 469 470 # Mask overlay 471 axes[2].imshow(overlay) 472 473 # Create legend for mask colors 474 legend_text = [] 475 for mask_type in mask_types: 476 if mask_type in reassembled_masks: 477 color_name = mask_type.capitalize() 478 legend_text.append(f'{color_name}: {mask_color_map[mask_type]}') 479 480 axes[2].set_title(f'Mask Overlay\n' + ', '.join(legend_text)) 481 axes[2].axis('off') 482 483 plt.tight_layout() 484 plt.show() 485 486 print(f" Reassembled {len(patch_coords)} patches into {padded_h}x{padded_w} image") 487 if roi_bbox: 488 print(f" Original: {original_image.shape[1]}x{original_image.shape[0]}, Cropped to ROI before patching") 489 else: 490 print(f" Original: {original_image.shape[1]}x{original_image.shape[0]}") 491 if preprocessing: 492 print(f" Preprocessing applied: {', '.join(preprocessing)}") 493 print()
Visualize original images alongside their reassembled patches and mask overlays.
Loads random original images and their corresponding patches, reassembles the patches, and displays them side-by-side with yellow borders showing patch boundaries. Also shows mask overlays in different colors.
Arguments:
- patch_dir: Directory containing saved patches (e.g., '../../data/test_output/patched_test').
- raw_dir: Directory containing original images (e.g., '../../data/dataset').
- dataset_type: Either 'train' or 'val'.
- n_images: Number of random images to visualize. Default is 3.
- mask_types: List of mask types to overlay. Default is ['root', 'shoot', 'seed'].
Example:
>>> visualize_patch_reassembly('../../data/test_output/patched_test', ... '../../data/dataset', 'train', n_images=3)