library.patch_dataset
Functions for creating and managing patch-based datasets from images and masks.
1# library/patch_dataset.py 2"""Functions for creating and managing patch-based datasets from images and masks.""" 3 4import os 5import shutil 6import json 7import glob 8from datetime import datetime, timezone 9from pathlib import Path 10from datetime import datetime 11import cv2 12import numpy as np 13from patchify import patchify, unpatchify 14from tqdm import tqdm 15from library.roi import detect_roi, crop_to_roi 16 17METADATA_VERSION = "1.0" 18 19 20def padder(image, patch_size, is_mask=False): 21 """Add padding to an image to make its dimensions divisible by patch size. 22 23 Calculates padding needed for both height and width so that dimensions become 24 divisible by the given patch size. Padding is applied evenly to both sides of 25 each dimension. If padding amount is odd, one extra pixel is added to the 26 bottom or right side. 27 28 Parameters: 29 image: Input image as numpy array with shape (height, width, channels). 30 patch_size: The patch size to which image dimensions should be divisible. 31 is_mask: If True, uses grayscale padding. If False, uses RGB padding. 32 33 Returns: 34 Padded image as numpy array with dimensions divisible by patch_size. 35 36 Example: 37 >>> padded_image = padder(cv2.imread('example.jpg'), 128) 38 """ 39 h, w = image.shape[:2] 40 41 # Calculate padding only if needed 42 height_padding = 0 if h % patch_size == 0 else ((h // patch_size) + 1) * patch_size - h 43 width_padding = 0 if w % patch_size == 0 else ((w // patch_size) + 1) * patch_size - w 44 45 # Early return if no padding needed 46 if height_padding == 0 and width_padding == 0: 47 return image 48 49 # Split padding evenly 50 top_padding = height_padding // 2 51 bottom_padding = height_padding - top_padding 52 left_padding = width_padding // 2 53 right_padding = width_padding - left_padding 54 55 # Use black padding 56 pad_value = 0 57 58 return cv2.copyMakeBorder( 59 image, 60 top_padding, bottom_padding, 61 left_padding, right_padding, 62 cv2.BORDER_CONSTANT, 63 value=pad_value 64 ) 65 66def unpadder(padded_image, roi_box): 67 """Remove padding from an image to restore original ROI dimensions. 68 69 Calculates and removes padding that was added to make an image divisible by 70 a patch size. Padding is removed evenly from both sides of each dimension. 71 This function reverses the padding operation applied by the padder function. 72 73 Parameters: 74 padded_image: Padded image as numpy array with shape (height, width, channels). 75 roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates. 76 77 Returns: 78 Cropped image as numpy array with original ROI dimensions. 79 80 Example: 81 >>> roi_box = ((776, 70), (3519, 2813)) 82 >>> original = unpadder(padded_image, roi_box) 83 """ 84 h, w = padded_image.shape[:2] 85 86 # Calculate original ROI dimensions 87 roi_height = roi_box[1][1] - roi_box[0][1] 88 roi_width = roi_box[1][0] - roi_box[0][0] 89 90 # Calculate total padding in each dimension 91 total_height_px = h - roi_height 92 total_width_px = w - roi_width 93 94 # Early return if no padding to remove 95 if total_height_px == 0 and total_width_px == 0: 96 return padded_image 97 98 # Split padding evenly 99 left_padding = total_width_px // 2 100 right_padding = total_width_px - left_padding 101 top_padding = total_height_px // 2 102 bottom_padding = total_height_px - top_padding 103 104 # Crop the image & handle masks too 105 if padded_image.ndim == 2: 106 return padded_image[top_padding:-bottom_padding, left_padding:-right_padding] 107 else: 108 return padded_image[top_padding:-bottom_padding, left_padding:-right_padding, :] 109 110 111def restore_mask_to_original(padded_mask, original_image_shape, roi_box): 112 """Restore a padded mask to match the original image dimensions. 113 114 This function removes padding from a mask and places it at the correct position 115 in a full-size mask that matches the original image dimensions. 116 117 Parameters: 118 padded_mask: Padded binary mask as numpy array with shape (height, width). 119 original_image_shape: Tuple of (height, width) or (height, width, channels) 120 from the original image. 121 roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates. 122 123 Returns: 124 Binary mask as numpy array with shape matching original_image_shape[:2], 125 with mask values of 0 and 255. 126 127 Example: 128 >>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box) 129 >>> cv2.imwrite('output.png', full_mask) 130 """ 131 # Remove padding using unpadder 132 unpadded_mask = unpadder(padded_mask, roi_box) 133 134 # Ensure mask is binary (0 and 255) 135 binary_mask = (unpadded_mask > 0).astype(np.uint8) * 255 136 137 # Create full-size mask matching original image dimensions 138 full_mask = np.zeros(original_image_shape[:2], dtype=np.uint8) 139 140 # Extract ROI coordinates 141 (x1, y1), (x2, y2) = roi_box 142 143 # Place the mask at the correct position 144 full_mask[y1:y2, x1:x2] = binary_mask 145 146 return full_mask 147 148def apply_preprocessing_pipeline(image, preprocess_fns): 149 """Apply a list of preprocessing functions sequentially to an image. 150 151 Args: 152 image: Numpy array with shape (H, W, C) or (H, W). 153 preprocess_fns: List of callables, each taking image and returning 154 processed image. Can be None or empty list. 155 156 Returns: 157 Preprocessed image. 158 """ 159 if preprocess_fns is None or len(preprocess_fns) == 0: 160 return image 161 162 processed = image.copy() 163 for fn in preprocess_fns: 164 processed = fn(processed) 165 166 return processed 167 168def process_image(image, patch_size, scaling_factor, is_mask=False): 169 """Pad and scale a single image. 170 171 Args: 172 image: Numpy array with shape (H, W, C) or (H, W). 173 patch_size: Target patch size. 174 scaling_factor: Scaling factor (<=1.0). 175 is_mask: Whether this is a mask (affects padding value). 176 177 Returns: 178 Processed image. 179 """ 180 # Scale if needed 181 if scaling_factor != 1.0: 182 image = cv2.resize(image, (0, 0), fx=scaling_factor, fy=scaling_factor) 183 184 # Pad to be divisible by patch_size 185 image = padder(image, patch_size, is_mask=is_mask) 186 187 return image 188 189 190def create_patch_directories(output_dir, dataset_type, mask_types=['root', 'shoot', 'seed']): 191 """Create directory structure needed for patch datasets. 192 193 Args: 194 output_dir: Base output directory (e.g., 'data_patched'). 195 dataset_type: Either 'train' or 'val'. 196 mask_types: List of mask types to create directories for. 197 198 Returns: 199 Dictionary of created paths with keys 'images' and 'masks_{type}'. 200 """ 201 paths = {} 202 203 # Images directory 204 img_dir = Path(output_dir) / f'{dataset_type}_images' / dataset_type 205 img_dir.mkdir(parents=True, exist_ok=True) 206 paths['images'] = img_dir 207 208 # Mask directories for each type 209 for mask_type in mask_types: 210 mask_dir = Path(output_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type 211 mask_dir.mkdir(parents=True, exist_ok=True) 212 paths[f'masks_{mask_type}'] = mask_dir 213 214 return paths 215 216def get_image_mask_pairs(data_dir, dataset_type, mask_types=['root', 'shoot', 'seed']): 217 """Find all images and their corresponding masks. 218 219 Args: 220 data_dir: Root data directory (e.g., '../../data/dataset'). 221 dataset_type: Either 'train' or 'val'. 222 mask_types: List of mask types to find. 223 224 Returns: 225 List of dictionaries with 'image' path and 'masks' dictionary. 226 Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}} 227 """ 228 import glob 229 230 image_dir = Path(data_dir) / f'{dataset_type}_images' 231 mask_dir = Path(data_dir) / f'{dataset_type}_masks' 232 233 # Get all image files 234 image_files = sorted(glob.glob(str(image_dir / '*.png'))) 235 236 pairs = [] 237 for img_path in image_files: 238 img_path = Path(img_path) 239 base_name = img_path.stem # filename without extension 240 241 # Find corresponding masks 242 masks = {} 243 for mask_type in mask_types: 244 mask_pattern = f'{base_name}_{mask_type}_mask.tif' 245 mask_path = mask_dir / mask_pattern 246 247 if mask_path.exists(): 248 masks[mask_type] = mask_path 249 250 # Only include if at least one mask exists 251 if masks: 252 pairs.append({ 253 'image': img_path, 254 'masks': masks 255 }) 256 257 return pairs 258 259def create_patches_from_image(image, mask_dict, patch_size, scaling_factor, step=None, 260 roi_bbox=None, preprocess_fns=None): 261 """Create patches from one image and its corresponding masks. 262 263 Args: 264 image: Numpy array of the image. 265 mask_dict: Dictionary with mask_type: mask_array. 266 patch_size: Size of patches. 267 scaling_factor: Scaling factor for resizing. 268 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 269 roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching. 270 preprocess_fns: List of preprocessing functions to apply before patching. 271 272 Returns: 273 Dictionary with 'image' patches, 'masks' dict of patches for each type, 274 'step', and 'roi_bbox' (if cropped). 275 """ 276 277 if step is None: 278 step = patch_size 279 280 # Crop to ROI if provided 281 if roi_bbox is not None: 282 image = crop_to_roi(image, roi_bbox) 283 mask_dict = {k: crop_to_roi(v, roi_bbox) for k, v in mask_dict.items()} 284 285 # Apply preprocessing to image only (not masks) 286 image = apply_preprocessing_pipeline(image, preprocess_fns) 287 288 # Convert to grayscale if color image 289 if image.ndim == 3 and image.shape[2] == 3: 290 image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 291 292 # Process image (add channel dimension for grayscale) 293 image = process_image(image, patch_size, scaling_factor, is_mask=False) 294 image = image[..., np.newaxis] # Add channel dimension for patchify 295 img_patches = patchify(image, (patch_size, patch_size, 1), step=step) 296 297 # Process masks 298 mask_patches = {} 299 for mask_type, mask in mask_dict.items(): 300 mask = process_image(mask, patch_size, scaling_factor, is_mask=True) 301 mask = mask[..., np.newaxis] # Add channel dimension for patchify 302 patches = patchify(mask, (patch_size, patch_size, 1), step=step) 303 mask_patches[mask_type] = patches 304 305 result = { 306 'image': img_patches, 307 'masks': mask_patches, 308 'step': step 309 } 310 311 if roi_bbox is not None: 312 result['roi_bbox'] = roi_bbox 313 314 return result 315 316def reconstruct_from_patches(patches, image_shape, patch_size, step): 317 """Reconstruct an image from patches. 318 319 Uses unpatchify for non-overlapping patches (step == patch_size). 320 Uses averaging reconstruction for overlapping patches (step < patch_size). 321 322 Args: 323 patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels). 324 image_shape: Target shape (height, width, channels) for reconstruction. 325 patch_size: Size of each patch. 326 step: Step size used during patch extraction. 327 328 Returns: 329 Reconstructed image. 330 """ 331 # Use unpatchify for non-overlapping patches 332 if step == patch_size: 333 return unpatchify(patches, image_shape) 334 335 # Use averaging for overlapping patches 336 h, w, c = image_shape 337 n_rows, n_cols = patches.shape[0], patches.shape[1] 338 339 reconstructed = np.zeros(image_shape, dtype=np.float32) 340 counts = np.zeros((h, w), dtype=np.float32) 341 342 for row_idx in range(n_rows): 343 for col_idx in range(n_cols): 344 y_start = row_idx * step 345 x_start = col_idx * step 346 y_end = min(y_start + patch_size, h) 347 x_end = min(x_start + patch_size, w) 348 349 patch = patches[row_idx, col_idx, 0] 350 patch_h = y_end - y_start 351 patch_w = x_end - x_start 352 353 reconstructed[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w] 354 counts[y_start:y_end, x_start:x_end] += 1 355 356 counts = np.maximum(counts, 1) 357 reconstructed = reconstructed / counts[:, :, np.newaxis] 358 359 return reconstructed.astype(np.uint8) 360 361 362def save_patches(pairs, output_dir, dataset_type, patch_size=128, 363 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], 364 filter_roi=True, preprocess_fns=None, notes=""): 365 """Create and save all patches to disk, optionally cropping to ROI first. 366 367 This function processes images serially. For faster processing with multiple 368 images, use save_patches_parallel() instead. 369 370 Args: 371 pairs: List from get_image_mask_pairs(). 372 output_dir: Base output directory. 373 dataset_type: Either 'train' or 'val'. 374 patch_size: Patch size. Default is 128. 375 scaling_factor: Scaling factor. Default is 1.0. 376 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 377 mask_types: Which masks to process. Default is ['root', 'shoot', 'seed']. 378 filter_roi: If True, crop to ROI before patching. Default is True. 379 preprocess_fns: Optional list of preprocessing functions to apply to images. 380 notes: Optional notes to include in metadata. Default is empty string. 381 382 Returns: 383 Number of patches created. 384 """ 385 # Simply delegate to save_patches_parallel with num_workers=1 for serial processing 386 return save_patches_parallel( 387 pairs=pairs, 388 output_dir=output_dir, 389 dataset_type=dataset_type, 390 patch_size=patch_size, 391 scaling_factor=scaling_factor, 392 step=step, 393 mask_types=mask_types, 394 filter_roi=filter_roi, 395 preprocess_fns=preprocess_fns, 396 notes=notes, 397 num_workers=1 398 ) 399 400def _process_image_worker_parallel(args): 401 """Worker function for parallel patch processing. Must be at module level for pickling. 402 403 Args: 404 args: Tuple of (pair, paths_dict, patch_size, scaling_factor, step, 405 mask_types, filter_roi, preprocess_fns) 406 407 Returns: 408 Tuple of (local_metadata, patch_count, image_name) 409 """ 410 pair, paths_dict, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns = args 411 412 img_path = pair['image'] 413 base_name = img_path.stem 414 415 # Load image as grayscale 416 image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE) 417 418 # Detect ROI 419 roi_bbox = detect_roi(image) if filter_roi else None 420 421 # Load masks 422 masks = {} 423 for mask_type in mask_types: 424 if mask_type in pair['masks']: 425 masks[mask_type] = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE) 426 427 # Create patches 428 result = create_patches_from_image(image, masks, patch_size, scaling_factor, 429 step, roi_bbox, preprocess_fns) 430 431 n_rows, n_cols = result['image'].shape[0], result['image'].shape[1] 432 local_metadata = [] 433 434 # Write all patches for this image 435 for row_idx in range(n_rows): 436 for col_idx in range(n_cols): 437 patch_name = f"{base_name}_r{row_idx:02d}_c{col_idx:02d}.png" 438 439 # Write image patch 440 img_patch = result['image'][row_idx, col_idx, 0] 441 cv2.imwrite(str(paths_dict['images'] / patch_name), img_patch) 442 443 # Write mask patches 444 for mask_type in mask_types: 445 if mask_type in masks: 446 mask_patch = result['masks'][mask_type][row_idx, col_idx, 0] 447 cv2.imwrite(str(paths_dict[f'masks_{mask_type}'] / patch_name), mask_patch) 448 449 # Calculate coordinates 450 x_start = col_idx * step 451 y_start = row_idx * step 452 x_end = x_start + patch_size 453 y_end = y_start + patch_size 454 455 # Record metadata 456 patch_metadata = { 457 "patch_filename": patch_name, 458 "source_image": img_path.name, 459 "row_idx": row_idx, 460 "col_idx": col_idx, 461 "x_start": x_start, 462 "y_start": y_start, 463 "x_end": x_end, 464 "y_end": y_end, 465 "grid_size": [n_rows, n_cols] 466 } 467 468 if roi_bbox: 469 patch_metadata["roi_bbox"] = [list(roi_bbox[0]), list(roi_bbox[1])] 470 471 local_metadata.append(patch_metadata) 472 473 return local_metadata, n_rows * n_cols, img_path.name 474 475def save_patches_parallel(pairs, output_dir, dataset_type, patch_size=128, 476 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], 477 filter_roi=True, preprocess_fns=None, notes="", num_workers=None): 478 """Create and save all patches to disk using parallel processing. 479 480 This is an optimized version of save_patches() that processes multiple images 481 in parallel using multiprocessing. 482 483 Args: 484 pairs: List from get_image_mask_pairs(). 485 output_dir: Base output directory. 486 dataset_type: Either 'train' or 'val'. 487 patch_size: Patch size. Default is 128. 488 scaling_factor: Scaling factor. Default is 1.0. 489 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 490 mask_types: Which masks to process. Default is ['root', 'shoot', 'seed']. 491 filter_roi: If True, crop to ROI before patching. Default is True. 492 preprocess_fns: Optional list of preprocessing functions to apply to images. 493 notes: Optional notes to include in metadata. Default is empty string. 494 num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None. 495 496 Returns: 497 Number of patches created. 498 """ 499 import multiprocessing as mp 500 from concurrent.futures import ProcessPoolExecutor, as_completed 501 502 if num_workers is None: 503 num_workers = max(1, mp.cpu_count() - 1) 504 505 if step is None: 506 step = patch_size 507 508 # Auto-clean only directories for this dataset_type 509 output_path = Path(output_dir) 510 if output_path.exists(): 511 dirs_to_clean = [ 512 output_path / f'{dataset_type}_images', 513 *[output_path / f'{dataset_type}_masks_{mt}' for mt in mask_types] 514 ] 515 516 for dir_path in dirs_to_clean: 517 if dir_path.exists(): 518 print(f"Cleaning existing directory: {dir_path}") 519 shutil.rmtree(dir_path) 520 521 # Create directories 522 paths = create_patch_directories(output_dir, dataset_type, mask_types) 523 524 # Extract preprocessing function names for metadata 525 preprocess_names = [] 526 if preprocess_fns: 527 for fn in preprocess_fns: 528 preprocess_names.append(fn.__name__) 529 530 metadata = { 531 "dataset_info": { 532 "dataset_type": dataset_type, 533 "dataset_source": str(pairs[0].get('image').parent.absolute().resolve()), 534 "patch_size": patch_size, 535 "step": step, 536 "overlap_percent": (1 - step / patch_size) * 100, 537 "scaling_factor": scaling_factor, 538 "filter_roi": filter_roi, 539 "preprocessing": preprocess_names if preprocess_names else None, 540 "created_at": datetime.now().isoformat(), 541 "epoch_utc": int(datetime.now(timezone.utc).timestamp()), 542 "num_source_images": len(pairs), 543 "num_patches": 0, 544 "notes": notes, 545 "metadata_version": METADATA_VERSION 546 }, 547 "patches": [] 548 } 549 550 # Prepare arguments for all workers 551 worker_args = [ 552 (pair, paths, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns) 553 for pair in pairs 554 ] 555 556 # Process in parallel with progress bar 557 patch_count = 0 558 with ProcessPoolExecutor(max_workers=num_workers) as executor: 559 futures = {executor.submit(_process_image_worker_parallel, args): args[0]['image'].name 560 for args in worker_args} 561 562 with tqdm(total=len(pairs), desc=f"Processing images ({num_workers} workers)") as pbar: 563 for future in as_completed(futures): 564 try: 565 local_meta, count, img_name = future.result() 566 metadata['patches'].extend(local_meta) 567 patch_count += count 568 pbar.update(1) 569 except Exception as e: 570 print(f"\nError processing {futures[future]}: {e}") 571 raise 572 573 metadata['dataset_info']['num_patches'] = patch_count 574 575 # Save metadata 576 metadata_path = Path(output_dir) / f'{dataset_type}_metadata.json' 577 with open(metadata_path, 'w') as f: 578 json.dump(metadata, f, indent=2) 579 580 print(f"\nTotal: {patch_count} patches saved") 581 print(f"Overlap: {metadata['dataset_info']['overlap_percent']:.1f}%") 582 print(f"ROI cropping: {'enabled' if filter_roi else 'disabled'}") 583 if preprocess_names: 584 print(f"Preprocessing: {', '.join(preprocess_names)}") 585 print(f"Workers used: {num_workers}") 586 print(f"Metadata saved to {metadata_path}") 587 588 return patch_count 589 590 591def load_patch_metadata(patch_dir, dataset_type): 592 """Load metadata for a saved patch dataset. 593 594 Args: 595 patch_dir: Directory containing saved patches. 596 dataset_type: Either 'train' or 'val'. 597 598 Returns: 599 Dictionary containing dataset_info and patches list. 600 601 Example: 602 >>> metadata = load_patch_metadata('data/patched', 'train') 603 >>> print(f"Patch size: {metadata['dataset_info']['patch_size']}") 604 >>> print(f"Step size: {metadata['dataset_info']['step']}") 605 >>> print(f"Total patches: {metadata['dataset_info']['num_patches']}") 606 """ 607 metadata_path = Path(patch_dir) / f'{dataset_type}_metadata.json' 608 609 if not metadata_path.exists(): 610 raise FileNotFoundError(f"Metadata not found: {metadata_path}") 611 612 with open(metadata_path, 'r') as f: 613 metadata = json.load(f) 614 615 return metadata 616 617def get_patch_statistics(patch_dir, dataset_type): 618 """Get statistics about a saved patch dataset. 619 620 Args: 621 patch_dir: Directory containing saved patches. 622 dataset_type: Either 'train' or 'val'. 623 624 Returns: 625 Dictionary with dataset statistics. 626 627 Example: 628 >>> stats = get_patch_statistics('data/patched', 'train') 629 >>> print(stats) 630 """ 631 metadata = load_patch_metadata(patch_dir, dataset_type) 632 info = metadata['dataset_info'] 633 634 stats = { 635 'dataset_source': info['dataset_source'], 636 'created_at': info['created_at'], 637 'num_patches': info['num_patches'], 638 'num_source_images': info['num_source_images'], 639 'patch_size': info['patch_size'], 640 'step': info['step'], 641 'overlap_percent': info['overlap_percent'], 642 'patches_per_image': info['num_patches'] / info['num_source_images'] 643 } 644 645 return stats
21def padder(image, patch_size, is_mask=False): 22 """Add padding to an image to make its dimensions divisible by patch size. 23 24 Calculates padding needed for both height and width so that dimensions become 25 divisible by the given patch size. Padding is applied evenly to both sides of 26 each dimension. If padding amount is odd, one extra pixel is added to the 27 bottom or right side. 28 29 Parameters: 30 image: Input image as numpy array with shape (height, width, channels). 31 patch_size: The patch size to which image dimensions should be divisible. 32 is_mask: If True, uses grayscale padding. If False, uses RGB padding. 33 34 Returns: 35 Padded image as numpy array with dimensions divisible by patch_size. 36 37 Example: 38 >>> padded_image = padder(cv2.imread('example.jpg'), 128) 39 """ 40 h, w = image.shape[:2] 41 42 # Calculate padding only if needed 43 height_padding = 0 if h % patch_size == 0 else ((h // patch_size) + 1) * patch_size - h 44 width_padding = 0 if w % patch_size == 0 else ((w // patch_size) + 1) * patch_size - w 45 46 # Early return if no padding needed 47 if height_padding == 0 and width_padding == 0: 48 return image 49 50 # Split padding evenly 51 top_padding = height_padding // 2 52 bottom_padding = height_padding - top_padding 53 left_padding = width_padding // 2 54 right_padding = width_padding - left_padding 55 56 # Use black padding 57 pad_value = 0 58 59 return cv2.copyMakeBorder( 60 image, 61 top_padding, bottom_padding, 62 left_padding, right_padding, 63 cv2.BORDER_CONSTANT, 64 value=pad_value 65 )
Add padding to an image to make its dimensions divisible by patch size.
Calculates padding needed for both height and width so that dimensions become divisible by the given patch size. Padding is applied evenly to both sides of each dimension. If padding amount is odd, one extra pixel is added to the bottom or right side.
Arguments:
- image: Input image as numpy array with shape (height, width, channels).
- patch_size: The patch size to which image dimensions should be divisible.
- is_mask: If True, uses grayscale padding. If False, uses RGB padding.
Returns:
Padded image as numpy array with dimensions divisible by patch_size.
Example:
>>> padded_image = padder(cv2.imread('example.jpg'), 128)
67def unpadder(padded_image, roi_box): 68 """Remove padding from an image to restore original ROI dimensions. 69 70 Calculates and removes padding that was added to make an image divisible by 71 a patch size. Padding is removed evenly from both sides of each dimension. 72 This function reverses the padding operation applied by the padder function. 73 74 Parameters: 75 padded_image: Padded image as numpy array with shape (height, width, channels). 76 roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates. 77 78 Returns: 79 Cropped image as numpy array with original ROI dimensions. 80 81 Example: 82 >>> roi_box = ((776, 70), (3519, 2813)) 83 >>> original = unpadder(padded_image, roi_box) 84 """ 85 h, w = padded_image.shape[:2] 86 87 # Calculate original ROI dimensions 88 roi_height = roi_box[1][1] - roi_box[0][1] 89 roi_width = roi_box[1][0] - roi_box[0][0] 90 91 # Calculate total padding in each dimension 92 total_height_px = h - roi_height 93 total_width_px = w - roi_width 94 95 # Early return if no padding to remove 96 if total_height_px == 0 and total_width_px == 0: 97 return padded_image 98 99 # Split padding evenly 100 left_padding = total_width_px // 2 101 right_padding = total_width_px - left_padding 102 top_padding = total_height_px // 2 103 bottom_padding = total_height_px - top_padding 104 105 # Crop the image & handle masks too 106 if padded_image.ndim == 2: 107 return padded_image[top_padding:-bottom_padding, left_padding:-right_padding] 108 else: 109 return padded_image[top_padding:-bottom_padding, left_padding:-right_padding, :]
Remove padding from an image to restore original ROI dimensions.
Calculates and removes padding that was added to make an image divisible by a patch size. Padding is removed evenly from both sides of each dimension. This function reverses the padding operation applied by the padder function.
Arguments:
- padded_image: Padded image as numpy array with shape (height, width, channels).
- roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
Returns:
Cropped image as numpy array with original ROI dimensions.
Example:
>>> roi_box = ((776, 70), (3519, 2813)) >>> original = unpadder(padded_image, roi_box)
112def restore_mask_to_original(padded_mask, original_image_shape, roi_box): 113 """Restore a padded mask to match the original image dimensions. 114 115 This function removes padding from a mask and places it at the correct position 116 in a full-size mask that matches the original image dimensions. 117 118 Parameters: 119 padded_mask: Padded binary mask as numpy array with shape (height, width). 120 original_image_shape: Tuple of (height, width) or (height, width, channels) 121 from the original image. 122 roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates. 123 124 Returns: 125 Binary mask as numpy array with shape matching original_image_shape[:2], 126 with mask values of 0 and 255. 127 128 Example: 129 >>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box) 130 >>> cv2.imwrite('output.png', full_mask) 131 """ 132 # Remove padding using unpadder 133 unpadded_mask = unpadder(padded_mask, roi_box) 134 135 # Ensure mask is binary (0 and 255) 136 binary_mask = (unpadded_mask > 0).astype(np.uint8) * 255 137 138 # Create full-size mask matching original image dimensions 139 full_mask = np.zeros(original_image_shape[:2], dtype=np.uint8) 140 141 # Extract ROI coordinates 142 (x1, y1), (x2, y2) = roi_box 143 144 # Place the mask at the correct position 145 full_mask[y1:y2, x1:x2] = binary_mask 146 147 return full_mask
Restore a padded mask to match the original image dimensions.
This function removes padding from a mask and places it at the correct position in a full-size mask that matches the original image dimensions.
Arguments:
- padded_mask: Padded binary mask as numpy array with shape (height, width).
- original_image_shape: Tuple of (height, width) or (height, width, channels) from the original image.
- roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
Returns:
Binary mask as numpy array with shape matching original_image_shape[:2], with mask values of 0 and 255.
Example:
>>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box) >>> cv2.imwrite('output.png', full_mask)
149def apply_preprocessing_pipeline(image, preprocess_fns): 150 """Apply a list of preprocessing functions sequentially to an image. 151 152 Args: 153 image: Numpy array with shape (H, W, C) or (H, W). 154 preprocess_fns: List of callables, each taking image and returning 155 processed image. Can be None or empty list. 156 157 Returns: 158 Preprocessed image. 159 """ 160 if preprocess_fns is None or len(preprocess_fns) == 0: 161 return image 162 163 processed = image.copy() 164 for fn in preprocess_fns: 165 processed = fn(processed) 166 167 return processed
Apply a list of preprocessing functions sequentially to an image.
Arguments:
- image: Numpy array with shape (H, W, C) or (H, W).
- preprocess_fns: List of callables, each taking image and returning processed image. Can be None or empty list.
Returns:
Preprocessed image.
169def process_image(image, patch_size, scaling_factor, is_mask=False): 170 """Pad and scale a single image. 171 172 Args: 173 image: Numpy array with shape (H, W, C) or (H, W). 174 patch_size: Target patch size. 175 scaling_factor: Scaling factor (<=1.0). 176 is_mask: Whether this is a mask (affects padding value). 177 178 Returns: 179 Processed image. 180 """ 181 # Scale if needed 182 if scaling_factor != 1.0: 183 image = cv2.resize(image, (0, 0), fx=scaling_factor, fy=scaling_factor) 184 185 # Pad to be divisible by patch_size 186 image = padder(image, patch_size, is_mask=is_mask) 187 188 return image
Pad and scale a single image.
Arguments:
- image: Numpy array with shape (H, W, C) or (H, W).
- patch_size: Target patch size.
- scaling_factor: Scaling factor (<=1.0).
- is_mask: Whether this is a mask (affects padding value).
Returns:
Processed image.
191def create_patch_directories(output_dir, dataset_type, mask_types=['root', 'shoot', 'seed']): 192 """Create directory structure needed for patch datasets. 193 194 Args: 195 output_dir: Base output directory (e.g., 'data_patched'). 196 dataset_type: Either 'train' or 'val'. 197 mask_types: List of mask types to create directories for. 198 199 Returns: 200 Dictionary of created paths with keys 'images' and 'masks_{type}'. 201 """ 202 paths = {} 203 204 # Images directory 205 img_dir = Path(output_dir) / f'{dataset_type}_images' / dataset_type 206 img_dir.mkdir(parents=True, exist_ok=True) 207 paths['images'] = img_dir 208 209 # Mask directories for each type 210 for mask_type in mask_types: 211 mask_dir = Path(output_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type 212 mask_dir.mkdir(parents=True, exist_ok=True) 213 paths[f'masks_{mask_type}'] = mask_dir 214 215 return paths
Create directory structure needed for patch datasets.
Arguments:
- output_dir: Base output directory (e.g., 'data_patched').
- dataset_type: Either 'train' or 'val'.
- mask_types: List of mask types to create directories for.
Returns:
Dictionary of created paths with keys 'images' and 'masks_{type}'.
217def get_image_mask_pairs(data_dir, dataset_type, mask_types=['root', 'shoot', 'seed']): 218 """Find all images and their corresponding masks. 219 220 Args: 221 data_dir: Root data directory (e.g., '../../data/dataset'). 222 dataset_type: Either 'train' or 'val'. 223 mask_types: List of mask types to find. 224 225 Returns: 226 List of dictionaries with 'image' path and 'masks' dictionary. 227 Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}} 228 """ 229 import glob 230 231 image_dir = Path(data_dir) / f'{dataset_type}_images' 232 mask_dir = Path(data_dir) / f'{dataset_type}_masks' 233 234 # Get all image files 235 image_files = sorted(glob.glob(str(image_dir / '*.png'))) 236 237 pairs = [] 238 for img_path in image_files: 239 img_path = Path(img_path) 240 base_name = img_path.stem # filename without extension 241 242 # Find corresponding masks 243 masks = {} 244 for mask_type in mask_types: 245 mask_pattern = f'{base_name}_{mask_type}_mask.tif' 246 mask_path = mask_dir / mask_pattern 247 248 if mask_path.exists(): 249 masks[mask_type] = mask_path 250 251 # Only include if at least one mask exists 252 if masks: 253 pairs.append({ 254 'image': img_path, 255 'masks': masks 256 }) 257 258 return pairs
Find all images and their corresponding masks.
Arguments:
- data_dir: Root data directory (e.g., '../../data/dataset').
- dataset_type: Either 'train' or 'val'.
- mask_types: List of mask types to find.
Returns:
List of dictionaries with 'image' path and 'masks' dictionary. Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}}
260def create_patches_from_image(image, mask_dict, patch_size, scaling_factor, step=None, 261 roi_bbox=None, preprocess_fns=None): 262 """Create patches from one image and its corresponding masks. 263 264 Args: 265 image: Numpy array of the image. 266 mask_dict: Dictionary with mask_type: mask_array. 267 patch_size: Size of patches. 268 scaling_factor: Scaling factor for resizing. 269 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 270 roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching. 271 preprocess_fns: List of preprocessing functions to apply before patching. 272 273 Returns: 274 Dictionary with 'image' patches, 'masks' dict of patches for each type, 275 'step', and 'roi_bbox' (if cropped). 276 """ 277 278 if step is None: 279 step = patch_size 280 281 # Crop to ROI if provided 282 if roi_bbox is not None: 283 image = crop_to_roi(image, roi_bbox) 284 mask_dict = {k: crop_to_roi(v, roi_bbox) for k, v in mask_dict.items()} 285 286 # Apply preprocessing to image only (not masks) 287 image = apply_preprocessing_pipeline(image, preprocess_fns) 288 289 # Convert to grayscale if color image 290 if image.ndim == 3 and image.shape[2] == 3: 291 image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 292 293 # Process image (add channel dimension for grayscale) 294 image = process_image(image, patch_size, scaling_factor, is_mask=False) 295 image = image[..., np.newaxis] # Add channel dimension for patchify 296 img_patches = patchify(image, (patch_size, patch_size, 1), step=step) 297 298 # Process masks 299 mask_patches = {} 300 for mask_type, mask in mask_dict.items(): 301 mask = process_image(mask, patch_size, scaling_factor, is_mask=True) 302 mask = mask[..., np.newaxis] # Add channel dimension for patchify 303 patches = patchify(mask, (patch_size, patch_size, 1), step=step) 304 mask_patches[mask_type] = patches 305 306 result = { 307 'image': img_patches, 308 'masks': mask_patches, 309 'step': step 310 } 311 312 if roi_bbox is not None: 313 result['roi_bbox'] = roi_bbox 314 315 return result
Create patches from one image and its corresponding masks.
Arguments:
- image: Numpy array of the image.
- mask_dict: Dictionary with mask_type: mask_array.
- patch_size: Size of patches.
- scaling_factor: Scaling factor for resizing.
- step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
- roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching.
- preprocess_fns: List of preprocessing functions to apply before patching.
Returns:
Dictionary with 'image' patches, 'masks' dict of patches for each type, 'step', and 'roi_bbox' (if cropped).
317def reconstruct_from_patches(patches, image_shape, patch_size, step): 318 """Reconstruct an image from patches. 319 320 Uses unpatchify for non-overlapping patches (step == patch_size). 321 Uses averaging reconstruction for overlapping patches (step < patch_size). 322 323 Args: 324 patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels). 325 image_shape: Target shape (height, width, channels) for reconstruction. 326 patch_size: Size of each patch. 327 step: Step size used during patch extraction. 328 329 Returns: 330 Reconstructed image. 331 """ 332 # Use unpatchify for non-overlapping patches 333 if step == patch_size: 334 return unpatchify(patches, image_shape) 335 336 # Use averaging for overlapping patches 337 h, w, c = image_shape 338 n_rows, n_cols = patches.shape[0], patches.shape[1] 339 340 reconstructed = np.zeros(image_shape, dtype=np.float32) 341 counts = np.zeros((h, w), dtype=np.float32) 342 343 for row_idx in range(n_rows): 344 for col_idx in range(n_cols): 345 y_start = row_idx * step 346 x_start = col_idx * step 347 y_end = min(y_start + patch_size, h) 348 x_end = min(x_start + patch_size, w) 349 350 patch = patches[row_idx, col_idx, 0] 351 patch_h = y_end - y_start 352 patch_w = x_end - x_start 353 354 reconstructed[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w] 355 counts[y_start:y_end, x_start:x_end] += 1 356 357 counts = np.maximum(counts, 1) 358 reconstructed = reconstructed / counts[:, :, np.newaxis] 359 360 return reconstructed.astype(np.uint8)
Reconstruct an image from patches.
Uses unpatchify for non-overlapping patches (step == patch_size). Uses averaging reconstruction for overlapping patches (step < patch_size).
Arguments:
- patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels).
- image_shape: Target shape (height, width, channels) for reconstruction.
- patch_size: Size of each patch.
- step: Step size used during patch extraction.
Returns:
Reconstructed image.
363def save_patches(pairs, output_dir, dataset_type, patch_size=128, 364 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], 365 filter_roi=True, preprocess_fns=None, notes=""): 366 """Create and save all patches to disk, optionally cropping to ROI first. 367 368 This function processes images serially. For faster processing with multiple 369 images, use save_patches_parallel() instead. 370 371 Args: 372 pairs: List from get_image_mask_pairs(). 373 output_dir: Base output directory. 374 dataset_type: Either 'train' or 'val'. 375 patch_size: Patch size. Default is 128. 376 scaling_factor: Scaling factor. Default is 1.0. 377 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 378 mask_types: Which masks to process. Default is ['root', 'shoot', 'seed']. 379 filter_roi: If True, crop to ROI before patching. Default is True. 380 preprocess_fns: Optional list of preprocessing functions to apply to images. 381 notes: Optional notes to include in metadata. Default is empty string. 382 383 Returns: 384 Number of patches created. 385 """ 386 # Simply delegate to save_patches_parallel with num_workers=1 for serial processing 387 return save_patches_parallel( 388 pairs=pairs, 389 output_dir=output_dir, 390 dataset_type=dataset_type, 391 patch_size=patch_size, 392 scaling_factor=scaling_factor, 393 step=step, 394 mask_types=mask_types, 395 filter_roi=filter_roi, 396 preprocess_fns=preprocess_fns, 397 notes=notes, 398 num_workers=1 399 )
Create and save all patches to disk, optionally cropping to ROI first.
This function processes images serially. For faster processing with multiple images, use save_patches_parallel() instead.
Arguments:
- pairs: List from get_image_mask_pairs().
- output_dir: Base output directory.
- dataset_type: Either 'train' or 'val'.
- patch_size: Patch size. Default is 128.
- scaling_factor: Scaling factor. Default is 1.0.
- step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
- mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
- filter_roi: If True, crop to ROI before patching. Default is True.
- preprocess_fns: Optional list of preprocessing functions to apply to images.
- notes: Optional notes to include in metadata. Default is empty string.
Returns:
Number of patches created.
476def save_patches_parallel(pairs, output_dir, dataset_type, patch_size=128, 477 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], 478 filter_roi=True, preprocess_fns=None, notes="", num_workers=None): 479 """Create and save all patches to disk using parallel processing. 480 481 This is an optimized version of save_patches() that processes multiple images 482 in parallel using multiprocessing. 483 484 Args: 485 pairs: List from get_image_mask_pairs(). 486 output_dir: Base output directory. 487 dataset_type: Either 'train' or 'val'. 488 patch_size: Patch size. Default is 128. 489 scaling_factor: Scaling factor. Default is 1.0. 490 step: Step size for patch extraction. If None, defaults to patch_size (no overlap). 491 mask_types: Which masks to process. Default is ['root', 'shoot', 'seed']. 492 filter_roi: If True, crop to ROI before patching. Default is True. 493 preprocess_fns: Optional list of preprocessing functions to apply to images. 494 notes: Optional notes to include in metadata. Default is empty string. 495 num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None. 496 497 Returns: 498 Number of patches created. 499 """ 500 import multiprocessing as mp 501 from concurrent.futures import ProcessPoolExecutor, as_completed 502 503 if num_workers is None: 504 num_workers = max(1, mp.cpu_count() - 1) 505 506 if step is None: 507 step = patch_size 508 509 # Auto-clean only directories for this dataset_type 510 output_path = Path(output_dir) 511 if output_path.exists(): 512 dirs_to_clean = [ 513 output_path / f'{dataset_type}_images', 514 *[output_path / f'{dataset_type}_masks_{mt}' for mt in mask_types] 515 ] 516 517 for dir_path in dirs_to_clean: 518 if dir_path.exists(): 519 print(f"Cleaning existing directory: {dir_path}") 520 shutil.rmtree(dir_path) 521 522 # Create directories 523 paths = create_patch_directories(output_dir, dataset_type, mask_types) 524 525 # Extract preprocessing function names for metadata 526 preprocess_names = [] 527 if preprocess_fns: 528 for fn in preprocess_fns: 529 preprocess_names.append(fn.__name__) 530 531 metadata = { 532 "dataset_info": { 533 "dataset_type": dataset_type, 534 "dataset_source": str(pairs[0].get('image').parent.absolute().resolve()), 535 "patch_size": patch_size, 536 "step": step, 537 "overlap_percent": (1 - step / patch_size) * 100, 538 "scaling_factor": scaling_factor, 539 "filter_roi": filter_roi, 540 "preprocessing": preprocess_names if preprocess_names else None, 541 "created_at": datetime.now().isoformat(), 542 "epoch_utc": int(datetime.now(timezone.utc).timestamp()), 543 "num_source_images": len(pairs), 544 "num_patches": 0, 545 "notes": notes, 546 "metadata_version": METADATA_VERSION 547 }, 548 "patches": [] 549 } 550 551 # Prepare arguments for all workers 552 worker_args = [ 553 (pair, paths, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns) 554 for pair in pairs 555 ] 556 557 # Process in parallel with progress bar 558 patch_count = 0 559 with ProcessPoolExecutor(max_workers=num_workers) as executor: 560 futures = {executor.submit(_process_image_worker_parallel, args): args[0]['image'].name 561 for args in worker_args} 562 563 with tqdm(total=len(pairs), desc=f"Processing images ({num_workers} workers)") as pbar: 564 for future in as_completed(futures): 565 try: 566 local_meta, count, img_name = future.result() 567 metadata['patches'].extend(local_meta) 568 patch_count += count 569 pbar.update(1) 570 except Exception as e: 571 print(f"\nError processing {futures[future]}: {e}") 572 raise 573 574 metadata['dataset_info']['num_patches'] = patch_count 575 576 # Save metadata 577 metadata_path = Path(output_dir) / f'{dataset_type}_metadata.json' 578 with open(metadata_path, 'w') as f: 579 json.dump(metadata, f, indent=2) 580 581 print(f"\nTotal: {patch_count} patches saved") 582 print(f"Overlap: {metadata['dataset_info']['overlap_percent']:.1f}%") 583 print(f"ROI cropping: {'enabled' if filter_roi else 'disabled'}") 584 if preprocess_names: 585 print(f"Preprocessing: {', '.join(preprocess_names)}") 586 print(f"Workers used: {num_workers}") 587 print(f"Metadata saved to {metadata_path}") 588 589 return patch_count
Create and save all patches to disk using parallel processing.
This is an optimized version of save_patches() that processes multiple images in parallel using multiprocessing.
Arguments:
- pairs: List from get_image_mask_pairs().
- output_dir: Base output directory.
- dataset_type: Either 'train' or 'val'.
- patch_size: Patch size. Default is 128.
- scaling_factor: Scaling factor. Default is 1.0.
- step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
- mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
- filter_roi: If True, crop to ROI before patching. Default is True.
- preprocess_fns: Optional list of preprocessing functions to apply to images.
- notes: Optional notes to include in metadata. Default is empty string.
- num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None.
Returns:
Number of patches created.
592def load_patch_metadata(patch_dir, dataset_type): 593 """Load metadata for a saved patch dataset. 594 595 Args: 596 patch_dir: Directory containing saved patches. 597 dataset_type: Either 'train' or 'val'. 598 599 Returns: 600 Dictionary containing dataset_info and patches list. 601 602 Example: 603 >>> metadata = load_patch_metadata('data/patched', 'train') 604 >>> print(f"Patch size: {metadata['dataset_info']['patch_size']}") 605 >>> print(f"Step size: {metadata['dataset_info']['step']}") 606 >>> print(f"Total patches: {metadata['dataset_info']['num_patches']}") 607 """ 608 metadata_path = Path(patch_dir) / f'{dataset_type}_metadata.json' 609 610 if not metadata_path.exists(): 611 raise FileNotFoundError(f"Metadata not found: {metadata_path}") 612 613 with open(metadata_path, 'r') as f: 614 metadata = json.load(f) 615 616 return metadata
Load metadata for a saved patch dataset.
Arguments:
- patch_dir: Directory containing saved patches.
- dataset_type: Either 'train' or 'val'.
Returns:
Dictionary containing dataset_info and patches list.
Example:
>>> metadata = load_patch_metadata('data/patched', 'train') >>> print(f"Patch size: {metadata['dataset_info']['patch_size']}") >>> print(f"Step size: {metadata['dataset_info']['step']}") >>> print(f"Total patches: {metadata['dataset_info']['num_patches']}")
618def get_patch_statistics(patch_dir, dataset_type): 619 """Get statistics about a saved patch dataset. 620 621 Args: 622 patch_dir: Directory containing saved patches. 623 dataset_type: Either 'train' or 'val'. 624 625 Returns: 626 Dictionary with dataset statistics. 627 628 Example: 629 >>> stats = get_patch_statistics('data/patched', 'train') 630 >>> print(stats) 631 """ 632 metadata = load_patch_metadata(patch_dir, dataset_type) 633 info = metadata['dataset_info'] 634 635 stats = { 636 'dataset_source': info['dataset_source'], 637 'created_at': info['created_at'], 638 'num_patches': info['num_patches'], 639 'num_source_images': info['num_source_images'], 640 'patch_size': info['patch_size'], 641 'step': info['step'], 642 'overlap_percent': info['overlap_percent'], 643 'patches_per_image': info['num_patches'] / info['num_source_images'] 644 } 645 646 return stats
Get statistics about a saved patch dataset.
Arguments:
- patch_dir: Directory containing saved patches.
- dataset_type: Either 'train' or 'val'.
Returns:
Dictionary with dataset statistics.
Example:
>>> stats = get_patch_statistics('data/patched', 'train') >>> print(stats)