LCOV - code coverage report
Current view: top level - maze - Optimizer.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 162 182 89.0 %
Date: 2019-08-13 10:15:31 Functions: 12 13 92.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2019 Jakub Rydzewski (jr@fizyka.umk.pl). All rights reserved.
       3             : 
       4             : See http://www.maze-code.github.io for more information.
       5             : 
       6             : This file is part of maze.
       7             : 
       8             : maze is free software: you can redistribute it and/or modify it under the
       9             : terms of the GNU Lesser General Public License as published by the Free
      10             : Software Foundation, either version 3 of the License, or (at your option)
      11             : any later version.
      12             : 
      13             : maze is distributed in the hope that it will be useful, but WITHOUT ANY
      14             : WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
      15             : FOR A PARTICULAR PURPOSE.
      16             : 
      17             : See the GNU Lesser General Public License for more details.
      18             : 
      19             : You should have received a copy of the GNU Lesser General Public License
      20             : along with maze. If not, see <https://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : 
      23             : /**
      24             :  * @file Optimizer.cpp
      25             :  *
      26             :  * @author J. Rydzewski (jr@fizyka.umk.pl)
      27             :  */
      28             : 
      29             : #include "Optimizer.h"
      30             : #include "core/PlumedMain.h"
      31             : 
      32             : namespace PLMD {
      33             : namespace maze {
      34             : 
      35          12 : void Optimizer::registerKeywords(Keywords& keys) {
      36          12 :   Colvar::registerKeywords(keys);
      37             : 
      38             :   keys.addFlag(
      39             :     "SERIAL",
      40             :     false,
      41             :     "Perform the simulation in serial -- used only for debugging purposes, "
      42             :     "should not be used otherwise."
      43          36 :   );
      44             : 
      45             :   keys.addFlag(
      46             :     "PAIR",
      47             :     false,
      48             :     "Pair only the 1st element of the 1st group with the 1st element in the "
      49             :     "second, etc."
      50          36 :   );
      51             : 
      52             :   keys.addFlag(
      53             :     "NLIST",
      54             :     true,
      55             :     "Use a neighbor list of ligand-protein atom pairs to speed up the "
      56             :     "calculating of the distances."
      57          36 :   );
      58             : 
      59             :   keys.add(
      60             :     "optional",
      61             :     "NL_CUTOFF",
      62             :     "Neighbor list cut-off for the distances of ligand-protein atom pairs."
      63          48 :   );
      64             : 
      65             :   keys.add(
      66             :     "optional",
      67             :     "NL_STRIDE",
      68             :     "Update stride for the ligand-protein atom pairs in the neighbor list."
      69          48 :   );
      70             : 
      71             :   keys.add(
      72             :     "compulsory",
      73             :     "N_ITER",
      74             :     "Number of optimization steps. Required only for optimizers, do not pass "
      75             :     "this keyword to the fake optimizers (results in crash) , e.g., random "
      76             :     "walk, steered MD, or random acceleration MD."
      77          48 :   );
      78             : 
      79             :   keys.add(
      80             :     "optional",
      81             :     "LOSS",
      82             :     "Loss function describing ligand-protein interactions required by every "
      83             :     "optimizer."
      84          48 :   );
      85             : 
      86             :   keys.add(
      87             :     "atoms",
      88             :     "LIGAND",
      89             :     "Indices of ligand atoms."
      90          48 :   );
      91             : 
      92             :   keys.add(
      93             :     "atoms",
      94             :     "PROTEIN",
      95             :     "Indices of protein atoms."
      96          48 :   );
      97             : 
      98             :   keys.add(
      99             :     "compulsory",
     100             :     "OPTIMIZER_STRIDE",
     101             :     "Optimizer stride. Sets up a callback function that launches the "
     102             :     "optimization process every OPTIMIZER_STRIDE."
     103          48 :   );
     104             : 
     105          12 :   componentsAreNotOptional(keys);
     106             : 
     107             :   keys.addOutputComponent(
     108             :     "x",
     109             :     "default",
     110             :     "Optimal biasing direction; x component."
     111          48 :   );
     112             : 
     113             :   keys.addOutputComponent(
     114             :     "y",
     115             :     "default",
     116             :     "Optimal biasing direction; y component."
     117          48 :   );
     118             : 
     119             :   keys.addOutputComponent(
     120             :     "z",
     121             :     "default",
     122             :     "Optimal biasing direction; z component."
     123          48 :   );
     124             : 
     125             :   keys.addOutputComponent(
     126             :     "loss",
     127             :     "default",
     128             :     "Loss function value defined by the provided pairing function."
     129          48 :   );
     130             : 
     131             :   keys.addOutputComponent(
     132             :     "sr",
     133             :     "default",
     134             :     "Sampling radius. Reduces sampling to the local proximity of the ligand "
     135             :     "position."
     136          48 :   );
     137          12 : }
     138             : 
     139           7 : Optimizer::Optimizer(const ActionOptions& ao)
     140             :   : PLUMED_COLVAR_INIT(ao),
     141             :     first_step_(true),
     142             :     opt_value_(0.0),
     143             :     pbc_(true),
     144             :     sampling_r_(0.0),
     145             :     serial_(false),
     146             :     validate_list_(true),
     147          14 :     first_time_(true)
     148             : {
     149          14 :   parseFlag("SERIAL", serial_);
     150             : 
     151          14 :   if (keywords.exists("LOSS")) {
     152           7 :     std::vector<std::string> loss_labels(0);
     153          14 :     parseVector("LOSS", loss_labels);
     154             : 
     155           7 :     plumed_massert(
     156             :       loss_labels.size() > 0,
     157             :       "maze> Something went wrong with the LOSS keyword.\n"
     158           0 :     );
     159             : 
     160           7 :     std::string error_msg = "";
     161          14 :     vec_loss_ = tls::get_pointers_labels<Loss*>(
     162             :                   loss_labels,
     163           7 :                   plumed.getActionSet(),
     164             :                   error_msg
     165             :                 );
     166             : 
     167           7 :     if (error_msg.size() > 0) {
     168           0 :       plumed_merror(
     169             :         "maze> Error in the LOSS keyword " + getName() + ": " + error_msg
     170           0 :       );
     171             :     }
     172             : 
     173           7 :     loss_ = vec_loss_[0];
     174          14 :     log.printf("maze> Loss function linked to the optimizer.\n");
     175             :   }
     176             : 
     177          14 :   if (keywords.exists("N_ITER")) {
     178           6 :     parse("N_ITER", n_iter_);
     179             : 
     180           3 :     plumed_massert(
     181             :       n_iter_ > 0,
     182             :       "maze> N_ITER should be explicitly specified and positive.\n"
     183           0 :     );
     184             : 
     185             :     log.printf(
     186             :       "maze> Optimizer will run %u iterations once launched.\n",
     187             :       n_iter_
     188           3 :     );
     189             :   }
     190             : 
     191             :   std::vector<AtomNumber> ga_list, gb_list;
     192          14 :   parseAtomList("LIGAND", ga_list);
     193          14 :   parseAtomList("PROTEIN", gb_list);
     194             : 
     195           7 :   bool nopbc = !pbc_;
     196          14 :   parseFlag("NOPBC", nopbc);
     197             : 
     198           7 :   bool do_pair = false;
     199          14 :   parseFlag("PAIR", do_pair);
     200             : 
     201           7 :   nl_stride_ = 0;
     202           7 :   bool do_neigh = false;
     203          14 :   parseFlag("NLIST", do_neigh);
     204             : 
     205           7 :   if (do_neigh) {
     206          14 :     if (keywords.exists("NL_CUTOFF")) {
     207          14 :       parse("NL_CUTOFF", nl_cutoff_);
     208             : 
     209           7 :       plumed_massert(
     210             :         nl_cutoff_ > 0,
     211             :         "maze> NL_CUTOFF should be explicitly specified and positive.\n"
     212           0 :       );
     213             :     }
     214             : 
     215          14 :     if (keywords.exists("NL_STRIDE")) {
     216          14 :       parse("NL_STRIDE", nl_stride_);
     217             : 
     218           7 :       plumed_massert(
     219             :         nl_stride_ > 0,
     220             :         "maze> NL_STRIDE should be explicitly specified and positive.\n"
     221           0 :       );
     222             :     }
     223             :   }
     224             : 
     225           7 :   if (gb_list.size() > 0) {
     226           7 :     if (do_neigh) {
     227             :       neighbor_list_ = new NeighborList(
     228             :         ga_list,
     229             :         gb_list,
     230             :         do_pair,
     231             :         pbc_,
     232             :         getPbc(),
     233             :         nl_cutoff_,
     234             :         nl_stride_
     235           7 :       );
     236             :     }
     237             :     else {
     238             :       neighbor_list_=new NeighborList(
     239             :         ga_list,
     240             :         gb_list,
     241             :         do_pair,
     242             :         pbc_,
     243             :         getPbc()
     244           0 :       );
     245             :     }
     246             :   }
     247             :   else {
     248           0 :     if (do_neigh) {
     249             :       neighbor_list_ = new NeighborList(
     250             :         ga_list,
     251             :         pbc_,
     252             :         getPbc(),
     253             :         nl_cutoff_,
     254             :         nl_stride_
     255           0 :       );
     256             :     }
     257             :     else {
     258             :       neighbor_list_=new NeighborList(
     259             :         ga_list,
     260             :         pbc_,
     261             :         getPbc()
     262           0 :       );
     263             :     }
     264             :   }
     265             : 
     266           7 :   requestAtoms(neighbor_list_->getFullAtomList());
     267             : 
     268             :   log.printf(
     269             :     "maze> Loss will be calculated between two groups of %u and %u atoms.\n",
     270             :     static_cast<unsigned>(ga_list.size()),
     271             :     static_cast<unsigned>(gb_list.size())
     272          14 :   );
     273             : 
     274             :   log.printf(
     275             :     "maze> First group (LIGAND): from %d to %d.\n",
     276             :     ga_list[0].serial(),
     277           7 :     ga_list[ga_list.size()-1].serial()
     278           7 :   );
     279             : 
     280           7 :   if (gb_list.size() > 0) {
     281             :     log.printf(
     282             :       "maze> Second group (PROTEIN): from %d to %d.\n",
     283             :       gb_list[0].serial(),
     284           7 :       gb_list[gb_list.size()-1].serial()
     285           7 :     );
     286             :   }
     287             : 
     288           7 :   if (pbc_) {
     289           7 :     log.printf("maze> Using periodic boundary conditions.\n");
     290             :   }
     291             :   else {
     292           0 :     log.printf("maze> Without periodic boundary conditions.\n");
     293             :   }
     294             : 
     295           7 :   if (do_pair) {
     296           0 :     log.printf("maze> With PAIR option.\n");
     297             :   }
     298             : 
     299           7 :   if (do_neigh) {
     300             :     log.printf(
     301             :       "maze> Using neighbor lists updated every %d steps and cutoff %f.\n",
     302             :       nl_stride_,
     303             :       nl_cutoff_
     304           7 :     );
     305             :   }
     306             : 
     307             :   // OpenMP
     308           7 :   stride_ = comm.Get_size();
     309           7 :   rank_ = comm.Get_rank();
     310             : 
     311           7 :   n_threads_ = OpenMP::getNumThreads();
     312           7 :   unsigned int nn = neighbor_list_->size();
     313             : 
     314           7 :   if (n_threads_ * stride_ * 10 > nn) {
     315           0 :     n_threads_ = nn / stride_ / 10;
     316             :   }
     317             : 
     318           7 :   if (n_threads_ == 0) {
     319           0 :     n_threads_ = 1;
     320             :   }
     321             : 
     322          14 :   if (keywords.exists("OPTIMIZER_STRIDE")) {
     323          14 :     parse("OPTIMIZER_STRIDE", optimizer_stride_);
     324             : 
     325           7 :     plumed_massert(
     326             :       optimizer_stride_,
     327             :       "maze> OPTIMIZER_STRIDE should be explicitly specified and positive.\n"
     328           0 :     );
     329             : 
     330             :     log.printf(
     331             :       "maze> Launching optimization every %u steps.\n",
     332             :       optimizer_stride_
     333           7 :     );
     334             :   }
     335             : 
     336           7 :   rnd::randomize();
     337             : 
     338           7 :   opt_.zero();
     339             : 
     340          14 :   addComponentWithDerivatives("x");
     341          14 :   componentIsNotPeriodic("x");
     342             : 
     343          14 :   addComponentWithDerivatives("y");
     344          14 :   componentIsNotPeriodic("y");
     345             : 
     346          14 :   addComponentWithDerivatives("z");
     347          14 :   componentIsNotPeriodic("z");
     348             : 
     349          14 :   addComponent("loss");
     350          14 :   componentIsNotPeriodic("loss");
     351             : 
     352          14 :   addComponent("sr");
     353          14 :   componentIsNotPeriodic("sr");
     354             : 
     355          14 :   value_x_ = getPntrToComponent("x");
     356          14 :   value_y_ = getPntrToComponent("y");
     357          14 :   value_z_ = getPntrToComponent("z");
     358          14 :   value_action_ = getPntrToComponent("loss");
     359          14 :   value_sampling_radius_ = getPntrToComponent("sr");
     360           7 : }
     361             : 
     362    15921922 : double Optimizer::pairing(double distance) const {
     363    15921922 :   return loss_->pairing(distance);
     364             : }
     365             : 
     366           6 : Vector Optimizer::center_of_mass() const {
     367           6 :   const unsigned nl_size = neighbor_list_->size();
     368             : 
     369           6 :   Vector center_of_mass;
     370           6 :   center_of_mass.zero();
     371             :   double mass = 0;
     372             : 
     373      189654 :   for (unsigned int i = 0; i < nl_size; ++i) {
     374      189648 :     unsigned int i0 = neighbor_list_->getClosePair(i).first;
     375      379296 :     center_of_mass += getPosition(i0) * getMass(i0);
     376      189648 :     mass += getMass(i0);
     377             :   }
     378             : 
     379           6 :   return center_of_mass / mass;
     380             : }
     381             : 
     382         210 : void Optimizer::prepare() {
     383         210 :   if (neighbor_list_->getStride() > 0) {
     384         210 :     if (first_time_ || (getStep() % neighbor_list_->getStride() == 0)) {
     385           7 :       requestAtoms(neighbor_list_->getFullAtomList());
     386             : 
     387           7 :       validate_list_ = true;
     388           7 :       first_time_ = false;
     389             :     }
     390             :     else {
     391         203 :       requestAtoms(neighbor_list_->getReducedAtomList());
     392             : 
     393         203 :       validate_list_ = false;
     394             : 
     395         203 :       if (getExchangeStep()) {
     396           0 :         plumed_merror(
     397             :           "maze> Neighbor lists should be updated on exchange steps -- choose "
     398           0 :           "an NL_STRIDE which divides the exchange stride.\n");
     399             :       }
     400             :     }
     401             : 
     402         210 :     if (getExchangeStep()) {
     403           0 :       first_time_ = true;
     404             :     }
     405             :   }
     406         210 : }
     407             : 
     408         226 : double Optimizer::score() {
     409         226 :   const unsigned nl_size = neighbor_list_->size();
     410         226 :   Vector distance;
     411             :   double function = 0;
     412             : 
     413         584 :   #pragma omp parallel num_threads(n_threads_)
     414             :   {
     415         452 :     #pragma omp for reduction(+:function)
     416             :     for(unsigned int i = 0; i < nl_size; i++) {
     417     5171859 :       unsigned i0 = neighbor_list_->getClosePair(i).first;
     418     5173332 :       unsigned i1 = neighbor_list_->getClosePair(i).second;
     419             : 
     420    15527247 :       if (getAbsoluteIndex(i0) == getAbsoluteIndex(i1)) {
     421             :         continue;
     422             :       }
     423             : 
     424     5174939 :       if (pbc_) {
     425     5174939 :         distance = pbcDistance(getPosition(i0), getPosition(i1));
     426             :       }
     427             :       else {
     428           0 :         distance = delta(getPosition(i0), getPosition(i1));
     429             :       }
     430             : 
     431     5175961 :       function += pairing(distance.modulo());
     432             :     }
     433             :   }
     434             : 
     435         226 :   return function;
     436             : }
     437             : 
     438         210 : void Optimizer::update_nl() {
     439         210 :   if (neighbor_list_->getStride() > 0 && validate_list_) {
     440           7 :     neighbor_list_->update(getPositions());
     441             :   }
     442         210 : }
     443             : 
     444         363 : double Optimizer::sampling_radius()
     445             : {
     446         363 :   const unsigned nl_size=neighbor_list_->size();
     447         363 :   Vector d;
     448             :   double min=std::numeric_limits<int>::max();
     449             : 
     450     9685887 :   for (unsigned int i = 0; i < nl_size; ++i) {
     451     9685524 :     unsigned i0 = neighbor_list_->getClosePair(i).first;
     452     9685524 :     unsigned i1 = neighbor_list_->getClosePair(i).second;
     453             : 
     454    29056572 :     if (getAbsoluteIndex(i0) == getAbsoluteIndex(i1)) {
     455             :       continue;
     456             :     }
     457             : 
     458     9685524 :     if (pbc_) {
     459     9685524 :       d = pbcDistance(getPosition(i0), getPosition(i1));
     460             :     }
     461             :     else {
     462           0 :       d = delta(getPosition(i0), getPosition(i1));
     463             :     }
     464             : 
     465     9685524 :     double dist = d.modulo();
     466             : 
     467     9685524 :     if(dist < min) {
     468             :       min = dist;
     469             :     }
     470             :   }
     471             : 
     472         363 :   return min;
     473             : }
     474             : 
     475         210 : void Optimizer::calculate() {
     476         210 :   update_nl();
     477             : 
     478         210 :   if (getStep() % optimizer_stride_ == 0 && !first_step_) {
     479          19 :     optimize();
     480             : 
     481          19 :     value_x_->set(opt_[0]);
     482          19 :     value_y_->set(opt_[1]);
     483          19 :     value_z_->set(opt_[2]);
     484             : 
     485          19 :     value_action_->set(score());
     486          19 :     value_sampling_radius_->set(sampling_radius());
     487             :   }
     488             :   else {
     489         191 :     first_step_=false;
     490             : 
     491         191 :     value_x_->set(opt_[0]);
     492         191 :     value_y_->set(opt_[1]);
     493         191 :     value_z_->set(opt_[2]);
     494             : 
     495         191 :     value_action_->set(score());
     496         191 :     value_sampling_radius_->set(sampling_radius());
     497             :   }
     498         210 : }
     499             : 
     500             : } // namespace maze
     501        5874 : } // namespace PLMD

Generated by: LCOV version 1.14