library.run_inference

Inference module for segmentation model predictions.

This module provides functions to run inference on images using trained segmentation models. It handles patch-based prediction with overlap, ROI detection, and reconstruction of full-resolution masks.

  1"""Inference module for segmentation model predictions.
  2
  3This module provides functions to run inference on images using trained
  4segmentation models. It handles patch-based prediction with overlap,
  5ROI detection, and reconstruction of full-resolution masks.
  6"""
  7
  8import platform
  9from pathlib import Path
 10from typing import Union, List, Tuple, Dict
 11
 12import numpy as np
 13import cv2
 14import tensorflow as tf
 15from tensorflow.keras.models import load_model
 16import keras.backend as K
 17from patchify import patchify
 18
 19from library.roi import detect_roi, crop_to_roi
 20from library.patch_dataset import padder, restore_mask_to_original
 21
 22
 23def f1(y_true, y_pred):
 24    """Calculate F1 score metric for binary segmentation.
 25    
 26    Computes F1 score as the harmonic mean of precision and recall.
 27    
 28    Args:
 29        y_true: Ground truth binary labels.
 30        y_pred: Predicted probabilities or binary labels.
 31    
 32    Returns:
 33        F1 score as a scalar tensor.
 34    """
 35    def recall_m(y_true, y_pred):
 36        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
 37        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
 38        recall = TP / (Positives + K.epsilon())
 39        return recall
 40    
 41    def precision_m(y_true, y_pred):
 42        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
 43        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
 44        precision = TP / (Pred_Positives + K.epsilon())
 45        return precision
 46    
 47    precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred)
 48    
 49    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))
 50
 51
 52def configure_tensorflow_for_platform():
 53    """Configure TensorFlow based on the current platform.
 54    
 55    Detects if running on Mac with Metal support, Linux/Windows with CUDA,
 56    or CPU-only. Configures TensorFlow accordingly and enables memory growth
 57    for GPU devices.
 58    
 59    Returns:
 60        str: Device type - 'metal', 'cuda', or 'cpu'.
 61    """
 62    system = platform.system()
 63    
 64    if system == 'Darwin':
 65        gpus = tf.config.list_physical_devices('GPU')
 66        if gpus:
 67            try:
 68                for gpu in gpus:
 69                    tf.config.experimental.set_memory_growth(gpu, True)
 70                return 'metal'
 71            except RuntimeError as e:
 72                print(f"GPU configuration error: {e}")
 73                return 'cpu'
 74        return 'cpu'
 75    else:
 76        gpus = tf.config.list_physical_devices('GPU')
 77        if gpus:
 78            return 'cuda'
 79        return 'cpu'
 80
 81
 82def load_segmentation_model(model_path, verbose=True):
 83    """Load a Keras segmentation model with custom F1 metric.
 84    
 85    Automatically detects platform and configures TensorFlow before loading
 86    the model.
 87    
 88    Args:
 89        model_path: Path to the saved model file (.h5 or .keras).
 90        verbose: If True, print device information.
 91    
 92    Returns:
 93        Loaded Keras model ready for inference.
 94    """
 95    device = configure_tensorflow_for_platform()
 96    
 97    if verbose:
 98        print(f"Using device: {device}")
 99    
100    model = load_model(model_path, custom_objects={"f1": f1}, compile=False)
101    
102    return model
103
104
105def predict_patches_batched(model, patches, patch_size, batch_size=8, verbose=True):
106    """Predict on all patches using batched processing.
107    
108    Processes patches in batches to manage memory usage effectively.
109    
110    Args:
111        model: Loaded Keras model.
112        patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C).
113        patch_size: Size of each square patch.
114        batch_size: Number of patches to process at once.
115        verbose: If True, print progress information.
116    
117    Returns:
118        Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1),
119        with values as uint8 (0 or 255).
120    """
121    n_rows, n_cols = patches.shape[0], patches.shape[1]
122    
123    predicted_patches = np.zeros((n_rows, n_cols, 1, patch_size, patch_size, 1), dtype=np.uint8)
124    
125    all_patches = []
126    patch_positions = []
127    
128    for row_idx in range(n_rows):
129        for col_idx in range(n_cols):
130            patch = patches[row_idx, col_idx, 0]
131            patch_normalized = patch.astype(np.float32) / 255.0
132            all_patches.append(patch_normalized)
133            patch_positions.append((row_idx, col_idx))
134    
135    all_patches = np.array(all_patches)
136    total_patches = len(all_patches)
137    
138    if verbose:
139        print(f"Total patches: {total_patches}")
140    
141    num_batches = (total_patches + batch_size - 1) // batch_size
142    
143    if verbose:
144        print(f"Processing in {num_batches} batches of size {batch_size}")
145    
146    for batch_idx in range(num_batches):
147        start_idx = batch_idx * batch_size
148        end_idx = min(start_idx + batch_size, total_patches)
149        
150        batch_patches = all_patches[start_idx:end_idx]
151        batch_predictions = model.predict(batch_patches, verbose=0)
152        
153        for i, pred in enumerate(batch_predictions):
154            patch_idx = start_idx + i
155            row_idx, col_idx = patch_positions[patch_idx]
156            pred_binary = (pred > 0.5).astype(np.uint8) * 255
157            predicted_patches[row_idx, col_idx, 0] = pred_binary
158        
159        if (batch_idx + 1) % 5 == 0:
160            tf.keras.backend.clear_session()
161        
162        if verbose:
163            print(f"Batch {batch_idx + 1}/{num_batches} ({end_idx}/{total_patches} patches)", end='\r')
164    
165    if verbose:
166        print()
167    
168    return predicted_patches
169
170
171def prepare_image_for_prediction(image, patch_size, step_size):
172    """Prepare an image for patch-based prediction.
173    
174    Performs preprocessing including ROI detection, cropping, padding,
175    and patch extraction.
176    
177    Args:
178        image: Input grayscale image as numpy array with shape (H, W).
179        patch_size: Size of square patches.
180        step_size: Step size for patch extraction (controls overlap).
181    
182    Returns:
183        Dictionary containing:
184            - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C)
185            - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2))
186            - 'padded_shape': Shape of padded image (H, W, C)
187            - 'original_shape': Shape of original input image (H, W)
188    """
189    original_shape = image.shape
190    
191    roi_box = detect_roi(image)
192    cropped_image = crop_to_roi(image, roi_box)
193    
194    padded_image = padder(cropped_image, patch_size, step_size)
195    
196    if len(padded_image.shape) == 2:
197        padded_image = np.expand_dims(padded_image, axis=-1)
198    
199    padded_shape = padded_image.shape
200    
201    patches = patchify(padded_image, (patch_size, patch_size, 1), step=step_size)
202    
203    return {
204        'patches': patches,
205        'roi_box': roi_box,
206        'padded_shape': padded_shape,
207        'original_shape': original_shape
208    }
209
210
211def unpatchify_with_overlap(patches, target_shape, patch_size, step_size):
212    """Reconstruct full image from overlapping patches.
213    
214    Uses averaging for overlapping regions to create smooth transitions.
215    
216    Args:
217        patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1).
218        target_shape: Target shape for reconstructed image (H, W, C).
219        patch_size: Size of square patches.
220        step_size: Step size used during patch extraction.
221    
222    Returns:
223        Reconstructed image as numpy array with shape (H, W).
224    """
225    n_rows, n_cols = patches.shape[0], patches.shape[1]
226    
227    reconstructed = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32)
228    count_map = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32)
229    
230    for row_idx in range(n_rows):
231        for col_idx in range(n_cols):
232            patch = patches[row_idx, col_idx, 0, :, :, 0]
233            
234            start_row = row_idx * step_size
235            start_col = col_idx * step_size
236            end_row = min(start_row + patch_size, target_shape[0])
237            end_col = min(start_col + patch_size, target_shape[1])
238            
239            patch_height = end_row - start_row
240            patch_width = end_col - start_col
241            
242            reconstructed[start_row:end_row, start_col:end_col] += patch[:patch_height, :patch_width]
243            count_map[start_row:end_row, start_col:end_col] += 1
244    
245    count_map[count_map == 0] = 1
246    reconstructed = reconstructed / count_map
247    
248    return reconstructed.astype(np.uint8)
249
250
251def predict_single_image(model, image, patch_size, step_size, batch_size=8, verbose=True):
252    """Run complete inference pipeline on a single image.
253    
254    Args:
255        model: Loaded Keras model.
256        image: Input grayscale image as numpy array with shape (H, W).
257        patch_size: Size of square patches for prediction.
258        step_size: Step size for patch extraction (controls overlap).
259        batch_size: Number of patches to process at once.
260        verbose: If True, print progress information.
261    
262    Returns:
263        Binary mask as numpy array with shape (H, W) matching input image,
264        with values 0 (background) and 255 (foreground).
265    """
266    if verbose:
267        print(f"Processing image with shape {image.shape}")
268    
269    result = prepare_image_for_prediction(image, patch_size, step_size)
270    if verbose:
271        print(f"Created {result['patches'].shape[0]}x{result['patches'].shape[1]} patch grid")
272    
273    predicted_patches = predict_patches_batched(
274        model, result['patches'], patch_size, batch_size, verbose
275    )
276    
277    if verbose:
278        print("Reconstructing image from patches...")
279    reconstructed = unpatchify_with_overlap(
280        predicted_patches, result['padded_shape'], patch_size, step_size
281    )
282    
283    full_mask = restore_mask_to_original(
284        reconstructed, result['original_shape'], result['roi_box']
285    )
286    
287    if verbose:
288        print(f"Final mask shape: {full_mask.shape}")
289    
290    return full_mask
291
292
293def run_inference(
294    model_path: Union[str, Path],
295    image_paths: List[Union[str, Path]],
296    output_dir: Union[str, Path],
297    mask_type: str,
298    patch_size: int = 256,
299    step_size: int = 128,
300    batch_size: int = 8,
301    verbose: bool = True
302) -> int:
303    """Run inference on multiple images using a trained segmentation model.
304    
305    This is the main convenience function that handles model loading, image
306    processing, and saving predictions. Output files are saved to a directory
307    named after the model stem.
308    
309    Args:
310        model_path: Path to the trained model file (.h5 or .keras).
311        image_paths: List of paths to input images (as Path objects or strings).
312        output_dir: Base directory where outputs will be saved.
313        mask_type: Type of mask being predicted (e.g., 'root', 'shoot').
314        patch_size: Size of square patches for prediction. Default is 256.
315        step_size: Step size for patch extraction. Default is 128.
316        batch_size: Number of patches to process at once. Default is 8.
317        verbose: If True, print progress information. Default is True.
318    
319    Returns:
320        Number of images successfully processed.
321    
322    Output structure:
323        output_dir/
324            model_name/
325                image1_mask_type.png
326                image2_mask_type.png
327    
328    Example:
329        >>> image_list = ['img1.png', 'img2.png', 'img3.png']
330        >>> n = run_inference(
331        ...     model_path='shoots.h5',
332        ...     image_paths=image_list,
333        ...     output_dir='./predictions',
334        ...     mask_type='shoot',
335        ...     patch_size=256,
336        ...     step_size=128
337        ... )
338        >>> print(f"Processed {n} images")
339    """
340    model_path = Path(model_path)
341    output_dir = Path(output_dir)
342    image_paths = [Path(p) for p in image_paths]
343    
344    model_name = model_path.stem
345    
346    model_output_dir = output_dir / model_name
347    model_output_dir.mkdir(parents=True, exist_ok=True)
348    
349    if verbose:
350        print(f"Loading model from {model_path}")
351    model = load_segmentation_model(model_path, verbose=verbose)
352    
353    if verbose:
354        print(f"Processing {len(image_paths)} images")
355        print(f"Output directory: {model_output_dir}")
356    
357    processed_count = 0
358    
359    for i, img_path in enumerate(image_paths, 1):
360        if not img_path.exists():
361            print(f"Warning: {img_path} does not exist, skipping")
362            continue
363        
364        if verbose:
365            print(f"\n[{i}/{len(image_paths)}] Processing {img_path.name}")
366        
367        image = cv2.imread(str(img_path), 0)
368        
369        if image is None:
370            print(f"Warning: Failed to load {img_path}, skipping")
371            continue
372        
373        mask = predict_single_image(model, image, patch_size, step_size, batch_size, verbose)
374        
375        output_filename = f"{img_path.stem}_{mask_type}.png"
376        output_path = model_output_dir / output_filename
377        
378        cv2.imwrite(str(output_path), mask)
379        
380        if verbose:
381            print(f"Saved mask to {output_path}")
382        
383        processed_count += 1
384    
385    if verbose:
386        print(f"\nCompleted processing {processed_count}/{len(image_paths)} images")
387    
388    return processed_count, model_output_dir
def f1(y_true, y_pred):
24def f1(y_true, y_pred):
25    """Calculate F1 score metric for binary segmentation.
26    
27    Computes F1 score as the harmonic mean of precision and recall.
28    
29    Args:
30        y_true: Ground truth binary labels.
31        y_pred: Predicted probabilities or binary labels.
32    
33    Returns:
34        F1 score as a scalar tensor.
35    """
36    def recall_m(y_true, y_pred):
37        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
38        Positives = K.sum(K.round(K.clip(y_true, 0, 1)))
39        recall = TP / (Positives + K.epsilon())
40        return recall
41    
42    def precision_m(y_true, y_pred):
43        TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
44        Pred_Positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
45        precision = TP / (Pred_Positives + K.epsilon())
46        return precision
47    
48    precision, recall = precision_m(y_true, y_pred), recall_m(y_true, y_pred)
49    
50    return 2 * ((precision * recall) / (precision + recall + K.epsilon()))

Calculate F1 score metric for binary segmentation.

Computes F1 score as the harmonic mean of precision and recall.

Arguments:
  • y_true: Ground truth binary labels.
  • y_pred: Predicted probabilities or binary labels.
Returns:

F1 score as a scalar tensor.

def configure_tensorflow_for_platform():
53def configure_tensorflow_for_platform():
54    """Configure TensorFlow based on the current platform.
55    
56    Detects if running on Mac with Metal support, Linux/Windows with CUDA,
57    or CPU-only. Configures TensorFlow accordingly and enables memory growth
58    for GPU devices.
59    
60    Returns:
61        str: Device type - 'metal', 'cuda', or 'cpu'.
62    """
63    system = platform.system()
64    
65    if system == 'Darwin':
66        gpus = tf.config.list_physical_devices('GPU')
67        if gpus:
68            try:
69                for gpu in gpus:
70                    tf.config.experimental.set_memory_growth(gpu, True)
71                return 'metal'
72            except RuntimeError as e:
73                print(f"GPU configuration error: {e}")
74                return 'cpu'
75        return 'cpu'
76    else:
77        gpus = tf.config.list_physical_devices('GPU')
78        if gpus:
79            return 'cuda'
80        return 'cpu'

Configure TensorFlow based on the current platform.

Detects if running on Mac with Metal support, Linux/Windows with CUDA, or CPU-only. Configures TensorFlow accordingly and enables memory growth for GPU devices.

Returns:

str: Device type - 'metal', 'cuda', or 'cpu'.

def load_segmentation_model(model_path, verbose=True):
 83def load_segmentation_model(model_path, verbose=True):
 84    """Load a Keras segmentation model with custom F1 metric.
 85    
 86    Automatically detects platform and configures TensorFlow before loading
 87    the model.
 88    
 89    Args:
 90        model_path: Path to the saved model file (.h5 or .keras).
 91        verbose: If True, print device information.
 92    
 93    Returns:
 94        Loaded Keras model ready for inference.
 95    """
 96    device = configure_tensorflow_for_platform()
 97    
 98    if verbose:
 99        print(f"Using device: {device}")
100    
101    model = load_model(model_path, custom_objects={"f1": f1}, compile=False)
102    
103    return model

Load a Keras segmentation model with custom F1 metric.

Automatically detects platform and configures TensorFlow before loading the model.

Arguments:
  • model_path: Path to the saved model file (.h5 or .keras).
  • verbose: If True, print device information.
Returns:

Loaded Keras model ready for inference.

def predict_patches_batched(model, patches, patch_size, batch_size=8, verbose=True):
106def predict_patches_batched(model, patches, patch_size, batch_size=8, verbose=True):
107    """Predict on all patches using batched processing.
108    
109    Processes patches in batches to manage memory usage effectively.
110    
111    Args:
112        model: Loaded Keras model.
113        patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C).
114        patch_size: Size of each square patch.
115        batch_size: Number of patches to process at once.
116        verbose: If True, print progress information.
117    
118    Returns:
119        Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1),
120        with values as uint8 (0 or 255).
121    """
122    n_rows, n_cols = patches.shape[0], patches.shape[1]
123    
124    predicted_patches = np.zeros((n_rows, n_cols, 1, patch_size, patch_size, 1), dtype=np.uint8)
125    
126    all_patches = []
127    patch_positions = []
128    
129    for row_idx in range(n_rows):
130        for col_idx in range(n_cols):
131            patch = patches[row_idx, col_idx, 0]
132            patch_normalized = patch.astype(np.float32) / 255.0
133            all_patches.append(patch_normalized)
134            patch_positions.append((row_idx, col_idx))
135    
136    all_patches = np.array(all_patches)
137    total_patches = len(all_patches)
138    
139    if verbose:
140        print(f"Total patches: {total_patches}")
141    
142    num_batches = (total_patches + batch_size - 1) // batch_size
143    
144    if verbose:
145        print(f"Processing in {num_batches} batches of size {batch_size}")
146    
147    for batch_idx in range(num_batches):
148        start_idx = batch_idx * batch_size
149        end_idx = min(start_idx + batch_size, total_patches)
150        
151        batch_patches = all_patches[start_idx:end_idx]
152        batch_predictions = model.predict(batch_patches, verbose=0)
153        
154        for i, pred in enumerate(batch_predictions):
155            patch_idx = start_idx + i
156            row_idx, col_idx = patch_positions[patch_idx]
157            pred_binary = (pred > 0.5).astype(np.uint8) * 255
158            predicted_patches[row_idx, col_idx, 0] = pred_binary
159        
160        if (batch_idx + 1) % 5 == 0:
161            tf.keras.backend.clear_session()
162        
163        if verbose:
164            print(f"Batch {batch_idx + 1}/{num_batches} ({end_idx}/{total_patches} patches)", end='\r')
165    
166    if verbose:
167        print()
168    
169    return predicted_patches

Predict on all patches using batched processing.

Processes patches in batches to manage memory usage effectively.

Arguments:
  • model: Loaded Keras model.
  • patches: Patchified image array with shape (n_rows, n_cols, 1, H, W, C).
  • patch_size: Size of each square patch.
  • batch_size: Number of patches to process at once.
  • verbose: If True, print progress information.
Returns:

Array of predicted patches with shape (n_rows, n_cols, 1, H, W, 1), with values as uint8 (0 or 255).

def prepare_image_for_prediction(image, patch_size, step_size):
172def prepare_image_for_prediction(image, patch_size, step_size):
173    """Prepare an image for patch-based prediction.
174    
175    Performs preprocessing including ROI detection, cropping, padding,
176    and patch extraction.
177    
178    Args:
179        image: Input grayscale image as numpy array with shape (H, W).
180        patch_size: Size of square patches.
181        step_size: Step size for patch extraction (controls overlap).
182    
183    Returns:
184        Dictionary containing:
185            - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C)
186            - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2))
187            - 'padded_shape': Shape of padded image (H, W, C)
188            - 'original_shape': Shape of original input image (H, W)
189    """
190    original_shape = image.shape
191    
192    roi_box = detect_roi(image)
193    cropped_image = crop_to_roi(image, roi_box)
194    
195    padded_image = padder(cropped_image, patch_size, step_size)
196    
197    if len(padded_image.shape) == 2:
198        padded_image = np.expand_dims(padded_image, axis=-1)
199    
200    padded_shape = padded_image.shape
201    
202    patches = patchify(padded_image, (patch_size, patch_size, 1), step=step_size)
203    
204    return {
205        'patches': patches,
206        'roi_box': roi_box,
207        'padded_shape': padded_shape,
208        'original_shape': original_shape
209    }

Prepare an image for patch-based prediction.

Performs preprocessing including ROI detection, cropping, padding, and patch extraction.

Arguments:
  • image: Input grayscale image as numpy array with shape (H, W).
  • patch_size: Size of square patches.
  • step_size: Step size for patch extraction (controls overlap).
Returns:

Dictionary containing: - 'patches': Patchified array with shape (n_rows, n_cols, 1, H, W, C) - 'roi_box': ROI coordinates as ((x1, y1), (x2, y2)) - 'padded_shape': Shape of padded image (H, W, C) - 'original_shape': Shape of original input image (H, W)

def unpatchify_with_overlap(patches, target_shape, patch_size, step_size):
212def unpatchify_with_overlap(patches, target_shape, patch_size, step_size):
213    """Reconstruct full image from overlapping patches.
214    
215    Uses averaging for overlapping regions to create smooth transitions.
216    
217    Args:
218        patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1).
219        target_shape: Target shape for reconstructed image (H, W, C).
220        patch_size: Size of square patches.
221        step_size: Step size used during patch extraction.
222    
223    Returns:
224        Reconstructed image as numpy array with shape (H, W).
225    """
226    n_rows, n_cols = patches.shape[0], patches.shape[1]
227    
228    reconstructed = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32)
229    count_map = np.zeros((target_shape[0], target_shape[1]), dtype=np.float32)
230    
231    for row_idx in range(n_rows):
232        for col_idx in range(n_cols):
233            patch = patches[row_idx, col_idx, 0, :, :, 0]
234            
235            start_row = row_idx * step_size
236            start_col = col_idx * step_size
237            end_row = min(start_row + patch_size, target_shape[0])
238            end_col = min(start_col + patch_size, target_shape[1])
239            
240            patch_height = end_row - start_row
241            patch_width = end_col - start_col
242            
243            reconstructed[start_row:end_row, start_col:end_col] += patch[:patch_height, :patch_width]
244            count_map[start_row:end_row, start_col:end_col] += 1
245    
246    count_map[count_map == 0] = 1
247    reconstructed = reconstructed / count_map
248    
249    return reconstructed.astype(np.uint8)

Reconstruct full image from overlapping patches.

Uses averaging for overlapping regions to create smooth transitions.

Arguments:
  • patches: Predicted patches with shape (n_rows, n_cols, 1, H, W, 1).
  • target_shape: Target shape for reconstructed image (H, W, C).
  • patch_size: Size of square patches.
  • step_size: Step size used during patch extraction.
Returns:

Reconstructed image as numpy array with shape (H, W).

def predict_single_image(model, image, patch_size, step_size, batch_size=8, verbose=True):
252def predict_single_image(model, image, patch_size, step_size, batch_size=8, verbose=True):
253    """Run complete inference pipeline on a single image.
254    
255    Args:
256        model: Loaded Keras model.
257        image: Input grayscale image as numpy array with shape (H, W).
258        patch_size: Size of square patches for prediction.
259        step_size: Step size for patch extraction (controls overlap).
260        batch_size: Number of patches to process at once.
261        verbose: If True, print progress information.
262    
263    Returns:
264        Binary mask as numpy array with shape (H, W) matching input image,
265        with values 0 (background) and 255 (foreground).
266    """
267    if verbose:
268        print(f"Processing image with shape {image.shape}")
269    
270    result = prepare_image_for_prediction(image, patch_size, step_size)
271    if verbose:
272        print(f"Created {result['patches'].shape[0]}x{result['patches'].shape[1]} patch grid")
273    
274    predicted_patches = predict_patches_batched(
275        model, result['patches'], patch_size, batch_size, verbose
276    )
277    
278    if verbose:
279        print("Reconstructing image from patches...")
280    reconstructed = unpatchify_with_overlap(
281        predicted_patches, result['padded_shape'], patch_size, step_size
282    )
283    
284    full_mask = restore_mask_to_original(
285        reconstructed, result['original_shape'], result['roi_box']
286    )
287    
288    if verbose:
289        print(f"Final mask shape: {full_mask.shape}")
290    
291    return full_mask

Run complete inference pipeline on a single image.

Arguments:
  • model: Loaded Keras model.
  • image: Input grayscale image as numpy array with shape (H, W).
  • patch_size: Size of square patches for prediction.
  • step_size: Step size for patch extraction (controls overlap).
  • batch_size: Number of patches to process at once.
  • verbose: If True, print progress information.
Returns:

Binary mask as numpy array with shape (H, W) matching input image, with values 0 (background) and 255 (foreground).

def run_inference( model_path: Union[str, pathlib.Path], image_paths: List[Union[str, pathlib.Path]], output_dir: Union[str, pathlib.Path], mask_type: str, patch_size: int = 256, step_size: int = 128, batch_size: int = 8, verbose: bool = True) -> int:
294def run_inference(
295    model_path: Union[str, Path],
296    image_paths: List[Union[str, Path]],
297    output_dir: Union[str, Path],
298    mask_type: str,
299    patch_size: int = 256,
300    step_size: int = 128,
301    batch_size: int = 8,
302    verbose: bool = True
303) -> int:
304    """Run inference on multiple images using a trained segmentation model.
305    
306    This is the main convenience function that handles model loading, image
307    processing, and saving predictions. Output files are saved to a directory
308    named after the model stem.
309    
310    Args:
311        model_path: Path to the trained model file (.h5 or .keras).
312        image_paths: List of paths to input images (as Path objects or strings).
313        output_dir: Base directory where outputs will be saved.
314        mask_type: Type of mask being predicted (e.g., 'root', 'shoot').
315        patch_size: Size of square patches for prediction. Default is 256.
316        step_size: Step size for patch extraction. Default is 128.
317        batch_size: Number of patches to process at once. Default is 8.
318        verbose: If True, print progress information. Default is True.
319    
320    Returns:
321        Number of images successfully processed.
322    
323    Output structure:
324        output_dir/
325            model_name/
326                image1_mask_type.png
327                image2_mask_type.png
328    
329    Example:
330        >>> image_list = ['img1.png', 'img2.png', 'img3.png']
331        >>> n = run_inference(
332        ...     model_path='shoots.h5',
333        ...     image_paths=image_list,
334        ...     output_dir='./predictions',
335        ...     mask_type='shoot',
336        ...     patch_size=256,
337        ...     step_size=128
338        ... )
339        >>> print(f"Processed {n} images")
340    """
341    model_path = Path(model_path)
342    output_dir = Path(output_dir)
343    image_paths = [Path(p) for p in image_paths]
344    
345    model_name = model_path.stem
346    
347    model_output_dir = output_dir / model_name
348    model_output_dir.mkdir(parents=True, exist_ok=True)
349    
350    if verbose:
351        print(f"Loading model from {model_path}")
352    model = load_segmentation_model(model_path, verbose=verbose)
353    
354    if verbose:
355        print(f"Processing {len(image_paths)} images")
356        print(f"Output directory: {model_output_dir}")
357    
358    processed_count = 0
359    
360    for i, img_path in enumerate(image_paths, 1):
361        if not img_path.exists():
362            print(f"Warning: {img_path} does not exist, skipping")
363            continue
364        
365        if verbose:
366            print(f"\n[{i}/{len(image_paths)}] Processing {img_path.name}")
367        
368        image = cv2.imread(str(img_path), 0)
369        
370        if image is None:
371            print(f"Warning: Failed to load {img_path}, skipping")
372            continue
373        
374        mask = predict_single_image(model, image, patch_size, step_size, batch_size, verbose)
375        
376        output_filename = f"{img_path.stem}_{mask_type}.png"
377        output_path = model_output_dir / output_filename
378        
379        cv2.imwrite(str(output_path), mask)
380        
381        if verbose:
382            print(f"Saved mask to {output_path}")
383        
384        processed_count += 1
385    
386    if verbose:
387        print(f"\nCompleted processing {processed_count}/{len(image_paths)} images")
388    
389    return processed_count, model_output_dir

Run inference on multiple images using a trained segmentation model.

This is the main convenience function that handles model loading, image processing, and saving predictions. Output files are saved to a directory named after the model stem.

Arguments:
  • model_path: Path to the trained model file (.h5 or .keras).
  • image_paths: List of paths to input images (as Path objects or strings).
  • output_dir: Base directory where outputs will be saved.
  • mask_type: Type of mask being predicted (e.g., 'root', 'shoot').
  • patch_size: Size of square patches for prediction. Default is 256.
  • step_size: Step size for patch extraction. Default is 128.
  • batch_size: Number of patches to process at once. Default is 8.
  • verbose: If True, print progress information. Default is True.
Returns:

Number of images successfully processed.

Output structure:

output_dir/ model_name/ image1_mask_type.png image2_mask_type.png

Example:
>>> image_list = ['img1.png', 'img2.png', 'img3.png']
>>> n = run_inference(
...     model_path='shoots.h5',
...     image_paths=image_list,
...     output_dir='./predictions',
...     mask_type='shoot',
...     patch_size=256,
...     step_size=128
... )
>>> print(f"Processed {n} images")