diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 723015c1f99..0ddaf6640dd 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -302,6 +302,7 @@ set(STANDARD_PACKAGES MISC ML-HDNNP ML-IAP + ML-MACE ML-PACE ML-POD ML-QUIP @@ -580,7 +581,7 @@ else() endif() foreach(PKG_WITH_INCL KSPACE PYTHON ML-IAP VORONOI COLVARS ML-HDNNP MDI MOLFILE NETCDF - PLUMED QMMM ML-QUIP SCAFACOS MACHDYN VTK KIM COMPRESS ML-PACE LEPTON RHEO) + PLUMED QMMM ML-QUIP SCAFACOS MACHDYN VTK KIM COMPRESS ML-PACE ML-MACE LEPTON RHEO) if(PKG_${PKG_WITH_INCL}) include(Packages/${PKG_WITH_INCL}) endif() diff --git a/cmake/Modules/Packages/ML-MACE.cmake b/cmake/Modules/Packages/ML-MACE.cmake new file mode 100644 index 00000000000..e2956ae5984 --- /dev/null +++ b/cmake/Modules/Packages/ML-MACE.cmake @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.0 FATAL_ERROR) + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +target_link_libraries(lammps PRIVATE "${TORCH_LIBRARIES}") diff --git a/src/KOKKOS/pair_mace_kokkos.cpp b/src/KOKKOS/pair_mace_kokkos.cpp new file mode 100644 index 00000000000..b984b5fc210 --- /dev/null +++ b/src/KOKKOS/pair_mace_kokkos.cpp @@ -0,0 +1,421 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributors + William C Witt (University of Cambridge) +------------------------------------------------------------------------- */ + +#include "pair_mace_kokkos.h" + +#include "atom_kokkos.h" +#include "atom_masks.h" +#include "domain.h" +#include "error.h" +#include "force.h" +#include "kokkos.h" +#include "memory_kokkos.h" +#include "neigh_list.h" +#include "neighbor_kokkos.h" +#include "neigh_request.h" +#include "neighbor.h" +#include "update.h" + +#include +#include + +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +template +PairMACEKokkos::PairMACEKokkos(LAMMPS *lmp) : PairMACE(lmp) +{ + no_virial_fdotr_compute = 1; + + kokkosable = 1; + atomKK = (AtomKokkos *) atom; + execution_space = ExecutionSpaceFromDevice::space; + datamask_read = EMPTY_MASK; + datamask_modify = EMPTY_MASK; + + host_flag = (execution_space == Host); +} + +/* ---------------------------------------------------------------------- */ + +template +PairMACEKokkos::~PairMACEKokkos() +{ + if (copymode) return; +} + +/* ---------------------------------------------------------------------- */ + +template +void PairMACEKokkos::compute(int eflag, int vflag) +{ + ev_init(eflag,vflag,0); + + atomKK->sync(execution_space,X_MASK|F_MASK|TYPE_MASK|TAG_MASK); + + NeighListKokkos* k_list = static_cast*>(list); + auto d_numneigh = k_list->d_numneigh; + auto d_neighbors = k_list->d_neighbors; + auto d_ilist = k_list->d_ilist; + + if (atom->nlocal != list->inum) + error->all(FLERR, "ERROR: nlocal != inum."); + if (domain_decomposition && (atom->nghost != list->gnum)) + error->all(FLERR, "ERROR: nghost != gnum."); + if (eflag_atom || vflag_atom) + error->all(FLERR, "ERROR: mace/kokkos eflag_atom and/or vflag_atom not implemented."); + + int nlocal = atom->nlocal; + auto r_max_squared = this->r_max_squared; + auto h0 = domain->h[0]; + auto h1 = domain->h[1]; + auto h2 = domain->h[2]; + auto h3 = domain->h[3]; + auto h4 = domain->h[4]; + auto h5 = domain->h[5]; + auto hinv0 = domain->h_inv[0]; + auto hinv1 = domain->h_inv[1]; + auto hinv2 = domain->h_inv[2]; + auto hinv3 = domain->h_inv[3]; + auto hinv4 = domain->h_inv[4]; + auto hinv5 = domain->h_inv[5]; + + auto _k_lammps_atomic_numbers = k_lammps_atomic_numbers; + auto _k_mace_atomic_numbers = k_mace_atomic_numbers; + auto _mace_atomic_numbers_size = mace_atomic_numbers_size; + + // atom map + auto map_style = atom->map_style; + auto k_map_array = atomKK->k_map_array; + auto k_map_hash = atomKK->k_map_hash; + k_map_array.template sync(); + + auto x = atomKK->k_x.view(); +// c_x = atomKK->k_x.view(); + auto f = atomKK->k_f.view(); + auto tag = atomKK->k_tag.view(); + auto type = atomKK->k_type.view(); + + + // ----- positions ----- + int n_nodes; + if (domain_decomposition) { + n_nodes = atom->nlocal + atom->nghost; + } else { + // normally, ghost atoms are included in the graph as independent + // nodes, as required when the local domain does not have PBC. + // however, in no_domain_decomposition mode, ghost atoms are simply + // shifted versions of local atoms. + n_nodes = atom->nlocal; + } + auto k_positions = Kokkos::View("k_positions", n_nodes); + Kokkos::parallel_for("PairMACEKokkos: Fill k_positions.", n_nodes, KOKKOS_LAMBDA (const int i) { + k_positions(i,0) = x(i,0); + k_positions(i,1) = x(i,1); + k_positions(i,2) = x(i,2); + }); + auto positions = torch::from_blob( + k_positions.data(), + {n_nodes,3}, + torch::TensorOptions().dtype(torch_float_dtype).device(device)); + + // ----- cell ----- + // TODO: how to use kokkos here? + auto cell = torch::zeros({3,3}, torch::TensorOptions().dtype(torch_float_dtype).device(device)); + cell[0][0] = h0; + cell[0][1] = 0.0; + cell[0][2] = 0.0; + cell[1][0] = h5; + cell[1][1] = h1; + cell[1][2] = 0.0; + cell[2][0] = h4; + cell[2][1] = h3; + cell[2][2] = h2; + + // ----- edge_index and unit_shifts ----- + // count total number of edges + auto k_n_edges_vec = Kokkos::View("k_n_edges_vec", n_nodes); + Kokkos::parallel_for("PairMACEKokkos: Fill k_n_edges_vec.", n_nodes, KOKKOS_LAMBDA (const int ii) { + const int i = d_ilist(ii); + const double xtmp = x(i,0); + const double ytmp = x(i,1); + const double ztmp = x(i,2); + for (int jj=0; jj("k_first_edge", n_nodes); // initialized to zero + // TODO: this is serial to avoid race ... is there something better? + Kokkos::parallel_for("PairMACEKokkos: Fill k_first_edge.", 1, KOKKOS_LAMBDA(const int i) { + for (int ii=0; ii("k_edge_index", 2, n_edges); + auto k_unit_shifts = Kokkos::View("k_unit_shifts", n_edges); + auto k_shifts = Kokkos::View("k_shifts", n_edges); + + if (domain_decomposition) { + + Kokkos::parallel_for("PairMACEKokkos: Fill edge_index (using domain decomposition).", n_nodes, KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + const double xtmp = x(i,0); + const double ytmp = x(i,1); + const double ztmp = x(i,2); + int k = k_first_edge(ii); + for (int jj=0; jj(tag(j),map_style,k_map_array,k_map_hash); + k_edge_index(1,k) = j_local; + double shiftx = x(j,0) - x(j_local,0); + double shifty = x(j,1) - x(j_local,1); + double shiftz = x(j,2) - x(j_local,2); + double shiftxs = std::round(hinv0*shiftx + hinv5*shifty + hinv4*shiftz); + double shiftys = std::round(hinv1*shifty + hinv3*shiftz); + double shiftzs = std::round(hinv2*shiftz); + k_unit_shifts(k,0) = shiftxs; + k_unit_shifts(k,1) = shiftys; + k_unit_shifts(k,2) = shiftzs; + k_shifts(k,0) = h0*shiftxs + h5*shiftys + h4*shiftzs; + k_shifts(k,1) = h1*shiftys + h3*shiftzs; + k_shifts(k,2) = h2*shiftzs; + k++; + } + } + }); + } + auto edge_index = torch::from_blob( + k_edge_index.data(), + {2,n_edges}, + torch::TensorOptions().dtype(torch::kInt64).device(device)); + auto unit_shifts = torch::from_blob( + k_unit_shifts.data(), + {n_edges,3}, + torch::TensorOptions().dtype(torch_float_dtype).device(device)); + auto shifts = torch::from_blob( + k_shifts.data(), + {n_edges,3}, + torch::TensorOptions().dtype(torch_float_dtype).device(device)); + + // ----- node_attrs ----- + // node_attrs is one-hot encoding for atomic numbers + int n_node_feats = _mace_atomic_numbers_size; + auto k_node_attrs = Kokkos::View("k_node_attrs", n_nodes, n_node_feats); + Kokkos::parallel_for("PairMACEKokkos: Fill k_node_attrs.", n_nodes, KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + const int lammps_type = type(i); + int t = -1; + for (int j=0; j<_mace_atomic_numbers_size; ++j) { + if (_k_mace_atomic_numbers(j)==_k_lammps_atomic_numbers(lammps_type-1)) { + t = j+1; + } + } + k_node_attrs(i,t-1) = 1.0; + }); + auto node_attrs = torch::from_blob( + k_node_attrs.data(), + {n_nodes, n_node_feats}, + torch::TensorOptions().dtype(torch_float_dtype).device(device)); + + // ----- mask for ghost ----- + Kokkos::View k_mask("k_mask", n_nodes); + Kokkos::parallel_for("PairMACEKokkos: Fill k_mask.", nlocal, KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + k_mask(i) = true; + }); + auto mask = torch::from_blob( + k_mask.data(), + n_nodes, + torch::TensorOptions().dtype(torch::kBool).device(device)); + + // TODO: why is batch of size n_nodes? + auto batch = torch::zeros({n_nodes}, torch::TensorOptions().dtype(torch::kInt64).device(device)); + auto energy = torch::empty({1}, torch::TensorOptions().dtype(torch_float_dtype).device(device)); + auto forces = torch::empty({n_nodes,3}, torch::TensorOptions().dtype(torch_float_dtype).device(device)); + auto ptr = torch::empty({2}, torch::TensorOptions().dtype(torch::kInt64).device(device)); + auto weight = torch::empty({1}, torch::TensorOptions().dtype(torch_float_dtype).device(device)); + ptr[0] = 0; + ptr[1] = n_nodes; + weight[0] = 1.0; + + // pack the input, call the model, extract the output + c10::Dict input; + input.insert("batch", batch); + input.insert("cell", cell); + input.insert("edge_index", edge_index); + input.insert("energy", energy); + input.insert("forces", forces); + input.insert("node_attrs", node_attrs); + input.insert("positions", positions); + input.insert("ptr", ptr); + input.insert("shifts", shifts); + input.insert("unit_shifts", unit_shifts); + input.insert("weight", weight); + auto output = model.forward({input, mask, bool(vflag_global)}).toGenericDict(); + + // mace energy + // -> sum of site energies of local atoms + if (eflag_global) { + auto node_energy = output.at("node_energy").toTensor(); + auto node_energy_ptr = static_cast(node_energy.data_ptr()); + auto k_node_energy = Kokkos::View>(node_energy_ptr,n_nodes); + eng_vdwl = 0.0; + Kokkos::parallel_reduce("PairMACEKokkos: Accumulate site energies.", nlocal, KOKKOS_LAMBDA(const int ii, double &eng_vdwl) { + const int i = d_ilist(ii); + eng_vdwl += k_node_energy(i); + }, eng_vdwl); + } + + // mace forces + // -> derivatives of total mace energy + forces = output.at("forces").toTensor(); + auto forces_ptr = static_cast(forces.data_ptr()); + auto k_forces = Kokkos::View>(forces_ptr,n_nodes); + Kokkos::parallel_for("PairMACEKokkos: Extract k_forces.", n_nodes, KOKKOS_LAMBDA(const int ii) { + const int i = d_ilist(ii); + f(i,0) += k_forces(i,0); + f(i,1) += k_forces(i,1); + f(i,2) += k_forces(i,2); + }); + + // mace virials (local atoms only) + // -> derivatives of sum of site energies of local atoms + if (vflag_global) { + // TODO: is this cpu transfer necessary? + auto vir = output.at("virials").toTensor().to("cpu"); + // caution: lammps does not use voigt ordering + // also: it would be nice to get rid of the 'template item' stuff, + // but some compilers seem to require it + virial[0] += vir[0][0][0].template item(); + virial[1] += vir[0][1][1].template item(); + virial[2] += vir[0][2][2].template item(); + virial[3] += 0.5*(vir[0][1][0].template item() + vir[0][0][1].template item()); + virial[4] += 0.5*(vir[0][2][0].template item() + vir[0][0][2].template item()); + virial[5] += 0.5*(vir[0][2][1].template item() + vir[0][1][2].template item()); + } + + // TODO: investigate this + // Appears to be important for dumps and probably more + atomKK->modified(execution_space,F_MASK); +} + +/* ---------------------------------------------------------------------- */ + +template +void PairMACEKokkos::coeff(int narg, char **arg) +{ + if (!allocated) allocate(); + PairMACE::coeff(narg,arg); + + // new + k_lammps_atomic_numbers = Kokkos::View("k_lammps_atomic_numbers",lammps_atomic_numbers.size()); + auto k_lammps_atomic_numbers_mirror = Kokkos::create_mirror_view(k_lammps_atomic_numbers); + for (int i=0; i("k_mace_atomic_numbers",mace_atomic_numbers.size()); + auto k_mace_atomic_numbers_mirror = Kokkos::create_mirror_view(k_mace_atomic_numbers); + for (int i=0; i +void PairMACEKokkos::init_style() +{ + PairMACE::init_style(); + auto request = neighbor->find_request(this); + request->set_kokkos_host(std::is_same::value && + !std::is_same::value); + request->set_kokkos_device(std::is_same::value); +} + +template +double PairMACEKokkos::init_one(int i, int j) +{ + double cutone = PairMACE::init_one(i,j); + k_cutsq.h_view(i,j) = k_cutsq.h_view(j,i) = cutone*cutone; + k_cutsq.template modify(); + return cutone; +} + +template +void PairMACEKokkos::allocate() +{ + PairMACE::allocate(); + int n = atom->ntypes + 1; + MemKK::realloc_kokkos(k_cutsq, "mace:cutsq", n, n); + d_cutsq = k_cutsq.template view(); +} + +namespace LAMMPS_NS { +template class PairMACEKokkos; +#ifdef LMP_KOKKOS_GPU +template class PairMACEKokkos; +#endif +} + diff --git a/src/KOKKOS/pair_mace_kokkos.h b/src/KOKKOS/pair_mace_kokkos.h new file mode 100644 index 00000000000..8aa0a8e62b4 --- /dev/null +++ b/src/KOKKOS/pair_mace_kokkos.h @@ -0,0 +1,69 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributors + William C Witt (University of Cambridge) +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(mace/kk,PairMACEKokkos); +PairStyle(mace/kk/device,PairMACEKokkos); +PairStyle(mace/kk/host,PairMACEKokkos); +// clang-format on +#else + +#ifndef LMP_PAIR_MACE_KOKKOS_H +#define LMP_PAIR_MACE_KOKKOS_H + +#include "pair_mace.h" +#include "kokkos_type.h" +#include "pair_kokkos.h" +#include "neigh_list_kokkos.h" + +namespace LAMMPS_NS { + +template +class PairMACEKokkos : public PairMACE { + + public: + + typedef DeviceType device_type; + typedef ArrayTypes AT; + PairMACEKokkos(class LAMMPS *); + ~PairMACEKokkos() override; + void compute(int, int) override; + void coeff(int, char **) override; + void init_style() override; + double init_one(int, int) override; + void allocate(); + + protected: + + int host_flag; + typedef Kokkos::DualView tdual_fparams; + tdual_fparams k_cutsq; + typedef Kokkos::View t_fparams; + t_fparams d_cutsq; + + // new + Kokkos::View k_lammps_atomic_numbers; + Kokkos::View k_mace_atomic_numbers; + int mace_atomic_numbers_size; + +}; +} // namespace LAMMPS_NS + +#endif +#endif diff --git a/src/ML-MACE/pair_mace.cpp b/src/ML-MACE/pair_mace.cpp new file mode 100644 index 00000000000..9b7547653fc --- /dev/null +++ b/src/ML-MACE/pair_mace.cpp @@ -0,0 +1,402 @@ +/* ---------------------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributors + William C Witt (University of Cambridge) +------------------------------------------------------------------------- */ + +#include "pair_mace.h" + +#include "atom.h" +#include "domain.h" +#include "error.h" +#include "force.h" +#include "memory.h" +#include "neigh_list.h" +#include "neighbor.h" + +#include +#include +#include + +using namespace LAMMPS_NS; + +/* ---------------------------------------------------------------------- */ + +PairMACE::PairMACE(LAMMPS *lmp) : Pair(lmp) +{ + no_virial_fdotr_compute = 1; +} + +/* ---------------------------------------------------------------------- */ + +PairMACE::~PairMACE() +{ +} + +/* ---------------------------------------------------------------------- */ + +void PairMACE::compute(int eflag, int vflag) +{ + ev_init(eflag, vflag); + + if (atom->nlocal != list->inum) error->all(FLERR, "ERROR: nlocal != inum."); + if (domain_decomposition) { + if (atom->nghost != list->gnum) error->all(FLERR, "ERROR: nghost != gnum."); + } + + // ----- positions ----- + int n_nodes; + if (domain_decomposition) { + n_nodes = atom->nlocal + atom->nghost; + } else { + // normally, ghost atoms are included in the graph as independent + // nodes, as required when the local domain does not have PBC. + // however, in no_domain_decomposition mode, ghost atoms are known to + // be shifted versions of local atoms. + n_nodes = atom->nlocal; + } + auto positions = torch::empty({n_nodes,3}, torch_float_dtype); + #pragma omp parallel for + for (int ii=0; iiilist[ii]; + positions[i][0] = atom->x[i][0]; + positions[i][1] = atom->x[i][1]; + positions[i][2] = atom->x[i][2]; + } + + // ----- cell ----- + auto cell = torch::zeros({3,3}, torch_float_dtype); + cell[0][0] = domain->h[0]; + cell[0][1] = 0.0; + cell[0][2] = 0.0; + cell[1][0] = domain->h[5]; + cell[1][1] = domain->h[1]; + cell[1][2] = 0.0; + cell[2][0] = domain->h[4]; + cell[2][1] = domain->h[3]; + cell[2][2] = domain->h[2]; + + // ----- edge_index and unit_shifts ----- + // count total number of edges + int n_edges = 0; + std::vector n_edges_vec(n_nodes, 0); + #pragma omp parallel for reduction(+:n_edges) + for (int ii=0; iiilist[ii]; + double xtmp = atom->x[i][0]; + double ytmp = atom->x[i][1]; + double ztmp = atom->x[i][2]; + int *jlist = list->firstneigh[i]; + int jnum = list->numneigh[i]; + for (int jj=0; jjx[j][0]; + double dely = ytmp - atom->x[j][1]; + double delz = ztmp - atom->x[j][2]; + double rsq = delx * delx + dely * dely + delz * delz; + if (rsq < r_max_squared) { + n_edges += 1; + n_edges_vec[ii] += 1; + } + } + } + // make first_edge vector to help with parallelizing following loop + std::vector first_edge(n_nodes); + first_edge[0] = 0; + for (int ii=0; iiilist[ii]; + double xtmp = atom->x[i][0]; + double ytmp = atom->x[i][1]; + double ztmp = atom->x[i][2]; + int *jlist = list->firstneigh[i]; + int jnum = list->numneigh[i]; + int k = first_edge[ii]; + for (int jj=0; jjx[j][0]; + double dely = ytmp - atom->x[j][1]; + double delz = ztmp - atom->x[j][2]; + double rsq = delx * delx + dely * dely + delz * delz; + if (rsq < r_max_squared) { + edge_index[0][k] = i; + if (domain_decomposition) { + edge_index[1][k] = j; + } else { + int j_local = atom->map(atom->tag[j]); + edge_index[1][k] = j_local; + double shiftx = atom->x[j][0] - atom->x[j_local][0]; + double shifty = atom->x[j][1] - atom->x[j_local][1]; + double shiftz = atom->x[j][2] - atom->x[j_local][2]; + double shiftxs = std::round(domain->h_inv[0]*shiftx + domain->h_inv[5]*shifty + domain->h_inv[4]*shiftz); + double shiftys = std::round(domain->h_inv[1]*shifty + domain->h_inv[3]*shiftz); + double shiftzs = std::round(domain->h_inv[2]*shiftz); + unit_shifts[k][0] = shiftxs; + unit_shifts[k][1] = shiftys; + unit_shifts[k][2] = shiftzs; + shifts[k][0] = domain->h[0]*shiftxs + domain->h[5]*shiftys + domain->h[4]*shiftzs; + shifts[k][1] = domain->h[1]*shiftys + domain->h[3]*shiftzs; + shifts[k][2] = domain->h[2]*shiftzs; + } + k++; + } + } + } + + // ----- node_attrs ----- + int n_node_feats = mace_atomic_numbers.size(); + auto node_attrs = torch::zeros({n_nodes,n_node_feats}, torch_float_dtype); + #pragma omp parallel for + for (int ii=0; iiilist[ii]; + node_attrs[i][mace_type(atom->type[i])-1] = 1.0; + } + + // ----- mask for ghost ----- + auto mask = torch::zeros(n_nodes, torch::dtype(torch::kBool)); + #pragma omp parallel for + for (int ii=0; iinlocal; ++ii) { + int i = list->ilist[ii]; + mask[i] = true; + } + + auto batch = torch::zeros({n_nodes}, torch::dtype(torch::kInt64)); + auto energy = torch::empty({1}, torch_float_dtype); + auto forces = torch::empty({n_nodes,3}, torch_float_dtype); + auto ptr = torch::empty({2}, torch::dtype(torch::kInt64)); + auto weight = torch::empty({1}, torch_float_dtype); + ptr[0] = 0; + ptr[1] = n_nodes; + weight[0] = 1.0; + + // transfer data to device + batch = batch.to(device); + cell = cell.to(device); + edge_index = edge_index.to(device); + energy = energy.to(device); + forces = forces.to(device); + node_attrs = node_attrs.to(device); + positions = positions.to(device); + ptr = ptr.to(device); + shifts = shifts.to(device); + unit_shifts = unit_shifts.to(device); + weight = weight.to(device); + + // pack the input, call the model + c10::Dict input; + input.insert("batch", batch); + input.insert("cell", cell); + input.insert("edge_index", edge_index); + input.insert("energy", energy); + input.insert("forces", forces); + input.insert("node_attrs", node_attrs); + input.insert("positions", positions); + input.insert("ptr", ptr); + input.insert("shifts", shifts); + input.insert("unit_shifts", unit_shifts); + input.insert("weight", weight); + auto output = model.forward({input, mask.to(device), bool(vflag_global)}).toGenericDict(); + + // mace energy + // -> sum of site energies of local atoms + if (eflag_global) { + energy = output.at("total_energy_local").toTensor().cpu(); + eng_vdwl += energy.item(); + } + + // mace forces + // -> derivatives of total mace energy + forces = output.at("forces").toTensor().cpu(); + #pragma omp parallel for + for (int ii=0; iiilist[ii]; + atom->f[i][0] += forces[i][0].item(); + atom->f[i][1] += forces[i][1].item(); + atom->f[i][2] += forces[i][2].item(); + } + + // mace site energies + // -> local atoms only + if (eflag_atom) { + auto node_energy = output.at("node_energy").toTensor().cpu(); + #pragma omp parallel for + for (int ii=0; iiinum; ++ii) { + int i = list->ilist[ii]; + eatom[i] = node_energy[i].item(); + } + } + + // mace virials (local atoms only) + // -> derivatives of sum of site energies of local atoms + if (vflag_global) { + auto vir = output.at("virials").toTensor().cpu(); + virial[0] += vir[0][0][0].item(); + virial[1] += vir[0][1][1].item(); + virial[2] += vir[0][2][2].item(); + virial[3] += 0.5*(vir[0][1][0].item() + vir[0][0][1].item()); + virial[4] += 0.5*(vir[0][2][0].item() + vir[0][0][2].item()); + virial[5] += 0.5*(vir[0][2][1].item() + vir[0][1][2].item()); + } + + // mace site virials + // -> not available + if (vflag_atom) { + error->all(FLERR, "ERROR: pair_mace does not support vflag_atom."); + } +} + +/* ---------------------------------------------------------------------- */ + +void PairMACE::settings(int narg, char **arg) +{ + if (narg > 1) { + error->all(FLERR, "Too many pair_style arguments for pair_style mace."); + } + + if (narg == 1) { + if (strcmp(arg[0], "no_domain_decomposition") == 0) { + domain_decomposition = false; + // TODO: add check against MPI rank + } else { + error->all(FLERR, "Unrecognized argument for pair_style mace."); + } + } +} + +/* ---------------------------------------------------------------------- */ + +void PairMACE::coeff(int narg, char **arg) +{ + // TODO: remove print statements from this routine, or have a single proc print + + if (!allocated) allocate(); + + if (!torch::cuda::is_available()) { + std::cout << "CUDA unavailable, setting device type to torch::kCPU." << std::endl; + device = c10::Device(torch::kCPU); + } else { + std::cout << "CUDA found, setting device type to torch::kCUDA." << std::endl; + int localrank = 0; // Assume GPU pinning occurs outside of LAMMPS + device = c10::Device(torch::kCUDA, localrank); + } + + std::cout << "Loading MACE model from \"" << arg[2] << "\" ..."; + model = torch::jit::load(arg[2], device); + std::cout << " finished." << std::endl; + + // extract default dtype from mace model + for (auto p: model.named_attributes()) { + // this is a somewhat random choice of variable to check. could it be improved? + if (p.name == "model.node_embedding.linear.weight") { + if (p.value.toTensor().dtype() == caffe2::TypeMeta::Make()) { + torch_float_dtype = torch::kFloat32; + } else if (p.value.toTensor().dtype() == caffe2::TypeMeta::Make()) { + torch_float_dtype = torch::kFloat64; + } + } + } + std::cout << " - The torch_float_dtype is: " << torch_float_dtype << std::endl; + + // extract r_max from mace model + r_max = model.attr("r_max").toTensor().item(); + r_max_squared = r_max*r_max; + std::cout << " - The r_max is: " << r_max << "." << std::endl; + num_interactions = model.attr("num_interactions").toTensor().item(); + std::cout << " - The model has: " << num_interactions << " layers." << std::endl; + + // extract atomic numbers from mace model + auto a_n = model.attr("atomic_numbers").toTensor(); + for (int i=0; i()); + } + std::cout << " - The MACE model atomic numbers are: " << mace_atomic_numbers << "." << std::endl; + + // extract atomic numbers from pair_coeff + for (int i=3; intypes+1; i++) + for (int j=i; jntypes+1; j++) + setflag[i][j] = 1; +} + +void PairMACE::init_style() +{ + if (force->newton_pair == 0) error->all(FLERR, "ERROR: Pair style mace requires newton pair on."); + + /* + MACE requires the full neighbor list AND neighbors of ghost atoms + it appears that: + * without REQ_GHOST + list->gnum == 0 + list->ilist does not include ghost atoms, but the jlists do + * with REQ_GHOST + list->gnum == atom->nghost + list->ilist includes ghost atoms + */ + if (domain_decomposition) { + neighbor->add_request(this, NeighConst::REQ_FULL | NeighConst::REQ_GHOST); + } else { + neighbor->add_request(this, NeighConst::REQ_FULL); + } +} + +double PairMACE::init_one(int i, int j) +{ + // to account for message passing, require cutoff of n_layers * r_max + return num_interactions*model.attr("r_max").toTensor().item(); +} + +void PairMACE::allocate() +{ + allocated = 1; + + memory->create(setflag, atom->ntypes+1, atom->ntypes+1, "pair:setflag"); + for (int i=1; intypes+1; i++) + for (int j=i; jntypes+1; j++) + setflag[i][j] = 0; + + memory->create(cutsq, atom->ntypes+1, atom->ntypes+1, "pair:cutsq"); +} + +int PairMACE::mace_type(int lammps_type) +{ + for (int i=0; iall(FLERR, "Problem converting lammps_type to mace_type."); + return -1; + } diff --git a/src/ML-MACE/pair_mace.h b/src/ML-MACE/pair_mace.h new file mode 100644 index 00000000000..4c8eadf86d6 --- /dev/null +++ b/src/ML-MACE/pair_mace.h @@ -0,0 +1,75 @@ +/* -*- c++ -*- ---------------------------------------------------------- + LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator + https://www.lammps.org/, Sandia National Laboratories + Steve Plimpton, sjplimp@sandia.gov + + Copyright (2003) Sandia Corporation. Under the terms of Contract + DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains + certain rights in this software. This software is distributed under + the GNU General Public License. + + See the README file in the top-level LAMMPS directory. +------------------------------------------------------------------------- */ + +/* ---------------------------------------------------------------------- + Contributors + William C Witt (University of Cambridge) +------------------------------------------------------------------------- */ + +#ifdef PAIR_CLASS +// clang-format off +PairStyle(mace,PairMACE); +// clang-format on +#else + +#ifndef LMP_PAIR_MACE_H +#define LMP_PAIR_MACE_H + +#include "pair.h" + +#include +#include + +namespace LAMMPS_NS { + +class PairMACE : public Pair { + + public: + + PairMACE(class LAMMPS *); + ~PairMACE() override; + void compute(int, int) override; + void settings(int, char **) override; + void coeff(int, char **) override; + void init_style() override; + double init_one(int, int) override; + void allocate(); + + protected: + + bool domain_decomposition = true; + torch::Device device = torch::kCPU; + torch::jit::script::Module model; + torch::ScalarType torch_float_dtype; + double r_max; + double r_max_squared; + int64_t num_interactions; + std::vector mace_atomic_numbers; + std::vector lammps_atomic_numbers; + int mace_type(int lammps_type); + const std::array periodic_table = + { "H", "He", + "Li", "Be", "B", "C", "N", "O", "F", "Ne", + "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", + "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", + "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", + "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", + "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn", + "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", + "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"}; + +}; +} // namespace LAMMPS_NS + +#endif +#endif