library.run_inference
Inference module for segmentation model predictions.
This module provides functions to run inference on images using trained segmentation models. It handles patch-based prediction with overlap, ROI detection, and reconstruction of full-resolution masks.
1"""Inference module for segmentation model predictions. 2 3This module provides functions to run inference on images using trained 4segmentation models. It handles patch-based prediction with overlap, 5ROI detection, and reconstruction of full-resolution masks. 6""" 7 8import platform 9from pathlib import Path 10from typing import Union, List, Tuple, Dict 11 12import numpy as np 13import cv2 14import tensorflow as tf 15from tensorflow.keras.models import load_model 16import keras.backend as K 17from patchify import patchify 18 19from library.roi import detect_roi, crop_to_roi 20from library.patch_dataset import padder, restore_mask_to_original 21 22 23def f1(y_true, y_pred): 24 """Calculate F1 score metric for binary segmentation. 25 26 Computes F1 score as the harmonic mean of precision and recall. 27 28 Args: 29 y_true: Ground truth binary labels. 30 y_pred: Predicted probabilities or binary labels. 31 32 Returns: 33 F1 score as a scalar tensor. 34 """ 35 def recall_m(y_true, y_pred): 36 TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 37 Positives = K.sum(K.round(K.clip(y_true, 0, 1))) 38 recall = TP / (Positives + K.epsilon()) 39 return recall 40 41 def precision_m(y_true, y_pred): 42 TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 43 Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 44 precision = TP / (Pred_Positives + K.epsilon()) 45 return precision 46 47 precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred) 48 49 return 2 * ((precision * recall) / (precision + recall + K.epsilon())) 50 51 52def configure_tensorflow_for_platform(): 53 """Configure TensorFlow based on the current platform. 54 55 Detects if running on Mac with Metal support, Linux/Windows with CUDA, 56 or CPU-only. Configures TensorFlow accordingly and enables memory growth 57 for GPU devices. 58 59 Returns: 60 str: Device type - 'metal', 'cuda', or 'cpu'. 61 """ 62 system = platform.system() 63 64 if system == 'Darwin': 65 gpus = tf.config.list_physical_devices('GPU') 66 if gpus: 67 try: 68 for gpu in gpus: 69 tf.config.experimental.set_memory_growth(gpu, True) 70 return 'metal' 71 except RuntimeError as e: 72 print(f"GPU configuration error: {e}") 73 return 'cpu' 74 return 'cpu' 75 else: 76 gpus = tf.config.list_physical_devices('GPU') 77 if gpus: 78 return 'cuda' 79 return 'cpu' 80 81 82def load_segmentation_model(model_path, verbose=True): 83 """Load a Keras segmentation model with custom F1 metric. 84 85 Automatically detects platform and configures TensorFlow before loading 86 the model. 87 88 Args: 89 model_path: Path to the saved model file (.h5 or .keras). 90 verbose: If True, print device information. 91 92 Returns: 93 Loaded Keras model ready for inference. 94 """ 95 device = configure_tensorflow_for_platform() 96 97 if verbose: 98 print(f"Using device: {device}") 99 100 model = load_model(model_path, custom_objects={"f1": f1}, compile=False) 101 102 return model 103 104 105def predict_patches_batched(model, patches, patch_size, batch_size=8, verbose=True): 106 """Predict on all patches using batched processing. 107 108 Processes patches in batches to manage memory usage effectively. 109 110 Args: 111 model: Loaded Keras model. 112 patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C). 113 patch_size: Size of each square patch. 114 batch_size: Number of patches to process at once. 115 verbose: If True, print progress information. 116 117 Returns: 118 Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1), 119 with values as uint8 (0 or 255). 120 """ 121 n_rows, n_cols = patches.shape[0], patches.shape[1] 122 123 predicted_patches = np.zeros((n_rows, n_cols, 1, patch_size, patch_size, 1), dtype=np.uint8) 124 125 all_patches = [] 126 patch_positions = [] 127 128 for row_idx in range(n_rows): 129 for col_idx in range(n_cols): 130 patch = patches[row_idx, col_idx, 0] 131 patch_normalized = patch.astype(np.float32) / 255.0 132 all_patches.append(patch_normalized) 133 patch_positions.append((row_idx, col_idx)) 134 135 all_patches = np.array(all_patches) 136 total_patches = len(all_patches) 137 138 if verbose: 139 print(f"Total patches: {total_patches}") 140 141 num_batches = (total_patches + batch_size - 1) // batch_size 142 143 if verbose: 144 print(f"Processing in {num_batches} batches of size {batch_size}") 145 146 for batch_idx in range(num_batches): 147 start_idx = batch_idx * batch_size 148 end_idx = min(start_idx + batch_size, total_patches) 149 150 batch_patches = all_patches[start_idx:end_idx] 151 batch_predictions = model.predict(batch_patches, verbose=0) 152 153 for i, pred in enumerate(batch_predictions): 154 patch_idx = start_idx + i 155 row_idx, col_idx = patch_positions[patch_idx] 156 pred_binary = (pred > 0.5).astype(np.uint8) * 255 157 predicted_patches[row_idx, col_idx, 0] = pred_binary 158 159 if (batch_idx + 1) % 5 == 0: 160 tf.keras.backend.clear_session() 161 162 if verbose: 163 print(f"Batch {batch_idx + 1}/{num_batches} ({end_idx}/{total_patches} patches)", end='\r') 164 165 if verbose: 166 print() 167 168 return predicted_patches 169 170 171def prepare_image_for_prediction(image, patch_size, step_size): 172 """Prepare an image for patch-based prediction. 173 174 Performs preprocessing including ROI detection, cropping, padding, 175 and patch extraction. 176 177 Args: 178 image: Input grayscale image as numpy array with shape (H, W). 179 patch_size: Size of square patches. 180 step_size: Step size for patch extraction (controls overlap). 181 182 Returns: 183 Dictionary containing: 184 - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C) 185 - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2)) 186 - 'padded_shape': Shape of padded image (H, W, C) 187 - 'original_shape': Shape of original input image (H, W) 188 """ 189 original_shape = image.shape 190 191 roi_box = detect_roi(image) 192 cropped_image = crop_to_roi(image, roi_box) 193 194 padded_image = padder(cropped_image, patch_size, step_size) 195 196 if len(padded_image.shape) == 2: 197 padded_image = np.expand_dims(padded_image, axis=-1) 198 199 padded_shape = padded_image.shape 200 201 patches = patchify(padded_image, (patch_size, patch_size, 1), step=step_size) 202 203 return { 204 'patches': patches, 205 'roi_box': roi_box, 206 'padded_shape': padded_shape, 207 'original_shape': original_shape 208 } 209 210 211def unpatchify_with_overlap(patches, target_shape, patch_size, step_size): 212 """Reconstruct full image from overlapping patches. 213 214 Uses averaging for overlapping regions to create smooth transitions. 215 216 Args: 217 patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1). 218 target_shape: Target shape for reconstructed image (H, W, C). 219 patch_size: Size of square patches. 220 step_size: Step size used during patch extraction. 221 222 Returns: 223 Reconstructed image as numpy array with shape (H, W). 224 """ 225 n_rows, n_cols = patches.shape[0], patches.shape[1] 226 227 reconstructed = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32) 228 count_map = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32) 229 230 for row_idx in range(n_rows): 231 for col_idx in range(n_cols): 232 patch = patches[row_idx, col_idx, 0, :, :, 0] 233 234 start_row = row_idx * step_size 235 start_col = col_idx * step_size 236 end_row = min(start_row + patch_size, target_shape[0]) 237 end_col = min(start_col + patch_size, target_shape[1]) 238 239 patch_height = end_row - start_row 240 patch_width = end_col - start_col 241 242 reconstructed[start_row:end_row, start_col:end_col] += patch[:patch_height, :patch_width] 243 count_map[start_row:end_row, start_col:end_col] += 1 244 245 count_map[count_map == 0] = 1 246 reconstructed = reconstructed / count_map 247 248 return reconstructed.astype(np.uint8) 249 250 251def predict_single_image(model, image, patch_size, step_size, batch_size=8, verbose=True): 252 """Run complete inference pipeline on a single image. 253 254 Args: 255 model: Loaded Keras model. 256 image: Input grayscale image as numpy array with shape (H, W). 257 patch_size: Size of square patches for prediction. 258 step_size: Step size for patch extraction (controls overlap). 259 batch_size: Number of patches to process at once. 260 verbose: If True, print progress information. 261 262 Returns: 263 Binary mask as numpy array with shape (H, W) matching input image, 264 with values 0 (background) and 255 (foreground). 265 """ 266 if verbose: 267 print(f"Processing image with shape {image.shape}") 268 269 result = prepare_image_for_prediction(image, patch_size, step_size) 270 if verbose: 271 print(f"Created {result['patches'].shape[0]}x{result['patches'].shape[1]} patch grid") 272 273 predicted_patches = predict_patches_batched( 274 model, result['patches'], patch_size, batch_size, verbose 275 ) 276 277 if verbose: 278 print("Reconstructing image from patches...") 279 reconstructed = unpatchify_with_overlap( 280 predicted_patches, result['padded_shape'], patch_size, step_size 281 ) 282 283 full_mask = restore_mask_to_original( 284 reconstructed, result['original_shape'], result['roi_box'] 285 ) 286 287 if verbose: 288 print(f"Final mask shape: {full_mask.shape}") 289 290 return full_mask 291 292 293def run_inference( 294 model_path: Union[str, Path], 295 image_paths: List[Union[str, Path]], 296 output_dir: Union[str, Path], 297 mask_type: str, 298 patch_size: int = 256, 299 step_size: int = 128, 300 batch_size: int = 8, 301 verbose: bool = True 302) -> int: 303 """Run inference on multiple images using a trained segmentation model. 304 305 This is the main convenience function that handles model loading, image 306 processing, and saving predictions. Output files are saved to a directory 307 named after the model stem. 308 309 Args: 310 model_path: Path to the trained model file (.h5 or .keras). 311 image_paths: List of paths to input images (as Path objects or strings). 312 output_dir: Base directory where outputs will be saved. 313 mask_type: Type of mask being predicted (e.g., 'root', 'shoot'). 314 patch_size: Size of square patches for prediction. Default is 256. 315 step_size: Step size for patch extraction. Default is 128. 316 batch_size: Number of patches to process at once. Default is 8. 317 verbose: If True, print progress information. Default is True. 318 319 Returns: 320 Number of images successfully processed. 321 322 Output structure: 323 output_dir/ 324 model_name/ 325 image1_mask_type.png 326 image2_mask_type.png 327 328 Example: 329 >>> image_list = ['img1.png', 'img2.png', 'img3.png'] 330 >>> n = run_inference( 331 ... model_path='shoots.h5', 332 ... image_paths=image_list, 333 ... output_dir='./predictions', 334 ... mask_type='shoot', 335 ... patch_size=256, 336 ... step_size=128 337 ... ) 338 >>> print(f"Processed {n} images") 339 """ 340 model_path = Path(model_path) 341 output_dir = Path(output_dir) 342 image_paths = [Path(p) for p in image_paths] 343 344 model_name = model_path.stem 345 346 model_output_dir = output_dir / model_name 347 model_output_dir.mkdir(parents=True, exist_ok=True) 348 349 if verbose: 350 print(f"Loading model from {model_path}") 351 model = load_segmentation_model(model_path, verbose=verbose) 352 353 if verbose: 354 print(f"Processing {len(image_paths)} images") 355 print(f"Output directory: {model_output_dir}") 356 357 processed_count = 0 358 359 for i, img_path in enumerate(image_paths, 1): 360 if not img_path.exists(): 361 print(f"Warning: {img_path} does not exist, skipping") 362 continue 363 364 if verbose: 365 print(f"\n[{i}/{len(image_paths)}] Processing {img_path.name}") 366 367 image = cv2.imread(str(img_path), 0) 368 369 if image is None: 370 print(f"Warning: Failed to load {img_path}, skipping") 371 continue 372 373 mask = predict_single_image(model, image, patch_size, step_size, batch_size, verbose) 374 375 output_filename = f"{img_path.stem}_{mask_type}.png" 376 output_path = model_output_dir / output_filename 377 378 cv2.imwrite(str(output_path), mask) 379 380 if verbose: 381 print(f"Saved mask to {output_path}") 382 383 processed_count += 1 384 385 if verbose: 386 print(f"\nCompleted processing {processed_count}/{len(image_paths)} images") 387 388 return processed_count, model_output_dir
24def f1(y_true, y_pred): 25 """Calculate F1 score metric for binary segmentation. 26 27 Computes F1 score as the harmonic mean of precision and recall. 28 29 Args: 30 y_true: Ground truth binary labels. 31 y_pred: Predicted probabilities or binary labels. 32 33 Returns: 34 F1 score as a scalar tensor. 35 """ 36 def recall_m(y_true, y_pred): 37 TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 38 Positives = K.sum(K.round(K.clip(y_true, 0, 1))) 39 recall = TP / (Positives + K.epsilon()) 40 return recall 41 42 def precision_m(y_true, y_pred): 43 TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 44 Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1))) 45 precision = TP / (Pred_Positives + K.epsilon()) 46 return precision 47 48 precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred) 49 50 return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
Calculate F1 score metric for binary segmentation.
Computes F1 score as the harmonic mean of precision and recall.
Arguments:
- y_true: Ground truth binary labels.
- y_pred: Predicted probabilities or binary labels.
Returns:
F1 score as a scalar tensor.
53def configure_tensorflow_for_platform(): 54 """Configure TensorFlow based on the current platform. 55 56 Detects if running on Mac with Metal support, Linux/Windows with CUDA, 57 or CPU-only. Configures TensorFlow accordingly and enables memory growth 58 for GPU devices. 59 60 Returns: 61 str: Device type - 'metal', 'cuda', or 'cpu'. 62 """ 63 system = platform.system() 64 65 if system == 'Darwin': 66 gpus = tf.config.list_physical_devices('GPU') 67 if gpus: 68 try: 69 for gpu in gpus: 70 tf.config.experimental.set_memory_growth(gpu, True) 71 return 'metal' 72 except RuntimeError as e: 73 print(f"GPU configuration error: {e}") 74 return 'cpu' 75 return 'cpu' 76 else: 77 gpus = tf.config.list_physical_devices('GPU') 78 if gpus: 79 return 'cuda' 80 return 'cpu'
Configure TensorFlow based on the current platform.
Detects if running on Mac with Metal support, Linux/Windows with CUDA, or CPU-only. Configures TensorFlow accordingly and enables memory growth for GPU devices.
Returns:
str: Device type - 'metal', 'cuda', or 'cpu'.
83def load_segmentation_model(model_path, verbose=True): 84 """Load a Keras segmentation model with custom F1 metric. 85 86 Automatically detects platform and configures TensorFlow before loading 87 the model. 88 89 Args: 90 model_path: Path to the saved model file (.h5 or .keras). 91 verbose: If True, print device information. 92 93 Returns: 94 Loaded Keras model ready for inference. 95 """ 96 device = configure_tensorflow_for_platform() 97 98 if verbose: 99 print(f"Using device: {device}") 100 101 model = load_model(model_path, custom_objects={"f1": f1}, compile=False) 102 103 return model
Load a Keras segmentation model with custom F1 metric.
Automatically detects platform and configures TensorFlow before loading the model.
Arguments:
- model_path: Path to the saved model file (.h5 or .keras).
- verbose: If True, print device information.
Returns:
Loaded Keras model ready for inference.
106def predict_patches_batched(model, patches, patch_size, batch_size=8, verbose=True): 107 """Predict on all patches using batched processing. 108 109 Processes patches in batches to manage memory usage effectively. 110 111 Args: 112 model: Loaded Keras model. 113 patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C). 114 patch_size: Size of each square patch. 115 batch_size: Number of patches to process at once. 116 verbose: If True, print progress information. 117 118 Returns: 119 Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1), 120 with values as uint8 (0 or 255). 121 """ 122 n_rows, n_cols = patches.shape[0], patches.shape[1] 123 124 predicted_patches = np.zeros((n_rows, n_cols, 1, patch_size, patch_size, 1), dtype=np.uint8) 125 126 all_patches = [] 127 patch_positions = [] 128 129 for row_idx in range(n_rows): 130 for col_idx in range(n_cols): 131 patch = patches[row_idx, col_idx, 0] 132 patch_normalized = patch.astype(np.float32) / 255.0 133 all_patches.append(patch_normalized) 134 patch_positions.append((row_idx, col_idx)) 135 136 all_patches = np.array(all_patches) 137 total_patches = len(all_patches) 138 139 if verbose: 140 print(f"Total patches: {total_patches}") 141 142 num_batches = (total_patches + batch_size - 1) // batch_size 143 144 if verbose: 145 print(f"Processing in {num_batches} batches of size {batch_size}") 146 147 for batch_idx in range(num_batches): 148 start_idx = batch_idx * batch_size 149 end_idx = min(start_idx + batch_size, total_patches) 150 151 batch_patches = all_patches[start_idx:end_idx] 152 batch_predictions = model.predict(batch_patches, verbose=0) 153 154 for i, pred in enumerate(batch_predictions): 155 patch_idx = start_idx + i 156 row_idx, col_idx = patch_positions[patch_idx] 157 pred_binary = (pred > 0.5).astype(np.uint8) * 255 158 predicted_patches[row_idx, col_idx, 0] = pred_binary 159 160 if (batch_idx + 1) % 5 == 0: 161 tf.keras.backend.clear_session() 162 163 if verbose: 164 print(f"Batch {batch_idx + 1}/{num_batches} ({end_idx}/{total_patches} patches)", end='\r') 165 166 if verbose: 167 print() 168 169 return predicted_patches
Predict on all patches using batched processing.
Processes patches in batches to manage memory usage effectively.
Arguments:
- model: Loaded Keras model.
- patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C).
- patch_size: Size of each square patch.
- batch_size: Number of patches to process at once.
- verbose: If True, print progress information.
Returns:
Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1), with values as uint8 (0 or 255).
172def prepare_image_for_prediction(image, patch_size, step_size): 173 """Prepare an image for patch-based prediction. 174 175 Performs preprocessing including ROI detection, cropping, padding, 176 and patch extraction. 177 178 Args: 179 image: Input grayscale image as numpy array with shape (H, W). 180 patch_size: Size of square patches. 181 step_size: Step size for patch extraction (controls overlap). 182 183 Returns: 184 Dictionary containing: 185 - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C) 186 - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2)) 187 - 'padded_shape': Shape of padded image (H, W, C) 188 - 'original_shape': Shape of original input image (H, W) 189 """ 190 original_shape = image.shape 191 192 roi_box = detect_roi(image) 193 cropped_image = crop_to_roi(image, roi_box) 194 195 padded_image = padder(cropped_image, patch_size, step_size) 196 197 if len(padded_image.shape) == 2: 198 padded_image = np.expand_dims(padded_image, axis=-1) 199 200 padded_shape = padded_image.shape 201 202 patches = patchify(padded_image, (patch_size, patch_size, 1), step=step_size) 203 204 return { 205 'patches': patches, 206 'roi_box': roi_box, 207 'padded_shape': padded_shape, 208 'original_shape': original_shape 209 }
Prepare an image for patch-based prediction.
Performs preprocessing including ROI detection, cropping, padding, and patch extraction.
Arguments:
- image: Input grayscale image as numpy array with shape (H, W).
- patch_size: Size of square patches.
- step_size: Step size for patch extraction (controls overlap).
Returns:
Dictionary containing: - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C) - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2)) - 'padded_shape': Shape of padded image (H, W, C) - 'original_shape': Shape of original input image (H, W)
212def unpatchify_with_overlap(patches, target_shape, patch_size, step_size): 213 """Reconstruct full image from overlapping patches. 214 215 Uses averaging for overlapping regions to create smooth transitions. 216 217 Args: 218 patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1). 219 target_shape: Target shape for reconstructed image (H, W, C). 220 patch_size: Size of square patches. 221 step_size: Step size used during patch extraction. 222 223 Returns: 224 Reconstructed image as numpy array with shape (H, W). 225 """ 226 n_rows, n_cols = patches.shape[0], patches.shape[1] 227 228 reconstructed = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32) 229 count_map = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32) 230 231 for row_idx in range(n_rows): 232 for col_idx in range(n_cols): 233 patch = patches[row_idx, col_idx, 0, :, :, 0] 234 235 start_row = row_idx * step_size 236 start_col = col_idx * step_size 237 end_row = min(start_row + patch_size, target_shape[0]) 238 end_col = min(start_col + patch_size, target_shape[1]) 239 240 patch_height = end_row - start_row 241 patch_width = end_col - start_col 242 243 reconstructed[start_row:end_row, start_col:end_col] += patch[:patch_height, :patch_width] 244 count_map[start_row:end_row, start_col:end_col] += 1 245 246 count_map[count_map == 0] = 1 247 reconstructed = reconstructed / count_map 248 249 return reconstructed.astype(np.uint8)
Reconstruct full image from overlapping patches.
Uses averaging for overlapping regions to create smooth transitions.
Arguments:
- patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1).
- target_shape: Target shape for reconstructed image (H, W, C).
- patch_size: Size of square patches.
- step_size: Step size used during patch extraction.
Returns:
Reconstructed image as numpy array with shape (H, W).
252def predict_single_image(model, image, patch_size, step_size, batch_size=8, verbose=True): 253 """Run complete inference pipeline on a single image. 254 255 Args: 256 model: Loaded Keras model. 257 image: Input grayscale image as numpy array with shape (H, W). 258 patch_size: Size of square patches for prediction. 259 step_size: Step size for patch extraction (controls overlap). 260 batch_size: Number of patches to process at once. 261 verbose: If True, print progress information. 262 263 Returns: 264 Binary mask as numpy array with shape (H, W) matching input image, 265 with values 0 (background) and 255 (foreground). 266 """ 267 if verbose: 268 print(f"Processing image with shape {image.shape}") 269 270 result = prepare_image_for_prediction(image, patch_size, step_size) 271 if verbose: 272 print(f"Created {result['patches'].shape[0]}x{result['patches'].shape[1]} patch grid") 273 274 predicted_patches = predict_patches_batched( 275 model, result['patches'], patch_size, batch_size, verbose 276 ) 277 278 if verbose: 279 print("Reconstructing image from patches...") 280 reconstructed = unpatchify_with_overlap( 281 predicted_patches, result['padded_shape'], patch_size, step_size 282 ) 283 284 full_mask = restore_mask_to_original( 285 reconstructed, result['original_shape'], result['roi_box'] 286 ) 287 288 if verbose: 289 print(f"Final mask shape: {full_mask.shape}") 290 291 return full_mask
Run complete inference pipeline on a single image.
Arguments:
- model: Loaded Keras model.
- image: Input grayscale image as numpy array with shape (H, W).
- patch_size: Size of square patches for prediction.
- step_size: Step size for patch extraction (controls overlap).
- batch_size: Number of patches to process at once.
- verbose: If True, print progress information.
Returns:
Binary mask as numpy array with shape (H, W) matching input image, with values 0 (background) and 255 (foreground).
294def run_inference( 295 model_path: Union[str, Path], 296 image_paths: List[Union[str, Path]], 297 output_dir: Union[str, Path], 298 mask_type: str, 299 patch_size: int = 256, 300 step_size: int = 128, 301 batch_size: int = 8, 302 verbose: bool = True 303) -> int: 304 """Run inference on multiple images using a trained segmentation model. 305 306 This is the main convenience function that handles model loading, image 307 processing, and saving predictions. Output files are saved to a directory 308 named after the model stem. 309 310 Args: 311 model_path: Path to the trained model file (.h5 or .keras). 312 image_paths: List of paths to input images (as Path objects or strings). 313 output_dir: Base directory where outputs will be saved. 314 mask_type: Type of mask being predicted (e.g., 'root', 'shoot'). 315 patch_size: Size of square patches for prediction. Default is 256. 316 step_size: Step size for patch extraction. Default is 128. 317 batch_size: Number of patches to process at once. Default is 8. 318 verbose: If True, print progress information. Default is True. 319 320 Returns: 321 Number of images successfully processed. 322 323 Output structure: 324 output_dir/ 325 model_name/ 326 image1_mask_type.png 327 image2_mask_type.png 328 329 Example: 330 >>> image_list = ['img1.png', 'img2.png', 'img3.png'] 331 >>> n = run_inference( 332 ... model_path='shoots.h5', 333 ... image_paths=image_list, 334 ... output_dir='./predictions', 335 ... mask_type='shoot', 336 ... patch_size=256, 337 ... step_size=128 338 ... ) 339 >>> print(f"Processed {n} images") 340 """ 341 model_path = Path(model_path) 342 output_dir = Path(output_dir) 343 image_paths = [Path(p) for p in image_paths] 344 345 model_name = model_path.stem 346 347 model_output_dir = output_dir / model_name 348 model_output_dir.mkdir(parents=True, exist_ok=True) 349 350 if verbose: 351 print(f"Loading model from {model_path}") 352 model = load_segmentation_model(model_path, verbose=verbose) 353 354 if verbose: 355 print(f"Processing {len(image_paths)} images") 356 print(f"Output directory: {model_output_dir}") 357 358 processed_count = 0 359 360 for i, img_path in enumerate(image_paths, 1): 361 if not img_path.exists(): 362 print(f"Warning: {img_path} does not exist, skipping") 363 continue 364 365 if verbose: 366 print(f"\n[{i}/{len(image_paths)}] Processing {img_path.name}") 367 368 image = cv2.imread(str(img_path), 0) 369 370 if image is None: 371 print(f"Warning: Failed to load {img_path}, skipping") 372 continue 373 374 mask = predict_single_image(model, image, patch_size, step_size, batch_size, verbose) 375 376 output_filename = f"{img_path.stem}_{mask_type}.png" 377 output_path = model_output_dir / output_filename 378 379 cv2.imwrite(str(output_path), mask) 380 381 if verbose: 382 print(f"Saved mask to {output_path}") 383 384 processed_count += 1 385 386 if verbose: 387 print(f"\nCompleted processing {processed_count}/{len(image_paths)} images") 388 389 return processed_count, model_output_dir
Run inference on multiple images using a trained segmentation model.
This is the main convenience function that handles model loading, image processing, and saving predictions. Output files are saved to a directory named after the model stem.
Arguments:
- model_path: Path to the trained model file (.h5 or .keras).
- image_paths: List of paths to input images (as Path objects or strings).
- output_dir: Base directory where outputs will be saved.
- mask_type: Type of mask being predicted (e.g., 'root', 'shoot').
- patch_size: Size of square patches for prediction. Default is 256.
- step_size: Step size for patch extraction. Default is 128.
- batch_size: Number of patches to process at once. Default is 8.
- verbose: If True, print progress information. Default is True.
Returns:
Number of images successfully processed.
Output structure:
output_dir/ model_name/ image1_mask_type.png image2_mask_type.png
Example:
>>> image_list = ['img1.png', 'img2.png', 'img3.png'] >>> n = run_inference( ... model_path='shoots.h5', ... image_paths=image_list, ... output_dir='./predictions', ... mask_type='shoot', ... patch_size=256, ... step_size=128 ... ) >>> print(f"Processed {n} images")