libTriton version 1.0 build 1592
Loading...
Searching...
No Matches
tritonToLLVM.cpp
Go to the documentation of this file.
1
2/*
3** Copyright (C) - Triton
4**
5** This program is under the terms of the Apache License 2.0.
6*/
7
8#include <algorithm>
9#include <list>
10#include <map>
11#include <vector>
12
13#include <triton/astEnums.hpp>
14#include <triton/exceptions.hpp>
19
20
21
22namespace triton {
23 namespace ast {
24
25 TritonToLLVM::TritonToLLVM(llvm::LLVMContext& llvmContext)
26 : llvmContext(llvmContext), llvmIR(this->llvmContext) {
27 this->llvmModule = std::make_shared<llvm::Module>("tritonModule", this->llvmContext);
28 if (llvmModule == nullptr) {
29 triton::exceptions::LiftingEngine("TritonToLLVM::TritonToLLVM: Failed to allocate the LLVM Module");
30 }
31 }
32
33
34 void TritonToLLVM::createFunction(const triton::ast::SharedAbstractNode& node, const char* fname) {
35 // Collect used symbolic variables.
37
39 std::sort(vars.begin(), vars.end());
40
41 // Each symbolic variable is a function argument
42 std::vector<llvm::Type*> argsType;
43 argsType.resize(vars.size());
44 for (triton::usize index = 0 ; index < vars.size() ; index++) {
45 switch (vars[index]->getBitvectorSize()) {
46 case 8:
47 argsType[index] = llvm::Type::getInt8Ty(this->llvmContext);
48 break;
49 case 16:
50 argsType[index] = llvm::Type::getInt16Ty(this->llvmContext);
51 break;
52 case 32:
53 argsType[index] = llvm::Type::getInt32Ty(this->llvmContext);
54 break;
55 case 64:
56 argsType[index] = llvm::Type::getInt64Ty(this->llvmContext);
57 break;
58 default:
59 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Symbolic variables must be aligned on 8, 16, 32 or 64 bit.");
60 }
61 }
62
63 /* Declare LLVM function */
64 auto retSize = node->getBitvectorSize();
65 auto* retType = llvm::IntegerType::get(this->llvmContext, retSize);
66 auto* funcType = llvm::FunctionType::get(retType, argsType, false /* isVarArg */);
67 auto* llvmFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, fname, this->llvmModule.get());
68
69 /* Rename parameters */
70 llvm::Function::arg_iterator params = llvmFunc->arg_begin();
71 for (const auto& node : vars) {
72 auto var = reinterpret_cast<triton::ast::VariableNode*>(node.get())->getSymbolicVariable();
73 auto* param = params++;
74 param->setName(var->getName());
75 this->llvmVars[node] = param;
76 }
77
78 // A Triton expression is represented as one basic block
79 auto* llvmBasicBlock = llvm::BasicBlock::Create(this->llvmContext, "entry", llvmFunc);
80 this->llvmIR.SetInsertPoint(llvmBasicBlock);
81 }
82
83
84 std::shared_ptr<llvm::Module> TritonToLLVM::convert(const triton::ast::SharedAbstractNode& node, const char* fname, bool optimize) {
85 std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*> results;
86
87 /* Create the LLVM function */
88 this->createFunction(node, fname);
89
90 /* Lift Triton AST to LLVM IR */
91 auto nodes = triton::ast::childrenExtraction(node, true /* unroll*/, true /* revert */);
92 for (const auto& node : nodes) {
93 if (node->getBitvectorSize()) {
94 results.insert(std::make_pair(node, this->do_convert(node, &results)));
95 }
96 }
97
98 /* Create the return instruction */
99 this->llvmIR.CreateRet(results.at(node));
100
101 /* Apply LLVM optimizations (-03 -Oz) if enabled */
102 if (optimize) {
103 llvm::legacy::PassManager pm;
104 llvm::PassManagerBuilder pmb;
105 pmb.OptLevel = 3;
106 pmb.SizeLevel = 2;
107 pmb.populateModulePassManager(pm);
108 pm.run(*this->llvmModule);
109 }
110
111 return this->llvmModule;
112 }
113
114
115 llvm::Value* TritonToLLVM::do_convert(const triton::ast::SharedAbstractNode& node, std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*>* results) {
116 if (node == nullptr)
117 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): node cannot be null.");
118
119 /* Prepare llvm's children */
120 std::vector<llvm::Value*> children;
121 for (auto&& n : node->getChildren()) {
122 /* Ignore children like INTEGER_NODE */
123 if (n->getBitvectorSize() == 0) {
124 children.emplace_back(nullptr);
125 }
126 else {
127 children.emplace_back(results->at(n));
128 }
129 }
130
131 switch (node->getType()) {
132
134 llvm::Function* bswap = nullptr;
135 switch (node->getBitvectorSize()) {
136 case triton::bitsize::byte: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt8Ty(this->llvmContext)); break;
137 case triton::bitsize::word: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt16Ty(this->llvmContext)); break;
138 case triton::bitsize::dword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt32Ty(this->llvmContext)); break;
139 case triton::bitsize::qword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt64Ty(this->llvmContext)); break;
140 default:
141 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Invalid bswap size.");
142 }
143 return this->llvmIR.CreateCall(bswap, children[0]);
144 }
145
147 return this->llvmIR.CreateAdd(children[0], children[1]);
148
150 return this->llvmIR.CreateAnd(children[0], children[1]);
151
153 return this->llvmIR.CreateAShr(children[0], children[1]);
154
156 return this->llvmIR.CreateLShr(children[0], children[1]);
157
159 return this->llvmIR.CreateMul(children[0], children[1]);
160
162 return this->llvmIR.CreateNot(this->llvmIR.CreateAnd(children[0], children[1]));
163
165 return this->llvmIR.CreateNeg(children[0]);
166
168 return this->llvmIR.CreateNot(this->llvmIR.CreateOr(children[0], children[1]));
169
171 return this->llvmIR.CreateNot(children[0]);
172
174 return this->llvmIR.CreateOr(children[0], children[1]);
175
176 // bvrol(expr, rot) = ((expr << (rot % size)) | (expr >> (size - (rot % size))))
178 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
179 auto size = node->getBitvectorSize();
180 return this->llvmIR.CreateOr(this->llvmIR.CreateShl(children[0], rot % size), this->llvmIR.CreateLShr(children[0], (size - (rot % size))));
181 }
182
183 // bvror(expr, rot) = ((expr >> (rot % size)) | (expr << (size - (rot % size))))
185 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
186 auto size = node->getBitvectorSize();
187 return this->llvmIR.CreateOr(this->llvmIR.CreateLShr(children[0], rot % size), this->llvmIR.CreateShl(children[0], (size - (rot % size))));
188 }
189
191 return this->llvmIR.CreateSDiv(children[0], children[1]);
192
194 return this->llvmIR.CreateICmpSGE(children[0], children[1]);
195
197 return this->llvmIR.CreateICmpSGT(children[0], children[1]);
198
200 return this->llvmIR.CreateShl(children[0], children[1]);
201
203 return this->llvmIR.CreateICmpSLE(children[0], children[1]);
204
206 return this->llvmIR.CreateICmpSLT(children[0], children[1]);
207
209 auto* LHS = children[0];
210 auto* RHS = children[1];
211 return this->llvmIR.CreateSRem(this->llvmIR.CreateAdd(this->llvmIR.CreateSRem(LHS, RHS), RHS), RHS);
212 }
213
215 return this->llvmIR.CreateSRem(children[0], children[1]);
216
218 return this->llvmIR.CreateSub(children[0], children[1]);
219
221 return this->llvmIR.CreateUDiv(children[0], children[1]);
222
224 return this->llvmIR.CreateICmpUGE(children[0], children[1]);
225
227 return this->llvmIR.CreateICmpUGT(children[0], children[1]);
228
230 return this->llvmIR.CreateICmpULE(children[0], children[1]);
231
233 return this->llvmIR.CreateICmpULT(children[0], children[1]);
234
236 return this->llvmIR.CreateURem(children[0], children[1]);
237
239 return this->llvmIR.CreateNot(this->llvmIR.CreateXor(children[0], children[1]));
240
242 return this->llvmIR.CreateXor(children[0], children[1]);
243
245 return llvm::ConstantInt::get(this->llvmContext, llvm::APInt(node->getBitvectorSize(), static_cast<uint64_t>(node->evaluate()), false));
246
248 auto dstSize = node->getBitvectorSize();
249 auto finalNode = this->llvmIR.CreateZExt(children[0], llvm::IntegerType::get(this->llvmContext, dstSize));
250
251 for (triton::usize index = 1; index < children.size(); index++) {
252 finalNode = this->llvmIR.CreateShl(finalNode, node->getChildren()[index]->getBitvectorSize());
253 auto* n = this->llvmIR.CreateZExt(children[index], llvm::IntegerType::get(this->llvmContext, dstSize));
254 finalNode = this->llvmIR.CreateOr(finalNode, n);
255 }
256
257 return finalNode;
258 }
259
261 return this->llvmIR.CreateICmpNE(children[0], children[1]);
262
264 return this->llvmIR.CreateICmpEQ(children[0], children[1]);
265
267 auto low = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
268 auto dstSize = node->getChildren()[2]->getBitvectorSize();
269 auto* value = children[2];
270
271 if (low == 0) {
272 return this->llvmIR.CreateTrunc(value, llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
273 }
274
275 return this->llvmIR.CreateTrunc(this->llvmIR.CreateLShr(value, low), llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
276 }
277
279 return this->llvmIR.CreateSelect(children[0], children[1], children[2]);
280
282 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
283 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateAnd(children), truenode);
284 }
285
287 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
288 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateNot(children[0]), truenode);
289 }
290
292 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
293 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateOr(children), truenode);
294 }
295
297 auto* child0 = children[0];
298 auto* child1 = children[1];
299 auto* current = this->llvmIR.CreateXor(child0, child1);
300
301 for (triton::usize index = 2; index < children.size(); index++) {
302 current = this->llvmIR.CreateXor(current, children[index]);
303 }
304
305 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
306 return this->llvmIR.CreateICmpEQ(current, truenode);
307 }
308
310 return results->at(reinterpret_cast<triton::ast::ReferenceNode*>(node.get())->getSymbolicExpression()->getAst());
311
313 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8PtrTy(this->llvmContext));
314 return this->llvmIR.CreateLoad(llvm::Type::getInt8Ty(this->llvmContext), ptr);
315 }
316
318 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8PtrTy(this->llvmContext));
319 return this->llvmIR.CreateStore(children[2], ptr);
320 }
321
323 return this->llvmIR.CreateSExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
324
326 return this->llvmVars.at(node);
327
329 return this->llvmIR.CreateZExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
330
331 default:
332 throw triton::exceptions::AstLifting("TritonToLLVM::do_convert(): Invalid kind of node.");
333 }
334 }
335
336 }; /* ast namespace */
337}; /* triton namespace */
Reference node.
Definition ast.hpp:789
TRITON_EXPORT std::shared_ptr< llvm::Module > convert(const triton::ast::SharedAbstractNode &node, const char *fname="__triton", bool optimize=false)
Lifts a symbolic expression and all its references to LLVM format. fname represents the name of the L...
TRITON_EXPORT TritonToLLVM(llvm::LLVMContext &llvmContext)
Constructor.
Variable node.
Definition ast.hpp:878
The exception class used by all AST lifting (e.g z3 <-> triton).
The exception class used by the lifting engine.
std::vector< SharedAbstractNode > childrenExtraction(const SharedAbstractNode &node, bool unroll, bool revert)
Returns node and all its children of an AST sorted topologically. If unroll is true,...
Definition ast.cpp:3700
std::shared_ptr< triton::ast::AbstractNode > SharedAbstractNode
Shared Abstract Node.
Definition ast.hpp:59
std::deque< SharedAbstractNode > search(const SharedAbstractNode &node, triton::ast::ast_e match)
Returns a deque of collected matched nodes via a depth-first pre order traversal.
Definition ast.cpp:3710
constexpr triton::uint32 byte
byte size in bit
Definition cpuSize.hpp:60
constexpr triton::uint32 dword
dword size in bit
Definition cpuSize.hpp:64
constexpr triton::uint32 qword
qword size in bit
Definition cpuSize.hpp:66
constexpr triton::uint32 word
word size in bit
Definition cpuSize.hpp:62
std::size_t usize
unsigned MAX_INT 32 or 64 bits according to the CPU.
The Triton namespace.