#######################################################
## .0. Load Libraries !!! ##
#######################################################
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from ruamel.yaml import YAML
from rxnDB.data.mapping import MINERAL_ABBREV_MAP
from rxnDB.utils import app_dir
#######################################################
## .1. RxnDB !!! ##
#######################################################
[docs]@dataclass
class RxnDBLoader:
in_dir: Path
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def __post_init__(self) -> None:
""""""
if not self.in_dir.exists():
raise FileNotFoundError(f"Directory {self.in_dir} not found!")
self.yaml = YAML()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def load_all(self) -> pd.DataFrame:
"""Load and concatenate all YAML entries in the directory into a single DataFrame."""
in_paths: list[Path] = sorted(self.in_dir.glob("*.yml"))
dfs: list[pd.DataFrame] = [self.load_entry(path) for path in in_paths]
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
return pd.concat(dfs, ignore_index=True)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def load_entry(self, filepath: Path) -> pd.DataFrame:
"""Load a single YAML file and convert it into a DataFrame."""
print(f"Loading {filepath.name} ...", end="\r", flush=True)
parsed_yml = self._read_yml(filepath)
reactants: list[str] = self._convert_to_str_list(
parsed_yml.get("reactants", {})
)
products: list[str] = self._convert_to_str_list(parsed_yml.get("products", {}))
reactant_names: list[str] = []
reactant_groups: list[str] = []
reactant_formulas: list[str] = []
product_names: list[str] = []
product_groups: list[str] = []
product_formulas: list[str] = []
for reactant in reactants:
if reactant in MINERAL_ABBREV_MAP:
mineral_info = MINERAL_ABBREV_MAP[reactant]
reactant_names.append(mineral_info["name"])
reactant_groups.append(mineral_info["group"])
reactant_formulas.append(mineral_info["formula"])
else:
print(f" !! Warning: phase {reactant} not in map!")
reactant_names.append(reactant)
reactant_groups.append("")
reactant_formulas.append("")
for product in products:
if product in MINERAL_ABBREV_MAP:
mineral_info = MINERAL_ABBREV_MAP[product]
product_names.append(mineral_info["name"])
product_groups.append(mineral_info["group"])
product_formulas.append(mineral_info["formula"])
else:
print(f" !! Warning: phase {product} not in map!")
product_names.append(product)
product_groups.append("")
product_formulas.append("")
reaction = "+".join(reactants) + "<=>" + "+".join(products)
reaction_names = "+".join(reactant_names) + "<=>" + "+".join(product_names)
data = parsed_yml.get("data", {})
data_type = data.get("type")
units_T = data.get("units", {}).get("T")
units_P = data.get("units", {}).get("P")
metadata = parsed_yml.get("metadata", {})
unique_id = metadata.get("unique_id")
method = metadata.get("method", {})
method_name = method.get("name") if method else None
calib = method.get("calibration") if method else {}
calib_P = calib.get("P") if calib else None
calib_T = calib.get("T") if calib else None
reference = metadata.get("reference").get("short_cite")
comments = metadata.get("comments")
rows = []
points = data.get("points", {})
curve = data.get("boundary_curve", {}).get("polynomial")
if points:
T_vals = points.get("T", {}).get("value", [None])
T_uncs = points.get("T", {}).get("uncertainty", [None])
P_vals = points.get("P", {}).get("value", [None])
P_uncs = points.get("P", {}).get("uncertainty", [None])
lnK_vals = points.get("lnK", {}).get("value", [None])
lnK_uncs = points.get("lnK", {}).get("uncertainty", [None])
n_rows = max(len(T_vals), len(P_vals))
for i in range(n_rows):
rows.append(
{
"unique_id": unique_id,
"reaction": reaction,
"reaction_names": reaction_names,
"reactants": reactants,
"reactant_names": reactant_names,
"reactant_groups": reactant_groups,
"reactant_formulas": reactant_formulas,
"products": products,
"product_names": product_names,
"product_groups": product_groups,
"product_formulas": product_formulas,
"type": data_type,
"units_T": units_T,
"units_P": units_P,
"T": T_vals[i] if i < len(T_vals) else np.nan,
"T_uncertainty": T_uncs[i] if i < len(T_uncs) else np.nan,
"P": P_vals[i] if i < len(P_vals) else np.nan,
"P_uncertainty": P_uncs[i] if i < len(P_uncs) else np.nan,
"lnK": lnK_vals[i] if i < len(lnK_vals) else np.nan,
"lnK_uncertainty": lnK_uncs[i] if i < len(lnK_uncs) else np.nan,
"plot_type": "point",
"reference": reference,
"method": method_name,
"calib_P": calib_P,
"calib_T": calib_T,
"comments": comments,
}
)
elif curve:
intercept = curve.get("intercept", 0.0)
x1 = curve.get("x1", 0.0)
x2 = curve.get("x2", 0.0)
x3 = curve.get("x3", 0.0)
limits = data.get("boundary_curve", {}).get("limits", {})
T_min = limits.get("T_min")
T_max = limits.get("T_max")
P_min = limits.get("P_min")
P_max = limits.get("P_max")
T_vals = np.linspace(T_min, T_max, num=20)
for T in T_vals:
P = intercept + x1 * T + x2 * T**2 + x3 * T**3
if (P_min is not None and P < P_min) or (
P_max is not None and P > P_max
):
T = np.nan
P = np.nan
rows.append(
{
"unique_id": unique_id,
"reaction": reaction,
"reaction_names": reaction_names,
"reactants": reactants,
"reactant_names": reactant_names,
"reactant_groups": reactant_groups,
"reactant_formulas": reactant_formulas,
"products": products,
"product_names": product_names,
"product_groups": product_groups,
"product_formulas": product_formulas,
"type": data_type,
"units_T": units_T,
"units_P": units_P,
"T": T,
"T_uncertainty": np.nan,
"P": P,
"P_uncertainty": np.nan,
"lnK": np.nan,
"lnK_uncertainty": np.nan,
"plot_type": "curve",
"reference": reference,
"method": method_name,
"calib_P": calib_P,
"calib_T": calib_T,
"comments": comments,
}
)
else:
print(f" No point or curve data found in {filepath.name}!")
df = pd.DataFrame(rows)
if "comments" in df.columns:
df["comments"] = df["comments"].fillna("").astype(str)
return df
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] @staticmethod
def save_as_parquet(df: pd.DataFrame, filepath: Path) -> None:
"""Save a DataFrame as a compressed Parquet file."""
filepath.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(filepath, index=False)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] @staticmethod
def load_parquet(filepath: Path) -> pd.DataFrame:
"""Load a DataFrame from a Parquet file."""
print(f"Loading data from {filepath.name} ...")
return pd.read_parquet(filepath)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def _read_yml(self, filepath: Path) -> dict[str, Any]:
"""Read and parse a YAML file."""
with open(filepath, "r") as file:
return self.yaml.load(file)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs] def _convert_to_str_list(self, data: Any) -> list[str]:
"""Ensure that the data is converted to a list of strings"""
if isinstance(data, list):
return [str(item) for item in data]
elif isinstance(data, str):
return [data]
elif isinstance(data, dict):
return [str(k) for k, _ in data.items()]
else:
return [str(data)]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
[docs]def main():
""""""
jimmy_loader = RxnDBLoader(app_dir / "data" / "sets" / "jimmy")
jimmy_data = jimmy_loader.load_all()
hp11_loader = RxnDBLoader(app_dir / "data" / "sets" / "hp11")
hp11_data = hp11_loader.load_all()
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
rxnDB = pd.concat([hp11_data, jimmy_data], ignore_index=True)
out_data = app_dir / "data" / "cache" / "rxnDB.parquet"
RxnDBLoader.save_as_parquet(rxnDB, app_dir / "data" / "cache" / "rxnDB.parquet")
print("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(f"Data saved to {out_data.name}!")
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print("Summary:")
print(rxnDB.info())
if __name__ == "__main__":
main()