library.model_manager

Class for managing Keras model training, saving, and loading with metadata.

This manager handles the complete lifecycle of Keras models including training, saving, loading, and metadata management. It works with any Keras model architecture including CNNs, U-Nets, ResNets, and custom architectures.

Examples:

U-Net for image segmentation (Task 5)::

from library.model_manager import KerasModelManager
from tensorflow.keras.callbacks import EarlyStopping

# Initialize with Task 5 naming convention
manager = KerasModelManager(
    model_name='smith_123456_unet_model_128px',
    student_name='John Smith',
    student_number='123456',
    patch_size=128,
    architecture='U-Net',
    task='root_segmentation'
)

# Set model with custom metrics
manager.set_model(unet_model, custom_objects={'f1': f1})

# Train with early stopping
early_stop = EarlyStopping(monitor='val_loss', patience=3, 
                           restore_best_weights=True)
manager.train(X_train, y_train, X_val, y_val, 
             epochs=50, callbacks=[early_stop])

# Save everything
manager.save('models')

# Get best metrics
metrics = manager.get_best_metrics()
print(f"Best F1: {metrics['best_val_f1']:.4f}")

CNN classifier for MNIST::

manager = KerasModelManager(
    model_name='mnist_cnn',
    version='1.0',
    dataset='MNIST',
    input_shape=(28, 28, 1),
    num_classes=10,
    architecture='CNN'
)

manager.set_model(cnn_model)
manager.train(X_train, y_train, X_val, y_val, epochs=20)
manager.save('models')

ResNet50 for plant classification::

manager = KerasModelManager(
    model_name='resnet50_plants',
    version='2.1',
    architecture='ResNet50',
    dataset='PlantCLEF',
    pretrained='ImageNet',
    fine_tuned_layers=10,
    input_size=224
)

manager.set_model(resnet_model)
manager.train(train_generator, steps_per_epoch=100,
             validation_data=val_generator, validation_steps=20,
             epochs=30)
manager.save('models')

LSTM for time series prediction::

manager = KerasModelManager(
    model_name='stock_lstm',
    version='1.0',
    architecture='LSTM',
    sequence_length=60,
    features=5,
    prediction_horizon=1
)

manager.set_model(lstm_model)
manager.train(X_train, y_train, X_val, y_val, epochs=50, batch_size=32)
manager.save('models')

Loading a saved model::

# Initialize manager with same naming
manager = KerasModelManager(
    model_name='smith_123456_unet_model_128px'
)

# Set custom objects before loading
manager.set_model(None, custom_objects={'f1': f1})

# Load model and history
model, history = manager.load('models/smith_123456_unet_model_128px.h5')

# Use loaded model
predictions = model.predict(X_test)
Attributes:
  • model_name (str): Base name for the model files.
  • version (str): Optional version string or number.
  • metadata (dict): Additional metadata stored with the model.
  • model (keras.Model): The Keras model being managed.
  • history (keras.callbacks.History): Training history from model.fit().
  • custom_objects (dict): Custom objects needed for model loading.
  • base_name (str): Generated base filename.
  • model_path (str): Path to saved model file.
  • history_path (str): Path to saved history JSON.
  • summary_path (str): Path to saved summary text file.
  1# library/model_manager.py
  2"""Class for managing Keras model training, saving, and loading with metadata.
  3
  4This manager handles the complete lifecycle of Keras models including training,
  5saving, loading, and metadata management. It works with any Keras model architecture
  6including CNNs, U-Nets, ResNets, and custom architectures.
  7
  8Examples:
  9    U-Net for image segmentation (Task 5)::
 10    
 11        from library.model_manager import KerasModelManager
 12        from tensorflow.keras.callbacks import EarlyStopping
 13        
 14        # Initialize with Task 5 naming convention
 15        manager = KerasModelManager(
 16            model_name='smith_123456_unet_model_128px',
 17            student_name='John Smith',
 18            student_number='123456',
 19            patch_size=128,
 20            architecture='U-Net',
 21            task='root_segmentation'
 22        )
 23        
 24        # Set model with custom metrics
 25        manager.set_model(unet_model, custom_objects={'f1': f1})
 26        
 27        # Train with early stopping
 28        early_stop = EarlyStopping(monitor='val_loss', patience=3, 
 29                                   restore_best_weights=True)
 30        manager.train(X_train, y_train, X_val, y_val, 
 31                     epochs=50, callbacks=[early_stop])
 32        
 33        # Save everything
 34        manager.save('models')
 35        
 36        # Get best metrics
 37        metrics = manager.get_best_metrics()
 38        print(f"Best F1: {metrics['best_val_f1']:.4f}")
 39    
 40    CNN classifier for MNIST::
 41    
 42        manager = KerasModelManager(
 43            model_name='mnist_cnn',
 44            version='1.0',
 45            dataset='MNIST',
 46            input_shape=(28, 28, 1),
 47            num_classes=10,
 48            architecture='CNN'
 49        )
 50        
 51        manager.set_model(cnn_model)
 52        manager.train(X_train, y_train, X_val, y_val, epochs=20)
 53        manager.save('models')
 54    
 55    ResNet50 for plant classification::
 56    
 57        manager = KerasModelManager(
 58            model_name='resnet50_plants',
 59            version='2.1',
 60            architecture='ResNet50',
 61            dataset='PlantCLEF',
 62            pretrained='ImageNet',
 63            fine_tuned_layers=10,
 64            input_size=224
 65        )
 66        
 67        manager.set_model(resnet_model)
 68        manager.train(train_generator, steps_per_epoch=100,
 69                     validation_data=val_generator, validation_steps=20,
 70                     epochs=30)
 71        manager.save('models')
 72    
 73    LSTM for time series prediction::
 74    
 75        manager = KerasModelManager(
 76            model_name='stock_lstm',
 77            version='1.0',
 78            architecture='LSTM',
 79            sequence_length=60,
 80            features=5,
 81            prediction_horizon=1
 82        )
 83        
 84        manager.set_model(lstm_model)
 85        manager.train(X_train, y_train, X_val, y_val, epochs=50, batch_size=32)
 86        manager.save('models')
 87    
 88    Loading a saved model::
 89    
 90        # Initialize manager with same naming
 91        manager = KerasModelManager(
 92            model_name='smith_123456_unet_model_128px'
 93        )
 94        
 95        # Set custom objects before loading
 96        manager.set_model(None, custom_objects={'f1': f1})
 97        
 98        # Load model and history
 99        model, history = manager.load('models/smith_123456_unet_model_128px.h5')
100        
101        # Use loaded model
102        predictions = model.predict(X_test)
103
104Attributes:
105    model_name (str): Base name for the model files.
106    version (str): Optional version string or number.
107    metadata (dict): Additional metadata stored with the model.
108    model (keras.Model): The Keras model being managed.
109    history (keras.callbacks.History): Training history from model.fit().
110    custom_objects (dict): Custom objects needed for model loading.
111    base_name (str): Generated base filename.
112    model_path (str): Path to saved model file.
113    history_path (str): Path to saved history JSON.
114    summary_path (str): Path to saved summary text file.
115"""
116
117import json
118import numpy as np
119from pathlib import Path
120from datetime import datetime
121from tensorflow import keras
122
123
124class KerasModelManager:
125    """Manages Keras model lifecycle: training, saving, loading, and metadata."""
126    
127    def __init__(self, model_name, version=None, notes=None, **metadata):
128        """
129        Initialize the model manager.
130        
131        Args:
132            model_name: Base name for the model files. For Task 5, use format:
133                       'studentname_studentnumber_modeltype_patchsizepx'
134            version: Optional version string or number (e.g., '1.0', 'v2', 2).
135            notes: Optional initial note/hypothesis for this model.
136            **metadata: Additional metadata to store with the model. Common examples:
137                       - student_name (str): Student name for academic projects
138                       - student_number (str): Student number for academic projects
139                       - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50')
140                       - patch_size (int): Patch size for segmentation models
141                       - input_shape (tuple): Input shape for the model
142                       - num_classes (int): Number of output classes
143                       - dataset (str): Dataset name
144                       - pretrained (str): Pretrained weights source
145                       - task (str): Task description
146        
147        Examples:
148            >>> # U-Net for segmentation with notes
149            >>> manager = KerasModelManager(
150            ...     'smith_123456_unet_model_128px',
151            ...     notes="Hypothesis: 128px patches will achieve F1 > 0.90",
152            ...     patch_size=128,
153            ...     architecture='U-Net'
154            ... )
155            
156            >>> # Versioned CNN classifier
157            >>> manager = KerasModelManager(
158            ...     'mnist_cnn',
159            ...     version='1.0',
160            ...     dataset='MNIST',
161            ...     input_shape=(28, 28, 1)
162            ... )
163        """
164        self.model_name = model_name
165        self.version = version
166        self.metadata = metadata
167        self.model = None
168        self.history = None
169        self.custom_objects = {}
170        self._callbacks = None
171        self._notes = []
172        
173        # Add initial note if provided
174        if notes:
175            self._add_note(notes)
176        
177        # Generate filenames
178        version_str = f"_v{version}" if version else ""
179        self.base_name = f"{model_name}{version_str}"
180        self.model_path = f"{self.base_name}.h5"
181        self.history_path = f"{self.base_name}_history.json"
182        self.summary_path = f"{self.base_name}_summary.txt"
183    
184    def set_model(self, model, custom_objects=None):
185        """
186        Set the model to be managed.
187        
188        Args:
189            model: Compiled Keras model, or None if only loading.
190            custom_objects: Dict of custom objects needed for model loading
191                           (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}).
192        
193        Examples:
194            >>> # With custom metric
195            >>> manager.set_model(unet_model, custom_objects={'f1': f1})
196            
197            >>> # Standard model, no custom objects
198            >>> manager.set_model(resnet_model)
199            
200            >>> # Preparing to load (no model yet)
201            >>> manager.set_model(None, custom_objects={'f1': f1})
202        """
203        self.model = model
204        if custom_objects:
205            self.custom_objects = custom_objects
206    
207    def _add_note(self, note):
208        """Internal method to add a timestamped note."""
209        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
210        self._notes.append({
211            'timestamp': timestamp,
212            'note': note
213        })
214    
215    @property
216    def notes(self):
217        """Get all notes as a formatted string."""
218        if not self._notes:
219            return ""
220        return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes])
221    
222    @notes.setter
223    def notes(self, note):
224        """Append a new note with automatic timestamp."""
225        self._add_note(note)
226    
227    def set_notes(self, note, replace=False):
228        """
229        Set notes with option to replace all existing notes.
230        
231        Args:
232            note: The note text to add or set.
233            replace: If True, replaces all existing notes. If False, appends.
234        
235        Examples:
236            >>> # Append (default)
237            >>> manager.set_notes("Added dropout layers")
238            
239            >>> # Replace all notes
240            >>> manager.set_notes("Starting fresh experiment", replace=True)
241        """
242        if replace:
243            self._notes = []
244        self._add_note(note)
245    
246    def clear_notes(self):
247        """Clear all notes."""
248        self._notes = []
249    
250    def get_notes_list(self):
251        """
252        Get notes as a list of dictionaries.
253        
254        Returns:
255            List of dicts with 'timestamp' and 'note' keys.
256        """
257        return self._notes.copy()
258    
259    def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 
260            batch_size=16, callbacks=None, verbose=1, **fit_kwargs):
261        """
262        Train the model and store history.
263        
264        Args:
265            X_train: Training data (numpy array, generator, or dataset).
266            y_train: Training labels (numpy array, or None if using generator).
267            X_val: Optional validation data (or None if using validation_data kwarg).
268            y_val: Optional validation labels (or None if using validation_data kwarg).
269            epochs: Maximum number of epochs.
270            batch_size: Batch size for training (ignored if using generator).
271            callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint).
272            verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch).
273            **fit_kwargs: Additional arguments passed to model.fit()
274                        (e.g., steps_per_epoch, validation_steps, validation_data).
275        
276        Returns:
277            Training history object.
278        """
279        if self.model is None:
280            raise ValueError("Model not set. Call set_model() first.")
281        
282        # Store callbacks for later serialization
283        self._callbacks = callbacks
284        
285        # Handle validation_data - check if it's in fit_kwargs first
286        validation_data = fit_kwargs.pop('validation_data', None)
287        
288        # If not in fit_kwargs and we have X_val/y_val, construct it
289        if validation_data is None and X_val is not None and y_val is not None:
290            validation_data = (X_val, y_val)
291        
292        self.history = self.model.fit(
293            X_train, y_train,
294            validation_data=validation_data,
295            batch_size=batch_size,
296            epochs=epochs,
297            callbacks=callbacks,
298            verbose=verbose,
299            **fit_kwargs
300        )
301        
302        return self.history
303    
304    def save(self, output_dir='.'):
305        """
306        Save model, history, and summary to disk.
307        
308        Saves three files:
309        1. .h5 file - Complete model (architecture, weights, optimizer state)
310        2. _history.json - Training metrics for all epochs
311        3. _summary.txt - Human-readable summary
312        
313        Args:
314            output_dir: Directory to save files. Created if it doesn't exist.
315                       Default is current directory.
316        
317        Examples:
318            >>> # Save to current directory
319            >>> manager.save()
320            
321            >>> # Save to specific directory
322            >>> manager.save('trained_models')
323            
324            >>> # Save to nested path
325            >>> manager.save('models/unet/version_1')
326        """
327        if self.model is None:
328            raise ValueError("No model to save. Train or set a model first.")
329        
330        output_dir = Path(output_dir)
331        output_dir.mkdir(parents=True, exist_ok=True)
332        
333        # Save model
334        model_file = output_dir / self.model_path
335        self.model.save(str(model_file))
336        print(f"Model saved: {model_file}")
337        
338        # Save history if available
339        if self.history is not None:
340            self._save_history(output_dir)
341            self._save_summary(output_dir)
342    
343    def _serialize_custom_objects(self):
344        """
345        Serialize information about custom objects for documentation.
346        
347        Returns:
348            List of dictionaries with custom object metadata.
349        """
350        import inspect
351        import hashlib
352        
353        if not self.custom_objects:
354            return []
355        
356        serialized = []
357        
358        for name, obj in self.custom_objects.items():
359            obj_info = {
360                'name': name,
361                'type': type(obj).__name__
362            }
363            
364            # Try to get useful information about the object
365            try:
366                # Get module
367                if hasattr(obj, '__module__'):
368                    obj_info['module'] = obj.__module__
369                
370                # Get docstring
371                if hasattr(obj, '__doc__') and obj.__doc__:
372                    obj_info['docstring'] = obj.__doc__.strip()
373                
374                # For functions, get signature
375                if callable(obj):
376                    try:
377                        sig = inspect.signature(obj)
378                        obj_info['signature'] = str(sig)
379                    except (ValueError, TypeError):
380                        pass
381                    
382                    # Get source code hash (for version tracking)
383                    try:
384                        source = inspect.getsource(obj)
385                        source_hash = hashlib.md5(source.encode()).hexdigest()
386                        obj_info['source_hash'] = source_hash
387                    except (OSError, TypeError):
388                        pass
389                
390                serialized.append(obj_info)
391                
392            except Exception as e:
393                # If we can't serialize, at least save the name and type
394                obj_info['error'] = str(e)
395                serialized.append(obj_info)
396        
397        return serialized
398    
399    def _serialize_callbacks(self, callbacks):
400        """
401        Serialize callback configurations.
402        
403        Args:
404            callbacks: List of Keras callbacks or None.
405        
406        Returns:
407            List of serialized callback dictionaries.
408        """
409        import warnings
410        
411        if callbacks is None:
412            return []
413        
414        serialized = []
415        
416        for callback in callbacks:
417            callback_info = {
418                'class': callback.__class__.__name__,
419                'module': callback.__class__.__module__
420            }
421            
422            # Try to extract configuration
423            try:
424                # Get callback config if available
425                if hasattr(callback, 'get_config'):
426                    callback_info['config'] = callback.get_config()
427                else:
428                    # Manually extract common attributes
429                    config = {}
430                    
431                    # EarlyStopping attributes
432                    if hasattr(callback, 'monitor'):
433                        config['monitor'] = callback.monitor
434                    if hasattr(callback, 'patience'):
435                        config['patience'] = callback.patience
436                    if hasattr(callback, 'restore_best_weights'):
437                        config['restore_best_weights'] = callback.restore_best_weights
438                    if hasattr(callback, 'mode'):
439                        config['mode'] = callback.mode
440                    if hasattr(callback, 'min_delta'):
441                        config['min_delta'] = callback.min_delta
442                    
443                    # ReduceLROnPlateau attributes
444                    if hasattr(callback, 'factor'):
445                        config['factor'] = callback.factor
446                    if hasattr(callback, 'cooldown'):
447                        config['cooldown'] = callback.cooldown
448                    if hasattr(callback, 'min_lr'):
449                        config['min_lr'] = callback.min_lr
450                    
451                    # ModelCheckpoint attributes
452                    if hasattr(callback, 'filepath'):
453                        config['filepath'] = callback.filepath
454                    if hasattr(callback, 'save_best_only'):
455                        config['save_best_only'] = callback.save_best_only
456                    if hasattr(callback, 'save_weights_only'):
457                        config['save_weights_only'] = callback.save_weights_only
458                    
459                    # LearningRateScheduler
460                    if hasattr(callback, 'schedule'):
461                        warnings.warn(
462                            f"Callback {callback.__class__.__name__} has a 'schedule' "
463                            f"function that cannot be serialized. Saving class name only.",
464                            UserWarning
465                        )
466                        config['schedule'] = '<function>'
467                    
468                    if config:
469                        callback_info['config'] = config
470                    else:
471                        warnings.warn(
472                            f"Could not extract configuration from callback "
473                            f"{callback.__class__.__name__}. Saving class name only.",
474                            UserWarning
475                        )
476                
477                # Check if config is JSON serializable
478                json.dumps(callback_info)
479                serialized.append(callback_info)
480                
481            except (TypeError, ValueError) as e:
482                warnings.warn(
483                    f"Callback {callback.__class__.__name__} could not be fully "
484                    f"serialized: {str(e)}. Saving class name only.",
485                    UserWarning
486                )
487                serialized.append({
488                    'class': callback.__class__.__name__,
489                    'module': callback.__class__.__module__,
490                    'error': str(e)
491                })
492        
493        return serialized
494    
495    def _save_history(self, output_dir):
496        """Save training history to JSON."""
497        history_dict = {}
498        
499        for key, values in self.history.history.items():
500            # Convert to list and ensure JSON serializable
501            if isinstance(values, np.ndarray):
502                history_dict[key] = values.tolist()
503            elif isinstance(values, list):
504                history_dict[key] = [float(x) for x in values]
505            else:
506                history_dict[key] = values
507        
508        # Serialize callbacks
509        callbacks_info = self._serialize_callbacks(self._callbacks)
510        
511        # Serialize custom objects
512        custom_objects_info = self._serialize_custom_objects()
513        
514        # Add metadata
515        history_dict['metadata'] = {
516            'model_name': self.model_name,
517            'version': self.version,
518            'epochs_trained': len(self.history.history['loss']),
519            'saved_at': datetime.now().isoformat(),
520            'callbacks': callbacks_info,
521            'custom_objects': custom_objects_info,
522            'notes': self._notes,
523            **self.metadata
524        }
525        
526        history_file = output_dir / self.history_path
527        with open(history_file, 'w') as f:
528            json.dump(history_dict, f, indent=4)
529        print(f"History saved: {history_file}")
530    
531    def _save_summary(self, output_dir):
532        """Save human-readable summary."""
533        summary_file = output_dir / self.summary_path
534        
535        with open(summary_file, 'w') as f:
536            f.write(f"{self.model_name} Model Summary\n")
537            f.write("=" * 60 + "\n\n")
538            
539            # Write metadata
540            if self.version:
541                f.write(f"Version: {self.version}\n")
542            for key, value in self.metadata.items():
543                f.write(f"{key}: {value}\n")
544            
545            f.write(f"Model file: {self.model_path}\n")
546            f.write(f"Saved: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
547            
548            # Write notes
549            if self._notes:
550                f.write("Notes / Hypothesis:\n")
551                f.write("-" * 60 + "\n")
552                for note_entry in self._notes:
553                    f.write(f"[{note_entry['timestamp']}] {note_entry['note']}\n")
554                f.write("\n")
555            
556            # Write custom objects information
557            if self.custom_objects:
558                f.write("Custom Objects:\n")
559                f.write("-" * 60 + "\n")
560                custom_objs_info = self._serialize_custom_objects()
561                for obj_info in custom_objs_info:
562                    f.write(f"  {obj_info['name']} ({obj_info['type']}):\n")
563                    if 'module' in obj_info:
564                        f.write(f"    module: {obj_info['module']}\n")
565                    if 'signature' in obj_info:
566                        f.write(f"    signature: {obj_info['signature']}\n")
567                    if 'source_hash' in obj_info:
568                        f.write(f"    source_hash: {obj_info['source_hash']}\n")
569                    if 'docstring' in obj_info:
570                        # First line of docstring only
571                        first_line = obj_info['docstring'].split('\n')[0]
572                        f.write(f"    docstring: {first_line}\n")
573                    f.write("\n")
574            
575            # Write callback information
576            if self._callbacks:
577                f.write("Callbacks Used:\n")
578                f.write("-" * 60 + "\n")
579                callbacks_info = self._serialize_callbacks(self._callbacks)
580                for cb_info in callbacks_info:
581                    f.write(f"  {cb_info['class']}:\n")
582                    if 'config' in cb_info:
583                        for key, value in cb_info['config'].items():
584                            f.write(f"    {key}: {value}\n")
585                    elif 'error' in cb_info:
586                        f.write(f"    (serialization error: {cb_info['error']})\n")
587                    f.write("\n")
588            
589            if self.history:
590                history = self.history.history
591                f.write("Training Results\n")
592                f.write("-" * 60 + "\n")
593                f.write(f"Epochs trained: {len(history['loss'])}\n\n")
594                
595                # Final metrics
596                f.write("Final Metrics:\n")
597                for key in sorted(history.keys()):
598                    if not key.startswith('val_'):
599                        val_key = f'val_{key}'
600                        train_val = history[key][-1]
601                        f.write(f"  {key}: {train_val:.4f}")
602                        if val_key in history:
603                            val_val = history[val_key][-1]
604                            f.write(f" | {val_key}: {val_val:.4f}")
605                        f.write("\n")
606                
607                # Best metrics (find all val_ metrics and report best)
608                f.write("\nBest Validation Metrics:\n")
609                for key in sorted(history.keys()):
610                    if key.startswith('val_'):
611                        # Determine if higher or lower is better
612                        if 'loss' in key or 'error' in key:
613                            best_idx = np.argmin(history[key])
614                            best_val = history[key][best_idx]
615                            f.write(f"  {key}: {best_val:.4f} (epoch {best_idx + 1}, min)\n")
616                        else:
617                            best_idx = np.argmax(history[key])
618                            best_val = history[key][best_idx]
619                            f.write(f"  {key}: {best_val:.4f} (epoch {best_idx + 1}, max)\n")
620        
621        print(f"Summary saved: {summary_file}")
622    
623    def load(self, model_path=None, load_history=True):
624        """
625        Load a saved model and optionally its history.
626        
627        Args:
628            model_path: Path to model file. If None, uses default naming convention.
629            load_history: Whether to load training history JSON. Default is True.
630        
631        Returns:
632            Tuple of (loaded_model, history_dict or None).
633        
634        Examples:
635            >>> # Load with default path
636            >>> manager = KerasModelManager('smith_123456_unet_model_128px')
637            >>> manager.set_model(None, custom_objects={'f1': f1})
638            >>> model, history = manager.load()
639            
640            >>> # Load from specific path
641            >>> model, history = manager.load('models/backup/model.h5')
642            
643            >>> # Load model only, skip history
644            >>> model, _ = manager.load(load_history=False)
645        """
646        if model_path is None:
647            model_path = self.model_path
648        
649        model_path = Path(model_path)
650        
651        if not model_path.exists():
652            raise FileNotFoundError(f"Model file not found: {model_path}")
653        
654        # Load model with custom objects
655        self.model = keras.models.load_model(
656            str(model_path),
657            custom_objects=self.custom_objects
658        )
659        print(f"Model loaded: {model_path}")
660        
661        loaded_history = None
662        
663        # Load history if available
664        if load_history:
665            history_path = model_path.parent / self.history_path
666            if history_path.exists():
667                with open(history_path, 'r') as f:
668                    loaded_history = json.load(f)
669                    print(f"History loaded: {history_path}")
670                    
671                    # Print summary
672                    if 'metadata' in loaded_history:
673                        meta = loaded_history['metadata']
674                        print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs")
675                        
676                        # Print notes if available
677                        if 'notes' in meta and meta['notes']:
678                            print("\nNotes:")
679                            for note_entry in meta['notes']:
680                                print(f"  [{note_entry['timestamp']}] {note_entry['note']}")
681                        
682                        # Print custom objects if available
683                        if 'custom_objects' in meta and meta['custom_objects']:
684                            print("\nCustom objects used:")
685                            for obj_info in meta['custom_objects']:
686                                print(f"  - {obj_info['name']} ({obj_info['type']})")
687                        
688                        # Print other metadata
689                        for key, val in meta.items():
690                            if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']:
691                                print(f"{key}: {val}")
692                    
693                    # Print best validation metrics
694                    for key in loaded_history.keys():
695                        if key.startswith('val_') and key != 'val_loss':
696                            best_val = max(loaded_history[key])
697                            print(f"Best {key}: {best_val:.4f}")
698        
699        return self.model, loaded_history
700    
701    def get_best_metrics(self, metric_preferences=None):
702        """
703        Get best metrics from training history.
704        
705        Args:
706            metric_preferences: Dict mapping metric names to 'min' or 'max'.
707                               If None, uses 'min' for loss/error metrics,
708                               'max' for all others.
709        
710        Returns:
711            Dictionary with best metrics, their epochs, and final values.
712            Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'.
713        
714        Examples:
715            >>> metrics = manager.get_best_metrics()
716            >>> print(f"Best F1: {metrics['best_val_f1']:.4f}")
717            >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}")
718            
719            >>> # Custom metric preferences
720            >>> metrics = manager.get_best_metrics(
721            ...     metric_preferences={'val_custom': 'min'}
722            ... )
723        """
724        if self.history is None:
725            raise ValueError("No training history available.")
726        
727        history = self.history.history
728        metrics = {}
729        
730        # Default preferences
731        if metric_preferences is None:
732            metric_preferences = {}
733        
734        for key in history.keys():
735            if key.startswith('val_'):
736                # Determine optimization direction
737                if key in metric_preferences:
738                    direction = metric_preferences[key]
739                elif 'loss' in key or 'error' in key:
740                    direction = 'min'
741                else:
742                    direction = 'max'
743                
744                # Find best value
745                if direction == 'min':
746                    best_idx = np.argmin(history[key])
747                else:
748                    best_idx = np.argmax(history[key])
749                
750                metrics[f'best_{key}'] = history[key][best_idx]
751                metrics[f'best_{key}_epoch'] = best_idx + 1
752                metrics[f'final_{key}'] = history[key][-1]
753        
754        return metrics
755    
756    def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)):
757        """
758        Plot training history.
759        
760        Args:
761            history_dict: History dict from load(). If None, uses self.history.
762            metrics: List of metrics to plot. If None, plots all metrics.
763            figsize: Figure size tuple.
764        
765        Examples:
766            >>> # After training
767            >>> manager.plot_history()
768            
769            >>> # After loading
770            >>> model, history = manager.load()
771            >>> manager.plot_history(history)
772            
773            >>> # Plot specific metrics
774            >>> manager.plot_history(metrics=['loss', 'f1'])
775        """
776        import matplotlib.pyplot as plt
777        
778        if history_dict is None:
779            if self.history is None:
780                raise ValueError("No history available. Train or load a model first.")
781            history_dict = self.history.history
782        
783        # Filter out metadata
784        history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'}
785        
786        # Determine metrics to plot
787        if metrics is None:
788            # Plot all non-validation metrics
789            metrics = [k for k in history_dict.keys() if not k.startswith('val_')]
790        
791        n_metrics = len(metrics)
792        fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
793        
794        # Handle single metric case
795        if n_metrics == 1:
796            axes = [axes]
797        
798        for ax, metric in zip(axes, metrics):
799            val_metric = f'val_{metric}'
800            
801            ax.plot(history_dict[metric], label=f'Training {metric}')
802            if val_metric in history_dict:
803                ax.plot(history_dict[val_metric], label=f'Validation {metric}')
804            
805            ax.set_xlabel('Epoch')
806            ax.set_ylabel(metric.replace('_', ' ').title())
807            ax.set_title(f'{metric.replace("_", " ").title()} over Epochs')
808            ax.legend()
809            ax.grid(True, alpha=0.3)
810        
811        plt.tight_layout()
812        plt.show()
class KerasModelManager:
125class KerasModelManager:
126    """Manages Keras model lifecycle: training, saving, loading, and metadata."""
127    
128    def __init__(self, model_name, version=None, notes=None, **metadata):
129        """
130        Initialize the model manager.
131        
132        Args:
133            model_name: Base name for the model files. For Task 5, use format:
134                       'studentname_studentnumber_modeltype_patchsizepx'
135            version: Optional version string or number (e.g., '1.0', 'v2', 2).
136            notes: Optional initial note/hypothesis for this model.
137            **metadata: Additional metadata to store with the model. Common examples:
138                       - student_name (str): Student name for academic projects
139                       - student_number (str): Student number for academic projects
140                       - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50')
141                       - patch_size (int): Patch size for segmentation models
142                       - input_shape (tuple): Input shape for the model
143                       - num_classes (int): Number of output classes
144                       - dataset (str): Dataset name
145                       - pretrained (str): Pretrained weights source
146                       - task (str): Task description
147        
148        Examples:
149            >>> # U-Net for segmentation with notes
150            >>> manager = KerasModelManager(
151            ...     'smith_123456_unet_model_128px',
152            ...     notes="Hypothesis: 128px patches will achieve F1 > 0.90",
153            ...     patch_size=128,
154            ...     architecture='U-Net'
155            ... )
156            
157            >>> # Versioned CNN classifier
158            >>> manager = KerasModelManager(
159            ...     'mnist_cnn',
160            ...     version='1.0',
161            ...     dataset='MNIST',
162            ...     input_shape=(28, 28, 1)
163            ... )
164        """
165        self.model_name = model_name
166        self.version = version
167        self.metadata = metadata
168        self.model = None
169        self.history = None
170        self.custom_objects = {}
171        self._callbacks = None
172        self._notes = []
173        
174        # Add initial note if provided
175        if notes:
176            self._add_note(notes)
177        
178        # Generate filenames
179        version_str = f"_v{version}" if version else ""
180        self.base_name = f"{model_name}{version_str}"
181        self.model_path = f"{self.base_name}.h5"
182        self.history_path = f"{self.base_name}_history.json"
183        self.summary_path = f"{self.base_name}_summary.txt"
184    
185    def set_model(self, model, custom_objects=None):
186        """
187        Set the model to be managed.
188        
189        Args:
190            model: Compiled Keras model, or None if only loading.
191            custom_objects: Dict of custom objects needed for model loading
192                           (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}).
193        
194        Examples:
195            >>> # With custom metric
196            >>> manager.set_model(unet_model, custom_objects={'f1': f1})
197            
198            >>> # Standard model, no custom objects
199            >>> manager.set_model(resnet_model)
200            
201            >>> # Preparing to load (no model yet)
202            >>> manager.set_model(None, custom_objects={'f1': f1})
203        """
204        self.model = model
205        if custom_objects:
206            self.custom_objects = custom_objects
207    
208    def _add_note(self, note):
209        """Internal method to add a timestamped note."""
210        timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
211        self._notes.append({
212            'timestamp': timestamp,
213            'note': note
214        })
215    
216    @property
217    def notes(self):
218        """Get all notes as a formatted string."""
219        if not self._notes:
220            return ""
221        return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes])
222    
223    @notes.setter
224    def notes(self, note):
225        """Append a new note with automatic timestamp."""
226        self._add_note(note)
227    
228    def set_notes(self, note, replace=False):
229        """
230        Set notes with option to replace all existing notes.
231        
232        Args:
233            note: The note text to add or set.
234            replace: If True, replaces all existing notes. If False, appends.
235        
236        Examples:
237            >>> # Append (default)
238            >>> manager.set_notes("Added dropout layers")
239            
240            >>> # Replace all notes
241            >>> manager.set_notes("Starting fresh experiment", replace=True)
242        """
243        if replace:
244            self._notes = []
245        self._add_note(note)
246    
247    def clear_notes(self):
248        """Clear all notes."""
249        self._notes = []
250    
251    def get_notes_list(self):
252        """
253        Get notes as a list of dictionaries.
254        
255        Returns:
256            List of dicts with 'timestamp' and 'note' keys.
257        """
258        return self._notes.copy()
259    
260    def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 
261            batch_size=16, callbacks=None, verbose=1, **fit_kwargs):
262        """
263        Train the model and store history.
264        
265        Args:
266            X_train: Training data (numpy array, generator, or dataset).
267            y_train: Training labels (numpy array, or None if using generator).
268            X_val: Optional validation data (or None if using validation_data kwarg).
269            y_val: Optional validation labels (or None if using validation_data kwarg).
270            epochs: Maximum number of epochs.
271            batch_size: Batch size for training (ignored if using generator).
272            callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint).
273            verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch).
274            **fit_kwargs: Additional arguments passed to model.fit()
275                        (e.g., steps_per_epoch, validation_steps, validation_data).
276        
277        Returns:
278            Training history object.
279        """
280        if self.model is None:
281            raise ValueError("Model not set. Call set_model() first.")
282        
283        # Store callbacks for later serialization
284        self._callbacks = callbacks
285        
286        # Handle validation_data - check if it's in fit_kwargs first
287        validation_data = fit_kwargs.pop('validation_data', None)
288        
289        # If not in fit_kwargs and we have X_val/y_val, construct it
290        if validation_data is None and X_val is not None and y_val is not None:
291            validation_data = (X_val, y_val)
292        
293        self.history = self.model.fit(
294            X_train, y_train,
295            validation_data=validation_data,
296            batch_size=batch_size,
297            epochs=epochs,
298            callbacks=callbacks,
299            verbose=verbose,
300            **fit_kwargs
301        )
302        
303        return self.history
304    
305    def save(self, output_dir='.'):
306        """
307        Save model, history, and summary to disk.
308        
309        Saves three files:
310        1. .h5 file - Complete model (architecture, weights, optimizer state)
311        2. _history.json - Training metrics for all epochs
312        3. _summary.txt - Human-readable summary
313        
314        Args:
315            output_dir: Directory to save files. Created if it doesn't exist.
316                       Default is current directory.
317        
318        Examples:
319            >>> # Save to current directory
320            >>> manager.save()
321            
322            >>> # Save to specific directory
323            >>> manager.save('trained_models')
324            
325            >>> # Save to nested path
326            >>> manager.save('models/unet/version_1')
327        """
328        if self.model is None:
329            raise ValueError("No model to save. Train or set a model first.")
330        
331        output_dir = Path(output_dir)
332        output_dir.mkdir(parents=True, exist_ok=True)
333        
334        # Save model
335        model_file = output_dir / self.model_path
336        self.model.save(str(model_file))
337        print(f"Model saved: {model_file}")
338        
339        # Save history if available
340        if self.history is not None:
341            self._save_history(output_dir)
342            self._save_summary(output_dir)
343    
344    def _serialize_custom_objects(self):
345        """
346        Serialize information about custom objects for documentation.
347        
348        Returns:
349            List of dictionaries with custom object metadata.
350        """
351        import inspect
352        import hashlib
353        
354        if not self.custom_objects:
355            return []
356        
357        serialized = []
358        
359        for name, obj in self.custom_objects.items():
360            obj_info = {
361                'name': name,
362                'type': type(obj).__name__
363            }
364            
365            # Try to get useful information about the object
366            try:
367                # Get module
368                if hasattr(obj, '__module__'):
369                    obj_info['module'] = obj.__module__
370                
371                # Get docstring
372                if hasattr(obj, '__doc__') and obj.__doc__:
373                    obj_info['docstring'] = obj.__doc__.strip()
374                
375                # For functions, get signature
376                if callable(obj):
377                    try:
378                        sig = inspect.signature(obj)
379                        obj_info['signature'] = str(sig)
380                    except (ValueError, TypeError):
381                        pass
382                    
383                    # Get source code hash (for version tracking)
384                    try:
385                        source = inspect.getsource(obj)
386                        source_hash = hashlib.md5(source.encode()).hexdigest()
387                        obj_info['source_hash'] = source_hash
388                    except (OSError, TypeError):
389                        pass
390                
391                serialized.append(obj_info)
392                
393            except Exception as e:
394                # If we can't serialize, at least save the name and type
395                obj_info['error'] = str(e)
396                serialized.append(obj_info)
397        
398        return serialized
399    
400    def _serialize_callbacks(self, callbacks):
401        """
402        Serialize callback configurations.
403        
404        Args:
405            callbacks: List of Keras callbacks or None.
406        
407        Returns:
408            List of serialized callback dictionaries.
409        """
410        import warnings
411        
412        if callbacks is None:
413            return []
414        
415        serialized = []
416        
417        for callback in callbacks:
418            callback_info = {
419                'class': callback.__class__.__name__,
420                'module': callback.__class__.__module__
421            }
422            
423            # Try to extract configuration
424            try:
425                # Get callback config if available
426                if hasattr(callback, 'get_config'):
427                    callback_info['config'] = callback.get_config()
428                else:
429                    # Manually extract common attributes
430                    config = {}
431                    
432                    # EarlyStopping attributes
433                    if hasattr(callback, 'monitor'):
434                        config['monitor'] = callback.monitor
435                    if hasattr(callback, 'patience'):
436                        config['patience'] = callback.patience
437                    if hasattr(callback, 'restore_best_weights'):
438                        config['restore_best_weights'] = callback.restore_best_weights
439                    if hasattr(callback, 'mode'):
440                        config['mode'] = callback.mode
441                    if hasattr(callback, 'min_delta'):
442                        config['min_delta'] = callback.min_delta
443                    
444                    # ReduceLROnPlateau attributes
445                    if hasattr(callback, 'factor'):
446                        config['factor'] = callback.factor
447                    if hasattr(callback, 'cooldown'):
448                        config['cooldown'] = callback.cooldown
449                    if hasattr(callback, 'min_lr'):
450                        config['min_lr'] = callback.min_lr
451                    
452                    # ModelCheckpoint attributes
453                    if hasattr(callback, 'filepath'):
454                        config['filepath'] = callback.filepath
455                    if hasattr(callback, 'save_best_only'):
456                        config['save_best_only'] = callback.save_best_only
457                    if hasattr(callback, 'save_weights_only'):
458                        config['save_weights_only'] = callback.save_weights_only
459                    
460                    # LearningRateScheduler
461                    if hasattr(callback, 'schedule'):
462                        warnings.warn(
463                            f"Callback {callback.__class__.__name__} has a 'schedule' "
464                            f"function that cannot be serialized. Saving class name only.",
465                            UserWarning
466                        )
467                        config['schedule'] = '<function>'
468                    
469                    if config:
470                        callback_info['config'] = config
471                    else:
472                        warnings.warn(
473                            f"Could not extract configuration from callback "
474                            f"{callback.__class__.__name__}. Saving class name only.",
475                            UserWarning
476                        )
477                
478                # Check if config is JSON serializable
479                json.dumps(callback_info)
480                serialized.append(callback_info)
481                
482            except (TypeError, ValueError) as e:
483                warnings.warn(
484                    f"Callback {callback.__class__.__name__} could not be fully "
485                    f"serialized: {str(e)}. Saving class name only.",
486                    UserWarning
487                )
488                serialized.append({
489                    'class': callback.__class__.__name__,
490                    'module': callback.__class__.__module__,
491                    'error': str(e)
492                })
493        
494        return serialized
495    
496    def _save_history(self, output_dir):
497        """Save training history to JSON."""
498        history_dict = {}
499        
500        for key, values in self.history.history.items():
501            # Convert to list and ensure JSON serializable
502            if isinstance(values, np.ndarray):
503                history_dict[key] = values.tolist()
504            elif isinstance(values, list):
505                history_dict[key] = [float(x) for x in values]
506            else:
507                history_dict[key] = values
508        
509        # Serialize callbacks
510        callbacks_info = self._serialize_callbacks(self._callbacks)
511        
512        # Serialize custom objects
513        custom_objects_info = self._serialize_custom_objects()
514        
515        # Add metadata
516        history_dict['metadata'] = {
517            'model_name': self.model_name,
518            'version': self.version,
519            'epochs_trained': len(self.history.history['loss']),
520            'saved_at': datetime.now().isoformat(),
521            'callbacks': callbacks_info,
522            'custom_objects': custom_objects_info,
523            'notes': self._notes,
524            **self.metadata
525        }
526        
527        history_file = output_dir / self.history_path
528        with open(history_file, 'w') as f:
529            json.dump(history_dict, f, indent=4)
530        print(f"History saved: {history_file}")
531    
532    def _save_summary(self, output_dir):
533        """Save human-readable summary."""
534        summary_file = output_dir / self.summary_path
535        
536        with open(summary_file, 'w') as f:
537            f.write(f"{self.model_name} Model Summary\n")
538            f.write("=" * 60 + "\n\n")
539            
540            # Write metadata
541            if self.version:
542                f.write(f"Version: {self.version}\n")
543            for key, value in self.metadata.items():
544                f.write(f"{key}: {value}\n")
545            
546            f.write(f"Model file: {self.model_path}\n")
547            f.write(f"Saved: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
548            
549            # Write notes
550            if self._notes:
551                f.write("Notes / Hypothesis:\n")
552                f.write("-" * 60 + "\n")
553                for note_entry in self._notes:
554                    f.write(f"[{note_entry['timestamp']}] {note_entry['note']}\n")
555                f.write("\n")
556            
557            # Write custom objects information
558            if self.custom_objects:
559                f.write("Custom Objects:\n")
560                f.write("-" * 60 + "\n")
561                custom_objs_info = self._serialize_custom_objects()
562                for obj_info in custom_objs_info:
563                    f.write(f"  {obj_info['name']} ({obj_info['type']}):\n")
564                    if 'module' in obj_info:
565                        f.write(f"    module: {obj_info['module']}\n")
566                    if 'signature' in obj_info:
567                        f.write(f"    signature: {obj_info['signature']}\n")
568                    if 'source_hash' in obj_info:
569                        f.write(f"    source_hash: {obj_info['source_hash']}\n")
570                    if 'docstring' in obj_info:
571                        # First line of docstring only
572                        first_line = obj_info['docstring'].split('\n')[0]
573                        f.write(f"    docstring: {first_line}\n")
574                    f.write("\n")
575            
576            # Write callback information
577            if self._callbacks:
578                f.write("Callbacks Used:\n")
579                f.write("-" * 60 + "\n")
580                callbacks_info = self._serialize_callbacks(self._callbacks)
581                for cb_info in callbacks_info:
582                    f.write(f"  {cb_info['class']}:\n")
583                    if 'config' in cb_info:
584                        for key, value in cb_info['config'].items():
585                            f.write(f"    {key}: {value}\n")
586                    elif 'error' in cb_info:
587                        f.write(f"    (serialization error: {cb_info['error']})\n")
588                    f.write("\n")
589            
590            if self.history:
591                history = self.history.history
592                f.write("Training Results\n")
593                f.write("-" * 60 + "\n")
594                f.write(f"Epochs trained: {len(history['loss'])}\n\n")
595                
596                # Final metrics
597                f.write("Final Metrics:\n")
598                for key in sorted(history.keys()):
599                    if not key.startswith('val_'):
600                        val_key = f'val_{key}'
601                        train_val = history[key][-1]
602                        f.write(f"  {key}: {train_val:.4f}")
603                        if val_key in history:
604                            val_val = history[val_key][-1]
605                            f.write(f" | {val_key}: {val_val:.4f}")
606                        f.write("\n")
607                
608                # Best metrics (find all val_ metrics and report best)
609                f.write("\nBest Validation Metrics:\n")
610                for key in sorted(history.keys()):
611                    if key.startswith('val_'):
612                        # Determine if higher or lower is better
613                        if 'loss' in key or 'error' in key:
614                            best_idx = np.argmin(history[key])
615                            best_val = history[key][best_idx]
616                            f.write(f"  {key}: {best_val:.4f} (epoch {best_idx + 1}, min)\n")
617                        else:
618                            best_idx = np.argmax(history[key])
619                            best_val = history[key][best_idx]
620                            f.write(f"  {key}: {best_val:.4f} (epoch {best_idx + 1}, max)\n")
621        
622        print(f"Summary saved: {summary_file}")
623    
624    def load(self, model_path=None, load_history=True):
625        """
626        Load a saved model and optionally its history.
627        
628        Args:
629            model_path: Path to model file. If None, uses default naming convention.
630            load_history: Whether to load training history JSON. Default is True.
631        
632        Returns:
633            Tuple of (loaded_model, history_dict or None).
634        
635        Examples:
636            >>> # Load with default path
637            >>> manager = KerasModelManager('smith_123456_unet_model_128px')
638            >>> manager.set_model(None, custom_objects={'f1': f1})
639            >>> model, history = manager.load()
640            
641            >>> # Load from specific path
642            >>> model, history = manager.load('models/backup/model.h5')
643            
644            >>> # Load model only, skip history
645            >>> model, _ = manager.load(load_history=False)
646        """
647        if model_path is None:
648            model_path = self.model_path
649        
650        model_path = Path(model_path)
651        
652        if not model_path.exists():
653            raise FileNotFoundError(f"Model file not found: {model_path}")
654        
655        # Load model with custom objects
656        self.model = keras.models.load_model(
657            str(model_path),
658            custom_objects=self.custom_objects
659        )
660        print(f"Model loaded: {model_path}")
661        
662        loaded_history = None
663        
664        # Load history if available
665        if load_history:
666            history_path = model_path.parent / self.history_path
667            if history_path.exists():
668                with open(history_path, 'r') as f:
669                    loaded_history = json.load(f)
670                    print(f"History loaded: {history_path}")
671                    
672                    # Print summary
673                    if 'metadata' in loaded_history:
674                        meta = loaded_history['metadata']
675                        print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs")
676                        
677                        # Print notes if available
678                        if 'notes' in meta and meta['notes']:
679                            print("\nNotes:")
680                            for note_entry in meta['notes']:
681                                print(f"  [{note_entry['timestamp']}] {note_entry['note']}")
682                        
683                        # Print custom objects if available
684                        if 'custom_objects' in meta and meta['custom_objects']:
685                            print("\nCustom objects used:")
686                            for obj_info in meta['custom_objects']:
687                                print(f"  - {obj_info['name']} ({obj_info['type']})")
688                        
689                        # Print other metadata
690                        for key, val in meta.items():
691                            if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']:
692                                print(f"{key}: {val}")
693                    
694                    # Print best validation metrics
695                    for key in loaded_history.keys():
696                        if key.startswith('val_') and key != 'val_loss':
697                            best_val = max(loaded_history[key])
698                            print(f"Best {key}: {best_val:.4f}")
699        
700        return self.model, loaded_history
701    
702    def get_best_metrics(self, metric_preferences=None):
703        """
704        Get best metrics from training history.
705        
706        Args:
707            metric_preferences: Dict mapping metric names to 'min' or 'max'.
708                               If None, uses 'min' for loss/error metrics,
709                               'max' for all others.
710        
711        Returns:
712            Dictionary with best metrics, their epochs, and final values.
713            Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'.
714        
715        Examples:
716            >>> metrics = manager.get_best_metrics()
717            >>> print(f"Best F1: {metrics['best_val_f1']:.4f}")
718            >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}")
719            
720            >>> # Custom metric preferences
721            >>> metrics = manager.get_best_metrics(
722            ...     metric_preferences={'val_custom': 'min'}
723            ... )
724        """
725        if self.history is None:
726            raise ValueError("No training history available.")
727        
728        history = self.history.history
729        metrics = {}
730        
731        # Default preferences
732        if metric_preferences is None:
733            metric_preferences = {}
734        
735        for key in history.keys():
736            if key.startswith('val_'):
737                # Determine optimization direction
738                if key in metric_preferences:
739                    direction = metric_preferences[key]
740                elif 'loss' in key or 'error' in key:
741                    direction = 'min'
742                else:
743                    direction = 'max'
744                
745                # Find best value
746                if direction == 'min':
747                    best_idx = np.argmin(history[key])
748                else:
749                    best_idx = np.argmax(history[key])
750                
751                metrics[f'best_{key}'] = history[key][best_idx]
752                metrics[f'best_{key}_epoch'] = best_idx + 1
753                metrics[f'final_{key}'] = history[key][-1]
754        
755        return metrics
756    
757    def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)):
758        """
759        Plot training history.
760        
761        Args:
762            history_dict: History dict from load(). If None, uses self.history.
763            metrics: List of metrics to plot. If None, plots all metrics.
764            figsize: Figure size tuple.
765        
766        Examples:
767            >>> # After training
768            >>> manager.plot_history()
769            
770            >>> # After loading
771            >>> model, history = manager.load()
772            >>> manager.plot_history(history)
773            
774            >>> # Plot specific metrics
775            >>> manager.plot_history(metrics=['loss', 'f1'])
776        """
777        import matplotlib.pyplot as plt
778        
779        if history_dict is None:
780            if self.history is None:
781                raise ValueError("No history available. Train or load a model first.")
782            history_dict = self.history.history
783        
784        # Filter out metadata
785        history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'}
786        
787        # Determine metrics to plot
788        if metrics is None:
789            # Plot all non-validation metrics
790            metrics = [k for k in history_dict.keys() if not k.startswith('val_')]
791        
792        n_metrics = len(metrics)
793        fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
794        
795        # Handle single metric case
796        if n_metrics == 1:
797            axes = [axes]
798        
799        for ax, metric in zip(axes, metrics):
800            val_metric = f'val_{metric}'
801            
802            ax.plot(history_dict[metric], label=f'Training {metric}')
803            if val_metric in history_dict:
804                ax.plot(history_dict[val_metric], label=f'Validation {metric}')
805            
806            ax.set_xlabel('Epoch')
807            ax.set_ylabel(metric.replace('_', ' ').title())
808            ax.set_title(f'{metric.replace("_", " ").title()} over Epochs')
809            ax.legend()
810            ax.grid(True, alpha=0.3)
811        
812        plt.tight_layout()
813        plt.show()

Manages Keras model lifecycle: training, saving, loading, and metadata.

KerasModelManager(model_name, version=None, notes=None, **metadata)
128    def __init__(self, model_name, version=None, notes=None, **metadata):
129        """
130        Initialize the model manager.
131        
132        Args:
133            model_name: Base name for the model files. For Task 5, use format:
134                       'studentname_studentnumber_modeltype_patchsizepx'
135            version: Optional version string or number (e.g., '1.0', 'v2', 2).
136            notes: Optional initial note/hypothesis for this model.
137            **metadata: Additional metadata to store with the model. Common examples:
138                       - student_name (str): Student name for academic projects
139                       - student_number (str): Student number for academic projects
140                       - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50')
141                       - patch_size (int): Patch size for segmentation models
142                       - input_shape (tuple): Input shape for the model
143                       - num_classes (int): Number of output classes
144                       - dataset (str): Dataset name
145                       - pretrained (str): Pretrained weights source
146                       - task (str): Task description
147        
148        Examples:
149            >>> # U-Net for segmentation with notes
150            >>> manager = KerasModelManager(
151            ...     'smith_123456_unet_model_128px',
152            ...     notes="Hypothesis: 128px patches will achieve F1 > 0.90",
153            ...     patch_size=128,
154            ...     architecture='U-Net'
155            ... )
156            
157            >>> # Versioned CNN classifier
158            >>> manager = KerasModelManager(
159            ...     'mnist_cnn',
160            ...     version='1.0',
161            ...     dataset='MNIST',
162            ...     input_shape=(28, 28, 1)
163            ... )
164        """
165        self.model_name = model_name
166        self.version = version
167        self.metadata = metadata
168        self.model = None
169        self.history = None
170        self.custom_objects = {}
171        self._callbacks = None
172        self._notes = []
173        
174        # Add initial note if provided
175        if notes:
176            self._add_note(notes)
177        
178        # Generate filenames
179        version_str = f"_v{version}" if version else ""
180        self.base_name = f"{model_name}{version_str}"
181        self.model_path = f"{self.base_name}.h5"
182        self.history_path = f"{self.base_name}_history.json"
183        self.summary_path = f"{self.base_name}_summary.txt"

Initialize the model manager.

Arguments:
  • model_name: Base name for the model files. For Task 5, use format: 'studentname_studentnumber_modeltype_patchsizepx'
  • version: Optional version string or number (e.g., '1.0', 'v2', 2).
  • notes: Optional initial note/hypothesis for this model.
  • **metadata: Additional metadata to store with the model. Common examples:
    • student_name (str): Student name for academic projects
    • student_number (str): Student number for academic projects
    • architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50')
    • patch_size (int): Patch size for segmentation models
    • input_shape (tuple): Input shape for the model
    • num_classes (int): Number of output classes
    • dataset (str): Dataset name
    • pretrained (str): Pretrained weights source
    • task (str): Task description
Examples:
>>> # U-Net for segmentation with notes
>>> manager = KerasModelManager(
...     'smith_123456_unet_model_128px',
...     notes="Hypothesis: 128px patches will achieve F1 > 0.90",
...     patch_size=128,
...     architecture='U-Net'
... )
>>> # Versioned CNN classifier
>>> manager = KerasModelManager(
...     'mnist_cnn',
...     version='1.0',
...     dataset='MNIST',
...     input_shape=(28, 28, 1)
... )
model_name
version
metadata
model
history
custom_objects
base_name
model_path
history_path
summary_path
def set_model(self, model, custom_objects=None):
185    def set_model(self, model, custom_objects=None):
186        """
187        Set the model to be managed.
188        
189        Args:
190            model: Compiled Keras model, or None if only loading.
191            custom_objects: Dict of custom objects needed for model loading
192                           (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}).
193        
194        Examples:
195            >>> # With custom metric
196            >>> manager.set_model(unet_model, custom_objects={'f1': f1})
197            
198            >>> # Standard model, no custom objects
199            >>> manager.set_model(resnet_model)
200            
201            >>> # Preparing to load (no model yet)
202            >>> manager.set_model(None, custom_objects={'f1': f1})
203        """
204        self.model = model
205        if custom_objects:
206            self.custom_objects = custom_objects

Set the model to be managed.

Arguments:
  • model: Compiled Keras model, or None if only loading.
  • custom_objects: Dict of custom objects needed for model loading (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}).
Examples:
>>> # With custom metric
>>> manager.set_model(unet_model, custom_objects={'f1': f1})
>>> # Standard model, no custom objects
>>> manager.set_model(resnet_model)
>>> # Preparing to load (no model yet)
>>> manager.set_model(None, custom_objects={'f1': f1})
notes
216    @property
217    def notes(self):
218        """Get all notes as a formatted string."""
219        if not self._notes:
220            return ""
221        return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes])

Get all notes as a formatted string.

def set_notes(self, note, replace=False):
228    def set_notes(self, note, replace=False):
229        """
230        Set notes with option to replace all existing notes.
231        
232        Args:
233            note: The note text to add or set.
234            replace: If True, replaces all existing notes. If False, appends.
235        
236        Examples:
237            >>> # Append (default)
238            >>> manager.set_notes("Added dropout layers")
239            
240            >>> # Replace all notes
241            >>> manager.set_notes("Starting fresh experiment", replace=True)
242        """
243        if replace:
244            self._notes = []
245        self._add_note(note)

Set notes with option to replace all existing notes.

Arguments:
  • note: The note text to add or set.
  • replace: If True, replaces all existing notes. If False, appends.
Examples:
>>> # Append (default)
>>> manager.set_notes("Added dropout layers")
>>> # Replace all notes
>>> manager.set_notes("Starting fresh experiment", replace=True)
def clear_notes(self):
247    def clear_notes(self):
248        """Clear all notes."""
249        self._notes = []

Clear all notes.

def get_notes_list(self):
251    def get_notes_list(self):
252        """
253        Get notes as a list of dictionaries.
254        
255        Returns:
256            List of dicts with 'timestamp' and 'note' keys.
257        """
258        return self._notes.copy()

Get notes as a list of dictionaries.

Returns:

List of dicts with 'timestamp' and 'note' keys.

def train( self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, batch_size=16, callbacks=None, verbose=1, **fit_kwargs):
260    def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 
261            batch_size=16, callbacks=None, verbose=1, **fit_kwargs):
262        """
263        Train the model and store history.
264        
265        Args:
266            X_train: Training data (numpy array, generator, or dataset).
267            y_train: Training labels (numpy array, or None if using generator).
268            X_val: Optional validation data (or None if using validation_data kwarg).
269            y_val: Optional validation labels (or None if using validation_data kwarg).
270            epochs: Maximum number of epochs.
271            batch_size: Batch size for training (ignored if using generator).
272            callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint).
273            verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch).
274            **fit_kwargs: Additional arguments passed to model.fit()
275                        (e.g., steps_per_epoch, validation_steps, validation_data).
276        
277        Returns:
278            Training history object.
279        """
280        if self.model is None:
281            raise ValueError("Model not set. Call set_model() first.")
282        
283        # Store callbacks for later serialization
284        self._callbacks = callbacks
285        
286        # Handle validation_data - check if it's in fit_kwargs first
287        validation_data = fit_kwargs.pop('validation_data', None)
288        
289        # If not in fit_kwargs and we have X_val/y_val, construct it
290        if validation_data is None and X_val is not None and y_val is not None:
291            validation_data = (X_val, y_val)
292        
293        self.history = self.model.fit(
294            X_train, y_train,
295            validation_data=validation_data,
296            batch_size=batch_size,
297            epochs=epochs,
298            callbacks=callbacks,
299            verbose=verbose,
300            **fit_kwargs
301        )
302        
303        return self.history

Train the model and store history.

Arguments:
  • X_train: Training data (numpy array, generator, or dataset).
  • y_train: Training labels (numpy array, or None if using generator).
  • X_val: Optional validation data (or None if using validation_data kwarg).
  • y_val: Optional validation labels (or None if using validation_data kwarg).
  • epochs: Maximum number of epochs.
  • batch_size: Batch size for training (ignored if using generator).
  • callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint).
  • verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch).
  • **fit_kwargs: Additional arguments passed to model.fit() (e.g., steps_per_epoch, validation_steps, validation_data).
Returns:

Training history object.

def save(self, output_dir='.'):
305    def save(self, output_dir='.'):
306        """
307        Save model, history, and summary to disk.
308        
309        Saves three files:
310        1. .h5 file - Complete model (architecture, weights, optimizer state)
311        2. _history.json - Training metrics for all epochs
312        3. _summary.txt - Human-readable summary
313        
314        Args:
315            output_dir: Directory to save files. Created if it doesn't exist.
316                       Default is current directory.
317        
318        Examples:
319            >>> # Save to current directory
320            >>> manager.save()
321            
322            >>> # Save to specific directory
323            >>> manager.save('trained_models')
324            
325            >>> # Save to nested path
326            >>> manager.save('models/unet/version_1')
327        """
328        if self.model is None:
329            raise ValueError("No model to save. Train or set a model first.")
330        
331        output_dir = Path(output_dir)
332        output_dir.mkdir(parents=True, exist_ok=True)
333        
334        # Save model
335        model_file = output_dir / self.model_path
336        self.model.save(str(model_file))
337        print(f"Model saved: {model_file}")
338        
339        # Save history if available
340        if self.history is not None:
341            self._save_history(output_dir)
342            self._save_summary(output_dir)

Save model, history, and summary to disk.

Saves three files:

  1. .h5 file - Complete model (architecture, weights, optimizer state)
  2. _history.json - Training metrics for all epochs
  3. _summary.txt - Human-readable summary
Arguments:
  • output_dir: Directory to save files. Created if it doesn't exist. Default is current directory.
Examples:
>>> # Save to current directory
>>> manager.save()
>>> # Save to specific directory
>>> manager.save('trained_models')
>>> # Save to nested path
>>> manager.save('models/unet/version_1')
def load(self, model_path=None, load_history=True):
624    def load(self, model_path=None, load_history=True):
625        """
626        Load a saved model and optionally its history.
627        
628        Args:
629            model_path: Path to model file. If None, uses default naming convention.
630            load_history: Whether to load training history JSON. Default is True.
631        
632        Returns:
633            Tuple of (loaded_model, history_dict or None).
634        
635        Examples:
636            >>> # Load with default path
637            >>> manager = KerasModelManager('smith_123456_unet_model_128px')
638            >>> manager.set_model(None, custom_objects={'f1': f1})
639            >>> model, history = manager.load()
640            
641            >>> # Load from specific path
642            >>> model, history = manager.load('models/backup/model.h5')
643            
644            >>> # Load model only, skip history
645            >>> model, _ = manager.load(load_history=False)
646        """
647        if model_path is None:
648            model_path = self.model_path
649        
650        model_path = Path(model_path)
651        
652        if not model_path.exists():
653            raise FileNotFoundError(f"Model file not found: {model_path}")
654        
655        # Load model with custom objects
656        self.model = keras.models.load_model(
657            str(model_path),
658            custom_objects=self.custom_objects
659        )
660        print(f"Model loaded: {model_path}")
661        
662        loaded_history = None
663        
664        # Load history if available
665        if load_history:
666            history_path = model_path.parent / self.history_path
667            if history_path.exists():
668                with open(history_path, 'r') as f:
669                    loaded_history = json.load(f)
670                    print(f"History loaded: {history_path}")
671                    
672                    # Print summary
673                    if 'metadata' in loaded_history:
674                        meta = loaded_history['metadata']
675                        print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs")
676                        
677                        # Print notes if available
678                        if 'notes' in meta and meta['notes']:
679                            print("\nNotes:")
680                            for note_entry in meta['notes']:
681                                print(f"  [{note_entry['timestamp']}] {note_entry['note']}")
682                        
683                        # Print custom objects if available
684                        if 'custom_objects' in meta and meta['custom_objects']:
685                            print("\nCustom objects used:")
686                            for obj_info in meta['custom_objects']:
687                                print(f"  - {obj_info['name']} ({obj_info['type']})")
688                        
689                        # Print other metadata
690                        for key, val in meta.items():
691                            if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']:
692                                print(f"{key}: {val}")
693                    
694                    # Print best validation metrics
695                    for key in loaded_history.keys():
696                        if key.startswith('val_') and key != 'val_loss':
697                            best_val = max(loaded_history[key])
698                            print(f"Best {key}: {best_val:.4f}")
699        
700        return self.model, loaded_history

Load a saved model and optionally its history.

Arguments:
  • model_path: Path to model file. If None, uses default naming convention.
  • load_history: Whether to load training history JSON. Default is True.
Returns:

Tuple of (loaded_model, history_dict or None).

Examples:
>>> # Load with default path
>>> manager = KerasModelManager('smith_123456_unet_model_128px')
>>> manager.set_model(None, custom_objects={'f1': f1})
>>> model, history = manager.load()
>>> # Load from specific path
>>> model, history = manager.load('models/backup/model.h5')
>>> # Load model only, skip history
>>> model, _ = manager.load(load_history=False)
def get_best_metrics(self, metric_preferences=None):
702    def get_best_metrics(self, metric_preferences=None):
703        """
704        Get best metrics from training history.
705        
706        Args:
707            metric_preferences: Dict mapping metric names to 'min' or 'max'.
708                               If None, uses 'min' for loss/error metrics,
709                               'max' for all others.
710        
711        Returns:
712            Dictionary with best metrics, their epochs, and final values.
713            Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'.
714        
715        Examples:
716            >>> metrics = manager.get_best_metrics()
717            >>> print(f"Best F1: {metrics['best_val_f1']:.4f}")
718            >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}")
719            
720            >>> # Custom metric preferences
721            >>> metrics = manager.get_best_metrics(
722            ...     metric_preferences={'val_custom': 'min'}
723            ... )
724        """
725        if self.history is None:
726            raise ValueError("No training history available.")
727        
728        history = self.history.history
729        metrics = {}
730        
731        # Default preferences
732        if metric_preferences is None:
733            metric_preferences = {}
734        
735        for key in history.keys():
736            if key.startswith('val_'):
737                # Determine optimization direction
738                if key in metric_preferences:
739                    direction = metric_preferences[key]
740                elif 'loss' in key or 'error' in key:
741                    direction = 'min'
742                else:
743                    direction = 'max'
744                
745                # Find best value
746                if direction == 'min':
747                    best_idx = np.argmin(history[key])
748                else:
749                    best_idx = np.argmax(history[key])
750                
751                metrics[f'best_{key}'] = history[key][best_idx]
752                metrics[f'best_{key}_epoch'] = best_idx + 1
753                metrics[f'final_{key}'] = history[key][-1]
754        
755        return metrics

Get best metrics from training history.

Arguments:
  • metric_preferences: Dict mapping metric names to 'min' or 'max'. If None, uses 'min' for loss/error metrics, 'max' for all others.
Returns:

Dictionary with best metrics, their epochs, and final values. Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'.

Examples:
>>> metrics = manager.get_best_metrics()
>>> print(f"Best F1: {metrics['best_val_f1']:.4f}")
>>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}")
>>> # Custom metric preferences
>>> metrics = manager.get_best_metrics(
...     metric_preferences={'val_custom': 'min'}
... )
def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)):
757    def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)):
758        """
759        Plot training history.
760        
761        Args:
762            history_dict: History dict from load(). If None, uses self.history.
763            metrics: List of metrics to plot. If None, plots all metrics.
764            figsize: Figure size tuple.
765        
766        Examples:
767            >>> # After training
768            >>> manager.plot_history()
769            
770            >>> # After loading
771            >>> model, history = manager.load()
772            >>> manager.plot_history(history)
773            
774            >>> # Plot specific metrics
775            >>> manager.plot_history(metrics=['loss', 'f1'])
776        """
777        import matplotlib.pyplot as plt
778        
779        if history_dict is None:
780            if self.history is None:
781                raise ValueError("No history available. Train or load a model first.")
782            history_dict = self.history.history
783        
784        # Filter out metadata
785        history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'}
786        
787        # Determine metrics to plot
788        if metrics is None:
789            # Plot all non-validation metrics
790            metrics = [k for k in history_dict.keys() if not k.startswith('val_')]
791        
792        n_metrics = len(metrics)
793        fig, axes = plt.subplots(1, n_metrics, figsize=figsize)
794        
795        # Handle single metric case
796        if n_metrics == 1:
797            axes = [axes]
798        
799        for ax, metric in zip(axes, metrics):
800            val_metric = f'val_{metric}'
801            
802            ax.plot(history_dict[metric], label=f'Training {metric}')
803            if val_metric in history_dict:
804                ax.plot(history_dict[val_metric], label=f'Validation {metric}')
805            
806            ax.set_xlabel('Epoch')
807            ax.set_ylabel(metric.replace('_', ' ').title())
808            ax.set_title(f'{metric.replace("_", " ").title()} over Epochs')
809            ax.legend()
810            ax.grid(True, alpha=0.3)
811        
812        plt.tight_layout()
813        plt.show()

Plot training history.

Arguments:
  • history_dict: History dict from load(). If None, uses self.history.
  • metrics: List of metrics to plot. If None, plots all metrics.
  • figsize: Figure size tuple.
Examples:
>>> # After training
>>> manager.plot_history()
>>> # After loading
>>> model, history = manager.load()
>>> manager.plot_history(history)
>>> # Plot specific metrics
>>> manager.plot_history(metrics=['loss', 'f1'])