Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : * -------------------------------------------------------------------------- *
3 : * Lepton *
4 : * -------------------------------------------------------------------------- *
5 : * This is part of the Lepton expression parser originating from *
6 : * Simbios, the NIH National Center for Physics-Based Simulation of *
7 : * Biological Structures at Stanford, funded under the NIH Roadmap for *
8 : * Medical Research, grant U54 GM072970. See https://simtk.org. *
9 : * *
10 : * Portions copyright (c) 2013-2016 Stanford University and the Authors. *
11 : * Authors: Peter Eastman *
12 : * Contributors: *
13 : * *
14 : * Permission is hereby granted, free of charge, to any person obtaining a *
15 : * copy of this software and associated documentation files (the "Software"), *
16 : * to deal in the Software without restriction, including without limitation *
17 : * the rights to use, copy, modify, merge, publish, distribute, sublicense, *
18 : * and/or sell copies of the Software, and to permit persons to whom the *
19 : * Software is furnished to do so, subject to the following conditions: *
20 : * *
21 : * The above copyright notice and this permission notice shall be included in *
22 : * all copies or substantial portions of the Software. *
23 : * *
24 : * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
25 : * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
26 : * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
27 : * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
28 : * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
29 : * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
30 : * USE OR OTHER DEALINGS IN THE SOFTWARE. *
31 : * -------------------------------------------------------------------------- *
32 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
33 : /* -------------------------------------------------------------------------- *
34 : * lepton *
35 : * -------------------------------------------------------------------------- *
36 : * This is part of the lepton expression parser originating from *
37 : * Simbios, the NIH National Center for Physics-Based Simulation of *
38 : * Biological Structures at Stanford, funded under the NIH Roadmap for *
39 : * Medical Research, grant U54 GM072970. See https://simtk.org. *
40 : * *
41 : * Portions copyright (c) 2013-2016 Stanford University and the Authors. *
42 : * Authors: Peter Eastman *
43 : * Contributors: *
44 : * *
45 : * Permission is hereby granted, free of charge, to any person obtaining a *
46 : * copy of this software and associated documentation files (the "Software"), *
47 : * to deal in the Software without restriction, including without limitation *
48 : * the rights to use, copy, modify, merge, publish, distribute, sublicense, *
49 : * and/or sell copies of the Software, and to permit persons to whom the *
50 : * Software is furnished to do so, subject to the following conditions: *
51 : * *
52 : * The above copyright notice and this permission notice shall be included in *
53 : * all copies or substantial portions of the Software. *
54 : * *
55 : * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
56 : * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
57 : * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
58 : * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
59 : * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
60 : * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
61 : * USE OR OTHER DEALINGS IN THE SOFTWARE. *
62 : * -------------------------------------------------------------------------- */
63 :
64 : #include "CompiledExpression.h"
65 : #include "Operation.h"
66 : #include "ParsedExpression.h"
67 : #ifdef __PLUMED_HAS_ASMJIT
68 : #include "asmjit/asmjit.h"
69 : #endif
70 : #include <utility>
71 :
72 : namespace PLMD {
73 : using namespace lepton;
74 : using namespace std;
75 : #ifdef __PLUMED_HAS_ASMJIT
76 : using namespace asmjit;
77 : #endif
78 :
79 80 : bool lepton::useAsmJit() {
80 : #ifdef __PLUMED_HAS_ASMJIT
81 : static const bool use=[](){
82 : if(auto s=std::getenv("PLUMED_USE_ASMJIT")) {
83 : auto ss=std::string(s);
84 : if(ss=="yes") return true;
85 : if(ss=="no") return false;
86 : throw Exception("PLUMED_USE_ASMJIT variable is set to " + ss + "; should be yes or no");
87 : }
88 : return true; // by default use asmjit
89 : }();
90 : return use;
91 : #else
92 80 : return false;
93 : #endif
94 : }
95 :
96 1668 : AsmJitRuntimePtr::AsmJitRuntimePtr()
97 : #ifdef __PLUMED_HAS_ASMJIT
98 : : ptr(useAsmJit()?new asmjit::JitRuntime:nullptr)
99 : #endif
100 1668 : {}
101 :
102 1668 : AsmJitRuntimePtr::~AsmJitRuntimePtr()
103 : {
104 : #ifdef __PLUMED_HAS_ASMJIT
105 : if(useAsmJit()) delete static_cast<asmjit::JitRuntime*>(ptr);
106 : #endif
107 1668 : }
108 :
109 2544 : CompiledExpression::CompiledExpression() : jitCode(NULL) {
110 848 : }
111 :
112 2460 : CompiledExpression::CompiledExpression(const ParsedExpression& expression) : jitCode(NULL) {
113 820 : ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
114 820 : vector<pair<ExpressionTreeNode, int> > temps;
115 820 : compileExpression(expr.getRootNode(), temps);
116 : int maxArguments = 1;
117 10754 : for (int i = 0; i < (int) operation.size(); i++)
118 9934 : if (operation[i]->getNumArguments() > maxArguments)
119 540 : maxArguments = operation[i]->getNumArguments();
120 820 : argValues.resize(maxArguments);
121 : #ifdef __PLUMED_HAS_ASMJIT
122 : if(useAsmJit()) generateJitCode();
123 : #endif
124 820 : }
125 :
126 5004 : CompiledExpression::~CompiledExpression() {
127 23204 : for (int i = 0; i < (int) operation.size(); i++)
128 19868 : if (operation[i] != NULL)
129 9934 : delete operation[i];
130 1668 : }
131 :
132 0 : CompiledExpression::CompiledExpression(const CompiledExpression& expression) : jitCode(NULL) {
133 0 : *this = expression;
134 0 : }
135 :
136 820 : CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expression) {
137 820 : arguments = expression.arguments;
138 820 : target = expression.target;
139 : variableIndices = expression.variableIndices;
140 : variableNames = expression.variableNames;
141 820 : workspace.resize(expression.workspace.size());
142 820 : argValues.resize(expression.argValues.size());
143 820 : operation.resize(expression.operation.size());
144 11574 : for (int i = 0; i < (int) operation.size(); i++)
145 9934 : operation[i] = expression.operation[i]->clone();
146 820 : setVariableLocations(variablePointers);
147 820 : return *this;
148 : }
149 :
150 7552 : void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
151 7552 : if (findTempIndex(node, temps) != -1)
152 1663 : return; // We have already processed a node identical to this one.
153 :
154 : // Process the child nodes.
155 :
156 : vector<int> args;
157 25242 : for (int i = 0; i < node.getChildren().size(); i++) {
158 13464 : compileExpression(node.getChildren()[i], temps);
159 20196 : args.push_back(findTempIndex(node.getChildren()[i], temps));
160 : }
161 :
162 : // Process this node.
163 :
164 5889 : if (node.getOperation().getId() == Operation::VARIABLE) {
165 2766 : variableIndices[node.getOperation().getName()] = (int) workspace.size();
166 1844 : variableNames.insert(node.getOperation().getName());
167 : }
168 : else {
169 4967 : int stepIndex = (int) arguments.size();
170 9934 : arguments.push_back(vector<int>());
171 9934 : target.push_back((int) workspace.size());
172 9934 : operation.push_back(node.getOperation().clone());
173 4967 : if (args.size() == 0)
174 158 : arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
175 : else {
176 : // If the arguments are sequential, we can just pass a pointer to the first one.
177 :
178 : bool sequential = true;
179 8576 : for (int i = 1; i < args.size(); i++)
180 3688 : if (args[i] != args[i-1]+1)
181 : sequential = false;
182 4888 : if (sequential)
183 7596 : arguments[stepIndex].push_back(args[0]);
184 : else
185 2180 : arguments[stepIndex] = args;
186 : }
187 : }
188 11778 : temps.push_back(make_pair(node, (int) workspace.size()));
189 11778 : workspace.push_back(0.0);
190 : }
191 :
192 14284 : int CompiledExpression::findTempIndex(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
193 110034 : for (int i = 0; i < (int) temps.size(); i++)
194 98256 : if (temps[i].first == node)
195 : return i;
196 : return -1;
197 : }
198 :
199 348 : const set<string>& CompiledExpression::getVariables() const {
200 348 : return variableNames;
201 : }
202 :
203 1391 : double& CompiledExpression::getVariableReference(const string& name) {
204 : map<string, double*>::iterator pointer = variablePointers.find(name);
205 1391 : if (pointer != variablePointers.end())
206 0 : return *pointer->second;
207 : map<string, int>::iterator index = variableIndices.find(name);
208 1391 : if (index == variableIndices.end())
209 1876 : throw Exception("getVariableReference: Unknown variable '"+name+"'");
210 1844 : return workspace[index->second];
211 : }
212 :
213 820 : void CompiledExpression::setVariableLocations(map<string, double*>& variableLocations) {
214 : variablePointers = variableLocations;
215 820 : static const bool asmjit=useAsmJit();
216 820 : if(asmjit) {
217 : #ifdef __PLUMED_HAS_ASMJIT
218 : // Rebuild the JIT code.
219 :
220 : if (workspace.size() > 0)
221 : generateJitCode();
222 : #endif
223 : } else {
224 : // Make a list of all variables we will need to copy before evaluating the expression.
225 :
226 : variablesToCopy.clear();
227 2562 : for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
228 922 : map<string, double*>::iterator pointer = variablePointers.find(iter->first);
229 922 : if (pointer != variablePointers.end())
230 0 : variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
231 : }
232 : }
233 820 : }
234 :
235 16263599 : double CompiledExpression::evaluate() const {
236 16263599 : static const bool asmjit=useAsmJit();
237 : #ifdef __PLUMED_HAS_ASMJIT
238 : if(asmjit) return ((double (*)()) jitCode)();
239 : #endif
240 16474954 : for (int i = 0; i < variablesToCopy.size(); i++)
241 0 : *variablesToCopy[i].first = *variablesToCopy[i].second;
242 :
243 : // Loop over the operations and evaluate each one.
244 :
245 154793088 : for (int step = 0; step < operation.size(); step++) {
246 : const vector<int>& args = arguments[step];
247 69363684 : if (args.size() == 1)
248 187044087 : workspace[target[step]] = operation[step]->evaluate(&workspace[args[0]], dummyVariables);
249 : else {
250 35036101 : for (int i = 0; i < args.size(); i++)
251 28020446 : argValues[i] = workspace[args[i]];
252 21046965 : workspace[target[step]] = operation[step]->evaluate(&argValues[0], dummyVariables);
253 : }
254 : }
255 32540674 : return workspace[workspace.size()-1];
256 : }
257 :
258 : #ifdef __PLUMED_HAS_ASMJIT
259 : static double evaluateOperation(Operation* op, double* args) {
260 : static map<string, double> dummyVariables;
261 : return op->evaluate(args, dummyVariables);
262 : }
263 :
264 : static void generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double));
265 :
266 : void CompiledExpression::generateJitCode() {
267 : CodeHolder code;
268 : auto & runtime(*static_cast<asmjit::JitRuntime*>(runtimeptr.get()));
269 : code.init(runtime.getCodeInfo());
270 : X86Assembler a(&code);
271 : X86Compiler c(&code);
272 : c.addFunc(FuncSignature0<double>());
273 : vector<X86Xmm> workspaceVar(workspace.size());
274 : for (int i = 0; i < (int) workspaceVar.size(); i++)
275 : workspaceVar[i] = c.newXmmSd();
276 : X86Gp argsPointer = c.newIntPtr();
277 : c.mov(argsPointer, imm_ptr(&argValues[0]));
278 :
279 : // Load the arguments into variables.
280 :
281 : for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
282 : map<string, int>::iterator index = variableIndices.find(*iter);
283 : X86Gp variablePointer = c.newIntPtr();
284 : c.mov(variablePointer, imm_ptr(&getVariableReference(index->first)));
285 : c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
286 : }
287 :
288 : // Make a list of all constants that will be needed for evaluation.
289 :
290 : vector<int> operationConstantIndex(operation.size(), -1);
291 : for (int step = 0; step < (int) operation.size(); step++) {
292 : // Find the constant value (if any) used by this operation.
293 :
294 : Operation& op = *operation[step];
295 : double value;
296 : if (op.getId() == Operation::CONSTANT)
297 : value = dynamic_cast<Operation::Constant&>(op).getValue();
298 : else if (op.getId() == Operation::ADD_CONSTANT)
299 : value = dynamic_cast<Operation::AddConstant&>(op).getValue();
300 : else if (op.getId() == Operation::MULTIPLY_CONSTANT)
301 : value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
302 : else if (op.getId() == Operation::RECIPROCAL)
303 : value = 1.0;
304 : else if (op.getId() == Operation::STEP)
305 : value = 1.0;
306 : else if (op.getId() == Operation::DELTA)
307 : value = 1.0/0.0;
308 : else
309 : continue;
310 :
311 : // See if we already have a variable for this constant.
312 :
313 : for (int i = 0; i < (int) constants.size(); i++)
314 : if (value == constants[i]) {
315 : operationConstantIndex[step] = i;
316 : break;
317 : }
318 : if (operationConstantIndex[step] == -1) {
319 : operationConstantIndex[step] = constants.size();
320 : constants.push_back(value);
321 : }
322 : }
323 :
324 : // Load constants into variables.
325 :
326 : vector<X86Xmm> constantVar(constants.size());
327 : if (constants.size() > 0) {
328 : X86Gp constantsPointer = c.newIntPtr();
329 : c.mov(constantsPointer, imm_ptr(&constants[0]));
330 : for (int i = 0; i < (int) constants.size(); i++) {
331 : constantVar[i] = c.newXmmSd();
332 : c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
333 : }
334 : }
335 :
336 : // Evaluate the operations.
337 :
338 : for (int step = 0; step < (int) operation.size(); step++) {
339 : Operation& op = *operation[step];
340 : vector<int> args = arguments[step];
341 : if (args.size() == 1) {
342 : // One or more sequential arguments. Fill out the list.
343 :
344 : for (int i = 1; i < op.getNumArguments(); i++)
345 : args.push_back(args[0]+i);
346 : }
347 :
348 : // Generate instructions to execute this operation.
349 :
350 : switch (op.getId()) {
351 : case Operation::CONSTANT:
352 : c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
353 : break;
354 : case Operation::ADD:
355 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
356 : c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
357 : break;
358 : case Operation::SUBTRACT:
359 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
360 : c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
361 : break;
362 : case Operation::MULTIPLY:
363 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
364 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
365 : break;
366 : case Operation::DIVIDE:
367 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
368 : c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
369 : break;
370 : case Operation::NEGATE:
371 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
372 : c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
373 : break;
374 : case Operation::SQRT:
375 : c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
376 : break;
377 : case Operation::EXP:
378 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
379 : break;
380 : case Operation::LOG:
381 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
382 : break;
383 : case Operation::SIN:
384 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
385 : break;
386 : case Operation::COS:
387 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
388 : break;
389 : case Operation::TAN:
390 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
391 : break;
392 : case Operation::ASIN:
393 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
394 : break;
395 : case Operation::ACOS:
396 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
397 : break;
398 : case Operation::ATAN:
399 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
400 : break;
401 : case Operation::SINH:
402 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
403 : break;
404 : case Operation::COSH:
405 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
406 : break;
407 : case Operation::TANH:
408 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
409 : break;
410 : case Operation::STEP:
411 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
412 : c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
413 : c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
414 : break;
415 : case Operation::DELTA:
416 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
417 : c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
418 : c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
419 : break;
420 : case Operation::SQUARE:
421 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
422 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
423 : break;
424 : case Operation::CUBE:
425 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
426 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
427 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
428 : break;
429 : case Operation::RECIPROCAL:
430 : c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
431 : c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
432 : break;
433 : case Operation::ADD_CONSTANT:
434 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
435 : c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
436 : break;
437 : case Operation::MULTIPLY_CONSTANT:
438 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
439 : c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
440 : break;
441 : case Operation::ABS:
442 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
443 : break;
444 : case Operation::FLOOR:
445 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
446 : break;
447 : case Operation::CEIL:
448 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
449 : break;
450 : default:
451 : // Just invoke evaluateOperation().
452 :
453 : for (int i = 0; i < (int) args.size(); i++)
454 : c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
455 : X86Gp fn = c.newIntPtr();
456 : c.mov(fn, imm_ptr((void*) evaluateOperation));
457 : CCFuncCall* call = c.call(fn, FuncSignature2<double, Operation*, double*>(CallConv::kIdHost));
458 : call->setArg(0, imm_ptr(&op));
459 : call->setArg(1, imm_ptr(&argValues[0]));
460 : call->setRet(0, workspaceVar[target[step]]);
461 : }
462 : }
463 : c.ret(workspaceVar[workspace.size()-1]);
464 : c.endFunc();
465 : c.finalize();
466 : typedef double (*Func0)(void);
467 : Func0 func0;
468 : Error err = runtime.add(&func0,&code);
469 : if(err) return;
470 : jitCode = (void*) func0;
471 : }
472 :
473 : void generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) {
474 : X86Gp fn = c.newIntPtr();
475 : c.mov(fn, imm_ptr((void*) function));
476 : CCFuncCall* call = c.call(fn, FuncSignature1<double, double>(CallConv::kIdHost));
477 : call->setArg(0, arg);
478 : call->setRet(0, dest);
479 : }
480 : #endif
481 : }
|