Source code for PlanetAlign.logger.plot

from typing import Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np


[docs] def plot_loss_curve(loss_logs: dict, save_path: Optional[Union[str, Path]] = None, title: str = 'Loss vs Epochs', xlabel: str = 'Epochs', ylabel: str = 'Loss'): """ Plot the loss curve along training (optimization). Parameters ---------- loss_logs : list[float] List of loss values for each epoch. save_path : str or pathlib.Path, optional Path to save the plot. If None, the plot will not be saved. Default is None. title : str, optional Title of the plot. Default is 'Training Curve'. xlabel : str, optional Label for the x-axis. Default is 'Epochs'. ylabel : str, optional Label for the y-axis. Default is 'Loss'. Returns ------- None """ plt.figure(figsize=(8, 6)) for key, values in loss_logs.items(): plt.plot(range(1, len(values) + 1), values, label=key, linewidth=2) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) # Calculate y-limits with margin all_losses = [value for values in loss_logs.values() for value in values] ymin, ymax = min(all_losses), max(all_losses) margin = (ymax - ymin) * 0.05 ymin = max(0, ymin - margin) # Ensure ymin is non-negative ymax += margin plt.ylim(ymin, ymax) max_epochs = max(len(values) for values in loss_logs.values()) tick_spacing = max(1, max_epochs // 10) # Show at most 10 major ticks epochs = list(range(0, max_epochs + 1, tick_spacing)) plt.xticks(epochs, [str(e) for e in epochs]) # Round tick interval to nearest 0.05, 0.1, 0.5, or 1 def round_tick_interval(value): if value >= 1: return round(value) # Round to whole numbers elif value >= 0.5: return round(value * 2) / 2 # Round to nearest 0.5 elif value >= 0.1: return round(value * 10) / 10 # Round to nearest 0.1 else: return round(value * 20) / 20 # Round to nearest 0.05 y_range = ymax - ymin base_tick_interval = y_range / 10 # Approximate tick interval for ~10 ticks tick_interval = round_tick_interval(base_tick_interval) # Generate y-ticks that end with 0 or 5 y_ticks = np.arange(np.ceil(ymin / tick_interval) * tick_interval, np.floor(ymax / tick_interval) * tick_interval + tick_interval, tick_interval) # Ensure ticks end in 0 or 5 by rounding y_ticks = np.array([round(y_tick * 20) / 20 for y_tick in y_ticks]) # Ensure all y-tick labels have the same decimal places decimal_places = 0 if tick_interval < 1: decimal_places = abs(int(np.log10(tick_interval))) # Count decimal places y_tick_labels = [f"{tick:.{decimal_places}f}" for tick in y_ticks] plt.yticks(y_ticks, y_tick_labels) plt.legend() plt.grid(True, linestyle='dashed', linewidth=0.5, alpha=0.7) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_hits_curve(hits_logs: dict, save_path: Optional[Union[str, Path]] = None, title: str = 'Hits@K Curve', xlabel: str = 'Epochs', ylabel: str = 'Hits@K'): """ Plot the Hits@K curve along training (optimization). Parameters ---------- hits_logs : list[dict] List of Hits@K scores for each epoch. save_path : str or Path, optional Path to save the plot. If None, the plot will not be saved. Default is None. title : str, optional Title of the plot. Default is 'Hits@K Curve'. xlabel : str, optional Label for the x-axis. Default is 'Epochs'. ylabel : str, optional Label for the y-axis. Default is 'Hits@K'. Returns ------- None """ plt.figure(figsize=(8, 6)) for key, values in hits_logs.items(): plt.plot(range(1, len(values) + 1), values, label=f"Hits@{key}", linewidth=2) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.ylim(-0.05, 1.05) # Set y-limits to be between -0.05 and 1.05 max_epochs = max(len(values) for values in hits_logs.values()) tick_spacing = max(1, max_epochs // 10) # Show at most 10 major ticks epochs = list(range(0, max_epochs + 1, tick_spacing)) plt.xticks(epochs, [str(e) for e in epochs]) plt.yticks(np.arange(0, 1.1, 0.1), [f"{i/10:.1f}" for i in range(11)]) # Set y-ticks from 0 to 1 with step of 0.1 plt.legend() plt.grid(True, linestyle='dashed', linewidth=0.5, alpha=0.7) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show()
[docs] def plot_mrr_curve(mrr_logs: Union[list, np.ndarray], save_path: Optional[Union[str, Path]] = None, title: str = 'MRR Curve', xlabel: str = 'Epochs', ylabel: str = 'MRR', xlim=None, ylim=None, grid: bool = True): """ Plot the MRR curve along training (optimization). Parameters ---------- mrr_logs : list or np.ndarray List of MRR values for each epoch. save_path : str or Path, optional Path to save the plot. If None, the plot will not be saved. Default is None. title : str, optional Title of the plot. Default is 'MRR Curve'. xlabel : str, optional Label for the x-axis. Default is 'Epochs'. ylabel : str, optional Label for the y-axis. Default is 'MRR'. xlim : tuple[float, float], optional Limits for the x-axis. Default is None. ylim : tuple[float, float], optional Limits for the y-axis. Default is None. grid : bool, optional Whether to show grid lines. Default is True. Returns ------- None """ plt.plot(range(1, len(mrr_logs) + 1), mrr_logs, linewidth=2) plt.title(title) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.ylim(-0.05, 1.05) # Set y-limits to be between -0.05 and 1.05 max_epochs = len(mrr_logs) tick_spacing = max(1, max_epochs // 10) # Show at most 10 major ticks epochs = list(range(0, max_epochs + 1, tick_spacing)) plt.xticks(epochs, [str(e) for e in epochs]) plt.yticks(np.arange(0, 1.1, 0.1), [f"{i / 10:.1f}" for i in range(11)]) # Set y-ticks from 0 to 1 with step of 0.1 plt.grid(True, linestyle='dashed', linewidth=0.5, alpha=0.7) if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.show()
if __name__ == '__main__': # Example usage example_loss_logs = { '1': np.sort(np.random.rand(200)), '10': np.sort(np.random.rand(200)), } example_mrr_logs = np.sort(np.random.rand(200)) # plot_hits_curve(example_loss_logs) plot_mrr_curve(example_mrr_logs)