library.pid_tester
1import numpy as np 2from datetime import datetime 3from pathlib import Path 4import json 5import uuid 6import matplotlib.pyplot as plt 7from library.robot_control import get_position, find_workspace, get_velocities, get_torques 8 9from library.pid_controller import PID 10 11 12def calculate_metrics(position_data, time_data, target): 13 """Calculate all PID performance metrics. 14 15 Args: 16 position_data (list): Position values over time. 17 time_data (list): Time step values. 18 target (float): Target position. 19 20 Returns: 21 dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error. 22 """ 23 if not position_data: 24 return { 25 'rise_time': None, 26 'settling_time': None, 27 'overshoot': 0.0, 28 'steady_state_error': None 29 } 30 31 position_array = np.array(position_data) 32 start_val = position_data[0] 33 34 # Avoid division by zero 35 if abs(target - start_val) < 1e-12: 36 return { 37 'rise_time': 0.0, 38 'settling_time': 0.0, 39 'overshoot': 0.0, 40 'steady_state_error': abs(position_array[-1] - target) 41 } 42 43 # Calculate rise time (10% to 90%) 44 ten_percent = start_val + 0.1 * (target - start_val) 45 ninety_percent = start_val + 0.9 * (target - start_val) 46 moving_positive = target > start_val 47 48 ten_idx = None 49 ninety_idx = None 50 51 for i, pos in enumerate(position_data): 52 if ten_idx is None: 53 if (moving_positive and pos >= ten_percent) or (not moving_positive and pos <= ten_percent): 54 ten_idx = i 55 if ninety_idx is None: 56 if (moving_positive and pos >= ninety_percent) or (not moving_positive and pos <= ninety_percent): 57 ninety_idx = i 58 break 59 60 rise_time = time_data[ninety_idx] - time_data[ten_idx] if (ten_idx is not None and ninety_idx is not None) else None 61 62 # Calculate settling time (2% band) 63 settling_band = abs(target - start_val) * 0.02 64 settling_time = None 65 66 for i in range(len(position_data) - 1, -1, -1): 67 if abs(position_data[i] - target) > settling_band: 68 settling_time = time_data[i] if i < len(time_data) - 1 else time_data[-1] 69 break 70 71 if settling_time is None: 72 settling_time = 0 73 74 # Calculate overshoot 75 if target > start_val: 76 max_val = np.max(position_array) 77 overshoot = (max_val - target) / abs(target - start_val) * 100 if max_val > target else 0 78 else: 79 min_val = np.min(position_array) 80 overshoot = (target - min_val) / abs(target - start_val) * 100 if min_val < target else 0 81 82 # Calculate steady state error 83 steady_state_error = abs(position_array[-1] - target) 84 85 return { 86 'rise_time': rise_time, 87 'settling_time': settling_time, 88 'overshoot': overshoot, 89 'steady_state_error': steady_state_error 90 } 91 92 93class PIDTester: 94 """PID controller testing class. 95 96 Args: 97 sim (Simulation): The simulation object. 98 axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'. 99 limits (dict, optional): Workspace limits (auto-calculated if not provided). 100 invert_output (bool): Whether to invert the PID output for this axis. Defaults to False. 101 """ 102 103 def __init__(self, sim, axis='x', limits=None, invert_output=False): 104 self.sim = sim 105 self.axis = axis.lower() 106 self.axis_map = {'x': 0, 'y': 1, 'z': 2} 107 self.limits = limits 108 self.invert_output = invert_output 109 110 @property 111 def limits(self): 112 if self._limits is None: 113 w = find_workspace(self.sim) 114 self._limits = { 115 'x': [w['x_min'], w['x_max']], 116 'y': [w['y_min'], w['y_max']], 117 'z': [w['z_min'], w['z_max']], 118 'center': w['center'] 119 } 120 return self._limits 121 122 @limits.setter 123 def limits(self, workspace=None): 124 if workspace is not None: 125 self._limits = { 126 'x': [workspace['x_min'], workspace['x_max']], 127 'y': [workspace['y_min'], workspace['y_max']], 128 'z': [workspace['z_min'], workspace['z_max']], 129 'center': workspace['center'] 130 } 131 else: 132 self._limits = None 133 134 def _run_single_trial(self, kp, ki, kd, dt, target_pos, 135 max_steps, tolerance, settle_steps, 136 reset=True, verbose=False): 137 """Run a single trial of PID test. 138 139 Args: 140 kp (float): Proportional gain. 141 ki (float): Integral gain. 142 kd (float): Derivative gain. 143 dt (float): Time step. 144 target_pos (list): Target position [x, y, z]. 145 max_steps (int): Maximum number of simulation steps. 146 tolerance (float): Position tolerance for settling. 147 settle_steps (int): Number of steps within tolerance to consider settled. 148 reset (bool): Whether to reset simulation before trial. Defaults to True. 149 verbose (bool): Print debug information. Defaults to False. 150 151 Returns: 152 dict: Single trial results containing metrics, position data, and metadata. 153 """ 154 if reset: 155 self.sim.reset() 156 157 axis_idx = self.axis_map[self.axis.lower()] 158 target = target_pos[axis_idx] 159 start_pos = get_position(self.sim) 160 161 # Check target is within limits 162 axis_limits = self.limits[self.axis] 163 if target < axis_limits[0] or target > axis_limits[1]: 164 raise ValueError(f'Target {target:.4f} outside of {self.axis}-axis range [{axis_limits[0]:.4f}, {axis_limits[1]:.4f}]') 165 166 # Move to starting position 167 actions = [[0, 0, 0, 0]] 168 actions[0][axis_idx] = -1000 169 self.sim.run(actions, 100) 170 171 pid = PID(kp=kp, ki=ki, kd=kd, setpoint=target, invert_output=self.invert_output) 172 173 # Data collection 174 position_data = [] 175 time_data = [] 176 error_data = [] 177 control_data = [] 178 velocity_data = [] 179 torque_data = [] 180 181 # Settling detection 182 steps_in_tolerance = 0 183 184 for step in range(max_steps): 185 current_pos = get_position(self.sim) 186 current_val = current_pos[axis_idx] 187 current_vel = get_velocities(self.sim) 188 current_torque = get_torques(self.sim) 189 190 if verbose: 191 print(f'current_val {self.axis} axis: {current_val:.5f} -> {target:.5f}') 192 193 position_data.append(current_val) 194 time_data.append(step) 195 error_data.append(target - current_val) 196 velocity_data.append(current_vel[axis_idx]) 197 torque_data.append(current_torque[axis_idx]) 198 199 # Check position remains near target 200 if abs(current_val - target) <= tolerance: 201 steps_in_tolerance += 1 202 if verbose: 203 print(f'settling {steps_in_tolerance} of {settle_steps}') 204 else: 205 steps_in_tolerance = 0 206 207 # Stop execution if position remains near target 208 if steps_in_tolerance >= settle_steps: 209 break 210 211 # Calculate PID output 212 vx = pid(current_val, dt) 213 control_data.append(vx) 214 215 actions = [[0, 0, 0, 0]] 216 actions[0][axis_idx] = vx 217 self.sim.run(actions, num_steps=1) 218 219 # Calculate metrics 220 metrics = calculate_metrics(position_data, time_data, target) 221 222 # Store result 223 result = { 224 'timestamp': datetime.now().isoformat(), 225 'metrics': metrics, 226 'position_data': position_data, 227 'time_data': time_data, 228 'error_data': error_data, 229 'control_data': control_data, 230 'velocity_data': velocity_data, 231 'torque_data': torque_data, 232 'target': target, 233 'start_position': start_pos[self.axis_map[self.axis.lower()]], 234 'settled': steps_in_tolerance >= settle_steps, 235 'invert_output': self.invert_output, 236 'total_time': len(time_data), 237 238 } 239 240 return result 241 242 def test_gains(self, kp, ki, kd, dt, target_pos, 243 max_steps, tolerance, settle_steps, num_trials, 244 reset=True, verbose=False): 245 """Test a set of PID gains over multiple trials. 246 247 Args: 248 kp (float): Proportional gain. 249 ki (float): Integral gain. 250 kd (float): Derivative gain. 251 dt (float): Time step. 252 target_pos (list): Target position [x, y, z]. 253 max_steps (int): Maximum number of simulation steps. 254 tolerance (float): Position tolerance for settling. 255 settle_steps (int): Number of steps within tolerance to consider settled. 256 num_trials (int): Number of trials to run. 257 reset (bool): Whether to reset simulation before each trial. Defaults to True. 258 verbose (bool): Print debug information. Defaults to False. 259 260 Returns: 261 dict: Test results with gains, timestamp, and trial data. 262 """ 263 trials = [] 264 265 for trial in range(num_trials): 266 trial_result = self._run_single_trial( 267 kp, ki, kd, dt, target_pos, 268 max_steps, tolerance, settle_steps, 269 reset=reset, verbose=verbose 270 ) 271 trials.append(trial_result) 272 273 result = { 274 'gains': {'kp': kp, 'ki': ki, 'kd': kd}, 275 'timestamp': datetime.now().isoformat(), 276 'num_trials': num_trials, 277 'trials': trials 278 } 279 280 return result 281 282 283class Experiment: 284 """Create a new experiment to track PID testing. 285 286 Args: 287 title (str): Title of the experiment. 288 hypothesis (str): Hypothesis being tested. 289 axis (str): Axis being tested ('x', 'y', or 'z'). 290 invert_output (bool): Whether output was inverted for this axis. Defaults to False. 291 """ 292 293 def __init__(self, title, hypothesis, axis, invert_output=False): 294 self.id = str(uuid.uuid4()) 295 self.title = title 296 self.hypothesis = hypothesis 297 self.axis = axis.lower() 298 self.timestamp = datetime.now().isoformat() 299 self.tests = [] 300 self.invert_output = invert_output 301 302 def add_test(self, test_result): 303 """Add a test result from PIDTester.test_gains(). 304 305 Args: 306 test_result (dict): Result dictionary from test_gains(). 307 308 Raises: 309 ValueError: If test_result doesn't contain required keys. 310 """ 311 # Basic validation 312 if 'gains' not in test_result or 'trials' not in test_result: 313 raise ValueError("Invalid test_result: must contain 'gains' and 'trials'") 314 315 self.tests.append(test_result) 316 317 def get_averaged_metrics(self, test_index): 318 """Get averaged metrics across trials for a specific test. 319 320 Args: 321 test_index (int): Index of the test in self.tests. 322 323 Returns: 324 dict: Averaged metrics with mean, std, min, max for each metric. 325 326 Raises: 327 IndexError: If test_index is out of range. 328 """ 329 if test_index >= len(self.tests): 330 raise IndexError(f"Test index {test_index} out of range") 331 332 test = self.tests[test_index] 333 trials = test['trials'] 334 335 if len(trials) == 1: 336 return trials[0]['metrics'] 337 338 metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 339 averaged = {} 340 341 for key in metric_keys: 342 values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None] 343 if values: 344 averaged[key] = { 345 'mean': np.mean(values), 346 'std': np.std(values), 347 'min': np.min(values), 348 'max': np.max(values) 349 } 350 else: 351 averaged[key] = None 352 353 return averaged 354 355 def print_summary(self): 356 """Print summary table of all tests in this experiment.""" 357 print(f"\nExperiment: {self.title}") 358 print(f"Hypothesis: {self.hypothesis}") 359 print(f"Axis: {self.axis}") 360 print(f"Inverted Output: {self.invert_output}") 361 print(f"Tests: {len(self.tests)}") 362 print("\n" + "="*90) # <-- MAKE WIDER 363 print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}") # <-- ADD COLUMN 364 print("="*90) 365 366 for i, test in enumerate(self.tests): 367 gains = test['gains'] 368 metrics = self.get_averaged_metrics(i) 369 370 # Handle both single trial (dict) and averaged (dict with 'mean') 371 if isinstance(metrics['rise_time'], dict): 372 rise = metrics['rise_time']['mean'] 373 settling = metrics['settling_time']['mean'] 374 overshoot = metrics['overshoot']['mean'] 375 ss_error = metrics['steady_state_error']['mean'] 376 else: 377 rise = metrics['rise_time'] 378 settling = metrics['settling_time'] 379 overshoot = metrics['overshoot'] 380 ss_error = metrics['steady_state_error'] 381 382 # Get total time from first trial 383 total_time = test['trials'][0]['total_time'] 384 385 # Format None values as 'N/A' instead of trying to format as float 386 rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}" 387 settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}" 388 overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}" 389 ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}" 390 391 print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} " 392 f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}") # <-- ADD TOTAL 393 394 print("="*90) 395 396 def _determine_varying_gain(self): 397 """Determine which gain parameter varies across tests. 398 399 Returns: 400 tuple: (gain_key, x_values, x_labels) where: 401 - gain_key (str or None): 'Kp', 'Ki', 'Kd', or None if multiple vary. 402 - x_values (list): Numeric values for x-axis. 403 - x_labels (list): String labels for x-axis. 404 """ 405 gains_list = [test['gains'] for test in self.tests] 406 407 kp_values = [g['kp'] for g in gains_list] 408 ki_values = [g['ki'] for g in gains_list] 409 kd_values = [g['kd'] for g in gains_list] 410 411 kp_varies = len(set(kp_values)) > 1 412 ki_varies = len(set(ki_values)) > 1 413 kd_varies = len(set(kd_values)) > 1 414 415 varies_count = sum([kp_varies, ki_varies, kd_varies]) 416 417 if varies_count == 0: 418 # All tests have same gains 419 return None, list(range(len(self.tests))), [str(i) for i in range(len(self.tests))] 420 elif varies_count == 1: 421 # Exactly one gain varies 422 if kp_varies: 423 # Format labels to avoid overlapping decimals 424 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kp_values] 425 return 'Kp', kp_values, labels 426 elif ki_varies: 427 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in ki_values] 428 return 'Ki', ki_values, labels 429 else: 430 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kd_values] 431 return 'Kd', kd_values, labels 432 else: 433 # Multiple gains vary - use test index with composite labels 434 x_values = list(range(len(self.tests))) 435 x_labels = [f"P={g['kp']:.0f if g['kp']>=1 else g['kp']:.2f},I={g['ki']:.0f if g['ki']>=1 else g['ki']:.2f},D={g['kd']:.0f if g['kd']>=1 else g['kd']:.2f}" for g in gains_list] 436 return None, x_values, x_labels 437 438 def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'): 439 """Plot averaged metrics across all tests. 440 441 Args: 442 metric_names (list of str or 'all', optional): Which metrics to plot for line plots. 443 If 'all' or None, plots all metrics. Ignored for scatter plots. 444 Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'. 445 separate_subplots (bool): If True, each metric gets its own subplot. 446 If False, all on one plot. Only applies to line plots. Defaults to True. 447 plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 448 speed vs stability tradeoff. Defaults to 'lines'. 449 450 Returns: 451 tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests. 452 """ 453 if not self.tests: 454 print("No tests to plot") 455 return None, None 456 457 # Scatter plot: speed vs stability 458 if plot_type == 'scatter': 459 return self._plot_speed_vs_stability() 460 461 # Line plot: existing behavior 462 if metric_names == 'all' or metric_names is None: 463 metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 464 465 # Determine which gain varies 466 gain_key, x_values, x_labels = self._determine_varying_gain() 467 468 # Collect data 469 data = {metric: [] for metric in metric_names} 470 errors = {metric: [] for metric in metric_names} 471 valid_indices = {metric: [] for metric in metric_names} 472 473 for i, test in enumerate(self.tests): 474 metrics = self.get_averaged_metrics(i) 475 for metric in metric_names: 476 if isinstance(metrics[metric], dict): 477 value = metrics[metric]['mean'] 478 error = metrics[metric]['std'] 479 else: 480 value = metrics[metric] 481 error = 0 482 483 if value is not None: 484 data[metric].append(value) 485 errors[metric].append(error) 486 valid_indices[metric].append(i) 487 488 # Create plots 489 if separate_subplots: 490 fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names))) 491 if len(metric_names) == 1: 492 axes = [axes] 493 494 for ax, metric in zip(axes, metric_names): 495 if data[metric]: 496 valid_x_values = [x_values[i] for i in valid_indices[metric]] 497 valid_x_labels = [x_labels[i] for i in valid_indices[metric]] 498 499 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 500 marker='o', capsize=5, capthick=2) 501 ax.set_xlabel(gain_key if gain_key else 'Test Index') 502 ax.set_ylabel(metric.replace('_', ' ').title()) 503 ax.set_xscale('log') 504 ax.set_xticks(valid_x_values) 505 ax.set_xticklabels(valid_x_labels) 506 ax.grid(True, alpha=0.3, which='both') 507 else: 508 ax.text(0.5, 0.5, f'No valid data for {metric}', 509 ha='center', va='center', transform=ax.transAxes) 510 ax.set_xlabel(gain_key if gain_key else 'Test Index') 511 ax.set_ylabel(metric.replace('_', ' ').title()) 512 else: 513 fig, ax = plt.subplots(figsize=(10, 6)) 514 for metric in metric_names: 515 if data[metric]: 516 valid_x_values = [x_values[i] for i in valid_indices[metric]] 517 518 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 519 marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title()) 520 521 ax.set_xlabel(gain_key if gain_key else 'Test Index') 522 ax.set_ylabel('Value') 523 ax.set_xscale('log') 524 if x_values: 525 ax.set_xticks(x_values) 526 ax.set_xticklabels(x_labels) 527 ax.legend() 528 ax.grid(True, alpha=0.3, which='both') 529 530 fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12) 531 fig.tight_layout() 532 533 return fig, axes if separate_subplots else ax 534 535 def _plot_speed_vs_stability(self): 536 """Create scatter plot of speed (total time) vs stability (overshoot). 537 538 Returns: 539 tuple: (fig, ax) Matplotlib figure and axis objects. 540 """ 541 # Collect data 542 total_times = [] 543 overshoots = [] 544 gain_values = [] 545 546 for i, test in enumerate(self.tests): 547 metrics = self.get_averaged_metrics(i) 548 549 # Get overshoot 550 if isinstance(metrics['overshoot'], dict): 551 overshoot = metrics['overshoot']['mean'] 552 else: 553 overshoot = metrics['overshoot'] 554 555 # Get total time from first trial 556 total_time = test['trials'][0]['total_time'] 557 558 if overshoot is not None and total_time is not None: 559 total_times.append(total_time) 560 overshoots.append(overshoot) 561 gain_values.append(i) 562 563 if not total_times: 564 print("No valid data for scatter plot") 565 return None, None 566 567 # Determine which gain varies for coloring 568 gain_key, varying_gains, _ = self._determine_varying_gain() 569 570 # Create scatter plot 571 fig, ax = plt.subplots(figsize=(10, 8)) 572 573 # Use colormap based on varying gain values 574 if gain_key: 575 # Extract the actual varying gain values for valid indices 576 color_values = [varying_gains[i] for i in gain_values] 577 scatter = ax.scatter(total_times, overshoots, c=color_values, 578 cmap='viridis', s=100, edgecolors='black', linewidth=1.5) 579 cbar = plt.colorbar(scatter, ax=ax) 580 cbar.set_label(gain_key, rotation=270, labelpad=20) 581 else: 582 # No single varying gain, use uniform color 583 ax.scatter(total_times, overshoots, s=100, edgecolors='black', linewidth=1.5) 584 585 # Annotate points with gain values 586 gains_list = [self.tests[i]['gains'] for i in gain_values] 587 for i, (x, y, gains) in enumerate(zip(total_times, overshoots, gains_list)): 588 # Format gain values outside f-string 589 kp_str = f"{gains['kp']:.0f}" if gains['kp'] >= 1 else f"{gains['kp']:.2f}" 590 ki_str = f"{gains['ki']:.0f}" if gains['ki'] >= 1 else f"{gains['ki']:.2f}" 591 kd_str = f"{gains['kd']:.0f}" if gains['kd'] >= 1 else f"{gains['kd']:.2f}" 592 593 label = f"P={kp_str}\nI={ki_str}\nD={kd_str}" 594 ax.annotate(label, (x, y), xytext=(5, 5), textcoords='offset points', 595 fontsize=8, alpha=0.7) 596 597 ax.set_xlabel('Total Time (steps)', fontsize=12) 598 ax.set_ylabel('Overshoot (%)', fontsize=12) 599 ax.set_title(f"{self.title}\nSpeed vs Stability Tradeoff", fontsize=14) 600 ax.grid(True, alpha=0.3) 601 602 # Add arrow pointing to ideal region (bottom-left) 603 ax.annotate('Ideal\n(Fast & Stable)', xy=(min(total_times), min(overshoots)), 604 xytext=(0.05, 0.95), textcoords='axes fraction', 605 fontsize=10, color='green', weight='bold', 606 ha='left', va='top', 607 arrowprops=dict(arrowstyle='->', color='green', lw=2)) 608 609 fig.tight_layout() 610 611 return fig, ax 612 613 def plot_trajectory(self, test_index, trial_index=0): 614 """Plot detailed trajectory data for a specific trial. 615 616 Shows position, error, and control signal over time. 617 618 Args: 619 test_index (int): Index of the test in self.tests. 620 trial_index (int): Index of the trial within that test. Defaults to 0. 621 622 Returns: 623 tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots). 624 625 Raises: 626 IndexError: If test_index or trial_index is out of range. 627 """ 628 if test_index >= len(self.tests): 629 raise IndexError(f"Test index {test_index} out of range") 630 631 test = self.tests[test_index] 632 633 if trial_index >= len(test['trials']): 634 raise IndexError(f"Trial index {trial_index} out of range") 635 636 trial = test['trials'][trial_index] 637 gains = test['gains'] 638 639 # Extract data 640 time_data = trial['time_data'] 641 velocity_data = trial['velocity_data'] 642 torque_data = trial['torque_data'] 643 position_data = trial['position_data'] 644 error_data = trial['error_data'] 645 control_data = trial['control_data'] 646 target = trial['target'] 647 648 # Create figure with 3 subplots 649 fig, axes = plt.subplots(5, 1, figsize=(12, 10)) 650 651 # Plot 1: Position vs Time 652 axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position') 653 axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target') 654 axes[0].set_xlabel('Time (steps)') 655 axes[0].set_ylabel('Position') 656 axes[0].set_title('Position vs Time') 657 axes[0].legend() 658 axes[0].grid(True, alpha=0.3) 659 660 # Plot 2: Error vs Time 661 axes[1].plot(time_data, error_data, 'g-', linewidth=2) 662 axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1) 663 axes[1].set_xlabel('Time (steps)') 664 axes[1].set_ylabel('Error') 665 axes[1].set_title('Error vs Time') 666 axes[1].grid(True, alpha=0.3) 667 668 # Plot 3: Control Signal vs Time 669 axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2) 670 axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1) 671 axes[2].set_xlabel('Time (steps)') 672 axes[2].set_ylabel('Control Signal') 673 axes[2].set_title('Control Signal vs Time') 674 axes[2].grid(True, alpha=0.3) 675 676 # Plot 4: Velocity vs Time 677 axes[3].plot(time_data, velocity_data, 'purple', linewidth=2) 678 axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1) 679 axes[3].set_xlabel('Time (steps)') 680 axes[3].set_ylabel('Velocity') 681 axes[3].set_title('Velocity vs Time') 682 axes[3].grid(True, alpha=0.3) 683 684 # Plot 5: Torque vs Time 685 axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2) 686 axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1) 687 axes[4].set_xlabel('Time (steps)') 688 axes[4].set_ylabel('Torque') 689 axes[4].set_title('Torque vs Time') 690 axes[4].grid(True, alpha=0.3) 691 # Overall title 692 fig.suptitle( 693 f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}", 694 fontsize=12 695 ) 696 fig.tight_layout() 697 698 return fig, axes 699 700 def to_dict(self): 701 """Convert experiment to dictionary for JSON serialization. 702 703 Returns: 704 dict: Dictionary representation of the experiment. 705 """ 706 return { 707 'id': self.id, 708 'title': self.title, 709 'hypothesis': self.hypothesis, 710 'axis': self.axis, 711 'timestamp': self.timestamp, 712 'tests': self.tests, 713 'invert_output': self.invert_output 714 } 715 716 717class ExperimentLog: 718 """Manage a log of PID experiments. 719 720 Args: 721 filepath (str or Path): Path to the JSON log file. 722 """ 723 724 def __init__(self, filepath): 725 self.filepath = Path(filepath) 726 self.experiments = [] 727 728 def add_experiment(self, experiment): 729 """Add an Experiment to the log. 730 731 Args: 732 experiment (Experiment): Experiment object to add. 733 734 Raises: 735 TypeError: If experiment is not an Experiment object. 736 """ 737 if not isinstance(experiment, Experiment): 738 raise TypeError("Must add an Experiment object") 739 self.experiments.append(experiment) 740 741 def save(self): 742 """Save experiments to JSON file, preserving existing experiments. 743 744 Checks for duplicate IDs and updates existing experiments. 745 """ 746 # Load existing data if file exists 747 if self.filepath.exists(): 748 with open(self.filepath, 'r') as f: 749 data = json.load(f) 750 existing_experiments = data.get('experiments', []) 751 else: 752 existing_experiments = [] 753 754 # Create dict of existing experiments by ID 755 existing_by_id = {exp['id']: exp for exp in existing_experiments} 756 757 # Add or update with current experiments 758 for exp in self.experiments: 759 existing_by_id[exp.id] = exp.to_dict() 760 761 # Save all experiments 762 data = {'experiments': list(existing_by_id.values())} 763 764 with open(self.filepath, 'w') as f: 765 json.dump(data, f, indent=2) 766 767 print(f"Saved {len(self.experiments)} experiments to {self.filepath}") 768 769 @classmethod 770 def load(cls, filepath): 771 """Load experiments from JSON file. 772 773 Args: 774 filepath (str or Path): Path to the JSON log file. 775 776 Returns: 777 ExperimentLog: Log with loaded Experiment objects. 778 """ 779 log = cls(filepath) 780 781 if not log.filepath.exists(): 782 print(f"No existing file at {filepath}, starting new log") 783 return log 784 785 with open(log.filepath, 'r') as f: 786 data = json.load(f) 787 788 for exp_dict in data.get('experiments', []): 789 exp = cls._from_dict(exp_dict) 790 log.experiments.append(exp) 791 792 print(f"Loaded {len(log.experiments)} experiments from {filepath}") 793 return log 794 795 @staticmethod 796 def _from_dict(exp_dict): 797 """Reconstruct an Experiment object from a dictionary. 798 799 Args: 800 exp_dict (dict): Dictionary representation of an experiment. 801 802 Returns: 803 Experiment: Reconstructed Experiment object. 804 """ 805 exp = Experiment( 806 title=exp_dict['title'], 807 hypothesis=exp_dict['hypothesis'], 808 axis=exp_dict['axis'], 809 invert_output=exp_dict.get('invert_output', False) 810 ) 811 exp.id = exp_dict['id'] 812 exp.timestamp = exp_dict['timestamp'] 813 exp.tests = exp_dict['tests'] 814 return exp 815 816 def get_experiment(self, identifier): 817 """Get experiment by title or short hash (first 6 characters of ID). 818 819 Args: 820 identifier (str): Either the full/partial title or first 6 chars of ID. 821 822 Returns: 823 Experiment or None: Matching experiment, or None if not found. 824 """ 825 # Try short hash first 826 if len(identifier) == 6: 827 for exp in self.experiments: 828 if exp.id[:6] == identifier: 829 return exp 830 831 # Try title (case-insensitive partial match) 832 identifier_lower = identifier.lower() 833 for exp in self.experiments: 834 if identifier_lower in exp.title.lower(): 835 return exp 836 837 return None 838 839 def list_experiments(self): 840 """Print summary table of all experiments.""" 841 if not self.experiments: 842 print("No experiments in log") 843 return 844 845 print("\n" + "="*100) 846 print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}") 847 print("="*100) 848 849 for exp in self.experiments: 850 short_id = exp.id[:6] 851 title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title 852 hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis 853 num_tests = len(exp.tests) 854 855 print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}") 856 857 print("="*100)
13def calculate_metrics(position_data, time_data, target): 14 """Calculate all PID performance metrics. 15 16 Args: 17 position_data (list): Position values over time. 18 time_data (list): Time step values. 19 target (float): Target position. 20 21 Returns: 22 dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error. 23 """ 24 if not position_data: 25 return { 26 'rise_time': None, 27 'settling_time': None, 28 'overshoot': 0.0, 29 'steady_state_error': None 30 } 31 32 position_array = np.array(position_data) 33 start_val = position_data[0] 34 35 # Avoid division by zero 36 if abs(target - start_val) < 1e-12: 37 return { 38 'rise_time': 0.0, 39 'settling_time': 0.0, 40 'overshoot': 0.0, 41 'steady_state_error': abs(position_array[-1] - target) 42 } 43 44 # Calculate rise time (10% to 90%) 45 ten_percent = start_val + 0.1 * (target - start_val) 46 ninety_percent = start_val + 0.9 * (target - start_val) 47 moving_positive = target > start_val 48 49 ten_idx = None 50 ninety_idx = None 51 52 for i, pos in enumerate(position_data): 53 if ten_idx is None: 54 if (moving_positive and pos >= ten_percent) or (not moving_positive and pos <= ten_percent): 55 ten_idx = i 56 if ninety_idx is None: 57 if (moving_positive and pos >= ninety_percent) or (not moving_positive and pos <= ninety_percent): 58 ninety_idx = i 59 break 60 61 rise_time = time_data[ninety_idx] - time_data[ten_idx] if (ten_idx is not None and ninety_idx is not None) else None 62 63 # Calculate settling time (2% band) 64 settling_band = abs(target - start_val) * 0.02 65 settling_time = None 66 67 for i in range(len(position_data) - 1, -1, -1): 68 if abs(position_data[i] - target) > settling_band: 69 settling_time = time_data[i] if i < len(time_data) - 1 else time_data[-1] 70 break 71 72 if settling_time is None: 73 settling_time = 0 74 75 # Calculate overshoot 76 if target > start_val: 77 max_val = np.max(position_array) 78 overshoot = (max_val - target) / abs(target - start_val) * 100 if max_val > target else 0 79 else: 80 min_val = np.min(position_array) 81 overshoot = (target - min_val) / abs(target - start_val) * 100 if min_val < target else 0 82 83 # Calculate steady state error 84 steady_state_error = abs(position_array[-1] - target) 85 86 return { 87 'rise_time': rise_time, 88 'settling_time': settling_time, 89 'overshoot': overshoot, 90 'steady_state_error': steady_state_error 91 }
Calculate all PID performance metrics.
Arguments:
- position_data (list): Position values over time.
- time_data (list): Time step values.
- target (float): Target position.
Returns:
dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error.
94class PIDTester: 95 """PID controller testing class. 96 97 Args: 98 sim (Simulation): The simulation object. 99 axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'. 100 limits (dict, optional): Workspace limits (auto-calculated if not provided). 101 invert_output (bool): Whether to invert the PID output for this axis. Defaults to False. 102 """ 103 104 def __init__(self, sim, axis='x', limits=None, invert_output=False): 105 self.sim = sim 106 self.axis = axis.lower() 107 self.axis_map = {'x': 0, 'y': 1, 'z': 2} 108 self.limits = limits 109 self.invert_output = invert_output 110 111 @property 112 def limits(self): 113 if self._limits is None: 114 w = find_workspace(self.sim) 115 self._limits = { 116 'x': [w['x_min'], w['x_max']], 117 'y': [w['y_min'], w['y_max']], 118 'z': [w['z_min'], w['z_max']], 119 'center': w['center'] 120 } 121 return self._limits 122 123 @limits.setter 124 def limits(self, workspace=None): 125 if workspace is not None: 126 self._limits = { 127 'x': [workspace['x_min'], workspace['x_max']], 128 'y': [workspace['y_min'], workspace['y_max']], 129 'z': [workspace['z_min'], workspace['z_max']], 130 'center': workspace['center'] 131 } 132 else: 133 self._limits = None 134 135 def _run_single_trial(self, kp, ki, kd, dt, target_pos, 136 max_steps, tolerance, settle_steps, 137 reset=True, verbose=False): 138 """Run a single trial of PID test. 139 140 Args: 141 kp (float): Proportional gain. 142 ki (float): Integral gain. 143 kd (float): Derivative gain. 144 dt (float): Time step. 145 target_pos (list): Target position [x, y, z]. 146 max_steps (int): Maximum number of simulation steps. 147 tolerance (float): Position tolerance for settling. 148 settle_steps (int): Number of steps within tolerance to consider settled. 149 reset (bool): Whether to reset simulation before trial. Defaults to True. 150 verbose (bool): Print debug information. Defaults to False. 151 152 Returns: 153 dict: Single trial results containing metrics, position data, and metadata. 154 """ 155 if reset: 156 self.sim.reset() 157 158 axis_idx = self.axis_map[self.axis.lower()] 159 target = target_pos[axis_idx] 160 start_pos = get_position(self.sim) 161 162 # Check target is within limits 163 axis_limits = self.limits[self.axis] 164 if target < axis_limits[0] or target > axis_limits[1]: 165 raise ValueError(f'Target {target:.4f} outside of {self.axis}-axis range [{axis_limits[0]:.4f}, {axis_limits[1]:.4f}]') 166 167 # Move to starting position 168 actions = [[0, 0, 0, 0]] 169 actions[0][axis_idx] = -1000 170 self.sim.run(actions, 100) 171 172 pid = PID(kp=kp, ki=ki, kd=kd, setpoint=target, invert_output=self.invert_output) 173 174 # Data collection 175 position_data = [] 176 time_data = [] 177 error_data = [] 178 control_data = [] 179 velocity_data = [] 180 torque_data = [] 181 182 # Settling detection 183 steps_in_tolerance = 0 184 185 for step in range(max_steps): 186 current_pos = get_position(self.sim) 187 current_val = current_pos[axis_idx] 188 current_vel = get_velocities(self.sim) 189 current_torque = get_torques(self.sim) 190 191 if verbose: 192 print(f'current_val {self.axis} axis: {current_val:.5f} -> {target:.5f}') 193 194 position_data.append(current_val) 195 time_data.append(step) 196 error_data.append(target - current_val) 197 velocity_data.append(current_vel[axis_idx]) 198 torque_data.append(current_torque[axis_idx]) 199 200 # Check position remains near target 201 if abs(current_val - target) <= tolerance: 202 steps_in_tolerance += 1 203 if verbose: 204 print(f'settling {steps_in_tolerance} of {settle_steps}') 205 else: 206 steps_in_tolerance = 0 207 208 # Stop execution if position remains near target 209 if steps_in_tolerance >= settle_steps: 210 break 211 212 # Calculate PID output 213 vx = pid(current_val, dt) 214 control_data.append(vx) 215 216 actions = [[0, 0, 0, 0]] 217 actions[0][axis_idx] = vx 218 self.sim.run(actions, num_steps=1) 219 220 # Calculate metrics 221 metrics = calculate_metrics(position_data, time_data, target) 222 223 # Store result 224 result = { 225 'timestamp': datetime.now().isoformat(), 226 'metrics': metrics, 227 'position_data': position_data, 228 'time_data': time_data, 229 'error_data': error_data, 230 'control_data': control_data, 231 'velocity_data': velocity_data, 232 'torque_data': torque_data, 233 'target': target, 234 'start_position': start_pos[self.axis_map[self.axis.lower()]], 235 'settled': steps_in_tolerance >= settle_steps, 236 'invert_output': self.invert_output, 237 'total_time': len(time_data), 238 239 } 240 241 return result 242 243 def test_gains(self, kp, ki, kd, dt, target_pos, 244 max_steps, tolerance, settle_steps, num_trials, 245 reset=True, verbose=False): 246 """Test a set of PID gains over multiple trials. 247 248 Args: 249 kp (float): Proportional gain. 250 ki (float): Integral gain. 251 kd (float): Derivative gain. 252 dt (float): Time step. 253 target_pos (list): Target position [x, y, z]. 254 max_steps (int): Maximum number of simulation steps. 255 tolerance (float): Position tolerance for settling. 256 settle_steps (int): Number of steps within tolerance to consider settled. 257 num_trials (int): Number of trials to run. 258 reset (bool): Whether to reset simulation before each trial. Defaults to True. 259 verbose (bool): Print debug information. Defaults to False. 260 261 Returns: 262 dict: Test results with gains, timestamp, and trial data. 263 """ 264 trials = [] 265 266 for trial in range(num_trials): 267 trial_result = self._run_single_trial( 268 kp, ki, kd, dt, target_pos, 269 max_steps, tolerance, settle_steps, 270 reset=reset, verbose=verbose 271 ) 272 trials.append(trial_result) 273 274 result = { 275 'gains': {'kp': kp, 'ki': ki, 'kd': kd}, 276 'timestamp': datetime.now().isoformat(), 277 'num_trials': num_trials, 278 'trials': trials 279 } 280 281 return result
PID controller testing class.
Arguments:
- sim (Simulation): The simulation object.
- axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'.
- limits (dict, optional): Workspace limits (auto-calculated if not provided).
- invert_output (bool): Whether to invert the PID output for this axis. Defaults to False.
243 def test_gains(self, kp, ki, kd, dt, target_pos, 244 max_steps, tolerance, settle_steps, num_trials, 245 reset=True, verbose=False): 246 """Test a set of PID gains over multiple trials. 247 248 Args: 249 kp (float): Proportional gain. 250 ki (float): Integral gain. 251 kd (float): Derivative gain. 252 dt (float): Time step. 253 target_pos (list): Target position [x, y, z]. 254 max_steps (int): Maximum number of simulation steps. 255 tolerance (float): Position tolerance for settling. 256 settle_steps (int): Number of steps within tolerance to consider settled. 257 num_trials (int): Number of trials to run. 258 reset (bool): Whether to reset simulation before each trial. Defaults to True. 259 verbose (bool): Print debug information. Defaults to False. 260 261 Returns: 262 dict: Test results with gains, timestamp, and trial data. 263 """ 264 trials = [] 265 266 for trial in range(num_trials): 267 trial_result = self._run_single_trial( 268 kp, ki, kd, dt, target_pos, 269 max_steps, tolerance, settle_steps, 270 reset=reset, verbose=verbose 271 ) 272 trials.append(trial_result) 273 274 result = { 275 'gains': {'kp': kp, 'ki': ki, 'kd': kd}, 276 'timestamp': datetime.now().isoformat(), 277 'num_trials': num_trials, 278 'trials': trials 279 } 280 281 return result
Test a set of PID gains over multiple trials.
Arguments:
- kp (float): Proportional gain.
- ki (float): Integral gain.
- kd (float): Derivative gain.
- dt (float): Time step.
- target_pos (list): Target position [x, y, z].
- max_steps (int): Maximum number of simulation steps.
- tolerance (float): Position tolerance for settling.
- settle_steps (int): Number of steps within tolerance to consider settled.
- num_trials (int): Number of trials to run.
- reset (bool): Whether to reset simulation before each trial. Defaults to True.
- verbose (bool): Print debug information. Defaults to False.
Returns:
dict: Test results with gains, timestamp, and trial data.
284class Experiment: 285 """Create a new experiment to track PID testing. 286 287 Args: 288 title (str): Title of the experiment. 289 hypothesis (str): Hypothesis being tested. 290 axis (str): Axis being tested ('x', 'y', or 'z'). 291 invert_output (bool): Whether output was inverted for this axis. Defaults to False. 292 """ 293 294 def __init__(self, title, hypothesis, axis, invert_output=False): 295 self.id = str(uuid.uuid4()) 296 self.title = title 297 self.hypothesis = hypothesis 298 self.axis = axis.lower() 299 self.timestamp = datetime.now().isoformat() 300 self.tests = [] 301 self.invert_output = invert_output 302 303 def add_test(self, test_result): 304 """Add a test result from PIDTester.test_gains(). 305 306 Args: 307 test_result (dict): Result dictionary from test_gains(). 308 309 Raises: 310 ValueError: If test_result doesn't contain required keys. 311 """ 312 # Basic validation 313 if 'gains' not in test_result or 'trials' not in test_result: 314 raise ValueError("Invalid test_result: must contain 'gains' and 'trials'") 315 316 self.tests.append(test_result) 317 318 def get_averaged_metrics(self, test_index): 319 """Get averaged metrics across trials for a specific test. 320 321 Args: 322 test_index (int): Index of the test in self.tests. 323 324 Returns: 325 dict: Averaged metrics with mean, std, min, max for each metric. 326 327 Raises: 328 IndexError: If test_index is out of range. 329 """ 330 if test_index >= len(self.tests): 331 raise IndexError(f"Test index {test_index} out of range") 332 333 test = self.tests[test_index] 334 trials = test['trials'] 335 336 if len(trials) == 1: 337 return trials[0]['metrics'] 338 339 metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 340 averaged = {} 341 342 for key in metric_keys: 343 values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None] 344 if values: 345 averaged[key] = { 346 'mean': np.mean(values), 347 'std': np.std(values), 348 'min': np.min(values), 349 'max': np.max(values) 350 } 351 else: 352 averaged[key] = None 353 354 return averaged 355 356 def print_summary(self): 357 """Print summary table of all tests in this experiment.""" 358 print(f"\nExperiment: {self.title}") 359 print(f"Hypothesis: {self.hypothesis}") 360 print(f"Axis: {self.axis}") 361 print(f"Inverted Output: {self.invert_output}") 362 print(f"Tests: {len(self.tests)}") 363 print("\n" + "="*90) # <-- MAKE WIDER 364 print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}") # <-- ADD COLUMN 365 print("="*90) 366 367 for i, test in enumerate(self.tests): 368 gains = test['gains'] 369 metrics = self.get_averaged_metrics(i) 370 371 # Handle both single trial (dict) and averaged (dict with 'mean') 372 if isinstance(metrics['rise_time'], dict): 373 rise = metrics['rise_time']['mean'] 374 settling = metrics['settling_time']['mean'] 375 overshoot = metrics['overshoot']['mean'] 376 ss_error = metrics['steady_state_error']['mean'] 377 else: 378 rise = metrics['rise_time'] 379 settling = metrics['settling_time'] 380 overshoot = metrics['overshoot'] 381 ss_error = metrics['steady_state_error'] 382 383 # Get total time from first trial 384 total_time = test['trials'][0]['total_time'] 385 386 # Format None values as 'N/A' instead of trying to format as float 387 rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}" 388 settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}" 389 overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}" 390 ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}" 391 392 print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} " 393 f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}") # <-- ADD TOTAL 394 395 print("="*90) 396 397 def _determine_varying_gain(self): 398 """Determine which gain parameter varies across tests. 399 400 Returns: 401 tuple: (gain_key, x_values, x_labels) where: 402 - gain_key (str or None): 'Kp', 'Ki', 'Kd', or None if multiple vary. 403 - x_values (list): Numeric values for x-axis. 404 - x_labels (list): String labels for x-axis. 405 """ 406 gains_list = [test['gains'] for test in self.tests] 407 408 kp_values = [g['kp'] for g in gains_list] 409 ki_values = [g['ki'] for g in gains_list] 410 kd_values = [g['kd'] for g in gains_list] 411 412 kp_varies = len(set(kp_values)) > 1 413 ki_varies = len(set(ki_values)) > 1 414 kd_varies = len(set(kd_values)) > 1 415 416 varies_count = sum([kp_varies, ki_varies, kd_varies]) 417 418 if varies_count == 0: 419 # All tests have same gains 420 return None, list(range(len(self.tests))), [str(i) for i in range(len(self.tests))] 421 elif varies_count == 1: 422 # Exactly one gain varies 423 if kp_varies: 424 # Format labels to avoid overlapping decimals 425 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kp_values] 426 return 'Kp', kp_values, labels 427 elif ki_varies: 428 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in ki_values] 429 return 'Ki', ki_values, labels 430 else: 431 labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kd_values] 432 return 'Kd', kd_values, labels 433 else: 434 # Multiple gains vary - use test index with composite labels 435 x_values = list(range(len(self.tests))) 436 x_labels = [f"P={g['kp']:.0f if g['kp']>=1 else g['kp']:.2f},I={g['ki']:.0f if g['ki']>=1 else g['ki']:.2f},D={g['kd']:.0f if g['kd']>=1 else g['kd']:.2f}" for g in gains_list] 437 return None, x_values, x_labels 438 439 def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'): 440 """Plot averaged metrics across all tests. 441 442 Args: 443 metric_names (list of str or 'all', optional): Which metrics to plot for line plots. 444 If 'all' or None, plots all metrics. Ignored for scatter plots. 445 Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'. 446 separate_subplots (bool): If True, each metric gets its own subplot. 447 If False, all on one plot. Only applies to line plots. Defaults to True. 448 plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 449 speed vs stability tradeoff. Defaults to 'lines'. 450 451 Returns: 452 tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests. 453 """ 454 if not self.tests: 455 print("No tests to plot") 456 return None, None 457 458 # Scatter plot: speed vs stability 459 if plot_type == 'scatter': 460 return self._plot_speed_vs_stability() 461 462 # Line plot: existing behavior 463 if metric_names == 'all' or metric_names is None: 464 metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 465 466 # Determine which gain varies 467 gain_key, x_values, x_labels = self._determine_varying_gain() 468 469 # Collect data 470 data = {metric: [] for metric in metric_names} 471 errors = {metric: [] for metric in metric_names} 472 valid_indices = {metric: [] for metric in metric_names} 473 474 for i, test in enumerate(self.tests): 475 metrics = self.get_averaged_metrics(i) 476 for metric in metric_names: 477 if isinstance(metrics[metric], dict): 478 value = metrics[metric]['mean'] 479 error = metrics[metric]['std'] 480 else: 481 value = metrics[metric] 482 error = 0 483 484 if value is not None: 485 data[metric].append(value) 486 errors[metric].append(error) 487 valid_indices[metric].append(i) 488 489 # Create plots 490 if separate_subplots: 491 fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names))) 492 if len(metric_names) == 1: 493 axes = [axes] 494 495 for ax, metric in zip(axes, metric_names): 496 if data[metric]: 497 valid_x_values = [x_values[i] for i in valid_indices[metric]] 498 valid_x_labels = [x_labels[i] for i in valid_indices[metric]] 499 500 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 501 marker='o', capsize=5, capthick=2) 502 ax.set_xlabel(gain_key if gain_key else 'Test Index') 503 ax.set_ylabel(metric.replace('_', ' ').title()) 504 ax.set_xscale('log') 505 ax.set_xticks(valid_x_values) 506 ax.set_xticklabels(valid_x_labels) 507 ax.grid(True, alpha=0.3, which='both') 508 else: 509 ax.text(0.5, 0.5, f'No valid data for {metric}', 510 ha='center', va='center', transform=ax.transAxes) 511 ax.set_xlabel(gain_key if gain_key else 'Test Index') 512 ax.set_ylabel(metric.replace('_', ' ').title()) 513 else: 514 fig, ax = plt.subplots(figsize=(10, 6)) 515 for metric in metric_names: 516 if data[metric]: 517 valid_x_values = [x_values[i] for i in valid_indices[metric]] 518 519 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 520 marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title()) 521 522 ax.set_xlabel(gain_key if gain_key else 'Test Index') 523 ax.set_ylabel('Value') 524 ax.set_xscale('log') 525 if x_values: 526 ax.set_xticks(x_values) 527 ax.set_xticklabels(x_labels) 528 ax.legend() 529 ax.grid(True, alpha=0.3, which='both') 530 531 fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12) 532 fig.tight_layout() 533 534 return fig, axes if separate_subplots else ax 535 536 def _plot_speed_vs_stability(self): 537 """Create scatter plot of speed (total time) vs stability (overshoot). 538 539 Returns: 540 tuple: (fig, ax) Matplotlib figure and axis objects. 541 """ 542 # Collect data 543 total_times = [] 544 overshoots = [] 545 gain_values = [] 546 547 for i, test in enumerate(self.tests): 548 metrics = self.get_averaged_metrics(i) 549 550 # Get overshoot 551 if isinstance(metrics['overshoot'], dict): 552 overshoot = metrics['overshoot']['mean'] 553 else: 554 overshoot = metrics['overshoot'] 555 556 # Get total time from first trial 557 total_time = test['trials'][0]['total_time'] 558 559 if overshoot is not None and total_time is not None: 560 total_times.append(total_time) 561 overshoots.append(overshoot) 562 gain_values.append(i) 563 564 if not total_times: 565 print("No valid data for scatter plot") 566 return None, None 567 568 # Determine which gain varies for coloring 569 gain_key, varying_gains, _ = self._determine_varying_gain() 570 571 # Create scatter plot 572 fig, ax = plt.subplots(figsize=(10, 8)) 573 574 # Use colormap based on varying gain values 575 if gain_key: 576 # Extract the actual varying gain values for valid indices 577 color_values = [varying_gains[i] for i in gain_values] 578 scatter = ax.scatter(total_times, overshoots, c=color_values, 579 cmap='viridis', s=100, edgecolors='black', linewidth=1.5) 580 cbar = plt.colorbar(scatter, ax=ax) 581 cbar.set_label(gain_key, rotation=270, labelpad=20) 582 else: 583 # No single varying gain, use uniform color 584 ax.scatter(total_times, overshoots, s=100, edgecolors='black', linewidth=1.5) 585 586 # Annotate points with gain values 587 gains_list = [self.tests[i]['gains'] for i in gain_values] 588 for i, (x, y, gains) in enumerate(zip(total_times, overshoots, gains_list)): 589 # Format gain values outside f-string 590 kp_str = f"{gains['kp']:.0f}" if gains['kp'] >= 1 else f"{gains['kp']:.2f}" 591 ki_str = f"{gains['ki']:.0f}" if gains['ki'] >= 1 else f"{gains['ki']:.2f}" 592 kd_str = f"{gains['kd']:.0f}" if gains['kd'] >= 1 else f"{gains['kd']:.2f}" 593 594 label = f"P={kp_str}\nI={ki_str}\nD={kd_str}" 595 ax.annotate(label, (x, y), xytext=(5, 5), textcoords='offset points', 596 fontsize=8, alpha=0.7) 597 598 ax.set_xlabel('Total Time (steps)', fontsize=12) 599 ax.set_ylabel('Overshoot (%)', fontsize=12) 600 ax.set_title(f"{self.title}\nSpeed vs Stability Tradeoff", fontsize=14) 601 ax.grid(True, alpha=0.3) 602 603 # Add arrow pointing to ideal region (bottom-left) 604 ax.annotate('Ideal\n(Fast & Stable)', xy=(min(total_times), min(overshoots)), 605 xytext=(0.05, 0.95), textcoords='axes fraction', 606 fontsize=10, color='green', weight='bold', 607 ha='left', va='top', 608 arrowprops=dict(arrowstyle='->', color='green', lw=2)) 609 610 fig.tight_layout() 611 612 return fig, ax 613 614 def plot_trajectory(self, test_index, trial_index=0): 615 """Plot detailed trajectory data for a specific trial. 616 617 Shows position, error, and control signal over time. 618 619 Args: 620 test_index (int): Index of the test in self.tests. 621 trial_index (int): Index of the trial within that test. Defaults to 0. 622 623 Returns: 624 tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots). 625 626 Raises: 627 IndexError: If test_index or trial_index is out of range. 628 """ 629 if test_index >= len(self.tests): 630 raise IndexError(f"Test index {test_index} out of range") 631 632 test = self.tests[test_index] 633 634 if trial_index >= len(test['trials']): 635 raise IndexError(f"Trial index {trial_index} out of range") 636 637 trial = test['trials'][trial_index] 638 gains = test['gains'] 639 640 # Extract data 641 time_data = trial['time_data'] 642 velocity_data = trial['velocity_data'] 643 torque_data = trial['torque_data'] 644 position_data = trial['position_data'] 645 error_data = trial['error_data'] 646 control_data = trial['control_data'] 647 target = trial['target'] 648 649 # Create figure with 3 subplots 650 fig, axes = plt.subplots(5, 1, figsize=(12, 10)) 651 652 # Plot 1: Position vs Time 653 axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position') 654 axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target') 655 axes[0].set_xlabel('Time (steps)') 656 axes[0].set_ylabel('Position') 657 axes[0].set_title('Position vs Time') 658 axes[0].legend() 659 axes[0].grid(True, alpha=0.3) 660 661 # Plot 2: Error vs Time 662 axes[1].plot(time_data, error_data, 'g-', linewidth=2) 663 axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1) 664 axes[1].set_xlabel('Time (steps)') 665 axes[1].set_ylabel('Error') 666 axes[1].set_title('Error vs Time') 667 axes[1].grid(True, alpha=0.3) 668 669 # Plot 3: Control Signal vs Time 670 axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2) 671 axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1) 672 axes[2].set_xlabel('Time (steps)') 673 axes[2].set_ylabel('Control Signal') 674 axes[2].set_title('Control Signal vs Time') 675 axes[2].grid(True, alpha=0.3) 676 677 # Plot 4: Velocity vs Time 678 axes[3].plot(time_data, velocity_data, 'purple', linewidth=2) 679 axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1) 680 axes[3].set_xlabel('Time (steps)') 681 axes[3].set_ylabel('Velocity') 682 axes[3].set_title('Velocity vs Time') 683 axes[3].grid(True, alpha=0.3) 684 685 # Plot 5: Torque vs Time 686 axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2) 687 axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1) 688 axes[4].set_xlabel('Time (steps)') 689 axes[4].set_ylabel('Torque') 690 axes[4].set_title('Torque vs Time') 691 axes[4].grid(True, alpha=0.3) 692 # Overall title 693 fig.suptitle( 694 f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}", 695 fontsize=12 696 ) 697 fig.tight_layout() 698 699 return fig, axes 700 701 def to_dict(self): 702 """Convert experiment to dictionary for JSON serialization. 703 704 Returns: 705 dict: Dictionary representation of the experiment. 706 """ 707 return { 708 'id': self.id, 709 'title': self.title, 710 'hypothesis': self.hypothesis, 711 'axis': self.axis, 712 'timestamp': self.timestamp, 713 'tests': self.tests, 714 'invert_output': self.invert_output 715 }
Create a new experiment to track PID testing.
Arguments:
- title (str): Title of the experiment.
- hypothesis (str): Hypothesis being tested.
- axis (str): Axis being tested ('x', 'y', or 'z').
- invert_output (bool): Whether output was inverted for this axis. Defaults to False.
303 def add_test(self, test_result): 304 """Add a test result from PIDTester.test_gains(). 305 306 Args: 307 test_result (dict): Result dictionary from test_gains(). 308 309 Raises: 310 ValueError: If test_result doesn't contain required keys. 311 """ 312 # Basic validation 313 if 'gains' not in test_result or 'trials' not in test_result: 314 raise ValueError("Invalid test_result: must contain 'gains' and 'trials'") 315 316 self.tests.append(test_result)
Add a test result from PIDTester.test_gains().
Arguments:
- test_result (dict): Result dictionary from test_gains().
Raises:
- ValueError: If test_result doesn't contain required keys.
318 def get_averaged_metrics(self, test_index): 319 """Get averaged metrics across trials for a specific test. 320 321 Args: 322 test_index (int): Index of the test in self.tests. 323 324 Returns: 325 dict: Averaged metrics with mean, std, min, max for each metric. 326 327 Raises: 328 IndexError: If test_index is out of range. 329 """ 330 if test_index >= len(self.tests): 331 raise IndexError(f"Test index {test_index} out of range") 332 333 test = self.tests[test_index] 334 trials = test['trials'] 335 336 if len(trials) == 1: 337 return trials[0]['metrics'] 338 339 metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 340 averaged = {} 341 342 for key in metric_keys: 343 values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None] 344 if values: 345 averaged[key] = { 346 'mean': np.mean(values), 347 'std': np.std(values), 348 'min': np.min(values), 349 'max': np.max(values) 350 } 351 else: 352 averaged[key] = None 353 354 return averaged
Get averaged metrics across trials for a specific test.
Arguments:
- test_index (int): Index of the test in self.tests.
Returns:
dict: Averaged metrics with mean, std, min, max for each metric.
Raises:
- IndexError: If test_index is out of range.
356 def print_summary(self): 357 """Print summary table of all tests in this experiment.""" 358 print(f"\nExperiment: {self.title}") 359 print(f"Hypothesis: {self.hypothesis}") 360 print(f"Axis: {self.axis}") 361 print(f"Inverted Output: {self.invert_output}") 362 print(f"Tests: {len(self.tests)}") 363 print("\n" + "="*90) # <-- MAKE WIDER 364 print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}") # <-- ADD COLUMN 365 print("="*90) 366 367 for i, test in enumerate(self.tests): 368 gains = test['gains'] 369 metrics = self.get_averaged_metrics(i) 370 371 # Handle both single trial (dict) and averaged (dict with 'mean') 372 if isinstance(metrics['rise_time'], dict): 373 rise = metrics['rise_time']['mean'] 374 settling = metrics['settling_time']['mean'] 375 overshoot = metrics['overshoot']['mean'] 376 ss_error = metrics['steady_state_error']['mean'] 377 else: 378 rise = metrics['rise_time'] 379 settling = metrics['settling_time'] 380 overshoot = metrics['overshoot'] 381 ss_error = metrics['steady_state_error'] 382 383 # Get total time from first trial 384 total_time = test['trials'][0]['total_time'] 385 386 # Format None values as 'N/A' instead of trying to format as float 387 rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}" 388 settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}" 389 overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}" 390 ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}" 391 392 print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} " 393 f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}") # <-- ADD TOTAL 394 395 print("="*90)
Print summary table of all tests in this experiment.
439 def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'): 440 """Plot averaged metrics across all tests. 441 442 Args: 443 metric_names (list of str or 'all', optional): Which metrics to plot for line plots. 444 If 'all' or None, plots all metrics. Ignored for scatter plots. 445 Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'. 446 separate_subplots (bool): If True, each metric gets its own subplot. 447 If False, all on one plot. Only applies to line plots. Defaults to True. 448 plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 449 speed vs stability tradeoff. Defaults to 'lines'. 450 451 Returns: 452 tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests. 453 """ 454 if not self.tests: 455 print("No tests to plot") 456 return None, None 457 458 # Scatter plot: speed vs stability 459 if plot_type == 'scatter': 460 return self._plot_speed_vs_stability() 461 462 # Line plot: existing behavior 463 if metric_names == 'all' or metric_names is None: 464 metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error'] 465 466 # Determine which gain varies 467 gain_key, x_values, x_labels = self._determine_varying_gain() 468 469 # Collect data 470 data = {metric: [] for metric in metric_names} 471 errors = {metric: [] for metric in metric_names} 472 valid_indices = {metric: [] for metric in metric_names} 473 474 for i, test in enumerate(self.tests): 475 metrics = self.get_averaged_metrics(i) 476 for metric in metric_names: 477 if isinstance(metrics[metric], dict): 478 value = metrics[metric]['mean'] 479 error = metrics[metric]['std'] 480 else: 481 value = metrics[metric] 482 error = 0 483 484 if value is not None: 485 data[metric].append(value) 486 errors[metric].append(error) 487 valid_indices[metric].append(i) 488 489 # Create plots 490 if separate_subplots: 491 fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names))) 492 if len(metric_names) == 1: 493 axes = [axes] 494 495 for ax, metric in zip(axes, metric_names): 496 if data[metric]: 497 valid_x_values = [x_values[i] for i in valid_indices[metric]] 498 valid_x_labels = [x_labels[i] for i in valid_indices[metric]] 499 500 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 501 marker='o', capsize=5, capthick=2) 502 ax.set_xlabel(gain_key if gain_key else 'Test Index') 503 ax.set_ylabel(metric.replace('_', ' ').title()) 504 ax.set_xscale('log') 505 ax.set_xticks(valid_x_values) 506 ax.set_xticklabels(valid_x_labels) 507 ax.grid(True, alpha=0.3, which='both') 508 else: 509 ax.text(0.5, 0.5, f'No valid data for {metric}', 510 ha='center', va='center', transform=ax.transAxes) 511 ax.set_xlabel(gain_key if gain_key else 'Test Index') 512 ax.set_ylabel(metric.replace('_', ' ').title()) 513 else: 514 fig, ax = plt.subplots(figsize=(10, 6)) 515 for metric in metric_names: 516 if data[metric]: 517 valid_x_values = [x_values[i] for i in valid_indices[metric]] 518 519 ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 520 marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title()) 521 522 ax.set_xlabel(gain_key if gain_key else 'Test Index') 523 ax.set_ylabel('Value') 524 ax.set_xscale('log') 525 if x_values: 526 ax.set_xticks(x_values) 527 ax.set_xticklabels(x_labels) 528 ax.legend() 529 ax.grid(True, alpha=0.3, which='both') 530 531 fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12) 532 fig.tight_layout() 533 534 return fig, axes if separate_subplots else ax
Plot averaged metrics across all tests.
Arguments:
- metric_names (list of str or 'all', optional): Which metrics to plot for line plots. If 'all' or None, plots all metrics. Ignored for scatter plots. Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'.
- separate_subplots (bool): If True, each metric gets its own subplot. If False, all on one plot. Only applies to line plots. Defaults to True.
- plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for speed vs stability tradeoff. Defaults to 'lines'.
Returns:
tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests.
614 def plot_trajectory(self, test_index, trial_index=0): 615 """Plot detailed trajectory data for a specific trial. 616 617 Shows position, error, and control signal over time. 618 619 Args: 620 test_index (int): Index of the test in self.tests. 621 trial_index (int): Index of the trial within that test. Defaults to 0. 622 623 Returns: 624 tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots). 625 626 Raises: 627 IndexError: If test_index or trial_index is out of range. 628 """ 629 if test_index >= len(self.tests): 630 raise IndexError(f"Test index {test_index} out of range") 631 632 test = self.tests[test_index] 633 634 if trial_index >= len(test['trials']): 635 raise IndexError(f"Trial index {trial_index} out of range") 636 637 trial = test['trials'][trial_index] 638 gains = test['gains'] 639 640 # Extract data 641 time_data = trial['time_data'] 642 velocity_data = trial['velocity_data'] 643 torque_data = trial['torque_data'] 644 position_data = trial['position_data'] 645 error_data = trial['error_data'] 646 control_data = trial['control_data'] 647 target = trial['target'] 648 649 # Create figure with 3 subplots 650 fig, axes = plt.subplots(5, 1, figsize=(12, 10)) 651 652 # Plot 1: Position vs Time 653 axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position') 654 axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target') 655 axes[0].set_xlabel('Time (steps)') 656 axes[0].set_ylabel('Position') 657 axes[0].set_title('Position vs Time') 658 axes[0].legend() 659 axes[0].grid(True, alpha=0.3) 660 661 # Plot 2: Error vs Time 662 axes[1].plot(time_data, error_data, 'g-', linewidth=2) 663 axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1) 664 axes[1].set_xlabel('Time (steps)') 665 axes[1].set_ylabel('Error') 666 axes[1].set_title('Error vs Time') 667 axes[1].grid(True, alpha=0.3) 668 669 # Plot 3: Control Signal vs Time 670 axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2) 671 axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1) 672 axes[2].set_xlabel('Time (steps)') 673 axes[2].set_ylabel('Control Signal') 674 axes[2].set_title('Control Signal vs Time') 675 axes[2].grid(True, alpha=0.3) 676 677 # Plot 4: Velocity vs Time 678 axes[3].plot(time_data, velocity_data, 'purple', linewidth=2) 679 axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1) 680 axes[3].set_xlabel('Time (steps)') 681 axes[3].set_ylabel('Velocity') 682 axes[3].set_title('Velocity vs Time') 683 axes[3].grid(True, alpha=0.3) 684 685 # Plot 5: Torque vs Time 686 axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2) 687 axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1) 688 axes[4].set_xlabel('Time (steps)') 689 axes[4].set_ylabel('Torque') 690 axes[4].set_title('Torque vs Time') 691 axes[4].grid(True, alpha=0.3) 692 # Overall title 693 fig.suptitle( 694 f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}", 695 fontsize=12 696 ) 697 fig.tight_layout() 698 699 return fig, axes
Plot detailed trajectory data for a specific trial.
Shows position, error, and control signal over time.
Arguments:
- test_index (int): Index of the test in self.tests.
- trial_index (int): Index of the trial within that test. Defaults to 0.
Returns:
tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots).
Raises:
- IndexError: If test_index or trial_index is out of range.
701 def to_dict(self): 702 """Convert experiment to dictionary for JSON serialization. 703 704 Returns: 705 dict: Dictionary representation of the experiment. 706 """ 707 return { 708 'id': self.id, 709 'title': self.title, 710 'hypothesis': self.hypothesis, 711 'axis': self.axis, 712 'timestamp': self.timestamp, 713 'tests': self.tests, 714 'invert_output': self.invert_output 715 }
Convert experiment to dictionary for JSON serialization.
Returns:
dict: Dictionary representation of the experiment.
718class ExperimentLog: 719 """Manage a log of PID experiments. 720 721 Args: 722 filepath (str or Path): Path to the JSON log file. 723 """ 724 725 def __init__(self, filepath): 726 self.filepath = Path(filepath) 727 self.experiments = [] 728 729 def add_experiment(self, experiment): 730 """Add an Experiment to the log. 731 732 Args: 733 experiment (Experiment): Experiment object to add. 734 735 Raises: 736 TypeError: If experiment is not an Experiment object. 737 """ 738 if not isinstance(experiment, Experiment): 739 raise TypeError("Must add an Experiment object") 740 self.experiments.append(experiment) 741 742 def save(self): 743 """Save experiments to JSON file, preserving existing experiments. 744 745 Checks for duplicate IDs and updates existing experiments. 746 """ 747 # Load existing data if file exists 748 if self.filepath.exists(): 749 with open(self.filepath, 'r') as f: 750 data = json.load(f) 751 existing_experiments = data.get('experiments', []) 752 else: 753 existing_experiments = [] 754 755 # Create dict of existing experiments by ID 756 existing_by_id = {exp['id']: exp for exp in existing_experiments} 757 758 # Add or update with current experiments 759 for exp in self.experiments: 760 existing_by_id[exp.id] = exp.to_dict() 761 762 # Save all experiments 763 data = {'experiments': list(existing_by_id.values())} 764 765 with open(self.filepath, 'w') as f: 766 json.dump(data, f, indent=2) 767 768 print(f"Saved {len(self.experiments)} experiments to {self.filepath}") 769 770 @classmethod 771 def load(cls, filepath): 772 """Load experiments from JSON file. 773 774 Args: 775 filepath (str or Path): Path to the JSON log file. 776 777 Returns: 778 ExperimentLog: Log with loaded Experiment objects. 779 """ 780 log = cls(filepath) 781 782 if not log.filepath.exists(): 783 print(f"No existing file at {filepath}, starting new log") 784 return log 785 786 with open(log.filepath, 'r') as f: 787 data = json.load(f) 788 789 for exp_dict in data.get('experiments', []): 790 exp = cls._from_dict(exp_dict) 791 log.experiments.append(exp) 792 793 print(f"Loaded {len(log.experiments)} experiments from {filepath}") 794 return log 795 796 @staticmethod 797 def _from_dict(exp_dict): 798 """Reconstruct an Experiment object from a dictionary. 799 800 Args: 801 exp_dict (dict): Dictionary representation of an experiment. 802 803 Returns: 804 Experiment: Reconstructed Experiment object. 805 """ 806 exp = Experiment( 807 title=exp_dict['title'], 808 hypothesis=exp_dict['hypothesis'], 809 axis=exp_dict['axis'], 810 invert_output=exp_dict.get('invert_output', False) 811 ) 812 exp.id = exp_dict['id'] 813 exp.timestamp = exp_dict['timestamp'] 814 exp.tests = exp_dict['tests'] 815 return exp 816 817 def get_experiment(self, identifier): 818 """Get experiment by title or short hash (first 6 characters of ID). 819 820 Args: 821 identifier (str): Either the full/partial title or first 6 chars of ID. 822 823 Returns: 824 Experiment or None: Matching experiment, or None if not found. 825 """ 826 # Try short hash first 827 if len(identifier) == 6: 828 for exp in self.experiments: 829 if exp.id[:6] == identifier: 830 return exp 831 832 # Try title (case-insensitive partial match) 833 identifier_lower = identifier.lower() 834 for exp in self.experiments: 835 if identifier_lower in exp.title.lower(): 836 return exp 837 838 return None 839 840 def list_experiments(self): 841 """Print summary table of all experiments.""" 842 if not self.experiments: 843 print("No experiments in log") 844 return 845 846 print("\n" + "="*100) 847 print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}") 848 print("="*100) 849 850 for exp in self.experiments: 851 short_id = exp.id[:6] 852 title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title 853 hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis 854 num_tests = len(exp.tests) 855 856 print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}") 857 858 print("="*100)
Manage a log of PID experiments.
Arguments:
- filepath (str or Path): Path to the JSON log file.
729 def add_experiment(self, experiment): 730 """Add an Experiment to the log. 731 732 Args: 733 experiment (Experiment): Experiment object to add. 734 735 Raises: 736 TypeError: If experiment is not an Experiment object. 737 """ 738 if not isinstance(experiment, Experiment): 739 raise TypeError("Must add an Experiment object") 740 self.experiments.append(experiment)
Add an Experiment to the log.
Arguments:
- experiment (Experiment): Experiment object to add.
Raises:
- TypeError: If experiment is not an Experiment object.
742 def save(self): 743 """Save experiments to JSON file, preserving existing experiments. 744 745 Checks for duplicate IDs and updates existing experiments. 746 """ 747 # Load existing data if file exists 748 if self.filepath.exists(): 749 with open(self.filepath, 'r') as f: 750 data = json.load(f) 751 existing_experiments = data.get('experiments', []) 752 else: 753 existing_experiments = [] 754 755 # Create dict of existing experiments by ID 756 existing_by_id = {exp['id']: exp for exp in existing_experiments} 757 758 # Add or update with current experiments 759 for exp in self.experiments: 760 existing_by_id[exp.id] = exp.to_dict() 761 762 # Save all experiments 763 data = {'experiments': list(existing_by_id.values())} 764 765 with open(self.filepath, 'w') as f: 766 json.dump(data, f, indent=2) 767 768 print(f"Saved {len(self.experiments)} experiments to {self.filepath}")
Save experiments to JSON file, preserving existing experiments.
Checks for duplicate IDs and updates existing experiments.
770 @classmethod 771 def load(cls, filepath): 772 """Load experiments from JSON file. 773 774 Args: 775 filepath (str or Path): Path to the JSON log file. 776 777 Returns: 778 ExperimentLog: Log with loaded Experiment objects. 779 """ 780 log = cls(filepath) 781 782 if not log.filepath.exists(): 783 print(f"No existing file at {filepath}, starting new log") 784 return log 785 786 with open(log.filepath, 'r') as f: 787 data = json.load(f) 788 789 for exp_dict in data.get('experiments', []): 790 exp = cls._from_dict(exp_dict) 791 log.experiments.append(exp) 792 793 print(f"Loaded {len(log.experiments)} experiments from {filepath}") 794 return log
Load experiments from JSON file.
Arguments:
- filepath (str or Path): Path to the JSON log file.
Returns:
ExperimentLog: Log with loaded Experiment objects.
817 def get_experiment(self, identifier): 818 """Get experiment by title or short hash (first 6 characters of ID). 819 820 Args: 821 identifier (str): Either the full/partial title or first 6 chars of ID. 822 823 Returns: 824 Experiment or None: Matching experiment, or None if not found. 825 """ 826 # Try short hash first 827 if len(identifier) == 6: 828 for exp in self.experiments: 829 if exp.id[:6] == identifier: 830 return exp 831 832 # Try title (case-insensitive partial match) 833 identifier_lower = identifier.lower() 834 for exp in self.experiments: 835 if identifier_lower in exp.title.lower(): 836 return exp 837 838 return None
Get experiment by title or short hash (first 6 characters of ID).
Arguments:
- identifier (str): Either the full/partial title or first 6 chars of ID.
Returns:
Experiment or None: Matching experiment, or None if not found.
840 def list_experiments(self): 841 """Print summary table of all experiments.""" 842 if not self.experiments: 843 print("No experiments in log") 844 return 845 846 print("\n" + "="*100) 847 print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}") 848 print("="*100) 849 850 for exp in self.experiments: 851 short_id = exp.id[:6] 852 title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title 853 hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis 854 num_tests = len(exp.tests) 855 856 print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}") 857 858 print("="*100)
Print summary table of all experiments.