Source code for rxnDB.visualize

#######################################################
## .0. Load Libraries                            !!! ##
#######################################################
from dataclasses import dataclass
from typing import Any

import numpy as np
import pandas as pd
import plotly.graph_objects as go


#######################################################
## .1. Plotly                                    !!! ##
#######################################################
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]@dataclass class RxnDBPlotter: df: pd.DataFrame ids: list[str] dark_mode: bool = False font_size: float = 20 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __post_init__(self): """""" if "rxn_color_key" not in self.df.columns: raise ValueError( "DataFrame must contain 'rxn_color_key' column. Did you use the processor's get_colors_for_filtered_df method?" )
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def plot(self, temperature_units: str, pressure_units: str) -> go.Figure: """ Plot reaction lines (phase diagram) using plotly. """ required_cols = { "unique_id", "reaction", "reaction_names", "reactants", "reactant_names", "reactant_groups", "reactant_formulas", "products", "product_names", "product_groups", "product_formulas", "type", "units_P", "units_T", "T", "T_uncertainty", "P", "P_uncertainty", "plot_type", "reference", } if not required_cols.issubset(self.df.columns): missing = required_cols - set(self.df.columns) raise ValueError(f"Missing required columns in DataFrame: {missing}") if temperature_units == "celcius": temperature_units_label = "˚C" elif temperature_units == "kelvin": temperature_units_label = "K" else: raise ValueError(f"Unknown temperature unit: {temperature_units}") if pressure_units == "gigapascal": pressure_units_label = "GPa" elif pressure_units == "kilobar": pressure_units_label = "kbar" else: raise ValueError(f"Unknown pressure unit: {pressure_units}") fig = go.Figure() hovertemplate = ( "%{customdata[0]}<br>" "%{customdata[1]}<br>" f"(%{{x:.1f}} {temperature_units_label}, %{{y:.2f}} {pressure_units_label})<br>" "%{customdata[2]}<extra></extra>" ) for rid in self.ids: d = self.df.query("unique_id == @rid") if d.empty: continue color = d["rxn_color_key"].iloc[0] plot_type = d["plot_type"].iloc[0] if plot_type == "curve": fig.add_trace( go.Scatter( x=d["T"], y=d["P"], mode="lines", line=dict(width=2, color=color), hovertemplate=hovertemplate, customdata=np.stack( (d["reaction"], d["unique_id"], d["type"]), axis=-1 ), ) ) elif plot_type == "point": fig.add_trace( go.Scatter( x=d["T"], y=d["P"], mode="markers", marker=dict(size=8, color=color), error_x=dict( type="data", array=d["T_uncertainty"], visible=True ), error_y=dict( type="data", array=d["P_uncertainty"], visible=True ), hovertemplate=hovertemplate, customdata=np.stack( (d["reaction"], d["unique_id"], d["type"]), axis=-1 ), ) ) layout_settings = self._configure_layout() fig.update_layout( xaxis_title=f"Temperature ({temperature_units_label})", yaxis_title=f"Pressure ({pressure_units_label})", showlegend=False, autosize=True, **layout_settings, ) return fig
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def _configure_layout(self) -> dict[str, Any]: """""" border_color = "#E5E5E5" if self.dark_mode else "black" grid_color = "#999999" if self.dark_mode else "#E5E5E5" tick_color = "#E5E5E5" if self.dark_mode else "black" label_color = "#E5E5E5" if self.dark_mode else "black" plot_bgcolor = "#1D1F21" if self.dark_mode else "#FFF" paper_bgcolor = "#1D1F21" if self.dark_mode else "#FFF" font_color = "#E5E5E5" if self.dark_mode else "black" legend_bgcolor = "#404040" if self.dark_mode else "#FFF" return { "template": "plotly_dark" if self.dark_mode else "plotly_white", "font": {"size": self.font_size, "color": font_color}, "plot_bgcolor": plot_bgcolor, "paper_bgcolor": paper_bgcolor, "xaxis": { "gridcolor": grid_color, "title_font": {"color": label_color}, "tickfont": {"color": tick_color}, "showline": True, "linecolor": border_color, "linewidth": 2, "mirror": True, "constrain": "range", }, "yaxis": { "gridcolor": grid_color, "title_font": {"color": label_color}, "tickfont": {"color": tick_color}, "showline": True, "linecolor": border_color, "linewidth": 2, "mirror": True, "constrain": "range", }, "legend": { "font": {"color": font_color}, "bgcolor": legend_bgcolor, }, }