Separate detail of Phono3py.develop_mlp() into a function in phonopy

This commit is contained in:
Atsushi Togo 2024-09-05 16:38:13 +09:00
parent 853e3ce02b
commit 9b23427cbf
1 changed files with 6 additions and 45 deletions

View File

@ -58,6 +58,7 @@ from phonopy.interface.fc_calculator import get_fc2
from phonopy.interface.pypolymlp import (
PypolymlpData,
PypolymlpParams,
develop_mlp_by_pypolymlp,
develop_polymlp,
evalulate_polymlp,
load_polymlp,
@ -2204,52 +2205,12 @@ class Phono3py:
if self._mlp_dataset is None:
raise RuntimeError("MLP dataset is not set.")
if params is not None:
_params = parse_mlp_params(params)
else:
_params = params
if (
_params is not None
and _params.ntrain is not None
and _params.ntest is not None
):
ntrain = _params.ntrain
ntest = _params.ntest
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
train_data = PypolymlpData(
displacements=disps[:ntrain],
forces=forces[:ntrain],
supercell_energies=energies[:ntrain],
)
test_data = PypolymlpData(
displacements=disps[-ntest:],
forces=forces[-ntest:],
supercell_energies=energies[-ntest:],
)
else:
disps = self._mlp_dataset["displacements"]
forces = self._mlp_dataset["forces"]
energies = self._mlp_dataset["supercell_energies"]
n = int(len(disps) * (1 - test_size))
train_data = PypolymlpData(
displacements=disps[:n],
forces=forces[:n],
supercell_energies=energies[:n],
)
test_data = PypolymlpData(
displacements=disps[n:],
forces=forces[n:],
supercell_energies=energies[n:],
)
self._mlp = develop_polymlp(
self._mlp = develop_mlp_by_pypolymlp(
self._mlp_dataset,
self._supercell,
train_data,
test_data,
params=_params,
verbose=self._log_level - 1 > 0,
params=params,
test_size=test_size,
log_level=self._log_level,
)
def load_mlp(self, filename: str = "phono3py.pmlp"):