library.root_analysis

  1import numpy as np
  2import pandas as pd
  3import cv2
  4from skimage.morphology import skeletonize, remove_small_objects
  5from skimage.measure import label
  6from skan import Skeleton, summarize
  7from skan.csr import skeleton_to_csgraph
  8from skimage.morphology import skeletonize
  9from skimage.measure import label
 10from scipy.spatial.distance import cdist
 11from scipy import ndimage
 12from pathlib import Path
 13
 14
 15# import matplotlib.pyplot as plt
 16
 17def find_most_vertical_branch(branch_data, current_node, direction='up', 
 18                               verticality_weight=0.7, height_weight=0.3):
 19    """
 20    From current_node, find the branch that balances verticality AND height gained.
 21    
 22    Args:
 23        branch_data: DataFrame with branch information
 24        current_node: The node we're starting from
 25        direction: 'up' or 'down'
 26        verticality_weight: Weight for slope (0-1)
 27        height_weight: Weight for vertical distance gained (0-1)
 28    
 29    Returns:
 30        tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None)
 31    """
 32    connected_branches = branch_data[
 33        (branch_data['node-id-src'] == current_node) | 
 34        (branch_data['node-id-dst'] == current_node)
 35    ]
 36    
 37    best_score = 0
 38    best_next_node = None
 39    best_branch_idx = None
 40    best_slope = None
 41    best_vertical = None
 42    
 43    candidates = []
 44    
 45    for idx, row in connected_branches.iterrows():
 46        if row['node-id-src'] == current_node:
 47            next_node = row['node-id-dst']
 48            vertical_change = row['coord-src-0'] - row['coord-dst-0']
 49            horizontal_change = abs(row['coord-src-1'] - row['coord-dst-1'])
 50        else:
 51            next_node = row['node-id-src']
 52            vertical_change = row['coord-dst-0'] - row['coord-src-0']
 53            horizontal_change = abs(row['coord-dst-1'] - row['coord-src-1'])
 54        
 55        # Check direction
 56        if direction == 'up':
 57            condition = vertical_change > 0
 58        elif direction == 'down':
 59            condition = vertical_change < 0
 60        else:
 61            raise ValueError("direction must be 'up' or 'down'")
 62        
 63        if condition:
 64            # Calculate slope
 65            if horizontal_change > 0:
 66                slope = abs(vertical_change) / horizontal_change
 67            else:
 68                slope = float('inf')
 69            
 70            # Normalize scores (slope can be 0-inf, vertical_change is in pixels)
 71            # Use ratio to max for normalization
 72            candidates.append({
 73                'idx': idx,
 74                'next_node': next_node,
 75                'slope': slope,
 76                'vertical_change': abs(vertical_change),
 77                'horizontal_change': horizontal_change
 78            })
 79    
 80    if not candidates:
 81        return None, None, None, None
 82    
 83    # Normalize scores
 84    max_slope = max(c['slope'] if c['slope'] != float('inf') else 0 for c in candidates)
 85    max_vertical = max(c['vertical_change'] for c in candidates)
 86    
 87    # Calculate combined scores
 88    for c in candidates:
 89        slope_norm = (c['slope'] if c['slope'] != float('inf') else max_slope * 2) / (max_slope if max_slope > 0 else 1)
 90        vertical_norm = c['vertical_change'] / (max_vertical if max_vertical > 0 else 1)
 91        
 92        combined_score = verticality_weight * slope_norm + height_weight * vertical_norm
 93        
 94        print(f"    B{c['idx']}{int(c['next_node'])}: slope={c['slope']:.2f}, vert={c['vertical_change']:.1f}, score={combined_score:.3f}")
 95        
 96        if combined_score > best_score:
 97            best_score = combined_score
 98            best_next_node = c['next_node']
 99            best_branch_idx = c['idx']
100            best_slope = c['slope']
101            best_vertical = c['vertical_change']
102    
103    return best_next_node, best_branch_idx, best_slope, best_vertical
104
105
106def trace_vertical_path(branch_data, start_node, direction='down', 
107                       verticality_weight=0.5, height_weight=0.5, max_steps=100):
108    """Trace the most vertical path from start_node."""
109    path = []
110    current_node = start_node
111    visited_nodes = set()
112    
113    for step in range(max_steps):
114        if current_node in visited_nodes:
115            print(f"  Loop detected at node {int(current_node)}, stopping")
116            break
117        visited_nodes.add(current_node)
118        
119        # Now expects 4 return values
120        next_node, branch_idx, slope, vertical = find_most_vertical_branch(
121            branch_data, current_node, direction, verticality_weight, height_weight
122        )
123        
124        if next_node is None:
125            print(f"  Reached endpoint at node {int(current_node)}")
126            break
127        
128        path.append((current_node, branch_idx, slope, vertical))
129        print(f"  Step {step}: Node {int(current_node)}{int(next_node)} via B{branch_idx}")
130        
131        current_node = next_node
132    
133    path.append((current_node, None, None, None))
134    return path
135
136
137
138def label_skeleton(binary_mask):
139    """
140    Take a binary mask, skeletonize it, and return labeled connected components.
141    
142    Args:
143        binary_mask: Binary numpy array (bool or 0/255)
144    
145    Returns:
146        tuple: (skeleton, labeled_skeleton, unique_labels)
147            - skeleton: Binary skeleton array
148            - labeled_skeleton: Labeled skeleton with unique IDs for each component
149            - unique_labels: Array of unique label IDs (excluding background 0)
150    """
151    
152    # Ensure binary
153    if binary_mask.dtype != bool:
154        binary_mask = binary_mask.astype(bool)
155    
156    # Skeletonize
157    skeleton = skeletonize(binary_mask)
158    
159    # Label connected components
160    labeled_skeleton = label(skeleton)
161    
162    # Get unique labels (excluding background)
163    unique_labels = np.unique(labeled_skeleton)
164    unique_labels = unique_labels[unique_labels != 0]
165    
166    return skeleton, labeled_skeleton, unique_labels
167
168def find_top_nodes(branch_data, n_nodes=1, threshold=None):
169    """
170    Find the node(s) closest to the top of the image (lowest row values).
171    
172    Args:
173        branch_data: DataFrame with branch information
174        n_nodes: Number of top nodes to return
175        threshold: Optional - return all nodes within this many pixels of the top
176    
177    Returns:
178        list of node IDs at the top
179    """
180    # Get all unique nodes and their coordinates
181    node_coords = {}
182    for idx, row in branch_data.iterrows():
183        node_coords[row['node-id-src']] = row['coord-src-0']
184        node_coords[row['node-id-dst']] = row['coord-dst-0']
185    
186    # Sort by row coordinate (lowest = top)
187    sorted_nodes = sorted(node_coords.items(), key=lambda x: x[1])
188    
189    if threshold is not None:
190        # Return all nodes within threshold pixels of the top
191        min_row = sorted_nodes[0][1]
192        top_nodes = [node_id for node_id, row in sorted_nodes if row <= min_row + threshold]
193        print(f"Top node at row {min_row:.0f}")
194        print(f"Found {len(top_nodes)} nodes within {threshold} pixels of top")
195    else:
196        # Return top n nodes
197        top_nodes = [node_id for node_id, row in sorted_nodes[:n_nodes]]
198        print(f"Top {n_nodes} node(s):")
199        for node_id, row in sorted_nodes[:n_nodes]:
200            print(f"  Node {int(node_id)} at row {row:.0f}")
201    
202    return top_nodes
203
204def extract_root_structures(binary_mask, verbose=False):
205    """
206    Analyze all root structures in a binary mask.
207    
208    Args:
209        binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file
210    
211    Returns:
212        dict: {
213            'skeleton': full skeleton array,
214            'labeled_skeleton': labeled skeleton array,
215            'unique_labels': array of label IDs,
216            'roots': {
217                label_id: {
218                    'mask': binary mask for this root only,
219                    'skeleton': skeleton for this root only,
220                    'branch_data': DataFrame with branch information,
221                    'num_nodes': number of nodes,
222                    'num_branches': number of branches,
223                    'total_pixels': total skeleton pixels
224                },
225                ...
226            }
227        }
228    """
229    # Load from file if path provided
230    if not isinstance(binary_mask, np.ndarray):
231        binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE)
232    
233    # Ensure binary
234    if binary_mask.dtype != bool:
235        binary_mask = binary_mask.astype(bool)
236    
237    # Get labeled skeleton
238    skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask)
239    
240    # Initialize results
241    results = {
242        'skeleton': skeleton,
243        'labeled_skeleton': labeled_skeleton,
244        'unique_labels': unique_labels,
245        'roots': {}
246    }
247    if verbose:
248        print(f"Analyzing {len(unique_labels)} root structure(s)...")
249    
250    # Analyze each root separately
251    for label_id in unique_labels:
252        if verbose:
253            print(f"\n--- Root {int(label_id)} ---")
254        
255        # Extract this root's skeleton
256        single_root_mask = (labeled_skeleton == label_id)
257        
258        # Get branch data for this root
259        try:
260            skeleton_obj = Skeleton(single_root_mask)
261            branch_data = summarize(skeleton_obj)
262            
263            # Get node count
264            unique_nodes = set(branch_data['node-id-src']).union(
265                set(branch_data['node-id-dst']))
266            num_nodes = len(unique_nodes)
267            num_branches = len(branch_data)
268            total_pixels = np.sum(single_root_mask)
269            
270            # Store results
271            results['roots'][int(label_id)] = {
272                'mask': single_root_mask,
273                'skeleton': single_root_mask,
274                'branch_data': branch_data,
275                'num_nodes': num_nodes,
276                'num_branches': num_branches,
277                'total_pixels': total_pixels
278            }
279            if verbose:
280                print(f"  Nodes: {num_nodes}, Branches: {num_branches}, Pixels: {total_pixels}")
281            
282        except Exception as e:
283            if verbose:
284                print(f"  ERROR analyzing root {int(label_id)}: {e}")
285            results['roots'][int(label_id)] = {
286                'mask': single_root_mask,
287                'skeleton': single_root_mask,
288                'branch_data': None,
289                'num_nodes': 0,
290                'num_branches': 0,
291                'total_pixels': np.sum(single_root_mask),
292                'error': str(e)
293            }
294
295    if verbose:    
296        print(f"\n=== Summary ===")
297        print(f"Total roots analyzed: {len(results['roots'])}")
298        successful = sum(1 for r in results['roots'].values() if r['branch_data'] is not None)
299        print(f"Successfully analyzed: {successful}")
300    
301    return results
302
303def find_farthest_endpoint_path(branch_data, start_node, direction='down', 
304                               use_smart_scoring=False, 
305                               horizontal_penalty=0.5,
306                               straightness_weight=1.0,
307                               verbose=True):
308    """
309    Find the path to the endpoint that is farthest away in straight-line distance.
310    
311    Args:
312        branch_data: DataFrame with branch information
313        start_node: Node to start from
314        direction: 'down' or 'up' (for filtering endpoints)
315        use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness
316        horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty)
317        straightness_weight: Weight for straightness ratio (higher = more reward for straight paths)
318        verbose: Print progress
319    
320    Returns:
321        Best path: [(node, branch_idx, ...), ...]
322    """
323    import networkx as nx
324    
325    # Build NetworkX graph
326    G = nx.Graph()
327    
328    # Get node coordinates
329    node_coords = {}
330    for idx, row in branch_data.iterrows():
331        src = row['node-id-src']
332        dst = row['node-id-dst']
333        
334        # Store coordinates
335        node_coords[src] = (row['coord-src-0'], row['coord-src-1'])
336        node_coords[dst] = (row['coord-dst-0'], row['coord-dst-1'])
337        
338        # Add edge
339        G.add_edge(src, dst, branch_idx=idx, distance=row['branch-distance'])
340    
341    # Get start node coordinates
342    if start_node not in node_coords:
343        print(f"Error: Start node {int(start_node)} not found in graph")
344        return [(start_node, None, None, None)]
345    
346    start_row, start_col = node_coords[start_node]
347    
348    if verbose:
349        print(f"Start node {int(start_node)} at (row={start_row:.0f}, col={start_col:.0f})")
350        print(f"Scoring mode: {'SMART (considering straightness & horizontal deviation)' if use_smart_scoring else 'SIMPLE (just Euclidean distance)'}")
351        
352    # Find all endpoints (degree 1 nodes)
353    endpoints = [node for node in G.nodes() if G.degree(node) == 1]
354    endpoints = [ep for ep in endpoints if ep != start_node]
355    
356    if verbose:
357        print(f"Found {len(endpoints)} candidate endpoints")
358    
359    # Evaluate each endpoint
360    endpoint_scores = []
361    
362    for endpoint in endpoints:
363        end_row, end_col = node_coords[endpoint]
364        
365        # Calculate basic distances
366        euclidean_dist = np.sqrt((end_row - start_row)**2 + (end_col - start_col)**2)
367        vertical_dist = end_row - start_row  # Positive = going down
368        horizontal_dist = abs(end_col - start_col)  # Horizontal deviation
369        
370        # Filter by direction
371        if direction == 'down' and vertical_dist <= 0:
372            continue
373        elif direction == 'up' and vertical_dist >= 0:
374            continue
375        
376        # Find a path to calculate straightness
377        try:
378            node_path = nx.shortest_path(G, start_node, endpoint, weight='distance')
379            
380            # Calculate actual skeleton path length
381            skeleton_length = 0
382            for i in range(len(node_path) - 1):
383                edge_data = G[node_path[i]][node_path[i + 1]]
384                skeleton_length += edge_data['distance']
385            
386            # Calculate straightness ratio
387            # 1.0 = perfectly straight, <1.0 = wiggly
388            straightness = euclidean_dist / skeleton_length if skeleton_length > 0 else 0
389            
390        except nx.NetworkXNoPath:
391            continue
392        
393        # Calculate score based on mode
394        if use_smart_scoring:
395            # SMART SCORING: vertical distance * straightness - horizontal penalty
396            # This rewards:
397            #   - Going far down (vertical_dist)
398            #   - Straight paths (straightness)
399            # This penalizes:
400            #   - Horizontal deviation (horizontal_dist)
401            score = (abs(vertical_dist) * straightness_weight * straightness 
402                    - horizontal_penalty * horizontal_dist)
403        else:
404            # SIMPLE SCORING: just use Euclidean distance
405            score = euclidean_dist
406        
407        endpoint_scores.append({
408            'endpoint': endpoint,
409            'score': score,
410            'euclidean_dist': euclidean_dist,
411            'vertical_dist': abs(vertical_dist),
412            'horizontal_dist': horizontal_dist,
413            'straightness': straightness,
414            'skeleton_length': skeleton_length,
415            'coords': (end_row, end_col)
416        })
417    
418    if not endpoint_scores:
419        print("No valid endpoints found in specified direction!")
420        return [(start_node, None, None, None)]
421    
422    # Sort by score (highest first)
423    endpoint_scores.sort(key=lambda x: x['score'], reverse=True)
424    
425    if verbose:
426        print(f"\nTop 5 candidates:")
427        for i, ep_info in enumerate(endpoint_scores[:5]):
428            print(f"  {i+1}. Node {int(ep_info['endpoint'])}: "
429                  f"score={ep_info['score']:.1f}, "
430                  f"euclidean={ep_info['euclidean_dist']:.1f}px, "
431                  f"vertical={ep_info['vertical_dist']:.1f}px, "
432                  f"horizontal={ep_info['horizontal_dist']:.1f}px, "
433                  f"straightness={ep_info['straightness']:.2f}")
434    
435    # Get the best endpoint
436    best = endpoint_scores[0]
437    target_endpoint = best['endpoint']
438    
439    if verbose:
440        print(f"\nSelected: Node {int(target_endpoint)} (score: {best['score']:.1f})")
441    
442    # Find shortest path to target
443    try:
444        node_path = nx.shortest_path(G, start_node, target_endpoint, weight='distance')
445    except nx.NetworkXNoPath:
446        print(f"No path found to endpoint {int(target_endpoint)}")
447        return [(start_node, None, None, None)]
448    
449    # Convert to detailed path
450    detailed_path = []
451    total_skeleton_length = 0
452    
453    for i in range(len(node_path) - 1):
454        current = node_path[i]
455        next_node = node_path[i + 1]
456        
457        edge_data = G[current][next_node]
458        branch_idx = edge_data['branch_idx']
459        distance = edge_data['distance']
460        
461        total_skeleton_length += distance
462        
463        curr_row, curr_col = node_coords[current]
464        next_row, next_col = node_coords[next_node]
465        vertical_dist = next_row - curr_row
466        
467        detailed_path.append((current, branch_idx, distance, vertical_dist))
468    
469    detailed_path.append((node_path[-1], None, None, None))
470    
471    if verbose:
472        node_sequence = [int(node) for node, _, _, _ in detailed_path]
473        print(f"\nPath: {' → '.join(map(str, node_sequence[:15]))}")
474        if len(node_sequence) > 15:
475            print(f"  ... (total {len(node_sequence)} nodes)")
476        print(f"Metrics:")
477        print(f"  Euclidean: {best['euclidean_dist']:.1f}px")
478        print(f"  Skeleton: {total_skeleton_length:.1f}px")
479        print(f"  Straightness: {best['straightness']:.2f}")
480        print(f"  Horizontal deviation: {best['horizontal_dist']:.1f}px")
481    
482    return detailed_path
483
484
485def calculate_skeleton_length_px(detailed_path):
486    """
487    Calculate the total skeleton path length in pixels.
488    
489    Args:
490        detailed_path: List of tuples (node, branch_idx, distance, vertical_dist)
491                      from find_farthest_endpoint_path()
492    
493    Returns:
494        float: Total path length in pixels
495    """
496    total_length = 0.0
497    
498    # Sum all distance values, skipping the last element which has None for distance
499    for node, branch_idx, distance, vertical_dist in detailed_path:
500        if distance is not None:
501            total_length += distance
502    
503    return total_length
504
505
506
507
508
509
510def process_matched_roots_to_lengths(structures, top_node_results, labeled_shoots, num_shoots):
511    """
512    Process matched roots and return array of lengths ordered by shoot position (left to right).
513    
514    Args:
515        structures: Dict from extract_root_structures
516        top_node_results: Dict from find_top_nodes_from_shoot
517        labeled_shoots: Labeled shoot array
518        num_shoots: Number of shoots (typically 5)
519        
520    Returns:
521        np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right).
522                 Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots.
523    """
524    
525    
526    # Get x-position (centroid) of each shoot for left-to-right ordering
527    shoot_positions = {}
528    for shoot_label in range(1, num_shoots + 1):
529        shoot_mask = labeled_shoots == shoot_label
530        y_coords, x_coords = np.where(shoot_mask)
531        if len(x_coords) > 0:
532            centroid_x = np.mean(x_coords)
533            shoot_positions[shoot_label] = centroid_x
534        else:
535            shoot_positions[shoot_label] = float('inf')
536    
537    # Sort shoots by x-position (left to right)
538    sorted_shoots = sorted(shoot_positions.items(), key=lambda x: x[1])
539    shoot_order = [label for label, _ in sorted_shoots]
540    
541    # Create mapping from shoot_label to root_length
542    shoot_to_length = {}
543    
544    for root_label, result in top_node_results.items():
545        branch_data = result['branch_data']
546        shoot_label = result['shoot_label']
547        top_node = result['top_nodes'][0][0]
548        
549        try:
550            # Find longest path from top node
551            path = find_farthest_endpoint_path(
552                branch_data, 
553                top_node, 
554                direction='down', 
555                use_smart_scoring=True,
556                verbose=False
557            )
558            
559            # Calculate length
560            root_length = calculate_skeleton_length_px(path)
561            
562            # Store length for this shoot
563            shoot_to_length[shoot_label] = root_length
564            
565        except Exception as e:
566            print(f"Warning: Failed to process root {root_label} for shoot {shoot_label}: {e}")
567            shoot_to_length[shoot_label] = 0.0
568    
569    # Build output array ordered by shoot position (left to right)
570    lengths_array = np.array([
571        shoot_to_length.get(shoot_label, 0.0) 
572        for shoot_label in shoot_order
573    ])
574    
575    return lengths_array
def find_most_vertical_branch( branch_data, current_node, direction='up', verticality_weight=0.7, height_weight=0.3):
 18def find_most_vertical_branch(branch_data, current_node, direction='up', 
 19                               verticality_weight=0.7, height_weight=0.3):
 20    """
 21    From current_node, find the branch that balances verticality AND height gained.
 22    
 23    Args:
 24        branch_data: DataFrame with branch information
 25        current_node: The node we're starting from
 26        direction: 'up' or 'down'
 27        verticality_weight: Weight for slope (0-1)
 28        height_weight: Weight for vertical distance gained (0-1)
 29    
 30    Returns:
 31        tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None)
 32    """
 33    connected_branches = branch_data[
 34        (branch_data['node-id-src'] == current_node) | 
 35        (branch_data['node-id-dst'] == current_node)
 36    ]
 37    
 38    best_score = 0
 39    best_next_node = None
 40    best_branch_idx = None
 41    best_slope = None
 42    best_vertical = None
 43    
 44    candidates = []
 45    
 46    for idx, row in connected_branches.iterrows():
 47        if row['node-id-src'] == current_node:
 48            next_node = row['node-id-dst']
 49            vertical_change = row['coord-src-0'] - row['coord-dst-0']
 50            horizontal_change = abs(row['coord-src-1'] - row['coord-dst-1'])
 51        else:
 52            next_node = row['node-id-src']
 53            vertical_change = row['coord-dst-0'] - row['coord-src-0']
 54            horizontal_change = abs(row['coord-dst-1'] - row['coord-src-1'])
 55        
 56        # Check direction
 57        if direction == 'up':
 58            condition = vertical_change > 0
 59        elif direction == 'down':
 60            condition = vertical_change < 0
 61        else:
 62            raise ValueError("direction must be 'up' or 'down'")
 63        
 64        if condition:
 65            # Calculate slope
 66            if horizontal_change > 0:
 67                slope = abs(vertical_change) / horizontal_change
 68            else:
 69                slope = float('inf')
 70            
 71            # Normalize scores (slope can be 0-inf, vertical_change is in pixels)
 72            # Use ratio to max for normalization
 73            candidates.append({
 74                'idx': idx,
 75                'next_node': next_node,
 76                'slope': slope,
 77                'vertical_change': abs(vertical_change),
 78                'horizontal_change': horizontal_change
 79            })
 80    
 81    if not candidates:
 82        return None, None, None, None
 83    
 84    # Normalize scores
 85    max_slope = max(c['slope'] if c['slope'] != float('inf') else 0 for c in candidates)
 86    max_vertical = max(c['vertical_change'] for c in candidates)
 87    
 88    # Calculate combined scores
 89    for c in candidates:
 90        slope_norm = (c['slope'] if c['slope'] != float('inf') else max_slope * 2) / (max_slope if max_slope > 0 else 1)
 91        vertical_norm = c['vertical_change'] / (max_vertical if max_vertical > 0 else 1)
 92        
 93        combined_score = verticality_weight * slope_norm + height_weight * vertical_norm
 94        
 95        print(f"    B{c['idx']}{int(c['next_node'])}: slope={c['slope']:.2f}, vert={c['vertical_change']:.1f}, score={combined_score:.3f}")
 96        
 97        if combined_score > best_score:
 98            best_score = combined_score
 99            best_next_node = c['next_node']
100            best_branch_idx = c['idx']
101            best_slope = c['slope']
102            best_vertical = c['vertical_change']
103    
104    return best_next_node, best_branch_idx, best_slope, best_vertical

From current_node, find the branch that balances verticality AND height gained.

Arguments:
  • branch_data: DataFrame with branch information
  • current_node: The node we're starting from
  • direction: 'up' or 'down'
  • verticality_weight: Weight for slope (0-1)
  • height_weight: Weight for vertical distance gained (0-1)
Returns:

tuple: (next_node, branch_index, slope, vertical_change) or (None, None, None, None)

def trace_vertical_path( branch_data, start_node, direction='down', verticality_weight=0.5, height_weight=0.5, max_steps=100):
107def trace_vertical_path(branch_data, start_node, direction='down', 
108                       verticality_weight=0.5, height_weight=0.5, max_steps=100):
109    """Trace the most vertical path from start_node."""
110    path = []
111    current_node = start_node
112    visited_nodes = set()
113    
114    for step in range(max_steps):
115        if current_node in visited_nodes:
116            print(f"  Loop detected at node {int(current_node)}, stopping")
117            break
118        visited_nodes.add(current_node)
119        
120        # Now expects 4 return values
121        next_node, branch_idx, slope, vertical = find_most_vertical_branch(
122            branch_data, current_node, direction, verticality_weight, height_weight
123        )
124        
125        if next_node is None:
126            print(f"  Reached endpoint at node {int(current_node)}")
127            break
128        
129        path.append((current_node, branch_idx, slope, vertical))
130        print(f"  Step {step}: Node {int(current_node)}{int(next_node)} via B{branch_idx}")
131        
132        current_node = next_node
133    
134    path.append((current_node, None, None, None))
135    return path

Trace the most vertical path from start_node.

def label_skeleton(binary_mask):
139def label_skeleton(binary_mask):
140    """
141    Take a binary mask, skeletonize it, and return labeled connected components.
142    
143    Args:
144        binary_mask: Binary numpy array (bool or 0/255)
145    
146    Returns:
147        tuple: (skeleton, labeled_skeleton, unique_labels)
148            - skeleton: Binary skeleton array
149            - labeled_skeleton: Labeled skeleton with unique IDs for each component
150            - unique_labels: Array of unique label IDs (excluding background 0)
151    """
152    
153    # Ensure binary
154    if binary_mask.dtype != bool:
155        binary_mask = binary_mask.astype(bool)
156    
157    # Skeletonize
158    skeleton = skeletonize(binary_mask)
159    
160    # Label connected components
161    labeled_skeleton = label(skeleton)
162    
163    # Get unique labels (excluding background)
164    unique_labels = np.unique(labeled_skeleton)
165    unique_labels = unique_labels[unique_labels != 0]
166    
167    return skeleton, labeled_skeleton, unique_labels

Take a binary mask, skeletonize it, and return labeled connected components.

Arguments:
  • binary_mask: Binary numpy array (bool or 0/255)
Returns:

tuple: (skeleton, labeled_skeleton, unique_labels) - skeleton: Binary skeleton array - labeled_skeleton: Labeled skeleton with unique IDs for each component - unique_labels: Array of unique label IDs (excluding background 0)

def find_top_nodes(branch_data, n_nodes=1, threshold=None):
169def find_top_nodes(branch_data, n_nodes=1, threshold=None):
170    """
171    Find the node(s) closest to the top of the image (lowest row values).
172    
173    Args:
174        branch_data: DataFrame with branch information
175        n_nodes: Number of top nodes to return
176        threshold: Optional - return all nodes within this many pixels of the top
177    
178    Returns:
179        list of node IDs at the top
180    """
181    # Get all unique nodes and their coordinates
182    node_coords = {}
183    for idx, row in branch_data.iterrows():
184        node_coords[row['node-id-src']] = row['coord-src-0']
185        node_coords[row['node-id-dst']] = row['coord-dst-0']
186    
187    # Sort by row coordinate (lowest = top)
188    sorted_nodes = sorted(node_coords.items(), key=lambda x: x[1])
189    
190    if threshold is not None:
191        # Return all nodes within threshold pixels of the top
192        min_row = sorted_nodes[0][1]
193        top_nodes = [node_id for node_id, row in sorted_nodes if row <= min_row + threshold]
194        print(f"Top node at row {min_row:.0f}")
195        print(f"Found {len(top_nodes)} nodes within {threshold} pixels of top")
196    else:
197        # Return top n nodes
198        top_nodes = [node_id for node_id, row in sorted_nodes[:n_nodes]]
199        print(f"Top {n_nodes} node(s):")
200        for node_id, row in sorted_nodes[:n_nodes]:
201            print(f"  Node {int(node_id)} at row {row:.0f}")
202    
203    return top_nodes

Find the node(s) closest to the top of the image (lowest row values).

Arguments:
  • branch_data: DataFrame with branch information
  • n_nodes: Number of top nodes to return
  • threshold: Optional - return all nodes within this many pixels of the top
Returns:

list of node IDs at the top

def extract_root_structures(binary_mask, verbose=False):
205def extract_root_structures(binary_mask, verbose=False):
206    """
207    Analyze all root structures in a binary mask.
208    
209    Args:
210        binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file
211    
212    Returns:
213        dict: {
214            'skeleton': full skeleton array,
215            'labeled_skeleton': labeled skeleton array,
216            'unique_labels': array of label IDs,
217            'roots': {
218                label_id: {
219                    'mask': binary mask for this root only,
220                    'skeleton': skeleton for this root only,
221                    'branch_data': DataFrame with branch information,
222                    'num_nodes': number of nodes,
223                    'num_branches': number of branches,
224                    'total_pixels': total skeleton pixels
225                },
226                ...
227            }
228        }
229    """
230    # Load from file if path provided
231    if not isinstance(binary_mask, np.ndarray):
232        binary_mask = cv2.imread(str(binary_mask), cv2.IMREAD_GRAYSCALE)
233    
234    # Ensure binary
235    if binary_mask.dtype != bool:
236        binary_mask = binary_mask.astype(bool)
237    
238    # Get labeled skeleton
239    skeleton, labeled_skeleton, unique_labels = label_skeleton(binary_mask)
240    
241    # Initialize results
242    results = {
243        'skeleton': skeleton,
244        'labeled_skeleton': labeled_skeleton,
245        'unique_labels': unique_labels,
246        'roots': {}
247    }
248    if verbose:
249        print(f"Analyzing {len(unique_labels)} root structure(s)...")
250    
251    # Analyze each root separately
252    for label_id in unique_labels:
253        if verbose:
254            print(f"\n--- Root {int(label_id)} ---")
255        
256        # Extract this root's skeleton
257        single_root_mask = (labeled_skeleton == label_id)
258        
259        # Get branch data for this root
260        try:
261            skeleton_obj = Skeleton(single_root_mask)
262            branch_data = summarize(skeleton_obj)
263            
264            # Get node count
265            unique_nodes = set(branch_data['node-id-src']).union(
266                set(branch_data['node-id-dst']))
267            num_nodes = len(unique_nodes)
268            num_branches = len(branch_data)
269            total_pixels = np.sum(single_root_mask)
270            
271            # Store results
272            results['roots'][int(label_id)] = {
273                'mask': single_root_mask,
274                'skeleton': single_root_mask,
275                'branch_data': branch_data,
276                'num_nodes': num_nodes,
277                'num_branches': num_branches,
278                'total_pixels': total_pixels
279            }
280            if verbose:
281                print(f"  Nodes: {num_nodes}, Branches: {num_branches}, Pixels: {total_pixels}")
282            
283        except Exception as e:
284            if verbose:
285                print(f"  ERROR analyzing root {int(label_id)}: {e}")
286            results['roots'][int(label_id)] = {
287                'mask': single_root_mask,
288                'skeleton': single_root_mask,
289                'branch_data': None,
290                'num_nodes': 0,
291                'num_branches': 0,
292                'total_pixels': np.sum(single_root_mask),
293                'error': str(e)
294            }
295
296    if verbose:    
297        print(f"\n=== Summary ===")
298        print(f"Total roots analyzed: {len(results['roots'])}")
299        successful = sum(1 for r in results['roots'].values() if r['branch_data'] is not None)
300        print(f"Successfully analyzed: {successful}")
301    
302    return results

Analyze all root structures in a binary mask.

Arguments:
  • binary_mask: Binary numpy array (bool or 0/255) or path-like object to image file
Returns:

dict: { 'skeleton': full skeleton array, 'labeled_skeleton': labeled skeleton array, 'unique_labels': array of label IDs, 'roots': { label_id: { 'mask': binary mask for this root only, 'skeleton': skeleton for this root only, 'branch_data': DataFrame with branch information, 'num_nodes': number of nodes, 'num_branches': number of branches, 'total_pixels': total skeleton pixels }, ... } }

def find_farthest_endpoint_path( branch_data, start_node, direction='down', use_smart_scoring=False, horizontal_penalty=0.5, straightness_weight=1.0, verbose=True):
304def find_farthest_endpoint_path(branch_data, start_node, direction='down', 
305                               use_smart_scoring=False, 
306                               horizontal_penalty=0.5,
307                               straightness_weight=1.0,
308                               verbose=True):
309    """
310    Find the path to the endpoint that is farthest away in straight-line distance.
311    
312    Args:
313        branch_data: DataFrame with branch information
314        start_node: Node to start from
315        direction: 'down' or 'up' (for filtering endpoints)
316        use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness
317        horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty)
318        straightness_weight: Weight for straightness ratio (higher = more reward for straight paths)
319        verbose: Print progress
320    
321    Returns:
322        Best path: [(node, branch_idx, ...), ...]
323    """
324    import networkx as nx
325    
326    # Build NetworkX graph
327    G = nx.Graph()
328    
329    # Get node coordinates
330    node_coords = {}
331    for idx, row in branch_data.iterrows():
332        src = row['node-id-src']
333        dst = row['node-id-dst']
334        
335        # Store coordinates
336        node_coords[src] = (row['coord-src-0'], row['coord-src-1'])
337        node_coords[dst] = (row['coord-dst-0'], row['coord-dst-1'])
338        
339        # Add edge
340        G.add_edge(src, dst, branch_idx=idx, distance=row['branch-distance'])
341    
342    # Get start node coordinates
343    if start_node not in node_coords:
344        print(f"Error: Start node {int(start_node)} not found in graph")
345        return [(start_node, None, None, None)]
346    
347    start_row, start_col = node_coords[start_node]
348    
349    if verbose:
350        print(f"Start node {int(start_node)} at (row={start_row:.0f}, col={start_col:.0f})")
351        print(f"Scoring mode: {'SMART (considering straightness & horizontal deviation)' if use_smart_scoring else 'SIMPLE (just Euclidean distance)'}")
352        
353    # Find all endpoints (degree 1 nodes)
354    endpoints = [node for node in G.nodes() if G.degree(node) == 1]
355    endpoints = [ep for ep in endpoints if ep != start_node]
356    
357    if verbose:
358        print(f"Found {len(endpoints)} candidate endpoints")
359    
360    # Evaluate each endpoint
361    endpoint_scores = []
362    
363    for endpoint in endpoints:
364        end_row, end_col = node_coords[endpoint]
365        
366        # Calculate basic distances
367        euclidean_dist = np.sqrt((end_row - start_row)**2 + (end_col - start_col)**2)
368        vertical_dist = end_row - start_row  # Positive = going down
369        horizontal_dist = abs(end_col - start_col)  # Horizontal deviation
370        
371        # Filter by direction
372        if direction == 'down' and vertical_dist <= 0:
373            continue
374        elif direction == 'up' and vertical_dist >= 0:
375            continue
376        
377        # Find a path to calculate straightness
378        try:
379            node_path = nx.shortest_path(G, start_node, endpoint, weight='distance')
380            
381            # Calculate actual skeleton path length
382            skeleton_length = 0
383            for i in range(len(node_path) - 1):
384                edge_data = G[node_path[i]][node_path[i + 1]]
385                skeleton_length += edge_data['distance']
386            
387            # Calculate straightness ratio
388            # 1.0 = perfectly straight, <1.0 = wiggly
389            straightness = euclidean_dist / skeleton_length if skeleton_length > 0 else 0
390            
391        except nx.NetworkXNoPath:
392            continue
393        
394        # Calculate score based on mode
395        if use_smart_scoring:
396            # SMART SCORING: vertical distance * straightness - horizontal penalty
397            # This rewards:
398            #   - Going far down (vertical_dist)
399            #   - Straight paths (straightness)
400            # This penalizes:
401            #   - Horizontal deviation (horizontal_dist)
402            score = (abs(vertical_dist) * straightness_weight * straightness 
403                    - horizontal_penalty * horizontal_dist)
404        else:
405            # SIMPLE SCORING: just use Euclidean distance
406            score = euclidean_dist
407        
408        endpoint_scores.append({
409            'endpoint': endpoint,
410            'score': score,
411            'euclidean_dist': euclidean_dist,
412            'vertical_dist': abs(vertical_dist),
413            'horizontal_dist': horizontal_dist,
414            'straightness': straightness,
415            'skeleton_length': skeleton_length,
416            'coords': (end_row, end_col)
417        })
418    
419    if not endpoint_scores:
420        print("No valid endpoints found in specified direction!")
421        return [(start_node, None, None, None)]
422    
423    # Sort by score (highest first)
424    endpoint_scores.sort(key=lambda x: x['score'], reverse=True)
425    
426    if verbose:
427        print(f"\nTop 5 candidates:")
428        for i, ep_info in enumerate(endpoint_scores[:5]):
429            print(f"  {i+1}. Node {int(ep_info['endpoint'])}: "
430                  f"score={ep_info['score']:.1f}, "
431                  f"euclidean={ep_info['euclidean_dist']:.1f}px, "
432                  f"vertical={ep_info['vertical_dist']:.1f}px, "
433                  f"horizontal={ep_info['horizontal_dist']:.1f}px, "
434                  f"straightness={ep_info['straightness']:.2f}")
435    
436    # Get the best endpoint
437    best = endpoint_scores[0]
438    target_endpoint = best['endpoint']
439    
440    if verbose:
441        print(f"\nSelected: Node {int(target_endpoint)} (score: {best['score']:.1f})")
442    
443    # Find shortest path to target
444    try:
445        node_path = nx.shortest_path(G, start_node, target_endpoint, weight='distance')
446    except nx.NetworkXNoPath:
447        print(f"No path found to endpoint {int(target_endpoint)}")
448        return [(start_node, None, None, None)]
449    
450    # Convert to detailed path
451    detailed_path = []
452    total_skeleton_length = 0
453    
454    for i in range(len(node_path) - 1):
455        current = node_path[i]
456        next_node = node_path[i + 1]
457        
458        edge_data = G[current][next_node]
459        branch_idx = edge_data['branch_idx']
460        distance = edge_data['distance']
461        
462        total_skeleton_length += distance
463        
464        curr_row, curr_col = node_coords[current]
465        next_row, next_col = node_coords[next_node]
466        vertical_dist = next_row - curr_row
467        
468        detailed_path.append((current, branch_idx, distance, vertical_dist))
469    
470    detailed_path.append((node_path[-1], None, None, None))
471    
472    if verbose:
473        node_sequence = [int(node) for node, _, _, _ in detailed_path]
474        print(f"\nPath: {' → '.join(map(str, node_sequence[:15]))}")
475        if len(node_sequence) > 15:
476            print(f"  ... (total {len(node_sequence)} nodes)")
477        print(f"Metrics:")
478        print(f"  Euclidean: {best['euclidean_dist']:.1f}px")
479        print(f"  Skeleton: {total_skeleton_length:.1f}px")
480        print(f"  Straightness: {best['straightness']:.2f}")
481        print(f"  Horizontal deviation: {best['horizontal_dist']:.1f}px")
482    
483    return detailed_path

Find the path to the endpoint that is farthest away in straight-line distance.

Arguments:
  • branch_data: DataFrame with branch information
  • start_node: Node to start from
  • direction: 'down' or 'up' (for filtering endpoints)
  • use_smart_scoring: If True, use scoring that penalizes horizontal deviation and rewards straightness
  • horizontal_penalty: Weight for horizontal deviation penalty (higher = more penalty)
  • straightness_weight: Weight for straightness ratio (higher = more reward for straight paths)
  • verbose: Print progress
Returns:

Best path: [(node, branch_idx, ...), ...]

def calculate_skeleton_length_px(detailed_path):
486def calculate_skeleton_length_px(detailed_path):
487    """
488    Calculate the total skeleton path length in pixels.
489    
490    Args:
491        detailed_path: List of tuples (node, branch_idx, distance, vertical_dist)
492                      from find_farthest_endpoint_path()
493    
494    Returns:
495        float: Total path length in pixels
496    """
497    total_length = 0.0
498    
499    # Sum all distance values, skipping the last element which has None for distance
500    for node, branch_idx, distance, vertical_dist in detailed_path:
501        if distance is not None:
502            total_length += distance
503    
504    return total_length

Calculate the total skeleton path length in pixels.

Arguments:
  • detailed_path: List of tuples (node, branch_idx, distance, vertical_dist) from find_farthest_endpoint_path()
Returns:

float: Total path length in pixels

def process_matched_roots_to_lengths(structures, top_node_results, labeled_shoots, num_shoots):
511def process_matched_roots_to_lengths(structures, top_node_results, labeled_shoots, num_shoots):
512    """
513    Process matched roots and return array of lengths ordered by shoot position (left to right).
514    
515    Args:
516        structures: Dict from extract_root_structures
517        top_node_results: Dict from find_top_nodes_from_shoot
518        labeled_shoots: Labeled shoot array
519        num_shoots: Number of shoots (typically 5)
520        
521    Returns:
522        np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right).
523                 Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots.
524    """
525    
526    
527    # Get x-position (centroid) of each shoot for left-to-right ordering
528    shoot_positions = {}
529    for shoot_label in range(1, num_shoots + 1):
530        shoot_mask = labeled_shoots == shoot_label
531        y_coords, x_coords = np.where(shoot_mask)
532        if len(x_coords) > 0:
533            centroid_x = np.mean(x_coords)
534            shoot_positions[shoot_label] = centroid_x
535        else:
536            shoot_positions[shoot_label] = float('inf')
537    
538    # Sort shoots by x-position (left to right)
539    sorted_shoots = sorted(shoot_positions.items(), key=lambda x: x[1])
540    shoot_order = [label for label, _ in sorted_shoots]
541    
542    # Create mapping from shoot_label to root_length
543    shoot_to_length = {}
544    
545    for root_label, result in top_node_results.items():
546        branch_data = result['branch_data']
547        shoot_label = result['shoot_label']
548        top_node = result['top_nodes'][0][0]
549        
550        try:
551            # Find longest path from top node
552            path = find_farthest_endpoint_path(
553                branch_data, 
554                top_node, 
555                direction='down', 
556                use_smart_scoring=True,
557                verbose=False
558            )
559            
560            # Calculate length
561            root_length = calculate_skeleton_length_px(path)
562            
563            # Store length for this shoot
564            shoot_to_length[shoot_label] = root_length
565            
566        except Exception as e:
567            print(f"Warning: Failed to process root {root_label} for shoot {shoot_label}: {e}")
568            shoot_to_length[shoot_label] = 0.0
569    
570    # Build output array ordered by shoot position (left to right)
571    lengths_array = np.array([
572        shoot_to_length.get(shoot_label, 0.0) 
573        for shoot_label in shoot_order
574    ])
575    
576    return lengths_array

Process matched roots and return array of lengths ordered by shoot position (left to right).

Arguments:
  • structures: Dict from extract_root_structures
  • top_node_results: Dict from find_top_nodes_from_shoot
  • labeled_shoots: Labeled shoot array
  • num_shoots: Number of shoots (typically 5)
Returns:

np.array: Array of root lengths in pixels, ordered by shoot x-position (left to right). Zero-indexed where [0] is leftmost shoot. Returns 0.0 for shoots without matched roots.