library.model_manager
Class for managing Keras model training, saving, and loading with metadata.
This manager handles the complete lifecycle of Keras models including training, saving, loading, and metadata management. It works with any Keras model architecture including CNNs, U-Nets, ResNets, and custom architectures.
Examples:
U-Net for image segmentation (Task 5)::
from library.model_manager import KerasModelManager from tensorflow.keras.callbacks import EarlyStopping # Initialize with Task 5 naming convention manager = KerasModelManager( model_name='smith_123456_unet_model_128px', student_name='John Smith', student_number='123456', patch_size=128, architecture='U-Net', task='root_segmentation' ) # Set model with custom metrics manager.set_model(unet_model, custom_objects={'f1': f1}) # Train with early stopping early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True) manager.train(X_train, y_train, X_val, y_val, epochs=50, callbacks=[early_stop]) # Save everything manager.save('models') # Get best metrics metrics = manager.get_best_metrics() print(f"Best F1: {metrics['best_val_f1']:.4f}")CNN classifier for MNIST::
manager = KerasModelManager( model_name='mnist_cnn', version='1.0', dataset='MNIST', input_shape=(28, 28, 1), num_classes=10, architecture='CNN' ) manager.set_model(cnn_model) manager.train(X_train, y_train, X_val, y_val, epochs=20) manager.save('models')ResNet50 for plant classification::
manager = KerasModelManager( model_name='resnet50_plants', version='2.1', architecture='ResNet50', dataset='PlantCLEF', pretrained='ImageNet', fine_tuned_layers=10, input_size=224 ) manager.set_model(resnet_model) manager.train(train_generator, steps_per_epoch=100, validation_data=val_generator, validation_steps=20, epochs=30) manager.save('models')LSTM for time series prediction::
manager = KerasModelManager( model_name='stock_lstm', version='1.0', architecture='LSTM', sequence_length=60, features=5, prediction_horizon=1 ) manager.set_model(lstm_model) manager.train(X_train, y_train, X_val, y_val, epochs=50, batch_size=32) manager.save('models')Loading a saved model::
# Initialize manager with same naming manager = KerasModelManager( model_name='smith_123456_unet_model_128px' ) # Set custom objects before loading manager.set_model(None, custom_objects={'f1': f1}) # Load model and history model, history = manager.load('models/smith_123456_unet_model_128px.h5') # Use loaded model predictions = model.predict(X_test)
Attributes:
- model_name (str): Base name for the model files.
- version (str): Optional version string or number.
- metadata (dict): Additional metadata stored with the model.
- model (keras.Model): The Keras model being managed.
- history (keras.callbacks.History): Training history from model.fit().
- custom_objects (dict): Custom objects needed for model loading.
- base_name (str): Generated base filename.
- model_path (str): Path to saved model file.
- history_path (str): Path to saved history JSON.
- summary_path (str): Path to saved summary text file.
1# library/model_manager.py 2"""Class for managing Keras model training, saving, and loading with metadata. 3 4This manager handles the complete lifecycle of Keras models including training, 5saving, loading, and metadata management. It works with any Keras model architecture 6including CNNs, U-Nets, ResNets, and custom architectures. 7 8Examples: 9 U-Net for image segmentation (Task 5):: 10 11 from library.model_manager import KerasModelManager 12 from tensorflow.keras.callbacks import EarlyStopping 13 14 # Initialize with Task 5 naming convention 15 manager = KerasModelManager( 16 model_name='smith_123456_unet_model_128px', 17 student_name='John Smith', 18 student_number='123456', 19 patch_size=128, 20 architecture='U-Net', 21 task='root_segmentation' 22 ) 23 24 # Set model with custom metrics 25 manager.set_model(unet_model, custom_objects={'f1': f1}) 26 27 # Train with early stopping 28 early_stop = EarlyStopping(monitor='val_loss', patience=3, 29 restore_best_weights=True) 30 manager.train(X_train, y_train, X_val, y_val, 31 epochs=50, callbacks=[early_stop]) 32 33 # Save everything 34 manager.save('models') 35 36 # Get best metrics 37 metrics = manager.get_best_metrics() 38 print(f"Best F1: {metrics['best_val_f1']:.4f}") 39 40 CNN classifier for MNIST:: 41 42 manager = KerasModelManager( 43 model_name='mnist_cnn', 44 version='1.0', 45 dataset='MNIST', 46 input_shape=(28, 28, 1), 47 num_classes=10, 48 architecture='CNN' 49 ) 50 51 manager.set_model(cnn_model) 52 manager.train(X_train, y_train, X_val, y_val, epochs=20) 53 manager.save('models') 54 55 ResNet50 for plant classification:: 56 57 manager = KerasModelManager( 58 model_name='resnet50_plants', 59 version='2.1', 60 architecture='ResNet50', 61 dataset='PlantCLEF', 62 pretrained='ImageNet', 63 fine_tuned_layers=10, 64 input_size=224 65 ) 66 67 manager.set_model(resnet_model) 68 manager.train(train_generator, steps_per_epoch=100, 69 validation_data=val_generator, validation_steps=20, 70 epochs=30) 71 manager.save('models') 72 73 LSTM for time series prediction:: 74 75 manager = KerasModelManager( 76 model_name='stock_lstm', 77 version='1.0', 78 architecture='LSTM', 79 sequence_length=60, 80 features=5, 81 prediction_horizon=1 82 ) 83 84 manager.set_model(lstm_model) 85 manager.train(X_train, y_train, X_val, y_val, epochs=50, batch_size=32) 86 manager.save('models') 87 88 Loading a saved model:: 89 90 # Initialize manager with same naming 91 manager = KerasModelManager( 92 model_name='smith_123456_unet_model_128px' 93 ) 94 95 # Set custom objects before loading 96 manager.set_model(None, custom_objects={'f1': f1}) 97 98 # Load model and history 99 model, history = manager.load('models/smith_123456_unet_model_128px.h5') 100 101 # Use loaded model 102 predictions = model.predict(X_test) 103 104Attributes: 105 model_name (str): Base name for the model files. 106 version (str): Optional version string or number. 107 metadata (dict): Additional metadata stored with the model. 108 model (keras.Model): The Keras model being managed. 109 history (keras.callbacks.History): Training history from model.fit(). 110 custom_objects (dict): Custom objects needed for model loading. 111 base_name (str): Generated base filename. 112 model_path (str): Path to saved model file. 113 history_path (str): Path to saved history JSON. 114 summary_path (str): Path to saved summary text file. 115""" 116 117import json 118import numpy as np 119from pathlib import Path 120from datetime import datetime 121from tensorflow import keras 122 123 124class KerasModelManager: 125 """Manages Keras model lifecycle: training, saving, loading, and metadata.""" 126 127 def __init__(self, model_name, version=None, notes=None, **metadata): 128 """ 129 Initialize the model manager. 130 131 Args: 132 model_name: Base name for the model files. For Task 5, use format: 133 'studentname_studentnumber_modeltype_patchsizepx' 134 version: Optional version string or number (e.g., '1.0', 'v2', 2). 135 notes: Optional initial note/hypothesis for this model. 136 **metadata: Additional metadata to store with the model. Common examples: 137 - student_name (str): Student name for academic projects 138 - student_number (str): Student number for academic projects 139 - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50') 140 - patch_size (int): Patch size for segmentation models 141 - input_shape (tuple): Input shape for the model 142 - num_classes (int): Number of output classes 143 - dataset (str): Dataset name 144 - pretrained (str): Pretrained weights source 145 - task (str): Task description 146 147 Examples: 148 >>> # U-Net for segmentation with notes 149 >>> manager = KerasModelManager( 150 ... 'smith_123456_unet_model_128px', 151 ... notes="Hypothesis: 128px patches will achieve F1 > 0.90", 152 ... patch_size=128, 153 ... architecture='U-Net' 154 ... ) 155 156 >>> # Versioned CNN classifier 157 >>> manager = KerasModelManager( 158 ... 'mnist_cnn', 159 ... version='1.0', 160 ... dataset='MNIST', 161 ... input_shape=(28, 28, 1) 162 ... ) 163 """ 164 self.model_name = model_name 165 self.version = version 166 self.metadata = metadata 167 self.model = None 168 self.history = None 169 self.custom_objects = {} 170 self._callbacks = None 171 self._notes = [] 172 173 # Add initial note if provided 174 if notes: 175 self._add_note(notes) 176 177 # Generate filenames 178 version_str = f"_v{version}" if version else "" 179 self.base_name = f"{model_name}{version_str}" 180 self.model_path = f"{self.base_name}.h5" 181 self.history_path = f"{self.base_name}_history.json" 182 self.summary_path = f"{self.base_name}_summary.txt" 183 184 def set_model(self, model, custom_objects=None): 185 """ 186 Set the model to be managed. 187 188 Args: 189 model: Compiled Keras model, or None if only loading. 190 custom_objects: Dict of custom objects needed for model loading 191 (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}). 192 193 Examples: 194 >>> # With custom metric 195 >>> manager.set_model(unet_model, custom_objects={'f1': f1}) 196 197 >>> # Standard model, no custom objects 198 >>> manager.set_model(resnet_model) 199 200 >>> # Preparing to load (no model yet) 201 >>> manager.set_model(None, custom_objects={'f1': f1}) 202 """ 203 self.model = model 204 if custom_objects: 205 self.custom_objects = custom_objects 206 207 def _add_note(self, note): 208 """Internal method to add a timestamped note.""" 209 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 210 self._notes.append({ 211 'timestamp': timestamp, 212 'note': note 213 }) 214 215 @property 216 def notes(self): 217 """Get all notes as a formatted string.""" 218 if not self._notes: 219 return "" 220 return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes]) 221 222 @notes.setter 223 def notes(self, note): 224 """Append a new note with automatic timestamp.""" 225 self._add_note(note) 226 227 def set_notes(self, note, replace=False): 228 """ 229 Set notes with option to replace all existing notes. 230 231 Args: 232 note: The note text to add or set. 233 replace: If True, replaces all existing notes. If False, appends. 234 235 Examples: 236 >>> # Append (default) 237 >>> manager.set_notes("Added dropout layers") 238 239 >>> # Replace all notes 240 >>> manager.set_notes("Starting fresh experiment", replace=True) 241 """ 242 if replace: 243 self._notes = [] 244 self._add_note(note) 245 246 def clear_notes(self): 247 """Clear all notes.""" 248 self._notes = [] 249 250 def get_notes_list(self): 251 """ 252 Get notes as a list of dictionaries. 253 254 Returns: 255 List of dicts with 'timestamp' and 'note' keys. 256 """ 257 return self._notes.copy() 258 259 def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 260 batch_size=16, callbacks=None, verbose=1, **fit_kwargs): 261 """ 262 Train the model and store history. 263 264 Args: 265 X_train: Training data (numpy array, generator, or dataset). 266 y_train: Training labels (numpy array, or None if using generator). 267 X_val: Optional validation data (or None if using validation_data kwarg). 268 y_val: Optional validation labels (or None if using validation_data kwarg). 269 epochs: Maximum number of epochs. 270 batch_size: Batch size for training (ignored if using generator). 271 callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint). 272 verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch). 273 **fit_kwargs: Additional arguments passed to model.fit() 274 (e.g., steps_per_epoch, validation_steps, validation_data). 275 276 Returns: 277 Training history object. 278 """ 279 if self.model is None: 280 raise ValueError("Model not set. Call set_model() first.") 281 282 # Store callbacks for later serialization 283 self._callbacks = callbacks 284 285 # Handle validation_data - check if it's in fit_kwargs first 286 validation_data = fit_kwargs.pop('validation_data', None) 287 288 # If not in fit_kwargs and we have X_val/y_val, construct it 289 if validation_data is None and X_val is not None and y_val is not None: 290 validation_data = (X_val, y_val) 291 292 self.history = self.model.fit( 293 X_train, y_train, 294 validation_data=validation_data, 295 batch_size=batch_size, 296 epochs=epochs, 297 callbacks=callbacks, 298 verbose=verbose, 299 **fit_kwargs 300 ) 301 302 return self.history 303 304 def save(self, output_dir='.'): 305 """ 306 Save model, history, and summary to disk. 307 308 Saves three files: 309 1. .h5 file - Complete model (architecture, weights, optimizer state) 310 2. _history.json - Training metrics for all epochs 311 3. _summary.txt - Human-readable summary 312 313 Args: 314 output_dir: Directory to save files. Created if it doesn't exist. 315 Default is current directory. 316 317 Examples: 318 >>> # Save to current directory 319 >>> manager.save() 320 321 >>> # Save to specific directory 322 >>> manager.save('trained_models') 323 324 >>> # Save to nested path 325 >>> manager.save('models/unet/version_1') 326 """ 327 if self.model is None: 328 raise ValueError("No model to save. Train or set a model first.") 329 330 output_dir = Path(output_dir) 331 output_dir.mkdir(parents=True, exist_ok=True) 332 333 # Save model 334 model_file = output_dir / self.model_path 335 self.model.save(str(model_file)) 336 print(f"Model saved: {model_file}") 337 338 # Save history if available 339 if self.history is not None: 340 self._save_history(output_dir) 341 self._save_summary(output_dir) 342 343 def _serialize_custom_objects(self): 344 """ 345 Serialize information about custom objects for documentation. 346 347 Returns: 348 List of dictionaries with custom object metadata. 349 """ 350 import inspect 351 import hashlib 352 353 if not self.custom_objects: 354 return [] 355 356 serialized = [] 357 358 for name, obj in self.custom_objects.items(): 359 obj_info = { 360 'name': name, 361 'type': type(obj).__name__ 362 } 363 364 # Try to get useful information about the object 365 try: 366 # Get module 367 if hasattr(obj, '__module__'): 368 obj_info['module'] = obj.__module__ 369 370 # Get docstring 371 if hasattr(obj, '__doc__') and obj.__doc__: 372 obj_info['docstring'] = obj.__doc__.strip() 373 374 # For functions, get signature 375 if callable(obj): 376 try: 377 sig = inspect.signature(obj) 378 obj_info['signature'] = str(sig) 379 except (ValueError, TypeError): 380 pass 381 382 # Get source code hash (for version tracking) 383 try: 384 source = inspect.getsource(obj) 385 source_hash = hashlib.md5(source.encode()).hexdigest() 386 obj_info['source_hash'] = source_hash 387 except (OSError, TypeError): 388 pass 389 390 serialized.append(obj_info) 391 392 except Exception as e: 393 # If we can't serialize, at least save the name and type 394 obj_info['error'] = str(e) 395 serialized.append(obj_info) 396 397 return serialized 398 399 def _serialize_callbacks(self, callbacks): 400 """ 401 Serialize callback configurations. 402 403 Args: 404 callbacks: List of Keras callbacks or None. 405 406 Returns: 407 List of serialized callback dictionaries. 408 """ 409 import warnings 410 411 if callbacks is None: 412 return [] 413 414 serialized = [] 415 416 for callback in callbacks: 417 callback_info = { 418 'class': callback.__class__.__name__, 419 'module': callback.__class__.__module__ 420 } 421 422 # Try to extract configuration 423 try: 424 # Get callback config if available 425 if hasattr(callback, 'get_config'): 426 callback_info['config'] = callback.get_config() 427 else: 428 # Manually extract common attributes 429 config = {} 430 431 # EarlyStopping attributes 432 if hasattr(callback, 'monitor'): 433 config['monitor'] = callback.monitor 434 if hasattr(callback, 'patience'): 435 config['patience'] = callback.patience 436 if hasattr(callback, 'restore_best_weights'): 437 config['restore_best_weights'] = callback.restore_best_weights 438 if hasattr(callback, 'mode'): 439 config['mode'] = callback.mode 440 if hasattr(callback, 'min_delta'): 441 config['min_delta'] = callback.min_delta 442 443 # ReduceLROnPlateau attributes 444 if hasattr(callback, 'factor'): 445 config['factor'] = callback.factor 446 if hasattr(callback, 'cooldown'): 447 config['cooldown'] = callback.cooldown 448 if hasattr(callback, 'min_lr'): 449 config['min_lr'] = callback.min_lr 450 451 # ModelCheckpoint attributes 452 if hasattr(callback, 'filepath'): 453 config['filepath'] = callback.filepath 454 if hasattr(callback, 'save_best_only'): 455 config['save_best_only'] = callback.save_best_only 456 if hasattr(callback, 'save_weights_only'): 457 config['save_weights_only'] = callback.save_weights_only 458 459 # LearningRateScheduler 460 if hasattr(callback, 'schedule'): 461 warnings.warn( 462 f"Callback {callback.__class__.__name__} has a 'schedule' " 463 f"function that cannot be serialized. Saving class name only.", 464 UserWarning 465 ) 466 config['schedule'] = '<function>' 467 468 if config: 469 callback_info['config'] = config 470 else: 471 warnings.warn( 472 f"Could not extract configuration from callback " 473 f"{callback.__class__.__name__}. Saving class name only.", 474 UserWarning 475 ) 476 477 # Check if config is JSON serializable 478 json.dumps(callback_info) 479 serialized.append(callback_info) 480 481 except (TypeError, ValueError) as e: 482 warnings.warn( 483 f"Callback {callback.__class__.__name__} could not be fully " 484 f"serialized: {str(e)}. Saving class name only.", 485 UserWarning 486 ) 487 serialized.append({ 488 'class': callback.__class__.__name__, 489 'module': callback.__class__.__module__, 490 'error': str(e) 491 }) 492 493 return serialized 494 495 def _save_history(self, output_dir): 496 """Save training history to JSON.""" 497 history_dict = {} 498 499 for key, values in self.history.history.items(): 500 # Convert to list and ensure JSON serializable 501 if isinstance(values, np.ndarray): 502 history_dict[key] = values.tolist() 503 elif isinstance(values, list): 504 history_dict[key] = [float(x) for x in values] 505 else: 506 history_dict[key] = values 507 508 # Serialize callbacks 509 callbacks_info = self._serialize_callbacks(self._callbacks) 510 511 # Serialize custom objects 512 custom_objects_info = self._serialize_custom_objects() 513 514 # Add metadata 515 history_dict['metadata'] = { 516 'model_name': self.model_name, 517 'version': self.version, 518 'epochs_trained': len(self.history.history['loss']), 519 'saved_at': datetime.now().isoformat(), 520 'callbacks': callbacks_info, 521 'custom_objects': custom_objects_info, 522 'notes': self._notes, 523 **self.metadata 524 } 525 526 history_file = output_dir / self.history_path 527 with open(history_file, 'w') as f: 528 json.dump(history_dict, f, indent=4) 529 print(f"History saved: {history_file}") 530 531 def _save_summary(self, output_dir): 532 """Save human-readable summary.""" 533 summary_file = output_dir / self.summary_path 534 535 with open(summary_file, 'w') as f: 536 f.write(f"{self.model_name} Model Summary\n") 537 f.write("=" * 60 + "\n\n") 538 539 # Write metadata 540 if self.version: 541 f.write(f"Version: {self.version}\n") 542 for key, value in self.metadata.items(): 543 f.write(f"{key}: {value}\n") 544 545 f.write(f"Model file: {self.model_path}\n") 546 f.write(f"Saved: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") 547 548 # Write notes 549 if self._notes: 550 f.write("Notes / Hypothesis:\n") 551 f.write("-" * 60 + "\n") 552 for note_entry in self._notes: 553 f.write(f"[{note_entry['timestamp']}] {note_entry['note']}\n") 554 f.write("\n") 555 556 # Write custom objects information 557 if self.custom_objects: 558 f.write("Custom Objects:\n") 559 f.write("-" * 60 + "\n") 560 custom_objs_info = self._serialize_custom_objects() 561 for obj_info in custom_objs_info: 562 f.write(f" {obj_info['name']} ({obj_info['type']}):\n") 563 if 'module' in obj_info: 564 f.write(f" module: {obj_info['module']}\n") 565 if 'signature' in obj_info: 566 f.write(f" signature: {obj_info['signature']}\n") 567 if 'source_hash' in obj_info: 568 f.write(f" source_hash: {obj_info['source_hash']}\n") 569 if 'docstring' in obj_info: 570 # First line of docstring only 571 first_line = obj_info['docstring'].split('\n')[0] 572 f.write(f" docstring: {first_line}\n") 573 f.write("\n") 574 575 # Write callback information 576 if self._callbacks: 577 f.write("Callbacks Used:\n") 578 f.write("-" * 60 + "\n") 579 callbacks_info = self._serialize_callbacks(self._callbacks) 580 for cb_info in callbacks_info: 581 f.write(f" {cb_info['class']}:\n") 582 if 'config' in cb_info: 583 for key, value in cb_info['config'].items(): 584 f.write(f" {key}: {value}\n") 585 elif 'error' in cb_info: 586 f.write(f" (serialization error: {cb_info['error']})\n") 587 f.write("\n") 588 589 if self.history: 590 history = self.history.history 591 f.write("Training Results\n") 592 f.write("-" * 60 + "\n") 593 f.write(f"Epochs trained: {len(history['loss'])}\n\n") 594 595 # Final metrics 596 f.write("Final Metrics:\n") 597 for key in sorted(history.keys()): 598 if not key.startswith('val_'): 599 val_key = f'val_{key}' 600 train_val = history[key][-1] 601 f.write(f" {key}: {train_val:.4f}") 602 if val_key in history: 603 val_val = history[val_key][-1] 604 f.write(f" | {val_key}: {val_val:.4f}") 605 f.write("\n") 606 607 # Best metrics (find all val_ metrics and report best) 608 f.write("\nBest Validation Metrics:\n") 609 for key in sorted(history.keys()): 610 if key.startswith('val_'): 611 # Determine if higher or lower is better 612 if 'loss' in key or 'error' in key: 613 best_idx = np.argmin(history[key]) 614 best_val = history[key][best_idx] 615 f.write(f" {key}: {best_val:.4f} (epoch {best_idx + 1}, min)\n") 616 else: 617 best_idx = np.argmax(history[key]) 618 best_val = history[key][best_idx] 619 f.write(f" {key}: {best_val:.4f} (epoch {best_idx + 1}, max)\n") 620 621 print(f"Summary saved: {summary_file}") 622 623 def load(self, model_path=None, load_history=True): 624 """ 625 Load a saved model and optionally its history. 626 627 Args: 628 model_path: Path to model file. If None, uses default naming convention. 629 load_history: Whether to load training history JSON. Default is True. 630 631 Returns: 632 Tuple of (loaded_model, history_dict or None). 633 634 Examples: 635 >>> # Load with default path 636 >>> manager = KerasModelManager('smith_123456_unet_model_128px') 637 >>> manager.set_model(None, custom_objects={'f1': f1}) 638 >>> model, history = manager.load() 639 640 >>> # Load from specific path 641 >>> model, history = manager.load('models/backup/model.h5') 642 643 >>> # Load model only, skip history 644 >>> model, _ = manager.load(load_history=False) 645 """ 646 if model_path is None: 647 model_path = self.model_path 648 649 model_path = Path(model_path) 650 651 if not model_path.exists(): 652 raise FileNotFoundError(f"Model file not found: {model_path}") 653 654 # Load model with custom objects 655 self.model = keras.models.load_model( 656 str(model_path), 657 custom_objects=self.custom_objects 658 ) 659 print(f"Model loaded: {model_path}") 660 661 loaded_history = None 662 663 # Load history if available 664 if load_history: 665 history_path = model_path.parent / self.history_path 666 if history_path.exists(): 667 with open(history_path, 'r') as f: 668 loaded_history = json.load(f) 669 print(f"History loaded: {history_path}") 670 671 # Print summary 672 if 'metadata' in loaded_history: 673 meta = loaded_history['metadata'] 674 print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs") 675 676 # Print notes if available 677 if 'notes' in meta and meta['notes']: 678 print("\nNotes:") 679 for note_entry in meta['notes']: 680 print(f" [{note_entry['timestamp']}] {note_entry['note']}") 681 682 # Print custom objects if available 683 if 'custom_objects' in meta and meta['custom_objects']: 684 print("\nCustom objects used:") 685 for obj_info in meta['custom_objects']: 686 print(f" - {obj_info['name']} ({obj_info['type']})") 687 688 # Print other metadata 689 for key, val in meta.items(): 690 if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']: 691 print(f"{key}: {val}") 692 693 # Print best validation metrics 694 for key in loaded_history.keys(): 695 if key.startswith('val_') and key != 'val_loss': 696 best_val = max(loaded_history[key]) 697 print(f"Best {key}: {best_val:.4f}") 698 699 return self.model, loaded_history 700 701 def get_best_metrics(self, metric_preferences=None): 702 """ 703 Get best metrics from training history. 704 705 Args: 706 metric_preferences: Dict mapping metric names to 'min' or 'max'. 707 If None, uses 'min' for loss/error metrics, 708 'max' for all others. 709 710 Returns: 711 Dictionary with best metrics, their epochs, and final values. 712 Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'. 713 714 Examples: 715 >>> metrics = manager.get_best_metrics() 716 >>> print(f"Best F1: {metrics['best_val_f1']:.4f}") 717 >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}") 718 719 >>> # Custom metric preferences 720 >>> metrics = manager.get_best_metrics( 721 ... metric_preferences={'val_custom': 'min'} 722 ... ) 723 """ 724 if self.history is None: 725 raise ValueError("No training history available.") 726 727 history = self.history.history 728 metrics = {} 729 730 # Default preferences 731 if metric_preferences is None: 732 metric_preferences = {} 733 734 for key in history.keys(): 735 if key.startswith('val_'): 736 # Determine optimization direction 737 if key in metric_preferences: 738 direction = metric_preferences[key] 739 elif 'loss' in key or 'error' in key: 740 direction = 'min' 741 else: 742 direction = 'max' 743 744 # Find best value 745 if direction == 'min': 746 best_idx = np.argmin(history[key]) 747 else: 748 best_idx = np.argmax(history[key]) 749 750 metrics[f'best_{key}'] = history[key][best_idx] 751 metrics[f'best_{key}_epoch'] = best_idx + 1 752 metrics[f'final_{key}'] = history[key][-1] 753 754 return metrics 755 756 def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)): 757 """ 758 Plot training history. 759 760 Args: 761 history_dict: History dict from load(). If None, uses self.history. 762 metrics: List of metrics to plot. If None, plots all metrics. 763 figsize: Figure size tuple. 764 765 Examples: 766 >>> # After training 767 >>> manager.plot_history() 768 769 >>> # After loading 770 >>> model, history = manager.load() 771 >>> manager.plot_history(history) 772 773 >>> # Plot specific metrics 774 >>> manager.plot_history(metrics=['loss', 'f1']) 775 """ 776 import matplotlib.pyplot as plt 777 778 if history_dict is None: 779 if self.history is None: 780 raise ValueError("No history available. Train or load a model first.") 781 history_dict = self.history.history 782 783 # Filter out metadata 784 history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'} 785 786 # Determine metrics to plot 787 if metrics is None: 788 # Plot all non-validation metrics 789 metrics = [k for k in history_dict.keys() if not k.startswith('val_')] 790 791 n_metrics = len(metrics) 792 fig, axes = plt.subplots(1, n_metrics, figsize=figsize) 793 794 # Handle single metric case 795 if n_metrics == 1: 796 axes = [axes] 797 798 for ax, metric in zip(axes, metrics): 799 val_metric = f'val_{metric}' 800 801 ax.plot(history_dict[metric], label=f'Training {metric}') 802 if val_metric in history_dict: 803 ax.plot(history_dict[val_metric], label=f'Validation {metric}') 804 805 ax.set_xlabel('Epoch') 806 ax.set_ylabel(metric.replace('_', ' ').title()) 807 ax.set_title(f'{metric.replace("_", " ").title()} over Epochs') 808 ax.legend() 809 ax.grid(True, alpha=0.3) 810 811 plt.tight_layout() 812 plt.show()
125class KerasModelManager: 126 """Manages Keras model lifecycle: training, saving, loading, and metadata.""" 127 128 def __init__(self, model_name, version=None, notes=None, **metadata): 129 """ 130 Initialize the model manager. 131 132 Args: 133 model_name: Base name for the model files. For Task 5, use format: 134 'studentname_studentnumber_modeltype_patchsizepx' 135 version: Optional version string or number (e.g., '1.0', 'v2', 2). 136 notes: Optional initial note/hypothesis for this model. 137 **metadata: Additional metadata to store with the model. Common examples: 138 - student_name (str): Student name for academic projects 139 - student_number (str): Student number for academic projects 140 - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50') 141 - patch_size (int): Patch size for segmentation models 142 - input_shape (tuple): Input shape for the model 143 - num_classes (int): Number of output classes 144 - dataset (str): Dataset name 145 - pretrained (str): Pretrained weights source 146 - task (str): Task description 147 148 Examples: 149 >>> # U-Net for segmentation with notes 150 >>> manager = KerasModelManager( 151 ... 'smith_123456_unet_model_128px', 152 ... notes="Hypothesis: 128px patches will achieve F1 > 0.90", 153 ... patch_size=128, 154 ... architecture='U-Net' 155 ... ) 156 157 >>> # Versioned CNN classifier 158 >>> manager = KerasModelManager( 159 ... 'mnist_cnn', 160 ... version='1.0', 161 ... dataset='MNIST', 162 ... input_shape=(28, 28, 1) 163 ... ) 164 """ 165 self.model_name = model_name 166 self.version = version 167 self.metadata = metadata 168 self.model = None 169 self.history = None 170 self.custom_objects = {} 171 self._callbacks = None 172 self._notes = [] 173 174 # Add initial note if provided 175 if notes: 176 self._add_note(notes) 177 178 # Generate filenames 179 version_str = f"_v{version}" if version else "" 180 self.base_name = f"{model_name}{version_str}" 181 self.model_path = f"{self.base_name}.h5" 182 self.history_path = f"{self.base_name}_history.json" 183 self.summary_path = f"{self.base_name}_summary.txt" 184 185 def set_model(self, model, custom_objects=None): 186 """ 187 Set the model to be managed. 188 189 Args: 190 model: Compiled Keras model, or None if only loading. 191 custom_objects: Dict of custom objects needed for model loading 192 (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}). 193 194 Examples: 195 >>> # With custom metric 196 >>> manager.set_model(unet_model, custom_objects={'f1': f1}) 197 198 >>> # Standard model, no custom objects 199 >>> manager.set_model(resnet_model) 200 201 >>> # Preparing to load (no model yet) 202 >>> manager.set_model(None, custom_objects={'f1': f1}) 203 """ 204 self.model = model 205 if custom_objects: 206 self.custom_objects = custom_objects 207 208 def _add_note(self, note): 209 """Internal method to add a timestamped note.""" 210 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S') 211 self._notes.append({ 212 'timestamp': timestamp, 213 'note': note 214 }) 215 216 @property 217 def notes(self): 218 """Get all notes as a formatted string.""" 219 if not self._notes: 220 return "" 221 return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes]) 222 223 @notes.setter 224 def notes(self, note): 225 """Append a new note with automatic timestamp.""" 226 self._add_note(note) 227 228 def set_notes(self, note, replace=False): 229 """ 230 Set notes with option to replace all existing notes. 231 232 Args: 233 note: The note text to add or set. 234 replace: If True, replaces all existing notes. If False, appends. 235 236 Examples: 237 >>> # Append (default) 238 >>> manager.set_notes("Added dropout layers") 239 240 >>> # Replace all notes 241 >>> manager.set_notes("Starting fresh experiment", replace=True) 242 """ 243 if replace: 244 self._notes = [] 245 self._add_note(note) 246 247 def clear_notes(self): 248 """Clear all notes.""" 249 self._notes = [] 250 251 def get_notes_list(self): 252 """ 253 Get notes as a list of dictionaries. 254 255 Returns: 256 List of dicts with 'timestamp' and 'note' keys. 257 """ 258 return self._notes.copy() 259 260 def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 261 batch_size=16, callbacks=None, verbose=1, **fit_kwargs): 262 """ 263 Train the model and store history. 264 265 Args: 266 X_train: Training data (numpy array, generator, or dataset). 267 y_train: Training labels (numpy array, or None if using generator). 268 X_val: Optional validation data (or None if using validation_data kwarg). 269 y_val: Optional validation labels (or None if using validation_data kwarg). 270 epochs: Maximum number of epochs. 271 batch_size: Batch size for training (ignored if using generator). 272 callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint). 273 verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch). 274 **fit_kwargs: Additional arguments passed to model.fit() 275 (e.g., steps_per_epoch, validation_steps, validation_data). 276 277 Returns: 278 Training history object. 279 """ 280 if self.model is None: 281 raise ValueError("Model not set. Call set_model() first.") 282 283 # Store callbacks for later serialization 284 self._callbacks = callbacks 285 286 # Handle validation_data - check if it's in fit_kwargs first 287 validation_data = fit_kwargs.pop('validation_data', None) 288 289 # If not in fit_kwargs and we have X_val/y_val, construct it 290 if validation_data is None and X_val is not None and y_val is not None: 291 validation_data = (X_val, y_val) 292 293 self.history = self.model.fit( 294 X_train, y_train, 295 validation_data=validation_data, 296 batch_size=batch_size, 297 epochs=epochs, 298 callbacks=callbacks, 299 verbose=verbose, 300 **fit_kwargs 301 ) 302 303 return self.history 304 305 def save(self, output_dir='.'): 306 """ 307 Save model, history, and summary to disk. 308 309 Saves three files: 310 1. .h5 file - Complete model (architecture, weights, optimizer state) 311 2. _history.json - Training metrics for all epochs 312 3. _summary.txt - Human-readable summary 313 314 Args: 315 output_dir: Directory to save files. Created if it doesn't exist. 316 Default is current directory. 317 318 Examples: 319 >>> # Save to current directory 320 >>> manager.save() 321 322 >>> # Save to specific directory 323 >>> manager.save('trained_models') 324 325 >>> # Save to nested path 326 >>> manager.save('models/unet/version_1') 327 """ 328 if self.model is None: 329 raise ValueError("No model to save. Train or set a model first.") 330 331 output_dir = Path(output_dir) 332 output_dir.mkdir(parents=True, exist_ok=True) 333 334 # Save model 335 model_file = output_dir / self.model_path 336 self.model.save(str(model_file)) 337 print(f"Model saved: {model_file}") 338 339 # Save history if available 340 if self.history is not None: 341 self._save_history(output_dir) 342 self._save_summary(output_dir) 343 344 def _serialize_custom_objects(self): 345 """ 346 Serialize information about custom objects for documentation. 347 348 Returns: 349 List of dictionaries with custom object metadata. 350 """ 351 import inspect 352 import hashlib 353 354 if not self.custom_objects: 355 return [] 356 357 serialized = [] 358 359 for name, obj in self.custom_objects.items(): 360 obj_info = { 361 'name': name, 362 'type': type(obj).__name__ 363 } 364 365 # Try to get useful information about the object 366 try: 367 # Get module 368 if hasattr(obj, '__module__'): 369 obj_info['module'] = obj.__module__ 370 371 # Get docstring 372 if hasattr(obj, '__doc__') and obj.__doc__: 373 obj_info['docstring'] = obj.__doc__.strip() 374 375 # For functions, get signature 376 if callable(obj): 377 try: 378 sig = inspect.signature(obj) 379 obj_info['signature'] = str(sig) 380 except (ValueError, TypeError): 381 pass 382 383 # Get source code hash (for version tracking) 384 try: 385 source = inspect.getsource(obj) 386 source_hash = hashlib.md5(source.encode()).hexdigest() 387 obj_info['source_hash'] = source_hash 388 except (OSError, TypeError): 389 pass 390 391 serialized.append(obj_info) 392 393 except Exception as e: 394 # If we can't serialize, at least save the name and type 395 obj_info['error'] = str(e) 396 serialized.append(obj_info) 397 398 return serialized 399 400 def _serialize_callbacks(self, callbacks): 401 """ 402 Serialize callback configurations. 403 404 Args: 405 callbacks: List of Keras callbacks or None. 406 407 Returns: 408 List of serialized callback dictionaries. 409 """ 410 import warnings 411 412 if callbacks is None: 413 return [] 414 415 serialized = [] 416 417 for callback in callbacks: 418 callback_info = { 419 'class': callback.__class__.__name__, 420 'module': callback.__class__.__module__ 421 } 422 423 # Try to extract configuration 424 try: 425 # Get callback config if available 426 if hasattr(callback, 'get_config'): 427 callback_info['config'] = callback.get_config() 428 else: 429 # Manually extract common attributes 430 config = {} 431 432 # EarlyStopping attributes 433 if hasattr(callback, 'monitor'): 434 config['monitor'] = callback.monitor 435 if hasattr(callback, 'patience'): 436 config['patience'] = callback.patience 437 if hasattr(callback, 'restore_best_weights'): 438 config['restore_best_weights'] = callback.restore_best_weights 439 if hasattr(callback, 'mode'): 440 config['mode'] = callback.mode 441 if hasattr(callback, 'min_delta'): 442 config['min_delta'] = callback.min_delta 443 444 # ReduceLROnPlateau attributes 445 if hasattr(callback, 'factor'): 446 config['factor'] = callback.factor 447 if hasattr(callback, 'cooldown'): 448 config['cooldown'] = callback.cooldown 449 if hasattr(callback, 'min_lr'): 450 config['min_lr'] = callback.min_lr 451 452 # ModelCheckpoint attributes 453 if hasattr(callback, 'filepath'): 454 config['filepath'] = callback.filepath 455 if hasattr(callback, 'save_best_only'): 456 config['save_best_only'] = callback.save_best_only 457 if hasattr(callback, 'save_weights_only'): 458 config['save_weights_only'] = callback.save_weights_only 459 460 # LearningRateScheduler 461 if hasattr(callback, 'schedule'): 462 warnings.warn( 463 f"Callback {callback.__class__.__name__} has a 'schedule' " 464 f"function that cannot be serialized. Saving class name only.", 465 UserWarning 466 ) 467 config['schedule'] = '<function>' 468 469 if config: 470 callback_info['config'] = config 471 else: 472 warnings.warn( 473 f"Could not extract configuration from callback " 474 f"{callback.__class__.__name__}. Saving class name only.", 475 UserWarning 476 ) 477 478 # Check if config is JSON serializable 479 json.dumps(callback_info) 480 serialized.append(callback_info) 481 482 except (TypeError, ValueError) as e: 483 warnings.warn( 484 f"Callback {callback.__class__.__name__} could not be fully " 485 f"serialized: {str(e)}. Saving class name only.", 486 UserWarning 487 ) 488 serialized.append({ 489 'class': callback.__class__.__name__, 490 'module': callback.__class__.__module__, 491 'error': str(e) 492 }) 493 494 return serialized 495 496 def _save_history(self, output_dir): 497 """Save training history to JSON.""" 498 history_dict = {} 499 500 for key, values in self.history.history.items(): 501 # Convert to list and ensure JSON serializable 502 if isinstance(values, np.ndarray): 503 history_dict[key] = values.tolist() 504 elif isinstance(values, list): 505 history_dict[key] = [float(x) for x in values] 506 else: 507 history_dict[key] = values 508 509 # Serialize callbacks 510 callbacks_info = self._serialize_callbacks(self._callbacks) 511 512 # Serialize custom objects 513 custom_objects_info = self._serialize_custom_objects() 514 515 # Add metadata 516 history_dict['metadata'] = { 517 'model_name': self.model_name, 518 'version': self.version, 519 'epochs_trained': len(self.history.history['loss']), 520 'saved_at': datetime.now().isoformat(), 521 'callbacks': callbacks_info, 522 'custom_objects': custom_objects_info, 523 'notes': self._notes, 524 **self.metadata 525 } 526 527 history_file = output_dir / self.history_path 528 with open(history_file, 'w') as f: 529 json.dump(history_dict, f, indent=4) 530 print(f"History saved: {history_file}") 531 532 def _save_summary(self, output_dir): 533 """Save human-readable summary.""" 534 summary_file = output_dir / self.summary_path 535 536 with open(summary_file, 'w') as f: 537 f.write(f"{self.model_name} Model Summary\n") 538 f.write("=" * 60 + "\n\n") 539 540 # Write metadata 541 if self.version: 542 f.write(f"Version: {self.version}\n") 543 for key, value in self.metadata.items(): 544 f.write(f"{key}: {value}\n") 545 546 f.write(f"Model file: {self.model_path}\n") 547 f.write(f"Saved: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") 548 549 # Write notes 550 if self._notes: 551 f.write("Notes / Hypothesis:\n") 552 f.write("-" * 60 + "\n") 553 for note_entry in self._notes: 554 f.write(f"[{note_entry['timestamp']}] {note_entry['note']}\n") 555 f.write("\n") 556 557 # Write custom objects information 558 if self.custom_objects: 559 f.write("Custom Objects:\n") 560 f.write("-" * 60 + "\n") 561 custom_objs_info = self._serialize_custom_objects() 562 for obj_info in custom_objs_info: 563 f.write(f" {obj_info['name']} ({obj_info['type']}):\n") 564 if 'module' in obj_info: 565 f.write(f" module: {obj_info['module']}\n") 566 if 'signature' in obj_info: 567 f.write(f" signature: {obj_info['signature']}\n") 568 if 'source_hash' in obj_info: 569 f.write(f" source_hash: {obj_info['source_hash']}\n") 570 if 'docstring' in obj_info: 571 # First line of docstring only 572 first_line = obj_info['docstring'].split('\n')[0] 573 f.write(f" docstring: {first_line}\n") 574 f.write("\n") 575 576 # Write callback information 577 if self._callbacks: 578 f.write("Callbacks Used:\n") 579 f.write("-" * 60 + "\n") 580 callbacks_info = self._serialize_callbacks(self._callbacks) 581 for cb_info in callbacks_info: 582 f.write(f" {cb_info['class']}:\n") 583 if 'config' in cb_info: 584 for key, value in cb_info['config'].items(): 585 f.write(f" {key}: {value}\n") 586 elif 'error' in cb_info: 587 f.write(f" (serialization error: {cb_info['error']})\n") 588 f.write("\n") 589 590 if self.history: 591 history = self.history.history 592 f.write("Training Results\n") 593 f.write("-" * 60 + "\n") 594 f.write(f"Epochs trained: {len(history['loss'])}\n\n") 595 596 # Final metrics 597 f.write("Final Metrics:\n") 598 for key in sorted(history.keys()): 599 if not key.startswith('val_'): 600 val_key = f'val_{key}' 601 train_val = history[key][-1] 602 f.write(f" {key}: {train_val:.4f}") 603 if val_key in history: 604 val_val = history[val_key][-1] 605 f.write(f" | {val_key}: {val_val:.4f}") 606 f.write("\n") 607 608 # Best metrics (find all val_ metrics and report best) 609 f.write("\nBest Validation Metrics:\n") 610 for key in sorted(history.keys()): 611 if key.startswith('val_'): 612 # Determine if higher or lower is better 613 if 'loss' in key or 'error' in key: 614 best_idx = np.argmin(history[key]) 615 best_val = history[key][best_idx] 616 f.write(f" {key}: {best_val:.4f} (epoch {best_idx + 1}, min)\n") 617 else: 618 best_idx = np.argmax(history[key]) 619 best_val = history[key][best_idx] 620 f.write(f" {key}: {best_val:.4f} (epoch {best_idx + 1}, max)\n") 621 622 print(f"Summary saved: {summary_file}") 623 624 def load(self, model_path=None, load_history=True): 625 """ 626 Load a saved model and optionally its history. 627 628 Args: 629 model_path: Path to model file. If None, uses default naming convention. 630 load_history: Whether to load training history JSON. Default is True. 631 632 Returns: 633 Tuple of (loaded_model, history_dict or None). 634 635 Examples: 636 >>> # Load with default path 637 >>> manager = KerasModelManager('smith_123456_unet_model_128px') 638 >>> manager.set_model(None, custom_objects={'f1': f1}) 639 >>> model, history = manager.load() 640 641 >>> # Load from specific path 642 >>> model, history = manager.load('models/backup/model.h5') 643 644 >>> # Load model only, skip history 645 >>> model, _ = manager.load(load_history=False) 646 """ 647 if model_path is None: 648 model_path = self.model_path 649 650 model_path = Path(model_path) 651 652 if not model_path.exists(): 653 raise FileNotFoundError(f"Model file not found: {model_path}") 654 655 # Load model with custom objects 656 self.model = keras.models.load_model( 657 str(model_path), 658 custom_objects=self.custom_objects 659 ) 660 print(f"Model loaded: {model_path}") 661 662 loaded_history = None 663 664 # Load history if available 665 if load_history: 666 history_path = model_path.parent / self.history_path 667 if history_path.exists(): 668 with open(history_path, 'r') as f: 669 loaded_history = json.load(f) 670 print(f"History loaded: {history_path}") 671 672 # Print summary 673 if 'metadata' in loaded_history: 674 meta = loaded_history['metadata'] 675 print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs") 676 677 # Print notes if available 678 if 'notes' in meta and meta['notes']: 679 print("\nNotes:") 680 for note_entry in meta['notes']: 681 print(f" [{note_entry['timestamp']}] {note_entry['note']}") 682 683 # Print custom objects if available 684 if 'custom_objects' in meta and meta['custom_objects']: 685 print("\nCustom objects used:") 686 for obj_info in meta['custom_objects']: 687 print(f" - {obj_info['name']} ({obj_info['type']})") 688 689 # Print other metadata 690 for key, val in meta.items(): 691 if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']: 692 print(f"{key}: {val}") 693 694 # Print best validation metrics 695 for key in loaded_history.keys(): 696 if key.startswith('val_') and key != 'val_loss': 697 best_val = max(loaded_history[key]) 698 print(f"Best {key}: {best_val:.4f}") 699 700 return self.model, loaded_history 701 702 def get_best_metrics(self, metric_preferences=None): 703 """ 704 Get best metrics from training history. 705 706 Args: 707 metric_preferences: Dict mapping metric names to 'min' or 'max'. 708 If None, uses 'min' for loss/error metrics, 709 'max' for all others. 710 711 Returns: 712 Dictionary with best metrics, their epochs, and final values. 713 Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'. 714 715 Examples: 716 >>> metrics = manager.get_best_metrics() 717 >>> print(f"Best F1: {metrics['best_val_f1']:.4f}") 718 >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}") 719 720 >>> # Custom metric preferences 721 >>> metrics = manager.get_best_metrics( 722 ... metric_preferences={'val_custom': 'min'} 723 ... ) 724 """ 725 if self.history is None: 726 raise ValueError("No training history available.") 727 728 history = self.history.history 729 metrics = {} 730 731 # Default preferences 732 if metric_preferences is None: 733 metric_preferences = {} 734 735 for key in history.keys(): 736 if key.startswith('val_'): 737 # Determine optimization direction 738 if key in metric_preferences: 739 direction = metric_preferences[key] 740 elif 'loss' in key or 'error' in key: 741 direction = 'min' 742 else: 743 direction = 'max' 744 745 # Find best value 746 if direction == 'min': 747 best_idx = np.argmin(history[key]) 748 else: 749 best_idx = np.argmax(history[key]) 750 751 metrics[f'best_{key}'] = history[key][best_idx] 752 metrics[f'best_{key}_epoch'] = best_idx + 1 753 metrics[f'final_{key}'] = history[key][-1] 754 755 return metrics 756 757 def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)): 758 """ 759 Plot training history. 760 761 Args: 762 history_dict: History dict from load(). If None, uses self.history. 763 metrics: List of metrics to plot. If None, plots all metrics. 764 figsize: Figure size tuple. 765 766 Examples: 767 >>> # After training 768 >>> manager.plot_history() 769 770 >>> # After loading 771 >>> model, history = manager.load() 772 >>> manager.plot_history(history) 773 774 >>> # Plot specific metrics 775 >>> manager.plot_history(metrics=['loss', 'f1']) 776 """ 777 import matplotlib.pyplot as plt 778 779 if history_dict is None: 780 if self.history is None: 781 raise ValueError("No history available. Train or load a model first.") 782 history_dict = self.history.history 783 784 # Filter out metadata 785 history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'} 786 787 # Determine metrics to plot 788 if metrics is None: 789 # Plot all non-validation metrics 790 metrics = [k for k in history_dict.keys() if not k.startswith('val_')] 791 792 n_metrics = len(metrics) 793 fig, axes = plt.subplots(1, n_metrics, figsize=figsize) 794 795 # Handle single metric case 796 if n_metrics == 1: 797 axes = [axes] 798 799 for ax, metric in zip(axes, metrics): 800 val_metric = f'val_{metric}' 801 802 ax.plot(history_dict[metric], label=f'Training {metric}') 803 if val_metric in history_dict: 804 ax.plot(history_dict[val_metric], label=f'Validation {metric}') 805 806 ax.set_xlabel('Epoch') 807 ax.set_ylabel(metric.replace('_', ' ').title()) 808 ax.set_title(f'{metric.replace("_", " ").title()} over Epochs') 809 ax.legend() 810 ax.grid(True, alpha=0.3) 811 812 plt.tight_layout() 813 plt.show()
Manages Keras model lifecycle: training, saving, loading, and metadata.
128 def __init__(self, model_name, version=None, notes=None, **metadata): 129 """ 130 Initialize the model manager. 131 132 Args: 133 model_name: Base name for the model files. For Task 5, use format: 134 'studentname_studentnumber_modeltype_patchsizepx' 135 version: Optional version string or number (e.g., '1.0', 'v2', 2). 136 notes: Optional initial note/hypothesis for this model. 137 **metadata: Additional metadata to store with the model. Common examples: 138 - student_name (str): Student name for academic projects 139 - student_number (str): Student number for academic projects 140 - architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50') 141 - patch_size (int): Patch size for segmentation models 142 - input_shape (tuple): Input shape for the model 143 - num_classes (int): Number of output classes 144 - dataset (str): Dataset name 145 - pretrained (str): Pretrained weights source 146 - task (str): Task description 147 148 Examples: 149 >>> # U-Net for segmentation with notes 150 >>> manager = KerasModelManager( 151 ... 'smith_123456_unet_model_128px', 152 ... notes="Hypothesis: 128px patches will achieve F1 > 0.90", 153 ... patch_size=128, 154 ... architecture='U-Net' 155 ... ) 156 157 >>> # Versioned CNN classifier 158 >>> manager = KerasModelManager( 159 ... 'mnist_cnn', 160 ... version='1.0', 161 ... dataset='MNIST', 162 ... input_shape=(28, 28, 1) 163 ... ) 164 """ 165 self.model_name = model_name 166 self.version = version 167 self.metadata = metadata 168 self.model = None 169 self.history = None 170 self.custom_objects = {} 171 self._callbacks = None 172 self._notes = [] 173 174 # Add initial note if provided 175 if notes: 176 self._add_note(notes) 177 178 # Generate filenames 179 version_str = f"_v{version}" if version else "" 180 self.base_name = f"{model_name}{version_str}" 181 self.model_path = f"{self.base_name}.h5" 182 self.history_path = f"{self.base_name}_history.json" 183 self.summary_path = f"{self.base_name}_summary.txt"
Initialize the model manager.
Arguments:
- model_name: Base name for the model files. For Task 5, use format: 'studentname_studentnumber_modeltype_patchsizepx'
- version: Optional version string or number (e.g., '1.0', 'v2', 2).
- notes: Optional initial note/hypothesis for this model.
- **metadata: Additional metadata to store with the model. Common examples:
- student_name (str): Student name for academic projects
- student_number (str): Student number for academic projects
- architecture (str): Model architecture (e.g., 'U-Net', 'ResNet50')
- patch_size (int): Patch size for segmentation models
- input_shape (tuple): Input shape for the model
- num_classes (int): Number of output classes
- dataset (str): Dataset name
- pretrained (str): Pretrained weights source
- task (str): Task description
Examples:
>>> # U-Net for segmentation with notes >>> manager = KerasModelManager( ... 'smith_123456_unet_model_128px', ... notes="Hypothesis: 128px patches will achieve F1 > 0.90", ... patch_size=128, ... architecture='U-Net' ... )>>> # Versioned CNN classifier >>> manager = KerasModelManager( ... 'mnist_cnn', ... version='1.0', ... dataset='MNIST', ... input_shape=(28, 28, 1) ... )
185 def set_model(self, model, custom_objects=None): 186 """ 187 Set the model to be managed. 188 189 Args: 190 model: Compiled Keras model, or None if only loading. 191 custom_objects: Dict of custom objects needed for model loading 192 (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}). 193 194 Examples: 195 >>> # With custom metric 196 >>> manager.set_model(unet_model, custom_objects={'f1': f1}) 197 198 >>> # Standard model, no custom objects 199 >>> manager.set_model(resnet_model) 200 201 >>> # Preparing to load (no model yet) 202 >>> manager.set_model(None, custom_objects={'f1': f1}) 203 """ 204 self.model = model 205 if custom_objects: 206 self.custom_objects = custom_objects
Set the model to be managed.
Arguments:
- model: Compiled Keras model, or None if only loading.
- custom_objects: Dict of custom objects needed for model loading (e.g., {'f1': f1_function, 'dice_loss': dice_loss_fn}).
Examples:
>>> # With custom metric >>> manager.set_model(unet_model, custom_objects={'f1': f1})>>> # Standard model, no custom objects >>> manager.set_model(resnet_model)>>> # Preparing to load (no model yet) >>> manager.set_model(None, custom_objects={'f1': f1})
216 @property 217 def notes(self): 218 """Get all notes as a formatted string.""" 219 if not self._notes: 220 return "" 221 return "\n".join([f"[{n['timestamp']}] {n['note']}" for n in self._notes])
Get all notes as a formatted string.
228 def set_notes(self, note, replace=False): 229 """ 230 Set notes with option to replace all existing notes. 231 232 Args: 233 note: The note text to add or set. 234 replace: If True, replaces all existing notes. If False, appends. 235 236 Examples: 237 >>> # Append (default) 238 >>> manager.set_notes("Added dropout layers") 239 240 >>> # Replace all notes 241 >>> manager.set_notes("Starting fresh experiment", replace=True) 242 """ 243 if replace: 244 self._notes = [] 245 self._add_note(note)
Set notes with option to replace all existing notes.
Arguments:
- note: The note text to add or set.
- replace: If True, replaces all existing notes. If False, appends.
Examples:
>>> # Append (default) >>> manager.set_notes("Added dropout layers")>>> # Replace all notes >>> manager.set_notes("Starting fresh experiment", replace=True)
251 def get_notes_list(self): 252 """ 253 Get notes as a list of dictionaries. 254 255 Returns: 256 List of dicts with 'timestamp' and 'note' keys. 257 """ 258 return self._notes.copy()
Get notes as a list of dictionaries.
Returns:
List of dicts with 'timestamp' and 'note' keys.
260 def train(self, X_train, y_train=None, X_val=None, y_val=None, epochs=50, 261 batch_size=16, callbacks=None, verbose=1, **fit_kwargs): 262 """ 263 Train the model and store history. 264 265 Args: 266 X_train: Training data (numpy array, generator, or dataset). 267 y_train: Training labels (numpy array, or None if using generator). 268 X_val: Optional validation data (or None if using validation_data kwarg). 269 y_val: Optional validation labels (or None if using validation_data kwarg). 270 epochs: Maximum number of epochs. 271 batch_size: Batch size for training (ignored if using generator). 272 callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint). 273 verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch). 274 **fit_kwargs: Additional arguments passed to model.fit() 275 (e.g., steps_per_epoch, validation_steps, validation_data). 276 277 Returns: 278 Training history object. 279 """ 280 if self.model is None: 281 raise ValueError("Model not set. Call set_model() first.") 282 283 # Store callbacks for later serialization 284 self._callbacks = callbacks 285 286 # Handle validation_data - check if it's in fit_kwargs first 287 validation_data = fit_kwargs.pop('validation_data', None) 288 289 # If not in fit_kwargs and we have X_val/y_val, construct it 290 if validation_data is None and X_val is not None and y_val is not None: 291 validation_data = (X_val, y_val) 292 293 self.history = self.model.fit( 294 X_train, y_train, 295 validation_data=validation_data, 296 batch_size=batch_size, 297 epochs=epochs, 298 callbacks=callbacks, 299 verbose=verbose, 300 **fit_kwargs 301 ) 302 303 return self.history
Train the model and store history.
Arguments:
- X_train: Training data (numpy array, generator, or dataset).
- y_train: Training labels (numpy array, or None if using generator).
- X_val: Optional validation data (or None if using validation_data kwarg).
- y_val: Optional validation labels (or None if using validation_data kwarg).
- epochs: Maximum number of epochs.
- batch_size: Batch size for training (ignored if using generator).
- callbacks: List of Keras callbacks (e.g., EarlyStopping, ModelCheckpoint).
- verbose: Verbosity level (0=silent, 1=progress bar, 2=one line per epoch).
- **fit_kwargs: Additional arguments passed to model.fit() (e.g., steps_per_epoch, validation_steps, validation_data).
Returns:
Training history object.
305 def save(self, output_dir='.'): 306 """ 307 Save model, history, and summary to disk. 308 309 Saves three files: 310 1. .h5 file - Complete model (architecture, weights, optimizer state) 311 2. _history.json - Training metrics for all epochs 312 3. _summary.txt - Human-readable summary 313 314 Args: 315 output_dir: Directory to save files. Created if it doesn't exist. 316 Default is current directory. 317 318 Examples: 319 >>> # Save to current directory 320 >>> manager.save() 321 322 >>> # Save to specific directory 323 >>> manager.save('trained_models') 324 325 >>> # Save to nested path 326 >>> manager.save('models/unet/version_1') 327 """ 328 if self.model is None: 329 raise ValueError("No model to save. Train or set a model first.") 330 331 output_dir = Path(output_dir) 332 output_dir.mkdir(parents=True, exist_ok=True) 333 334 # Save model 335 model_file = output_dir / self.model_path 336 self.model.save(str(model_file)) 337 print(f"Model saved: {model_file}") 338 339 # Save history if available 340 if self.history is not None: 341 self._save_history(output_dir) 342 self._save_summary(output_dir)
Save model, history, and summary to disk.
Saves three files:
- .h5 file - Complete model (architecture, weights, optimizer state)
- _history.json - Training metrics for all epochs
- _summary.txt - Human-readable summary
Arguments:
- output_dir: Directory to save files. Created if it doesn't exist. Default is current directory.
Examples:
>>> # Save to current directory >>> manager.save()>>> # Save to specific directory >>> manager.save('trained_models')>>> # Save to nested path >>> manager.save('models/unet/version_1')
624 def load(self, model_path=None, load_history=True): 625 """ 626 Load a saved model and optionally its history. 627 628 Args: 629 model_path: Path to model file. If None, uses default naming convention. 630 load_history: Whether to load training history JSON. Default is True. 631 632 Returns: 633 Tuple of (loaded_model, history_dict or None). 634 635 Examples: 636 >>> # Load with default path 637 >>> manager = KerasModelManager('smith_123456_unet_model_128px') 638 >>> manager.set_model(None, custom_objects={'f1': f1}) 639 >>> model, history = manager.load() 640 641 >>> # Load from specific path 642 >>> model, history = manager.load('models/backup/model.h5') 643 644 >>> # Load model only, skip history 645 >>> model, _ = manager.load(load_history=False) 646 """ 647 if model_path is None: 648 model_path = self.model_path 649 650 model_path = Path(model_path) 651 652 if not model_path.exists(): 653 raise FileNotFoundError(f"Model file not found: {model_path}") 654 655 # Load model with custom objects 656 self.model = keras.models.load_model( 657 str(model_path), 658 custom_objects=self.custom_objects 659 ) 660 print(f"Model loaded: {model_path}") 661 662 loaded_history = None 663 664 # Load history if available 665 if load_history: 666 history_path = model_path.parent / self.history_path 667 if history_path.exists(): 668 with open(history_path, 'r') as f: 669 loaded_history = json.load(f) 670 print(f"History loaded: {history_path}") 671 672 # Print summary 673 if 'metadata' in loaded_history: 674 meta = loaded_history['metadata'] 675 print(f"Trained for {meta.get('epochs_trained', 'unknown')} epochs") 676 677 # Print notes if available 678 if 'notes' in meta and meta['notes']: 679 print("\nNotes:") 680 for note_entry in meta['notes']: 681 print(f" [{note_entry['timestamp']}] {note_entry['note']}") 682 683 # Print custom objects if available 684 if 'custom_objects' in meta and meta['custom_objects']: 685 print("\nCustom objects used:") 686 for obj_info in meta['custom_objects']: 687 print(f" - {obj_info['name']} ({obj_info['type']})") 688 689 # Print other metadata 690 for key, val in meta.items(): 691 if key not in ['model_name', 'version', 'epochs_trained', 'saved_at', 'callbacks', 'custom_objects', 'notes']: 692 print(f"{key}: {val}") 693 694 # Print best validation metrics 695 for key in loaded_history.keys(): 696 if key.startswith('val_') and key != 'val_loss': 697 best_val = max(loaded_history[key]) 698 print(f"Best {key}: {best_val:.4f}") 699 700 return self.model, loaded_history
Load a saved model and optionally its history.
Arguments:
- model_path: Path to model file. If None, uses default naming convention.
- load_history: Whether to load training history JSON. Default is True.
Returns:
Tuple of (loaded_model, history_dict or None).
Examples:
>>> # Load with default path >>> manager = KerasModelManager('smith_123456_unet_model_128px') >>> manager.set_model(None, custom_objects={'f1': f1}) >>> model, history = manager.load()>>> # Load from specific path >>> model, history = manager.load('models/backup/model.h5')>>> # Load model only, skip history >>> model, _ = manager.load(load_history=False)
702 def get_best_metrics(self, metric_preferences=None): 703 """ 704 Get best metrics from training history. 705 706 Args: 707 metric_preferences: Dict mapping metric names to 'min' or 'max'. 708 If None, uses 'min' for loss/error metrics, 709 'max' for all others. 710 711 Returns: 712 Dictionary with best metrics, their epochs, and final values. 713 Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'. 714 715 Examples: 716 >>> metrics = manager.get_best_metrics() 717 >>> print(f"Best F1: {metrics['best_val_f1']:.4f}") 718 >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}") 719 720 >>> # Custom metric preferences 721 >>> metrics = manager.get_best_metrics( 722 ... metric_preferences={'val_custom': 'min'} 723 ... ) 724 """ 725 if self.history is None: 726 raise ValueError("No training history available.") 727 728 history = self.history.history 729 metrics = {} 730 731 # Default preferences 732 if metric_preferences is None: 733 metric_preferences = {} 734 735 for key in history.keys(): 736 if key.startswith('val_'): 737 # Determine optimization direction 738 if key in metric_preferences: 739 direction = metric_preferences[key] 740 elif 'loss' in key or 'error' in key: 741 direction = 'min' 742 else: 743 direction = 'max' 744 745 # Find best value 746 if direction == 'min': 747 best_idx = np.argmin(history[key]) 748 else: 749 best_idx = np.argmax(history[key]) 750 751 metrics[f'best_{key}'] = history[key][best_idx] 752 metrics[f'best_{key}_epoch'] = best_idx + 1 753 metrics[f'final_{key}'] = history[key][-1] 754 755 return metrics
Get best metrics from training history.
Arguments:
- metric_preferences: Dict mapping metric names to 'min' or 'max'. If None, uses 'min' for loss/error metrics, 'max' for all others.
Returns:
Dictionary with best metrics, their epochs, and final values. Keys include 'best_{metric}', 'best_{metric}_epoch', 'final_{metric}'.
Examples:
>>> metrics = manager.get_best_metrics() >>> print(f"Best F1: {metrics['best_val_f1']:.4f}") >>> print(f"Achieved at epoch: {metrics['best_val_f1_epoch']}")>>> # Custom metric preferences >>> metrics = manager.get_best_metrics( ... metric_preferences={'val_custom': 'min'} ... )
757 def plot_history(self, history_dict=None, metrics=None, figsize=(14, 5)): 758 """ 759 Plot training history. 760 761 Args: 762 history_dict: History dict from load(). If None, uses self.history. 763 metrics: List of metrics to plot. If None, plots all metrics. 764 figsize: Figure size tuple. 765 766 Examples: 767 >>> # After training 768 >>> manager.plot_history() 769 770 >>> # After loading 771 >>> model, history = manager.load() 772 >>> manager.plot_history(history) 773 774 >>> # Plot specific metrics 775 >>> manager.plot_history(metrics=['loss', 'f1']) 776 """ 777 import matplotlib.pyplot as plt 778 779 if history_dict is None: 780 if self.history is None: 781 raise ValueError("No history available. Train or load a model first.") 782 history_dict = self.history.history 783 784 # Filter out metadata 785 history_dict = {k: v for k, v in history_dict.items() if k != 'metadata'} 786 787 # Determine metrics to plot 788 if metrics is None: 789 # Plot all non-validation metrics 790 metrics = [k for k in history_dict.keys() if not k.startswith('val_')] 791 792 n_metrics = len(metrics) 793 fig, axes = plt.subplots(1, n_metrics, figsize=figsize) 794 795 # Handle single metric case 796 if n_metrics == 1: 797 axes = [axes] 798 799 for ax, metric in zip(axes, metrics): 800 val_metric = f'val_{metric}' 801 802 ax.plot(history_dict[metric], label=f'Training {metric}') 803 if val_metric in history_dict: 804 ax.plot(history_dict[val_metric], label=f'Validation {metric}') 805 806 ax.set_xlabel('Epoch') 807 ax.set_ylabel(metric.replace('_', ' ').title()) 808 ax.set_title(f'{metric.replace("_", " ").title()} over Epochs') 809 ax.legend() 810 ax.grid(True, alpha=0.3) 811 812 plt.tight_layout() 813 plt.show()
Plot training history.
Arguments:
- history_dict: History dict from load(). If None, uses self.history.
- metrics: List of metrics to plot. If None, plots all metrics.
- figsize: Figure size tuple.
Examples:
>>> # After training >>> manager.plot_history()>>> # After loading >>> model, history = manager.load() >>> manager.plot_history(history)>>> # Plot specific metrics >>> manager.plot_history(metrics=['loss', 'f1'])