Refactoring of calculation using pypolymlp

This commit is contained in:
Atsushi Togo 2024-09-04 17:29:31 +09:00
parent dd4c43ab10
commit 9fa805b1b5
3 changed files with 52 additions and 44 deletions

View File

@ -2208,16 +2208,37 @@ class Phono3py:
else:
_params = params
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n], forces=forces[:n], supercell_energies=energies[:n]
)
test_data = PypolymlpData(
displacements=disps[n:], forces=forces[n:], supercell_energies=energies[n:]
)
if _params.ntrain is not None and _params.ntest is not None:
ntrain = _params.ntrain
ntest = _params.ntest
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
train_data = PypolymlpData(
displacements=disps[:ntrain],
forces=forces[:ntrain],
supercell_energies=energies[:ntrain],
)
test_data = PypolymlpData(
displacements=disps[-ntest:],
forces=forces[-ntest:],
supercell_energies=energies[-ntest:],
)
else:
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n],
forces=forces[:n],
supercell_energies=energies[:n],
)
test_data = PypolymlpData(
displacements=disps[n:],
forces=forces[n:],
supercell_energies=energies[n:],
)
self._mlp = develop_polymlp(
self._supercell,
train_data,

View File

@ -118,19 +118,28 @@ def create_phono3py_force_constants(
if settings.read_fc3:
_read_phono3py_fc3(phono3py, symmetrize_fc3r, input_filename, log_level)
else: # fc3 from FORCES_FC3 or ph3py_yaml
_read_dataset_fc3(
dataset = _read_dataset_fc3(
phono3py,
ph3py_yaml,
phono3py_yaml_filename,
settings.cutoff_pair_distance,
calculator,
settings.use_pypolymlp,
settings.mlp_params,
settings.displacement_distance,
settings.random_displacements,
settings.random_seed,
log_level,
)
if settings.use_pypolymlp:
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
settings.mlp_params,
displacement_distance=settings.displacement_distance,
number_of_snapshots=settings.random_displacements,
random_seed=settings.random_seed,
log_level=log_level,
)
else:
phono3py.dataset = dataset
phono3py.produce_fc3(
symmetrize_fc3r=symmetrize_fc3r,
is_compact_fc=settings.is_compact_fc,
@ -214,7 +223,7 @@ def parse_forces(
fc_type: Literal["fc3", "phonon_fc2"] = "fc3",
calculator: Optional[str] = None,
log_level=0,
):
) -> dict:
"""Read displacements and forces.
Physical units of displacements and forces are converted following the
@ -454,13 +463,8 @@ def _read_dataset_fc3(
phono3py_yaml_filename: Optional[str],
cutoff_pair_distance: Optional[float],
calculator: Optional[str],
use_pypolymlp: bool,
mlp_params: Union[str, dict, PypolymlpParams],
displacement_distance: Optional[float],
number_of_snapshots: Optional[int],
random_seed: Optional[int],
log_level: int,
):
) -> dict:
"""Read or calculate fc3.
Note
@ -496,18 +500,7 @@ def _read_dataset_fc3(
# from _get_type2_dataset
file_exists(e.filename, log_level=log_level)
if use_pypolymlp:
phono3py.mlp_dataset = dataset
run_pypolymlp_to_compute_forces(
phono3py,
mlp_params,
displacement_distance=displacement_distance,
number_of_snapshots=number_of_snapshots,
random_seed=random_seed,
log_level=log_level,
)
else:
phono3py.dataset = dataset
return dataset
def run_pypolymlp_to_compute_forces(
@ -529,8 +522,6 @@ def run_pypolymlp_to_compute_forces(
for k, v in asdict(parse_mlp_params(mlp_params)).items():
if v is not None:
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)
@ -579,6 +570,7 @@ def run_pypolymlp_to_compute_forces(
raise RuntimeError("Displacements are not set. Run generate_displacements.")
ph3py.evaluate_mlp()
ph3py.save("phono3py_mlp_eval_dataset.yaml")
def run_pypolymlp_to_compute_phonon_forces(
@ -601,8 +593,6 @@ def run_pypolymlp_to_compute_phonon_forces(
for k, v in asdict(parse_mlp_params(mlp_params)).items():
if v is not None:
print(f" {k}: {v}")
if log_level > 1:
print("")
if log_level:
print("Developing MLPs by pypolymlp...", flush=True)

View File

@ -165,16 +165,13 @@ def finalize_phono3py(
_physical_units = get_default_physical_units(phono3py.calculator)
write_force_sets = phono3py.mlp is not None
_write_displacements = write_displacements or phono3py.mlp is not None
ph3py_yaml = Phono3pyYaml(
configuration=confs_dict,
calculator=phono3py.calculator,
physical_units=_physical_units,
settings={
"force_sets": write_force_sets,
"displacements": _write_displacements,
"force_sets": False,
"displacements": write_displacements,
},
)
ph3py_yaml.set_phonon_info(phono3py)