library.patch_dataset

Functions for creating and managing patch-based datasets from images and masks.

  1# library/patch_dataset.py
  2"""Functions for creating and managing patch-based datasets from images and masks."""
  3
  4import os
  5import shutil
  6import json
  7import glob
  8from datetime import datetime, timezone
  9from pathlib import Path
 10from datetime import datetime
 11import cv2
 12import numpy as np
 13from patchify import patchify, unpatchify
 14from tqdm import tqdm
 15from library.roi import detect_roi, crop_to_roi
 16
 17METADATA_VERSION = "1.0"
 18
 19
 20def padder(image, patch_size, is_mask=False):
 21    """Add padding to an image to make its dimensions divisible by patch size.
 22
 23    Calculates padding needed for both height and width so that dimensions become 
 24    divisible by the given patch size. Padding is applied evenly to both sides of 
 25    each dimension. If padding amount is odd, one extra pixel is added to the 
 26    bottom or right side.
 27
 28    Parameters:
 29        image: Input image as numpy array with shape (height, width, channels).
 30        patch_size: The patch size to which image dimensions should be divisible.
 31        is_mask: If True, uses grayscale padding. If False, uses RGB padding.
 32
 33    Returns:
 34        Padded image as numpy array with dimensions divisible by patch_size.
 35
 36    Example:
 37        >>> padded_image = padder(cv2.imread('example.jpg'), 128)
 38    """
 39    h, w = image.shape[:2]
 40    
 41    # Calculate padding only if needed
 42    height_padding = 0 if h % patch_size == 0 else ((h // patch_size) + 1) * patch_size - h
 43    width_padding = 0 if w % patch_size == 0 else ((w // patch_size) + 1) * patch_size - w
 44    
 45    # Early return if no padding needed
 46    if height_padding == 0 and width_padding == 0:
 47        return image
 48    
 49    # Split padding evenly
 50    top_padding = height_padding // 2
 51    bottom_padding = height_padding - top_padding
 52    left_padding = width_padding // 2
 53    right_padding = width_padding - left_padding
 54    
 55    # Use black padding
 56    pad_value = 0
 57    
 58    return cv2.copyMakeBorder(
 59        image, 
 60        top_padding, bottom_padding, 
 61        left_padding, right_padding, 
 62        cv2.BORDER_CONSTANT, 
 63        value=pad_value
 64    )
 65
 66def unpadder(padded_image, roi_box):
 67    """Remove padding from an image to restore original ROI dimensions.
 68
 69    Calculates and removes padding that was added to make an image divisible by
 70    a patch size. Padding is removed evenly from both sides of each dimension.
 71    This function reverses the padding operation applied by the padder function.
 72
 73    Parameters:
 74        padded_image: Padded image as numpy array with shape (height, width, channels).
 75        roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
 76
 77    Returns:
 78        Cropped image as numpy array with original ROI dimensions.
 79
 80    Example:
 81        >>> roi_box = ((776, 70), (3519, 2813))
 82        >>> original = unpadder(padded_image, roi_box)
 83    """
 84    h, w = padded_image.shape[:2]
 85
 86    # Calculate original ROI dimensions
 87    roi_height = roi_box[1][1] - roi_box[0][1]
 88    roi_width = roi_box[1][0] - roi_box[0][0]
 89
 90    # Calculate total padding in each dimension
 91    total_height_px = h - roi_height
 92    total_width_px = w - roi_width
 93
 94    # Early return if no padding to remove
 95    if total_height_px == 0 and total_width_px == 0:
 96        return padded_image
 97
 98    # Split padding evenly
 99    left_padding = total_width_px // 2
100    right_padding = total_width_px - left_padding
101    top_padding = total_height_px // 2
102    bottom_padding = total_height_px - top_padding
103
104    # Crop the image & handle masks too
105    if padded_image.ndim == 2:
106        return padded_image[top_padding:-bottom_padding, left_padding:-right_padding]
107    else:
108        return padded_image[top_padding:-bottom_padding, left_padding:-right_padding, :]
109
110
111def restore_mask_to_original(padded_mask, original_image_shape, roi_box):
112    """Restore a padded mask to match the original image dimensions.
113    
114    This function removes padding from a mask and places it at the correct position
115    in a full-size mask that matches the original image dimensions.
116    
117    Parameters:
118        padded_mask: Padded binary mask as numpy array with shape (height, width).
119        original_image_shape: Tuple of (height, width) or (height, width, channels) 
120                            from the original image.
121        roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
122    
123    Returns:
124        Binary mask as numpy array with shape matching original_image_shape[:2],
125        with mask values of 0 and 255.
126    
127    Example:
128        >>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box)
129        >>> cv2.imwrite('output.png', full_mask)
130    """
131    # Remove padding using unpadder
132    unpadded_mask = unpadder(padded_mask, roi_box)
133    
134    # Ensure mask is binary (0 and 255)
135    binary_mask = (unpadded_mask > 0).astype(np.uint8) * 255
136    
137    # Create full-size mask matching original image dimensions
138    full_mask = np.zeros(original_image_shape[:2], dtype=np.uint8)
139    
140    # Extract ROI coordinates
141    (x1, y1), (x2, y2) = roi_box
142    
143    # Place the mask at the correct position
144    full_mask[y1:y2, x1:x2] = binary_mask
145    
146    return full_mask
147
148def apply_preprocessing_pipeline(image, preprocess_fns):
149    """Apply a list of preprocessing functions sequentially to an image.
150    
151    Args:
152        image: Numpy array with shape (H, W, C) or (H, W).
153        preprocess_fns: List of callables, each taking image and returning 
154            processed image. Can be None or empty list.
155    
156    Returns:
157        Preprocessed image.
158    """
159    if preprocess_fns is None or len(preprocess_fns) == 0:
160        return image
161    
162    processed = image.copy()
163    for fn in preprocess_fns:
164        processed = fn(processed)
165    
166    return processed
167
168def process_image(image, patch_size, scaling_factor, is_mask=False):
169    """Pad and scale a single image.
170    
171    Args:
172        image: Numpy array with shape (H, W, C) or (H, W).
173        patch_size: Target patch size.
174        scaling_factor: Scaling factor (<=1.0).
175        is_mask: Whether this is a mask (affects padding value).
176    
177    Returns:
178        Processed image.
179    """
180    # Scale if needed
181    if scaling_factor != 1.0:
182        image = cv2.resize(image, (0, 0), fx=scaling_factor, fy=scaling_factor)
183    
184    # Pad to be divisible by patch_size
185    image = padder(image, patch_size, is_mask=is_mask)
186    
187    return image
188
189
190def create_patch_directories(output_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
191    """Create directory structure needed for patch datasets.
192    
193    Args:
194        output_dir: Base output directory (e.g., 'data_patched').
195        dataset_type: Either 'train' or 'val'.
196        mask_types: List of mask types to create directories for.
197    
198    Returns:
199        Dictionary of created paths with keys 'images' and 'masks_{type}'.
200    """
201    paths = {}
202    
203    # Images directory
204    img_dir = Path(output_dir) / f'{dataset_type}_images' / dataset_type
205    img_dir.mkdir(parents=True, exist_ok=True)
206    paths['images'] = img_dir
207    
208    # Mask directories for each type
209    for mask_type in mask_types:
210        mask_dir = Path(output_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type
211        mask_dir.mkdir(parents=True, exist_ok=True)
212        paths[f'masks_{mask_type}'] = mask_dir
213    
214    return paths
215
216def get_image_mask_pairs(data_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
217    """Find all images and their corresponding masks.
218    
219    Args:
220        data_dir: Root data directory (e.g., '../../data/dataset').
221        dataset_type: Either 'train' or 'val'.
222        mask_types: List of mask types to find.
223    
224    Returns:
225        List of dictionaries with 'image' path and 'masks' dictionary.
226        Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}}
227    """
228    import glob
229    
230    image_dir = Path(data_dir) / f'{dataset_type}_images'
231    mask_dir = Path(data_dir) / f'{dataset_type}_masks'
232    
233    # Get all image files
234    image_files = sorted(glob.glob(str(image_dir / '*.png')))
235    
236    pairs = []
237    for img_path in image_files:
238        img_path = Path(img_path)
239        base_name = img_path.stem  # filename without extension
240        
241        # Find corresponding masks
242        masks = {}
243        for mask_type in mask_types:
244            mask_pattern = f'{base_name}_{mask_type}_mask.tif'
245            mask_path = mask_dir / mask_pattern
246            
247            if mask_path.exists():
248                masks[mask_type] = mask_path
249        
250        # Only include if at least one mask exists
251        if masks:
252            pairs.append({
253                'image': img_path,
254                'masks': masks
255            })
256    
257    return pairs
258
259def create_patches_from_image(image, mask_dict, patch_size, scaling_factor, step=None, 
260                               roi_bbox=None, preprocess_fns=None):
261    """Create patches from one image and its corresponding masks.
262    
263    Args:
264        image: Numpy array of the image.
265        mask_dict: Dictionary with mask_type: mask_array.
266        patch_size: Size of patches.
267        scaling_factor: Scaling factor for resizing.
268        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
269        roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching.
270        preprocess_fns: List of preprocessing functions to apply before patching.
271    
272    Returns:
273        Dictionary with 'image' patches, 'masks' dict of patches for each type, 
274        'step', and 'roi_bbox' (if cropped).
275    """
276    
277    if step is None:
278        step = patch_size
279    
280    # Crop to ROI if provided
281    if roi_bbox is not None:
282        image = crop_to_roi(image, roi_bbox)
283        mask_dict = {k: crop_to_roi(v, roi_bbox) for k, v in mask_dict.items()}
284    
285    # Apply preprocessing to image only (not masks)
286    image = apply_preprocessing_pipeline(image, preprocess_fns)
287    
288    # Convert to grayscale if color image
289    if image.ndim == 3 and image.shape[2] == 3:
290        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
291
292    # Process image (add channel dimension for grayscale)
293    image = process_image(image, patch_size, scaling_factor, is_mask=False)
294    image = image[..., np.newaxis]  # Add channel dimension for patchify
295    img_patches = patchify(image, (patch_size, patch_size, 1), step=step)
296    
297    # Process masks
298    mask_patches = {}
299    for mask_type, mask in mask_dict.items():
300        mask = process_image(mask, patch_size, scaling_factor, is_mask=True)
301        mask = mask[..., np.newaxis]  # Add channel dimension for patchify
302        patches = patchify(mask, (patch_size, patch_size, 1), step=step)
303        mask_patches[mask_type] = patches
304    
305    result = {
306        'image': img_patches,
307        'masks': mask_patches,
308        'step': step
309    }
310    
311    if roi_bbox is not None:
312        result['roi_bbox'] = roi_bbox
313    
314    return result
315
316def reconstruct_from_patches(patches, image_shape, patch_size, step):
317    """Reconstruct an image from patches.
318    
319    Uses unpatchify for non-overlapping patches (step == patch_size).
320    Uses averaging reconstruction for overlapping patches (step < patch_size).
321    
322    Args:
323        patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels).
324        image_shape: Target shape (height, width, channels) for reconstruction.
325        patch_size: Size of each patch.
326        step: Step size used during patch extraction.
327    
328    Returns:
329        Reconstructed image.
330    """
331    # Use unpatchify for non-overlapping patches
332    if step == patch_size:
333        return unpatchify(patches, image_shape)
334    
335    # Use averaging for overlapping patches
336    h, w, c = image_shape
337    n_rows, n_cols = patches.shape[0], patches.shape[1]
338    
339    reconstructed = np.zeros(image_shape, dtype=np.float32)
340    counts = np.zeros((h, w), dtype=np.float32)
341    
342    for row_idx in range(n_rows):
343        for col_idx in range(n_cols):
344            y_start = row_idx * step
345            x_start = col_idx * step
346            y_end = min(y_start + patch_size, h)
347            x_end = min(x_start + patch_size, w)
348            
349            patch = patches[row_idx, col_idx, 0]
350            patch_h = y_end - y_start
351            patch_w = x_end - x_start
352            
353            reconstructed[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w]
354            counts[y_start:y_end, x_start:x_end] += 1
355    
356    counts = np.maximum(counts, 1)
357    reconstructed = reconstructed / counts[:, :, np.newaxis]
358    
359    return reconstructed.astype(np.uint8)
360
361
362def save_patches(pairs, output_dir, dataset_type, patch_size=128, 
363                 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'],
364                 filter_roi=True, preprocess_fns=None, notes=""):
365    """Create and save all patches to disk, optionally cropping to ROI first.
366    
367    This function processes images serially. For faster processing with multiple
368    images, use save_patches_parallel() instead.
369    
370    Args:
371        pairs: List from get_image_mask_pairs().
372        output_dir: Base output directory.
373        dataset_type: Either 'train' or 'val'.
374        patch_size: Patch size. Default is 128.
375        scaling_factor: Scaling factor. Default is 1.0.
376        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
377        mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
378        filter_roi: If True, crop to ROI before patching. Default is True.
379        preprocess_fns: Optional list of preprocessing functions to apply to images.
380        notes: Optional notes to include in metadata. Default is empty string.
381    
382    Returns:
383        Number of patches created.
384    """
385    # Simply delegate to save_patches_parallel with num_workers=1 for serial processing
386    return save_patches_parallel(
387        pairs=pairs,
388        output_dir=output_dir,
389        dataset_type=dataset_type,
390        patch_size=patch_size,
391        scaling_factor=scaling_factor,
392        step=step,
393        mask_types=mask_types,
394        filter_roi=filter_roi,
395        preprocess_fns=preprocess_fns,
396        notes=notes,
397        num_workers=1
398    )
399
400def _process_image_worker_parallel(args):
401    """Worker function for parallel patch processing. Must be at module level for pickling.
402    
403    Args:
404        args: Tuple of (pair, paths_dict, patch_size, scaling_factor, step, 
405                       mask_types, filter_roi, preprocess_fns)
406    
407    Returns:
408        Tuple of (local_metadata, patch_count, image_name)
409    """
410    pair, paths_dict, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns = args
411    
412    img_path = pair['image']
413    base_name = img_path.stem
414    
415    # Load image as grayscale
416    image = cv2.imread(str(img_path), cv2.IMREAD_GRAYSCALE)
417    
418    # Detect ROI
419    roi_bbox = detect_roi(image) if filter_roi else None
420    
421    # Load masks
422    masks = {}
423    for mask_type in mask_types:
424        if mask_type in pair['masks']:
425            masks[mask_type] = cv2.imread(str(pair['masks'][mask_type]), cv2.IMREAD_GRAYSCALE)
426    
427    # Create patches
428    result = create_patches_from_image(image, masks, patch_size, scaling_factor, 
429                                      step, roi_bbox, preprocess_fns)
430    
431    n_rows, n_cols = result['image'].shape[0], result['image'].shape[1]
432    local_metadata = []
433    
434    # Write all patches for this image
435    for row_idx in range(n_rows):
436        for col_idx in range(n_cols):
437            patch_name = f"{base_name}_r{row_idx:02d}_c{col_idx:02d}.png"
438            
439            # Write image patch
440            img_patch = result['image'][row_idx, col_idx, 0]
441            cv2.imwrite(str(paths_dict['images'] / patch_name), img_patch)
442            
443            # Write mask patches
444            for mask_type in mask_types:
445                if mask_type in masks:
446                    mask_patch = result['masks'][mask_type][row_idx, col_idx, 0]
447                    cv2.imwrite(str(paths_dict[f'masks_{mask_type}'] / patch_name), mask_patch)
448            
449            # Calculate coordinates
450            x_start = col_idx * step
451            y_start = row_idx * step
452            x_end = x_start + patch_size
453            y_end = y_start + patch_size
454            
455            # Record metadata
456            patch_metadata = {
457                "patch_filename": patch_name,
458                "source_image": img_path.name,
459                "row_idx": row_idx,
460                "col_idx": col_idx,
461                "x_start": x_start,
462                "y_start": y_start,
463                "x_end": x_end,
464                "y_end": y_end,
465                "grid_size": [n_rows, n_cols]
466            }
467            
468            if roi_bbox:
469                patch_metadata["roi_bbox"] = [list(roi_bbox[0]), list(roi_bbox[1])]
470            
471            local_metadata.append(patch_metadata)
472    
473    return local_metadata, n_rows * n_cols, img_path.name
474
475def save_patches_parallel(pairs, output_dir, dataset_type, patch_size=128, 
476                         scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'],
477                         filter_roi=True, preprocess_fns=None, notes="", num_workers=None):
478    """Create and save all patches to disk using parallel processing.
479    
480    This is an optimized version of save_patches() that processes multiple images
481    in parallel using multiprocessing.
482    
483    Args:
484        pairs: List from get_image_mask_pairs().
485        output_dir: Base output directory.
486        dataset_type: Either 'train' or 'val'.
487        patch_size: Patch size. Default is 128.
488        scaling_factor: Scaling factor. Default is 1.0.
489        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
490        mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
491        filter_roi: If True, crop to ROI before patching. Default is True.
492        preprocess_fns: Optional list of preprocessing functions to apply to images.
493        notes: Optional notes to include in metadata. Default is empty string.
494        num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None.
495    
496    Returns:
497        Number of patches created.
498    """
499    import multiprocessing as mp
500    from concurrent.futures import ProcessPoolExecutor, as_completed
501    
502    if num_workers is None:
503        num_workers = max(1, mp.cpu_count() - 1)
504    
505    if step is None:
506        step = patch_size
507    
508    # Auto-clean only directories for this dataset_type
509    output_path = Path(output_dir)
510    if output_path.exists():
511        dirs_to_clean = [
512            output_path / f'{dataset_type}_images',
513            *[output_path / f'{dataset_type}_masks_{mt}' for mt in mask_types]
514        ]
515        
516        for dir_path in dirs_to_clean:
517            if dir_path.exists():
518                print(f"Cleaning existing directory: {dir_path}")
519                shutil.rmtree(dir_path)
520    
521    # Create directories
522    paths = create_patch_directories(output_dir, dataset_type, mask_types)
523    
524    # Extract preprocessing function names for metadata
525    preprocess_names = []
526    if preprocess_fns:
527        for fn in preprocess_fns:
528            preprocess_names.append(fn.__name__)
529    
530    metadata = {
531        "dataset_info": {
532            "dataset_type": dataset_type,
533            "dataset_source": str(pairs[0].get('image').parent.absolute().resolve()),
534            "patch_size": patch_size,
535            "step": step,
536            "overlap_percent": (1 - step / patch_size) * 100,
537            "scaling_factor": scaling_factor,
538            "filter_roi": filter_roi,
539            "preprocessing": preprocess_names if preprocess_names else None,
540            "created_at": datetime.now().isoformat(),
541            "epoch_utc": int(datetime.now(timezone.utc).timestamp()),            
542            "num_source_images": len(pairs),
543            "num_patches": 0,
544            "notes": notes,
545            "metadata_version": METADATA_VERSION
546        },
547        "patches": []
548    }
549    
550    # Prepare arguments for all workers
551    worker_args = [
552        (pair, paths, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns)
553        for pair in pairs
554    ]
555    
556    # Process in parallel with progress bar
557    patch_count = 0
558    with ProcessPoolExecutor(max_workers=num_workers) as executor:
559        futures = {executor.submit(_process_image_worker_parallel, args): args[0]['image'].name 
560                  for args in worker_args}
561        
562        with tqdm(total=len(pairs), desc=f"Processing images ({num_workers} workers)") as pbar:
563            for future in as_completed(futures):
564                try:
565                    local_meta, count, img_name = future.result()
566                    metadata['patches'].extend(local_meta)
567                    patch_count += count
568                    pbar.update(1)
569                except Exception as e:
570                    print(f"\nError processing {futures[future]}: {e}")
571                    raise
572    
573    metadata['dataset_info']['num_patches'] = patch_count
574    
575    # Save metadata
576    metadata_path = Path(output_dir) / f'{dataset_type}_metadata.json'
577    with open(metadata_path, 'w') as f:
578        json.dump(metadata, f, indent=2)
579    
580    print(f"\nTotal: {patch_count} patches saved")
581    print(f"Overlap: {metadata['dataset_info']['overlap_percent']:.1f}%")
582    print(f"ROI cropping: {'enabled' if filter_roi else 'disabled'}")
583    if preprocess_names:
584        print(f"Preprocessing: {', '.join(preprocess_names)}")
585    print(f"Workers used: {num_workers}")
586    print(f"Metadata saved to {metadata_path}")
587    
588    return patch_count
589
590
591def load_patch_metadata(patch_dir, dataset_type):
592    """Load metadata for a saved patch dataset.
593    
594    Args:
595        patch_dir: Directory containing saved patches.
596        dataset_type: Either 'train' or 'val'.
597    
598    Returns:
599        Dictionary containing dataset_info and patches list.
600    
601    Example:
602        >>> metadata = load_patch_metadata('data/patched', 'train')
603        >>> print(f"Patch size: {metadata['dataset_info']['patch_size']}")
604        >>> print(f"Step size: {metadata['dataset_info']['step']}")
605        >>> print(f"Total patches: {metadata['dataset_info']['num_patches']}")
606    """
607    metadata_path = Path(patch_dir) / f'{dataset_type}_metadata.json'
608    
609    if not metadata_path.exists():
610        raise FileNotFoundError(f"Metadata not found: {metadata_path}")
611    
612    with open(metadata_path, 'r') as f:
613        metadata = json.load(f)
614    
615    return metadata
616
617def get_patch_statistics(patch_dir, dataset_type):
618    """Get statistics about a saved patch dataset.
619    
620    Args:
621        patch_dir: Directory containing saved patches.
622        dataset_type: Either 'train' or 'val'.
623    
624    Returns:
625        Dictionary with dataset statistics.
626    
627    Example:
628        >>> stats = get_patch_statistics('data/patched', 'train')
629        >>> print(stats)
630    """
631    metadata = load_patch_metadata(patch_dir, dataset_type)
632    info = metadata['dataset_info']
633    
634    stats = {
635        'dataset_source': info['dataset_source'],
636        'created_at': info['created_at'],
637        'num_patches': info['num_patches'],
638        'num_source_images': info['num_source_images'],
639        'patch_size': info['patch_size'],
640        'step': info['step'],
641        'overlap_percent': info['overlap_percent'],
642        'patches_per_image': info['num_patches'] / info['num_source_images']
643    }
644    
645    return stats
METADATA_VERSION = '1.0'
def padder(image, patch_size, is_mask=False):
21def padder(image, patch_size, is_mask=False):
22    """Add padding to an image to make its dimensions divisible by patch size.
23
24    Calculates padding needed for both height and width so that dimensions become 
25    divisible by the given patch size. Padding is applied evenly to both sides of 
26    each dimension. If padding amount is odd, one extra pixel is added to the 
27    bottom or right side.
28
29    Parameters:
30        image: Input image as numpy array with shape (height, width, channels).
31        patch_size: The patch size to which image dimensions should be divisible.
32        is_mask: If True, uses grayscale padding. If False, uses RGB padding.
33
34    Returns:
35        Padded image as numpy array with dimensions divisible by patch_size.
36
37    Example:
38        >>> padded_image = padder(cv2.imread('example.jpg'), 128)
39    """
40    h, w = image.shape[:2]
41    
42    # Calculate padding only if needed
43    height_padding = 0 if h % patch_size == 0 else ((h // patch_size) + 1) * patch_size - h
44    width_padding = 0 if w % patch_size == 0 else ((w // patch_size) + 1) * patch_size - w
45    
46    # Early return if no padding needed
47    if height_padding == 0 and width_padding == 0:
48        return image
49    
50    # Split padding evenly
51    top_padding = height_padding // 2
52    bottom_padding = height_padding - top_padding
53    left_padding = width_padding // 2
54    right_padding = width_padding - left_padding
55    
56    # Use black padding
57    pad_value = 0
58    
59    return cv2.copyMakeBorder(
60        image, 
61        top_padding, bottom_padding, 
62        left_padding, right_padding, 
63        cv2.BORDER_CONSTANT, 
64        value=pad_value
65    )

Add padding to an image to make its dimensions divisible by patch size.

Calculates padding needed for both height and width so that dimensions become divisible by the given patch size. Padding is applied evenly to both sides of each dimension. If padding amount is odd, one extra pixel is added to the bottom or right side.

Arguments:
  • image: Input image as numpy array with shape (height, width, channels).
  • patch_size: The patch size to which image dimensions should be divisible.
  • is_mask: If True, uses grayscale padding. If False, uses RGB padding.
Returns:

Padded image as numpy array with dimensions divisible by patch_size.

Example:
>>> padded_image = padder(cv2.imread('example.jpg'), 128)
def unpadder(padded_image, roi_box):
 67def unpadder(padded_image, roi_box):
 68    """Remove padding from an image to restore original ROI dimensions.
 69
 70    Calculates and removes padding that was added to make an image divisible by
 71    a patch size. Padding is removed evenly from both sides of each dimension.
 72    This function reverses the padding operation applied by the padder function.
 73
 74    Parameters:
 75        padded_image: Padded image as numpy array with shape (height, width, channels).
 76        roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
 77
 78    Returns:
 79        Cropped image as numpy array with original ROI dimensions.
 80
 81    Example:
 82        >>> roi_box = ((776, 70), (3519, 2813))
 83        >>> original = unpadder(padded_image, roi_box)
 84    """
 85    h, w = padded_image.shape[:2]
 86
 87    # Calculate original ROI dimensions
 88    roi_height = roi_box[1][1] - roi_box[0][1]
 89    roi_width = roi_box[1][0] - roi_box[0][0]
 90
 91    # Calculate total padding in each dimension
 92    total_height_px = h - roi_height
 93    total_width_px = w - roi_width
 94
 95    # Early return if no padding to remove
 96    if total_height_px == 0 and total_width_px == 0:
 97        return padded_image
 98
 99    # Split padding evenly
100    left_padding = total_width_px // 2
101    right_padding = total_width_px - left_padding
102    top_padding = total_height_px // 2
103    bottom_padding = total_height_px - top_padding
104
105    # Crop the image & handle masks too
106    if padded_image.ndim == 2:
107        return padded_image[top_padding:-bottom_padding, left_padding:-right_padding]
108    else:
109        return padded_image[top_padding:-bottom_padding, left_padding:-right_padding, :]

Remove padding from an image to restore original ROI dimensions.

Calculates and removes padding that was added to make an image divisible by a patch size. Padding is removed evenly from both sides of each dimension. This function reverses the padding operation applied by the padder function.

Arguments:
  • padded_image: Padded image as numpy array with shape (height, width, channels).
  • roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
Returns:

Cropped image as numpy array with original ROI dimensions.

Example:
>>> roi_box = ((776, 70), (3519, 2813))
>>> original = unpadder(padded_image, roi_box)
def restore_mask_to_original(padded_mask, original_image_shape, roi_box):
112def restore_mask_to_original(padded_mask, original_image_shape, roi_box):
113    """Restore a padded mask to match the original image dimensions.
114    
115    This function removes padding from a mask and places it at the correct position
116    in a full-size mask that matches the original image dimensions.
117    
118    Parameters:
119        padded_mask: Padded binary mask as numpy array with shape (height, width).
120        original_image_shape: Tuple of (height, width) or (height, width, channels) 
121                            from the original image.
122        roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
123    
124    Returns:
125        Binary mask as numpy array with shape matching original_image_shape[:2],
126        with mask values of 0 and 255.
127    
128    Example:
129        >>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box)
130        >>> cv2.imwrite('output.png', full_mask)
131    """
132    # Remove padding using unpadder
133    unpadded_mask = unpadder(padded_mask, roi_box)
134    
135    # Ensure mask is binary (0 and 255)
136    binary_mask = (unpadded_mask > 0).astype(np.uint8) * 255
137    
138    # Create full-size mask matching original image dimensions
139    full_mask = np.zeros(original_image_shape[:2], dtype=np.uint8)
140    
141    # Extract ROI coordinates
142    (x1, y1), (x2, y2) = roi_box
143    
144    # Place the mask at the correct position
145    full_mask[y1:y2, x1:x2] = binary_mask
146    
147    return full_mask

Restore a padded mask to match the original image dimensions.

This function removes padding from a mask and places it at the correct position in a full-size mask that matches the original image dimensions.

Arguments:
  • padded_mask: Padded binary mask as numpy array with shape (height, width).
  • original_image_shape: Tuple of (height, width) or (height, width, channels) from the original image.
  • roi_box: Tuple of ((x1, y1), (x2, y2)) representing the original ROI coordinates.
Returns:

Binary mask as numpy array with shape matching original_image_shape[:2], with mask values of 0 and 255.

Example:
>>> full_mask = restore_mask_to_original(predicted_mask, image.shape, roi_box)
>>> cv2.imwrite('output.png', full_mask)
def apply_preprocessing_pipeline(image, preprocess_fns):
149def apply_preprocessing_pipeline(image, preprocess_fns):
150    """Apply a list of preprocessing functions sequentially to an image.
151    
152    Args:
153        image: Numpy array with shape (H, W, C) or (H, W).
154        preprocess_fns: List of callables, each taking image and returning 
155            processed image. Can be None or empty list.
156    
157    Returns:
158        Preprocessed image.
159    """
160    if preprocess_fns is None or len(preprocess_fns) == 0:
161        return image
162    
163    processed = image.copy()
164    for fn in preprocess_fns:
165        processed = fn(processed)
166    
167    return processed

Apply a list of preprocessing functions sequentially to an image.

Arguments:
  • image: Numpy array with shape (H, W, C) or (H, W).
  • preprocess_fns: List of callables, each taking image and returning processed image. Can be None or empty list.
Returns:

Preprocessed image.

def process_image(image, patch_size, scaling_factor, is_mask=False):
169def process_image(image, patch_size, scaling_factor, is_mask=False):
170    """Pad and scale a single image.
171    
172    Args:
173        image: Numpy array with shape (H, W, C) or (H, W).
174        patch_size: Target patch size.
175        scaling_factor: Scaling factor (<=1.0).
176        is_mask: Whether this is a mask (affects padding value).
177    
178    Returns:
179        Processed image.
180    """
181    # Scale if needed
182    if scaling_factor != 1.0:
183        image = cv2.resize(image, (0, 0), fx=scaling_factor, fy=scaling_factor)
184    
185    # Pad to be divisible by patch_size
186    image = padder(image, patch_size, is_mask=is_mask)
187    
188    return image

Pad and scale a single image.

Arguments:
  • image: Numpy array with shape (H, W, C) or (H, W).
  • patch_size: Target patch size.
  • scaling_factor: Scaling factor (<=1.0).
  • is_mask: Whether this is a mask (affects padding value).
Returns:

Processed image.

def create_patch_directories(output_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
191def create_patch_directories(output_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
192    """Create directory structure needed for patch datasets.
193    
194    Args:
195        output_dir: Base output directory (e.g., 'data_patched').
196        dataset_type: Either 'train' or 'val'.
197        mask_types: List of mask types to create directories for.
198    
199    Returns:
200        Dictionary of created paths with keys 'images' and 'masks_{type}'.
201    """
202    paths = {}
203    
204    # Images directory
205    img_dir = Path(output_dir) / f'{dataset_type}_images' / dataset_type
206    img_dir.mkdir(parents=True, exist_ok=True)
207    paths['images'] = img_dir
208    
209    # Mask directories for each type
210    for mask_type in mask_types:
211        mask_dir = Path(output_dir) / f'{dataset_type}_masks_{mask_type}' / dataset_type
212        mask_dir.mkdir(parents=True, exist_ok=True)
213        paths[f'masks_{mask_type}'] = mask_dir
214    
215    return paths

Create directory structure needed for patch datasets.

Arguments:
  • output_dir: Base output directory (e.g., 'data_patched').
  • dataset_type: Either 'train' or 'val'.
  • mask_types: List of mask types to create directories for.
Returns:

Dictionary of created paths with keys 'images' and 'masks_{type}'.

def get_image_mask_pairs(data_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
217def get_image_mask_pairs(data_dir, dataset_type, mask_types=['root', 'shoot', 'seed']):
218    """Find all images and their corresponding masks.
219    
220    Args:
221        data_dir: Root data directory (e.g., '../../data/dataset').
222        dataset_type: Either 'train' or 'val'.
223        mask_types: List of mask types to find.
224    
225    Returns:
226        List of dictionaries with 'image' path and 'masks' dictionary.
227        Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}}
228    """
229    import glob
230    
231    image_dir = Path(data_dir) / f'{dataset_type}_images'
232    mask_dir = Path(data_dir) / f'{dataset_type}_masks'
233    
234    # Get all image files
235    image_files = sorted(glob.glob(str(image_dir / '*.png')))
236    
237    pairs = []
238    for img_path in image_files:
239        img_path = Path(img_path)
240        base_name = img_path.stem  # filename without extension
241        
242        # Find corresponding masks
243        masks = {}
244        for mask_type in mask_types:
245            mask_pattern = f'{base_name}_{mask_type}_mask.tif'
246            mask_path = mask_dir / mask_pattern
247            
248            if mask_path.exists():
249                masks[mask_type] = mask_path
250        
251        # Only include if at least one mask exists
252        if masks:
253            pairs.append({
254                'image': img_path,
255                'masks': masks
256            })
257    
258    return pairs

Find all images and their corresponding masks.

Arguments:
  • data_dir: Root data directory (e.g., '../../data/dataset').
  • dataset_type: Either 'train' or 'val'.
  • mask_types: List of mask types to find.
Returns:

List of dictionaries with 'image' path and 'masks' dictionary. Each dict has structure: {'image': Path, 'masks': {'root': Path, ...}}

def create_patches_from_image( image, mask_dict, patch_size, scaling_factor, step=None, roi_bbox=None, preprocess_fns=None):
260def create_patches_from_image(image, mask_dict, patch_size, scaling_factor, step=None, 
261                               roi_bbox=None, preprocess_fns=None):
262    """Create patches from one image and its corresponding masks.
263    
264    Args:
265        image: Numpy array of the image.
266        mask_dict: Dictionary with mask_type: mask_array.
267        patch_size: Size of patches.
268        scaling_factor: Scaling factor for resizing.
269        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
270        roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching.
271        preprocess_fns: List of preprocessing functions to apply before patching.
272    
273    Returns:
274        Dictionary with 'image' patches, 'masks' dict of patches for each type, 
275        'step', and 'roi_bbox' (if cropped).
276    """
277    
278    if step is None:
279        step = patch_size
280    
281    # Crop to ROI if provided
282    if roi_bbox is not None:
283        image = crop_to_roi(image, roi_bbox)
284        mask_dict = {k: crop_to_roi(v, roi_bbox) for k, v in mask_dict.items()}
285    
286    # Apply preprocessing to image only (not masks)
287    image = apply_preprocessing_pipeline(image, preprocess_fns)
288    
289    # Convert to grayscale if color image
290    if image.ndim == 3 and image.shape[2] == 3:
291        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
292
293    # Process image (add channel dimension for grayscale)
294    image = process_image(image, patch_size, scaling_factor, is_mask=False)
295    image = image[..., np.newaxis]  # Add channel dimension for patchify
296    img_patches = patchify(image, (patch_size, patch_size, 1), step=step)
297    
298    # Process masks
299    mask_patches = {}
300    for mask_type, mask in mask_dict.items():
301        mask = process_image(mask, patch_size, scaling_factor, is_mask=True)
302        mask = mask[..., np.newaxis]  # Add channel dimension for patchify
303        patches = patchify(mask, (patch_size, patch_size, 1), step=step)
304        mask_patches[mask_type] = patches
305    
306    result = {
307        'image': img_patches,
308        'masks': mask_patches,
309        'step': step
310    }
311    
312    if roi_bbox is not None:
313        result['roi_bbox'] = roi_bbox
314    
315    return result

Create patches from one image and its corresponding masks.

Arguments:
  • image: Numpy array of the image.
  • mask_dict: Dictionary with mask_type: mask_array.
  • patch_size: Size of patches.
  • scaling_factor: Scaling factor for resizing.
  • step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
  • roi_bbox: Optional ROI bounding box. If provided, crops image and masks before patching.
  • preprocess_fns: List of preprocessing functions to apply before patching.
Returns:

Dictionary with 'image' patches, 'masks' dict of patches for each type, 'step', and 'roi_bbox' (if cropped).

def reconstruct_from_patches(patches, image_shape, patch_size, step):
317def reconstruct_from_patches(patches, image_shape, patch_size, step):
318    """Reconstruct an image from patches.
319    
320    Uses unpatchify for non-overlapping patches (step == patch_size).
321    Uses averaging reconstruction for overlapping patches (step < patch_size).
322    
323    Args:
324        patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels).
325        image_shape: Target shape (height, width, channels) for reconstruction.
326        patch_size: Size of each patch.
327        step: Step size used during patch extraction.
328    
329    Returns:
330        Reconstructed image.
331    """
332    # Use unpatchify for non-overlapping patches
333    if step == patch_size:
334        return unpatchify(patches, image_shape)
335    
336    # Use averaging for overlapping patches
337    h, w, c = image_shape
338    n_rows, n_cols = patches.shape[0], patches.shape[1]
339    
340    reconstructed = np.zeros(image_shape, dtype=np.float32)
341    counts = np.zeros((h, w), dtype=np.float32)
342    
343    for row_idx in range(n_rows):
344        for col_idx in range(n_cols):
345            y_start = row_idx * step
346            x_start = col_idx * step
347            y_end = min(y_start + patch_size, h)
348            x_end = min(x_start + patch_size, w)
349            
350            patch = patches[row_idx, col_idx, 0]
351            patch_h = y_end - y_start
352            patch_w = x_end - x_start
353            
354            reconstructed[y_start:y_end, x_start:x_end] += patch[:patch_h, :patch_w]
355            counts[y_start:y_end, x_start:x_end] += 1
356    
357    counts = np.maximum(counts, 1)
358    reconstructed = reconstructed / counts[:, :, np.newaxis]
359    
360    return reconstructed.astype(np.uint8)

Reconstruct an image from patches.

Uses unpatchify for non-overlapping patches (step == patch_size). Uses averaging reconstruction for overlapping patches (step < patch_size).

Arguments:
  • patches: Patch array from patchify with shape (n_rows, n_cols, 1, patch_size, patch_size, channels).
  • image_shape: Target shape (height, width, channels) for reconstruction.
  • patch_size: Size of each patch.
  • step: Step size used during patch extraction.
Returns:

Reconstructed image.

def save_patches( pairs, output_dir, dataset_type, patch_size=128, scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], filter_roi=True, preprocess_fns=None, notes=''):
363def save_patches(pairs, output_dir, dataset_type, patch_size=128, 
364                 scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'],
365                 filter_roi=True, preprocess_fns=None, notes=""):
366    """Create and save all patches to disk, optionally cropping to ROI first.
367    
368    This function processes images serially. For faster processing with multiple
369    images, use save_patches_parallel() instead.
370    
371    Args:
372        pairs: List from get_image_mask_pairs().
373        output_dir: Base output directory.
374        dataset_type: Either 'train' or 'val'.
375        patch_size: Patch size. Default is 128.
376        scaling_factor: Scaling factor. Default is 1.0.
377        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
378        mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
379        filter_roi: If True, crop to ROI before patching. Default is True.
380        preprocess_fns: Optional list of preprocessing functions to apply to images.
381        notes: Optional notes to include in metadata. Default is empty string.
382    
383    Returns:
384        Number of patches created.
385    """
386    # Simply delegate to save_patches_parallel with num_workers=1 for serial processing
387    return save_patches_parallel(
388        pairs=pairs,
389        output_dir=output_dir,
390        dataset_type=dataset_type,
391        patch_size=patch_size,
392        scaling_factor=scaling_factor,
393        step=step,
394        mask_types=mask_types,
395        filter_roi=filter_roi,
396        preprocess_fns=preprocess_fns,
397        notes=notes,
398        num_workers=1
399    )

Create and save all patches to disk, optionally cropping to ROI first.

This function processes images serially. For faster processing with multiple images, use save_patches_parallel() instead.

Arguments:
  • pairs: List from get_image_mask_pairs().
  • output_dir: Base output directory.
  • dataset_type: Either 'train' or 'val'.
  • patch_size: Patch size. Default is 128.
  • scaling_factor: Scaling factor. Default is 1.0.
  • step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
  • mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
  • filter_roi: If True, crop to ROI before patching. Default is True.
  • preprocess_fns: Optional list of preprocessing functions to apply to images.
  • notes: Optional notes to include in metadata. Default is empty string.
Returns:

Number of patches created.

def save_patches_parallel( pairs, output_dir, dataset_type, patch_size=128, scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'], filter_roi=True, preprocess_fns=None, notes='', num_workers=None):
476def save_patches_parallel(pairs, output_dir, dataset_type, patch_size=128, 
477                         scaling_factor=1.0, step=None, mask_types=['root', 'shoot', 'seed'],
478                         filter_roi=True, preprocess_fns=None, notes="", num_workers=None):
479    """Create and save all patches to disk using parallel processing.
480    
481    This is an optimized version of save_patches() that processes multiple images
482    in parallel using multiprocessing.
483    
484    Args:
485        pairs: List from get_image_mask_pairs().
486        output_dir: Base output directory.
487        dataset_type: Either 'train' or 'val'.
488        patch_size: Patch size. Default is 128.
489        scaling_factor: Scaling factor. Default is 1.0.
490        step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
491        mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
492        filter_roi: If True, crop to ROI before patching. Default is True.
493        preprocess_fns: Optional list of preprocessing functions to apply to images.
494        notes: Optional notes to include in metadata. Default is empty string.
495        num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None.
496    
497    Returns:
498        Number of patches created.
499    """
500    import multiprocessing as mp
501    from concurrent.futures import ProcessPoolExecutor, as_completed
502    
503    if num_workers is None:
504        num_workers = max(1, mp.cpu_count() - 1)
505    
506    if step is None:
507        step = patch_size
508    
509    # Auto-clean only directories for this dataset_type
510    output_path = Path(output_dir)
511    if output_path.exists():
512        dirs_to_clean = [
513            output_path / f'{dataset_type}_images',
514            *[output_path / f'{dataset_type}_masks_{mt}' for mt in mask_types]
515        ]
516        
517        for dir_path in dirs_to_clean:
518            if dir_path.exists():
519                print(f"Cleaning existing directory: {dir_path}")
520                shutil.rmtree(dir_path)
521    
522    # Create directories
523    paths = create_patch_directories(output_dir, dataset_type, mask_types)
524    
525    # Extract preprocessing function names for metadata
526    preprocess_names = []
527    if preprocess_fns:
528        for fn in preprocess_fns:
529            preprocess_names.append(fn.__name__)
530    
531    metadata = {
532        "dataset_info": {
533            "dataset_type": dataset_type,
534            "dataset_source": str(pairs[0].get('image').parent.absolute().resolve()),
535            "patch_size": patch_size,
536            "step": step,
537            "overlap_percent": (1 - step / patch_size) * 100,
538            "scaling_factor": scaling_factor,
539            "filter_roi": filter_roi,
540            "preprocessing": preprocess_names if preprocess_names else None,
541            "created_at": datetime.now().isoformat(),
542            "epoch_utc": int(datetime.now(timezone.utc).timestamp()),            
543            "num_source_images": len(pairs),
544            "num_patches": 0,
545            "notes": notes,
546            "metadata_version": METADATA_VERSION
547        },
548        "patches": []
549    }
550    
551    # Prepare arguments for all workers
552    worker_args = [
553        (pair, paths, patch_size, scaling_factor, step, mask_types, filter_roi, preprocess_fns)
554        for pair in pairs
555    ]
556    
557    # Process in parallel with progress bar
558    patch_count = 0
559    with ProcessPoolExecutor(max_workers=num_workers) as executor:
560        futures = {executor.submit(_process_image_worker_parallel, args): args[0]['image'].name 
561                  for args in worker_args}
562        
563        with tqdm(total=len(pairs), desc=f"Processing images ({num_workers} workers)") as pbar:
564            for future in as_completed(futures):
565                try:
566                    local_meta, count, img_name = future.result()
567                    metadata['patches'].extend(local_meta)
568                    patch_count += count
569                    pbar.update(1)
570                except Exception as e:
571                    print(f"\nError processing {futures[future]}: {e}")
572                    raise
573    
574    metadata['dataset_info']['num_patches'] = patch_count
575    
576    # Save metadata
577    metadata_path = Path(output_dir) / f'{dataset_type}_metadata.json'
578    with open(metadata_path, 'w') as f:
579        json.dump(metadata, f, indent=2)
580    
581    print(f"\nTotal: {patch_count} patches saved")
582    print(f"Overlap: {metadata['dataset_info']['overlap_percent']:.1f}%")
583    print(f"ROI cropping: {'enabled' if filter_roi else 'disabled'}")
584    if preprocess_names:
585        print(f"Preprocessing: {', '.join(preprocess_names)}")
586    print(f"Workers used: {num_workers}")
587    print(f"Metadata saved to {metadata_path}")
588    
589    return patch_count

Create and save all patches to disk using parallel processing.

This is an optimized version of save_patches() that processes multiple images in parallel using multiprocessing.

Arguments:
  • pairs: List from get_image_mask_pairs().
  • output_dir: Base output directory.
  • dataset_type: Either 'train' or 'val'.
  • patch_size: Patch size. Default is 128.
  • scaling_factor: Scaling factor. Default is 1.0.
  • step: Step size for patch extraction. If None, defaults to patch_size (no overlap).
  • mask_types: Which masks to process. Default is ['root', 'shoot', 'seed'].
  • filter_roi: If True, crop to ROI before patching. Default is True.
  • preprocess_fns: Optional list of preprocessing functions to apply to images.
  • notes: Optional notes to include in metadata. Default is empty string.
  • num_workers: Number of parallel workers. None = cpu_count() - 1. Default is None.
Returns:

Number of patches created.

def load_patch_metadata(patch_dir, dataset_type):
592def load_patch_metadata(patch_dir, dataset_type):
593    """Load metadata for a saved patch dataset.
594    
595    Args:
596        patch_dir: Directory containing saved patches.
597        dataset_type: Either 'train' or 'val'.
598    
599    Returns:
600        Dictionary containing dataset_info and patches list.
601    
602    Example:
603        >>> metadata = load_patch_metadata('data/patched', 'train')
604        >>> print(f"Patch size: {metadata['dataset_info']['patch_size']}")
605        >>> print(f"Step size: {metadata['dataset_info']['step']}")
606        >>> print(f"Total patches: {metadata['dataset_info']['num_patches']}")
607    """
608    metadata_path = Path(patch_dir) / f'{dataset_type}_metadata.json'
609    
610    if not metadata_path.exists():
611        raise FileNotFoundError(f"Metadata not found: {metadata_path}")
612    
613    with open(metadata_path, 'r') as f:
614        metadata = json.load(f)
615    
616    return metadata

Load metadata for a saved patch dataset.

Arguments:
  • patch_dir: Directory containing saved patches.
  • dataset_type: Either 'train' or 'val'.
Returns:

Dictionary containing dataset_info and patches list.

Example:
>>> metadata = load_patch_metadata('data/patched', 'train')
>>> print(f"Patch size: {metadata['dataset_info']['patch_size']}")
>>> print(f"Step size: {metadata['dataset_info']['step']}")
>>> print(f"Total patches: {metadata['dataset_info']['num_patches']}")
def get_patch_statistics(patch_dir, dataset_type):
618def get_patch_statistics(patch_dir, dataset_type):
619    """Get statistics about a saved patch dataset.
620    
621    Args:
622        patch_dir: Directory containing saved patches.
623        dataset_type: Either 'train' or 'val'.
624    
625    Returns:
626        Dictionary with dataset statistics.
627    
628    Example:
629        >>> stats = get_patch_statistics('data/patched', 'train')
630        >>> print(stats)
631    """
632    metadata = load_patch_metadata(patch_dir, dataset_type)
633    info = metadata['dataset_info']
634    
635    stats = {
636        'dataset_source': info['dataset_source'],
637        'created_at': info['created_at'],
638        'num_patches': info['num_patches'],
639        'num_source_images': info['num_source_images'],
640        'patch_size': info['patch_size'],
641        'step': info['step'],
642        'overlap_percent': info['overlap_percent'],
643        'patches_per_image': info['num_patches'] / info['num_source_images']
644    }
645    
646    return stats

Get statistics about a saved patch dataset.

Arguments:
  • patch_dir: Directory containing saved patches.
  • dataset_type: Either 'train' or 'val'.
Returns:

Dictionary with dataset statistics.

Example:
>>> stats = get_patch_statistics('data/patched', 'train')
>>> print(stats)