Skip to content

Commit 4da55e1

Browse files
authored
Feature: Add support for ML EXX in training script. (#6479)
* Feature: Support ML EXX for training script. * Update the interface to libnpy
1 parent 4e44caa commit 4da55e1

File tree

7 files changed

+102
-63
lines changed

7 files changed

+102
-63
lines changed

source/source_pw/module_ofdft/kedf_ml.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,15 +328,24 @@ void KEDF_ML::NN_forward(const double * const * prho, ModulePW::PW_Basis *pw_rho
328328

329329
void KEDF_ML::loadVector(std::string filename, std::vector<double> &data)
330330
{
331-
std::vector<long unsigned int> cshape = {(long unsigned) this->cal_tool->nx};
332-
bool fortran_order = false;
333-
npy::LoadArrayFromNumpy(filename, cshape, fortran_order, data);
331+
npy::npy_data<double> d = npy::read_npy<double>(filename);
332+
data = d.data;
333+
// ========== For old version of npy.hpp ==========
334+
// std::vector<long unsigned int> cshape = {(long unsigned) this->cal_tool->nx};
335+
// bool fortran_order = false;
336+
// npy::LoadArrayFromNumpy(filename, cshape, fortran_order, data);
334337
}
335338

336339
void KEDF_ML::dumpVector(std::string filename, const std::vector<double> &data)
337340
{
338-
const long unsigned cshape[] = {(long unsigned) this->cal_tool->nx}; // shape
339-
npy::SaveArrayAsNumpy(filename, false, 1, cshape, data);
341+
npy::npy_data_ptr<double> d;
342+
d.data_ptr = data.data();
343+
d.shape = {(long unsigned) this->cal_tool->nx};
344+
d.fortran_order = false; // optional
345+
npy::write_npy(filename, d);
346+
// ========== For old version of npy.hpp ==========
347+
// const long unsigned cshape[] = {(long unsigned) this->cal_tool->nx}; // shape
348+
// npy::SaveArrayAsNumpy(filename, false, 1, cshape, data);
340349
}
341350

342351
/**

source/source_pw/module_ofdft/ml_tools/data.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
208208
if (this->load_tanhxi[ik]){
209209
this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
210210
}
211-
if (this->load_tanhxi_nl[ik{
211+
if (this->load_tanhxi_nl[ik]){
212212
this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
213213
}
214214
if (this->load_tanh_pnl[ik]){
@@ -319,7 +319,16 @@ void Data::load_data_(
319319
enhancement.resize_({this->nx_tot, 1});
320320
pauli.resize_({nx_tot, 1});
321321

322-
this->tau_tf = this->cTF * torch::pow(this->rho, 5./3.);
322+
if (input.energy_type == "kedf")
323+
{
324+
this->tau_exp = 5. / 3.;
325+
this->tau_lda = this->cTF * torch::pow(this->rho, this->tau_exp);
326+
}
327+
else if (input.energy_type == "exx")
328+
{
329+
this->tau_exp = 4. / 3.;
330+
this->tau_lda = this->cDirac * torch::pow(this->rho, this->tau_exp);
331+
}
323332
// Input::print("load_data done");
324333
}
325334

source/source_pw/module_ofdft/ml_tools/data.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Data
1717
// =========== data ===========
1818
torch::Tensor rho;
1919
torch::Tensor nablaRho;
20-
torch::Tensor tau_tf;
20+
torch::Tensor tau_lda; // energy density of LDA, i.e. TF for KEDF, Dirac term for EXX
2121
// semi-local descriptors
2222
torch::Tensor gamma;
2323
torch::Tensor p;
@@ -67,7 +67,9 @@ class Data
6767
void init_data(const int nkernel, const int ndata, const int fftdim, const torch::Device device);
6868
void load_data_(Input &input, const int ndata, const int fftdim, std::string *dir);
6969

70-
const double cTF = 3.0/10.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
70+
const double cTF = 3. /10. * std::pow(3. * std::pow(M_PI, 2.), 2. / 3.) * 2.; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
71+
const double cDirac = - 3. /4. * std::pow(3. / M_PI, 1./3.) * 2.; // -3/4*(3/pi)^{1/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
72+
double tau_exp = 5. / 3.; // 5/3 for TF KEDF, and 4/3 for Dirac term
7173

7274
public:
7375
void loadTensor(std::string file,

source/source_pw/module_ofdft/ml_tools/input.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ void Input::readInput()
277277
{
278278
this->read_value(ifs, this->device_type);
279279
}
280+
else if (strcmp("energy_type", word) == 0)
281+
{
282+
this->read_value(ifs, this->energy_type);
283+
}
280284
}
281285

282286
std::cout << "Read nnINPUT done" << std::endl;

source/source_pw/module_ofdft/ml_tools/input.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Input
7373
double lr_end = 1e-4;
7474
int lr_fre = 5000;
7575
double exponent = 5.; // exponent of weight rho^{exponent/3.}
76+
std::string energy_type = "kedf"; // kedf or exx
7677

7778
// output
7879
int dump_fre = 1;

0 commit comments

Comments
 (0)