library.tf_generators

TensorFlow data generators for patch-based training.

  1# library/tf_generators.py
  2"""TensorFlow data generators for patch-based training."""
  3
  4import tensorflow as tf
  5from pathlib import Path
  6import json
  7
  8
  9def create_patch_dataset(patch_dir, dataset_type='train', mask_type='root', 
 10                        batch_size=16, shuffle=True, seed=42, augment=False,
 11                        augment_config=None):
 12    """Create a tf.data.Dataset for loading image and mask patches.
 13    
 14    Args:
 15        patch_dir: Directory containing saved patches.
 16        dataset_type: Either 'train' or 'val'. Default is 'train'.
 17        mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'.
 18        batch_size: Batch size for training. Default is 16.
 19        shuffle: Whether to shuffle the dataset. Default is True.
 20        seed: Random seed for reproducibility. Default is 42.
 21        augment: Whether to apply data augmentation. Default is False.
 22        augment_config: Dictionary of augmentation settings. If None, uses defaults.
 23            Options:
 24            - 'flip_left_right': bool, default True
 25            - 'flip_up_down': bool, default True
 26            - 'rotate': bool, default True (90 degree rotations only)
 27            - 'brightness': float, max delta for brightness adjustment, default 0.0
 28            - 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0)
 29    
 30    Returns:
 31        tf.data.Dataset yielding (image_batch, mask_batch) tuples.
 32    
 33    Example:
 34        >>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1}
 35        >>> dataset = create_patch_dataset('../../data/patched', 'train', 
 36        ...                                augment=True, augment_config=config)
 37    """
 38    patch_dir = Path(patch_dir)
 39    
 40    # Default augmentation config
 41    if augment_config is None:
 42        augment_config = {
 43            'flip_left_right': True,
 44            'flip_up_down': True,
 45            'rotate': True,
 46            'brightness': 0.0,
 47            'contrast': (1.0, 1.0)
 48        }
 49    
 50    # Get paths to image and mask directories
 51    image_dir = patch_dir / f'{dataset_type}_images' / dataset_type
 52    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
 53    
 54    # Get list of image files
 55    image_files = sorted(list(image_dir.glob('*.png')))
 56    
 57    if len(image_files) == 0:
 58        raise ValueError(f"No image files found in {image_dir}")
 59    
 60    # Create corresponding mask file paths
 61    mask_files = [mask_dir / img_file.name for img_file in image_files]
 62    
 63    # Convert to string paths for TensorFlow
 64    image_paths = [str(p) for p in image_files]
 65    mask_paths = [str(p) for p in mask_files]
 66    
 67    print(f"Found {len(image_paths)} patches")
 68    if augment:
 69        print(f"Augmentation enabled: {augment_config}")
 70    
 71    # Create dataset from file paths
 72    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
 73    
 74    # Shuffle if requested
 75    if shuffle:
 76        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed)
 77    
 78    # Map loading and preprocessing function
 79    if augment:
 80        def map_fn(img_path, mask_path):
 81            return _load_and_augment(img_path, mask_path, augment_config, seed)
 82        
 83        dataset = dataset.map(
 84            map_fn,
 85            num_parallel_calls=tf.data.AUTOTUNE
 86        )
 87    else:
 88        dataset = dataset.map(
 89            _load_and_preprocess,
 90            num_parallel_calls=tf.data.AUTOTUNE
 91        )
 92    
 93    # Batch the dataset
 94    dataset = dataset.batch(batch_size)
 95    
 96    # Prefetch for performance
 97    dataset = dataset.prefetch(tf.data.AUTOTUNE)
 98    
 99    return dataset
100
101def create_filtered_patch_dataset(patch_dir, filtered_list_json, dataset_type='train',
102                                 batch_size=16, shuffle=True, seed=42, 
103                                 augment=False, augment_config=None):
104    """Create a tf.data.Dataset using a filtered patch list.
105    
106    Args:
107        patch_dir: Directory containing saved patches.
108        filtered_list_json: Path to JSON file with filtered patch list.
109        dataset_type: Either 'train' or 'val'. Default is 'train'.
110        batch_size: Batch size for training. Default is 16.
111        shuffle: Whether to shuffle the dataset. Default is True.
112        seed: Random seed for reproducibility. Default is 42.
113        augment: Whether to apply data augmentation. Default is False.
114        augment_config: Dictionary of augmentation settings (same as before).
115    
116    Returns:
117        tf.data.Dataset yielding (image_batch, mask_batch) tuples.
118    
119    Example:
120        >>> dataset = create_filtered_patch_dataset(
121        ...     patch_dir='../../data/patched',
122        ...     filtered_list_json='filtered_lists/train_filtered_root.json',
123        ...     dataset_type='train',
124        ...     augment=True
125        ... )
126    """
127    patch_dir = Path(patch_dir)
128    
129    # Load filtered patch list
130    with open(filtered_list_json, 'r') as f:
131        filtered_data = json.load(f)
132    
133    mask_type = filtered_data['metadata']['mask_type']
134    patch_filenames = filtered_data['patch_filenames']
135    
136    # Default augmentation config
137    if augment_config is None:
138        augment_config = {
139            'flip_left_right': True,
140            'flip_up_down': True,
141            'rotate': True,
142            'brightness': 0.0,
143            'contrast': (1.0, 1.0)
144        }
145    
146    # Get paths to image and mask directories
147    image_dir = patch_dir / f'{dataset_type}_images' / dataset_type
148    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
149    
150    # Create full file paths
151    image_paths = [str(image_dir / fname) for fname in patch_filenames]
152    mask_paths = [str(mask_dir / fname) for fname in patch_filenames]
153    
154    print(f"Loaded filtered list for {mask_type}")
155    print(f"Total patches: {filtered_data['statistics']['total_patches']}")
156    print(f"Non-empty: {filtered_data['statistics']['non_empty_patches']}")
157    print(f"Empty: {filtered_data['statistics']['empty_patches']}")
158    print(f"Empty ratio: {filtered_data['metadata']['actual_empty_ratio']*100:.1f}%")
159    
160    if augment:
161        print(f"Augmentation enabled: {augment_config}")
162    
163    # Create dataset from file paths
164    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
165    
166    # Shuffle if requested
167    if shuffle:
168        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed)
169    
170    # Map loading and preprocessing function
171    if augment:
172        def map_fn(img_path, mask_path):
173            return _load_and_augment(img_path, mask_path, augment_config, seed)
174        
175        dataset = dataset.map(
176            map_fn,
177            num_parallel_calls=tf.data.AUTOTUNE
178        )
179    else:
180        dataset = dataset.map(
181            _load_and_preprocess,
182            num_parallel_calls=tf.data.AUTOTUNE
183        )
184    
185    # Batch the dataset
186    dataset = dataset.batch(batch_size)
187    
188    # Prefetch for performance
189    dataset = dataset.prefetch(tf.data.AUTOTUNE)
190    
191    return dataset
192
193
194def _load_and_augment(image_path, mask_path, augment_config, seed):
195    """Load and augment a single image-mask pair.
196    
197    Args:
198        image_path: Path to image file.
199        mask_path: Path to mask file.
200        augment_config: Dictionary of augmentation settings.
201        seed: Random seed.
202    
203    Returns:
204        Tuple of (augmented_image, augmented_mask) tensors.
205    """
206    # Load image and mask
207    image, mask = _load_and_preprocess(image_path, mask_path)
208    
209    # Random flips
210    if augment_config.get('flip_left_right', False):
211        do_flip = tf.random.uniform((), seed=seed) > 0.5
212        image = tf.cond(do_flip, lambda: tf.image.flip_left_right(image), lambda: image)
213        mask = tf.cond(do_flip, lambda: tf.image.flip_left_right(mask), lambda: mask)
214    
215    if augment_config.get('flip_up_down', False):
216        do_flip = tf.random.uniform((), seed=seed) > 0.5
217        image = tf.cond(do_flip, lambda: tf.image.flip_up_down(image), lambda: image)
218        mask = tf.cond(do_flip, lambda: tf.image.flip_up_down(mask), lambda: mask)
219    
220    # Random 90 degree rotations (0, 90, 180, 270)
221    if augment_config.get('rotate', False):
222        k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32, seed=seed)
223        image = tf.image.rot90(image, k=k)
224        mask = tf.image.rot90(mask, k=k)
225    
226    # Color augmentations (only on image, not mask)
227    brightness_delta = augment_config.get('brightness', 0.0)
228    if brightness_delta > 0:
229        image = tf.image.random_brightness(image, brightness_delta, seed=seed)
230        image = tf.clip_by_value(image, 0.0, 1.0)
231    
232    contrast_range = augment_config.get('contrast', (1.0, 1.0))
233    if contrast_range[0] != 1.0 or contrast_range[1] != 1.0:
234        image = tf.image.random_contrast(image, contrast_range[0], contrast_range[1], seed=seed)
235        image = tf.clip_by_value(image, 0.0, 1.0)
236    
237    return image, mask
238
239def _load_and_preprocess(image_path, mask_path):
240    """Load and preprocess a single image-mask pair.
241    
242    Args:
243        image_path: Path to image file.
244        mask_path: Path to mask file.
245    
246    Returns:
247        Tuple of (image, mask) tensors.
248    """
249    # Load image
250    image = tf.io.read_file(image_path)
251    image = tf.image.decode_png(image, channels=1)
252    image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0, 1]
253    
254    # Load mask
255    mask = tf.io.read_file(mask_path)
256    mask = tf.image.decode_png(mask, channels=1)
257    mask = tf.cast(mask, tf.float32)  # Keep as 0 and 1
258    
259    return image, mask
260
261
262def get_content_rich_patches(patch_dir, dataset_type='train', mask_type='root', 
263                             n_samples=10, min_mask_coverage=0.01):
264    """Get a list of patch filenames that contain meaningful mask content.
265    
266    Args:
267        patch_dir: Directory containing saved patches.
268        dataset_type: Either 'train' or 'val'. Default is 'train'.
269        mask_type: Which mask type to check. Default is 'root'.
270        n_samples: Number of content-rich patches to return. Default is 10.
271        min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%).
272    
273    Returns:
274        List of patch filenames that have sufficient mask content.
275    """
276    import cv2
277    import random
278    
279    patch_dir = Path(patch_dir)
280    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
281    
282    # Get all mask files
283    mask_files = list(mask_dir.glob('*.png'))
284    
285    # Find patches with content
286    content_patches = []
287    for mask_file in mask_files:
288        mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
289        mask_coverage = (mask > 0).sum() / mask.size
290        
291        if mask_coverage >= min_mask_coverage:
292            content_patches.append(mask_file.name)
293    
294    print(f"Found {len(content_patches)} patches with >{min_mask_coverage*100:.1f}% mask coverage")
295    
296    # Return random sample
297    return random.sample(content_patches, min(n_samples, len(content_patches)))
def create_patch_dataset( patch_dir, dataset_type='train', mask_type='root', batch_size=16, shuffle=True, seed=42, augment=False, augment_config=None):
 10def create_patch_dataset(patch_dir, dataset_type='train', mask_type='root', 
 11                        batch_size=16, shuffle=True, seed=42, augment=False,
 12                        augment_config=None):
 13    """Create a tf.data.Dataset for loading image and mask patches.
 14    
 15    Args:
 16        patch_dir: Directory containing saved patches.
 17        dataset_type: Either 'train' or 'val'. Default is 'train'.
 18        mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'.
 19        batch_size: Batch size for training. Default is 16.
 20        shuffle: Whether to shuffle the dataset. Default is True.
 21        seed: Random seed for reproducibility. Default is 42.
 22        augment: Whether to apply data augmentation. Default is False.
 23        augment_config: Dictionary of augmentation settings. If None, uses defaults.
 24            Options:
 25            - 'flip_left_right': bool, default True
 26            - 'flip_up_down': bool, default True
 27            - 'rotate': bool, default True (90 degree rotations only)
 28            - 'brightness': float, max delta for brightness adjustment, default 0.0
 29            - 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0)
 30    
 31    Returns:
 32        tf.data.Dataset yielding (image_batch, mask_batch) tuples.
 33    
 34    Example:
 35        >>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1}
 36        >>> dataset = create_patch_dataset('../../data/patched', 'train', 
 37        ...                                augment=True, augment_config=config)
 38    """
 39    patch_dir = Path(patch_dir)
 40    
 41    # Default augmentation config
 42    if augment_config is None:
 43        augment_config = {
 44            'flip_left_right': True,
 45            'flip_up_down': True,
 46            'rotate': True,
 47            'brightness': 0.0,
 48            'contrast': (1.0, 1.0)
 49        }
 50    
 51    # Get paths to image and mask directories
 52    image_dir = patch_dir / f'{dataset_type}_images' / dataset_type
 53    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
 54    
 55    # Get list of image files
 56    image_files = sorted(list(image_dir.glob('*.png')))
 57    
 58    if len(image_files) == 0:
 59        raise ValueError(f"No image files found in {image_dir}")
 60    
 61    # Create corresponding mask file paths
 62    mask_files = [mask_dir / img_file.name for img_file in image_files]
 63    
 64    # Convert to string paths for TensorFlow
 65    image_paths = [str(p) for p in image_files]
 66    mask_paths = [str(p) for p in mask_files]
 67    
 68    print(f"Found {len(image_paths)} patches")
 69    if augment:
 70        print(f"Augmentation enabled: {augment_config}")
 71    
 72    # Create dataset from file paths
 73    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
 74    
 75    # Shuffle if requested
 76    if shuffle:
 77        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed)
 78    
 79    # Map loading and preprocessing function
 80    if augment:
 81        def map_fn(img_path, mask_path):
 82            return _load_and_augment(img_path, mask_path, augment_config, seed)
 83        
 84        dataset = dataset.map(
 85            map_fn,
 86            num_parallel_calls=tf.data.AUTOTUNE
 87        )
 88    else:
 89        dataset = dataset.map(
 90            _load_and_preprocess,
 91            num_parallel_calls=tf.data.AUTOTUNE
 92        )
 93    
 94    # Batch the dataset
 95    dataset = dataset.batch(batch_size)
 96    
 97    # Prefetch for performance
 98    dataset = dataset.prefetch(tf.data.AUTOTUNE)
 99    
100    return dataset

Create a tf.data.Dataset for loading image and mask patches.

Arguments:
  • patch_dir: Directory containing saved patches.
  • dataset_type: Either 'train' or 'val'. Default is 'train'.
  • mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'.
  • batch_size: Batch size for training. Default is 16.
  • shuffle: Whether to shuffle the dataset. Default is True.
  • seed: Random seed for reproducibility. Default is 42.
  • augment: Whether to apply data augmentation. Default is False.
  • augment_config: Dictionary of augmentation settings. If None, uses defaults. Options:
    • 'flip_left_right': bool, default True
    • 'flip_up_down': bool, default True
    • 'rotate': bool, default True (90 degree rotations only)
    • 'brightness': float, max delta for brightness adjustment, default 0.0
    • 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0)
Returns:

tf.data.Dataset yielding (image_batch, mask_batch) tuples.

Example:
>>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1}
>>> dataset = create_patch_dataset('../../data/patched', 'train', 
...                                augment=True, augment_config=config)
def create_filtered_patch_dataset( patch_dir, filtered_list_json, dataset_type='train', batch_size=16, shuffle=True, seed=42, augment=False, augment_config=None):
102def create_filtered_patch_dataset(patch_dir, filtered_list_json, dataset_type='train',
103                                 batch_size=16, shuffle=True, seed=42, 
104                                 augment=False, augment_config=None):
105    """Create a tf.data.Dataset using a filtered patch list.
106    
107    Args:
108        patch_dir: Directory containing saved patches.
109        filtered_list_json: Path to JSON file with filtered patch list.
110        dataset_type: Either 'train' or 'val'. Default is 'train'.
111        batch_size: Batch size for training. Default is 16.
112        shuffle: Whether to shuffle the dataset. Default is True.
113        seed: Random seed for reproducibility. Default is 42.
114        augment: Whether to apply data augmentation. Default is False.
115        augment_config: Dictionary of augmentation settings (same as before).
116    
117    Returns:
118        tf.data.Dataset yielding (image_batch, mask_batch) tuples.
119    
120    Example:
121        >>> dataset = create_filtered_patch_dataset(
122        ...     patch_dir='../../data/patched',
123        ...     filtered_list_json='filtered_lists/train_filtered_root.json',
124        ...     dataset_type='train',
125        ...     augment=True
126        ... )
127    """
128    patch_dir = Path(patch_dir)
129    
130    # Load filtered patch list
131    with open(filtered_list_json, 'r') as f:
132        filtered_data = json.load(f)
133    
134    mask_type = filtered_data['metadata']['mask_type']
135    patch_filenames = filtered_data['patch_filenames']
136    
137    # Default augmentation config
138    if augment_config is None:
139        augment_config = {
140            'flip_left_right': True,
141            'flip_up_down': True,
142            'rotate': True,
143            'brightness': 0.0,
144            'contrast': (1.0, 1.0)
145        }
146    
147    # Get paths to image and mask directories
148    image_dir = patch_dir / f'{dataset_type}_images' / dataset_type
149    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
150    
151    # Create full file paths
152    image_paths = [str(image_dir / fname) for fname in patch_filenames]
153    mask_paths = [str(mask_dir / fname) for fname in patch_filenames]
154    
155    print(f"Loaded filtered list for {mask_type}")
156    print(f"Total patches: {filtered_data['statistics']['total_patches']}")
157    print(f"Non-empty: {filtered_data['statistics']['non_empty_patches']}")
158    print(f"Empty: {filtered_data['statistics']['empty_patches']}")
159    print(f"Empty ratio: {filtered_data['metadata']['actual_empty_ratio']*100:.1f}%")
160    
161    if augment:
162        print(f"Augmentation enabled: {augment_config}")
163    
164    # Create dataset from file paths
165    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
166    
167    # Shuffle if requested
168    if shuffle:
169        dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed)
170    
171    # Map loading and preprocessing function
172    if augment:
173        def map_fn(img_path, mask_path):
174            return _load_and_augment(img_path, mask_path, augment_config, seed)
175        
176        dataset = dataset.map(
177            map_fn,
178            num_parallel_calls=tf.data.AUTOTUNE
179        )
180    else:
181        dataset = dataset.map(
182            _load_and_preprocess,
183            num_parallel_calls=tf.data.AUTOTUNE
184        )
185    
186    # Batch the dataset
187    dataset = dataset.batch(batch_size)
188    
189    # Prefetch for performance
190    dataset = dataset.prefetch(tf.data.AUTOTUNE)
191    
192    return dataset

Create a tf.data.Dataset using a filtered patch list.

Arguments:
  • patch_dir: Directory containing saved patches.
  • filtered_list_json: Path to JSON file with filtered patch list.
  • dataset_type: Either 'train' or 'val'. Default is 'train'.
  • batch_size: Batch size for training. Default is 16.
  • shuffle: Whether to shuffle the dataset. Default is True.
  • seed: Random seed for reproducibility. Default is 42.
  • augment: Whether to apply data augmentation. Default is False.
  • augment_config: Dictionary of augmentation settings (same as before).
Returns:

tf.data.Dataset yielding (image_batch, mask_batch) tuples.

Example:
>>> dataset = create_filtered_patch_dataset(
...     patch_dir='../../data/patched',
...     filtered_list_json='filtered_lists/train_filtered_root.json',
...     dataset_type='train',
...     augment=True
... )
def get_content_rich_patches( patch_dir, dataset_type='train', mask_type='root', n_samples=10, min_mask_coverage=0.01):
263def get_content_rich_patches(patch_dir, dataset_type='train', mask_type='root', 
264                             n_samples=10, min_mask_coverage=0.01):
265    """Get a list of patch filenames that contain meaningful mask content.
266    
267    Args:
268        patch_dir: Directory containing saved patches.
269        dataset_type: Either 'train' or 'val'. Default is 'train'.
270        mask_type: Which mask type to check. Default is 'root'.
271        n_samples: Number of content-rich patches to return. Default is 10.
272        min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%).
273    
274    Returns:
275        List of patch filenames that have sufficient mask content.
276    """
277    import cv2
278    import random
279    
280    patch_dir = Path(patch_dir)
281    mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type
282    
283    # Get all mask files
284    mask_files = list(mask_dir.glob('*.png'))
285    
286    # Find patches with content
287    content_patches = []
288    for mask_file in mask_files:
289        mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE)
290        mask_coverage = (mask > 0).sum() / mask.size
291        
292        if mask_coverage >= min_mask_coverage:
293            content_patches.append(mask_file.name)
294    
295    print(f"Found {len(content_patches)} patches with >{min_mask_coverage*100:.1f}% mask coverage")
296    
297    # Return random sample
298    return random.sample(content_patches, min(n_samples, len(content_patches)))

Get a list of patch filenames that contain meaningful mask content.

Arguments:
  • patch_dir: Directory containing saved patches.
  • dataset_type: Either 'train' or 'val'. Default is 'train'.
  • mask_type: Which mask type to check. Default is 'root'.
  • n_samples: Number of content-rich patches to return. Default is 10.
  • min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%).
Returns:

List of patch filenames that have sufficient mask content.