library.pid_tester

  1import numpy as np
  2from datetime import datetime
  3from pathlib import Path
  4import json
  5import uuid
  6import matplotlib.pyplot as plt
  7from library.robot_control import get_position, find_workspace, get_velocities, get_torques
  8
  9from library.pid_controller import PID
 10
 11
 12def calculate_metrics(position_data, time_data, target):
 13    """Calculate all PID performance metrics.
 14    
 15    Args:
 16        position_data (list): Position values over time.
 17        time_data (list): Time step values.
 18        target (float): Target position.
 19        
 20    Returns:
 21        dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error.
 22    """
 23    if not position_data:
 24        return {
 25            'rise_time': None,
 26            'settling_time': None,
 27            'overshoot': 0.0,
 28            'steady_state_error': None
 29        }
 30    
 31    position_array = np.array(position_data)
 32    start_val = position_data[0]
 33    
 34    # Avoid division by zero
 35    if abs(target - start_val) < 1e-12:
 36        return {
 37            'rise_time': 0.0,
 38            'settling_time': 0.0,
 39            'overshoot': 0.0,
 40            'steady_state_error': abs(position_array[-1] - target)
 41        }
 42    
 43    # Calculate rise time (10% to 90%)
 44    ten_percent = start_val + 0.1 * (target - start_val)
 45    ninety_percent = start_val + 0.9 * (target - start_val)
 46    moving_positive = target > start_val
 47    
 48    ten_idx = None
 49    ninety_idx = None
 50    
 51    for i, pos in enumerate(position_data):
 52        if ten_idx is None:
 53            if (moving_positive and pos >= ten_percent) or (not moving_positive and pos <= ten_percent):
 54                ten_idx = i
 55        if ninety_idx is None:
 56            if (moving_positive and pos >= ninety_percent) or (not moving_positive and pos <= ninety_percent):
 57                ninety_idx = i
 58                break
 59    
 60    rise_time = time_data[ninety_idx] - time_data[ten_idx] if (ten_idx is not None and ninety_idx is not None) else None
 61    
 62    # Calculate settling time (2% band)
 63    settling_band = abs(target - start_val) * 0.02
 64    settling_time = None
 65    
 66    for i in range(len(position_data) - 1, -1, -1):
 67        if abs(position_data[i] - target) > settling_band:
 68            settling_time = time_data[i] if i < len(time_data) - 1 else time_data[-1]
 69            break
 70    
 71    if settling_time is None:
 72        settling_time = 0
 73    
 74    # Calculate overshoot
 75    if target > start_val:
 76        max_val = np.max(position_array)
 77        overshoot = (max_val - target) / abs(target - start_val) * 100 if max_val > target else 0
 78    else:
 79        min_val = np.min(position_array)
 80        overshoot = (target - min_val) / abs(target - start_val) * 100 if min_val < target else 0
 81    
 82    # Calculate steady state error
 83    steady_state_error = abs(position_array[-1] - target)
 84    
 85    return {
 86        'rise_time': rise_time,
 87        'settling_time': settling_time,
 88        'overshoot': overshoot,
 89        'steady_state_error': steady_state_error
 90    }
 91
 92
 93class PIDTester:
 94    """PID controller testing class.
 95    
 96    Args:
 97        sim (Simulation): The simulation object.
 98        axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'.
 99        limits (dict, optional): Workspace limits (auto-calculated if not provided).
100        invert_output (bool): Whether to invert the PID output for this axis. Defaults to False.
101    """
102    
103    def __init__(self, sim, axis='x', limits=None, invert_output=False):
104        self.sim = sim
105        self.axis = axis.lower()
106        self.axis_map = {'x': 0, 'y': 1, 'z': 2}
107        self.limits = limits
108        self.invert_output = invert_output
109                
110    @property
111    def limits(self):
112        if self._limits is None:
113            w = find_workspace(self.sim)
114            self._limits = {
115                'x': [w['x_min'], w['x_max']],
116                'y': [w['y_min'], w['y_max']],
117                'z': [w['z_min'], w['z_max']],
118                'center': w['center']
119            }
120        return self._limits
121
122    @limits.setter
123    def limits(self, workspace=None):
124        if workspace is not None:
125            self._limits = {
126                'x': [workspace['x_min'], workspace['x_max']],
127                'y': [workspace['y_min'], workspace['y_max']],
128                'z': [workspace['z_min'], workspace['z_max']],
129                'center': workspace['center']
130            }
131        else:
132            self._limits = None
133
134    def _run_single_trial(self, kp, ki, kd, dt, target_pos,
135                          max_steps, tolerance, settle_steps,
136                          reset=True, verbose=False):
137        """Run a single trial of PID test.
138
139        Args:
140            kp (float): Proportional gain.
141            ki (float): Integral gain.
142            kd (float): Derivative gain.
143            dt (float): Time step.
144            target_pos (list): Target position [x, y, z].
145            max_steps (int): Maximum number of simulation steps.
146            tolerance (float): Position tolerance for settling.
147            settle_steps (int): Number of steps within tolerance to consider settled.
148            reset (bool): Whether to reset simulation before trial. Defaults to True.
149            verbose (bool): Print debug information. Defaults to False.
150
151        Returns:
152            dict: Single trial results containing metrics, position data, and metadata.
153        """
154        if reset:
155            self.sim.reset()
156        
157        axis_idx = self.axis_map[self.axis.lower()]
158        target = target_pos[axis_idx]
159        start_pos = get_position(self.sim)
160
161        # Check target is within limits
162        axis_limits = self.limits[self.axis]
163        if target < axis_limits[0] or target > axis_limits[1]:
164            raise ValueError(f'Target {target:.4f} outside of {self.axis}-axis range [{axis_limits[0]:.4f}, {axis_limits[1]:.4f}]')
165
166        # Move to starting position
167        actions = [[0, 0, 0, 0]]
168        actions[0][axis_idx] = -1000
169        self.sim.run(actions, 100)
170
171        pid = PID(kp=kp, ki=ki, kd=kd, setpoint=target, invert_output=self.invert_output)
172
173        # Data collection
174        position_data = []
175        time_data = []
176        error_data = []
177        control_data = []
178        velocity_data = []
179        torque_data = []
180        
181        # Settling detection
182        steps_in_tolerance = 0
183
184        for step in range(max_steps):
185            current_pos = get_position(self.sim)
186            current_val = current_pos[axis_idx]
187            current_vel = get_velocities(self.sim)
188            current_torque = get_torques(self.sim)
189            
190            if verbose:
191                print(f'current_val {self.axis} axis: {current_val:.5f} -> {target:.5f}')
192            
193            position_data.append(current_val)
194            time_data.append(step)
195            error_data.append(target - current_val)
196            velocity_data.append(current_vel[axis_idx])
197            torque_data.append(current_torque[axis_idx])
198            
199            # Check position remains near target
200            if abs(current_val - target) <= tolerance:
201                steps_in_tolerance += 1
202                if verbose:
203                    print(f'settling {steps_in_tolerance} of {settle_steps}')
204            else:
205                steps_in_tolerance = 0
206
207            # Stop execution if position remains near target
208            if steps_in_tolerance >= settle_steps:
209                break
210
211            # Calculate PID output 
212            vx = pid(current_val, dt)
213            control_data.append(vx)
214
215            actions = [[0, 0, 0, 0]]
216            actions[0][axis_idx] = vx
217            self.sim.run(actions, num_steps=1)
218
219        # Calculate metrics
220        metrics = calculate_metrics(position_data, time_data, target)
221        
222        # Store result
223        result = {
224            'timestamp': datetime.now().isoformat(),
225            'metrics': metrics,
226            'position_data': position_data,
227            'time_data': time_data,
228            'error_data': error_data,
229            'control_data': control_data,
230            'velocity_data': velocity_data,
231            'torque_data': torque_data,
232            'target': target,
233            'start_position': start_pos[self.axis_map[self.axis.lower()]],
234            'settled': steps_in_tolerance >= settle_steps,
235            'invert_output': self.invert_output,
236            'total_time': len(time_data),
237
238        }
239        
240        return result
241    
242    def test_gains(self, kp, ki, kd, dt, target_pos,
243                   max_steps, tolerance, settle_steps, num_trials,
244                   reset=True, verbose=False):
245        """Test a set of PID gains over multiple trials.
246        
247        Args:
248            kp (float): Proportional gain.
249            ki (float): Integral gain.
250            kd (float): Derivative gain.
251            dt (float): Time step.
252            target_pos (list): Target position [x, y, z].
253            max_steps (int): Maximum number of simulation steps.
254            tolerance (float): Position tolerance for settling.
255            settle_steps (int): Number of steps within tolerance to consider settled.
256            num_trials (int): Number of trials to run.
257            reset (bool): Whether to reset simulation before each trial. Defaults to True.
258            verbose (bool): Print debug information. Defaults to False.
259        
260        Returns:
261            dict: Test results with gains, timestamp, and trial data.
262        """
263        trials = []
264
265        for trial in range(num_trials):
266            trial_result = self._run_single_trial(
267                kp, ki, kd, dt, target_pos,
268                max_steps, tolerance, settle_steps,
269                reset=reset, verbose=verbose
270            )
271            trials.append(trial_result)
272
273        result = {
274            'gains': {'kp': kp, 'ki': ki, 'kd': kd},
275            'timestamp': datetime.now().isoformat(),
276            'num_trials': num_trials,
277            'trials': trials
278        }
279
280        return result
281    
282
283class Experiment:
284    """Create a new experiment to track PID testing.
285    
286    Args:
287        title (str): Title of the experiment.
288        hypothesis (str): Hypothesis being tested.
289        axis (str): Axis being tested ('x', 'y', or 'z').
290        invert_output (bool): Whether output was inverted for this axis. Defaults to False.
291    """
292    
293    def __init__(self, title, hypothesis, axis, invert_output=False):
294        self.id = str(uuid.uuid4())
295        self.title = title
296        self.hypothesis = hypothesis
297        self.axis = axis.lower()
298        self.timestamp = datetime.now().isoformat()
299        self.tests = []
300        self.invert_output = invert_output
301    
302    def add_test(self, test_result):
303        """Add a test result from PIDTester.test_gains().
304        
305        Args:
306            test_result (dict): Result dictionary from test_gains().
307        
308        Raises:
309            ValueError: If test_result doesn't contain required keys.
310        """
311        # Basic validation
312        if 'gains' not in test_result or 'trials' not in test_result:
313            raise ValueError("Invalid test_result: must contain 'gains' and 'trials'")
314        
315        self.tests.append(test_result)
316    
317    def get_averaged_metrics(self, test_index):
318        """Get averaged metrics across trials for a specific test.
319        
320        Args:
321            test_index (int): Index of the test in self.tests.
322        
323        Returns:
324            dict: Averaged metrics with mean, std, min, max for each metric.
325        
326        Raises:
327            IndexError: If test_index is out of range.
328        """
329        if test_index >= len(self.tests):
330            raise IndexError(f"Test index {test_index} out of range")
331        
332        test = self.tests[test_index]
333        trials = test['trials']
334        
335        if len(trials) == 1:
336            return trials[0]['metrics']
337        
338        metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
339        averaged = {}
340        
341        for key in metric_keys:
342            values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None]
343            if values:
344                averaged[key] = {
345                    'mean': np.mean(values),
346                    'std': np.std(values),
347                    'min': np.min(values),
348                    'max': np.max(values)
349                }
350            else:
351                averaged[key] = None
352        
353        return averaged
354    
355    def print_summary(self):
356        """Print summary table of all tests in this experiment."""
357        print(f"\nExperiment: {self.title}")
358        print(f"Hypothesis: {self.hypothesis}")
359        print(f"Axis: {self.axis}")
360        print(f"Inverted Output: {self.invert_output}")
361        print(f"Tests: {len(self.tests)}")
362        print("\n" + "="*90)  # <-- MAKE WIDER
363        print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}")  # <-- ADD COLUMN
364        print("="*90)
365        
366        for i, test in enumerate(self.tests):
367            gains = test['gains']
368            metrics = self.get_averaged_metrics(i)
369            
370            # Handle both single trial (dict) and averaged (dict with 'mean')
371            if isinstance(metrics['rise_time'], dict):
372                rise = metrics['rise_time']['mean']
373                settling = metrics['settling_time']['mean']
374                overshoot = metrics['overshoot']['mean']
375                ss_error = metrics['steady_state_error']['mean']
376            else:
377                rise = metrics['rise_time']
378                settling = metrics['settling_time']
379                overshoot = metrics['overshoot']
380                ss_error = metrics['steady_state_error']
381            
382            # Get total time from first trial
383            total_time = test['trials'][0]['total_time']
384            
385            # Format None values as 'N/A' instead of trying to format as float
386            rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}"
387            settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}"
388            overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}"
389            ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}"
390            
391            print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} "
392                f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}")  # <-- ADD TOTAL
393        
394        print("="*90)
395    
396    def _determine_varying_gain(self):
397        """Determine which gain parameter varies across tests.
398        
399        Returns:
400            tuple: (gain_key, x_values, x_labels) where:
401                - gain_key (str or None): 'Kp', 'Ki', 'Kd', or None if multiple vary.
402                - x_values (list): Numeric values for x-axis.
403                - x_labels (list): String labels for x-axis.
404        """
405        gains_list = [test['gains'] for test in self.tests]
406        
407        kp_values = [g['kp'] for g in gains_list]
408        ki_values = [g['ki'] for g in gains_list]
409        kd_values = [g['kd'] for g in gains_list]
410        
411        kp_varies = len(set(kp_values)) > 1
412        ki_varies = len(set(ki_values)) > 1
413        kd_varies = len(set(kd_values)) > 1
414        
415        varies_count = sum([kp_varies, ki_varies, kd_varies])
416        
417        if varies_count == 0:
418            # All tests have same gains
419            return None, list(range(len(self.tests))), [str(i) for i in range(len(self.tests))]
420        elif varies_count == 1:
421            # Exactly one gain varies
422            if kp_varies:
423                # Format labels to avoid overlapping decimals
424                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kp_values]
425                return 'Kp', kp_values, labels
426            elif ki_varies:
427                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in ki_values]
428                return 'Ki', ki_values, labels
429            else:
430                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kd_values]
431                return 'Kd', kd_values, labels
432        else:
433            # Multiple gains vary - use test index with composite labels
434            x_values = list(range(len(self.tests)))
435            x_labels = [f"P={g['kp']:.0f if g['kp']>=1 else g['kp']:.2f},I={g['ki']:.0f if g['ki']>=1 else g['ki']:.2f},D={g['kd']:.0f if g['kd']>=1 else g['kd']:.2f}" for g in gains_list]
436            return None, x_values, x_labels
437
438    def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'):
439        """Plot averaged metrics across all tests.
440        
441        Args:
442            metric_names (list of str or 'all', optional): Which metrics to plot for line plots.
443                If 'all' or None, plots all metrics. Ignored for scatter plots.
444                Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'.
445            separate_subplots (bool): If True, each metric gets its own subplot. 
446                If False, all on one plot. Only applies to line plots. Defaults to True.
447            plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 
448                speed vs stability tradeoff. Defaults to 'lines'.
449        
450        Returns:
451            tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests.
452        """
453        if not self.tests:
454            print("No tests to plot")
455            return None, None
456        
457        # Scatter plot: speed vs stability
458        if plot_type == 'scatter':
459            return self._plot_speed_vs_stability()
460        
461        # Line plot: existing behavior
462        if metric_names == 'all' or metric_names is None:
463            metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
464        
465        # Determine which gain varies
466        gain_key, x_values, x_labels = self._determine_varying_gain()
467        
468        # Collect data
469        data = {metric: [] for metric in metric_names}
470        errors = {metric: [] for metric in metric_names}
471        valid_indices = {metric: [] for metric in metric_names}
472        
473        for i, test in enumerate(self.tests):
474            metrics = self.get_averaged_metrics(i)
475            for metric in metric_names:
476                if isinstance(metrics[metric], dict):
477                    value = metrics[metric]['mean']
478                    error = metrics[metric]['std']
479                else:
480                    value = metrics[metric]
481                    error = 0
482                
483                if value is not None:
484                    data[metric].append(value)
485                    errors[metric].append(error)
486                    valid_indices[metric].append(i)
487        
488        # Create plots
489        if separate_subplots:
490            fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names)))
491            if len(metric_names) == 1:
492                axes = [axes]
493            
494            for ax, metric in zip(axes, metric_names):
495                if data[metric]:
496                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
497                    valid_x_labels = [x_labels[i] for i in valid_indices[metric]]
498                    
499                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 
500                            marker='o', capsize=5, capthick=2)
501                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
502                    ax.set_ylabel(metric.replace('_', ' ').title())
503                    ax.set_xscale('log')
504                    ax.set_xticks(valid_x_values)
505                    ax.set_xticklabels(valid_x_labels)
506                    ax.grid(True, alpha=0.3, which='both')
507                else:
508                    ax.text(0.5, 0.5, f'No valid data for {metric}', 
509                        ha='center', va='center', transform=ax.transAxes)
510                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
511                    ax.set_ylabel(metric.replace('_', ' ').title())
512        else:
513            fig, ax = plt.subplots(figsize=(10, 6))
514            for metric in metric_names:
515                if data[metric]:
516                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
517                    
518                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric],
519                            marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title())
520            
521            ax.set_xlabel(gain_key if gain_key else 'Test Index')
522            ax.set_ylabel('Value')
523            ax.set_xscale('log')
524            if x_values:
525                ax.set_xticks(x_values)
526                ax.set_xticklabels(x_labels)
527            ax.legend()
528            ax.grid(True, alpha=0.3, which='both')
529        
530        fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12)
531        fig.tight_layout()
532        
533        return fig, axes if separate_subplots else ax
534
535    def _plot_speed_vs_stability(self):
536        """Create scatter plot of speed (total time) vs stability (overshoot).
537        
538        Returns:
539            tuple: (fig, ax) Matplotlib figure and axis objects.
540        """
541        # Collect data
542        total_times = []
543        overshoots = []
544        gain_values = []
545        
546        for i, test in enumerate(self.tests):
547            metrics = self.get_averaged_metrics(i)
548            
549            # Get overshoot
550            if isinstance(metrics['overshoot'], dict):
551                overshoot = metrics['overshoot']['mean']
552            else:
553                overshoot = metrics['overshoot']
554            
555            # Get total time from first trial
556            total_time = test['trials'][0]['total_time']
557            
558            if overshoot is not None and total_time is not None:
559                total_times.append(total_time)
560                overshoots.append(overshoot)
561                gain_values.append(i)
562        
563        if not total_times:
564            print("No valid data for scatter plot")
565            return None, None
566        
567        # Determine which gain varies for coloring
568        gain_key, varying_gains, _ = self._determine_varying_gain()
569        
570        # Create scatter plot
571        fig, ax = plt.subplots(figsize=(10, 8))
572        
573        # Use colormap based on varying gain values
574        if gain_key:
575            # Extract the actual varying gain values for valid indices
576            color_values = [varying_gains[i] for i in gain_values]
577            scatter = ax.scatter(total_times, overshoots, c=color_values, 
578                            cmap='viridis', s=100, edgecolors='black', linewidth=1.5)
579            cbar = plt.colorbar(scatter, ax=ax)
580            cbar.set_label(gain_key, rotation=270, labelpad=20)
581        else:
582            # No single varying gain, use uniform color
583            ax.scatter(total_times, overshoots, s=100, edgecolors='black', linewidth=1.5)
584        
585        # Annotate points with gain values
586        gains_list = [self.tests[i]['gains'] for i in gain_values]
587        for i, (x, y, gains) in enumerate(zip(total_times, overshoots, gains_list)):
588            # Format gain values outside f-string
589            kp_str = f"{gains['kp']:.0f}" if gains['kp'] >= 1 else f"{gains['kp']:.2f}"
590            ki_str = f"{gains['ki']:.0f}" if gains['ki'] >= 1 else f"{gains['ki']:.2f}"
591            kd_str = f"{gains['kd']:.0f}" if gains['kd'] >= 1 else f"{gains['kd']:.2f}"
592            
593            label = f"P={kp_str}\nI={ki_str}\nD={kd_str}"
594            ax.annotate(label, (x, y), xytext=(5, 5), textcoords='offset points', 
595                    fontsize=8, alpha=0.7)
596        
597        ax.set_xlabel('Total Time (steps)', fontsize=12)
598        ax.set_ylabel('Overshoot (%)', fontsize=12)
599        ax.set_title(f"{self.title}\nSpeed vs Stability Tradeoff", fontsize=14)
600        ax.grid(True, alpha=0.3)
601        
602        # Add arrow pointing to ideal region (bottom-left)
603        ax.annotate('Ideal\n(Fast & Stable)', xy=(min(total_times), min(overshoots)),
604                xytext=(0.05, 0.95), textcoords='axes fraction',
605                fontsize=10, color='green', weight='bold',
606                ha='left', va='top',
607                arrowprops=dict(arrowstyle='->', color='green', lw=2))
608        
609        fig.tight_layout()
610        
611        return fig, ax
612    
613    def plot_trajectory(self, test_index, trial_index=0):
614        """Plot detailed trajectory data for a specific trial.
615        
616        Shows position, error, and control signal over time.
617        
618        Args:
619            test_index (int): Index of the test in self.tests.
620            trial_index (int): Index of the trial within that test. Defaults to 0.
621        
622        Returns:
623            tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots).
624        
625        Raises:
626            IndexError: If test_index or trial_index is out of range.
627        """
628        if test_index >= len(self.tests):
629            raise IndexError(f"Test index {test_index} out of range")
630        
631        test = self.tests[test_index]
632        
633        if trial_index >= len(test['trials']):
634            raise IndexError(f"Trial index {trial_index} out of range")
635        
636        trial = test['trials'][trial_index]
637        gains = test['gains']
638        
639        # Extract data
640        time_data = trial['time_data']
641        velocity_data = trial['velocity_data']
642        torque_data = trial['torque_data']
643        position_data = trial['position_data']
644        error_data = trial['error_data']
645        control_data = trial['control_data']
646        target = trial['target']
647        
648        # Create figure with 3 subplots
649        fig, axes = plt.subplots(5, 1, figsize=(12, 10))
650        
651        # Plot 1: Position vs Time
652        axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position')
653        axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target')
654        axes[0].set_xlabel('Time (steps)')
655        axes[0].set_ylabel('Position')
656        axes[0].set_title('Position vs Time')
657        axes[0].legend()
658        axes[0].grid(True, alpha=0.3)
659        
660        # Plot 2: Error vs Time
661        axes[1].plot(time_data, error_data, 'g-', linewidth=2)
662        axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1)
663        axes[1].set_xlabel('Time (steps)')
664        axes[1].set_ylabel('Error')
665        axes[1].set_title('Error vs Time')
666        axes[1].grid(True, alpha=0.3)
667        
668        # Plot 3: Control Signal vs Time
669        axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2)
670        axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1)
671        axes[2].set_xlabel('Time (steps)')
672        axes[2].set_ylabel('Control Signal')
673        axes[2].set_title('Control Signal vs Time')
674        axes[2].grid(True, alpha=0.3)
675
676        # Plot 4: Velocity vs Time
677        axes[3].plot(time_data, velocity_data, 'purple', linewidth=2)
678        axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1)
679        axes[3].set_xlabel('Time (steps)')
680        axes[3].set_ylabel('Velocity')
681        axes[3].set_title('Velocity vs Time')
682        axes[3].grid(True, alpha=0.3)
683        
684        # Plot 5: Torque vs Time
685        axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2)
686        axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1)
687        axes[4].set_xlabel('Time (steps)')
688        axes[4].set_ylabel('Torque')
689        axes[4].set_title('Torque vs Time')
690        axes[4].grid(True, alpha=0.3)
691        # Overall title
692        fig.suptitle(
693            f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}",
694            fontsize=12
695        )
696        fig.tight_layout()
697        
698        return fig, axes
699    
700    def to_dict(self):
701        """Convert experiment to dictionary for JSON serialization.
702        
703        Returns:
704            dict: Dictionary representation of the experiment.
705        """
706        return {
707            'id': self.id,
708            'title': self.title,
709            'hypothesis': self.hypothesis,
710            'axis': self.axis,
711            'timestamp': self.timestamp,
712            'tests': self.tests,
713            'invert_output': self.invert_output
714        }
715
716
717class ExperimentLog:
718    """Manage a log of PID experiments.
719    
720    Args:
721        filepath (str or Path): Path to the JSON log file.
722    """
723    
724    def __init__(self, filepath):
725        self.filepath = Path(filepath)
726        self.experiments = []
727    
728    def add_experiment(self, experiment):
729        """Add an Experiment to the log.
730        
731        Args:
732            experiment (Experiment): Experiment object to add.
733        
734        Raises:
735            TypeError: If experiment is not an Experiment object.
736        """
737        if not isinstance(experiment, Experiment):
738            raise TypeError("Must add an Experiment object")
739        self.experiments.append(experiment)
740    
741    def save(self):
742        """Save experiments to JSON file, preserving existing experiments.
743        
744        Checks for duplicate IDs and updates existing experiments.
745        """
746        # Load existing data if file exists
747        if self.filepath.exists():
748            with open(self.filepath, 'r') as f:
749                data = json.load(f)
750            existing_experiments = data.get('experiments', [])
751        else:
752            existing_experiments = []
753        
754        # Create dict of existing experiments by ID
755        existing_by_id = {exp['id']: exp for exp in existing_experiments}
756        
757        # Add or update with current experiments
758        for exp in self.experiments:
759            existing_by_id[exp.id] = exp.to_dict()
760        
761        # Save all experiments
762        data = {'experiments': list(existing_by_id.values())}
763        
764        with open(self.filepath, 'w') as f:
765            json.dump(data, f, indent=2)
766        
767        print(f"Saved {len(self.experiments)} experiments to {self.filepath}")
768    
769    @classmethod
770    def load(cls, filepath):
771        """Load experiments from JSON file.
772        
773        Args:
774            filepath (str or Path): Path to the JSON log file.
775        
776        Returns:
777            ExperimentLog: Log with loaded Experiment objects.
778        """
779        log = cls(filepath)
780        
781        if not log.filepath.exists():
782            print(f"No existing file at {filepath}, starting new log")
783            return log
784        
785        with open(log.filepath, 'r') as f:
786            data = json.load(f)
787        
788        for exp_dict in data.get('experiments', []):
789            exp = cls._from_dict(exp_dict)
790            log.experiments.append(exp)
791        
792        print(f"Loaded {len(log.experiments)} experiments from {filepath}")
793        return log
794    
795    @staticmethod
796    def _from_dict(exp_dict):
797        """Reconstruct an Experiment object from a dictionary.
798        
799        Args:
800            exp_dict (dict): Dictionary representation of an experiment.
801        
802        Returns:
803            Experiment: Reconstructed Experiment object.
804        """
805        exp = Experiment(
806            title=exp_dict['title'],
807            hypothesis=exp_dict['hypothesis'],
808            axis=exp_dict['axis'],
809            invert_output=exp_dict.get('invert_output', False)
810        )
811        exp.id = exp_dict['id']
812        exp.timestamp = exp_dict['timestamp']
813        exp.tests = exp_dict['tests']
814        return exp
815    
816    def get_experiment(self, identifier):
817        """Get experiment by title or short hash (first 6 characters of ID).
818        
819        Args:
820            identifier (str): Either the full/partial title or first 6 chars of ID.
821        
822        Returns:
823            Experiment or None: Matching experiment, or None if not found.
824        """
825        # Try short hash first
826        if len(identifier) == 6:
827            for exp in self.experiments:
828                if exp.id[:6] == identifier:
829                    return exp
830        
831        # Try title (case-insensitive partial match)
832        identifier_lower = identifier.lower()
833        for exp in self.experiments:
834            if identifier_lower in exp.title.lower():
835                return exp
836        
837        return None
838    
839    def list_experiments(self):
840        """Print summary table of all experiments."""
841        if not self.experiments:
842            print("No experiments in log")
843            return
844        
845        print("\n" + "="*100)
846        print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}")
847        print("="*100)
848        
849        for exp in self.experiments:
850            short_id = exp.id[:6]
851            title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title
852            hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis
853            num_tests = len(exp.tests)
854            
855            print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}")
856        
857        print("="*100)
def calculate_metrics(position_data, time_data, target):
13def calculate_metrics(position_data, time_data, target):
14    """Calculate all PID performance metrics.
15    
16    Args:
17        position_data (list): Position values over time.
18        time_data (list): Time step values.
19        target (float): Target position.
20        
21    Returns:
22        dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error.
23    """
24    if not position_data:
25        return {
26            'rise_time': None,
27            'settling_time': None,
28            'overshoot': 0.0,
29            'steady_state_error': None
30        }
31    
32    position_array = np.array(position_data)
33    start_val = position_data[0]
34    
35    # Avoid division by zero
36    if abs(target - start_val) < 1e-12:
37        return {
38            'rise_time': 0.0,
39            'settling_time': 0.0,
40            'overshoot': 0.0,
41            'steady_state_error': abs(position_array[-1] - target)
42        }
43    
44    # Calculate rise time (10% to 90%)
45    ten_percent = start_val + 0.1 * (target - start_val)
46    ninety_percent = start_val + 0.9 * (target - start_val)
47    moving_positive = target > start_val
48    
49    ten_idx = None
50    ninety_idx = None
51    
52    for i, pos in enumerate(position_data):
53        if ten_idx is None:
54            if (moving_positive and pos >= ten_percent) or (not moving_positive and pos <= ten_percent):
55                ten_idx = i
56        if ninety_idx is None:
57            if (moving_positive and pos >= ninety_percent) or (not moving_positive and pos <= ninety_percent):
58                ninety_idx = i
59                break
60    
61    rise_time = time_data[ninety_idx] - time_data[ten_idx] if (ten_idx is not None and ninety_idx is not None) else None
62    
63    # Calculate settling time (2% band)
64    settling_band = abs(target - start_val) * 0.02
65    settling_time = None
66    
67    for i in range(len(position_data) - 1, -1, -1):
68        if abs(position_data[i] - target) > settling_band:
69            settling_time = time_data[i] if i < len(time_data) - 1 else time_data[-1]
70            break
71    
72    if settling_time is None:
73        settling_time = 0
74    
75    # Calculate overshoot
76    if target > start_val:
77        max_val = np.max(position_array)
78        overshoot = (max_val - target) / abs(target - start_val) * 100 if max_val > target else 0
79    else:
80        min_val = np.min(position_array)
81        overshoot = (target - min_val) / abs(target - start_val) * 100 if min_val < target else 0
82    
83    # Calculate steady state error
84    steady_state_error = abs(position_array[-1] - target)
85    
86    return {
87        'rise_time': rise_time,
88        'settling_time': settling_time,
89        'overshoot': overshoot,
90        'steady_state_error': steady_state_error
91    }

Calculate all PID performance metrics.

Arguments:
  • position_data (list): Position values over time.
  • time_data (list): Time step values.
  • target (float): Target position.
Returns:

dict: Dictionary containing rise_time, settling_time, overshoot, steady_state_error.

class PIDTester:
 94class PIDTester:
 95    """PID controller testing class.
 96    
 97    Args:
 98        sim (Simulation): The simulation object.
 99        axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'.
100        limits (dict, optional): Workspace limits (auto-calculated if not provided).
101        invert_output (bool): Whether to invert the PID output for this axis. Defaults to False.
102    """
103    
104    def __init__(self, sim, axis='x', limits=None, invert_output=False):
105        self.sim = sim
106        self.axis = axis.lower()
107        self.axis_map = {'x': 0, 'y': 1, 'z': 2}
108        self.limits = limits
109        self.invert_output = invert_output
110                
111    @property
112    def limits(self):
113        if self._limits is None:
114            w = find_workspace(self.sim)
115            self._limits = {
116                'x': [w['x_min'], w['x_max']],
117                'y': [w['y_min'], w['y_max']],
118                'z': [w['z_min'], w['z_max']],
119                'center': w['center']
120            }
121        return self._limits
122
123    @limits.setter
124    def limits(self, workspace=None):
125        if workspace is not None:
126            self._limits = {
127                'x': [workspace['x_min'], workspace['x_max']],
128                'y': [workspace['y_min'], workspace['y_max']],
129                'z': [workspace['z_min'], workspace['z_max']],
130                'center': workspace['center']
131            }
132        else:
133            self._limits = None
134
135    def _run_single_trial(self, kp, ki, kd, dt, target_pos,
136                          max_steps, tolerance, settle_steps,
137                          reset=True, verbose=False):
138        """Run a single trial of PID test.
139
140        Args:
141            kp (float): Proportional gain.
142            ki (float): Integral gain.
143            kd (float): Derivative gain.
144            dt (float): Time step.
145            target_pos (list): Target position [x, y, z].
146            max_steps (int): Maximum number of simulation steps.
147            tolerance (float): Position tolerance for settling.
148            settle_steps (int): Number of steps within tolerance to consider settled.
149            reset (bool): Whether to reset simulation before trial. Defaults to True.
150            verbose (bool): Print debug information. Defaults to False.
151
152        Returns:
153            dict: Single trial results containing metrics, position data, and metadata.
154        """
155        if reset:
156            self.sim.reset()
157        
158        axis_idx = self.axis_map[self.axis.lower()]
159        target = target_pos[axis_idx]
160        start_pos = get_position(self.sim)
161
162        # Check target is within limits
163        axis_limits = self.limits[self.axis]
164        if target < axis_limits[0] or target > axis_limits[1]:
165            raise ValueError(f'Target {target:.4f} outside of {self.axis}-axis range [{axis_limits[0]:.4f}, {axis_limits[1]:.4f}]')
166
167        # Move to starting position
168        actions = [[0, 0, 0, 0]]
169        actions[0][axis_idx] = -1000
170        self.sim.run(actions, 100)
171
172        pid = PID(kp=kp, ki=ki, kd=kd, setpoint=target, invert_output=self.invert_output)
173
174        # Data collection
175        position_data = []
176        time_data = []
177        error_data = []
178        control_data = []
179        velocity_data = []
180        torque_data = []
181        
182        # Settling detection
183        steps_in_tolerance = 0
184
185        for step in range(max_steps):
186            current_pos = get_position(self.sim)
187            current_val = current_pos[axis_idx]
188            current_vel = get_velocities(self.sim)
189            current_torque = get_torques(self.sim)
190            
191            if verbose:
192                print(f'current_val {self.axis} axis: {current_val:.5f} -> {target:.5f}')
193            
194            position_data.append(current_val)
195            time_data.append(step)
196            error_data.append(target - current_val)
197            velocity_data.append(current_vel[axis_idx])
198            torque_data.append(current_torque[axis_idx])
199            
200            # Check position remains near target
201            if abs(current_val - target) <= tolerance:
202                steps_in_tolerance += 1
203                if verbose:
204                    print(f'settling {steps_in_tolerance} of {settle_steps}')
205            else:
206                steps_in_tolerance = 0
207
208            # Stop execution if position remains near target
209            if steps_in_tolerance >= settle_steps:
210                break
211
212            # Calculate PID output 
213            vx = pid(current_val, dt)
214            control_data.append(vx)
215
216            actions = [[0, 0, 0, 0]]
217            actions[0][axis_idx] = vx
218            self.sim.run(actions, num_steps=1)
219
220        # Calculate metrics
221        metrics = calculate_metrics(position_data, time_data, target)
222        
223        # Store result
224        result = {
225            'timestamp': datetime.now().isoformat(),
226            'metrics': metrics,
227            'position_data': position_data,
228            'time_data': time_data,
229            'error_data': error_data,
230            'control_data': control_data,
231            'velocity_data': velocity_data,
232            'torque_data': torque_data,
233            'target': target,
234            'start_position': start_pos[self.axis_map[self.axis.lower()]],
235            'settled': steps_in_tolerance >= settle_steps,
236            'invert_output': self.invert_output,
237            'total_time': len(time_data),
238
239        }
240        
241        return result
242    
243    def test_gains(self, kp, ki, kd, dt, target_pos,
244                   max_steps, tolerance, settle_steps, num_trials,
245                   reset=True, verbose=False):
246        """Test a set of PID gains over multiple trials.
247        
248        Args:
249            kp (float): Proportional gain.
250            ki (float): Integral gain.
251            kd (float): Derivative gain.
252            dt (float): Time step.
253            target_pos (list): Target position [x, y, z].
254            max_steps (int): Maximum number of simulation steps.
255            tolerance (float): Position tolerance for settling.
256            settle_steps (int): Number of steps within tolerance to consider settled.
257            num_trials (int): Number of trials to run.
258            reset (bool): Whether to reset simulation before each trial. Defaults to True.
259            verbose (bool): Print debug information. Defaults to False.
260        
261        Returns:
262            dict: Test results with gains, timestamp, and trial data.
263        """
264        trials = []
265
266        for trial in range(num_trials):
267            trial_result = self._run_single_trial(
268                kp, ki, kd, dt, target_pos,
269                max_steps, tolerance, settle_steps,
270                reset=reset, verbose=verbose
271            )
272            trials.append(trial_result)
273
274        result = {
275            'gains': {'kp': kp, 'ki': ki, 'kd': kd},
276            'timestamp': datetime.now().isoformat(),
277            'num_trials': num_trials,
278            'trials': trials
279        }
280
281        return result

PID controller testing class.

Arguments:
  • sim (Simulation): The simulation object.
  • axis (str): Which axis to test ('x', 'y', or 'z'). Defaults to 'x'.
  • limits (dict, optional): Workspace limits (auto-calculated if not provided).
  • invert_output (bool): Whether to invert the PID output for this axis. Defaults to False.
PIDTester(sim, axis='x', limits=None, invert_output=False)
104    def __init__(self, sim, axis='x', limits=None, invert_output=False):
105        self.sim = sim
106        self.axis = axis.lower()
107        self.axis_map = {'x': 0, 'y': 1, 'z': 2}
108        self.limits = limits
109        self.invert_output = invert_output
sim
axis
axis_map
limits
111    @property
112    def limits(self):
113        if self._limits is None:
114            w = find_workspace(self.sim)
115            self._limits = {
116                'x': [w['x_min'], w['x_max']],
117                'y': [w['y_min'], w['y_max']],
118                'z': [w['z_min'], w['z_max']],
119                'center': w['center']
120            }
121        return self._limits
invert_output
def test_gains( self, kp, ki, kd, dt, target_pos, max_steps, tolerance, settle_steps, num_trials, reset=True, verbose=False):
243    def test_gains(self, kp, ki, kd, dt, target_pos,
244                   max_steps, tolerance, settle_steps, num_trials,
245                   reset=True, verbose=False):
246        """Test a set of PID gains over multiple trials.
247        
248        Args:
249            kp (float): Proportional gain.
250            ki (float): Integral gain.
251            kd (float): Derivative gain.
252            dt (float): Time step.
253            target_pos (list): Target position [x, y, z].
254            max_steps (int): Maximum number of simulation steps.
255            tolerance (float): Position tolerance for settling.
256            settle_steps (int): Number of steps within tolerance to consider settled.
257            num_trials (int): Number of trials to run.
258            reset (bool): Whether to reset simulation before each trial. Defaults to True.
259            verbose (bool): Print debug information. Defaults to False.
260        
261        Returns:
262            dict: Test results with gains, timestamp, and trial data.
263        """
264        trials = []
265
266        for trial in range(num_trials):
267            trial_result = self._run_single_trial(
268                kp, ki, kd, dt, target_pos,
269                max_steps, tolerance, settle_steps,
270                reset=reset, verbose=verbose
271            )
272            trials.append(trial_result)
273
274        result = {
275            'gains': {'kp': kp, 'ki': ki, 'kd': kd},
276            'timestamp': datetime.now().isoformat(),
277            'num_trials': num_trials,
278            'trials': trials
279        }
280
281        return result

Test a set of PID gains over multiple trials.

Arguments:
  • kp (float): Proportional gain.
  • ki (float): Integral gain.
  • kd (float): Derivative gain.
  • dt (float): Time step.
  • target_pos (list): Target position [x, y, z].
  • max_steps (int): Maximum number of simulation steps.
  • tolerance (float): Position tolerance for settling.
  • settle_steps (int): Number of steps within tolerance to consider settled.
  • num_trials (int): Number of trials to run.
  • reset (bool): Whether to reset simulation before each trial. Defaults to True.
  • verbose (bool): Print debug information. Defaults to False.
Returns:

dict: Test results with gains, timestamp, and trial data.

class Experiment:
284class Experiment:
285    """Create a new experiment to track PID testing.
286    
287    Args:
288        title (str): Title of the experiment.
289        hypothesis (str): Hypothesis being tested.
290        axis (str): Axis being tested ('x', 'y', or 'z').
291        invert_output (bool): Whether output was inverted for this axis. Defaults to False.
292    """
293    
294    def __init__(self, title, hypothesis, axis, invert_output=False):
295        self.id = str(uuid.uuid4())
296        self.title = title
297        self.hypothesis = hypothesis
298        self.axis = axis.lower()
299        self.timestamp = datetime.now().isoformat()
300        self.tests = []
301        self.invert_output = invert_output
302    
303    def add_test(self, test_result):
304        """Add a test result from PIDTester.test_gains().
305        
306        Args:
307            test_result (dict): Result dictionary from test_gains().
308        
309        Raises:
310            ValueError: If test_result doesn't contain required keys.
311        """
312        # Basic validation
313        if 'gains' not in test_result or 'trials' not in test_result:
314            raise ValueError("Invalid test_result: must contain 'gains' and 'trials'")
315        
316        self.tests.append(test_result)
317    
318    def get_averaged_metrics(self, test_index):
319        """Get averaged metrics across trials for a specific test.
320        
321        Args:
322            test_index (int): Index of the test in self.tests.
323        
324        Returns:
325            dict: Averaged metrics with mean, std, min, max for each metric.
326        
327        Raises:
328            IndexError: If test_index is out of range.
329        """
330        if test_index >= len(self.tests):
331            raise IndexError(f"Test index {test_index} out of range")
332        
333        test = self.tests[test_index]
334        trials = test['trials']
335        
336        if len(trials) == 1:
337            return trials[0]['metrics']
338        
339        metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
340        averaged = {}
341        
342        for key in metric_keys:
343            values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None]
344            if values:
345                averaged[key] = {
346                    'mean': np.mean(values),
347                    'std': np.std(values),
348                    'min': np.min(values),
349                    'max': np.max(values)
350                }
351            else:
352                averaged[key] = None
353        
354        return averaged
355    
356    def print_summary(self):
357        """Print summary table of all tests in this experiment."""
358        print(f"\nExperiment: {self.title}")
359        print(f"Hypothesis: {self.hypothesis}")
360        print(f"Axis: {self.axis}")
361        print(f"Inverted Output: {self.invert_output}")
362        print(f"Tests: {len(self.tests)}")
363        print("\n" + "="*90)  # <-- MAKE WIDER
364        print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}")  # <-- ADD COLUMN
365        print("="*90)
366        
367        for i, test in enumerate(self.tests):
368            gains = test['gains']
369            metrics = self.get_averaged_metrics(i)
370            
371            # Handle both single trial (dict) and averaged (dict with 'mean')
372            if isinstance(metrics['rise_time'], dict):
373                rise = metrics['rise_time']['mean']
374                settling = metrics['settling_time']['mean']
375                overshoot = metrics['overshoot']['mean']
376                ss_error = metrics['steady_state_error']['mean']
377            else:
378                rise = metrics['rise_time']
379                settling = metrics['settling_time']
380                overshoot = metrics['overshoot']
381                ss_error = metrics['steady_state_error']
382            
383            # Get total time from first trial
384            total_time = test['trials'][0]['total_time']
385            
386            # Format None values as 'N/A' instead of trying to format as float
387            rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}"
388            settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}"
389            overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}"
390            ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}"
391            
392            print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} "
393                f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}")  # <-- ADD TOTAL
394        
395        print("="*90)
396    
397    def _determine_varying_gain(self):
398        """Determine which gain parameter varies across tests.
399        
400        Returns:
401            tuple: (gain_key, x_values, x_labels) where:
402                - gain_key (str or None): 'Kp', 'Ki', 'Kd', or None if multiple vary.
403                - x_values (list): Numeric values for x-axis.
404                - x_labels (list): String labels for x-axis.
405        """
406        gains_list = [test['gains'] for test in self.tests]
407        
408        kp_values = [g['kp'] for g in gains_list]
409        ki_values = [g['ki'] for g in gains_list]
410        kd_values = [g['kd'] for g in gains_list]
411        
412        kp_varies = len(set(kp_values)) > 1
413        ki_varies = len(set(ki_values)) > 1
414        kd_varies = len(set(kd_values)) > 1
415        
416        varies_count = sum([kp_varies, ki_varies, kd_varies])
417        
418        if varies_count == 0:
419            # All tests have same gains
420            return None, list(range(len(self.tests))), [str(i) for i in range(len(self.tests))]
421        elif varies_count == 1:
422            # Exactly one gain varies
423            if kp_varies:
424                # Format labels to avoid overlapping decimals
425                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kp_values]
426                return 'Kp', kp_values, labels
427            elif ki_varies:
428                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in ki_values]
429                return 'Ki', ki_values, labels
430            else:
431                labels = [f"{v:.0f}" if v >= 1 else f"{v:.2f}" for v in kd_values]
432                return 'Kd', kd_values, labels
433        else:
434            # Multiple gains vary - use test index with composite labels
435            x_values = list(range(len(self.tests)))
436            x_labels = [f"P={g['kp']:.0f if g['kp']>=1 else g['kp']:.2f},I={g['ki']:.0f if g['ki']>=1 else g['ki']:.2f},D={g['kd']:.0f if g['kd']>=1 else g['kd']:.2f}" for g in gains_list]
437            return None, x_values, x_labels
438
439    def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'):
440        """Plot averaged metrics across all tests.
441        
442        Args:
443            metric_names (list of str or 'all', optional): Which metrics to plot for line plots.
444                If 'all' or None, plots all metrics. Ignored for scatter plots.
445                Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'.
446            separate_subplots (bool): If True, each metric gets its own subplot. 
447                If False, all on one plot. Only applies to line plots. Defaults to True.
448            plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 
449                speed vs stability tradeoff. Defaults to 'lines'.
450        
451        Returns:
452            tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests.
453        """
454        if not self.tests:
455            print("No tests to plot")
456            return None, None
457        
458        # Scatter plot: speed vs stability
459        if plot_type == 'scatter':
460            return self._plot_speed_vs_stability()
461        
462        # Line plot: existing behavior
463        if metric_names == 'all' or metric_names is None:
464            metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
465        
466        # Determine which gain varies
467        gain_key, x_values, x_labels = self._determine_varying_gain()
468        
469        # Collect data
470        data = {metric: [] for metric in metric_names}
471        errors = {metric: [] for metric in metric_names}
472        valid_indices = {metric: [] for metric in metric_names}
473        
474        for i, test in enumerate(self.tests):
475            metrics = self.get_averaged_metrics(i)
476            for metric in metric_names:
477                if isinstance(metrics[metric], dict):
478                    value = metrics[metric]['mean']
479                    error = metrics[metric]['std']
480                else:
481                    value = metrics[metric]
482                    error = 0
483                
484                if value is not None:
485                    data[metric].append(value)
486                    errors[metric].append(error)
487                    valid_indices[metric].append(i)
488        
489        # Create plots
490        if separate_subplots:
491            fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names)))
492            if len(metric_names) == 1:
493                axes = [axes]
494            
495            for ax, metric in zip(axes, metric_names):
496                if data[metric]:
497                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
498                    valid_x_labels = [x_labels[i] for i in valid_indices[metric]]
499                    
500                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 
501                            marker='o', capsize=5, capthick=2)
502                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
503                    ax.set_ylabel(metric.replace('_', ' ').title())
504                    ax.set_xscale('log')
505                    ax.set_xticks(valid_x_values)
506                    ax.set_xticklabels(valid_x_labels)
507                    ax.grid(True, alpha=0.3, which='both')
508                else:
509                    ax.text(0.5, 0.5, f'No valid data for {metric}', 
510                        ha='center', va='center', transform=ax.transAxes)
511                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
512                    ax.set_ylabel(metric.replace('_', ' ').title())
513        else:
514            fig, ax = plt.subplots(figsize=(10, 6))
515            for metric in metric_names:
516                if data[metric]:
517                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
518                    
519                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric],
520                            marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title())
521            
522            ax.set_xlabel(gain_key if gain_key else 'Test Index')
523            ax.set_ylabel('Value')
524            ax.set_xscale('log')
525            if x_values:
526                ax.set_xticks(x_values)
527                ax.set_xticklabels(x_labels)
528            ax.legend()
529            ax.grid(True, alpha=0.3, which='both')
530        
531        fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12)
532        fig.tight_layout()
533        
534        return fig, axes if separate_subplots else ax
535
536    def _plot_speed_vs_stability(self):
537        """Create scatter plot of speed (total time) vs stability (overshoot).
538        
539        Returns:
540            tuple: (fig, ax) Matplotlib figure and axis objects.
541        """
542        # Collect data
543        total_times = []
544        overshoots = []
545        gain_values = []
546        
547        for i, test in enumerate(self.tests):
548            metrics = self.get_averaged_metrics(i)
549            
550            # Get overshoot
551            if isinstance(metrics['overshoot'], dict):
552                overshoot = metrics['overshoot']['mean']
553            else:
554                overshoot = metrics['overshoot']
555            
556            # Get total time from first trial
557            total_time = test['trials'][0]['total_time']
558            
559            if overshoot is not None and total_time is not None:
560                total_times.append(total_time)
561                overshoots.append(overshoot)
562                gain_values.append(i)
563        
564        if not total_times:
565            print("No valid data for scatter plot")
566            return None, None
567        
568        # Determine which gain varies for coloring
569        gain_key, varying_gains, _ = self._determine_varying_gain()
570        
571        # Create scatter plot
572        fig, ax = plt.subplots(figsize=(10, 8))
573        
574        # Use colormap based on varying gain values
575        if gain_key:
576            # Extract the actual varying gain values for valid indices
577            color_values = [varying_gains[i] for i in gain_values]
578            scatter = ax.scatter(total_times, overshoots, c=color_values, 
579                            cmap='viridis', s=100, edgecolors='black', linewidth=1.5)
580            cbar = plt.colorbar(scatter, ax=ax)
581            cbar.set_label(gain_key, rotation=270, labelpad=20)
582        else:
583            # No single varying gain, use uniform color
584            ax.scatter(total_times, overshoots, s=100, edgecolors='black', linewidth=1.5)
585        
586        # Annotate points with gain values
587        gains_list = [self.tests[i]['gains'] for i in gain_values]
588        for i, (x, y, gains) in enumerate(zip(total_times, overshoots, gains_list)):
589            # Format gain values outside f-string
590            kp_str = f"{gains['kp']:.0f}" if gains['kp'] >= 1 else f"{gains['kp']:.2f}"
591            ki_str = f"{gains['ki']:.0f}" if gains['ki'] >= 1 else f"{gains['ki']:.2f}"
592            kd_str = f"{gains['kd']:.0f}" if gains['kd'] >= 1 else f"{gains['kd']:.2f}"
593            
594            label = f"P={kp_str}\nI={ki_str}\nD={kd_str}"
595            ax.annotate(label, (x, y), xytext=(5, 5), textcoords='offset points', 
596                    fontsize=8, alpha=0.7)
597        
598        ax.set_xlabel('Total Time (steps)', fontsize=12)
599        ax.set_ylabel('Overshoot (%)', fontsize=12)
600        ax.set_title(f"{self.title}\nSpeed vs Stability Tradeoff", fontsize=14)
601        ax.grid(True, alpha=0.3)
602        
603        # Add arrow pointing to ideal region (bottom-left)
604        ax.annotate('Ideal\n(Fast & Stable)', xy=(min(total_times), min(overshoots)),
605                xytext=(0.05, 0.95), textcoords='axes fraction',
606                fontsize=10, color='green', weight='bold',
607                ha='left', va='top',
608                arrowprops=dict(arrowstyle='->', color='green', lw=2))
609        
610        fig.tight_layout()
611        
612        return fig, ax
613    
614    def plot_trajectory(self, test_index, trial_index=0):
615        """Plot detailed trajectory data for a specific trial.
616        
617        Shows position, error, and control signal over time.
618        
619        Args:
620            test_index (int): Index of the test in self.tests.
621            trial_index (int): Index of the trial within that test. Defaults to 0.
622        
623        Returns:
624            tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots).
625        
626        Raises:
627            IndexError: If test_index or trial_index is out of range.
628        """
629        if test_index >= len(self.tests):
630            raise IndexError(f"Test index {test_index} out of range")
631        
632        test = self.tests[test_index]
633        
634        if trial_index >= len(test['trials']):
635            raise IndexError(f"Trial index {trial_index} out of range")
636        
637        trial = test['trials'][trial_index]
638        gains = test['gains']
639        
640        # Extract data
641        time_data = trial['time_data']
642        velocity_data = trial['velocity_data']
643        torque_data = trial['torque_data']
644        position_data = trial['position_data']
645        error_data = trial['error_data']
646        control_data = trial['control_data']
647        target = trial['target']
648        
649        # Create figure with 3 subplots
650        fig, axes = plt.subplots(5, 1, figsize=(12, 10))
651        
652        # Plot 1: Position vs Time
653        axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position')
654        axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target')
655        axes[0].set_xlabel('Time (steps)')
656        axes[0].set_ylabel('Position')
657        axes[0].set_title('Position vs Time')
658        axes[0].legend()
659        axes[0].grid(True, alpha=0.3)
660        
661        # Plot 2: Error vs Time
662        axes[1].plot(time_data, error_data, 'g-', linewidth=2)
663        axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1)
664        axes[1].set_xlabel('Time (steps)')
665        axes[1].set_ylabel('Error')
666        axes[1].set_title('Error vs Time')
667        axes[1].grid(True, alpha=0.3)
668        
669        # Plot 3: Control Signal vs Time
670        axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2)
671        axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1)
672        axes[2].set_xlabel('Time (steps)')
673        axes[2].set_ylabel('Control Signal')
674        axes[2].set_title('Control Signal vs Time')
675        axes[2].grid(True, alpha=0.3)
676
677        # Plot 4: Velocity vs Time
678        axes[3].plot(time_data, velocity_data, 'purple', linewidth=2)
679        axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1)
680        axes[3].set_xlabel('Time (steps)')
681        axes[3].set_ylabel('Velocity')
682        axes[3].set_title('Velocity vs Time')
683        axes[3].grid(True, alpha=0.3)
684        
685        # Plot 5: Torque vs Time
686        axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2)
687        axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1)
688        axes[4].set_xlabel('Time (steps)')
689        axes[4].set_ylabel('Torque')
690        axes[4].set_title('Torque vs Time')
691        axes[4].grid(True, alpha=0.3)
692        # Overall title
693        fig.suptitle(
694            f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}",
695            fontsize=12
696        )
697        fig.tight_layout()
698        
699        return fig, axes
700    
701    def to_dict(self):
702        """Convert experiment to dictionary for JSON serialization.
703        
704        Returns:
705            dict: Dictionary representation of the experiment.
706        """
707        return {
708            'id': self.id,
709            'title': self.title,
710            'hypothesis': self.hypothesis,
711            'axis': self.axis,
712            'timestamp': self.timestamp,
713            'tests': self.tests,
714            'invert_output': self.invert_output
715        }

Create a new experiment to track PID testing.

Arguments:
  • title (str): Title of the experiment.
  • hypothesis (str): Hypothesis being tested.
  • axis (str): Axis being tested ('x', 'y', or 'z').
  • invert_output (bool): Whether output was inverted for this axis. Defaults to False.
Experiment(title, hypothesis, axis, invert_output=False)
294    def __init__(self, title, hypothesis, axis, invert_output=False):
295        self.id = str(uuid.uuid4())
296        self.title = title
297        self.hypothesis = hypothesis
298        self.axis = axis.lower()
299        self.timestamp = datetime.now().isoformat()
300        self.tests = []
301        self.invert_output = invert_output
id
title
hypothesis
axis
timestamp
tests
invert_output
def add_test(self, test_result):
303    def add_test(self, test_result):
304        """Add a test result from PIDTester.test_gains().
305        
306        Args:
307            test_result (dict): Result dictionary from test_gains().
308        
309        Raises:
310            ValueError: If test_result doesn't contain required keys.
311        """
312        # Basic validation
313        if 'gains' not in test_result or 'trials' not in test_result:
314            raise ValueError("Invalid test_result: must contain 'gains' and 'trials'")
315        
316        self.tests.append(test_result)

Add a test result from PIDTester.test_gains().

Arguments:
  • test_result (dict): Result dictionary from test_gains().
Raises:
  • ValueError: If test_result doesn't contain required keys.
def get_averaged_metrics(self, test_index):
318    def get_averaged_metrics(self, test_index):
319        """Get averaged metrics across trials for a specific test.
320        
321        Args:
322            test_index (int): Index of the test in self.tests.
323        
324        Returns:
325            dict: Averaged metrics with mean, std, min, max for each metric.
326        
327        Raises:
328            IndexError: If test_index is out of range.
329        """
330        if test_index >= len(self.tests):
331            raise IndexError(f"Test index {test_index} out of range")
332        
333        test = self.tests[test_index]
334        trials = test['trials']
335        
336        if len(trials) == 1:
337            return trials[0]['metrics']
338        
339        metric_keys = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
340        averaged = {}
341        
342        for key in metric_keys:
343            values = [t['metrics'][key] for t in trials if t['metrics'][key] is not None]
344            if values:
345                averaged[key] = {
346                    'mean': np.mean(values),
347                    'std': np.std(values),
348                    'min': np.min(values),
349                    'max': np.max(values)
350                }
351            else:
352                averaged[key] = None
353        
354        return averaged

Get averaged metrics across trials for a specific test.

Arguments:
  • test_index (int): Index of the test in self.tests.
Returns:

dict: Averaged metrics with mean, std, min, max for each metric.

Raises:
  • IndexError: If test_index is out of range.
def print_summary(self):
356    def print_summary(self):
357        """Print summary table of all tests in this experiment."""
358        print(f"\nExperiment: {self.title}")
359        print(f"Hypothesis: {self.hypothesis}")
360        print(f"Axis: {self.axis}")
361        print(f"Inverted Output: {self.invert_output}")
362        print(f"Tests: {len(self.tests)}")
363        print("\n" + "="*90)  # <-- MAKE WIDER
364        print(f"{'Kp':>8} {'Ki':>8} {'Kd':>8} {'Rise':>10} {'Settling':>10} {'Overshoot':>10} {'SS Error':>10} {'Total Steps':>8}")  # <-- ADD COLUMN
365        print("="*90)
366        
367        for i, test in enumerate(self.tests):
368            gains = test['gains']
369            metrics = self.get_averaged_metrics(i)
370            
371            # Handle both single trial (dict) and averaged (dict with 'mean')
372            if isinstance(metrics['rise_time'], dict):
373                rise = metrics['rise_time']['mean']
374                settling = metrics['settling_time']['mean']
375                overshoot = metrics['overshoot']['mean']
376                ss_error = metrics['steady_state_error']['mean']
377            else:
378                rise = metrics['rise_time']
379                settling = metrics['settling_time']
380                overshoot = metrics['overshoot']
381                ss_error = metrics['steady_state_error']
382            
383            # Get total time from first trial
384            total_time = test['trials'][0]['total_time']
385            
386            # Format None values as 'N/A' instead of trying to format as float
387            rise_str = f"{rise:>10.2f}" if rise is not None else f"{'N/A':>10}"
388            settling_str = f"{settling:>10.2f}" if settling is not None else f"{'N/A':>10}"
389            overshoot_str = f"{overshoot:>10.2f}" if overshoot is not None else f"{'N/A':>10}"
390            ss_error_str = f"{ss_error:>10.6f}" if ss_error is not None else f"{'N/A':>10}"
391            
392            print(f"{gains['kp']:>8.2f} {gains['ki']:>8.2f} {gains['kd']:>8.2f} "
393                f"{rise_str} {settling_str} {overshoot_str} {ss_error_str} {total_time:>8}")  # <-- ADD TOTAL
394        
395        print("="*90)

Print summary table of all tests in this experiment.

def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'):
439    def plot_metrics_summary(self, metric_names='all', separate_subplots=True, plot_type='lines'):
440        """Plot averaged metrics across all tests.
441        
442        Args:
443            metric_names (list of str or 'all', optional): Which metrics to plot for line plots.
444                If 'all' or None, plots all metrics. Ignored for scatter plots.
445                Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'.
446            separate_subplots (bool): If True, each metric gets its own subplot. 
447                If False, all on one plot. Only applies to line plots. Defaults to True.
448            plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for 
449                speed vs stability tradeoff. Defaults to 'lines'.
450        
451        Returns:
452            tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests.
453        """
454        if not self.tests:
455            print("No tests to plot")
456            return None, None
457        
458        # Scatter plot: speed vs stability
459        if plot_type == 'scatter':
460            return self._plot_speed_vs_stability()
461        
462        # Line plot: existing behavior
463        if metric_names == 'all' or metric_names is None:
464            metric_names = ['rise_time', 'settling_time', 'overshoot', 'steady_state_error']
465        
466        # Determine which gain varies
467        gain_key, x_values, x_labels = self._determine_varying_gain()
468        
469        # Collect data
470        data = {metric: [] for metric in metric_names}
471        errors = {metric: [] for metric in metric_names}
472        valid_indices = {metric: [] for metric in metric_names}
473        
474        for i, test in enumerate(self.tests):
475            metrics = self.get_averaged_metrics(i)
476            for metric in metric_names:
477                if isinstance(metrics[metric], dict):
478                    value = metrics[metric]['mean']
479                    error = metrics[metric]['std']
480                else:
481                    value = metrics[metric]
482                    error = 0
483                
484                if value is not None:
485                    data[metric].append(value)
486                    errors[metric].append(error)
487                    valid_indices[metric].append(i)
488        
489        # Create plots
490        if separate_subplots:
491            fig, axes = plt.subplots(len(metric_names), 1, figsize=(10, 4*len(metric_names)))
492            if len(metric_names) == 1:
493                axes = [axes]
494            
495            for ax, metric in zip(axes, metric_names):
496                if data[metric]:
497                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
498                    valid_x_labels = [x_labels[i] for i in valid_indices[metric]]
499                    
500                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric], 
501                            marker='o', capsize=5, capthick=2)
502                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
503                    ax.set_ylabel(metric.replace('_', ' ').title())
504                    ax.set_xscale('log')
505                    ax.set_xticks(valid_x_values)
506                    ax.set_xticklabels(valid_x_labels)
507                    ax.grid(True, alpha=0.3, which='both')
508                else:
509                    ax.text(0.5, 0.5, f'No valid data for {metric}', 
510                        ha='center', va='center', transform=ax.transAxes)
511                    ax.set_xlabel(gain_key if gain_key else 'Test Index')
512                    ax.set_ylabel(metric.replace('_', ' ').title())
513        else:
514            fig, ax = plt.subplots(figsize=(10, 6))
515            for metric in metric_names:
516                if data[metric]:
517                    valid_x_values = [x_values[i] for i in valid_indices[metric]]
518                    
519                    ax.errorbar(valid_x_values, data[metric], yerr=errors[metric],
520                            marker='o', capsize=5, capthick=2, label=metric.replace('_', ' ').title())
521            
522            ax.set_xlabel(gain_key if gain_key else 'Test Index')
523            ax.set_ylabel('Value')
524            ax.set_xscale('log')
525            if x_values:
526                ax.set_xticks(x_values)
527                ax.set_xticklabels(x_labels)
528            ax.legend()
529            ax.grid(True, alpha=0.3, which='both')
530        
531        fig.suptitle(f"{self.title}\n{self.hypothesis}", fontsize=12)
532        fig.tight_layout()
533        
534        return fig, axes if separate_subplots else ax

Plot averaged metrics across all tests.

Arguments:
  • metric_names (list of str or 'all', optional): Which metrics to plot for line plots. If 'all' or None, plots all metrics. Ignored for scatter plots. Options: 'rise_time', 'settling_time', 'overshoot', 'steady_state_error'.
  • separate_subplots (bool): If True, each metric gets its own subplot. If False, all on one plot. Only applies to line plots. Defaults to True.
  • plot_type (str): Type of plot - 'lines' for metric trends or 'scatter' for speed vs stability tradeoff. Defaults to 'lines'.
Returns:

tuple: (fig, axes) Matplotlib figure and axes objects, or (None, None) if no tests.

def plot_trajectory(self, test_index, trial_index=0):
614    def plot_trajectory(self, test_index, trial_index=0):
615        """Plot detailed trajectory data for a specific trial.
616        
617        Shows position, error, and control signal over time.
618        
619        Args:
620            test_index (int): Index of the test in self.tests.
621            trial_index (int): Index of the trial within that test. Defaults to 0.
622        
623        Returns:
624            tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots).
625        
626        Raises:
627            IndexError: If test_index or trial_index is out of range.
628        """
629        if test_index >= len(self.tests):
630            raise IndexError(f"Test index {test_index} out of range")
631        
632        test = self.tests[test_index]
633        
634        if trial_index >= len(test['trials']):
635            raise IndexError(f"Trial index {trial_index} out of range")
636        
637        trial = test['trials'][trial_index]
638        gains = test['gains']
639        
640        # Extract data
641        time_data = trial['time_data']
642        velocity_data = trial['velocity_data']
643        torque_data = trial['torque_data']
644        position_data = trial['position_data']
645        error_data = trial['error_data']
646        control_data = trial['control_data']
647        target = trial['target']
648        
649        # Create figure with 3 subplots
650        fig, axes = plt.subplots(5, 1, figsize=(12, 10))
651        
652        # Plot 1: Position vs Time
653        axes[0].plot(time_data, position_data, 'b-', linewidth=2, label='Actual Position')
654        axes[0].axhline(y=target, color='r', linestyle='--', linewidth=1, label='Target')
655        axes[0].set_xlabel('Time (steps)')
656        axes[0].set_ylabel('Position')
657        axes[0].set_title('Position vs Time')
658        axes[0].legend()
659        axes[0].grid(True, alpha=0.3)
660        
661        # Plot 2: Error vs Time
662        axes[1].plot(time_data, error_data, 'g-', linewidth=2)
663        axes[1].axhline(y=0, color='k', linestyle='--', linewidth=1)
664        axes[1].set_xlabel('Time (steps)')
665        axes[1].set_ylabel('Error')
666        axes[1].set_title('Error vs Time')
667        axes[1].grid(True, alpha=0.3)
668        
669        # Plot 3: Control Signal vs Time
670        axes[2].plot(time_data[:len(control_data)], control_data, 'orange', linewidth=2)
671        axes[2].axhline(y=0, color='k', linestyle='--', linewidth=1)
672        axes[2].set_xlabel('Time (steps)')
673        axes[2].set_ylabel('Control Signal')
674        axes[2].set_title('Control Signal vs Time')
675        axes[2].grid(True, alpha=0.3)
676
677        # Plot 4: Velocity vs Time
678        axes[3].plot(time_data, velocity_data, 'purple', linewidth=2)
679        axes[3].axhline(y=0, color='k', linestyle='--', linewidth=1)
680        axes[3].set_xlabel('Time (steps)')
681        axes[3].set_ylabel('Velocity')
682        axes[3].set_title('Velocity vs Time')
683        axes[3].grid(True, alpha=0.3)
684        
685        # Plot 5: Torque vs Time
686        axes[4].plot(time_data[:len(torque_data)], torque_data, 'brown', linewidth=2)
687        axes[4].axhline(y=0, color='k', linestyle='--', linewidth=1)
688        axes[4].set_xlabel('Time (steps)')
689        axes[4].set_ylabel('Torque')
690        axes[4].set_title('Torque vs Time')
691        axes[4].grid(True, alpha=0.3)
692        # Overall title
693        fig.suptitle(
694            f"{self.title}\nKp={gains['kp']}, Ki={gains['ki']}, Kd={gains['kd']} | Trial {trial_index}",
695            fontsize=12
696        )
697        fig.tight_layout()
698        
699        return fig, axes

Plot detailed trajectory data for a specific trial.

Shows position, error, and control signal over time.

Arguments:
  • test_index (int): Index of the test in self.tests.
  • trial_index (int): Index of the trial within that test. Defaults to 0.
Returns:

tuple: (fig, axes) Matplotlib figure and axes objects (3 subplots).

Raises:
  • IndexError: If test_index or trial_index is out of range.
def to_dict(self):
701    def to_dict(self):
702        """Convert experiment to dictionary for JSON serialization.
703        
704        Returns:
705            dict: Dictionary representation of the experiment.
706        """
707        return {
708            'id': self.id,
709            'title': self.title,
710            'hypothesis': self.hypothesis,
711            'axis': self.axis,
712            'timestamp': self.timestamp,
713            'tests': self.tests,
714            'invert_output': self.invert_output
715        }

Convert experiment to dictionary for JSON serialization.

Returns:

dict: Dictionary representation of the experiment.

class ExperimentLog:
718class ExperimentLog:
719    """Manage a log of PID experiments.
720    
721    Args:
722        filepath (str or Path): Path to the JSON log file.
723    """
724    
725    def __init__(self, filepath):
726        self.filepath = Path(filepath)
727        self.experiments = []
728    
729    def add_experiment(self, experiment):
730        """Add an Experiment to the log.
731        
732        Args:
733            experiment (Experiment): Experiment object to add.
734        
735        Raises:
736            TypeError: If experiment is not an Experiment object.
737        """
738        if not isinstance(experiment, Experiment):
739            raise TypeError("Must add an Experiment object")
740        self.experiments.append(experiment)
741    
742    def save(self):
743        """Save experiments to JSON file, preserving existing experiments.
744        
745        Checks for duplicate IDs and updates existing experiments.
746        """
747        # Load existing data if file exists
748        if self.filepath.exists():
749            with open(self.filepath, 'r') as f:
750                data = json.load(f)
751            existing_experiments = data.get('experiments', [])
752        else:
753            existing_experiments = []
754        
755        # Create dict of existing experiments by ID
756        existing_by_id = {exp['id']: exp for exp in existing_experiments}
757        
758        # Add or update with current experiments
759        for exp in self.experiments:
760            existing_by_id[exp.id] = exp.to_dict()
761        
762        # Save all experiments
763        data = {'experiments': list(existing_by_id.values())}
764        
765        with open(self.filepath, 'w') as f:
766            json.dump(data, f, indent=2)
767        
768        print(f"Saved {len(self.experiments)} experiments to {self.filepath}")
769    
770    @classmethod
771    def load(cls, filepath):
772        """Load experiments from JSON file.
773        
774        Args:
775            filepath (str or Path): Path to the JSON log file.
776        
777        Returns:
778            ExperimentLog: Log with loaded Experiment objects.
779        """
780        log = cls(filepath)
781        
782        if not log.filepath.exists():
783            print(f"No existing file at {filepath}, starting new log")
784            return log
785        
786        with open(log.filepath, 'r') as f:
787            data = json.load(f)
788        
789        for exp_dict in data.get('experiments', []):
790            exp = cls._from_dict(exp_dict)
791            log.experiments.append(exp)
792        
793        print(f"Loaded {len(log.experiments)} experiments from {filepath}")
794        return log
795    
796    @staticmethod
797    def _from_dict(exp_dict):
798        """Reconstruct an Experiment object from a dictionary.
799        
800        Args:
801            exp_dict (dict): Dictionary representation of an experiment.
802        
803        Returns:
804            Experiment: Reconstructed Experiment object.
805        """
806        exp = Experiment(
807            title=exp_dict['title'],
808            hypothesis=exp_dict['hypothesis'],
809            axis=exp_dict['axis'],
810            invert_output=exp_dict.get('invert_output', False)
811        )
812        exp.id = exp_dict['id']
813        exp.timestamp = exp_dict['timestamp']
814        exp.tests = exp_dict['tests']
815        return exp
816    
817    def get_experiment(self, identifier):
818        """Get experiment by title or short hash (first 6 characters of ID).
819        
820        Args:
821            identifier (str): Either the full/partial title or first 6 chars of ID.
822        
823        Returns:
824            Experiment or None: Matching experiment, or None if not found.
825        """
826        # Try short hash first
827        if len(identifier) == 6:
828            for exp in self.experiments:
829                if exp.id[:6] == identifier:
830                    return exp
831        
832        # Try title (case-insensitive partial match)
833        identifier_lower = identifier.lower()
834        for exp in self.experiments:
835            if identifier_lower in exp.title.lower():
836                return exp
837        
838        return None
839    
840    def list_experiments(self):
841        """Print summary table of all experiments."""
842        if not self.experiments:
843            print("No experiments in log")
844            return
845        
846        print("\n" + "="*100)
847        print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}")
848        print("="*100)
849        
850        for exp in self.experiments:
851            short_id = exp.id[:6]
852            title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title
853            hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis
854            num_tests = len(exp.tests)
855            
856            print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}")
857        
858        print("="*100)

Manage a log of PID experiments.

Arguments:
  • filepath (str or Path): Path to the JSON log file.
ExperimentLog(filepath)
725    def __init__(self, filepath):
726        self.filepath = Path(filepath)
727        self.experiments = []
filepath
experiments
def add_experiment(self, experiment):
729    def add_experiment(self, experiment):
730        """Add an Experiment to the log.
731        
732        Args:
733            experiment (Experiment): Experiment object to add.
734        
735        Raises:
736            TypeError: If experiment is not an Experiment object.
737        """
738        if not isinstance(experiment, Experiment):
739            raise TypeError("Must add an Experiment object")
740        self.experiments.append(experiment)

Add an Experiment to the log.

Arguments:
  • experiment (Experiment): Experiment object to add.
Raises:
  • TypeError: If experiment is not an Experiment object.
def save(self):
742    def save(self):
743        """Save experiments to JSON file, preserving existing experiments.
744        
745        Checks for duplicate IDs and updates existing experiments.
746        """
747        # Load existing data if file exists
748        if self.filepath.exists():
749            with open(self.filepath, 'r') as f:
750                data = json.load(f)
751            existing_experiments = data.get('experiments', [])
752        else:
753            existing_experiments = []
754        
755        # Create dict of existing experiments by ID
756        existing_by_id = {exp['id']: exp for exp in existing_experiments}
757        
758        # Add or update with current experiments
759        for exp in self.experiments:
760            existing_by_id[exp.id] = exp.to_dict()
761        
762        # Save all experiments
763        data = {'experiments': list(existing_by_id.values())}
764        
765        with open(self.filepath, 'w') as f:
766            json.dump(data, f, indent=2)
767        
768        print(f"Saved {len(self.experiments)} experiments to {self.filepath}")

Save experiments to JSON file, preserving existing experiments.

Checks for duplicate IDs and updates existing experiments.

@classmethod
def load(cls, filepath):
770    @classmethod
771    def load(cls, filepath):
772        """Load experiments from JSON file.
773        
774        Args:
775            filepath (str or Path): Path to the JSON log file.
776        
777        Returns:
778            ExperimentLog: Log with loaded Experiment objects.
779        """
780        log = cls(filepath)
781        
782        if not log.filepath.exists():
783            print(f"No existing file at {filepath}, starting new log")
784            return log
785        
786        with open(log.filepath, 'r') as f:
787            data = json.load(f)
788        
789        for exp_dict in data.get('experiments', []):
790            exp = cls._from_dict(exp_dict)
791            log.experiments.append(exp)
792        
793        print(f"Loaded {len(log.experiments)} experiments from {filepath}")
794        return log

Load experiments from JSON file.

Arguments:
  • filepath (str or Path): Path to the JSON log file.
Returns:

ExperimentLog: Log with loaded Experiment objects.

def get_experiment(self, identifier):
817    def get_experiment(self, identifier):
818        """Get experiment by title or short hash (first 6 characters of ID).
819        
820        Args:
821            identifier (str): Either the full/partial title or first 6 chars of ID.
822        
823        Returns:
824            Experiment or None: Matching experiment, or None if not found.
825        """
826        # Try short hash first
827        if len(identifier) == 6:
828            for exp in self.experiments:
829                if exp.id[:6] == identifier:
830                    return exp
831        
832        # Try title (case-insensitive partial match)
833        identifier_lower = identifier.lower()
834        for exp in self.experiments:
835            if identifier_lower in exp.title.lower():
836                return exp
837        
838        return None

Get experiment by title or short hash (first 6 characters of ID).

Arguments:
  • identifier (str): Either the full/partial title or first 6 chars of ID.
Returns:

Experiment or None: Matching experiment, or None if not found.

def list_experiments(self):
840    def list_experiments(self):
841        """Print summary table of all experiments."""
842        if not self.experiments:
843            print("No experiments in log")
844            return
845        
846        print("\n" + "="*100)
847        print(f"{'ID':>8} {'Title':<30} {'Hypothesis':<40} {'Tests':>6}")
848        print("="*100)
849        
850        for exp in self.experiments:
851            short_id = exp.id[:6]
852            title = exp.title[:28] + '..' if len(exp.title) > 30 else exp.title
853            hypothesis = exp.hypothesis[:38] + '..' if len(exp.hypothesis) > 40 else exp.hypothesis
854            num_tests = len(exp.tests)
855            
856            print(f"{short_id:>8} {title:<30} {hypothesis:<40} {num_tests:>6}")
857        
858        print("="*100)

Print summary table of all experiments.