mirror of https://github.com/phonopy/phono3py.git
Update for reading pypolymlp MLPs from file
This commit is contained in:
parent
f38a704c40
commit
b7fea1a068
|
@ -60,6 +60,7 @@ from phonopy.interface.pypolymlp import (
|
|||
PypolymlpParams,
|
||||
develop_polymlp,
|
||||
evalulate_polymlp,
|
||||
load_polymlp,
|
||||
parse_mlp_params,
|
||||
)
|
||||
from phonopy.structure.atoms import PhonopyAtoms
|
||||
|
@ -2226,8 +2227,12 @@ class Phono3py:
|
|||
verbose=self._log_level - 1 > 0,
|
||||
)
|
||||
|
||||
def load_mlp(self, filename: str = "pypolymlp.mlp"):
|
||||
"""Load machine learning potential of pypolymlp."""
|
||||
self._mlp = load_polymlp(filename=filename)
|
||||
|
||||
def evaluate_mlp(self):
|
||||
"""Evaluate the machine learning potential of pypolymlp.
|
||||
"""Evaluate machine learning potential of pypolymlp.
|
||||
|
||||
This method calculates the supercell energies and forces from the MLP
|
||||
for the displacements in self._dataset of type 2. The results are stored
|
||||
|
|
|
@ -502,7 +502,8 @@ def _read_dataset_fc3(
|
|||
file_exists(e.filename, log_level=log_level)
|
||||
|
||||
if use_pypolymlp:
|
||||
phono3py.mlp_dataset = dataset
|
||||
if forces_in_dataset(dataset):
|
||||
phono3py.mlp_dataset = dataset
|
||||
run_pypolymlp_to_compute_forces(
|
||||
phono3py,
|
||||
mlp_params,
|
||||
|
@ -521,6 +522,7 @@ def run_pypolymlp_to_compute_forces(
|
|||
displacement_distance: Optional[float] = None,
|
||||
number_of_snapshots: Optional[int] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
mlp_filename: str = "pypolymlp.mlp",
|
||||
log_level: int = 0,
|
||||
):
|
||||
"""Run pypolymlp to compute forces."""
|
||||
|
@ -536,10 +538,18 @@ def run_pypolymlp_to_compute_forces(
|
|||
print(f" {k}: {v}")
|
||||
if log_level > 1:
|
||||
print("")
|
||||
if log_level:
|
||||
print("Developing MLPs by pypolymlp...", flush=True)
|
||||
|
||||
ph3py.develop_mlp(params=mlp_params)
|
||||
if forces_in_dataset(ph3py.mlp_dataset):
|
||||
if log_level:
|
||||
print("Developing MLPs by pypolymlp...", flush=True)
|
||||
ph3py.develop_mlp(params=mlp_params)
|
||||
else:
|
||||
if pathlib.Path(mlp_filename).exists():
|
||||
if log_level:
|
||||
print(f'Load MLPs from "{mlp_filename}".')
|
||||
ph3py.load_mlp(mlp_filename)
|
||||
else:
|
||||
raise RuntimeError(f'"{mlp_filename}" is not found.')
|
||||
|
||||
if log_level:
|
||||
print("-" * 30 + " pypolymlp end " + "-" * 31, flush=True)
|
||||
|
@ -577,9 +587,6 @@ def run_pypolymlp_to_compute_forces(
|
|||
flush=True,
|
||||
)
|
||||
|
||||
if ph3py.mlp_dataset is None:
|
||||
msg = "mlp_dataset has to be set before calling this method."
|
||||
raise RuntimeError(msg)
|
||||
if ph3py.supercells_with_displacements is None:
|
||||
raise RuntimeError("Displacements are not set. Run generate_displacements.")
|
||||
|
||||
|
|
|
@ -416,7 +416,8 @@ def set_dataset_and_force_constants(
|
|||
)
|
||||
if not read_fc["fc3"]:
|
||||
if use_pypolymlp:
|
||||
ph3py.mlp_dataset = dataset
|
||||
if forces_in_dataset(dataset):
|
||||
ph3py.mlp_dataset = dataset
|
||||
else:
|
||||
ph3py.dataset = dataset
|
||||
read_fc["fc2"], phonon_dataset = _get_dataset_phonon_dataset_or_fc2(
|
||||
|
@ -461,8 +462,8 @@ def compute_force_constants_from_datasets(
|
|||
"""
|
||||
fc3_calculator = extract_fc2_fc3_calculators(fc_calculator, 3)
|
||||
fc2_calculator = extract_fc2_fc3_calculators(fc_calculator, 2)
|
||||
if not read_fc["fc3"] and (ph3py.dataset or ph3py.mlp_dataset):
|
||||
if use_pypolymlp and forces_in_dataset(ph3py.mlp_dataset):
|
||||
if not read_fc["fc3"]:
|
||||
if use_pypolymlp:
|
||||
run_pypolymlp_to_compute_forces(
|
||||
ph3py,
|
||||
mlp_params=mlp_params,
|
||||
|
|
|
@ -272,7 +272,7 @@ def get_bond_symmetry(
|
|||
def get_least_orbits(atom_index, cell, site_symmetry, symprec=1e-5):
|
||||
"""Find least orbits for a centering atom."""
|
||||
orbits = _get_orbits(atom_index, cell, site_symmetry, symprec)
|
||||
mapping = np.arange(cell.get_number_of_atoms())
|
||||
mapping = np.arange(len(cell))
|
||||
|
||||
for i, orb in enumerate(orbits):
|
||||
for num in np.unique(orb):
|
||||
|
|
|
@ -88,7 +88,7 @@ class ReciprocalToNormal:
|
|||
self._fc3_normal[i, j, k] = fc3_elem / fff
|
||||
|
||||
def _sum_in_atoms(self, band_indices, eigvecs):
|
||||
num_atom = self._primitive.get_number_of_atoms()
|
||||
num_atom = len(self._primitive)
|
||||
(e1, e2, e3) = eigvecs
|
||||
(b1, b2, b3) = band_indices
|
||||
|
||||
|
|
Loading…
Reference in New Issue