library.tf_generators
TensorFlow data generators for patch-based training.
1# library/tf_generators.py 2"""TensorFlow data generators for patch-based training.""" 3 4import tensorflow as tf 5from pathlib import Path 6import json 7 8 9def create_patch_dataset(patch_dir, dataset_type='train', mask_type='root', 10 batch_size=16, shuffle=True, seed=42, augment=False, 11 augment_config=None): 12 """Create a tf.data.Dataset for loading image and mask patches. 13 14 Args: 15 patch_dir: Directory containing saved patches. 16 dataset_type: Either 'train' or 'val'. Default is 'train'. 17 mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'. 18 batch_size: Batch size for training. Default is 16. 19 shuffle: Whether to shuffle the dataset. Default is True. 20 seed: Random seed for reproducibility. Default is 42. 21 augment: Whether to apply data augmentation. Default is False. 22 augment_config: Dictionary of augmentation settings. If None, uses defaults. 23 Options: 24 - 'flip_left_right': bool, default True 25 - 'flip_up_down': bool, default True 26 - 'rotate': bool, default True (90 degree rotations only) 27 - 'brightness': float, max delta for brightness adjustment, default 0.0 28 - 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0) 29 30 Returns: 31 tf.data.Dataset yielding (image_batch, mask_batch) tuples. 32 33 Example: 34 >>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1} 35 >>> dataset = create_patch_dataset('../../data/patched', 'train', 36 ... augment=True, augment_config=config) 37 """ 38 patch_dir = Path(patch_dir) 39 40 # Default augmentation config 41 if augment_config is None: 42 augment_config = { 43 'flip_left_right': True, 44 'flip_up_down': True, 45 'rotate': True, 46 'brightness': 0.0, 47 'contrast': (1.0, 1.0) 48 } 49 50 # Get paths to image and mask directories 51 image_dir = patch_dir / f'{dataset_type}_images' / dataset_type 52 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 53 54 # Get list of image files 55 image_files = sorted(list(image_dir.glob('*.png'))) 56 57 if len(image_files) == 0: 58 raise ValueError(f"No image files found in {image_dir}") 59 60 # Create corresponding mask file paths 61 mask_files = [mask_dir / img_file.name for img_file in image_files] 62 63 # Convert to string paths for TensorFlow 64 image_paths = [str(p) for p in image_files] 65 mask_paths = [str(p) for p in mask_files] 66 67 print(f"Found {len(image_paths)} patches") 68 if augment: 69 print(f"Augmentation enabled: {augment_config}") 70 71 # Create dataset from file paths 72 dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) 73 74 # Shuffle if requested 75 if shuffle: 76 dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed) 77 78 # Map loading and preprocessing function 79 if augment: 80 def map_fn(img_path, mask_path): 81 return _load_and_augment(img_path, mask_path, augment_config, seed) 82 83 dataset = dataset.map( 84 map_fn, 85 num_parallel_calls=tf.data.AUTOTUNE 86 ) 87 else: 88 dataset = dataset.map( 89 _load_and_preprocess, 90 num_parallel_calls=tf.data.AUTOTUNE 91 ) 92 93 # Batch the dataset 94 dataset = dataset.batch(batch_size) 95 96 # Prefetch for performance 97 dataset = dataset.prefetch(tf.data.AUTOTUNE) 98 99 return dataset 100 101def create_filtered_patch_dataset(patch_dir, filtered_list_json, dataset_type='train', 102 batch_size=16, shuffle=True, seed=42, 103 augment=False, augment_config=None): 104 """Create a tf.data.Dataset using a filtered patch list. 105 106 Args: 107 patch_dir: Directory containing saved patches. 108 filtered_list_json: Path to JSON file with filtered patch list. 109 dataset_type: Either 'train' or 'val'. Default is 'train'. 110 batch_size: Batch size for training. Default is 16. 111 shuffle: Whether to shuffle the dataset. Default is True. 112 seed: Random seed for reproducibility. Default is 42. 113 augment: Whether to apply data augmentation. Default is False. 114 augment_config: Dictionary of augmentation settings (same as before). 115 116 Returns: 117 tf.data.Dataset yielding (image_batch, mask_batch) tuples. 118 119 Example: 120 >>> dataset = create_filtered_patch_dataset( 121 ... patch_dir='../../data/patched', 122 ... filtered_list_json='filtered_lists/train_filtered_root.json', 123 ... dataset_type='train', 124 ... augment=True 125 ... ) 126 """ 127 patch_dir = Path(patch_dir) 128 129 # Load filtered patch list 130 with open(filtered_list_json, 'r') as f: 131 filtered_data = json.load(f) 132 133 mask_type = filtered_data['metadata']['mask_type'] 134 patch_filenames = filtered_data['patch_filenames'] 135 136 # Default augmentation config 137 if augment_config is None: 138 augment_config = { 139 'flip_left_right': True, 140 'flip_up_down': True, 141 'rotate': True, 142 'brightness': 0.0, 143 'contrast': (1.0, 1.0) 144 } 145 146 # Get paths to image and mask directories 147 image_dir = patch_dir / f'{dataset_type}_images' / dataset_type 148 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 149 150 # Create full file paths 151 image_paths = [str(image_dir / fname) for fname in patch_filenames] 152 mask_paths = [str(mask_dir / fname) for fname in patch_filenames] 153 154 print(f"Loaded filtered list for {mask_type}") 155 print(f"Total patches: {filtered_data['statistics']['total_patches']}") 156 print(f"Non-empty: {filtered_data['statistics']['non_empty_patches']}") 157 print(f"Empty: {filtered_data['statistics']['empty_patches']}") 158 print(f"Empty ratio: {filtered_data['metadata']['actual_empty_ratio']*100:.1f}%") 159 160 if augment: 161 print(f"Augmentation enabled: {augment_config}") 162 163 # Create dataset from file paths 164 dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) 165 166 # Shuffle if requested 167 if shuffle: 168 dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed) 169 170 # Map loading and preprocessing function 171 if augment: 172 def map_fn(img_path, mask_path): 173 return _load_and_augment(img_path, mask_path, augment_config, seed) 174 175 dataset = dataset.map( 176 map_fn, 177 num_parallel_calls=tf.data.AUTOTUNE 178 ) 179 else: 180 dataset = dataset.map( 181 _load_and_preprocess, 182 num_parallel_calls=tf.data.AUTOTUNE 183 ) 184 185 # Batch the dataset 186 dataset = dataset.batch(batch_size) 187 188 # Prefetch for performance 189 dataset = dataset.prefetch(tf.data.AUTOTUNE) 190 191 return dataset 192 193 194def _load_and_augment(image_path, mask_path, augment_config, seed): 195 """Load and augment a single image-mask pair. 196 197 Args: 198 image_path: Path to image file. 199 mask_path: Path to mask file. 200 augment_config: Dictionary of augmentation settings. 201 seed: Random seed. 202 203 Returns: 204 Tuple of (augmented_image, augmented_mask) tensors. 205 """ 206 # Load image and mask 207 image, mask = _load_and_preprocess(image_path, mask_path) 208 209 # Random flips 210 if augment_config.get('flip_left_right', False): 211 do_flip = tf.random.uniform((), seed=seed) > 0.5 212 image = tf.cond(do_flip, lambda: tf.image.flip_left_right(image), lambda: image) 213 mask = tf.cond(do_flip, lambda: tf.image.flip_left_right(mask), lambda: mask) 214 215 if augment_config.get('flip_up_down', False): 216 do_flip = tf.random.uniform((), seed=seed) > 0.5 217 image = tf.cond(do_flip, lambda: tf.image.flip_up_down(image), lambda: image) 218 mask = tf.cond(do_flip, lambda: tf.image.flip_up_down(mask), lambda: mask) 219 220 # Random 90 degree rotations (0, 90, 180, 270) 221 if augment_config.get('rotate', False): 222 k = tf.random.uniform((), minval=0, maxval=4, dtype=tf.int32, seed=seed) 223 image = tf.image.rot90(image, k=k) 224 mask = tf.image.rot90(mask, k=k) 225 226 # Color augmentations (only on image, not mask) 227 brightness_delta = augment_config.get('brightness', 0.0) 228 if brightness_delta > 0: 229 image = tf.image.random_brightness(image, brightness_delta, seed=seed) 230 image = tf.clip_by_value(image, 0.0, 1.0) 231 232 contrast_range = augment_config.get('contrast', (1.0, 1.0)) 233 if contrast_range[0] != 1.0 or contrast_range[1] != 1.0: 234 image = tf.image.random_contrast(image, contrast_range[0], contrast_range[1], seed=seed) 235 image = tf.clip_by_value(image, 0.0, 1.0) 236 237 return image, mask 238 239def _load_and_preprocess(image_path, mask_path): 240 """Load and preprocess a single image-mask pair. 241 242 Args: 243 image_path: Path to image file. 244 mask_path: Path to mask file. 245 246 Returns: 247 Tuple of (image, mask) tensors. 248 """ 249 # Load image 250 image = tf.io.read_file(image_path) 251 image = tf.image.decode_png(image, channels=1) 252 image = tf.cast(image, tf.float32) / 255.0 # Normalize to [0, 1] 253 254 # Load mask 255 mask = tf.io.read_file(mask_path) 256 mask = tf.image.decode_png(mask, channels=1) 257 mask = tf.cast(mask, tf.float32) # Keep as 0 and 1 258 259 return image, mask 260 261 262def get_content_rich_patches(patch_dir, dataset_type='train', mask_type='root', 263 n_samples=10, min_mask_coverage=0.01): 264 """Get a list of patch filenames that contain meaningful mask content. 265 266 Args: 267 patch_dir: Directory containing saved patches. 268 dataset_type: Either 'train' or 'val'. Default is 'train'. 269 mask_type: Which mask type to check. Default is 'root'. 270 n_samples: Number of content-rich patches to return. Default is 10. 271 min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%). 272 273 Returns: 274 List of patch filenames that have sufficient mask content. 275 """ 276 import cv2 277 import random 278 279 patch_dir = Path(patch_dir) 280 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 281 282 # Get all mask files 283 mask_files = list(mask_dir.glob('*.png')) 284 285 # Find patches with content 286 content_patches = [] 287 for mask_file in mask_files: 288 mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE) 289 mask_coverage = (mask > 0).sum() / mask.size 290 291 if mask_coverage >= min_mask_coverage: 292 content_patches.append(mask_file.name) 293 294 print(f"Found {len(content_patches)} patches with >{min_mask_coverage*100:.1f}% mask coverage") 295 296 # Return random sample 297 return random.sample(content_patches, min(n_samples, len(content_patches)))
def
create_patch_dataset( patch_dir, dataset_type='train', mask_type='root', batch_size=16, shuffle=True, seed=42, augment=False, augment_config=None):
10def create_patch_dataset(patch_dir, dataset_type='train', mask_type='root', 11 batch_size=16, shuffle=True, seed=42, augment=False, 12 augment_config=None): 13 """Create a tf.data.Dataset for loading image and mask patches. 14 15 Args: 16 patch_dir: Directory containing saved patches. 17 dataset_type: Either 'train' or 'val'. Default is 'train'. 18 mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'. 19 batch_size: Batch size for training. Default is 16. 20 shuffle: Whether to shuffle the dataset. Default is True. 21 seed: Random seed for reproducibility. Default is 42. 22 augment: Whether to apply data augmentation. Default is False. 23 augment_config: Dictionary of augmentation settings. If None, uses defaults. 24 Options: 25 - 'flip_left_right': bool, default True 26 - 'flip_up_down': bool, default True 27 - 'rotate': bool, default True (90 degree rotations only) 28 - 'brightness': float, max delta for brightness adjustment, default 0.0 29 - 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0) 30 31 Returns: 32 tf.data.Dataset yielding (image_batch, mask_batch) tuples. 33 34 Example: 35 >>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1} 36 >>> dataset = create_patch_dataset('../../data/patched', 'train', 37 ... augment=True, augment_config=config) 38 """ 39 patch_dir = Path(patch_dir) 40 41 # Default augmentation config 42 if augment_config is None: 43 augment_config = { 44 'flip_left_right': True, 45 'flip_up_down': True, 46 'rotate': True, 47 'brightness': 0.0, 48 'contrast': (1.0, 1.0) 49 } 50 51 # Get paths to image and mask directories 52 image_dir = patch_dir / f'{dataset_type}_images' / dataset_type 53 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 54 55 # Get list of image files 56 image_files = sorted(list(image_dir.glob('*.png'))) 57 58 if len(image_files) == 0: 59 raise ValueError(f"No image files found in {image_dir}") 60 61 # Create corresponding mask file paths 62 mask_files = [mask_dir / img_file.name for img_file in image_files] 63 64 # Convert to string paths for TensorFlow 65 image_paths = [str(p) for p in image_files] 66 mask_paths = [str(p) for p in mask_files] 67 68 print(f"Found {len(image_paths)} patches") 69 if augment: 70 print(f"Augmentation enabled: {augment_config}") 71 72 # Create dataset from file paths 73 dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) 74 75 # Shuffle if requested 76 if shuffle: 77 dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed) 78 79 # Map loading and preprocessing function 80 if augment: 81 def map_fn(img_path, mask_path): 82 return _load_and_augment(img_path, mask_path, augment_config, seed) 83 84 dataset = dataset.map( 85 map_fn, 86 num_parallel_calls=tf.data.AUTOTUNE 87 ) 88 else: 89 dataset = dataset.map( 90 _load_and_preprocess, 91 num_parallel_calls=tf.data.AUTOTUNE 92 ) 93 94 # Batch the dataset 95 dataset = dataset.batch(batch_size) 96 97 # Prefetch for performance 98 dataset = dataset.prefetch(tf.data.AUTOTUNE) 99 100 return dataset
Create a tf.data.Dataset for loading image and mask patches.
Arguments:
- patch_dir: Directory containing saved patches.
- dataset_type: Either 'train' or 'val'. Default is 'train'.
- mask_type: Which mask to load ('root', 'shoot', or 'seed'). Default is 'root'.
- batch_size: Batch size for training. Default is 16.
- shuffle: Whether to shuffle the dataset. Default is True.
- seed: Random seed for reproducibility. Default is 42.
- augment: Whether to apply data augmentation. Default is False.
- augment_config: Dictionary of augmentation settings. If None, uses defaults.
Options:
- 'flip_left_right': bool, default True
- 'flip_up_down': bool, default True
- 'rotate': bool, default True (90 degree rotations only)
- 'brightness': float, max delta for brightness adjustment, default 0.0
- 'contrast': tuple, (lower, upper) contrast factor range, default (1.0, 1.0)
Returns:
tf.data.Dataset yielding (image_batch, mask_batch) tuples.
Example:
>>> config = {'flip_left_right': True, 'rotate': True, 'brightness': 0.1} >>> dataset = create_patch_dataset('../../data/patched', 'train', ... augment=True, augment_config=config)
def
create_filtered_patch_dataset( patch_dir, filtered_list_json, dataset_type='train', batch_size=16, shuffle=True, seed=42, augment=False, augment_config=None):
102def create_filtered_patch_dataset(patch_dir, filtered_list_json, dataset_type='train', 103 batch_size=16, shuffle=True, seed=42, 104 augment=False, augment_config=None): 105 """Create a tf.data.Dataset using a filtered patch list. 106 107 Args: 108 patch_dir: Directory containing saved patches. 109 filtered_list_json: Path to JSON file with filtered patch list. 110 dataset_type: Either 'train' or 'val'. Default is 'train'. 111 batch_size: Batch size for training. Default is 16. 112 shuffle: Whether to shuffle the dataset. Default is True. 113 seed: Random seed for reproducibility. Default is 42. 114 augment: Whether to apply data augmentation. Default is False. 115 augment_config: Dictionary of augmentation settings (same as before). 116 117 Returns: 118 tf.data.Dataset yielding (image_batch, mask_batch) tuples. 119 120 Example: 121 >>> dataset = create_filtered_patch_dataset( 122 ... patch_dir='../../data/patched', 123 ... filtered_list_json='filtered_lists/train_filtered_root.json', 124 ... dataset_type='train', 125 ... augment=True 126 ... ) 127 """ 128 patch_dir = Path(patch_dir) 129 130 # Load filtered patch list 131 with open(filtered_list_json, 'r') as f: 132 filtered_data = json.load(f) 133 134 mask_type = filtered_data['metadata']['mask_type'] 135 patch_filenames = filtered_data['patch_filenames'] 136 137 # Default augmentation config 138 if augment_config is None: 139 augment_config = { 140 'flip_left_right': True, 141 'flip_up_down': True, 142 'rotate': True, 143 'brightness': 0.0, 144 'contrast': (1.0, 1.0) 145 } 146 147 # Get paths to image and mask directories 148 image_dir = patch_dir / f'{dataset_type}_images' / dataset_type 149 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 150 151 # Create full file paths 152 image_paths = [str(image_dir / fname) for fname in patch_filenames] 153 mask_paths = [str(mask_dir / fname) for fname in patch_filenames] 154 155 print(f"Loaded filtered list for {mask_type}") 156 print(f"Total patches: {filtered_data['statistics']['total_patches']}") 157 print(f"Non-empty: {filtered_data['statistics']['non_empty_patches']}") 158 print(f"Empty: {filtered_data['statistics']['empty_patches']}") 159 print(f"Empty ratio: {filtered_data['metadata']['actual_empty_ratio']*100:.1f}%") 160 161 if augment: 162 print(f"Augmentation enabled: {augment_config}") 163 164 # Create dataset from file paths 165 dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths)) 166 167 # Shuffle if requested 168 if shuffle: 169 dataset = dataset.shuffle(buffer_size=len(image_paths), seed=seed) 170 171 # Map loading and preprocessing function 172 if augment: 173 def map_fn(img_path, mask_path): 174 return _load_and_augment(img_path, mask_path, augment_config, seed) 175 176 dataset = dataset.map( 177 map_fn, 178 num_parallel_calls=tf.data.AUTOTUNE 179 ) 180 else: 181 dataset = dataset.map( 182 _load_and_preprocess, 183 num_parallel_calls=tf.data.AUTOTUNE 184 ) 185 186 # Batch the dataset 187 dataset = dataset.batch(batch_size) 188 189 # Prefetch for performance 190 dataset = dataset.prefetch(tf.data.AUTOTUNE) 191 192 return dataset
Create a tf.data.Dataset using a filtered patch list.
Arguments:
- patch_dir: Directory containing saved patches.
- filtered_list_json: Path to JSON file with filtered patch list.
- dataset_type: Either 'train' or 'val'. Default is 'train'.
- batch_size: Batch size for training. Default is 16.
- shuffle: Whether to shuffle the dataset. Default is True.
- seed: Random seed for reproducibility. Default is 42.
- augment: Whether to apply data augmentation. Default is False.
- augment_config: Dictionary of augmentation settings (same as before).
Returns:
tf.data.Dataset yielding (image_batch, mask_batch) tuples.
Example:
>>> dataset = create_filtered_patch_dataset( ... patch_dir='../../data/patched', ... filtered_list_json='filtered_lists/train_filtered_root.json', ... dataset_type='train', ... augment=True ... )
def
get_content_rich_patches( patch_dir, dataset_type='train', mask_type='root', n_samples=10, min_mask_coverage=0.01):
263def get_content_rich_patches(patch_dir, dataset_type='train', mask_type='root', 264 n_samples=10, min_mask_coverage=0.01): 265 """Get a list of patch filenames that contain meaningful mask content. 266 267 Args: 268 patch_dir: Directory containing saved patches. 269 dataset_type: Either 'train' or 'val'. Default is 'train'. 270 mask_type: Which mask type to check. Default is 'root'. 271 n_samples: Number of content-rich patches to return. Default is 10. 272 min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%). 273 274 Returns: 275 List of patch filenames that have sufficient mask content. 276 """ 277 import cv2 278 import random 279 280 patch_dir = Path(patch_dir) 281 mask_dir = patch_dir / f'{dataset_type}_masks_{mask_type}' / dataset_type 282 283 # Get all mask files 284 mask_files = list(mask_dir.glob('*.png')) 285 286 # Find patches with content 287 content_patches = [] 288 for mask_file in mask_files: 289 mask = cv2.imread(str(mask_file), cv2.IMREAD_GRAYSCALE) 290 mask_coverage = (mask > 0).sum() / mask.size 291 292 if mask_coverage >= min_mask_coverage: 293 content_patches.append(mask_file.name) 294 295 print(f"Found {len(content_patches)} patches with >{min_mask_coverage*100:.1f}% mask coverage") 296 297 # Return random sample 298 return random.sample(content_patches, min(n_samples, len(content_patches)))
Get a list of patch filenames that contain meaningful mask content.
Arguments:
- patch_dir: Directory containing saved patches.
- dataset_type: Either 'train' or 'val'. Default is 'train'.
- mask_type: Which mask type to check. Default is 'root'.
- n_samples: Number of content-rich patches to return. Default is 10.
- min_mask_coverage: Minimum percentage of mask pixels (0-1). Default is 0.01 (1%).
Returns:
List of patch filenames that have sufficient mask content.