Loading [MathJax]/extensions/tex2jax.js
libTriton version 1.0 build 1599
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
tritonToLLVM.cpp
Go to the documentation of this file.
1
2/*
3** Copyright (C) - Triton
4**
5** This program is under the terms of the Apache License 2.0.
6*/
7
8#include <algorithm>
9#include <map>
10#include <vector>
11
12#include <triton/astEnums.hpp>
13#include <triton/exceptions.hpp>
18
19#include <llvm/IR/PassManager.h>
20#include <llvm/Analysis/LoopAnalysisManager.h>
21#include <llvm/Analysis/CGSCCPassManager.h>
22#include <llvm/Passes/PassBuilder.h>
23
24
25namespace triton {
26 namespace ast {
27
28 TritonToLLVM::TritonToLLVM(llvm::LLVMContext& llvmContext)
29 : llvmContext(llvmContext), llvmIR(this->llvmContext) {
30 this->llvmModule = std::make_shared<llvm::Module>("tritonModule", this->llvmContext);
31 if (llvmModule == nullptr) {
32 triton::exceptions::LiftingEngine("TritonToLLVM::TritonToLLVM: Failed to allocate the LLVM Module");
33 }
34 }
35
36
37 void TritonToLLVM::createFunction(const triton::ast::SharedAbstractNode& node, const char* fname) {
38 // Collect used symbolic variables.
40
42 std::sort(vars.begin(), vars.end(), [](const triton::ast::SharedAbstractNode& a, const triton::ast::SharedAbstractNode& b) {
43 auto varA = reinterpret_cast<triton::ast::VariableNode*>(a.get())->getSymbolicVariable();
44 auto varB = reinterpret_cast<triton::ast::VariableNode*>(b.get())->getSymbolicVariable();
45 return *varA < *varB;
46 });
47
48 // Each symbolic variable is a function argument
49 std::vector<llvm::Type*> argsType;
50 argsType.resize(vars.size());
51 for (triton::usize index = 0 ; index < vars.size() ; index++) {
52 switch (vars[index]->getBitvectorSize()) {
53 case 8:
54 argsType[index] = llvm::Type::getInt8Ty(this->llvmContext);
55 break;
56 case 16:
57 argsType[index] = llvm::Type::getInt16Ty(this->llvmContext);
58 break;
59 case 32:
60 argsType[index] = llvm::Type::getInt32Ty(this->llvmContext);
61 break;
62 case 64:
63 argsType[index] = llvm::Type::getInt64Ty(this->llvmContext);
64 break;
65 default:
66 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Symbolic variables must be aligned on 8, 16, 32 or 64 bit.");
67 }
68 }
69
70 /* Declare LLVM function */
71 auto retSize = node->getBitvectorSize();
72 auto* retType = llvm::IntegerType::get(this->llvmContext, retSize);
73 auto* funcType = llvm::FunctionType::get(retType, argsType, false /* isVarArg */);
74 auto* llvmFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, fname, this->llvmModule.get());
75
76 /* Rename parameters */
77 llvm::Function::arg_iterator params = llvmFunc->arg_begin();
78 for (const auto& node : vars) {
79 auto var = reinterpret_cast<triton::ast::VariableNode*>(node.get())->getSymbolicVariable();
80 auto* param = params++;
81 param->setName(var->getName());
82 this->llvmVars[node] = param;
83 }
84
85 // A Triton expression is represented as one basic block
86 auto* llvmBasicBlock = llvm::BasicBlock::Create(this->llvmContext, "entry", llvmFunc);
87 this->llvmIR.SetInsertPoint(llvmBasicBlock);
88 }
89
90
91 std::shared_ptr<llvm::Module> TritonToLLVM::convert(const triton::ast::SharedAbstractNode& node, const char* fname, bool optimize) {
92 std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*> results;
93
94 /* Create the LLVM function */
95 this->createFunction(node, fname);
96
97 /* Lift Triton AST to LLVM IR */
98 auto nodes = triton::ast::childrenExtraction(node, true /* unroll*/, true /* revert */);
99 for (const auto& node : nodes) {
100 if (node->getBitvectorSize()) {
101 results.insert(std::make_pair(node, this->do_convert(node, &results)));
102 }
103 }
104
105 /* Create the return instruction */
106 this->llvmIR.CreateRet(results.at(node));
107
108 /* Apply LLVM optimizations (-03 -Oz) if enabled */
109 if (optimize) {
110 llvm::LoopAnalysisManager lam;
111 llvm::FunctionAnalysisManager fam;
112 llvm::CGSCCAnalysisManager cgam;
113 llvm::ModuleAnalysisManager mam;
114
115 llvm::PassBuilder pb;
116
117 pb.registerModuleAnalyses(mam);
118 pb.registerCGSCCAnalyses(cgam);
119 pb.registerFunctionAnalyses(fam);
120 pb.registerLoopAnalyses(lam);
121 pb.crossRegisterProxies(lam, fam, cgam, mam);
122
123 llvm::ModulePassManager pm = pb.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3);
124 pm.run(*this->llvmModule, mam);
125 }
126
127 return this->llvmModule;
128 }
129
130
131 llvm::Value* TritonToLLVM::do_convert(const triton::ast::SharedAbstractNode& node, std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*>* results) {
132 if (node == nullptr)
133 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): node cannot be null.");
134
135 /* Prepare llvm's children */
136 std::vector<llvm::Value*> children;
137 for (auto&& n : node->getChildren()) {
138 /* Ignore children like INTEGER_NODE */
139 if (n->getBitvectorSize() == 0) {
140 children.emplace_back(nullptr);
141 }
142 else {
143 children.emplace_back(results->at(n));
144 }
145 }
146
147 switch (node->getType()) {
148
150 llvm::Function* bswap = nullptr;
151 switch (node->getBitvectorSize()) {
152 case triton::bitsize::byte: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt8Ty(this->llvmContext)); break;
153 case triton::bitsize::word: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt16Ty(this->llvmContext)); break;
154 case triton::bitsize::dword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt32Ty(this->llvmContext)); break;
155 case triton::bitsize::qword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt64Ty(this->llvmContext)); break;
156 default:
157 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Invalid bswap size.");
158 }
159 return this->llvmIR.CreateCall(bswap, children[0]);
160 }
161
163 return this->llvmIR.CreateAdd(children[0], children[1]);
164
166 return this->llvmIR.CreateAnd(children[0], children[1]);
167
169 return this->llvmIR.CreateAShr(children[0], children[1]);
170
172 return this->llvmIR.CreateLShr(children[0], children[1]);
173
175 return this->llvmIR.CreateMul(children[0], children[1]);
176
178 return this->llvmIR.CreateNot(this->llvmIR.CreateAnd(children[0], children[1]));
179
181 return this->llvmIR.CreateNeg(children[0]);
182
184 return this->llvmIR.CreateNot(this->llvmIR.CreateOr(children[0], children[1]));
185
187 return this->llvmIR.CreateNot(children[0]);
188
190 return this->llvmIR.CreateOr(children[0], children[1]);
191
192 // bvrol(expr, rot) = ((expr << (rot % size)) | (expr >> (size - (rot % size))))
194 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
195 auto size = node->getBitvectorSize();
196 return this->llvmIR.CreateOr(this->llvmIR.CreateShl(children[0], rot % size), this->llvmIR.CreateLShr(children[0], (size - (rot % size))));
197 }
198
199 // bvror(expr, rot) = ((expr >> (rot % size)) | (expr << (size - (rot % size))))
201 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
202 auto size = node->getBitvectorSize();
203 return this->llvmIR.CreateOr(this->llvmIR.CreateLShr(children[0], rot % size), this->llvmIR.CreateShl(children[0], (size - (rot % size))));
204 }
205
207 return this->llvmIR.CreateSDiv(children[0], children[1]);
208
210 return this->llvmIR.CreateICmpSGE(children[0], children[1]);
211
213 return this->llvmIR.CreateICmpSGT(children[0], children[1]);
214
216 return this->llvmIR.CreateShl(children[0], children[1]);
217
219 return this->llvmIR.CreateICmpSLE(children[0], children[1]);
220
222 return this->llvmIR.CreateICmpSLT(children[0], children[1]);
223
225 auto* LHS = children[0];
226 auto* RHS = children[1];
227 return this->llvmIR.CreateSRem(this->llvmIR.CreateAdd(this->llvmIR.CreateSRem(LHS, RHS), RHS), RHS);
228 }
229
231 return this->llvmIR.CreateSRem(children[0], children[1]);
232
234 return this->llvmIR.CreateSub(children[0], children[1]);
235
237 return this->llvmIR.CreateUDiv(children[0], children[1]);
238
240 return this->llvmIR.CreateICmpUGE(children[0], children[1]);
241
243 return this->llvmIR.CreateICmpUGT(children[0], children[1]);
244
246 return this->llvmIR.CreateICmpULE(children[0], children[1]);
247
249 return this->llvmIR.CreateICmpULT(children[0], children[1]);
250
252 return this->llvmIR.CreateURem(children[0], children[1]);
253
255 return this->llvmIR.CreateNot(this->llvmIR.CreateXor(children[0], children[1]));
256
258 return this->llvmIR.CreateXor(children[0], children[1]);
259
261 return llvm::ConstantInt::get(this->llvmContext, llvm::APInt(node->getBitvectorSize(), static_cast<uint64_t>(node->evaluate()), false));
262
264 auto dstSize = node->getBitvectorSize();
265 auto finalNode = this->llvmIR.CreateZExt(children[0], llvm::IntegerType::get(this->llvmContext, dstSize));
266
267 for (triton::usize index = 1; index < children.size(); index++) {
268 finalNode = this->llvmIR.CreateShl(finalNode, node->getChildren()[index]->getBitvectorSize());
269 auto* n = this->llvmIR.CreateZExt(children[index], llvm::IntegerType::get(this->llvmContext, dstSize));
270 finalNode = this->llvmIR.CreateOr(finalNode, n);
271 }
272
273 return finalNode;
274 }
275
277 return this->llvmIR.CreateICmpNE(children[0], children[1]);
278
280 return this->llvmIR.CreateICmpEQ(children[0], children[1]);
281
283 auto low = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
284 auto dstSize = node->getChildren()[2]->getBitvectorSize();
285 auto* value = children[2];
286
287 if (low == 0) {
288 return this->llvmIR.CreateTrunc(value, llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
289 }
290
291 return this->llvmIR.CreateTrunc(this->llvmIR.CreateLShr(value, low), llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
292 }
293
295 return this->llvmIR.CreateSelect(children[0], children[1], children[2]);
296
298 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
299 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateAnd(children), truenode);
300 }
301
303 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
304 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateNot(children[0]), truenode);
305 }
306
308 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
309 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateOr(children), truenode);
310 }
311
313 auto* child0 = children[0];
314 auto* child1 = children[1];
315 auto* current = this->llvmIR.CreateXor(child0, child1);
316
317 for (triton::usize index = 2; index < children.size(); index++) {
318 current = this->llvmIR.CreateXor(current, children[index]);
319 }
320
321 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
322 return this->llvmIR.CreateICmpEQ(current, truenode);
323 }
324
326 return results->at(reinterpret_cast<triton::ast::ReferenceNode*>(node.get())->getSymbolicExpression()->getAst());
327
329 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8Ty(this->llvmContext)->getPointerTo());
330 return this->llvmIR.CreateLoad(llvm::Type::getInt8Ty(this->llvmContext), ptr);
331 }
332
334 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8Ty(this->llvmContext)->getPointerTo());
335 return this->llvmIR.CreateStore(children[2], ptr);
336 }
337
339 return this->llvmIR.CreateSExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
340
342 return this->llvmVars.at(node);
343
345 return this->llvmIR.CreateZExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
346
347 default:
348 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Invalid kind of node.");
349 }
350 }
351
352 }; /* ast namespace */
353}; /* triton namespace */
Reference node.
Definition ast.hpp:789
TRITON_EXPORT std::shared_ptr< llvm::Module > convert(const triton::ast::SharedAbstractNode &node, const char *fname="__triton", bool optimize=false)
Lifts a symbolic expression and all its references to LLVM format. fname represents the name of the L...
TRITON_EXPORT TritonToLLVM(llvm::LLVMContext &llvmContext)
Constructor.
Variable node.
Definition ast.hpp:878
The exception class used by all AST lifting (e.g z3 <-> triton).
The exception class used by the lifting engine.
std::vector< SharedAbstractNode > childrenExtraction(const SharedAbstractNode &node, bool unroll, bool revert)
Returns node and all its children of an AST sorted topologically. If unroll is true,...
Definition ast.cpp:3700
std::shared_ptr< triton::ast::AbstractNode > SharedAbstractNode
Shared Abstract Node.
Definition ast.hpp:59
std::deque< SharedAbstractNode > search(const SharedAbstractNode &node, triton::ast::ast_e match)
Returns a deque of collected matched nodes via a depth-first pre order traversal.
Definition ast.cpp:3710
constexpr triton::uint32 byte
byte size in bit
Definition cpuSize.hpp:60
constexpr triton::uint32 dword
dword size in bit
Definition cpuSize.hpp:64
constexpr triton::uint32 qword
qword size in bit
Definition cpuSize.hpp:66
constexpr triton::uint32 word
word size in bit
Definition cpuSize.hpp:62
std::size_t usize
unsigned MAX_INT 32 or 64 bits according to the CPU.
The Triton namespace.