library.root_analysis_visualization

  1from scipy.ndimage import binary_dilation
  2from scipy import ndimage
  3
  4import numpy as np
  5import matplotlib.pyplot as plt
  6from scipy import ndimage
  7import matplotlib.patches as mpatches
  8from library.root_analysis import *
  9
 10
 11def visualize_skeleton_with_nodes(skeleton, show_branch_labels=False, row_range=None, padding=50, dilate_iterations=3):
 12    """
 13    Visualize skeleton with nodes and optionally branch labels.
 14
 15    Args:
 16        skeleton (skeleton numpy array): object to show
 17        show_branch_labels (bool): add labels
 18        row_range(tuple): rows to zoom in on
 19        padding(int): add additional pixels to make skeleton easier to view
 20        dilate_iterations(int): number of rounds to dilate and increase size
 21
 22    Returns:
 23        branch_data (pandas table)
 24    """
 25    
 26    # Thicken skeleton for visibility
 27    if dilate_iterations > 0:
 28        thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
 29        skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
 30    else:
 31        skeleton_uint8 = (skeleton * 255).astype(np.uint8)
 32    
 33    # Find bounding box
 34    rows, cols = np.where(skeleton)
 35    
 36    if row_range is not None:
 37        row_min, row_max = row_range
 38        mask = (rows >= row_min) & (rows <= row_max)
 39        if np.any(mask):
 40            col_min, col_max = cols[mask].min(), cols[mask].max()
 41        else:
 42            col_min, col_max = cols.min(), cols.max()
 43    else:
 44        row_min, row_max = rows.min(), rows.max()
 45        col_min, col_max = cols.min(), cols.max()
 46    
 47    # Get branch data and nodes
 48    skeleton_obj = Skeleton(skeleton)
 49    branch_data = summarize(skeleton_obj)
 50    
 51    # Get node coordinates
 52    node_coords = {}
 53    for idx, row in branch_data.iterrows():
 54        node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1'])
 55        node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1'])
 56    
 57    # Plot
 58    fig, ax = plt.subplots(figsize=(14, 12))
 59    ax.imshow(skeleton_uint8, cmap='gray')
 60    
 61    # Track occupied regions (for collision detection)
 62    occupied_regions = []  # List of (row, col, radius) tuples
 63    
 64    # First pass: plot node markers
 65    visible_nodes = []
 66    for node_id, (node_row, node_col) in node_coords.items():
 67        if (row_min - padding <= node_row <= row_max + padding and 
 68            col_min - padding <= node_col <= col_max + padding):
 69            visible_nodes.append((node_id, node_row, node_col))
 70            ax.plot(node_col, node_row, 'ro', markersize=12, zorder=3)
 71            # Mark node position as occupied
 72            occupied_regions.append((node_row, node_col, 15))  # Small radius for node dot itself
 73    
 74    # Second pass: place node labels with collision avoidance
 75    for node_id, node_row, node_col in visible_nodes:
 76        # Try different offset positions for node label
 77        offsets = [(40, 0), (-40, 0), (0, 40), (0, -40), (30, 30), (-30, -30), (30, -30), (-30, 30)]
 78        best_offset = (40, 0)
 79        min_collision = float('inf')
 80        
 81        for offset_col, offset_row in offsets:
 82            test_col = node_col + offset_col
 83            test_row = node_row + offset_row
 84            
 85            # Check collision with all occupied regions
 86            max_collision = 0
 87            for occ_row, occ_col, radius in occupied_regions:
 88                dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2)
 89                if dist < radius:
 90                    max_collision = max(max_collision, radius - dist)
 91            
 92            if max_collision < min_collision:
 93                min_collision = max_collision
 94                best_offset = (offset_col, offset_row)
 95                if max_collision == 0:
 96                    break
 97        
 98        label_col = node_col + best_offset[0]
 99        label_row = node_row + best_offset[1]
100        
101        # Draw leader line
102        ax.plot([node_col, label_col], [node_row, label_row], 
103               'y--', linewidth=1, alpha=0.6, zorder=2)
104        
105        # Draw label
106        ax.text(label_col, label_row, str(int(node_id)), color='yellow', fontsize=11, 
107               bbox=dict(boxstyle='round', facecolor='black', alpha=0.7),
108               ha='center', zorder=4)
109        
110        # Mark label area as occupied
111        occupied_regions.append((label_row, label_col, 45))
112    
113    # Branch labels with collision avoidance
114    if show_branch_labels:
115        branch_positions = []
116        for idx, row in branch_data.iterrows():
117            mid_row = (row['coord-src-0'] + row['coord-dst-0']) / 2
118            mid_col = (row['coord-src-1'] + row['coord-dst-1']) / 2
119            
120            if (row_min - padding <= mid_row <= row_max + padding and 
121                col_min - padding <= mid_col <= col_max + padding):
122                branch_positions.append((idx, mid_row, mid_col))
123        
124        for idx, mid_row, mid_col in branch_positions:
125            # Try different offset positions
126            offsets = [(60, 0), (-60, 0), (0, 60), (0, -60), (45, 45), (-45, -45), 
127                      (120, 0), (-120, 0), (0, 120), (0, -120)]
128            best_offset = (60, 0)
129            min_collision = float('inf')
130            
131            for offset_col, offset_row in offsets:
132                test_col = mid_col + offset_col
133                test_row = mid_row + offset_row
134                
135                max_collision = 0
136                for occ_row, occ_col, radius in occupied_regions:
137                    dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2)
138                    if dist < radius:
139                        max_collision = max(max_collision, radius - dist)
140                
141                if max_collision < min_collision:
142                    min_collision = max_collision
143                    best_offset = (offset_col, offset_row)
144                    if max_collision == 0:
145                        break
146            
147            label_col = mid_col + best_offset[0]
148            label_row = mid_row + best_offset[1]
149            
150            # Draw leader line
151            ax.plot([mid_col, label_col], [mid_row, label_row], 
152                   'c--', linewidth=1, alpha=0.5, zorder=1)
153            
154            # Draw label
155            ax.text(label_col, label_row, f"B{idx}", color='cyan', fontsize=9, 
156                   bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7),
157                   ha='center', zorder=2)
158            
159            occupied_regions.append((label_row, label_col, 40))
160    
161    # Set zoom limits
162    ax.set_xlim(col_min - padding, col_max + padding)
163    ax.set_ylim(row_max + padding, row_min - padding)
164    
165    title = f'Skeleton with {len(node_coords)} nodes, {branch_data.shape[0]} branches'
166    if row_range:
167        title += f' (rows {row_range[0]}-{row_range[1]})'
168    ax.set_title(title)
169    
170    plt.show()
171    
172    return branch_data
173
174def visualize_trunk_path(skeleton, branch_data, trunk_path, title='Trunk Path', 
175                        dilate_iterations=3, zoom_to_path=True, padding=100):
176    """
177    Visualize the trunk path overlaid on the skeleton image.
178    
179    Args:
180        skeleton: Binary skeleton array
181        branch_data: DataFrame with branch information
182        trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...]
183        title: Plot title
184        dilate_iterations: Number of iterations to thicken skeleton
185        zoom_to_path: Whether to zoom to the trunk path region
186        padding: Padding around zoom region
187    """
188    
189    # Thicken skeleton for visibility
190    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
191    skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
192    
193    # Get all nodes in the path
194    path_nodes = [int(node) for node, _, _, _ in trunk_path]
195    path_branches = [idx for _, idx, _, _ in trunk_path if idx is not None]
196    
197    # Get coordinates for all nodes
198    all_node_coords = {}
199    for idx, row in branch_data.iterrows():
200        all_node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1'])
201        all_node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1'])
202    
203    # Get bounding box of path nodes for zooming
204    if zoom_to_path and path_nodes:
205        path_rows = [all_node_coords[n][0] for n in path_nodes if n in all_node_coords]
206        path_cols = [all_node_coords[n][1] for n in path_nodes if n in all_node_coords]
207        if path_rows and path_cols:
208            row_min, row_max = min(path_rows), max(path_rows)
209            col_min, col_max = min(path_cols), max(path_cols)
210        else:
211            zoom_to_path = False
212    
213    # Plot
214    fig, ax = plt.subplots(figsize=(14, 14))
215    ax.imshow(skeleton_uint8, cmap='gray')
216    
217    # Draw all nodes (small, gray, semi-transparent)
218    for node_id, (row, col) in all_node_coords.items():
219        ax.plot(col, row, 'o', color='gray', markersize=4, alpha=0.3)
220    
221    # Draw trunk path connections FIRST (behind nodes)
222    for i in range(len(path_nodes)-1):
223        node1, node2 = path_nodes[i], path_nodes[i+1]
224        if node1 in all_node_coords and node2 in all_node_coords:
225            r1, c1 = all_node_coords[node1]
226            r2, c2 = all_node_coords[node2]
227            ax.plot([c1, c2], [r1, r2], 'r-', linewidth=4, alpha=0.8, zorder=2)
228    
229    # Draw trunk path nodes (large, colored)
230    for i, node_id in enumerate(path_nodes):
231        if node_id in all_node_coords:
232            row, col = all_node_coords[node_id]
233            # Color gradient: dark red (start) to bright yellow (end)
234            color = plt.cm.autumn(i / max(len(path_nodes)-1, 1))
235            ax.plot(col, row, 'o', color=color, markersize=16, zorder=3, 
236                   markeredgecolor='white', markeredgewidth=2)
237            ax.text(col+15, row, str(node_id), color='cyan', fontsize=12, 
238                   fontweight='bold',
239                   bbox=dict(boxstyle='round', facecolor='black', alpha=0.8),
240                   zorder=4)
241    
242    # Zoom to path if requested
243    if zoom_to_path:
244        ax.set_xlim(col_min - padding, col_max + padding)
245        ax.set_ylim(row_max + padding, row_min - padding)
246    
247    # Title with path info
248    path_str = " → ".join(map(str, path_nodes[:10]))  # First 10 nodes
249    if len(path_nodes) > 10:
250        path_str += "..."
251    ax.set_title(f'{title}\nPath: {path_str}\n{title}', fontsize=14)
252    ax.axis('off')
253    plt.tight_layout()
254    plt.show()
255    
256    # Print summary
257    print(f"Trunk path: {len(path_nodes)} nodes, {len(path_branches)} branches")
258    print(f"Start node: {path_nodes[0]}, End node: {path_nodes[-1]}")
259    
260    # Calculate total length
261    total_length = sum(branch_data.iloc[idx]['branch-distance'] 
262                      for idx in path_branches)
263    print(f"Total trunk length: {total_length:.2f} pixels")
264    
265    return total_length
266
267def visualize_thickened_skeleton(binary_mask, dilate_iterations=3, figsize=(12, 8), show_labels=True):
268    """
269    Skeletonize, label, thicken, and visualize a binary mask.
270    
271    Args:
272        binary_mask: Binary numpy array (bool or 0/255) or path to image file
273        dilate_iterations: Number of dilation iterations to thicken skeleton
274        figsize: Figure size tuple
275        show_labels: Whether to show component labels on the image
276    
277    Returns:
278        tuple: (skeleton, labeled_skeleton, unique_labels)
279    """
280    # Load from file if path provided
281    if not isinstance(binary_mask, np.ndarray):
282        binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE)
283
284    
285    # Get labeled skeleton
286    skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask)
287    
288    # Thicken skeleton for visibility
289    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
290    thick_skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
291    
292    # Visualize
293    fig, ax = plt.subplots(figsize=figsize)
294    ax.imshow(thick_skeleton_uint8, cmap='gray')
295    ax.set_title(f'{len(unique_labels)} skeleton(s) (thickened for visibility)')
296    ax.axis('off')
297    
298    # Add labels for each structure
299    if show_labels and len(unique_labels) > 0:
300        for label_id in unique_labels:
301            # Get all pixels for this label
302            component_mask = (labeled_skeleton == label_id)
303            rows, cols = np.where(component_mask)
304            
305            if len(rows) > 0:
306                # Calculate centroid of this component
307                centroid_row = int(np.mean(rows))
308                centroid_col = int(np.mean(cols))
309                
310                # Find the bottommost point (highest row value)
311                max_row = rows.max()
312                
313                # Place label below the bottommost point
314                label_row = max_row + 50  # 50 pixels below
315                label_col = centroid_col
316                
317                ax.text(label_col, label_row, str(int(label_id)), 
318                       color='yellow', fontsize=14, fontweight='bold',
319                       bbox=dict(boxstyle='round,pad=0.5', facecolor='red', 
320                                edgecolor='white', linewidth=2, alpha=0.8),
321                       ha='center', va='center')
322    
323    plt.show()
324    
325    print(f"Found {len(unique_labels)} connected structure(s)")
326    print(f"Total skeleton pixels: {np.sum(skeleton)}")
327    
328    return skeleton, labeled_skeleton, unique_labels
329
330
331def visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points, size=(16, 8)):
332    """Visualize shoot-root assignments with ordered labels 1-5 and root labels.
333    
334    Args:
335        shoot_mask: Binary mask of shoot regions
336        root_mask: Binary mask of root regions (not directly used)
337        structures: Dictionary from extract_root_structures containing 'roots' with root data
338        assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments
339        ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids
340        size: Figure size tuple (width, height)
341    
342    Returns:
343        None: Displays the visualization
344    """
345    # Label the shoot mask
346    labeled_shoots, num_shoots = ndimage.label(shoot_mask)
347    
348    # Get image dimensions
349    h, w = shoot_mask.shape
350    
351    # Create blank RGB image
352    rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
353    
354    # Color palette (5 distinct colors)
355    colors = [
356        [255, 0, 0],      # Red
357        [255, 165, 0],    # Orange  
358        [255, 255, 0],    # Yellow
359        [0, 255, 255],    # Cyan
360        [255, 0, 255]     # Magenta
361    ]
362    
363    legend_elements = []
364    
365    # Process each shoot in order
366    for shoot_label, info in assignments.items():
367        order = info['order']
368        root_label = info['root_label']
369        color = colors[order]
370        
371        # Draw shoot
372        shoot_region = (labeled_shoots == shoot_label)
373        rgb_img[shoot_region] = color
374        
375        # Draw root if assigned
376        if root_label is not None:
377            root_mask = structures['roots'][root_label]['mask']
378            dilated_root = ndimage.binary_dilation(root_mask, iterations=2)
379            rgb_img[dilated_root] = color
380            
381            score = info['score_dict']['score']
382            legend_elements.append(
383                mpatches.Patch(color=np.array(color)/255, 
384                             label=f"S{order+1}->R{root_label} ({score:.2f})")
385            )
386        else:
387            legend_elements.append(
388                mpatches.Patch(color=np.array(color)/255, 
389                             label=f"S{order+1}->None")
390            )
391    
392    # Create figure
393    fig, ax = plt.subplots(figsize=size)
394    ax.imshow(rgb_img)
395    
396    # Add shoot order labels and root labels
397    for shoot_label, info in assignments.items():
398        order = info['order']
399        root_label = info['root_label']
400        
401        if root_label is not None:
402            # Find root endpoint (lowest y-coordinate)
403            root_mask = structures['roots'][root_label]['mask']
404            y_coords, x_coords = np.where(root_mask)
405            
406            if len(y_coords) > 0:
407                max_y = np.max(y_coords)
408                max_y_idx = np.argmax(y_coords)
409                endpoint_x = x_coords[max_y_idx]
410                
411                # Shoot order label 50px below endpoint
412                ax.text(endpoint_x, max_y + 300, str(order + 1), 
413                       color='white', fontsize=20, fontweight='bold',
414                       ha='center', va='center',
415                       bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7))
416        else:
417            # Place label at shoot centroid for ungerminated
418            cy, cx = ref_points[shoot_label]['centroid']
419            ax.text(cx, cy, str(order + 1), 
420                   color='white', fontsize=20, fontweight='bold',
421                   ha='center', va='center',
422                   bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7))
423    
424    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
425    ax.set_title('Shoot-Root Assignments (Shoot# = Left->Right Order)', fontsize=12)
426    ax.axis('off')
427    
428    plt.tight_layout()
429    plt.show()
430
431
432def visualize_root_shoot_matches(structures, labeled_shoots, root_matches, 
433                                  zoom_to_content=True, padding=50, 
434                                  root_thickness=3):
435    """
436    Visualize root-to-shoot matches with color coding.
437    
438    Args:
439        structures: Dict from extract_root_structures
440        labeled_shoots: Labeled shoot array
441        root_matches: Dict from match_roots_to_shoots
442        zoom_to_content: If True, zoom to bounding box of all content
443        padding: Pixels to add around bounding box
444        root_thickness: Pixels to dilate root skeletons for visibility
445    """
446    # Create RGB image
447    h, w = structures['labeled_skeleton'].shape
448    vis_img = np.zeros((h, w, 3), dtype=np.uint8)
449    
450    # Color map for shoots (distinct colors)
451    colors = [
452        [255, 0, 0],      # Red
453        [0, 255, 0],      # Green
454        [0, 0, 255],      # Blue
455        [255, 255, 0],    # Yellow
456        [255, 0, 255],    # Magenta
457        [0, 255, 255],    # Cyan
458        [255, 128, 0],    # Orange
459        [128, 0, 255],    # Purple
460    ]
461    
462    # Draw shoots with their colors
463    for shoot_label in range(1, labeled_shoots.max() + 1):
464        shoot_mask = labeled_shoots == shoot_label
465        color = colors[(shoot_label - 1) % len(colors)]
466        vis_img[shoot_mask] = color
467    
468    # Draw roots with matching shoot colors (or white for unmatched)
469    for root_label in structures['unique_labels']:
470        root_skeleton = structures['labeled_skeleton'] == root_label
471        
472        # Thicken the skeleton for visibility
473        if root_thickness > 0:
474            root_skeleton = ndimage.binary_dilation(
475                root_skeleton, 
476                structure=ndimage.generate_binary_structure(2, 1),
477                iterations=root_thickness
478            )
479        
480        match = root_matches.get(root_label)
481        if match and match is not None:
482            shoot_label = match['shoot_label']
483            color = colors[(shoot_label - 1) % len(colors)]
484        else:
485            color = [128, 128, 128]  # Gray for unmatched
486        
487        vis_img[root_skeleton] = color
488    
489    # Find bounding box if zooming
490    if zoom_to_content:
491        # Find all non-black pixels
492        content_mask = np.any(vis_img > 0, axis=2)
493        rows, cols = np.where(content_mask)
494        
495        if len(rows) > 0:
496            y_min = max(0, rows.min() - padding)
497            y_max = min(h, rows.max() + padding)
498            x_min = max(0, cols.min() - padding)
499            x_max = min(w, cols.max() + padding)
500            
501            vis_img = vis_img[y_min:y_max, x_min:x_max]
502    
503    # Plot
504    fig, ax = plt.subplots(figsize=(15, 8))
505    ax.imshow(vis_img)
506    ax.set_title('Root-Shoot Matches (same color = matched)', fontsize=14)
507    ax.axis('off')
508    
509    # Create legend
510    legend_elements = []
511    matched_shoots = set()
512    for root_label, match in root_matches.items():
513        if match and match is not None:
514            matched_shoots.add(match['shoot_label'])
515    
516    for shoot_label in sorted(matched_shoots):
517        color = np.array(colors[(shoot_label - 1) % len(colors)]) / 255.0
518        legend_elements.append(
519            mpatches.Patch(color=color, label=f'Shoot/Root {shoot_label}')
520        )
521    
522    # Add unmatched if any
523    unmatched_count = sum(1 for m in root_matches.values() if m is None)
524    if unmatched_count > 0:
525        legend_elements.append(
526            mpatches.Patch(color=[0.5, 0.5, 0.5], 
527                          label=f'Unmatched ({unmatched_count})')
528        )
529    
530    ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
531    
532    plt.tight_layout()
533    plt.show()
534    
535    # Print summary
536    print(f"\nMatch Summary:")
537    print(f"Total root structures: {len(structures['unique_labels'])}")
538    print(f"Matched: {sum(1 for m in root_matches.values() if m is not None)}")
539    print(f"Unmatched (noise): {unmatched_count}")
540
541
542def get_skeleton_path_from_branch(skeleton, branch_row):
543    """
544    Extract actual skeleton pixels for a branch using the labeled skeleton.
545    
546    Args:
547        skeleton: Binary skeleton array
548        branch_row: Row from branch_data DataFrame
549        
550    Returns:
551        Array of (row, col) coordinates along the skeleton path
552    """
553    from skimage.morphology import skeletonize
554    from skimage.graph import route_through_array
555    
556    start_coord = (int(branch_row['image-coord-src-0']), 
557                   int(branch_row['image-coord-src-1']))
558    end_coord = (int(branch_row['image-coord-dst-0']), 
559                 int(branch_row['image-coord-dst-1']))
560    
561    # Create cost array: low cost on skeleton, high elsewhere
562    cost = np.where(skeleton, 1, 10000)
563    
564    try:
565        indices, weight = route_through_array(
566            cost, start_coord, end_coord, fully_connected=True
567        )
568        return np.array(indices)
569    except:
570        # Fallback to straight line
571        return np.array([start_coord, end_coord])
572
573def visualize_single_root_structure(skeleton, branch_data, title='Root Structure',
574                                   dilate_iterations=2, zoom_to_content=True, 
575                                   padding=50, trunk_path=None,
576                                   path_edge_width=3, path_node_size=10,
577                                   show_edge_dots=True, edge_dot_spacing=50,
578                                   show_cumulative_length=True,
579                                   output_path=None):
580    """Visualize a single root structure with trunk path and cumulative length measurements.
581    
582    Displays a root skeleton with the main trunk path highlighted in blue, showing nodes
583    along the path with gradient coloring (red to yellow) and optional cumulative length
584    labels at regular intervals. The path follows the actual skeleton pixels rather than
585    straight-line connections between nodes.
586    
587    Args:
588        skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure.
589        branch_data (pd.DataFrame): DataFrame containing branch information with columns:
590            'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1',
591            'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'.
592        title (str, optional): Title for the visualization. Defaults to 'Root Structure'.
593        dilate_iterations (int, optional): Number of morphological dilation iterations 
594            to thicken skeleton for visibility. Defaults to 2.
595        zoom_to_content (bool, optional): Whether to zoom to the bounding box of the 
596            root structure. Defaults to True.
597        padding (int, optional): Padding in pixels around the zoomed content. 
598            Defaults to 50.
599        trunk_path (list, optional): List of tuples from find_farthest_endpoint_path:
600            [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton
601            is shown. Defaults to None.
602        path_edge_width (float, optional): Line width for trunk path edges. 
603            Defaults to 3.
604        path_node_size (float, optional): Marker size for nodes along trunk path. 
605            Defaults to 10.
606        show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path.
607            Defaults to True.
608        edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum
609            spacing is 50 pixels to prevent overcrowding. Defaults to 50.
610        show_cumulative_length (bool, optional): Whether to display yellow labels showing
611            cumulative length at regular intervals along the path. Defaults to True.
612        output_path (str or Path): path to save image
613    
614    Returns:
615        None: Displays matplotlib figure showing the root structure visualization.
616    
617    Notes:
618        - Trunk path is drawn in blue following actual skeleton pixels
619        - Cyan dots mark regular intervals along the path
620        - Nodes are colored with autumn colormap (red=start, yellow=end)
621        - Yellow labels show cumulative length every other dot
622        - Lime green label shows total length at the final node
623        - Node IDs are displayed in white text boxes offset from nodes
624    
625    Examples:
626        >>> skeleton = structures['labeled_skeleton'] == root_label
627        >>> branch_data = structures['roots'][root_label]['branch_data']
628        >>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down')
629        >>> visualize_single_root_structure(
630        ...     skeleton, branch_data, 
631        ...     title='Root 35 Structure',
632        ...     trunk_path=path,
633        ...     edge_dot_spacing=100
634        ... )
635    """
636    from scipy.ndimage import binary_dilation
637    
638    # Thicken skeleton for display
639    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
640    skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
641    
642    # Get all node coordinates
643    all_node_coords = {}
644    for idx, row in branch_data.iterrows():
645        all_node_coords[row['node-id-src']] = (row['image-coord-src-0'], 
646                                                row['image-coord-src-1'])
647        all_node_coords[row['node-id-dst']] = (row['image-coord-dst-0'], 
648                                                row['image-coord-dst-1'])
649    
650    # Get trunk path nodes
651    path_nodes = []
652    path_branch_indices = []
653    if trunk_path:
654        for node, branch_idx, _, _ in trunk_path:
655            path_nodes.append(int(node))
656            if branch_idx is not None:
657                path_branch_indices.append(branch_idx)
658    
659    # Bounding box
660    if zoom_to_content and all_node_coords:
661        rows = [coord[0] for coord in all_node_coords.values()]
662        cols = [coord[1] for coord in all_node_coords.values()]
663        row_min, row_max = min(rows), max(rows)
664        col_min, col_max = min(cols), max(cols)
665    else:
666        zoom_to_content = False
667    
668    # Plot
669    fig, ax = plt.subplots(figsize=(12, 12))
670    ax.imshow(skeleton_uint8, cmap='gray')
671    
672    # Draw trunk path and collect cumulative lengths
673    all_path_pixels = []
674    cumulative_lengths = []  # Store (pixel_index, cumulative_length)
675    current_length = 0.0
676    total_length = 0.0
677    
678    if trunk_path:
679        for i, (node, branch_idx, distance, _) in enumerate(trunk_path):
680            if branch_idx is not None:
681                branch_row = branch_data.iloc[branch_idx]
682                
683                # Get actual skeleton pixels for this branch
684                path_pixels = get_skeleton_path_from_branch(skeleton, branch_row)
685                
686                # Track cumulative length at start of this segment
687                pixel_start_idx = len(np.vstack(all_path_pixels)) if all_path_pixels else 0
688                cumulative_lengths.append((pixel_start_idx, current_length))
689                
690                all_path_pixels.append(path_pixels)
691                
692                # Add this segment's length
693                current_length += distance
694                total_length += distance
695                
696                # Draw the path in blue
697                ax.plot(path_pixels[:, 1], path_pixels[:, 0], '-',
698                       color='blue', linewidth=path_edge_width,
699                       alpha=0.8, zorder=2)
700        
701        # Add final length
702        if all_path_pixels:
703            final_idx = len(np.vstack(all_path_pixels))
704            cumulative_lengths.append((final_idx, current_length))
705        
706        # Draw dots and labels
707        if show_edge_dots and all_path_pixels:
708            # Concatenate all segments
709            combined_path = np.vstack(all_path_pixels)
710            
711            # Ensure minimum spacing of 50px
712            actual_spacing = max(edge_dot_spacing, 50)
713            sampled_indices = np.arange(0, len(combined_path), actual_spacing)
714            
715            # Draw dots
716            sampled_pixels = combined_path[sampled_indices]
717            ax.plot(sampled_pixels[:, 1], sampled_pixels[:, 0], 'o',
718                   color='cyan', markersize=4, alpha=0.9, zorder=2.5,
719                   markeredgecolor='blue', markeredgewidth=0.5)
720            
721            # Add cumulative length labels
722            if show_cumulative_length:
723                for sample_idx in sampled_indices[::2]:  # Label every other dot to reduce clutter
724                    if sample_idx < len(combined_path):
725                        # Find cumulative length at this pixel
726                        cum_length = 0.0
727                        for seg_idx, (pixel_idx, length) in enumerate(cumulative_lengths[:-1]):
728                            next_pixel_idx = cumulative_lengths[seg_idx + 1][0]
729                            if pixel_idx <= sample_idx < next_pixel_idx:
730                                # Interpolate within this segment
731                                segment_progress = (sample_idx - pixel_idx) / max(next_pixel_idx - pixel_idx, 1)
732                                next_length = cumulative_lengths[seg_idx + 1][1]
733                                cum_length = length + segment_progress * (next_length - length)
734                                break
735                        else:
736                            cum_length = current_length
737                        
738                        row, col = combined_path[sample_idx]
739                        ax.text(col + 15, row, f'{cum_length:.0f}px',
740                               color='yellow', fontsize=8, fontweight='bold',
741                               ha='left', va='center',
742                               bbox=dict(boxstyle='round,pad=0.2',
743                                       facecolor='black', alpha=0.6),
744                               zorder=4)
745    
746    # Draw path nodes
747    if trunk_path:
748        for i, node_id in enumerate(path_nodes):
749            if node_id in all_node_coords:
750                row, col = all_node_coords[node_id]
751                color = plt.cm.autumn(i / max(len(path_nodes)-1, 1))
752                ax.plot(col, row, 'o', color=color, markersize=path_node_size,
753                       zorder=3, markeredgecolor='white', markeredgewidth=1.5)
754                
755                # Add node label
756                offset_x = 20
757                offset_y = -5
758                ax.text(col + offset_x, row + offset_y, str(node_id),
759                       color='white', fontsize=9, fontweight='bold',
760                       ha='left', va='center',
761                       bbox=dict(boxstyle='round,pad=0.3',
762                                facecolor='black', alpha=0.7),
763                       zorder=4)
764                
765                # Add total length at final node
766                if i == len(path_nodes) - 1 and show_cumulative_length:
767                    ax.text(col + 15, row + 20, f'{total_length:.1f}px',
768                           color='lime', fontsize=10, fontweight='bold',
769                           ha='left', va='top',
770                           bbox=dict(boxstyle='round,pad=0.3',
771                                   facecolor='black', alpha=0.8),
772                           zorder=5)
773    
774    if zoom_to_content:
775        ax.set_xlim(col_min - padding, col_max + padding)
776        ax.set_ylim(row_max + padding, row_min - padding)
777    
778    ax.set_title(title, fontsize=14)
779    ax.axis('off')
780    plt.tight_layout()
781    if output_path:
782        plt.savefig(str(output_path))
783    plt.show()
784
785
786def visualize_root_lengths(structures, top_node_results, labeled_shoots, 
787                          zoom_to_content=True, padding=50,
788                          show_detailed_roots=False, detail_viz_kwargs=None):
789    """Visualize matched roots with their calculated lengths displayed.
790    
791    Creates an overview visualization showing all matched root structures color-coded
792    by their associated shoot regions, with length measurements labeled. Optionally
793    displays detailed individual root visualizations below the overview.
794    
795    Args:
796        structures (dict): Dictionary from extract_root_structures containing:
797            - 'labeled_skeleton': Array with labeled skeleton structures
798            - 'roots': Dictionary of root data by label
799        top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys:
800            - root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes'
801        labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each
802            shoot region has a unique integer label.
803        zoom_to_content (bool, optional): If True, zoom to bounding box of all content.
804            Defaults to True.
805        padding (int, optional): Pixels to add around bounding box when zooming.
806            Defaults to 50.
807        show_detailed_roots (bool, optional): If True, display detailed visualization
808            for each individual root below the overview. Defaults to False.
809        detail_viz_kwargs (dict, optional): Keyword arguments to pass to 
810            visualize_single_root_structure for detailed views. Common options:
811            - 'dilate_iterations': int, skeleton thickness (default: 2)
812            - 'path_edge_width': float, trunk path line width (default: 3)
813            - 'path_node_size': float, node marker size (default: 10)
814            - 'show_edge_dots': bool, show dots along path (default: True)
815            - 'edge_dot_spacing': int, spacing between dots (default: 50)
816            - 'show_cumulative_length': bool, show length labels (default: True)
817            If None, uses default values. Defaults to None.
818    
819    Returns:
820        None: Displays matplotlib figures showing root length visualizations.
821    
822    Notes:
823        - Overview uses color-coding: each root matches its shoot color
824        - Roots are thickened with binary dilation for visibility
825        - Length labels appear at root centroids in white text boxes
826        - Detailed views are indexed (idx=0, idx=1, etc.) in display order
827        - Failed root measurements appear in gray with no length label
828    
829    Examples:
830        >>> # Basic usage - overview only
831        >>> visualize_root_lengths(structures, top_node_results, labeled_shoots)
832        
833        >>> # With detailed individual visualizations
834        >>> visualize_root_lengths(
835        ...     structures, top_node_results, labeled_shoots,
836        ...     show_detailed_roots=True,
837        ...     detail_viz_kwargs={
838        ...         'edge_dot_spacing': 100,
839        ...         'path_edge_width': 2,
840        ...         'show_cumulative_length': True
841        ...     }
842        ... )
843    """
844    
845    # Default kwargs for detailed visualizations
846    if detail_viz_kwargs is None:
847        detail_viz_kwargs = {}
848    
849    # Create RGB image
850    h, w = structures['labeled_skeleton'].shape
851    vis_img = np.zeros((h, w, 3), dtype=np.uint8)
852    
853    # Color map
854    colors = [
855        [255, 0, 0],      # Red
856        [0, 255, 0],      # Green
857        [0, 0, 255],      # Blue
858        [255, 255, 0],    # Yellow
859        [255, 0, 255],    # Magenta
860        [0, 255, 255],    # Cyan
861        [255, 128, 0],    # Orange
862        [128, 0, 255],    # Purple
863    ]
864    
865    # Draw shoots
866    for shoot_label in range(1, labeled_shoots.max() + 1):
867        shoot_mask = labeled_shoots == shoot_label
868        color = colors[(shoot_label - 1) % len(colors)]
869        vis_img[shoot_mask] = color
870    
871    # Process and draw roots with paths
872    root_info = []
873    root_details = []  # Store data for detailed visualizations
874    
875    for root_label, result in top_node_results.items():
876        branch_data = result['branch_data']
877        shoot_label = result['shoot_label']
878        top_node = result['top_nodes'][0][0]
879        
880        # Get root skeleton and thicken
881        root_skeleton = structures['labeled_skeleton'] == root_label
882        root_skeleton = ndimage.binary_dilation(
883            root_skeleton, 
884            structure=ndimage.generate_binary_structure(2, 1),
885            iterations=3
886        )
887        
888        color = colors[(shoot_label - 1) % len(colors)]
889        
890        try:
891            # Find path and calculate length
892            path = find_farthest_endpoint_path(
893                branch_data, top_node, 
894                direction='down', use_smart_scoring=True,
895                verbose=False
896            )
897            root_length = calculate_skeleton_length_px(path)
898            
899            # Draw root
900            vis_img[root_skeleton] = color
901            
902            # Store info for labels
903            y_coords, x_coords = np.where(root_skeleton)
904            centroid_y, centroid_x = np.mean(y_coords), np.mean(x_coords)
905            root_info.append((centroid_x, centroid_y, root_label, shoot_label, root_length))
906            
907            # Store for detailed visualization
908            if show_detailed_roots:
909                root_details.append({
910                    'root_label': root_label,
911                    'skeleton': structures['labeled_skeleton'] == root_label,
912                    'branch_data': branch_data,
913                    'path': path,
914                    'shoot_label': shoot_label
915                })
916            
917        except Exception as e:
918            # Draw root in gray if failed
919            vis_img[root_skeleton] = [128, 128, 128]
920    
921    # Zoom if requested
922    if zoom_to_content:
923        content_mask = np.any(vis_img > 0, axis=2)
924        rows, cols = np.where(content_mask)
925        if len(rows) > 0:
926            y_min = max(0, rows.min() - padding)
927            y_max = min(h, rows.max() + padding)
928            x_min = max(0, cols.min() - padding)
929            x_max = min(w, cols.max() + padding)
930            vis_img = vis_img[y_min:y_max, x_min:x_max]
931            # Adjust coordinates for zoom
932            root_info = [(x - x_min, y - y_min, rl, sl, length) 
933                        for x, y, rl, sl, length in root_info]
934    
935    # Plot overview
936    fig, ax = plt.subplots(figsize=(15, 8))
937    ax.imshow(vis_img)
938    
939    # Add text labels
940    for x, y, root_label, shoot_label, length in root_info:
941        ax.text(x, y, f'{length:.1f}px', 
942               color='white', fontsize=10, weight='bold',
943               ha='center', va='center',
944               bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
945    
946    ax.set_title('Root Lengths (pixels)', fontsize=14)
947    ax.axis('off')
948    plt.tight_layout()
949    plt.show()
950    
951    # Display detailed visualizations if requested
952    if show_detailed_roots and root_details:
953        print(f"\n{'='*60}")
954        print(f"Detailed Root Visualizations ({len(root_details)} roots)")
955        print(f"{'='*60}\n")
956        
957        for idx, detail in enumerate(root_details):
958            print(f"Displaying Root {detail['root_label']} (idx={idx})...")
959            
960            visualize_single_root_structure(
961                skeleton=detail['skeleton'],
962                branch_data=detail['branch_data'],
963                title=f"Root {detail['root_label']} (idx={idx}) - Shoot {detail['shoot_label']}",
964                trunk_path=detail['path'],
965                **detail_viz_kwargs
966            )
def visualize_skeleton_with_nodes( skeleton, show_branch_labels=False, row_range=None, padding=50, dilate_iterations=3):
 12def visualize_skeleton_with_nodes(skeleton, show_branch_labels=False, row_range=None, padding=50, dilate_iterations=3):
 13    """
 14    Visualize skeleton with nodes and optionally branch labels.
 15
 16    Args:
 17        skeleton (skeleton numpy array): object to show
 18        show_branch_labels (bool): add labels
 19        row_range(tuple): rows to zoom in on
 20        padding(int): add additional pixels to make skeleton easier to view
 21        dilate_iterations(int): number of rounds to dilate and increase size
 22
 23    Returns:
 24        branch_data (pandas table)
 25    """
 26    
 27    # Thicken skeleton for visibility
 28    if dilate_iterations > 0:
 29        thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
 30        skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
 31    else:
 32        skeleton_uint8 = (skeleton * 255).astype(np.uint8)
 33    
 34    # Find bounding box
 35    rows, cols = np.where(skeleton)
 36    
 37    if row_range is not None:
 38        row_min, row_max = row_range
 39        mask = (rows >= row_min) & (rows <= row_max)
 40        if np.any(mask):
 41            col_min, col_max = cols[mask].min(), cols[mask].max()
 42        else:
 43            col_min, col_max = cols.min(), cols.max()
 44    else:
 45        row_min, row_max = rows.min(), rows.max()
 46        col_min, col_max = cols.min(), cols.max()
 47    
 48    # Get branch data and nodes
 49    skeleton_obj = Skeleton(skeleton)
 50    branch_data = summarize(skeleton_obj)
 51    
 52    # Get node coordinates
 53    node_coords = {}
 54    for idx, row in branch_data.iterrows():
 55        node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1'])
 56        node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1'])
 57    
 58    # Plot
 59    fig, ax = plt.subplots(figsize=(14, 12))
 60    ax.imshow(skeleton_uint8, cmap='gray')
 61    
 62    # Track occupied regions (for collision detection)
 63    occupied_regions = []  # List of (row, col, radius) tuples
 64    
 65    # First pass: plot node markers
 66    visible_nodes = []
 67    for node_id, (node_row, node_col) in node_coords.items():
 68        if (row_min - padding <= node_row <= row_max + padding and 
 69            col_min - padding <= node_col <= col_max + padding):
 70            visible_nodes.append((node_id, node_row, node_col))
 71            ax.plot(node_col, node_row, 'ro', markersize=12, zorder=3)
 72            # Mark node position as occupied
 73            occupied_regions.append((node_row, node_col, 15))  # Small radius for node dot itself
 74    
 75    # Second pass: place node labels with collision avoidance
 76    for node_id, node_row, node_col in visible_nodes:
 77        # Try different offset positions for node label
 78        offsets = [(40, 0), (-40, 0), (0, 40), (0, -40), (30, 30), (-30, -30), (30, -30), (-30, 30)]
 79        best_offset = (40, 0)
 80        min_collision = float('inf')
 81        
 82        for offset_col, offset_row in offsets:
 83            test_col = node_col + offset_col
 84            test_row = node_row + offset_row
 85            
 86            # Check collision with all occupied regions
 87            max_collision = 0
 88            for occ_row, occ_col, radius in occupied_regions:
 89                dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2)
 90                if dist < radius:
 91                    max_collision = max(max_collision, radius - dist)
 92            
 93            if max_collision < min_collision:
 94                min_collision = max_collision
 95                best_offset = (offset_col, offset_row)
 96                if max_collision == 0:
 97                    break
 98        
 99        label_col = node_col + best_offset[0]
100        label_row = node_row + best_offset[1]
101        
102        # Draw leader line
103        ax.plot([node_col, label_col], [node_row, label_row], 
104               'y--', linewidth=1, alpha=0.6, zorder=2)
105        
106        # Draw label
107        ax.text(label_col, label_row, str(int(node_id)), color='yellow', fontsize=11, 
108               bbox=dict(boxstyle='round', facecolor='black', alpha=0.7),
109               ha='center', zorder=4)
110        
111        # Mark label area as occupied
112        occupied_regions.append((label_row, label_col, 45))
113    
114    # Branch labels with collision avoidance
115    if show_branch_labels:
116        branch_positions = []
117        for idx, row in branch_data.iterrows():
118            mid_row = (row['coord-src-0'] + row['coord-dst-0']) / 2
119            mid_col = (row['coord-src-1'] + row['coord-dst-1']) / 2
120            
121            if (row_min - padding <= mid_row <= row_max + padding and 
122                col_min - padding <= mid_col <= col_max + padding):
123                branch_positions.append((idx, mid_row, mid_col))
124        
125        for idx, mid_row, mid_col in branch_positions:
126            # Try different offset positions
127            offsets = [(60, 0), (-60, 0), (0, 60), (0, -60), (45, 45), (-45, -45), 
128                      (120, 0), (-120, 0), (0, 120), (0, -120)]
129            best_offset = (60, 0)
130            min_collision = float('inf')
131            
132            for offset_col, offset_row in offsets:
133                test_col = mid_col + offset_col
134                test_row = mid_row + offset_row
135                
136                max_collision = 0
137                for occ_row, occ_col, radius in occupied_regions:
138                    dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2)
139                    if dist < radius:
140                        max_collision = max(max_collision, radius - dist)
141                
142                if max_collision < min_collision:
143                    min_collision = max_collision
144                    best_offset = (offset_col, offset_row)
145                    if max_collision == 0:
146                        break
147            
148            label_col = mid_col + best_offset[0]
149            label_row = mid_row + best_offset[1]
150            
151            # Draw leader line
152            ax.plot([mid_col, label_col], [mid_row, label_row], 
153                   'c--', linewidth=1, alpha=0.5, zorder=1)
154            
155            # Draw label
156            ax.text(label_col, label_row, f"B{idx}", color='cyan', fontsize=9, 
157                   bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7),
158                   ha='center', zorder=2)
159            
160            occupied_regions.append((label_row, label_col, 40))
161    
162    # Set zoom limits
163    ax.set_xlim(col_min - padding, col_max + padding)
164    ax.set_ylim(row_max + padding, row_min - padding)
165    
166    title = f'Skeleton with {len(node_coords)} nodes, {branch_data.shape[0]} branches'
167    if row_range:
168        title += f' (rows {row_range[0]}-{row_range[1]})'
169    ax.set_title(title)
170    
171    plt.show()
172    
173    return branch_data

Visualize skeleton with nodes and optionally branch labels.

Arguments:
  • skeleton (skeleton numpy array): object to show
  • show_branch_labels (bool): add labels
  • row_range(tuple): rows to zoom in on
  • padding(int): add additional pixels to make skeleton easier to view
  • dilate_iterations(int): number of rounds to dilate and increase size
Returns:

branch_data (pandas table)

def visualize_trunk_path( skeleton, branch_data, trunk_path, title='Trunk Path', dilate_iterations=3, zoom_to_path=True, padding=100):
175def visualize_trunk_path(skeleton, branch_data, trunk_path, title='Trunk Path', 
176                        dilate_iterations=3, zoom_to_path=True, padding=100):
177    """
178    Visualize the trunk path overlaid on the skeleton image.
179    
180    Args:
181        skeleton: Binary skeleton array
182        branch_data: DataFrame with branch information
183        trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...]
184        title: Plot title
185        dilate_iterations: Number of iterations to thicken skeleton
186        zoom_to_path: Whether to zoom to the trunk path region
187        padding: Padding around zoom region
188    """
189    
190    # Thicken skeleton for visibility
191    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
192    skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
193    
194    # Get all nodes in the path
195    path_nodes = [int(node) for node, _, _, _ in trunk_path]
196    path_branches = [idx for _, idx, _, _ in trunk_path if idx is not None]
197    
198    # Get coordinates for all nodes
199    all_node_coords = {}
200    for idx, row in branch_data.iterrows():
201        all_node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1'])
202        all_node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1'])
203    
204    # Get bounding box of path nodes for zooming
205    if zoom_to_path and path_nodes:
206        path_rows = [all_node_coords[n][0] for n in path_nodes if n in all_node_coords]
207        path_cols = [all_node_coords[n][1] for n in path_nodes if n in all_node_coords]
208        if path_rows and path_cols:
209            row_min, row_max = min(path_rows), max(path_rows)
210            col_min, col_max = min(path_cols), max(path_cols)
211        else:
212            zoom_to_path = False
213    
214    # Plot
215    fig, ax = plt.subplots(figsize=(14, 14))
216    ax.imshow(skeleton_uint8, cmap='gray')
217    
218    # Draw all nodes (small, gray, semi-transparent)
219    for node_id, (row, col) in all_node_coords.items():
220        ax.plot(col, row, 'o', color='gray', markersize=4, alpha=0.3)
221    
222    # Draw trunk path connections FIRST (behind nodes)
223    for i in range(len(path_nodes)-1):
224        node1, node2 = path_nodes[i], path_nodes[i+1]
225        if node1 in all_node_coords and node2 in all_node_coords:
226            r1, c1 = all_node_coords[node1]
227            r2, c2 = all_node_coords[node2]
228            ax.plot([c1, c2], [r1, r2], 'r-', linewidth=4, alpha=0.8, zorder=2)
229    
230    # Draw trunk path nodes (large, colored)
231    for i, node_id in enumerate(path_nodes):
232        if node_id in all_node_coords:
233            row, col = all_node_coords[node_id]
234            # Color gradient: dark red (start) to bright yellow (end)
235            color = plt.cm.autumn(i / max(len(path_nodes)-1, 1))
236            ax.plot(col, row, 'o', color=color, markersize=16, zorder=3, 
237                   markeredgecolor='white', markeredgewidth=2)
238            ax.text(col+15, row, str(node_id), color='cyan', fontsize=12, 
239                   fontweight='bold',
240                   bbox=dict(boxstyle='round', facecolor='black', alpha=0.8),
241                   zorder=4)
242    
243    # Zoom to path if requested
244    if zoom_to_path:
245        ax.set_xlim(col_min - padding, col_max + padding)
246        ax.set_ylim(row_max + padding, row_min - padding)
247    
248    # Title with path info
249    path_str = " → ".join(map(str, path_nodes[:10]))  # First 10 nodes
250    if len(path_nodes) > 10:
251        path_str += "..."
252    ax.set_title(f'{title}\nPath: {path_str}\n{title}', fontsize=14)
253    ax.axis('off')
254    plt.tight_layout()
255    plt.show()
256    
257    # Print summary
258    print(f"Trunk path: {len(path_nodes)} nodes, {len(path_branches)} branches")
259    print(f"Start node: {path_nodes[0]}, End node: {path_nodes[-1]}")
260    
261    # Calculate total length
262    total_length = sum(branch_data.iloc[idx]['branch-distance'] 
263                      for idx in path_branches)
264    print(f"Total trunk length: {total_length:.2f} pixels")
265    
266    return total_length

Visualize the trunk path overlaid on the skeleton image.

Arguments:
  • skeleton: Binary skeleton array
  • branch_data: DataFrame with branch information
  • trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...]
  • title: Plot title
  • dilate_iterations: Number of iterations to thicken skeleton
  • zoom_to_path: Whether to zoom to the trunk path region
  • padding: Padding around zoom region
def visualize_thickened_skeleton(binary_mask, dilate_iterations=3, figsize=(12, 8), show_labels=True):
268def visualize_thickened_skeleton(binary_mask, dilate_iterations=3, figsize=(12, 8), show_labels=True):
269    """
270    Skeletonize, label, thicken, and visualize a binary mask.
271    
272    Args:
273        binary_mask: Binary numpy array (bool or 0/255) or path to image file
274        dilate_iterations: Number of dilation iterations to thicken skeleton
275        figsize: Figure size tuple
276        show_labels: Whether to show component labels on the image
277    
278    Returns:
279        tuple: (skeleton, labeled_skeleton, unique_labels)
280    """
281    # Load from file if path provided
282    if not isinstance(binary_mask, np.ndarray):
283        binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE)
284
285    
286    # Get labeled skeleton
287    skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask)
288    
289    # Thicken skeleton for visibility
290    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
291    thick_skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
292    
293    # Visualize
294    fig, ax = plt.subplots(figsize=figsize)
295    ax.imshow(thick_skeleton_uint8, cmap='gray')
296    ax.set_title(f'{len(unique_labels)} skeleton(s) (thickened for visibility)')
297    ax.axis('off')
298    
299    # Add labels for each structure
300    if show_labels and len(unique_labels) > 0:
301        for label_id in unique_labels:
302            # Get all pixels for this label
303            component_mask = (labeled_skeleton == label_id)
304            rows, cols = np.where(component_mask)
305            
306            if len(rows) > 0:
307                # Calculate centroid of this component
308                centroid_row = int(np.mean(rows))
309                centroid_col = int(np.mean(cols))
310                
311                # Find the bottommost point (highest row value)
312                max_row = rows.max()
313                
314                # Place label below the bottommost point
315                label_row = max_row + 50  # 50 pixels below
316                label_col = centroid_col
317                
318                ax.text(label_col, label_row, str(int(label_id)), 
319                       color='yellow', fontsize=14, fontweight='bold',
320                       bbox=dict(boxstyle='round,pad=0.5', facecolor='red', 
321                                edgecolor='white', linewidth=2, alpha=0.8),
322                       ha='center', va='center')
323    
324    plt.show()
325    
326    print(f"Found {len(unique_labels)} connected structure(s)")
327    print(f"Total skeleton pixels: {np.sum(skeleton)}")
328    
329    return skeleton, labeled_skeleton, unique_labels

Skeletonize, label, thicken, and visualize a binary mask.

Arguments:
  • binary_mask: Binary numpy array (bool or 0/255) or path to image file
  • dilate_iterations: Number of dilation iterations to thicken skeleton
  • figsize: Figure size tuple
  • show_labels: Whether to show component labels on the image
Returns:

tuple: (skeleton, labeled_skeleton, unique_labels)

def visualize_assignments( shoot_mask, root_mask, structures, assignments, ref_points, size=(16, 8)):
332def visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points, size=(16, 8)):
333    """Visualize shoot-root assignments with ordered labels 1-5 and root labels.
334    
335    Args:
336        shoot_mask: Binary mask of shoot regions
337        root_mask: Binary mask of root regions (not directly used)
338        structures: Dictionary from extract_root_structures containing 'roots' with root data
339        assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments
340        ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids
341        size: Figure size tuple (width, height)
342    
343    Returns:
344        None: Displays the visualization
345    """
346    # Label the shoot mask
347    labeled_shoots, num_shoots = ndimage.label(shoot_mask)
348    
349    # Get image dimensions
350    h, w = shoot_mask.shape
351    
352    # Create blank RGB image
353    rgb_img = np.zeros((h, w, 3), dtype=np.uint8)
354    
355    # Color palette (5 distinct colors)
356    colors = [
357        [255, 0, 0],      # Red
358        [255, 165, 0],    # Orange  
359        [255, 255, 0],    # Yellow
360        [0, 255, 255],    # Cyan
361        [255, 0, 255]     # Magenta
362    ]
363    
364    legend_elements = []
365    
366    # Process each shoot in order
367    for shoot_label, info in assignments.items():
368        order = info['order']
369        root_label = info['root_label']
370        color = colors[order]
371        
372        # Draw shoot
373        shoot_region = (labeled_shoots == shoot_label)
374        rgb_img[shoot_region] = color
375        
376        # Draw root if assigned
377        if root_label is not None:
378            root_mask = structures['roots'][root_label]['mask']
379            dilated_root = ndimage.binary_dilation(root_mask, iterations=2)
380            rgb_img[dilated_root] = color
381            
382            score = info['score_dict']['score']
383            legend_elements.append(
384                mpatches.Patch(color=np.array(color)/255, 
385                             label=f"S{order+1}->R{root_label} ({score:.2f})")
386            )
387        else:
388            legend_elements.append(
389                mpatches.Patch(color=np.array(color)/255, 
390                             label=f"S{order+1}->None")
391            )
392    
393    # Create figure
394    fig, ax = plt.subplots(figsize=size)
395    ax.imshow(rgb_img)
396    
397    # Add shoot order labels and root labels
398    for shoot_label, info in assignments.items():
399        order = info['order']
400        root_label = info['root_label']
401        
402        if root_label is not None:
403            # Find root endpoint (lowest y-coordinate)
404            root_mask = structures['roots'][root_label]['mask']
405            y_coords, x_coords = np.where(root_mask)
406            
407            if len(y_coords) > 0:
408                max_y = np.max(y_coords)
409                max_y_idx = np.argmax(y_coords)
410                endpoint_x = x_coords[max_y_idx]
411                
412                # Shoot order label 50px below endpoint
413                ax.text(endpoint_x, max_y + 300, str(order + 1), 
414                       color='white', fontsize=20, fontweight='bold',
415                       ha='center', va='center',
416                       bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7))
417        else:
418            # Place label at shoot centroid for ungerminated
419            cy, cx = ref_points[shoot_label]['centroid']
420            ax.text(cx, cy, str(order + 1), 
421                   color='white', fontsize=20, fontweight='bold',
422                   ha='center', va='center',
423                   bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7))
424    
425    ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
426    ax.set_title('Shoot-Root Assignments (Shoot# = Left->Right Order)', fontsize=12)
427    ax.axis('off')
428    
429    plt.tight_layout()
430    plt.show()

Visualize shoot-root assignments with ordered labels 1-5 and root labels.

Arguments:
  • shoot_mask: Binary mask of shoot regions
  • root_mask: Binary mask of root regions (not directly used)
  • structures: Dictionary from extract_root_structures containing 'roots' with root data
  • assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments
  • ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids
  • size: Figure size tuple (width, height)
Returns:

None: Displays the visualization

def visualize_root_shoot_matches( structures, labeled_shoots, root_matches, zoom_to_content=True, padding=50, root_thickness=3):
433def visualize_root_shoot_matches(structures, labeled_shoots, root_matches, 
434                                  zoom_to_content=True, padding=50, 
435                                  root_thickness=3):
436    """
437    Visualize root-to-shoot matches with color coding.
438    
439    Args:
440        structures: Dict from extract_root_structures
441        labeled_shoots: Labeled shoot array
442        root_matches: Dict from match_roots_to_shoots
443        zoom_to_content: If True, zoom to bounding box of all content
444        padding: Pixels to add around bounding box
445        root_thickness: Pixels to dilate root skeletons for visibility
446    """
447    # Create RGB image
448    h, w = structures['labeled_skeleton'].shape
449    vis_img = np.zeros((h, w, 3), dtype=np.uint8)
450    
451    # Color map for shoots (distinct colors)
452    colors = [
453        [255, 0, 0],      # Red
454        [0, 255, 0],      # Green
455        [0, 0, 255],      # Blue
456        [255, 255, 0],    # Yellow
457        [255, 0, 255],    # Magenta
458        [0, 255, 255],    # Cyan
459        [255, 128, 0],    # Orange
460        [128, 0, 255],    # Purple
461    ]
462    
463    # Draw shoots with their colors
464    for shoot_label in range(1, labeled_shoots.max() + 1):
465        shoot_mask = labeled_shoots == shoot_label
466        color = colors[(shoot_label - 1) % len(colors)]
467        vis_img[shoot_mask] = color
468    
469    # Draw roots with matching shoot colors (or white for unmatched)
470    for root_label in structures['unique_labels']:
471        root_skeleton = structures['labeled_skeleton'] == root_label
472        
473        # Thicken the skeleton for visibility
474        if root_thickness > 0:
475            root_skeleton = ndimage.binary_dilation(
476                root_skeleton, 
477                structure=ndimage.generate_binary_structure(2, 1),
478                iterations=root_thickness
479            )
480        
481        match = root_matches.get(root_label)
482        if match and match is not None:
483            shoot_label = match['shoot_label']
484            color = colors[(shoot_label - 1) % len(colors)]
485        else:
486            color = [128, 128, 128]  # Gray for unmatched
487        
488        vis_img[root_skeleton] = color
489    
490    # Find bounding box if zooming
491    if zoom_to_content:
492        # Find all non-black pixels
493        content_mask = np.any(vis_img > 0, axis=2)
494        rows, cols = np.where(content_mask)
495        
496        if len(rows) > 0:
497            y_min = max(0, rows.min() - padding)
498            y_max = min(h, rows.max() + padding)
499            x_min = max(0, cols.min() - padding)
500            x_max = min(w, cols.max() + padding)
501            
502            vis_img = vis_img[y_min:y_max, x_min:x_max]
503    
504    # Plot
505    fig, ax = plt.subplots(figsize=(15, 8))
506    ax.imshow(vis_img)
507    ax.set_title('Root-Shoot Matches (same color = matched)', fontsize=14)
508    ax.axis('off')
509    
510    # Create legend
511    legend_elements = []
512    matched_shoots = set()
513    for root_label, match in root_matches.items():
514        if match and match is not None:
515            matched_shoots.add(match['shoot_label'])
516    
517    for shoot_label in sorted(matched_shoots):
518        color = np.array(colors[(shoot_label - 1) % len(colors)]) / 255.0
519        legend_elements.append(
520            mpatches.Patch(color=color, label=f'Shoot/Root {shoot_label}')
521        )
522    
523    # Add unmatched if any
524    unmatched_count = sum(1 for m in root_matches.values() if m is None)
525    if unmatched_count > 0:
526        legend_elements.append(
527            mpatches.Patch(color=[0.5, 0.5, 0.5], 
528                          label=f'Unmatched ({unmatched_count})')
529        )
530    
531    ax.legend(handles=legend_elements, loc='lower right', fontsize=10)
532    
533    plt.tight_layout()
534    plt.show()
535    
536    # Print summary
537    print(f"\nMatch Summary:")
538    print(f"Total root structures: {len(structures['unique_labels'])}")
539    print(f"Matched: {sum(1 for m in root_matches.values() if m is not None)}")
540    print(f"Unmatched (noise): {unmatched_count}")

Visualize root-to-shoot matches with color coding.

Arguments:
  • structures: Dict from extract_root_structures
  • labeled_shoots: Labeled shoot array
  • root_matches: Dict from match_roots_to_shoots
  • zoom_to_content: If True, zoom to bounding box of all content
  • padding: Pixels to add around bounding box
  • root_thickness: Pixels to dilate root skeletons for visibility
def get_skeleton_path_from_branch(skeleton, branch_row):
543def get_skeleton_path_from_branch(skeleton, branch_row):
544    """
545    Extract actual skeleton pixels for a branch using the labeled skeleton.
546    
547    Args:
548        skeleton: Binary skeleton array
549        branch_row: Row from branch_data DataFrame
550        
551    Returns:
552        Array of (row, col) coordinates along the skeleton path
553    """
554    from skimage.morphology import skeletonize
555    from skimage.graph import route_through_array
556    
557    start_coord = (int(branch_row['image-coord-src-0']), 
558                   int(branch_row['image-coord-src-1']))
559    end_coord = (int(branch_row['image-coord-dst-0']), 
560                 int(branch_row['image-coord-dst-1']))
561    
562    # Create cost array: low cost on skeleton, high elsewhere
563    cost = np.where(skeleton, 1, 10000)
564    
565    try:
566        indices, weight = route_through_array(
567            cost, start_coord, end_coord, fully_connected=True
568        )
569        return np.array(indices)
570    except:
571        # Fallback to straight line
572        return np.array([start_coord, end_coord])

Extract actual skeleton pixels for a branch using the labeled skeleton.

Arguments:
  • skeleton: Binary skeleton array
  • branch_row: Row from branch_data DataFrame
Returns:

Array of (row, col) coordinates along the skeleton path

def visualize_single_root_structure( skeleton, branch_data, title='Root Structure', dilate_iterations=2, zoom_to_content=True, padding=50, trunk_path=None, path_edge_width=3, path_node_size=10, show_edge_dots=True, edge_dot_spacing=50, show_cumulative_length=True, output_path=None):
574def visualize_single_root_structure(skeleton, branch_data, title='Root Structure',
575                                   dilate_iterations=2, zoom_to_content=True, 
576                                   padding=50, trunk_path=None,
577                                   path_edge_width=3, path_node_size=10,
578                                   show_edge_dots=True, edge_dot_spacing=50,
579                                   show_cumulative_length=True,
580                                   output_path=None):
581    """Visualize a single root structure with trunk path and cumulative length measurements.
582    
583    Displays a root skeleton with the main trunk path highlighted in blue, showing nodes
584    along the path with gradient coloring (red to yellow) and optional cumulative length
585    labels at regular intervals. The path follows the actual skeleton pixels rather than
586    straight-line connections between nodes.
587    
588    Args:
589        skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure.
590        branch_data (pd.DataFrame): DataFrame containing branch information with columns:
591            'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1',
592            'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'.
593        title (str, optional): Title for the visualization. Defaults to 'Root Structure'.
594        dilate_iterations (int, optional): Number of morphological dilation iterations 
595            to thicken skeleton for visibility. Defaults to 2.
596        zoom_to_content (bool, optional): Whether to zoom to the bounding box of the 
597            root structure. Defaults to True.
598        padding (int, optional): Padding in pixels around the zoomed content. 
599            Defaults to 50.
600        trunk_path (list, optional): List of tuples from find_farthest_endpoint_path:
601            [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton
602            is shown. Defaults to None.
603        path_edge_width (float, optional): Line width for trunk path edges. 
604            Defaults to 3.
605        path_node_size (float, optional): Marker size for nodes along trunk path. 
606            Defaults to 10.
607        show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path.
608            Defaults to True.
609        edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum
610            spacing is 50 pixels to prevent overcrowding. Defaults to 50.
611        show_cumulative_length (bool, optional): Whether to display yellow labels showing
612            cumulative length at regular intervals along the path. Defaults to True.
613        output_path (str or Path): path to save image
614    
615    Returns:
616        None: Displays matplotlib figure showing the root structure visualization.
617    
618    Notes:
619        - Trunk path is drawn in blue following actual skeleton pixels
620        - Cyan dots mark regular intervals along the path
621        - Nodes are colored with autumn colormap (red=start, yellow=end)
622        - Yellow labels show cumulative length every other dot
623        - Lime green label shows total length at the final node
624        - Node IDs are displayed in white text boxes offset from nodes
625    
626    Examples:
627        >>> skeleton = structures['labeled_skeleton'] == root_label
628        >>> branch_data = structures['roots'][root_label]['branch_data']
629        >>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down')
630        >>> visualize_single_root_structure(
631        ...     skeleton, branch_data, 
632        ...     title='Root 35 Structure',
633        ...     trunk_path=path,
634        ...     edge_dot_spacing=100
635        ... )
636    """
637    from scipy.ndimage import binary_dilation
638    
639    # Thicken skeleton for display
640    thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations)
641    skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8)
642    
643    # Get all node coordinates
644    all_node_coords = {}
645    for idx, row in branch_data.iterrows():
646        all_node_coords[row['node-id-src']] = (row['image-coord-src-0'], 
647                                                row['image-coord-src-1'])
648        all_node_coords[row['node-id-dst']] = (row['image-coord-dst-0'], 
649                                                row['image-coord-dst-1'])
650    
651    # Get trunk path nodes
652    path_nodes = []
653    path_branch_indices = []
654    if trunk_path:
655        for node, branch_idx, _, _ in trunk_path:
656            path_nodes.append(int(node))
657            if branch_idx is not None:
658                path_branch_indices.append(branch_idx)
659    
660    # Bounding box
661    if zoom_to_content and all_node_coords:
662        rows = [coord[0] for coord in all_node_coords.values()]
663        cols = [coord[1] for coord in all_node_coords.values()]
664        row_min, row_max = min(rows), max(rows)
665        col_min, col_max = min(cols), max(cols)
666    else:
667        zoom_to_content = False
668    
669    # Plot
670    fig, ax = plt.subplots(figsize=(12, 12))
671    ax.imshow(skeleton_uint8, cmap='gray')
672    
673    # Draw trunk path and collect cumulative lengths
674    all_path_pixels = []
675    cumulative_lengths = []  # Store (pixel_index, cumulative_length)
676    current_length = 0.0
677    total_length = 0.0
678    
679    if trunk_path:
680        for i, (node, branch_idx, distance, _) in enumerate(trunk_path):
681            if branch_idx is not None:
682                branch_row = branch_data.iloc[branch_idx]
683                
684                # Get actual skeleton pixels for this branch
685                path_pixels = get_skeleton_path_from_branch(skeleton, branch_row)
686                
687                # Track cumulative length at start of this segment
688                pixel_start_idx = len(np.vstack(all_path_pixels)) if all_path_pixels else 0
689                cumulative_lengths.append((pixel_start_idx, current_length))
690                
691                all_path_pixels.append(path_pixels)
692                
693                # Add this segment's length
694                current_length += distance
695                total_length += distance
696                
697                # Draw the path in blue
698                ax.plot(path_pixels[:, 1], path_pixels[:, 0], '-',
699                       color='blue', linewidth=path_edge_width,
700                       alpha=0.8, zorder=2)
701        
702        # Add final length
703        if all_path_pixels:
704            final_idx = len(np.vstack(all_path_pixels))
705            cumulative_lengths.append((final_idx, current_length))
706        
707        # Draw dots and labels
708        if show_edge_dots and all_path_pixels:
709            # Concatenate all segments
710            combined_path = np.vstack(all_path_pixels)
711            
712            # Ensure minimum spacing of 50px
713            actual_spacing = max(edge_dot_spacing, 50)
714            sampled_indices = np.arange(0, len(combined_path), actual_spacing)
715            
716            # Draw dots
717            sampled_pixels = combined_path[sampled_indices]
718            ax.plot(sampled_pixels[:, 1], sampled_pixels[:, 0], 'o',
719                   color='cyan', markersize=4, alpha=0.9, zorder=2.5,
720                   markeredgecolor='blue', markeredgewidth=0.5)
721            
722            # Add cumulative length labels
723            if show_cumulative_length:
724                for sample_idx in sampled_indices[::2]:  # Label every other dot to reduce clutter
725                    if sample_idx < len(combined_path):
726                        # Find cumulative length at this pixel
727                        cum_length = 0.0
728                        for seg_idx, (pixel_idx, length) in enumerate(cumulative_lengths[:-1]):
729                            next_pixel_idx = cumulative_lengths[seg_idx + 1][0]
730                            if pixel_idx <= sample_idx < next_pixel_idx:
731                                # Interpolate within this segment
732                                segment_progress = (sample_idx - pixel_idx) / max(next_pixel_idx - pixel_idx, 1)
733                                next_length = cumulative_lengths[seg_idx + 1][1]
734                                cum_length = length + segment_progress * (next_length - length)
735                                break
736                        else:
737                            cum_length = current_length
738                        
739                        row, col = combined_path[sample_idx]
740                        ax.text(col + 15, row, f'{cum_length:.0f}px',
741                               color='yellow', fontsize=8, fontweight='bold',
742                               ha='left', va='center',
743                               bbox=dict(boxstyle='round,pad=0.2',
744                                       facecolor='black', alpha=0.6),
745                               zorder=4)
746    
747    # Draw path nodes
748    if trunk_path:
749        for i, node_id in enumerate(path_nodes):
750            if node_id in all_node_coords:
751                row, col = all_node_coords[node_id]
752                color = plt.cm.autumn(i / max(len(path_nodes)-1, 1))
753                ax.plot(col, row, 'o', color=color, markersize=path_node_size,
754                       zorder=3, markeredgecolor='white', markeredgewidth=1.5)
755                
756                # Add node label
757                offset_x = 20
758                offset_y = -5
759                ax.text(col + offset_x, row + offset_y, str(node_id),
760                       color='white', fontsize=9, fontweight='bold',
761                       ha='left', va='center',
762                       bbox=dict(boxstyle='round,pad=0.3',
763                                facecolor='black', alpha=0.7),
764                       zorder=4)
765                
766                # Add total length at final node
767                if i == len(path_nodes) - 1 and show_cumulative_length:
768                    ax.text(col + 15, row + 20, f'{total_length:.1f}px',
769                           color='lime', fontsize=10, fontweight='bold',
770                           ha='left', va='top',
771                           bbox=dict(boxstyle='round,pad=0.3',
772                                   facecolor='black', alpha=0.8),
773                           zorder=5)
774    
775    if zoom_to_content:
776        ax.set_xlim(col_min - padding, col_max + padding)
777        ax.set_ylim(row_max + padding, row_min - padding)
778    
779    ax.set_title(title, fontsize=14)
780    ax.axis('off')
781    plt.tight_layout()
782    if output_path:
783        plt.savefig(str(output_path))
784    plt.show()

Visualize a single root structure with trunk path and cumulative length measurements.

Displays a root skeleton with the main trunk path highlighted in blue, showing nodes along the path with gradient coloring (red to yellow) and optional cumulative length labels at regular intervals. The path follows the actual skeleton pixels rather than straight-line connections between nodes.

Arguments:
  • skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure.
  • branch_data (pd.DataFrame): DataFrame containing branch information with columns: 'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1', 'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'.
  • title (str, optional): Title for the visualization. Defaults to 'Root Structure'.
  • dilate_iterations (int, optional): Number of morphological dilation iterations to thicken skeleton for visibility. Defaults to 2.
  • zoom_to_content (bool, optional): Whether to zoom to the bounding box of the root structure. Defaults to True.
  • padding (int, optional): Padding in pixels around the zoomed content. Defaults to 50.
  • trunk_path (list, optional): List of tuples from find_farthest_endpoint_path: [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton is shown. Defaults to None.
  • path_edge_width (float, optional): Line width for trunk path edges. Defaults to 3.
  • path_node_size (float, optional): Marker size for nodes along trunk path. Defaults to 10.
  • show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path. Defaults to True.
  • edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum spacing is 50 pixels to prevent overcrowding. Defaults to 50.
  • show_cumulative_length (bool, optional): Whether to display yellow labels showing cumulative length at regular intervals along the path. Defaults to True.
  • output_path (str or Path): path to save image
Returns:

None: Displays matplotlib figure showing the root structure visualization.

Notes:
  • Trunk path is drawn in blue following actual skeleton pixels
  • Cyan dots mark regular intervals along the path
  • Nodes are colored with autumn colormap (red=start, yellow=end)
  • Yellow labels show cumulative length every other dot
  • Lime green label shows total length at the final node
  • Node IDs are displayed in white text boxes offset from nodes
Examples:
>>> skeleton = structures['labeled_skeleton'] == root_label
>>> branch_data = structures['roots'][root_label]['branch_data']
>>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down')
>>> visualize_single_root_structure(
...     skeleton, branch_data, 
...     title='Root 35 Structure',
...     trunk_path=path,
...     edge_dot_spacing=100
... )
def visualize_root_lengths( structures, top_node_results, labeled_shoots, zoom_to_content=True, padding=50, show_detailed_roots=False, detail_viz_kwargs=None):
787def visualize_root_lengths(structures, top_node_results, labeled_shoots, 
788                          zoom_to_content=True, padding=50,
789                          show_detailed_roots=False, detail_viz_kwargs=None):
790    """Visualize matched roots with their calculated lengths displayed.
791    
792    Creates an overview visualization showing all matched root structures color-coded
793    by their associated shoot regions, with length measurements labeled. Optionally
794    displays detailed individual root visualizations below the overview.
795    
796    Args:
797        structures (dict): Dictionary from extract_root_structures containing:
798            - 'labeled_skeleton': Array with labeled skeleton structures
799            - 'roots': Dictionary of root data by label
800        top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys:
801            - root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes'
802        labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each
803            shoot region has a unique integer label.
804        zoom_to_content (bool, optional): If True, zoom to bounding box of all content.
805            Defaults to True.
806        padding (int, optional): Pixels to add around bounding box when zooming.
807            Defaults to 50.
808        show_detailed_roots (bool, optional): If True, display detailed visualization
809            for each individual root below the overview. Defaults to False.
810        detail_viz_kwargs (dict, optional): Keyword arguments to pass to 
811            visualize_single_root_structure for detailed views. Common options:
812            - 'dilate_iterations': int, skeleton thickness (default: 2)
813            - 'path_edge_width': float, trunk path line width (default: 3)
814            - 'path_node_size': float, node marker size (default: 10)
815            - 'show_edge_dots': bool, show dots along path (default: True)
816            - 'edge_dot_spacing': int, spacing between dots (default: 50)
817            - 'show_cumulative_length': bool, show length labels (default: True)
818            If None, uses default values. Defaults to None.
819    
820    Returns:
821        None: Displays matplotlib figures showing root length visualizations.
822    
823    Notes:
824        - Overview uses color-coding: each root matches its shoot color
825        - Roots are thickened with binary dilation for visibility
826        - Length labels appear at root centroids in white text boxes
827        - Detailed views are indexed (idx=0, idx=1, etc.) in display order
828        - Failed root measurements appear in gray with no length label
829    
830    Examples:
831        >>> # Basic usage - overview only
832        >>> visualize_root_lengths(structures, top_node_results, labeled_shoots)
833        
834        >>> # With detailed individual visualizations
835        >>> visualize_root_lengths(
836        ...     structures, top_node_results, labeled_shoots,
837        ...     show_detailed_roots=True,
838        ...     detail_viz_kwargs={
839        ...         'edge_dot_spacing': 100,
840        ...         'path_edge_width': 2,
841        ...         'show_cumulative_length': True
842        ...     }
843        ... )
844    """
845    
846    # Default kwargs for detailed visualizations
847    if detail_viz_kwargs is None:
848        detail_viz_kwargs = {}
849    
850    # Create RGB image
851    h, w = structures['labeled_skeleton'].shape
852    vis_img = np.zeros((h, w, 3), dtype=np.uint8)
853    
854    # Color map
855    colors = [
856        [255, 0, 0],      # Red
857        [0, 255, 0],      # Green
858        [0, 0, 255],      # Blue
859        [255, 255, 0],    # Yellow
860        [255, 0, 255],    # Magenta
861        [0, 255, 255],    # Cyan
862        [255, 128, 0],    # Orange
863        [128, 0, 255],    # Purple
864    ]
865    
866    # Draw shoots
867    for shoot_label in range(1, labeled_shoots.max() + 1):
868        shoot_mask = labeled_shoots == shoot_label
869        color = colors[(shoot_label - 1) % len(colors)]
870        vis_img[shoot_mask] = color
871    
872    # Process and draw roots with paths
873    root_info = []
874    root_details = []  # Store data for detailed visualizations
875    
876    for root_label, result in top_node_results.items():
877        branch_data = result['branch_data']
878        shoot_label = result['shoot_label']
879        top_node = result['top_nodes'][0][0]
880        
881        # Get root skeleton and thicken
882        root_skeleton = structures['labeled_skeleton'] == root_label
883        root_skeleton = ndimage.binary_dilation(
884            root_skeleton, 
885            structure=ndimage.generate_binary_structure(2, 1),
886            iterations=3
887        )
888        
889        color = colors[(shoot_label - 1) % len(colors)]
890        
891        try:
892            # Find path and calculate length
893            path = find_farthest_endpoint_path(
894                branch_data, top_node, 
895                direction='down', use_smart_scoring=True,
896                verbose=False
897            )
898            root_length = calculate_skeleton_length_px(path)
899            
900            # Draw root
901            vis_img[root_skeleton] = color
902            
903            # Store info for labels
904            y_coords, x_coords = np.where(root_skeleton)
905            centroid_y, centroid_x = np.mean(y_coords), np.mean(x_coords)
906            root_info.append((centroid_x, centroid_y, root_label, shoot_label, root_length))
907            
908            # Store for detailed visualization
909            if show_detailed_roots:
910                root_details.append({
911                    'root_label': root_label,
912                    'skeleton': structures['labeled_skeleton'] == root_label,
913                    'branch_data': branch_data,
914                    'path': path,
915                    'shoot_label': shoot_label
916                })
917            
918        except Exception as e:
919            # Draw root in gray if failed
920            vis_img[root_skeleton] = [128, 128, 128]
921    
922    # Zoom if requested
923    if zoom_to_content:
924        content_mask = np.any(vis_img > 0, axis=2)
925        rows, cols = np.where(content_mask)
926        if len(rows) > 0:
927            y_min = max(0, rows.min() - padding)
928            y_max = min(h, rows.max() + padding)
929            x_min = max(0, cols.min() - padding)
930            x_max = min(w, cols.max() + padding)
931            vis_img = vis_img[y_min:y_max, x_min:x_max]
932            # Adjust coordinates for zoom
933            root_info = [(x - x_min, y - y_min, rl, sl, length) 
934                        for x, y, rl, sl, length in root_info]
935    
936    # Plot overview
937    fig, ax = plt.subplots(figsize=(15, 8))
938    ax.imshow(vis_img)
939    
940    # Add text labels
941    for x, y, root_label, shoot_label, length in root_info:
942        ax.text(x, y, f'{length:.1f}px', 
943               color='white', fontsize=10, weight='bold',
944               ha='center', va='center',
945               bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
946    
947    ax.set_title('Root Lengths (pixels)', fontsize=14)
948    ax.axis('off')
949    plt.tight_layout()
950    plt.show()
951    
952    # Display detailed visualizations if requested
953    if show_detailed_roots and root_details:
954        print(f"\n{'='*60}")
955        print(f"Detailed Root Visualizations ({len(root_details)} roots)")
956        print(f"{'='*60}\n")
957        
958        for idx, detail in enumerate(root_details):
959            print(f"Displaying Root {detail['root_label']} (idx={idx})...")
960            
961            visualize_single_root_structure(
962                skeleton=detail['skeleton'],
963                branch_data=detail['branch_data'],
964                title=f"Root {detail['root_label']} (idx={idx}) - Shoot {detail['shoot_label']}",
965                trunk_path=detail['path'],
966                **detail_viz_kwargs
967            )

Visualize matched roots with their calculated lengths displayed.

Creates an overview visualization showing all matched root structures color-coded by their associated shoot regions, with length measurements labeled. Optionally displays detailed individual root visualizations below the overview.

Arguments:
  • structures (dict): Dictionary from extract_root_structures containing:
    • 'labeled_skeleton': Array with labeled skeleton structures
    • 'roots': Dictionary of root data by label
  • top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys:
    • root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes'
  • labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each shoot region has a unique integer label.
  • zoom_to_content (bool, optional): If True, zoom to bounding box of all content. Defaults to True.
  • padding (int, optional): Pixels to add around bounding box when zooming. Defaults to 50.
  • show_detailed_roots (bool, optional): If True, display detailed visualization for each individual root below the overview. Defaults to False.
  • detail_viz_kwargs (dict, optional): Keyword arguments to pass to visualize_single_root_structure for detailed views. Common options:
    • 'dilate_iterations': int, skeleton thickness (default: 2)
    • 'path_edge_width': float, trunk path line width (default: 3)
    • 'path_node_size': float, node marker size (default: 10)
    • 'show_edge_dots': bool, show dots along path (default: True)
    • 'edge_dot_spacing': int, spacing between dots (default: 50)
    • 'show_cumulative_length': bool, show length labels (default: True) If None, uses default values. Defaults to None.
Returns:

None: Displays matplotlib figures showing root length visualizations.

Notes:
  • Overview uses color-coding: each root matches its shoot color
  • Roots are thickened with binary dilation for visibility
  • Length labels appear at root centroids in white text boxes
  • Detailed views are indexed (idx=0, idx=1, etc.) in display order
  • Failed root measurements appear in gray with no length label
Examples:
>>> # Basic usage - overview only
>>> visualize_root_lengths(structures, top_node_results, labeled_shoots)
>>> # With detailed individual visualizations
>>> visualize_root_lengths(
...     structures, top_node_results, labeled_shoots,
...     show_detailed_roots=True,
...     detail_viz_kwargs={
...         'edge_dot_spacing': 100,
...         'path_edge_width': 2,
...         'show_cumulative_length': True
...     }
... )