library.root_analysis_visualization
1from scipy.ndimage import binary_dilation 2from scipy import ndimage 3 4import numpy as np 5import matplotlib.pyplot as plt 6from scipy import ndimage 7import matplotlib.patches as mpatches 8from library.root_analysis import * 9 10 11def visualize_skeleton_with_nodes(skeleton, show_branch_labels=False, row_range=None, padding=50, dilate_iterations=3): 12 """ 13 Visualize skeleton with nodes and optionally branch labels. 14 15 Args: 16 skeleton (skeleton numpy array): object to show 17 show_branch_labels (bool): add labels 18 row_range(tuple): rows to zoom in on 19 padding(int): add additional pixels to make skeleton easier to view 20 dilate_iterations(int): number of rounds to dilate and increase size 21 22 Returns: 23 branch_data (pandas table) 24 """ 25 26 # Thicken skeleton for visibility 27 if dilate_iterations > 0: 28 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 29 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 30 else: 31 skeleton_uint8 = (skeleton * 255).astype(np.uint8) 32 33 # Find bounding box 34 rows, cols = np.where(skeleton) 35 36 if row_range is not None: 37 row_min, row_max = row_range 38 mask = (rows >= row_min) & (rows <= row_max) 39 if np.any(mask): 40 col_min, col_max = cols[mask].min(), cols[mask].max() 41 else: 42 col_min, col_max = cols.min(), cols.max() 43 else: 44 row_min, row_max = rows.min(), rows.max() 45 col_min, col_max = cols.min(), cols.max() 46 47 # Get branch data and nodes 48 skeleton_obj = Skeleton(skeleton) 49 branch_data = summarize(skeleton_obj) 50 51 # Get node coordinates 52 node_coords = {} 53 for idx, row in branch_data.iterrows(): 54 node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1']) 55 node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1']) 56 57 # Plot 58 fig, ax = plt.subplots(figsize=(14, 12)) 59 ax.imshow(skeleton_uint8, cmap='gray') 60 61 # Track occupied regions (for collision detection) 62 occupied_regions = [] # List of (row, col, radius) tuples 63 64 # First pass: plot node markers 65 visible_nodes = [] 66 for node_id, (node_row, node_col) in node_coords.items(): 67 if (row_min - padding <= node_row <= row_max + padding and 68 col_min - padding <= node_col <= col_max + padding): 69 visible_nodes.append((node_id, node_row, node_col)) 70 ax.plot(node_col, node_row, 'ro', markersize=12, zorder=3) 71 # Mark node position as occupied 72 occupied_regions.append((node_row, node_col, 15)) # Small radius for node dot itself 73 74 # Second pass: place node labels with collision avoidance 75 for node_id, node_row, node_col in visible_nodes: 76 # Try different offset positions for node label 77 offsets = [(40, 0), (-40, 0), (0, 40), (0, -40), (30, 30), (-30, -30), (30, -30), (-30, 30)] 78 best_offset = (40, 0) 79 min_collision = float('inf') 80 81 for offset_col, offset_row in offsets: 82 test_col = node_col + offset_col 83 test_row = node_row + offset_row 84 85 # Check collision with all occupied regions 86 max_collision = 0 87 for occ_row, occ_col, radius in occupied_regions: 88 dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2) 89 if dist < radius: 90 max_collision = max(max_collision, radius - dist) 91 92 if max_collision < min_collision: 93 min_collision = max_collision 94 best_offset = (offset_col, offset_row) 95 if max_collision == 0: 96 break 97 98 label_col = node_col + best_offset[0] 99 label_row = node_row + best_offset[1] 100 101 # Draw leader line 102 ax.plot([node_col, label_col], [node_row, label_row], 103 'y--', linewidth=1, alpha=0.6, zorder=2) 104 105 # Draw label 106 ax.text(label_col, label_row, str(int(node_id)), color='yellow', fontsize=11, 107 bbox=dict(boxstyle='round', facecolor='black', alpha=0.7), 108 ha='center', zorder=4) 109 110 # Mark label area as occupied 111 occupied_regions.append((label_row, label_col, 45)) 112 113 # Branch labels with collision avoidance 114 if show_branch_labels: 115 branch_positions = [] 116 for idx, row in branch_data.iterrows(): 117 mid_row = (row['coord-src-0'] + row['coord-dst-0']) / 2 118 mid_col = (row['coord-src-1'] + row['coord-dst-1']) / 2 119 120 if (row_min - padding <= mid_row <= row_max + padding and 121 col_min - padding <= mid_col <= col_max + padding): 122 branch_positions.append((idx, mid_row, mid_col)) 123 124 for idx, mid_row, mid_col in branch_positions: 125 # Try different offset positions 126 offsets = [(60, 0), (-60, 0), (0, 60), (0, -60), (45, 45), (-45, -45), 127 (120, 0), (-120, 0), (0, 120), (0, -120)] 128 best_offset = (60, 0) 129 min_collision = float('inf') 130 131 for offset_col, offset_row in offsets: 132 test_col = mid_col + offset_col 133 test_row = mid_row + offset_row 134 135 max_collision = 0 136 for occ_row, occ_col, radius in occupied_regions: 137 dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2) 138 if dist < radius: 139 max_collision = max(max_collision, radius - dist) 140 141 if max_collision < min_collision: 142 min_collision = max_collision 143 best_offset = (offset_col, offset_row) 144 if max_collision == 0: 145 break 146 147 label_col = mid_col + best_offset[0] 148 label_row = mid_row + best_offset[1] 149 150 # Draw leader line 151 ax.plot([mid_col, label_col], [mid_row, label_row], 152 'c--', linewidth=1, alpha=0.5, zorder=1) 153 154 # Draw label 155 ax.text(label_col, label_row, f"B{idx}", color='cyan', fontsize=9, 156 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7), 157 ha='center', zorder=2) 158 159 occupied_regions.append((label_row, label_col, 40)) 160 161 # Set zoom limits 162 ax.set_xlim(col_min - padding, col_max + padding) 163 ax.set_ylim(row_max + padding, row_min - padding) 164 165 title = f'Skeleton with {len(node_coords)} nodes, {branch_data.shape[0]} branches' 166 if row_range: 167 title += f' (rows {row_range[0]}-{row_range[1]})' 168 ax.set_title(title) 169 170 plt.show() 171 172 return branch_data 173 174def visualize_trunk_path(skeleton, branch_data, trunk_path, title='Trunk Path', 175 dilate_iterations=3, zoom_to_path=True, padding=100): 176 """ 177 Visualize the trunk path overlaid on the skeleton image. 178 179 Args: 180 skeleton: Binary skeleton array 181 branch_data: DataFrame with branch information 182 trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...] 183 title: Plot title 184 dilate_iterations: Number of iterations to thicken skeleton 185 zoom_to_path: Whether to zoom to the trunk path region 186 padding: Padding around zoom region 187 """ 188 189 # Thicken skeleton for visibility 190 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 191 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 192 193 # Get all nodes in the path 194 path_nodes = [int(node) for node, _, _, _ in trunk_path] 195 path_branches = [idx for _, idx, _, _ in trunk_path if idx is not None] 196 197 # Get coordinates for all nodes 198 all_node_coords = {} 199 for idx, row in branch_data.iterrows(): 200 all_node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1']) 201 all_node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1']) 202 203 # Get bounding box of path nodes for zooming 204 if zoom_to_path and path_nodes: 205 path_rows = [all_node_coords[n][0] for n in path_nodes if n in all_node_coords] 206 path_cols = [all_node_coords[n][1] for n in path_nodes if n in all_node_coords] 207 if path_rows and path_cols: 208 row_min, row_max = min(path_rows), max(path_rows) 209 col_min, col_max = min(path_cols), max(path_cols) 210 else: 211 zoom_to_path = False 212 213 # Plot 214 fig, ax = plt.subplots(figsize=(14, 14)) 215 ax.imshow(skeleton_uint8, cmap='gray') 216 217 # Draw all nodes (small, gray, semi-transparent) 218 for node_id, (row, col) in all_node_coords.items(): 219 ax.plot(col, row, 'o', color='gray', markersize=4, alpha=0.3) 220 221 # Draw trunk path connections FIRST (behind nodes) 222 for i in range(len(path_nodes)-1): 223 node1, node2 = path_nodes[i], path_nodes[i+1] 224 if node1 in all_node_coords and node2 in all_node_coords: 225 r1, c1 = all_node_coords[node1] 226 r2, c2 = all_node_coords[node2] 227 ax.plot([c1, c2], [r1, r2], 'r-', linewidth=4, alpha=0.8, zorder=2) 228 229 # Draw trunk path nodes (large, colored) 230 for i, node_id in enumerate(path_nodes): 231 if node_id in all_node_coords: 232 row, col = all_node_coords[node_id] 233 # Color gradient: dark red (start) to bright yellow (end) 234 color = plt.cm.autumn(i / max(len(path_nodes)-1, 1)) 235 ax.plot(col, row, 'o', color=color, markersize=16, zorder=3, 236 markeredgecolor='white', markeredgewidth=2) 237 ax.text(col+15, row, str(node_id), color='cyan', fontsize=12, 238 fontweight='bold', 239 bbox=dict(boxstyle='round', facecolor='black', alpha=0.8), 240 zorder=4) 241 242 # Zoom to path if requested 243 if zoom_to_path: 244 ax.set_xlim(col_min - padding, col_max + padding) 245 ax.set_ylim(row_max + padding, row_min - padding) 246 247 # Title with path info 248 path_str = " → ".join(map(str, path_nodes[:10])) # First 10 nodes 249 if len(path_nodes) > 10: 250 path_str += "..." 251 ax.set_title(f'{title}\nPath: {path_str}\n{title}', fontsize=14) 252 ax.axis('off') 253 plt.tight_layout() 254 plt.show() 255 256 # Print summary 257 print(f"Trunk path: {len(path_nodes)} nodes, {len(path_branches)} branches") 258 print(f"Start node: {path_nodes[0]}, End node: {path_nodes[-1]}") 259 260 # Calculate total length 261 total_length = sum(branch_data.iloc[idx]['branch-distance'] 262 for idx in path_branches) 263 print(f"Total trunk length: {total_length:.2f} pixels") 264 265 return total_length 266 267def visualize_thickened_skeleton(binary_mask, dilate_iterations=3, figsize=(12, 8), show_labels=True): 268 """ 269 Skeletonize, label, thicken, and visualize a binary mask. 270 271 Args: 272 binary_mask: Binary numpy array (bool or 0/255) or path to image file 273 dilate_iterations: Number of dilation iterations to thicken skeleton 274 figsize: Figure size tuple 275 show_labels: Whether to show component labels on the image 276 277 Returns: 278 tuple: (skeleton, labeled_skeleton, unique_labels) 279 """ 280 # Load from file if path provided 281 if not isinstance(binary_mask, np.ndarray): 282 binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE) 283 284 285 # Get labeled skeleton 286 skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask) 287 288 # Thicken skeleton for visibility 289 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 290 thick_skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 291 292 # Visualize 293 fig, ax = plt.subplots(figsize=figsize) 294 ax.imshow(thick_skeleton_uint8, cmap='gray') 295 ax.set_title(f'{len(unique_labels)} skeleton(s) (thickened for visibility)') 296 ax.axis('off') 297 298 # Add labels for each structure 299 if show_labels and len(unique_labels) > 0: 300 for label_id in unique_labels: 301 # Get all pixels for this label 302 component_mask = (labeled_skeleton == label_id) 303 rows, cols = np.where(component_mask) 304 305 if len(rows) > 0: 306 # Calculate centroid of this component 307 centroid_row = int(np.mean(rows)) 308 centroid_col = int(np.mean(cols)) 309 310 # Find the bottommost point (highest row value) 311 max_row = rows.max() 312 313 # Place label below the bottommost point 314 label_row = max_row + 50 # 50 pixels below 315 label_col = centroid_col 316 317 ax.text(label_col, label_row, str(int(label_id)), 318 color='yellow', fontsize=14, fontweight='bold', 319 bbox=dict(boxstyle='round,pad=0.5', facecolor='red', 320 edgecolor='white', linewidth=2, alpha=0.8), 321 ha='center', va='center') 322 323 plt.show() 324 325 print(f"Found {len(unique_labels)} connected structure(s)") 326 print(f"Total skeleton pixels: {np.sum(skeleton)}") 327 328 return skeleton, labeled_skeleton, unique_labels 329 330 331def visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points, size=(16, 8)): 332 """Visualize shoot-root assignments with ordered labels 1-5 and root labels. 333 334 Args: 335 shoot_mask: Binary mask of shoot regions 336 root_mask: Binary mask of root regions (not directly used) 337 structures: Dictionary from extract_root_structures containing 'roots' with root data 338 assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments 339 ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids 340 size: Figure size tuple (width, height) 341 342 Returns: 343 None: Displays the visualization 344 """ 345 # Label the shoot mask 346 labeled_shoots, num_shoots = ndimage.label(shoot_mask) 347 348 # Get image dimensions 349 h, w = shoot_mask.shape 350 351 # Create blank RGB image 352 rgb_img = np.zeros((h, w, 3), dtype=np.uint8) 353 354 # Color palette (5 distinct colors) 355 colors = [ 356 [255, 0, 0], # Red 357 [255, 165, 0], # Orange 358 [255, 255, 0], # Yellow 359 [0, 255, 255], # Cyan 360 [255, 0, 255] # Magenta 361 ] 362 363 legend_elements = [] 364 365 # Process each shoot in order 366 for shoot_label, info in assignments.items(): 367 order = info['order'] 368 root_label = info['root_label'] 369 color = colors[order] 370 371 # Draw shoot 372 shoot_region = (labeled_shoots == shoot_label) 373 rgb_img[shoot_region] = color 374 375 # Draw root if assigned 376 if root_label is not None: 377 root_mask = structures['roots'][root_label]['mask'] 378 dilated_root = ndimage.binary_dilation(root_mask, iterations=2) 379 rgb_img[dilated_root] = color 380 381 score = info['score_dict']['score'] 382 legend_elements.append( 383 mpatches.Patch(color=np.array(color)/255, 384 label=f"S{order+1}->R{root_label} ({score:.2f})") 385 ) 386 else: 387 legend_elements.append( 388 mpatches.Patch(color=np.array(color)/255, 389 label=f"S{order+1}->None") 390 ) 391 392 # Create figure 393 fig, ax = plt.subplots(figsize=size) 394 ax.imshow(rgb_img) 395 396 # Add shoot order labels and root labels 397 for shoot_label, info in assignments.items(): 398 order = info['order'] 399 root_label = info['root_label'] 400 401 if root_label is not None: 402 # Find root endpoint (lowest y-coordinate) 403 root_mask = structures['roots'][root_label]['mask'] 404 y_coords, x_coords = np.where(root_mask) 405 406 if len(y_coords) > 0: 407 max_y = np.max(y_coords) 408 max_y_idx = np.argmax(y_coords) 409 endpoint_x = x_coords[max_y_idx] 410 411 # Shoot order label 50px below endpoint 412 ax.text(endpoint_x, max_y + 300, str(order + 1), 413 color='white', fontsize=20, fontweight='bold', 414 ha='center', va='center', 415 bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7)) 416 else: 417 # Place label at shoot centroid for ungerminated 418 cy, cx = ref_points[shoot_label]['centroid'] 419 ax.text(cx, cy, str(order + 1), 420 color='white', fontsize=20, fontweight='bold', 421 ha='center', va='center', 422 bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7)) 423 424 ax.legend(handles=legend_elements, loc='upper right', fontsize=10) 425 ax.set_title('Shoot-Root Assignments (Shoot# = Left->Right Order)', fontsize=12) 426 ax.axis('off') 427 428 plt.tight_layout() 429 plt.show() 430 431 432def visualize_root_shoot_matches(structures, labeled_shoots, root_matches, 433 zoom_to_content=True, padding=50, 434 root_thickness=3): 435 """ 436 Visualize root-to-shoot matches with color coding. 437 438 Args: 439 structures: Dict from extract_root_structures 440 labeled_shoots: Labeled shoot array 441 root_matches: Dict from match_roots_to_shoots 442 zoom_to_content: If True, zoom to bounding box of all content 443 padding: Pixels to add around bounding box 444 root_thickness: Pixels to dilate root skeletons for visibility 445 """ 446 # Create RGB image 447 h, w = structures['labeled_skeleton'].shape 448 vis_img = np.zeros((h, w, 3), dtype=np.uint8) 449 450 # Color map for shoots (distinct colors) 451 colors = [ 452 [255, 0, 0], # Red 453 [0, 255, 0], # Green 454 [0, 0, 255], # Blue 455 [255, 255, 0], # Yellow 456 [255, 0, 255], # Magenta 457 [0, 255, 255], # Cyan 458 [255, 128, 0], # Orange 459 [128, 0, 255], # Purple 460 ] 461 462 # Draw shoots with their colors 463 for shoot_label in range(1, labeled_shoots.max() + 1): 464 shoot_mask = labeled_shoots == shoot_label 465 color = colors[(shoot_label - 1) % len(colors)] 466 vis_img[shoot_mask] = color 467 468 # Draw roots with matching shoot colors (or white for unmatched) 469 for root_label in structures['unique_labels']: 470 root_skeleton = structures['labeled_skeleton'] == root_label 471 472 # Thicken the skeleton for visibility 473 if root_thickness > 0: 474 root_skeleton = ndimage.binary_dilation( 475 root_skeleton, 476 structure=ndimage.generate_binary_structure(2, 1), 477 iterations=root_thickness 478 ) 479 480 match = root_matches.get(root_label) 481 if match and match is not None: 482 shoot_label = match['shoot_label'] 483 color = colors[(shoot_label - 1) % len(colors)] 484 else: 485 color = [128, 128, 128] # Gray for unmatched 486 487 vis_img[root_skeleton] = color 488 489 # Find bounding box if zooming 490 if zoom_to_content: 491 # Find all non-black pixels 492 content_mask = np.any(vis_img > 0, axis=2) 493 rows, cols = np.where(content_mask) 494 495 if len(rows) > 0: 496 y_min = max(0, rows.min() - padding) 497 y_max = min(h, rows.max() + padding) 498 x_min = max(0, cols.min() - padding) 499 x_max = min(w, cols.max() + padding) 500 501 vis_img = vis_img[y_min:y_max, x_min:x_max] 502 503 # Plot 504 fig, ax = plt.subplots(figsize=(15, 8)) 505 ax.imshow(vis_img) 506 ax.set_title('Root-Shoot Matches (same color = matched)', fontsize=14) 507 ax.axis('off') 508 509 # Create legend 510 legend_elements = [] 511 matched_shoots = set() 512 for root_label, match in root_matches.items(): 513 if match and match is not None: 514 matched_shoots.add(match['shoot_label']) 515 516 for shoot_label in sorted(matched_shoots): 517 color = np.array(colors[(shoot_label - 1) % len(colors)]) / 255.0 518 legend_elements.append( 519 mpatches.Patch(color=color, label=f'Shoot/Root {shoot_label}') 520 ) 521 522 # Add unmatched if any 523 unmatched_count = sum(1 for m in root_matches.values() if m is None) 524 if unmatched_count > 0: 525 legend_elements.append( 526 mpatches.Patch(color=[0.5, 0.5, 0.5], 527 label=f'Unmatched ({unmatched_count})') 528 ) 529 530 ax.legend(handles=legend_elements, loc='lower right', fontsize=10) 531 532 plt.tight_layout() 533 plt.show() 534 535 # Print summary 536 print(f"\nMatch Summary:") 537 print(f"Total root structures: {len(structures['unique_labels'])}") 538 print(f"Matched: {sum(1 for m in root_matches.values() if m is not None)}") 539 print(f"Unmatched (noise): {unmatched_count}") 540 541 542def get_skeleton_path_from_branch(skeleton, branch_row): 543 """ 544 Extract actual skeleton pixels for a branch using the labeled skeleton. 545 546 Args: 547 skeleton: Binary skeleton array 548 branch_row: Row from branch_data DataFrame 549 550 Returns: 551 Array of (row, col) coordinates along the skeleton path 552 """ 553 from skimage.morphology import skeletonize 554 from skimage.graph import route_through_array 555 556 start_coord = (int(branch_row['image-coord-src-0']), 557 int(branch_row['image-coord-src-1'])) 558 end_coord = (int(branch_row['image-coord-dst-0']), 559 int(branch_row['image-coord-dst-1'])) 560 561 # Create cost array: low cost on skeleton, high elsewhere 562 cost = np.where(skeleton, 1, 10000) 563 564 try: 565 indices, weight = route_through_array( 566 cost, start_coord, end_coord, fully_connected=True 567 ) 568 return np.array(indices) 569 except: 570 # Fallback to straight line 571 return np.array([start_coord, end_coord]) 572 573def visualize_single_root_structure(skeleton, branch_data, title='Root Structure', 574 dilate_iterations=2, zoom_to_content=True, 575 padding=50, trunk_path=None, 576 path_edge_width=3, path_node_size=10, 577 show_edge_dots=True, edge_dot_spacing=50, 578 show_cumulative_length=True, 579 output_path=None): 580 """Visualize a single root structure with trunk path and cumulative length measurements. 581 582 Displays a root skeleton with the main trunk path highlighted in blue, showing nodes 583 along the path with gradient coloring (red to yellow) and optional cumulative length 584 labels at regular intervals. The path follows the actual skeleton pixels rather than 585 straight-line connections between nodes. 586 587 Args: 588 skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure. 589 branch_data (pd.DataFrame): DataFrame containing branch information with columns: 590 'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1', 591 'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'. 592 title (str, optional): Title for the visualization. Defaults to 'Root Structure'. 593 dilate_iterations (int, optional): Number of morphological dilation iterations 594 to thicken skeleton for visibility. Defaults to 2. 595 zoom_to_content (bool, optional): Whether to zoom to the bounding box of the 596 root structure. Defaults to True. 597 padding (int, optional): Padding in pixels around the zoomed content. 598 Defaults to 50. 599 trunk_path (list, optional): List of tuples from find_farthest_endpoint_path: 600 [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton 601 is shown. Defaults to None. 602 path_edge_width (float, optional): Line width for trunk path edges. 603 Defaults to 3. 604 path_node_size (float, optional): Marker size for nodes along trunk path. 605 Defaults to 10. 606 show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path. 607 Defaults to True. 608 edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum 609 spacing is 50 pixels to prevent overcrowding. Defaults to 50. 610 show_cumulative_length (bool, optional): Whether to display yellow labels showing 611 cumulative length at regular intervals along the path. Defaults to True. 612 output_path (str or Path): path to save image 613 614 Returns: 615 None: Displays matplotlib figure showing the root structure visualization. 616 617 Notes: 618 - Trunk path is drawn in blue following actual skeleton pixels 619 - Cyan dots mark regular intervals along the path 620 - Nodes are colored with autumn colormap (red=start, yellow=end) 621 - Yellow labels show cumulative length every other dot 622 - Lime green label shows total length at the final node 623 - Node IDs are displayed in white text boxes offset from nodes 624 625 Examples: 626 >>> skeleton = structures['labeled_skeleton'] == root_label 627 >>> branch_data = structures['roots'][root_label]['branch_data'] 628 >>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down') 629 >>> visualize_single_root_structure( 630 ... skeleton, branch_data, 631 ... title='Root 35 Structure', 632 ... trunk_path=path, 633 ... edge_dot_spacing=100 634 ... ) 635 """ 636 from scipy.ndimage import binary_dilation 637 638 # Thicken skeleton for display 639 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 640 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 641 642 # Get all node coordinates 643 all_node_coords = {} 644 for idx, row in branch_data.iterrows(): 645 all_node_coords[row['node-id-src']] = (row['image-coord-src-0'], 646 row['image-coord-src-1']) 647 all_node_coords[row['node-id-dst']] = (row['image-coord-dst-0'], 648 row['image-coord-dst-1']) 649 650 # Get trunk path nodes 651 path_nodes = [] 652 path_branch_indices = [] 653 if trunk_path: 654 for node, branch_idx, _, _ in trunk_path: 655 path_nodes.append(int(node)) 656 if branch_idx is not None: 657 path_branch_indices.append(branch_idx) 658 659 # Bounding box 660 if zoom_to_content and all_node_coords: 661 rows = [coord[0] for coord in all_node_coords.values()] 662 cols = [coord[1] for coord in all_node_coords.values()] 663 row_min, row_max = min(rows), max(rows) 664 col_min, col_max = min(cols), max(cols) 665 else: 666 zoom_to_content = False 667 668 # Plot 669 fig, ax = plt.subplots(figsize=(12, 12)) 670 ax.imshow(skeleton_uint8, cmap='gray') 671 672 # Draw trunk path and collect cumulative lengths 673 all_path_pixels = [] 674 cumulative_lengths = [] # Store (pixel_index, cumulative_length) 675 current_length = 0.0 676 total_length = 0.0 677 678 if trunk_path: 679 for i, (node, branch_idx, distance, _) in enumerate(trunk_path): 680 if branch_idx is not None: 681 branch_row = branch_data.iloc[branch_idx] 682 683 # Get actual skeleton pixels for this branch 684 path_pixels = get_skeleton_path_from_branch(skeleton, branch_row) 685 686 # Track cumulative length at start of this segment 687 pixel_start_idx = len(np.vstack(all_path_pixels)) if all_path_pixels else 0 688 cumulative_lengths.append((pixel_start_idx, current_length)) 689 690 all_path_pixels.append(path_pixels) 691 692 # Add this segment's length 693 current_length += distance 694 total_length += distance 695 696 # Draw the path in blue 697 ax.plot(path_pixels[:, 1], path_pixels[:, 0], '-', 698 color='blue', linewidth=path_edge_width, 699 alpha=0.8, zorder=2) 700 701 # Add final length 702 if all_path_pixels: 703 final_idx = len(np.vstack(all_path_pixels)) 704 cumulative_lengths.append((final_idx, current_length)) 705 706 # Draw dots and labels 707 if show_edge_dots and all_path_pixels: 708 # Concatenate all segments 709 combined_path = np.vstack(all_path_pixels) 710 711 # Ensure minimum spacing of 50px 712 actual_spacing = max(edge_dot_spacing, 50) 713 sampled_indices = np.arange(0, len(combined_path), actual_spacing) 714 715 # Draw dots 716 sampled_pixels = combined_path[sampled_indices] 717 ax.plot(sampled_pixels[:, 1], sampled_pixels[:, 0], 'o', 718 color='cyan', markersize=4, alpha=0.9, zorder=2.5, 719 markeredgecolor='blue', markeredgewidth=0.5) 720 721 # Add cumulative length labels 722 if show_cumulative_length: 723 for sample_idx in sampled_indices[::2]: # Label every other dot to reduce clutter 724 if sample_idx < len(combined_path): 725 # Find cumulative length at this pixel 726 cum_length = 0.0 727 for seg_idx, (pixel_idx, length) in enumerate(cumulative_lengths[:-1]): 728 next_pixel_idx = cumulative_lengths[seg_idx + 1][0] 729 if pixel_idx <= sample_idx < next_pixel_idx: 730 # Interpolate within this segment 731 segment_progress = (sample_idx - pixel_idx) / max(next_pixel_idx - pixel_idx, 1) 732 next_length = cumulative_lengths[seg_idx + 1][1] 733 cum_length = length + segment_progress * (next_length - length) 734 break 735 else: 736 cum_length = current_length 737 738 row, col = combined_path[sample_idx] 739 ax.text(col + 15, row, f'{cum_length:.0f}px', 740 color='yellow', fontsize=8, fontweight='bold', 741 ha='left', va='center', 742 bbox=dict(boxstyle='round,pad=0.2', 743 facecolor='black', alpha=0.6), 744 zorder=4) 745 746 # Draw path nodes 747 if trunk_path: 748 for i, node_id in enumerate(path_nodes): 749 if node_id in all_node_coords: 750 row, col = all_node_coords[node_id] 751 color = plt.cm.autumn(i / max(len(path_nodes)-1, 1)) 752 ax.plot(col, row, 'o', color=color, markersize=path_node_size, 753 zorder=3, markeredgecolor='white', markeredgewidth=1.5) 754 755 # Add node label 756 offset_x = 20 757 offset_y = -5 758 ax.text(col + offset_x, row + offset_y, str(node_id), 759 color='white', fontsize=9, fontweight='bold', 760 ha='left', va='center', 761 bbox=dict(boxstyle='round,pad=0.3', 762 facecolor='black', alpha=0.7), 763 zorder=4) 764 765 # Add total length at final node 766 if i == len(path_nodes) - 1 and show_cumulative_length: 767 ax.text(col + 15, row + 20, f'{total_length:.1f}px', 768 color='lime', fontsize=10, fontweight='bold', 769 ha='left', va='top', 770 bbox=dict(boxstyle='round,pad=0.3', 771 facecolor='black', alpha=0.8), 772 zorder=5) 773 774 if zoom_to_content: 775 ax.set_xlim(col_min - padding, col_max + padding) 776 ax.set_ylim(row_max + padding, row_min - padding) 777 778 ax.set_title(title, fontsize=14) 779 ax.axis('off') 780 plt.tight_layout() 781 if output_path: 782 plt.savefig(str(output_path)) 783 plt.show() 784 785 786def visualize_root_lengths(structures, top_node_results, labeled_shoots, 787 zoom_to_content=True, padding=50, 788 show_detailed_roots=False, detail_viz_kwargs=None): 789 """Visualize matched roots with their calculated lengths displayed. 790 791 Creates an overview visualization showing all matched root structures color-coded 792 by their associated shoot regions, with length measurements labeled. Optionally 793 displays detailed individual root visualizations below the overview. 794 795 Args: 796 structures (dict): Dictionary from extract_root_structures containing: 797 - 'labeled_skeleton': Array with labeled skeleton structures 798 - 'roots': Dictionary of root data by label 799 top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys: 800 - root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes' 801 labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each 802 shoot region has a unique integer label. 803 zoom_to_content (bool, optional): If True, zoom to bounding box of all content. 804 Defaults to True. 805 padding (int, optional): Pixels to add around bounding box when zooming. 806 Defaults to 50. 807 show_detailed_roots (bool, optional): If True, display detailed visualization 808 for each individual root below the overview. Defaults to False. 809 detail_viz_kwargs (dict, optional): Keyword arguments to pass to 810 visualize_single_root_structure for detailed views. Common options: 811 - 'dilate_iterations': int, skeleton thickness (default: 2) 812 - 'path_edge_width': float, trunk path line width (default: 3) 813 - 'path_node_size': float, node marker size (default: 10) 814 - 'show_edge_dots': bool, show dots along path (default: True) 815 - 'edge_dot_spacing': int, spacing between dots (default: 50) 816 - 'show_cumulative_length': bool, show length labels (default: True) 817 If None, uses default values. Defaults to None. 818 819 Returns: 820 None: Displays matplotlib figures showing root length visualizations. 821 822 Notes: 823 - Overview uses color-coding: each root matches its shoot color 824 - Roots are thickened with binary dilation for visibility 825 - Length labels appear at root centroids in white text boxes 826 - Detailed views are indexed (idx=0, idx=1, etc.) in display order 827 - Failed root measurements appear in gray with no length label 828 829 Examples: 830 >>> # Basic usage - overview only 831 >>> visualize_root_lengths(structures, top_node_results, labeled_shoots) 832 833 >>> # With detailed individual visualizations 834 >>> visualize_root_lengths( 835 ... structures, top_node_results, labeled_shoots, 836 ... show_detailed_roots=True, 837 ... detail_viz_kwargs={ 838 ... 'edge_dot_spacing': 100, 839 ... 'path_edge_width': 2, 840 ... 'show_cumulative_length': True 841 ... } 842 ... ) 843 """ 844 845 # Default kwargs for detailed visualizations 846 if detail_viz_kwargs is None: 847 detail_viz_kwargs = {} 848 849 # Create RGB image 850 h, w = structures['labeled_skeleton'].shape 851 vis_img = np.zeros((h, w, 3), dtype=np.uint8) 852 853 # Color map 854 colors = [ 855 [255, 0, 0], # Red 856 [0, 255, 0], # Green 857 [0, 0, 255], # Blue 858 [255, 255, 0], # Yellow 859 [255, 0, 255], # Magenta 860 [0, 255, 255], # Cyan 861 [255, 128, 0], # Orange 862 [128, 0, 255], # Purple 863 ] 864 865 # Draw shoots 866 for shoot_label in range(1, labeled_shoots.max() + 1): 867 shoot_mask = labeled_shoots == shoot_label 868 color = colors[(shoot_label - 1) % len(colors)] 869 vis_img[shoot_mask] = color 870 871 # Process and draw roots with paths 872 root_info = [] 873 root_details = [] # Store data for detailed visualizations 874 875 for root_label, result in top_node_results.items(): 876 branch_data = result['branch_data'] 877 shoot_label = result['shoot_label'] 878 top_node = result['top_nodes'][0][0] 879 880 # Get root skeleton and thicken 881 root_skeleton = structures['labeled_skeleton'] == root_label 882 root_skeleton = ndimage.binary_dilation( 883 root_skeleton, 884 structure=ndimage.generate_binary_structure(2, 1), 885 iterations=3 886 ) 887 888 color = colors[(shoot_label - 1) % len(colors)] 889 890 try: 891 # Find path and calculate length 892 path = find_farthest_endpoint_path( 893 branch_data, top_node, 894 direction='down', use_smart_scoring=True, 895 verbose=False 896 ) 897 root_length = calculate_skeleton_length_px(path) 898 899 # Draw root 900 vis_img[root_skeleton] = color 901 902 # Store info for labels 903 y_coords, x_coords = np.where(root_skeleton) 904 centroid_y, centroid_x = np.mean(y_coords), np.mean(x_coords) 905 root_info.append((centroid_x, centroid_y, root_label, shoot_label, root_length)) 906 907 # Store for detailed visualization 908 if show_detailed_roots: 909 root_details.append({ 910 'root_label': root_label, 911 'skeleton': structures['labeled_skeleton'] == root_label, 912 'branch_data': branch_data, 913 'path': path, 914 'shoot_label': shoot_label 915 }) 916 917 except Exception as e: 918 # Draw root in gray if failed 919 vis_img[root_skeleton] = [128, 128, 128] 920 921 # Zoom if requested 922 if zoom_to_content: 923 content_mask = np.any(vis_img > 0, axis=2) 924 rows, cols = np.where(content_mask) 925 if len(rows) > 0: 926 y_min = max(0, rows.min() - padding) 927 y_max = min(h, rows.max() + padding) 928 x_min = max(0, cols.min() - padding) 929 x_max = min(w, cols.max() + padding) 930 vis_img = vis_img[y_min:y_max, x_min:x_max] 931 # Adjust coordinates for zoom 932 root_info = [(x - x_min, y - y_min, rl, sl, length) 933 for x, y, rl, sl, length in root_info] 934 935 # Plot overview 936 fig, ax = plt.subplots(figsize=(15, 8)) 937 ax.imshow(vis_img) 938 939 # Add text labels 940 for x, y, root_label, shoot_label, length in root_info: 941 ax.text(x, y, f'{length:.1f}px', 942 color='white', fontsize=10, weight='bold', 943 ha='center', va='center', 944 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7)) 945 946 ax.set_title('Root Lengths (pixels)', fontsize=14) 947 ax.axis('off') 948 plt.tight_layout() 949 plt.show() 950 951 # Display detailed visualizations if requested 952 if show_detailed_roots and root_details: 953 print(f"\n{'='*60}") 954 print(f"Detailed Root Visualizations ({len(root_details)} roots)") 955 print(f"{'='*60}\n") 956 957 for idx, detail in enumerate(root_details): 958 print(f"Displaying Root {detail['root_label']} (idx={idx})...") 959 960 visualize_single_root_structure( 961 skeleton=detail['skeleton'], 962 branch_data=detail['branch_data'], 963 title=f"Root {detail['root_label']} (idx={idx}) - Shoot {detail['shoot_label']}", 964 trunk_path=detail['path'], 965 **detail_viz_kwargs 966 )
12def visualize_skeleton_with_nodes(skeleton, show_branch_labels=False, row_range=None, padding=50, dilate_iterations=3): 13 """ 14 Visualize skeleton with nodes and optionally branch labels. 15 16 Args: 17 skeleton (skeleton numpy array): object to show 18 show_branch_labels (bool): add labels 19 row_range(tuple): rows to zoom in on 20 padding(int): add additional pixels to make skeleton easier to view 21 dilate_iterations(int): number of rounds to dilate and increase size 22 23 Returns: 24 branch_data (pandas table) 25 """ 26 27 # Thicken skeleton for visibility 28 if dilate_iterations > 0: 29 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 30 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 31 else: 32 skeleton_uint8 = (skeleton * 255).astype(np.uint8) 33 34 # Find bounding box 35 rows, cols = np.where(skeleton) 36 37 if row_range is not None: 38 row_min, row_max = row_range 39 mask = (rows >= row_min) & (rows <= row_max) 40 if np.any(mask): 41 col_min, col_max = cols[mask].min(), cols[mask].max() 42 else: 43 col_min, col_max = cols.min(), cols.max() 44 else: 45 row_min, row_max = rows.min(), rows.max() 46 col_min, col_max = cols.min(), cols.max() 47 48 # Get branch data and nodes 49 skeleton_obj = Skeleton(skeleton) 50 branch_data = summarize(skeleton_obj) 51 52 # Get node coordinates 53 node_coords = {} 54 for idx, row in branch_data.iterrows(): 55 node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1']) 56 node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1']) 57 58 # Plot 59 fig, ax = plt.subplots(figsize=(14, 12)) 60 ax.imshow(skeleton_uint8, cmap='gray') 61 62 # Track occupied regions (for collision detection) 63 occupied_regions = [] # List of (row, col, radius) tuples 64 65 # First pass: plot node markers 66 visible_nodes = [] 67 for node_id, (node_row, node_col) in node_coords.items(): 68 if (row_min - padding <= node_row <= row_max + padding and 69 col_min - padding <= node_col <= col_max + padding): 70 visible_nodes.append((node_id, node_row, node_col)) 71 ax.plot(node_col, node_row, 'ro', markersize=12, zorder=3) 72 # Mark node position as occupied 73 occupied_regions.append((node_row, node_col, 15)) # Small radius for node dot itself 74 75 # Second pass: place node labels with collision avoidance 76 for node_id, node_row, node_col in visible_nodes: 77 # Try different offset positions for node label 78 offsets = [(40, 0), (-40, 0), (0, 40), (0, -40), (30, 30), (-30, -30), (30, -30), (-30, 30)] 79 best_offset = (40, 0) 80 min_collision = float('inf') 81 82 for offset_col, offset_row in offsets: 83 test_col = node_col + offset_col 84 test_row = node_row + offset_row 85 86 # Check collision with all occupied regions 87 max_collision = 0 88 for occ_row, occ_col, radius in occupied_regions: 89 dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2) 90 if dist < radius: 91 max_collision = max(max_collision, radius - dist) 92 93 if max_collision < min_collision: 94 min_collision = max_collision 95 best_offset = (offset_col, offset_row) 96 if max_collision == 0: 97 break 98 99 label_col = node_col + best_offset[0] 100 label_row = node_row + best_offset[1] 101 102 # Draw leader line 103 ax.plot([node_col, label_col], [node_row, label_row], 104 'y--', linewidth=1, alpha=0.6, zorder=2) 105 106 # Draw label 107 ax.text(label_col, label_row, str(int(node_id)), color='yellow', fontsize=11, 108 bbox=dict(boxstyle='round', facecolor='black', alpha=0.7), 109 ha='center', zorder=4) 110 111 # Mark label area as occupied 112 occupied_regions.append((label_row, label_col, 45)) 113 114 # Branch labels with collision avoidance 115 if show_branch_labels: 116 branch_positions = [] 117 for idx, row in branch_data.iterrows(): 118 mid_row = (row['coord-src-0'] + row['coord-dst-0']) / 2 119 mid_col = (row['coord-src-1'] + row['coord-dst-1']) / 2 120 121 if (row_min - padding <= mid_row <= row_max + padding and 122 col_min - padding <= mid_col <= col_max + padding): 123 branch_positions.append((idx, mid_row, mid_col)) 124 125 for idx, mid_row, mid_col in branch_positions: 126 # Try different offset positions 127 offsets = [(60, 0), (-60, 0), (0, 60), (0, -60), (45, 45), (-45, -45), 128 (120, 0), (-120, 0), (0, 120), (0, -120)] 129 best_offset = (60, 0) 130 min_collision = float('inf') 131 132 for offset_col, offset_row in offsets: 133 test_col = mid_col + offset_col 134 test_row = mid_row + offset_row 135 136 max_collision = 0 137 for occ_row, occ_col, radius in occupied_regions: 138 dist = np.sqrt((test_row - occ_row)**2 + (test_col - occ_col)**2) 139 if dist < radius: 140 max_collision = max(max_collision, radius - dist) 141 142 if max_collision < min_collision: 143 min_collision = max_collision 144 best_offset = (offset_col, offset_row) 145 if max_collision == 0: 146 break 147 148 label_col = mid_col + best_offset[0] 149 label_row = mid_row + best_offset[1] 150 151 # Draw leader line 152 ax.plot([mid_col, label_col], [mid_row, label_row], 153 'c--', linewidth=1, alpha=0.5, zorder=1) 154 155 # Draw label 156 ax.text(label_col, label_row, f"B{idx}", color='cyan', fontsize=9, 157 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7), 158 ha='center', zorder=2) 159 160 occupied_regions.append((label_row, label_col, 40)) 161 162 # Set zoom limits 163 ax.set_xlim(col_min - padding, col_max + padding) 164 ax.set_ylim(row_max + padding, row_min - padding) 165 166 title = f'Skeleton with {len(node_coords)} nodes, {branch_data.shape[0]} branches' 167 if row_range: 168 title += f' (rows {row_range[0]}-{row_range[1]})' 169 ax.set_title(title) 170 171 plt.show() 172 173 return branch_data
Visualize skeleton with nodes and optionally branch labels.
Arguments:
- skeleton (skeleton numpy array): object to show
- show_branch_labels (bool): add labels
- row_range(tuple): rows to zoom in on
- padding(int): add additional pixels to make skeleton easier to view
- dilate_iterations(int): number of rounds to dilate and increase size
Returns:
branch_data (pandas table)
175def visualize_trunk_path(skeleton, branch_data, trunk_path, title='Trunk Path', 176 dilate_iterations=3, zoom_to_path=True, padding=100): 177 """ 178 Visualize the trunk path overlaid on the skeleton image. 179 180 Args: 181 skeleton: Binary skeleton array 182 branch_data: DataFrame with branch information 183 trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...] 184 title: Plot title 185 dilate_iterations: Number of iterations to thicken skeleton 186 zoom_to_path: Whether to zoom to the trunk path region 187 padding: Padding around zoom region 188 """ 189 190 # Thicken skeleton for visibility 191 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 192 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 193 194 # Get all nodes in the path 195 path_nodes = [int(node) for node, _, _, _ in trunk_path] 196 path_branches = [idx for _, idx, _, _ in trunk_path if idx is not None] 197 198 # Get coordinates for all nodes 199 all_node_coords = {} 200 for idx, row in branch_data.iterrows(): 201 all_node_coords[row['node-id-src']] = (row['coord-src-0'], row['coord-src-1']) 202 all_node_coords[row['node-id-dst']] = (row['coord-dst-0'], row['coord-dst-1']) 203 204 # Get bounding box of path nodes for zooming 205 if zoom_to_path and path_nodes: 206 path_rows = [all_node_coords[n][0] for n in path_nodes if n in all_node_coords] 207 path_cols = [all_node_coords[n][1] for n in path_nodes if n in all_node_coords] 208 if path_rows and path_cols: 209 row_min, row_max = min(path_rows), max(path_rows) 210 col_min, col_max = min(path_cols), max(path_cols) 211 else: 212 zoom_to_path = False 213 214 # Plot 215 fig, ax = plt.subplots(figsize=(14, 14)) 216 ax.imshow(skeleton_uint8, cmap='gray') 217 218 # Draw all nodes (small, gray, semi-transparent) 219 for node_id, (row, col) in all_node_coords.items(): 220 ax.plot(col, row, 'o', color='gray', markersize=4, alpha=0.3) 221 222 # Draw trunk path connections FIRST (behind nodes) 223 for i in range(len(path_nodes)-1): 224 node1, node2 = path_nodes[i], path_nodes[i+1] 225 if node1 in all_node_coords and node2 in all_node_coords: 226 r1, c1 = all_node_coords[node1] 227 r2, c2 = all_node_coords[node2] 228 ax.plot([c1, c2], [r1, r2], 'r-', linewidth=4, alpha=0.8, zorder=2) 229 230 # Draw trunk path nodes (large, colored) 231 for i, node_id in enumerate(path_nodes): 232 if node_id in all_node_coords: 233 row, col = all_node_coords[node_id] 234 # Color gradient: dark red (start) to bright yellow (end) 235 color = plt.cm.autumn(i / max(len(path_nodes)-1, 1)) 236 ax.plot(col, row, 'o', color=color, markersize=16, zorder=3, 237 markeredgecolor='white', markeredgewidth=2) 238 ax.text(col+15, row, str(node_id), color='cyan', fontsize=12, 239 fontweight='bold', 240 bbox=dict(boxstyle='round', facecolor='black', alpha=0.8), 241 zorder=4) 242 243 # Zoom to path if requested 244 if zoom_to_path: 245 ax.set_xlim(col_min - padding, col_max + padding) 246 ax.set_ylim(row_max + padding, row_min - padding) 247 248 # Title with path info 249 path_str = " → ".join(map(str, path_nodes[:10])) # First 10 nodes 250 if len(path_nodes) > 10: 251 path_str += "..." 252 ax.set_title(f'{title}\nPath: {path_str}\n{title}', fontsize=14) 253 ax.axis('off') 254 plt.tight_layout() 255 plt.show() 256 257 # Print summary 258 print(f"Trunk path: {len(path_nodes)} nodes, {len(path_branches)} branches") 259 print(f"Start node: {path_nodes[0]}, End node: {path_nodes[-1]}") 260 261 # Calculate total length 262 total_length = sum(branch_data.iloc[idx]['branch-distance'] 263 for idx in path_branches) 264 print(f"Total trunk length: {total_length:.2f} pixels") 265 266 return total_length
Visualize the trunk path overlaid on the skeleton image.
Arguments:
- skeleton: Binary skeleton array
- branch_data: DataFrame with branch information
- trunk_path: List of tuples from trace_vertical_path [(node, branch_idx, slope, vertical), ...]
- title: Plot title
- dilate_iterations: Number of iterations to thicken skeleton
- zoom_to_path: Whether to zoom to the trunk path region
- padding: Padding around zoom region
268def visualize_thickened_skeleton(binary_mask, dilate_iterations=3, figsize=(12, 8), show_labels=True): 269 """ 270 Skeletonize, label, thicken, and visualize a binary mask. 271 272 Args: 273 binary_mask: Binary numpy array (bool or 0/255) or path to image file 274 dilate_iterations: Number of dilation iterations to thicken skeleton 275 figsize: Figure size tuple 276 show_labels: Whether to show component labels on the image 277 278 Returns: 279 tuple: (skeleton, labeled_skeleton, unique_labels) 280 """ 281 # Load from file if path provided 282 if not isinstance(binary_mask, np.ndarray): 283 binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE) 284 285 286 # Get labeled skeleton 287 skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask) 288 289 # Thicken skeleton for visibility 290 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 291 thick_skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 292 293 # Visualize 294 fig, ax = plt.subplots(figsize=figsize) 295 ax.imshow(thick_skeleton_uint8, cmap='gray') 296 ax.set_title(f'{len(unique_labels)} skeleton(s) (thickened for visibility)') 297 ax.axis('off') 298 299 # Add labels for each structure 300 if show_labels and len(unique_labels) > 0: 301 for label_id in unique_labels: 302 # Get all pixels for this label 303 component_mask = (labeled_skeleton == label_id) 304 rows, cols = np.where(component_mask) 305 306 if len(rows) > 0: 307 # Calculate centroid of this component 308 centroid_row = int(np.mean(rows)) 309 centroid_col = int(np.mean(cols)) 310 311 # Find the bottommost point (highest row value) 312 max_row = rows.max() 313 314 # Place label below the bottommost point 315 label_row = max_row + 50 # 50 pixels below 316 label_col = centroid_col 317 318 ax.text(label_col, label_row, str(int(label_id)), 319 color='yellow', fontsize=14, fontweight='bold', 320 bbox=dict(boxstyle='round,pad=0.5', facecolor='red', 321 edgecolor='white', linewidth=2, alpha=0.8), 322 ha='center', va='center') 323 324 plt.show() 325 326 print(f"Found {len(unique_labels)} connected structure(s)") 327 print(f"Total skeleton pixels: {np.sum(skeleton)}") 328 329 return skeleton, labeled_skeleton, unique_labels
Skeletonize, label, thicken, and visualize a binary mask.
Arguments:
- binary_mask: Binary numpy array (bool or 0/255) or path to image file
- dilate_iterations: Number of dilation iterations to thicken skeleton
- figsize: Figure size tuple
- show_labels: Whether to show component labels on the image
Returns:
tuple: (skeleton, labeled_skeleton, unique_labels)
332def visualize_assignments(shoot_mask, root_mask, structures, assignments, ref_points, size=(16, 8)): 333 """Visualize shoot-root assignments with ordered labels 1-5 and root labels. 334 335 Args: 336 shoot_mask: Binary mask of shoot regions 337 root_mask: Binary mask of root regions (not directly used) 338 structures: Dictionary from extract_root_structures containing 'roots' with root data 339 assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments 340 ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids 341 size: Figure size tuple (width, height) 342 343 Returns: 344 None: Displays the visualization 345 """ 346 # Label the shoot mask 347 labeled_shoots, num_shoots = ndimage.label(shoot_mask) 348 349 # Get image dimensions 350 h, w = shoot_mask.shape 351 352 # Create blank RGB image 353 rgb_img = np.zeros((h, w, 3), dtype=np.uint8) 354 355 # Color palette (5 distinct colors) 356 colors = [ 357 [255, 0, 0], # Red 358 [255, 165, 0], # Orange 359 [255, 255, 0], # Yellow 360 [0, 255, 255], # Cyan 361 [255, 0, 255] # Magenta 362 ] 363 364 legend_elements = [] 365 366 # Process each shoot in order 367 for shoot_label, info in assignments.items(): 368 order = info['order'] 369 root_label = info['root_label'] 370 color = colors[order] 371 372 # Draw shoot 373 shoot_region = (labeled_shoots == shoot_label) 374 rgb_img[shoot_region] = color 375 376 # Draw root if assigned 377 if root_label is not None: 378 root_mask = structures['roots'][root_label]['mask'] 379 dilated_root = ndimage.binary_dilation(root_mask, iterations=2) 380 rgb_img[dilated_root] = color 381 382 score = info['score_dict']['score'] 383 legend_elements.append( 384 mpatches.Patch(color=np.array(color)/255, 385 label=f"S{order+1}->R{root_label} ({score:.2f})") 386 ) 387 else: 388 legend_elements.append( 389 mpatches.Patch(color=np.array(color)/255, 390 label=f"S{order+1}->None") 391 ) 392 393 # Create figure 394 fig, ax = plt.subplots(figsize=size) 395 ax.imshow(rgb_img) 396 397 # Add shoot order labels and root labels 398 for shoot_label, info in assignments.items(): 399 order = info['order'] 400 root_label = info['root_label'] 401 402 if root_label is not None: 403 # Find root endpoint (lowest y-coordinate) 404 root_mask = structures['roots'][root_label]['mask'] 405 y_coords, x_coords = np.where(root_mask) 406 407 if len(y_coords) > 0: 408 max_y = np.max(y_coords) 409 max_y_idx = np.argmax(y_coords) 410 endpoint_x = x_coords[max_y_idx] 411 412 # Shoot order label 50px below endpoint 413 ax.text(endpoint_x, max_y + 300, str(order + 1), 414 color='white', fontsize=20, fontweight='bold', 415 ha='center', va='center', 416 bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7)) 417 else: 418 # Place label at shoot centroid for ungerminated 419 cy, cx = ref_points[shoot_label]['centroid'] 420 ax.text(cx, cy, str(order + 1), 421 color='white', fontsize=20, fontweight='bold', 422 ha='center', va='center', 423 bbox=dict(boxstyle='circle', facecolor='black', alpha=0.7)) 424 425 ax.legend(handles=legend_elements, loc='upper right', fontsize=10) 426 ax.set_title('Shoot-Root Assignments (Shoot# = Left->Right Order)', fontsize=12) 427 ax.axis('off') 428 429 plt.tight_layout() 430 plt.show()
Visualize shoot-root assignments with ordered labels 1-5 and root labels.
Arguments:
- shoot_mask: Binary mask of shoot regions
- root_mask: Binary mask of root regions (not directly used)
- structures: Dictionary from extract_root_structures containing 'roots' with root data
- assignments: Dictionary from assign_roots_to_shoots_greedy with shoot assignments
- ref_points: Dictionary from get_shoot_reference_points_multi with shoot centroids
- size: Figure size tuple (width, height)
Returns:
None: Displays the visualization
433def visualize_root_shoot_matches(structures, labeled_shoots, root_matches, 434 zoom_to_content=True, padding=50, 435 root_thickness=3): 436 """ 437 Visualize root-to-shoot matches with color coding. 438 439 Args: 440 structures: Dict from extract_root_structures 441 labeled_shoots: Labeled shoot array 442 root_matches: Dict from match_roots_to_shoots 443 zoom_to_content: If True, zoom to bounding box of all content 444 padding: Pixels to add around bounding box 445 root_thickness: Pixels to dilate root skeletons for visibility 446 """ 447 # Create RGB image 448 h, w = structures['labeled_skeleton'].shape 449 vis_img = np.zeros((h, w, 3), dtype=np.uint8) 450 451 # Color map for shoots (distinct colors) 452 colors = [ 453 [255, 0, 0], # Red 454 [0, 255, 0], # Green 455 [0, 0, 255], # Blue 456 [255, 255, 0], # Yellow 457 [255, 0, 255], # Magenta 458 [0, 255, 255], # Cyan 459 [255, 128, 0], # Orange 460 [128, 0, 255], # Purple 461 ] 462 463 # Draw shoots with their colors 464 for shoot_label in range(1, labeled_shoots.max() + 1): 465 shoot_mask = labeled_shoots == shoot_label 466 color = colors[(shoot_label - 1) % len(colors)] 467 vis_img[shoot_mask] = color 468 469 # Draw roots with matching shoot colors (or white for unmatched) 470 for root_label in structures['unique_labels']: 471 root_skeleton = structures['labeled_skeleton'] == root_label 472 473 # Thicken the skeleton for visibility 474 if root_thickness > 0: 475 root_skeleton = ndimage.binary_dilation( 476 root_skeleton, 477 structure=ndimage.generate_binary_structure(2, 1), 478 iterations=root_thickness 479 ) 480 481 match = root_matches.get(root_label) 482 if match and match is not None: 483 shoot_label = match['shoot_label'] 484 color = colors[(shoot_label - 1) % len(colors)] 485 else: 486 color = [128, 128, 128] # Gray for unmatched 487 488 vis_img[root_skeleton] = color 489 490 # Find bounding box if zooming 491 if zoom_to_content: 492 # Find all non-black pixels 493 content_mask = np.any(vis_img > 0, axis=2) 494 rows, cols = np.where(content_mask) 495 496 if len(rows) > 0: 497 y_min = max(0, rows.min() - padding) 498 y_max = min(h, rows.max() + padding) 499 x_min = max(0, cols.min() - padding) 500 x_max = min(w, cols.max() + padding) 501 502 vis_img = vis_img[y_min:y_max, x_min:x_max] 503 504 # Plot 505 fig, ax = plt.subplots(figsize=(15, 8)) 506 ax.imshow(vis_img) 507 ax.set_title('Root-Shoot Matches (same color = matched)', fontsize=14) 508 ax.axis('off') 509 510 # Create legend 511 legend_elements = [] 512 matched_shoots = set() 513 for root_label, match in root_matches.items(): 514 if match and match is not None: 515 matched_shoots.add(match['shoot_label']) 516 517 for shoot_label in sorted(matched_shoots): 518 color = np.array(colors[(shoot_label - 1) % len(colors)]) / 255.0 519 legend_elements.append( 520 mpatches.Patch(color=color, label=f'Shoot/Root {shoot_label}') 521 ) 522 523 # Add unmatched if any 524 unmatched_count = sum(1 for m in root_matches.values() if m is None) 525 if unmatched_count > 0: 526 legend_elements.append( 527 mpatches.Patch(color=[0.5, 0.5, 0.5], 528 label=f'Unmatched ({unmatched_count})') 529 ) 530 531 ax.legend(handles=legend_elements, loc='lower right', fontsize=10) 532 533 plt.tight_layout() 534 plt.show() 535 536 # Print summary 537 print(f"\nMatch Summary:") 538 print(f"Total root structures: {len(structures['unique_labels'])}") 539 print(f"Matched: {sum(1 for m in root_matches.values() if m is not None)}") 540 print(f"Unmatched (noise): {unmatched_count}")
Visualize root-to-shoot matches with color coding.
Arguments:
- structures: Dict from extract_root_structures
- labeled_shoots: Labeled shoot array
- root_matches: Dict from match_roots_to_shoots
- zoom_to_content: If True, zoom to bounding box of all content
- padding: Pixels to add around bounding box
- root_thickness: Pixels to dilate root skeletons for visibility
543def get_skeleton_path_from_branch(skeleton, branch_row): 544 """ 545 Extract actual skeleton pixels for a branch using the labeled skeleton. 546 547 Args: 548 skeleton: Binary skeleton array 549 branch_row: Row from branch_data DataFrame 550 551 Returns: 552 Array of (row, col) coordinates along the skeleton path 553 """ 554 from skimage.morphology import skeletonize 555 from skimage.graph import route_through_array 556 557 start_coord = (int(branch_row['image-coord-src-0']), 558 int(branch_row['image-coord-src-1'])) 559 end_coord = (int(branch_row['image-coord-dst-0']), 560 int(branch_row['image-coord-dst-1'])) 561 562 # Create cost array: low cost on skeleton, high elsewhere 563 cost = np.where(skeleton, 1, 10000) 564 565 try: 566 indices, weight = route_through_array( 567 cost, start_coord, end_coord, fully_connected=True 568 ) 569 return np.array(indices) 570 except: 571 # Fallback to straight line 572 return np.array([start_coord, end_coord])
Extract actual skeleton pixels for a branch using the labeled skeleton.
Arguments:
- skeleton: Binary skeleton array
- branch_row: Row from branch_data DataFrame
Returns:
Array of (row, col) coordinates along the skeleton path
574def visualize_single_root_structure(skeleton, branch_data, title='Root Structure', 575 dilate_iterations=2, zoom_to_content=True, 576 padding=50, trunk_path=None, 577 path_edge_width=3, path_node_size=10, 578 show_edge_dots=True, edge_dot_spacing=50, 579 show_cumulative_length=True, 580 output_path=None): 581 """Visualize a single root structure with trunk path and cumulative length measurements. 582 583 Displays a root skeleton with the main trunk path highlighted in blue, showing nodes 584 along the path with gradient coloring (red to yellow) and optional cumulative length 585 labels at regular intervals. The path follows the actual skeleton pixels rather than 586 straight-line connections between nodes. 587 588 Args: 589 skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure. 590 branch_data (pd.DataFrame): DataFrame containing branch information with columns: 591 'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1', 592 'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'. 593 title (str, optional): Title for the visualization. Defaults to 'Root Structure'. 594 dilate_iterations (int, optional): Number of morphological dilation iterations 595 to thicken skeleton for visibility. Defaults to 2. 596 zoom_to_content (bool, optional): Whether to zoom to the bounding box of the 597 root structure. Defaults to True. 598 padding (int, optional): Padding in pixels around the zoomed content. 599 Defaults to 50. 600 trunk_path (list, optional): List of tuples from find_farthest_endpoint_path: 601 [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton 602 is shown. Defaults to None. 603 path_edge_width (float, optional): Line width for trunk path edges. 604 Defaults to 3. 605 path_node_size (float, optional): Marker size for nodes along trunk path. 606 Defaults to 10. 607 show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path. 608 Defaults to True. 609 edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum 610 spacing is 50 pixels to prevent overcrowding. Defaults to 50. 611 show_cumulative_length (bool, optional): Whether to display yellow labels showing 612 cumulative length at regular intervals along the path. Defaults to True. 613 output_path (str or Path): path to save image 614 615 Returns: 616 None: Displays matplotlib figure showing the root structure visualization. 617 618 Notes: 619 - Trunk path is drawn in blue following actual skeleton pixels 620 - Cyan dots mark regular intervals along the path 621 - Nodes are colored with autumn colormap (red=start, yellow=end) 622 - Yellow labels show cumulative length every other dot 623 - Lime green label shows total length at the final node 624 - Node IDs are displayed in white text boxes offset from nodes 625 626 Examples: 627 >>> skeleton = structures['labeled_skeleton'] == root_label 628 >>> branch_data = structures['roots'][root_label]['branch_data'] 629 >>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down') 630 >>> visualize_single_root_structure( 631 ... skeleton, branch_data, 632 ... title='Root 35 Structure', 633 ... trunk_path=path, 634 ... edge_dot_spacing=100 635 ... ) 636 """ 637 from scipy.ndimage import binary_dilation 638 639 # Thicken skeleton for display 640 thick_skeleton = binary_dilation(skeleton, iterations=dilate_iterations) 641 skeleton_uint8 = (thick_skeleton * 255).astype(np.uint8) 642 643 # Get all node coordinates 644 all_node_coords = {} 645 for idx, row in branch_data.iterrows(): 646 all_node_coords[row['node-id-src']] = (row['image-coord-src-0'], 647 row['image-coord-src-1']) 648 all_node_coords[row['node-id-dst']] = (row['image-coord-dst-0'], 649 row['image-coord-dst-1']) 650 651 # Get trunk path nodes 652 path_nodes = [] 653 path_branch_indices = [] 654 if trunk_path: 655 for node, branch_idx, _, _ in trunk_path: 656 path_nodes.append(int(node)) 657 if branch_idx is not None: 658 path_branch_indices.append(branch_idx) 659 660 # Bounding box 661 if zoom_to_content and all_node_coords: 662 rows = [coord[0] for coord in all_node_coords.values()] 663 cols = [coord[1] for coord in all_node_coords.values()] 664 row_min, row_max = min(rows), max(rows) 665 col_min, col_max = min(cols), max(cols) 666 else: 667 zoom_to_content = False 668 669 # Plot 670 fig, ax = plt.subplots(figsize=(12, 12)) 671 ax.imshow(skeleton_uint8, cmap='gray') 672 673 # Draw trunk path and collect cumulative lengths 674 all_path_pixels = [] 675 cumulative_lengths = [] # Store (pixel_index, cumulative_length) 676 current_length = 0.0 677 total_length = 0.0 678 679 if trunk_path: 680 for i, (node, branch_idx, distance, _) in enumerate(trunk_path): 681 if branch_idx is not None: 682 branch_row = branch_data.iloc[branch_idx] 683 684 # Get actual skeleton pixels for this branch 685 path_pixels = get_skeleton_path_from_branch(skeleton, branch_row) 686 687 # Track cumulative length at start of this segment 688 pixel_start_idx = len(np.vstack(all_path_pixels)) if all_path_pixels else 0 689 cumulative_lengths.append((pixel_start_idx, current_length)) 690 691 all_path_pixels.append(path_pixels) 692 693 # Add this segment's length 694 current_length += distance 695 total_length += distance 696 697 # Draw the path in blue 698 ax.plot(path_pixels[:, 1], path_pixels[:, 0], '-', 699 color='blue', linewidth=path_edge_width, 700 alpha=0.8, zorder=2) 701 702 # Add final length 703 if all_path_pixels: 704 final_idx = len(np.vstack(all_path_pixels)) 705 cumulative_lengths.append((final_idx, current_length)) 706 707 # Draw dots and labels 708 if show_edge_dots and all_path_pixels: 709 # Concatenate all segments 710 combined_path = np.vstack(all_path_pixels) 711 712 # Ensure minimum spacing of 50px 713 actual_spacing = max(edge_dot_spacing, 50) 714 sampled_indices = np.arange(0, len(combined_path), actual_spacing) 715 716 # Draw dots 717 sampled_pixels = combined_path[sampled_indices] 718 ax.plot(sampled_pixels[:, 1], sampled_pixels[:, 0], 'o', 719 color='cyan', markersize=4, alpha=0.9, zorder=2.5, 720 markeredgecolor='blue', markeredgewidth=0.5) 721 722 # Add cumulative length labels 723 if show_cumulative_length: 724 for sample_idx in sampled_indices[::2]: # Label every other dot to reduce clutter 725 if sample_idx < len(combined_path): 726 # Find cumulative length at this pixel 727 cum_length = 0.0 728 for seg_idx, (pixel_idx, length) in enumerate(cumulative_lengths[:-1]): 729 next_pixel_idx = cumulative_lengths[seg_idx + 1][0] 730 if pixel_idx <= sample_idx < next_pixel_idx: 731 # Interpolate within this segment 732 segment_progress = (sample_idx - pixel_idx) / max(next_pixel_idx - pixel_idx, 1) 733 next_length = cumulative_lengths[seg_idx + 1][1] 734 cum_length = length + segment_progress * (next_length - length) 735 break 736 else: 737 cum_length = current_length 738 739 row, col = combined_path[sample_idx] 740 ax.text(col + 15, row, f'{cum_length:.0f}px', 741 color='yellow', fontsize=8, fontweight='bold', 742 ha='left', va='center', 743 bbox=dict(boxstyle='round,pad=0.2', 744 facecolor='black', alpha=0.6), 745 zorder=4) 746 747 # Draw path nodes 748 if trunk_path: 749 for i, node_id in enumerate(path_nodes): 750 if node_id in all_node_coords: 751 row, col = all_node_coords[node_id] 752 color = plt.cm.autumn(i / max(len(path_nodes)-1, 1)) 753 ax.plot(col, row, 'o', color=color, markersize=path_node_size, 754 zorder=3, markeredgecolor='white', markeredgewidth=1.5) 755 756 # Add node label 757 offset_x = 20 758 offset_y = -5 759 ax.text(col + offset_x, row + offset_y, str(node_id), 760 color='white', fontsize=9, fontweight='bold', 761 ha='left', va='center', 762 bbox=dict(boxstyle='round,pad=0.3', 763 facecolor='black', alpha=0.7), 764 zorder=4) 765 766 # Add total length at final node 767 if i == len(path_nodes) - 1 and show_cumulative_length: 768 ax.text(col + 15, row + 20, f'{total_length:.1f}px', 769 color='lime', fontsize=10, fontweight='bold', 770 ha='left', va='top', 771 bbox=dict(boxstyle='round,pad=0.3', 772 facecolor='black', alpha=0.8), 773 zorder=5) 774 775 if zoom_to_content: 776 ax.set_xlim(col_min - padding, col_max + padding) 777 ax.set_ylim(row_max + padding, row_min - padding) 778 779 ax.set_title(title, fontsize=14) 780 ax.axis('off') 781 plt.tight_layout() 782 if output_path: 783 plt.savefig(str(output_path)) 784 plt.show()
Visualize a single root structure with trunk path and cumulative length measurements.
Displays a root skeleton with the main trunk path highlighted in blue, showing nodes along the path with gradient coloring (red to yellow) and optional cumulative length labels at regular intervals. The path follows the actual skeleton pixels rather than straight-line connections between nodes.
Arguments:
- skeleton (np.ndarray): Binary skeleton array (H, W) for a single root structure.
- branch_data (pd.DataFrame): DataFrame containing branch information with columns: 'node-id-src', 'node-id-dst', 'image-coord-src-0', 'image-coord-src-1', 'image-coord-dst-0', 'image-coord-dst-1', 'branch-distance'.
- title (str, optional): Title for the visualization. Defaults to 'Root Structure'.
- dilate_iterations (int, optional): Number of morphological dilation iterations to thicken skeleton for visibility. Defaults to 2.
- zoom_to_content (bool, optional): Whether to zoom to the bounding box of the root structure. Defaults to True.
- padding (int, optional): Padding in pixels around the zoomed content. Defaults to 50.
- trunk_path (list, optional): List of tuples from find_farthest_endpoint_path: [(node_id, branch_idx, distance, vertical), ...]. If None, only skeleton is shown. Defaults to None.
- path_edge_width (float, optional): Line width for trunk path edges. Defaults to 3.
- path_node_size (float, optional): Marker size for nodes along trunk path. Defaults to 10.
- show_edge_dots (bool, optional): Whether to show cyan dots along the trunk path. Defaults to True.
- edge_dot_spacing (int, optional): Spacing in pixels between edge dots. Minimum spacing is 50 pixels to prevent overcrowding. Defaults to 50.
- show_cumulative_length (bool, optional): Whether to display yellow labels showing cumulative length at regular intervals along the path. Defaults to True.
- output_path (str or Path): path to save image
Returns:
None: Displays matplotlib figure showing the root structure visualization.
Notes:
- Trunk path is drawn in blue following actual skeleton pixels
- Cyan dots mark regular intervals along the path
- Nodes are colored with autumn colormap (red=start, yellow=end)
- Yellow labels show cumulative length every other dot
- Lime green label shows total length at the final node
- Node IDs are displayed in white text boxes offset from nodes
Examples:
>>> skeleton = structures['labeled_skeleton'] == root_label >>> branch_data = structures['roots'][root_label]['branch_data'] >>> path = find_farthest_endpoint_path(branch_data, top_node, direction='down') >>> visualize_single_root_structure( ... skeleton, branch_data, ... title='Root 35 Structure', ... trunk_path=path, ... edge_dot_spacing=100 ... )
787def visualize_root_lengths(structures, top_node_results, labeled_shoots, 788 zoom_to_content=True, padding=50, 789 show_detailed_roots=False, detail_viz_kwargs=None): 790 """Visualize matched roots with their calculated lengths displayed. 791 792 Creates an overview visualization showing all matched root structures color-coded 793 by their associated shoot regions, with length measurements labeled. Optionally 794 displays detailed individual root visualizations below the overview. 795 796 Args: 797 structures (dict): Dictionary from extract_root_structures containing: 798 - 'labeled_skeleton': Array with labeled skeleton structures 799 - 'roots': Dictionary of root data by label 800 top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys: 801 - root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes' 802 labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each 803 shoot region has a unique integer label. 804 zoom_to_content (bool, optional): If True, zoom to bounding box of all content. 805 Defaults to True. 806 padding (int, optional): Pixels to add around bounding box when zooming. 807 Defaults to 50. 808 show_detailed_roots (bool, optional): If True, display detailed visualization 809 for each individual root below the overview. Defaults to False. 810 detail_viz_kwargs (dict, optional): Keyword arguments to pass to 811 visualize_single_root_structure for detailed views. Common options: 812 - 'dilate_iterations': int, skeleton thickness (default: 2) 813 - 'path_edge_width': float, trunk path line width (default: 3) 814 - 'path_node_size': float, node marker size (default: 10) 815 - 'show_edge_dots': bool, show dots along path (default: True) 816 - 'edge_dot_spacing': int, spacing between dots (default: 50) 817 - 'show_cumulative_length': bool, show length labels (default: True) 818 If None, uses default values. Defaults to None. 819 820 Returns: 821 None: Displays matplotlib figures showing root length visualizations. 822 823 Notes: 824 - Overview uses color-coding: each root matches its shoot color 825 - Roots are thickened with binary dilation for visibility 826 - Length labels appear at root centroids in white text boxes 827 - Detailed views are indexed (idx=0, idx=1, etc.) in display order 828 - Failed root measurements appear in gray with no length label 829 830 Examples: 831 >>> # Basic usage - overview only 832 >>> visualize_root_lengths(structures, top_node_results, labeled_shoots) 833 834 >>> # With detailed individual visualizations 835 >>> visualize_root_lengths( 836 ... structures, top_node_results, labeled_shoots, 837 ... show_detailed_roots=True, 838 ... detail_viz_kwargs={ 839 ... 'edge_dot_spacing': 100, 840 ... 'path_edge_width': 2, 841 ... 'show_cumulative_length': True 842 ... } 843 ... ) 844 """ 845 846 # Default kwargs for detailed visualizations 847 if detail_viz_kwargs is None: 848 detail_viz_kwargs = {} 849 850 # Create RGB image 851 h, w = structures['labeled_skeleton'].shape 852 vis_img = np.zeros((h, w, 3), dtype=np.uint8) 853 854 # Color map 855 colors = [ 856 [255, 0, 0], # Red 857 [0, 255, 0], # Green 858 [0, 0, 255], # Blue 859 [255, 255, 0], # Yellow 860 [255, 0, 255], # Magenta 861 [0, 255, 255], # Cyan 862 [255, 128, 0], # Orange 863 [128, 0, 255], # Purple 864 ] 865 866 # Draw shoots 867 for shoot_label in range(1, labeled_shoots.max() + 1): 868 shoot_mask = labeled_shoots == shoot_label 869 color = colors[(shoot_label - 1) % len(colors)] 870 vis_img[shoot_mask] = color 871 872 # Process and draw roots with paths 873 root_info = [] 874 root_details = [] # Store data for detailed visualizations 875 876 for root_label, result in top_node_results.items(): 877 branch_data = result['branch_data'] 878 shoot_label = result['shoot_label'] 879 top_node = result['top_nodes'][0][0] 880 881 # Get root skeleton and thicken 882 root_skeleton = structures['labeled_skeleton'] == root_label 883 root_skeleton = ndimage.binary_dilation( 884 root_skeleton, 885 structure=ndimage.generate_binary_structure(2, 1), 886 iterations=3 887 ) 888 889 color = colors[(shoot_label - 1) % len(colors)] 890 891 try: 892 # Find path and calculate length 893 path = find_farthest_endpoint_path( 894 branch_data, top_node, 895 direction='down', use_smart_scoring=True, 896 verbose=False 897 ) 898 root_length = calculate_skeleton_length_px(path) 899 900 # Draw root 901 vis_img[root_skeleton] = color 902 903 # Store info for labels 904 y_coords, x_coords = np.where(root_skeleton) 905 centroid_y, centroid_x = np.mean(y_coords), np.mean(x_coords) 906 root_info.append((centroid_x, centroid_y, root_label, shoot_label, root_length)) 907 908 # Store for detailed visualization 909 if show_detailed_roots: 910 root_details.append({ 911 'root_label': root_label, 912 'skeleton': structures['labeled_skeleton'] == root_label, 913 'branch_data': branch_data, 914 'path': path, 915 'shoot_label': shoot_label 916 }) 917 918 except Exception as e: 919 # Draw root in gray if failed 920 vis_img[root_skeleton] = [128, 128, 128] 921 922 # Zoom if requested 923 if zoom_to_content: 924 content_mask = np.any(vis_img > 0, axis=2) 925 rows, cols = np.where(content_mask) 926 if len(rows) > 0: 927 y_min = max(0, rows.min() - padding) 928 y_max = min(h, rows.max() + padding) 929 x_min = max(0, cols.min() - padding) 930 x_max = min(w, cols.max() + padding) 931 vis_img = vis_img[y_min:y_max, x_min:x_max] 932 # Adjust coordinates for zoom 933 root_info = [(x - x_min, y - y_min, rl, sl, length) 934 for x, y, rl, sl, length in root_info] 935 936 # Plot overview 937 fig, ax = plt.subplots(figsize=(15, 8)) 938 ax.imshow(vis_img) 939 940 # Add text labels 941 for x, y, root_label, shoot_label, length in root_info: 942 ax.text(x, y, f'{length:.1f}px', 943 color='white', fontsize=10, weight='bold', 944 ha='center', va='center', 945 bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7)) 946 947 ax.set_title('Root Lengths (pixels)', fontsize=14) 948 ax.axis('off') 949 plt.tight_layout() 950 plt.show() 951 952 # Display detailed visualizations if requested 953 if show_detailed_roots and root_details: 954 print(f"\n{'='*60}") 955 print(f"Detailed Root Visualizations ({len(root_details)} roots)") 956 print(f"{'='*60}\n") 957 958 for idx, detail in enumerate(root_details): 959 print(f"Displaying Root {detail['root_label']} (idx={idx})...") 960 961 visualize_single_root_structure( 962 skeleton=detail['skeleton'], 963 branch_data=detail['branch_data'], 964 title=f"Root {detail['root_label']} (idx={idx}) - Shoot {detail['shoot_label']}", 965 trunk_path=detail['path'], 966 **detail_viz_kwargs 967 )
Visualize matched roots with their calculated lengths displayed.
Creates an overview visualization showing all matched root structures color-coded by their associated shoot regions, with length measurements labeled. Optionally displays detailed individual root visualizations below the overview.
Arguments:
- structures (dict): Dictionary from extract_root_structures containing:
- 'labeled_skeleton': Array with labeled skeleton structures
- 'roots': Dictionary of root data by label
- top_node_results (dict): Dictionary from find_top_nodes_from_shoot with keys:
- root_label: Dict containing 'branch_data', 'shoot_label', 'top_nodes'
- labeled_shoots (np.ndarray): Labeled array from label_shoot_regions where each shoot region has a unique integer label.
- zoom_to_content (bool, optional): If True, zoom to bounding box of all content. Defaults to True.
- padding (int, optional): Pixels to add around bounding box when zooming. Defaults to 50.
- show_detailed_roots (bool, optional): If True, display detailed visualization for each individual root below the overview. Defaults to False.
- detail_viz_kwargs (dict, optional): Keyword arguments to pass to
visualize_single_root_structure for detailed views. Common options:
- 'dilate_iterations': int, skeleton thickness (default: 2)
- 'path_edge_width': float, trunk path line width (default: 3)
- 'path_node_size': float, node marker size (default: 10)
- 'show_edge_dots': bool, show dots along path (default: True)
- 'edge_dot_spacing': int, spacing between dots (default: 50)
- 'show_cumulative_length': bool, show length labels (default: True) If None, uses default values. Defaults to None.
Returns:
None: Displays matplotlib figures showing root length visualizations.
Notes:
- Overview uses color-coding: each root matches its shoot color
- Roots are thickened with binary dilation for visibility
- Length labels appear at root centroids in white text boxes
- Detailed views are indexed (idx=0, idx=1, etc.) in display order
- Failed root measurements appear in gray with no length label
Examples:
>>> # Basic usage - overview only >>> visualize_root_lengths(structures, top_node_results, labeled_shoots)>>> # With detailed individual visualizations >>> visualize_root_lengths( ... structures, top_node_results, labeled_shoots, ... show_detailed_roots=True, ... detail_viz_kwargs={ ... 'edge_dot_spacing': 100, ... 'path_edge_width': 2, ... 'show_cumulative_length': True ... } ... )