Update for reading pypolymlp MLPs from file

This commit is contained in:
Atsushi Togo 2024-09-03 20:12:59 +09:00
parent f38a704c40
commit b7fea1a068
5 changed files with 26 additions and 13 deletions

View File

@ -60,6 +60,7 @@ from phonopy.interface.pypolymlp import (
PypolymlpParams,
develop_polymlp,
evalulate_polymlp,
load_polymlp,
parse_mlp_params,
)
from phonopy.structure.atoms import PhonopyAtoms
@ -2226,8 +2227,12 @@ class Phono3py:
verbose=self._log_level - 1 > 0,
)
def load_mlp(self, filename: str = "pypolymlp.mlp"):
"""Load machine learning potential of pypolymlp."""
self._mlp = load_polymlp(filename=filename)
def evaluate_mlp(self):
"""Evaluate the machine learning potential of pypolymlp.
"""Evaluate machine learning potential of pypolymlp.
This method calculates the supercell energies and forces from the MLP
for the displacements in self._dataset of type 2. The results are stored

View File

@ -502,7 +502,8 @@ def _read_dataset_fc3(
file_exists(e.filename, log_level=log_level)
if use_pypolymlp:
phono3py.mlp_dataset = dataset
if forces_in_dataset(dataset):
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
mlp_params,
@ -521,6 +522,7 @@ def run_pypolymlp_to_compute_forces(
displacement_distance: Optional[float] = None,
number_of_snapshots: Optional[int] = None,
random_seed: Optional[int] = None,
mlp_filename: str = "pypolymlp.mlp",
log_level: int = 0,
):
"""Run pypolymlp to compute forces."""
@ -536,10 +538,18 @@ def run_pypolymlp_to_compute_forces(
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
if forces_in_dataset(ph3py.mlp_dataset):
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)
ph3py.develop_mlp(params=mlp_params)
else:
if pathlib.Path(mlp_filename).exists():
if log_level:
print(f'Load MLPs from "{mlp_filename}".')
ph3py.load_mlp(mlp_filename)
else:
raise RuntimeError(f'"{mlp_filename}" is not found.')
if log_level:
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
@ -577,9 +587,6 @@ def run_pypolymlp_to_compute_forces(
flush=True,
)
if ph3py.mlp_dataset is None:
msg = "mlp_dataset has to be set before calling this method."
raise RuntimeError(msg)
if ph3py.supercells_with_displacements is None:
raise RuntimeError("Displacements are not set. Run generate_displacements.")

View File

@ -416,7 +416,8 @@ def set_dataset_and_force_constants(
)
if not read_fc["fc3"]:
if use_pypolymlp:
ph3py.mlp_dataset = dataset
if forces_in_dataset(dataset):
ph3py.mlp_dataset = dataset
else:
ph3py.dataset = dataset
read_fc["fc2"], phonon_dataset = _get_dataset_phonon_dataset_or_fc2(
@ -461,8 +462,8 @@ def compute_force_constants_from_datasets(
"""
fc3_calculator = extract_fc2_fc3_calculators(fc_calculator, 3)
fc2_calculator = extract_fc2_fc3_calculators(fc_calculator, 2)
if not read_fc["fc3"] and (ph3py.dataset or ph3py.mlp_dataset):
if use_pypolymlp and forces_in_dataset(ph3py.mlp_dataset):
if not read_fc["fc3"]:
if use_pypolymlp:
run_pypolymlp_to_compute_forces(
ph3py,
mlp_params=mlp_params,

View File

@ -272,7 +272,7 @@ def get_bond_symmetry(
def get_least_orbits(atom_index, cell, site_symmetry, symprec=1e-5):
"""Find least orbits for a centering atom."""
orbits = _get_orbits(atom_index, cell, site_symmetry, symprec)
mapping = np.arange(cell.get_number_of_atoms())
mapping = np.arange(len(cell))
for i, orb in enumerate(orbits):
for num in np.unique(orb):

View File

@ -88,7 +88,7 @@ class ReciprocalToNormal:
self._fc3_normal[i, j, k] = fc3_elem / fff
def _sum_in_atoms(self, band_indices, eigvecs):
num_atom = self._primitive.get_number_of_atoms()
num_atom = len(self._primitive)
(e1, e2, e3) = eigvecs
(b1, b2, b3) = band_indices