library.root_analysis
1import numpy as np 2import pandas as pd 3import cv2 4from skimage.morphology import skeletonize, remove_small_objects 5from skimage.measure import label 6from skan import Skeleton, summarize 7from skan.csr import skeleton_to_csgraph 8from skimage.morphology import skeletonize 9from skimage.measure import label 10from scipy.spatial.distance import cdist 11from scipy import ndimage 12from pathlib import Path 13 14 15# import matplotlib.pyplot as plt 16 17def find_most_vertical_branch(branch_data, current_node, direction='up', 18 verticality_weight=0.7, height_weight=0.3): 19 """ 20 From current_node, find the branch that balances verticality AND height gained. 21 22 Args: 23 branch_data: DataFrame with branch information 24 current_node: The node we're starting from 25 direction: 'up' or 'down' 26 verticality_weight: Weight for slope (0-1) 27 height_weight: Weight for vertical distance gained (0-1) 28 29 Returns: 30 tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None) 31 """ 32 connected_branches = branch_data[ 33 (branch_data['node-id-src'] == current_node) | 34 (branch_data['node-id-dst'] == current_node) 35 ] 36 37 best_score = 0 38 best_next_node = None 39 best_branch_idx = None 40 best_slope = None 41 best_vertical = None 42 43 candidates = [] 44 45 for idx, row in connected_branches.iterrows(): 46 if row['node-id-src'] == current_node: 47 next_node = row['node-id-dst'] 48 vertical_change = row['coord-src-0'] - row['coord-dst-0'] 49 horizontal_change = abs(row['coord-src-1'] - row['coord-dst-1']) 50 else: 51 next_node = row['node-id-src'] 52 vertical_change = row['coord-dst-0'] - row['coord-src-0'] 53 horizontal_change = abs(row['coord-dst-1'] - row['coord-src-1']) 54 55 # Check direction 56 if direction == 'up': 57 condition = vertical_change > 0 58 elif direction == 'down': 59 condition = vertical_change < 0 60 else: 61 raise ValueError("direction must be 'up' or 'down'") 62 63 if condition: 64 # Calculate slope 65 if horizontal_change > 0: 66 slope = abs(vertical_change) / horizontal_change 67 else: 68 slope = float('inf') 69 70 # Normalize scores (slope can be 0-inf, vertical_change is in pixels) 71 # Use ratio to max for normalization 72 candidates.append({ 73 'idx': idx, 74 'next_node': next_node, 75 'slope': slope, 76 'vertical_change': abs(vertical_change), 77 'horizontal_change': horizontal_change 78 }) 79 80 if not candidates: 81 return None, None, None, None 82 83 # Normalize scores 84 max_slope = max(c['slope'] if c['slope'] != float('inf') else 0 for c in candidates) 85 max_vertical = max(c['vertical_change'] for c in candidates) 86 87 # Calculate combined scores 88 for c in candidates: 89 slope_norm = (c['slope'] if c['slope'] != float('inf') else max_slope * 2) / (max_slope if max_slope > 0 else 1) 90 vertical_norm = c['vertical_change'] / (max_vertical if max_vertical > 0 else 1) 91 92 combined_score = verticality_weight * slope_norm + height_weight * vertical_norm 93 94 print(f" B{c['idx']} → {int(c['next_node'])}: slope={c['slope']:.2f}, vert={c['vertical_change']:.1f}, score={combined_score:.3f}") 95 96 if combined_score > best_score: 97 best_score = combined_score 98 best_next_node = c['next_node'] 99 best_branch_idx = c['idx'] 100 best_slope = c['slope'] 101 best_vertical = c['vertical_change'] 102 103 return best_next_node, best_branch_idx, best_slope, best_vertical 104 105 106def trace_vertical_path(branch_data, start_node, direction='down', 107 verticality_weight=0.5, height_weight=0.5, max_steps=100): 108 """Trace the most vertical path from start_node.""" 109 path = [] 110 current_node = start_node 111 visited_nodes = set() 112 113 for step in range(max_steps): 114 if current_node in visited_nodes: 115 print(f" Loop detected at node {int(current_node)}, stopping") 116 break 117 visited_nodes.add(current_node) 118 119 # Now expects 4 return values 120 next_node, branch_idx, slope, vertical = find_most_vertical_branch( 121 branch_data, current_node, direction, verticality_weight, height_weight 122 ) 123 124 if next_node is None: 125 print(f" Reached endpoint at node {int(current_node)}") 126 break 127 128 path.append((current_node, branch_idx, slope, vertical)) 129 print(f" Step {step}: Node {int(current_node)} → {int(next_node)} via B{branch_idx}") 130 131 current_node = next_node 132 133 path.append((current_node, None, None, None)) 134 return path 135 136 137 138def label_skeleton(binary_mask): 139 """ 140 Take a binary mask, skeletonize it, and return labeled connected components. 141 142 Args: 143 binary_mask: Binary numpy array (bool or 0/255) 144 145 Returns: 146 tuple: (skeleton, labeled_skeleton, unique_labels) 147 - skeleton: Binary skeleton array 148 - labeled_skeleton: Labeled skeleton with unique IDs for each component 149 - unique_labels: Array of unique label IDs (excluding background 0) 150 """ 151 152 # Ensure binary 153 if binary_mask.dtype != bool: 154 binary_mask = binary_mask.astype(bool) 155 156 # Skeletonize 157 skeleton = skeletonize(binary_mask) 158 159 # Label connected components 160 labeled_skeleton = label(skeleton) 161 162 # Get unique labels (excluding background) 163 unique_labels = np.unique(labeled_skeleton) 164 unique_labels = unique_labels[unique_labels != 0] 165 166 return skeleton, labeled_skeleton, unique_labels 167 168def find_top_nodes(branch_data, n_nodes=1, threshold=None): 169 """ 170 Find the node(s) closest to the top of the image (lowest row values). 171 172 Args: 173 branch_data: DataFrame with branch information 174 n_nodes: Number of top nodes to return 175 threshold: Optional - return all nodes within this many pixels of the top 176 177 Returns: 178 list of node IDs at the top 179 """ 180 # Get all unique nodes and their coordinates 181 node_coords = {} 182 for idx, row in branch_data.iterrows(): 183 node_coords[row['node-id-src']] = row['coord-src-0'] 184 node_coords[row['node-id-dst']] = row['coord-dst-0'] 185 186 # Sort by row coordinate (lowest = top) 187 sorted_nodes = sorted(node_coords.items(), key=lambda x: x[1]) 188 189 if threshold is not None: 190 # Return all nodes within threshold pixels of the top 191 min_row = sorted_nodes[0][1] 192 top_nodes = [node_id for node_id, row in sorted_nodes if row <= min_row + threshold] 193 print(f"Top node at row {min_row:.0f}") 194 print(f"Found {len(top_nodes)} nodes within {threshold} pixels of top") 195 else: 196 # Return top n nodes 197 top_nodes = [node_id for node_id, row in sorted_nodes[:n_nodes]] 198 print(f"Top {n_nodes} node(s):") 199 for node_id, row in sorted_nodes[:n_nodes]: 200 print(f" Node {int(node_id)} at row {row:.0f}") 201 202 return top_nodes 203 204def extract_root_structures(binary_mask, verbose=False): 205 """ 206 Analyze all root structures in a binary mask. 207 208 Args: 209 binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file 210 211 Returns: 212 dict: { 213 'skeleton': full skeleton array, 214 'labeled_skeleton': labeled skeleton array, 215 'unique_labels': array of label IDs, 216 'roots': { 217 label_id: { 218 'mask': binary mask for this root only, 219 'skeleton': skeleton for this root only, 220 'branch_data': DataFrame with branch information, 221 'num_nodes': number of nodes, 222 'num_branches': number of branches, 223 'total_pixels': total skeleton pixels 224 }, 225 ... 226 } 227 } 228 """ 229 # Load from file if path provided 230 if not isinstance(binary_mask, np.ndarray): 231 binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE) 232 233 # Ensure binary 234 if binary_mask.dtype != bool: 235 binary_mask = binary_mask.astype(bool) 236 237 # Get labeled skeleton 238 skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask) 239 240 # Initialize results 241 results = { 242 'skeleton': skeleton, 243 'labeled_skeleton': labeled_skeleton, 244 'unique_labels': unique_labels, 245 'roots': {} 246 } 247 if verbose: 248 print(f"Analyzing {len(unique_labels)} root structure(s)...") 249 250 # Analyze each root separately 251 for label_id in unique_labels: 252 if verbose: 253 print(f"\n--- Root {int(label_id)} ---") 254 255 # Extract this root's skeleton 256 single_root_mask = (labeled_skeleton == label_id) 257 258 # Get branch data for this root 259 try: 260 skeleton_obj = Skeleton(single_root_mask) 261 branch_data = summarize(skeleton_obj) 262 263 # Get node count 264 unique_nodes = set(branch_data['node-id-src']).union( 265 set(branch_data['node-id-dst'])) 266 num_nodes = len(unique_nodes) 267 num_branches = len(branch_data) 268 total_pixels = np.sum(single_root_mask) 269 270 # Store results 271 results['roots'][int(label_id)] = { 272 'mask': single_root_mask, 273 'skeleton': single_root_mask, 274 'branch_data': branch_data, 275 'num_nodes': num_nodes, 276 'num_branches': num_branches, 277 'total_pixels': total_pixels 278 } 279 if verbose: 280 print(f" Nodes: {num_nodes}, Branches: {num_branches}, Pixels: {total_pixels}") 281 282 except Exception as e: 283 if verbose: 284 print(f" ERROR analyzing root {int(label_id)}: {e}") 285 results['roots'][int(label_id)] = { 286 'mask': single_root_mask, 287 'skeleton': single_root_mask, 288 'branch_data': None, 289 'num_nodes': 0, 290 'num_branches': 0, 291 'total_pixels': np.sum(single_root_mask), 292 'error': str(e) 293 } 294 295 if verbose: 296 print(f"\n=== Summary ===") 297 print(f"Total roots analyzed: {len(results['roots'])}") 298 successful = sum(1 for r in results['roots'].values() if r['branch_data'] is not None) 299 print(f"Successfully analyzed: {successful}") 300 301 return results 302 303def find_farthest_endpoint_path(branch_data, start_node, direction='down', 304 use_smart_scoring=False, 305 horizontal_penalty=0.5, 306 straightness_weight=1.0, 307 verbose=True): 308 """ 309 Find the path to the endpoint that is farthest away in straight-line distance. 310 311 Args: 312 branch_data: DataFrame with branch information 313 start_node: Node to start from 314 direction: 'down' or 'up' (for filtering endpoints) 315 use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness 316 horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty) 317 straightness_weight: Weight for straightness ratio (higher = more reward for straight paths) 318 verbose: Print progress 319 320 Returns: 321 Best path: [(node, branch_idx, ...), ...] 322 """ 323 import networkx as nx 324 325 # Build NetworkX graph 326 G = nx.Graph() 327 328 # Get node coordinates 329 node_coords = {} 330 for idx, row in branch_data.iterrows(): 331 src = row['node-id-src'] 332 dst = row['node-id-dst'] 333 334 # Store coordinates 335 node_coords[src] = (row['coord-src-0'], row['coord-src-1']) 336 node_coords[dst] = (row['coord-dst-0'], row['coord-dst-1']) 337 338 # Add edge 339 G.add_edge(src, dst, branch_idx=idx, distance=row['branch-distance']) 340 341 # Get start node coordinates 342 if start_node not in node_coords: 343 print(f"Error: Start node {int(start_node)} not found in graph") 344 return [(start_node, None, None, None)] 345 346 start_row, start_col = node_coords[start_node] 347 348 if verbose: 349 print(f"Start node {int(start_node)} at (row={start_row:.0f}, col={start_col:.0f})") 350 print(f"Scoring mode: {'SMART (considering straightness & horizontal deviation)' if use_smart_scoring else 'SIMPLE (just Euclidean distance)'}") 351 352 # Find all endpoints (degree 1 nodes) 353 endpoints = [node for node in G.nodes() if G.degree(node) == 1] 354 endpoints = [ep for ep in endpoints if ep != start_node] 355 356 if verbose: 357 print(f"Found {len(endpoints)} candidate endpoints") 358 359 # Evaluate each endpoint 360 endpoint_scores = [] 361 362 for endpoint in endpoints: 363 end_row, end_col = node_coords[endpoint] 364 365 # Calculate basic distances 366 euclidean_dist = np.sqrt((end_row - start_row)**2 + (end_col - start_col)**2) 367 vertical_dist = end_row - start_row # Positive = going down 368 horizontal_dist = abs(end_col - start_col) # Horizontal deviation 369 370 # Filter by direction 371 if direction == 'down' and vertical_dist <= 0: 372 continue 373 elif direction == 'up' and vertical_dist >= 0: 374 continue 375 376 # Find a path to calculate straightness 377 try: 378 node_path = nx.shortest_path(G, start_node, endpoint, weight='distance') 379 380 # Calculate actual skeleton path length 381 skeleton_length = 0 382 for i in range(len(node_path) - 1): 383 edge_data = G[node_path[i]][node_path[i + 1]] 384 skeleton_length += edge_data['distance'] 385 386 # Calculate straightness ratio 387 # 1.0 = perfectly straight, <1.0 = wiggly 388 straightness = euclidean_dist / skeleton_length if skeleton_length > 0 else 0 389 390 except nx.NetworkXNoPath: 391 continue 392 393 # Calculate score based on mode 394 if use_smart_scoring: 395 # SMART SCORING: vertical distance * straightness - horizontal penalty 396 # This rewards: 397 # - Going far down (vertical_dist) 398 # - Straight paths (straightness) 399 # This penalizes: 400 # - Horizontal deviation (horizontal_dist) 401 score = (abs(vertical_dist) * straightness_weight * straightness 402 - horizontal_penalty * horizontal_dist) 403 else: 404 # SIMPLE SCORING: just use Euclidean distance 405 score = euclidean_dist 406 407 endpoint_scores.append({ 408 'endpoint': endpoint, 409 'score': score, 410 'euclidean_dist': euclidean_dist, 411 'vertical_dist': abs(vertical_dist), 412 'horizontal_dist': horizontal_dist, 413 'straightness': straightness, 414 'skeleton_length': skeleton_length, 415 'coords': (end_row, end_col) 416 }) 417 418 if not endpoint_scores: 419 print("No valid endpoints found in specified direction!") 420 return [(start_node, None, None, None)] 421 422 # Sort by score (highest first) 423 endpoint_scores.sort(key=lambda x: x['score'], reverse=True) 424 425 if verbose: 426 print(f"\nTop 5 candidates:") 427 for i, ep_info in enumerate(endpoint_scores[:5]): 428 print(f" {i+1}. Node {int(ep_info['endpoint'])}: " 429 f"score={ep_info['score']:.1f}, " 430 f"euclidean={ep_info['euclidean_dist']:.1f}px, " 431 f"vertical={ep_info['vertical_dist']:.1f}px, " 432 f"horizontal={ep_info['horizontal_dist']:.1f}px, " 433 f"straightness={ep_info['straightness']:.2f}") 434 435 # Get the best endpoint 436 best = endpoint_scores[0] 437 target_endpoint = best['endpoint'] 438 439 if verbose: 440 print(f"\nSelected: Node {int(target_endpoint)} (score: {best['score']:.1f})") 441 442 # Find shortest path to target 443 try: 444 node_path = nx.shortest_path(G, start_node, target_endpoint, weight='distance') 445 except nx.NetworkXNoPath: 446 print(f"No path found to endpoint {int(target_endpoint)}") 447 return [(start_node, None, None, None)] 448 449 # Convert to detailed path 450 detailed_path = [] 451 total_skeleton_length = 0 452 453 for i in range(len(node_path) - 1): 454 current = node_path[i] 455 next_node = node_path[i + 1] 456 457 edge_data = G[current][next_node] 458 branch_idx = edge_data['branch_idx'] 459 distance = edge_data['distance'] 460 461 total_skeleton_length += distance 462 463 curr_row, curr_col = node_coords[current] 464 next_row, next_col = node_coords[next_node] 465 vertical_dist = next_row - curr_row 466 467 detailed_path.append((current, branch_idx, distance, vertical_dist)) 468 469 detailed_path.append((node_path[-1], None, None, None)) 470 471 if verbose: 472 node_sequence = [int(node) for node, _, _, _ in detailed_path] 473 print(f"\nPath: {' → '.join(map(str, node_sequence[:15]))}") 474 if len(node_sequence) > 15: 475 print(f" ... (total {len(node_sequence)} nodes)") 476 print(f"Metrics:") 477 print(f" Euclidean: {best['euclidean_dist']:.1f}px") 478 print(f" Skeleton: {total_skeleton_length:.1f}px") 479 print(f" Straightness: {best['straightness']:.2f}") 480 print(f" Horizontal deviation: {best['horizontal_dist']:.1f}px") 481 482 return detailed_path 483 484 485def calculate_skeleton_length_px(detailed_path): 486 """ 487 Calculate the total skeleton path length in pixels. 488 489 Args: 490 detailed_path: List of tuples (node, branch_idx, distance, vertical_dist) 491 from find_farthest_endpoint_path() 492 493 Returns: 494 float: Total path length in pixels 495 """ 496 total_length = 0.0 497 498 # Sum all distance values, skipping the last element which has None for distance 499 for node, branch_idx, distance, vertical_dist in detailed_path: 500 if distance is not None: 501 total_length += distance 502 503 return total_length 504 505 506 507 508 509 510def process_matched_roots_to_lengths(structures, top_node_results, labeled_shoots, num_shoots): 511 """ 512 Process matched roots and return array of lengths ordered by shoot position (left to right). 513 514 Args: 515 structures: Dict from extract_root_structures 516 top_node_results: Dict from find_top_nodes_from_shoot 517 labeled_shoots: Labeled shoot array 518 num_shoots: Number of shoots (typically 5) 519 520 Returns: 521 np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right). 522 Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots. 523 """ 524 525 526 # Get x-position (centroid) of each shoot for left-to-right ordering 527 shoot_positions = {} 528 for shoot_label in range(1, num_shoots + 1): 529 shoot_mask = labeled_shoots == shoot_label 530 y_coords, x_coords = np.where(shoot_mask) 531 if len(x_coords) > 0: 532 centroid_x = np.mean(x_coords) 533 shoot_positions[shoot_label] = centroid_x 534 else: 535 shoot_positions[shoot_label] = float('inf') 536 537 # Sort shoots by x-position (left to right) 538 sorted_shoots = sorted(shoot_positions.items(), key=lambda x: x[1]) 539 shoot_order = [label for label, _ in sorted_shoots] 540 541 # Create mapping from shoot_label to root_length 542 shoot_to_length = {} 543 544 for root_label, result in top_node_results.items(): 545 branch_data = result['branch_data'] 546 shoot_label = result['shoot_label'] 547 top_node = result['top_nodes'][0][0] 548 549 try: 550 # Find longest path from top node 551 path = find_farthest_endpoint_path( 552 branch_data, 553 top_node, 554 direction='down', 555 use_smart_scoring=True, 556 verbose=False 557 ) 558 559 # Calculate length 560 root_length = calculate_skeleton_length_px(path) 561 562 # Store length for this shoot 563 shoot_to_length[shoot_label] = root_length 564 565 except Exception as e: 566 print(f"Warning: Failed to process root {root_label} for shoot {shoot_label}: {e}") 567 shoot_to_length[shoot_label] = 0.0 568 569 # Build output array ordered by shoot position (left to right) 570 lengths_array = np.array([ 571 shoot_to_length.get(shoot_label, 0.0) 572 for shoot_label in shoot_order 573 ]) 574 575 return lengths_array
18def find_most_vertical_branch(branch_data, current_node, direction='up', 19 verticality_weight=0.7, height_weight=0.3): 20 """ 21 From current_node, find the branch that balances verticality AND height gained. 22 23 Args: 24 branch_data: DataFrame with branch information 25 current_node: The node we're starting from 26 direction: 'up' or 'down' 27 verticality_weight: Weight for slope (0-1) 28 height_weight: Weight for vertical distance gained (0-1) 29 30 Returns: 31 tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None) 32 """ 33 connected_branches = branch_data[ 34 (branch_data['node-id-src'] == current_node) | 35 (branch_data['node-id-dst'] == current_node) 36 ] 37 38 best_score = 0 39 best_next_node = None 40 best_branch_idx = None 41 best_slope = None 42 best_vertical = None 43 44 candidates = [] 45 46 for idx, row in connected_branches.iterrows(): 47 if row['node-id-src'] == current_node: 48 next_node = row['node-id-dst'] 49 vertical_change = row['coord-src-0'] - row['coord-dst-0'] 50 horizontal_change = abs(row['coord-src-1'] - row['coord-dst-1']) 51 else: 52 next_node = row['node-id-src'] 53 vertical_change = row['coord-dst-0'] - row['coord-src-0'] 54 horizontal_change = abs(row['coord-dst-1'] - row['coord-src-1']) 55 56 # Check direction 57 if direction == 'up': 58 condition = vertical_change > 0 59 elif direction == 'down': 60 condition = vertical_change < 0 61 else: 62 raise ValueError("direction must be 'up' or 'down'") 63 64 if condition: 65 # Calculate slope 66 if horizontal_change > 0: 67 slope = abs(vertical_change) / horizontal_change 68 else: 69 slope = float('inf') 70 71 # Normalize scores (slope can be 0-inf, vertical_change is in pixels) 72 # Use ratio to max for normalization 73 candidates.append({ 74 'idx': idx, 75 'next_node': next_node, 76 'slope': slope, 77 'vertical_change': abs(vertical_change), 78 'horizontal_change': horizontal_change 79 }) 80 81 if not candidates: 82 return None, None, None, None 83 84 # Normalize scores 85 max_slope = max(c['slope'] if c['slope'] != float('inf') else 0 for c in candidates) 86 max_vertical = max(c['vertical_change'] for c in candidates) 87 88 # Calculate combined scores 89 for c in candidates: 90 slope_norm = (c['slope'] if c['slope'] != float('inf') else max_slope * 2) / (max_slope if max_slope > 0 else 1) 91 vertical_norm = c['vertical_change'] / (max_vertical if max_vertical > 0 else 1) 92 93 combined_score = verticality_weight * slope_norm + height_weight * vertical_norm 94 95 print(f" B{c['idx']} → {int(c['next_node'])}: slope={c['slope']:.2f}, vert={c['vertical_change']:.1f}, score={combined_score:.3f}") 96 97 if combined_score > best_score: 98 best_score = combined_score 99 best_next_node = c['next_node'] 100 best_branch_idx = c['idx'] 101 best_slope = c['slope'] 102 best_vertical = c['vertical_change'] 103 104 return best_next_node, best_branch_idx, best_slope, best_vertical
From current_node, find the branch that balances verticality AND height gained.
Arguments:
- branch_data: DataFrame with branch information
- current_node: The node we're starting from
- direction: 'up' or 'down'
- verticality_weight: Weight for slope (0-1)
- height_weight: Weight for vertical distance gained (0-1)
Returns:
tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None)
107def trace_vertical_path(branch_data, start_node, direction='down', 108 verticality_weight=0.5, height_weight=0.5, max_steps=100): 109 """Trace the most vertical path from start_node.""" 110 path = [] 111 current_node = start_node 112 visited_nodes = set() 113 114 for step in range(max_steps): 115 if current_node in visited_nodes: 116 print(f" Loop detected at node {int(current_node)}, stopping") 117 break 118 visited_nodes.add(current_node) 119 120 # Now expects 4 return values 121 next_node, branch_idx, slope, vertical = find_most_vertical_branch( 122 branch_data, current_node, direction, verticality_weight, height_weight 123 ) 124 125 if next_node is None: 126 print(f" Reached endpoint at node {int(current_node)}") 127 break 128 129 path.append((current_node, branch_idx, slope, vertical)) 130 print(f" Step {step}: Node {int(current_node)} → {int(next_node)} via B{branch_idx}") 131 132 current_node = next_node 133 134 path.append((current_node, None, None, None)) 135 return path
Trace the most vertical path from start_node.
139def label_skeleton(binary_mask): 140 """ 141 Take a binary mask, skeletonize it, and return labeled connected components. 142 143 Args: 144 binary_mask: Binary numpy array (bool or 0/255) 145 146 Returns: 147 tuple: (skeleton, labeled_skeleton, unique_labels) 148 - skeleton: Binary skeleton array 149 - labeled_skeleton: Labeled skeleton with unique IDs for each component 150 - unique_labels: Array of unique label IDs (excluding background 0) 151 """ 152 153 # Ensure binary 154 if binary_mask.dtype != bool: 155 binary_mask = binary_mask.astype(bool) 156 157 # Skeletonize 158 skeleton = skeletonize(binary_mask) 159 160 # Label connected components 161 labeled_skeleton = label(skeleton) 162 163 # Get unique labels (excluding background) 164 unique_labels = np.unique(labeled_skeleton) 165 unique_labels = unique_labels[unique_labels != 0] 166 167 return skeleton, labeled_skeleton, unique_labels
Take a binary mask, skeletonize it, and return labeled connected components.
Arguments:
- binary_mask: Binary numpy array (bool or 0/255)
Returns:
tuple: (skeleton, labeled_skeleton, unique_labels) - skeleton: Binary skeleton array - labeled_skeleton: Labeled skeleton with unique IDs for each component - unique_labels: Array of unique label IDs (excluding background 0)
169def find_top_nodes(branch_data, n_nodes=1, threshold=None): 170 """ 171 Find the node(s) closest to the top of the image (lowest row values). 172 173 Args: 174 branch_data: DataFrame with branch information 175 n_nodes: Number of top nodes to return 176 threshold: Optional - return all nodes within this many pixels of the top 177 178 Returns: 179 list of node IDs at the top 180 """ 181 # Get all unique nodes and their coordinates 182 node_coords = {} 183 for idx, row in branch_data.iterrows(): 184 node_coords[row['node-id-src']] = row['coord-src-0'] 185 node_coords[row['node-id-dst']] = row['coord-dst-0'] 186 187 # Sort by row coordinate (lowest = top) 188 sorted_nodes = sorted(node_coords.items(), key=lambda x: x[1]) 189 190 if threshold is not None: 191 # Return all nodes within threshold pixels of the top 192 min_row = sorted_nodes[0][1] 193 top_nodes = [node_id for node_id, row in sorted_nodes if row <= min_row + threshold] 194 print(f"Top node at row {min_row:.0f}") 195 print(f"Found {len(top_nodes)} nodes within {threshold} pixels of top") 196 else: 197 # Return top n nodes 198 top_nodes = [node_id for node_id, row in sorted_nodes[:n_nodes]] 199 print(f"Top {n_nodes} node(s):") 200 for node_id, row in sorted_nodes[:n_nodes]: 201 print(f" Node {int(node_id)} at row {row:.0f}") 202 203 return top_nodes
Find the node(s) closest to the top of the image (lowest row values).
Arguments:
- branch_data: DataFrame with branch information
- n_nodes: Number of top nodes to return
- threshold: Optional - return all nodes within this many pixels of the top
Returns:
list of node IDs at the top
205def extract_root_structures(binary_mask, verbose=False): 206 """ 207 Analyze all root structures in a binary mask. 208 209 Args: 210 binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file 211 212 Returns: 213 dict: { 214 'skeleton': full skeleton array, 215 'labeled_skeleton': labeled skeleton array, 216 'unique_labels': array of label IDs, 217 'roots': { 218 label_id: { 219 'mask': binary mask for this root only, 220 'skeleton': skeleton for this root only, 221 'branch_data': DataFrame with branch information, 222 'num_nodes': number of nodes, 223 'num_branches': number of branches, 224 'total_pixels': total skeleton pixels 225 }, 226 ... 227 } 228 } 229 """ 230 # Load from file if path provided 231 if not isinstance(binary_mask, np.ndarray): 232 binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE) 233 234 # Ensure binary 235 if binary_mask.dtype != bool: 236 binary_mask = binary_mask.astype(bool) 237 238 # Get labeled skeleton 239 skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask) 240 241 # Initialize results 242 results = { 243 'skeleton': skeleton, 244 'labeled_skeleton': labeled_skeleton, 245 'unique_labels': unique_labels, 246 'roots': {} 247 } 248 if verbose: 249 print(f"Analyzing {len(unique_labels)} root structure(s)...") 250 251 # Analyze each root separately 252 for label_id in unique_labels: 253 if verbose: 254 print(f"\n--- Root {int(label_id)} ---") 255 256 # Extract this root's skeleton 257 single_root_mask = (labeled_skeleton == label_id) 258 259 # Get branch data for this root 260 try: 261 skeleton_obj = Skeleton(single_root_mask) 262 branch_data = summarize(skeleton_obj) 263 264 # Get node count 265 unique_nodes = set(branch_data['node-id-src']).union( 266 set(branch_data['node-id-dst'])) 267 num_nodes = len(unique_nodes) 268 num_branches = len(branch_data) 269 total_pixels = np.sum(single_root_mask) 270 271 # Store results 272 results['roots'][int(label_id)] = { 273 'mask': single_root_mask, 274 'skeleton': single_root_mask, 275 'branch_data': branch_data, 276 'num_nodes': num_nodes, 277 'num_branches': num_branches, 278 'total_pixels': total_pixels 279 } 280 if verbose: 281 print(f" Nodes: {num_nodes}, Branches: {num_branches}, Pixels: {total_pixels}") 282 283 except Exception as e: 284 if verbose: 285 print(f" ERROR analyzing root {int(label_id)}: {e}") 286 results['roots'][int(label_id)] = { 287 'mask': single_root_mask, 288 'skeleton': single_root_mask, 289 'branch_data': None, 290 'num_nodes': 0, 291 'num_branches': 0, 292 'total_pixels': np.sum(single_root_mask), 293 'error': str(e) 294 } 295 296 if verbose: 297 print(f"\n=== Summary ===") 298 print(f"Total roots analyzed: {len(results['roots'])}") 299 successful = sum(1 for r in results['roots'].values() if r['branch_data'] is not None) 300 print(f"Successfully analyzed: {successful}") 301 302 return results
Analyze all root structures in a binary mask.
Arguments:
- binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file
Returns:
dict: { 'skeleton': full skeleton array, 'labeled_skeleton': labeled skeleton array, 'unique_labels': array of label IDs, 'roots': { label_id: { 'mask': binary mask for this root only, 'skeleton': skeleton for this root only, 'branch_data': DataFrame with branch information, 'num_nodes': number of nodes, 'num_branches': number of branches, 'total_pixels': total skeleton pixels }, ... } }
304def find_farthest_endpoint_path(branch_data, start_node, direction='down', 305 use_smart_scoring=False, 306 horizontal_penalty=0.5, 307 straightness_weight=1.0, 308 verbose=True): 309 """ 310 Find the path to the endpoint that is farthest away in straight-line distance. 311 312 Args: 313 branch_data: DataFrame with branch information 314 start_node: Node to start from 315 direction: 'down' or 'up' (for filtering endpoints) 316 use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness 317 horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty) 318 straightness_weight: Weight for straightness ratio (higher = more reward for straight paths) 319 verbose: Print progress 320 321 Returns: 322 Best path: [(node, branch_idx, ...), ...] 323 """ 324 import networkx as nx 325 326 # Build NetworkX graph 327 G = nx.Graph() 328 329 # Get node coordinates 330 node_coords = {} 331 for idx, row in branch_data.iterrows(): 332 src = row['node-id-src'] 333 dst = row['node-id-dst'] 334 335 # Store coordinates 336 node_coords[src] = (row['coord-src-0'], row['coord-src-1']) 337 node_coords[dst] = (row['coord-dst-0'], row['coord-dst-1']) 338 339 # Add edge 340 G.add_edge(src, dst, branch_idx=idx, distance=row['branch-distance']) 341 342 # Get start node coordinates 343 if start_node not in node_coords: 344 print(f"Error: Start node {int(start_node)} not found in graph") 345 return [(start_node, None, None, None)] 346 347 start_row, start_col = node_coords[start_node] 348 349 if verbose: 350 print(f"Start node {int(start_node)} at (row={start_row:.0f}, col={start_col:.0f})") 351 print(f"Scoring mode: {'SMART (considering straightness & horizontal deviation)' if use_smart_scoring else 'SIMPLE (just Euclidean distance)'}") 352 353 # Find all endpoints (degree 1 nodes) 354 endpoints = [node for node in G.nodes() if G.degree(node) == 1] 355 endpoints = [ep for ep in endpoints if ep != start_node] 356 357 if verbose: 358 print(f"Found {len(endpoints)} candidate endpoints") 359 360 # Evaluate each endpoint 361 endpoint_scores = [] 362 363 for endpoint in endpoints: 364 end_row, end_col = node_coords[endpoint] 365 366 # Calculate basic distances 367 euclidean_dist = np.sqrt((end_row - start_row)**2 + (end_col - start_col)**2) 368 vertical_dist = end_row - start_row # Positive = going down 369 horizontal_dist = abs(end_col - start_col) # Horizontal deviation 370 371 # Filter by direction 372 if direction == 'down' and vertical_dist <= 0: 373 continue 374 elif direction == 'up' and vertical_dist >= 0: 375 continue 376 377 # Find a path to calculate straightness 378 try: 379 node_path = nx.shortest_path(G, start_node, endpoint, weight='distance') 380 381 # Calculate actual skeleton path length 382 skeleton_length = 0 383 for i in range(len(node_path) - 1): 384 edge_data = G[node_path[i]][node_path[i + 1]] 385 skeleton_length += edge_data['distance'] 386 387 # Calculate straightness ratio 388 # 1.0 = perfectly straight, <1.0 = wiggly 389 straightness = euclidean_dist / skeleton_length if skeleton_length > 0 else 0 390 391 except nx.NetworkXNoPath: 392 continue 393 394 # Calculate score based on mode 395 if use_smart_scoring: 396 # SMART SCORING: vertical distance * straightness - horizontal penalty 397 # This rewards: 398 # - Going far down (vertical_dist) 399 # - Straight paths (straightness) 400 # This penalizes: 401 # - Horizontal deviation (horizontal_dist) 402 score = (abs(vertical_dist) * straightness_weight * straightness 403 - horizontal_penalty * horizontal_dist) 404 else: 405 # SIMPLE SCORING: just use Euclidean distance 406 score = euclidean_dist 407 408 endpoint_scores.append({ 409 'endpoint': endpoint, 410 'score': score, 411 'euclidean_dist': euclidean_dist, 412 'vertical_dist': abs(vertical_dist), 413 'horizontal_dist': horizontal_dist, 414 'straightness': straightness, 415 'skeleton_length': skeleton_length, 416 'coords': (end_row, end_col) 417 }) 418 419 if not endpoint_scores: 420 print("No valid endpoints found in specified direction!") 421 return [(start_node, None, None, None)] 422 423 # Sort by score (highest first) 424 endpoint_scores.sort(key=lambda x: x['score'], reverse=True) 425 426 if verbose: 427 print(f"\nTop 5 candidates:") 428 for i, ep_info in enumerate(endpoint_scores[:5]): 429 print(f" {i+1}. Node {int(ep_info['endpoint'])}: " 430 f"score={ep_info['score']:.1f}, " 431 f"euclidean={ep_info['euclidean_dist']:.1f}px, " 432 f"vertical={ep_info['vertical_dist']:.1f}px, " 433 f"horizontal={ep_info['horizontal_dist']:.1f}px, " 434 f"straightness={ep_info['straightness']:.2f}") 435 436 # Get the best endpoint 437 best = endpoint_scores[0] 438 target_endpoint = best['endpoint'] 439 440 if verbose: 441 print(f"\nSelected: Node {int(target_endpoint)} (score: {best['score']:.1f})") 442 443 # Find shortest path to target 444 try: 445 node_path = nx.shortest_path(G, start_node, target_endpoint, weight='distance') 446 except nx.NetworkXNoPath: 447 print(f"No path found to endpoint {int(target_endpoint)}") 448 return [(start_node, None, None, None)] 449 450 # Convert to detailed path 451 detailed_path = [] 452 total_skeleton_length = 0 453 454 for i in range(len(node_path) - 1): 455 current = node_path[i] 456 next_node = node_path[i + 1] 457 458 edge_data = G[current][next_node] 459 branch_idx = edge_data['branch_idx'] 460 distance = edge_data['distance'] 461 462 total_skeleton_length += distance 463 464 curr_row, curr_col = node_coords[current] 465 next_row, next_col = node_coords[next_node] 466 vertical_dist = next_row - curr_row 467 468 detailed_path.append((current, branch_idx, distance, vertical_dist)) 469 470 detailed_path.append((node_path[-1], None, None, None)) 471 472 if verbose: 473 node_sequence = [int(node) for node, _, _, _ in detailed_path] 474 print(f"\nPath: {' → '.join(map(str, node_sequence[:15]))}") 475 if len(node_sequence) > 15: 476 print(f" ... (total {len(node_sequence)} nodes)") 477 print(f"Metrics:") 478 print(f" Euclidean: {best['euclidean_dist']:.1f}px") 479 print(f" Skeleton: {total_skeleton_length:.1f}px") 480 print(f" Straightness: {best['straightness']:.2f}") 481 print(f" Horizontal deviation: {best['horizontal_dist']:.1f}px") 482 483 return detailed_path
Find the path to the endpoint that is farthest away in straight-line distance.
Arguments:
- branch_data: DataFrame with branch information
- start_node: Node to start from
- direction: 'down' or 'up' (for filtering endpoints)
- use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness
- horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty)
- straightness_weight: Weight for straightness ratio (higher = more reward for straight paths)
- verbose: Print progress
Returns:
Best path: [(node, branch_idx, ...), ...]
486def calculate_skeleton_length_px(detailed_path): 487 """ 488 Calculate the total skeleton path length in pixels. 489 490 Args: 491 detailed_path: List of tuples (node, branch_idx, distance, vertical_dist) 492 from find_farthest_endpoint_path() 493 494 Returns: 495 float: Total path length in pixels 496 """ 497 total_length = 0.0 498 499 # Sum all distance values, skipping the last element which has None for distance 500 for node, branch_idx, distance, vertical_dist in detailed_path: 501 if distance is not None: 502 total_length += distance 503 504 return total_length
Calculate the total skeleton path length in pixels.
Arguments:
- detailed_path: List of tuples (node, branch_idx, distance, vertical_dist) from find_farthest_endpoint_path()
Returns:
float: Total path length in pixels
511def process_matched_roots_to_lengths(structures, top_node_results, labeled_shoots, num_shoots): 512 """ 513 Process matched roots and return array of lengths ordered by shoot position (left to right). 514 515 Args: 516 structures: Dict from extract_root_structures 517 top_node_results: Dict from find_top_nodes_from_shoot 518 labeled_shoots: Labeled shoot array 519 num_shoots: Number of shoots (typically 5) 520 521 Returns: 522 np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right). 523 Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots. 524 """ 525 526 527 # Get x-position (centroid) of each shoot for left-to-right ordering 528 shoot_positions = {} 529 for shoot_label in range(1, num_shoots + 1): 530 shoot_mask = labeled_shoots == shoot_label 531 y_coords, x_coords = np.where(shoot_mask) 532 if len(x_coords) > 0: 533 centroid_x = np.mean(x_coords) 534 shoot_positions[shoot_label] = centroid_x 535 else: 536 shoot_positions[shoot_label] = float('inf') 537 538 # Sort shoots by x-position (left to right) 539 sorted_shoots = sorted(shoot_positions.items(), key=lambda x: x[1]) 540 shoot_order = [label for label, _ in sorted_shoots] 541 542 # Create mapping from shoot_label to root_length 543 shoot_to_length = {} 544 545 for root_label, result in top_node_results.items(): 546 branch_data = result['branch_data'] 547 shoot_label = result['shoot_label'] 548 top_node = result['top_nodes'][0][0] 549 550 try: 551 # Find longest path from top node 552 path = find_farthest_endpoint_path( 553 branch_data, 554 top_node, 555 direction='down', 556 use_smart_scoring=True, 557 verbose=False 558 ) 559 560 # Calculate length 561 root_length = calculate_skeleton_length_px(path) 562 563 # Store length for this shoot 564 shoot_to_length[shoot_label] = root_length 565 566 except Exception as e: 567 print(f"Warning: Failed to process root {root_label} for shoot {shoot_label}: {e}") 568 shoot_to_length[shoot_label] = 0.0 569 570 # Build output array ordered by shoot position (left to right) 571 lengths_array = np.array([ 572 shoot_to_length.get(shoot_label, 0.0) 573 for shoot_label in shoot_order 574 ]) 575 576 return lengths_array
Process matched roots and return array of lengths ordered by shoot position (left to right).
Arguments:
- structures: Dict from extract_root_structures
- top_node_results: Dict from find_top_nodes_from_shoot
- labeled_shoots: Labeled shoot array
- num_shoots: Number of shoots (typically 5)
Returns:
np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right). Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots.