19#include <llvm/IR/PassManager.h>
20#include <llvm/Analysis/LoopAnalysisManager.h>
21#include <llvm/Analysis/CGSCCPassManager.h>
22#include <llvm/Passes/PassBuilder.h>
29 : llvmContext(llvmContext), llvmIR(this->llvmContext) {
30 this->llvmModule = std::make_shared<llvm::Module>(
"tritonModule", this->llvmContext);
31 if (llvmModule ==
nullptr) {
43 auto varA = reinterpret_cast<triton::ast::VariableNode*>(a.get())->getSymbolicVariable();
44 auto varB = reinterpret_cast<triton::ast::VariableNode*>(b.get())->getSymbolicVariable();
49 std::vector<llvm::Type*> argsType;
50 argsType.resize(vars.size());
51 for (
triton::usize index = 0 ; index < vars.size() ; index++) {
52 switch (vars[index]->getBitvectorSize()) {
54 argsType[index] = llvm::Type::getInt8Ty(this->llvmContext);
57 argsType[index] = llvm::Type::getInt16Ty(this->llvmContext);
60 argsType[index] = llvm::Type::getInt32Ty(this->llvmContext);
63 argsType[index] = llvm::Type::getInt64Ty(this->llvmContext);
71 auto retSize = node->getBitvectorSize();
72 auto* retType = llvm::IntegerType::get(this->llvmContext, retSize);
73 auto* funcType = llvm::FunctionType::get(retType, argsType,
false );
74 auto* llvmFunc = llvm::Function::Create(funcType, llvm::Function::ExternalLinkage, fname, this->llvmModule.get());
77 llvm::Function::arg_iterator params = llvmFunc->arg_begin();
78 for (
const auto& node : vars) {
80 auto* param = params++;
81 param->setName(var->getName());
82 this->llvmVars[node] = param;
86 auto* llvmBasicBlock = llvm::BasicBlock::Create(this->llvmContext,
"entry", llvmFunc);
87 this->llvmIR.SetInsertPoint(llvmBasicBlock);
92 std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*> results;
95 this->createFunction(node, fname);
99 for (
const auto& node : nodes) {
100 if (node->getBitvectorSize()) {
101 results.insert(std::make_pair(node, this->do_convert(node, &results)));
106 this->llvmIR.CreateRet(results.at(node));
110 llvm::LoopAnalysisManager lam;
111 llvm::FunctionAnalysisManager fam;
112 llvm::CGSCCAnalysisManager cgam;
113 llvm::ModuleAnalysisManager mam;
115 llvm::PassBuilder pb;
117 pb.registerModuleAnalyses(mam);
118 pb.registerCGSCCAnalyses(cgam);
119 pb.registerFunctionAnalyses(fam);
120 pb.registerLoopAnalyses(lam);
121 pb.crossRegisterProxies(lam, fam, cgam, mam);
123 llvm::ModulePassManager pm = pb.buildPerModuleDefaultPipeline(llvm::OptimizationLevel::O3);
124 pm.run(*this->llvmModule, mam);
127 return this->llvmModule;
131 llvm::Value* TritonToLLVM::do_convert(
const triton::ast::SharedAbstractNode& node, std::unordered_map<triton::ast::SharedAbstractNode, llvm::Value*>* results) {
136 std::vector<llvm::Value*> children;
137 for (
auto&& n : node->getChildren()) {
139 if (n->getBitvectorSize() == 0) {
140 children.emplace_back(
nullptr);
143 children.emplace_back(results->at(n));
147 switch (node->getType()) {
150 llvm::Function* bswap =
nullptr;
151 switch (node->getBitvectorSize()) {
152 case triton::bitsize::byte: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt8Ty(this->llvmContext));
break;
153 case triton::bitsize::word: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt16Ty(this->llvmContext));
break;
154 case triton::bitsize::dword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt32Ty(this->llvmContext));
break;
155 case triton::bitsize::qword: bswap = llvm::Intrinsic::getDeclaration(this->llvmModule.get(), llvm::Intrinsic::bswap, llvm::Type::getInt64Ty(this->llvmContext));
break;
159 return this->llvmIR.CreateCall(bswap, children[0]);
163 return this->llvmIR.CreateAdd(children[0], children[1]);
166 return this->llvmIR.CreateAnd(children[0], children[1]);
169 return this->llvmIR.CreateAShr(children[0], children[1]);
172 return this->llvmIR.CreateLShr(children[0], children[1]);
175 return this->llvmIR.CreateMul(children[0], children[1]);
178 return this->llvmIR.CreateNot(this->llvmIR.CreateAnd(children[0], children[1]));
181 return this->llvmIR.CreateNeg(children[0]);
184 return this->llvmIR.CreateNot(this->llvmIR.CreateOr(children[0], children[1]));
187 return this->llvmIR.CreateNot(children[0]);
190 return this->llvmIR.CreateOr(children[0], children[1]);
194 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
195 auto size = node->getBitvectorSize();
196 return this->llvmIR.CreateOr(this->llvmIR.CreateShl(children[0], rot % size), this->llvmIR.CreateLShr(children[0], (size - (rot % size))));
201 auto rot = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
202 auto size = node->getBitvectorSize();
203 return this->llvmIR.CreateOr(this->llvmIR.CreateLShr(children[0], rot % size), this->llvmIR.CreateShl(children[0], (size - (rot % size))));
207 return this->llvmIR.CreateSDiv(children[0], children[1]);
210 return this->llvmIR.CreateICmpSGE(children[0], children[1]);
213 return this->llvmIR.CreateICmpSGT(children[0], children[1]);
216 return this->llvmIR.CreateShl(children[0], children[1]);
219 return this->llvmIR.CreateICmpSLE(children[0], children[1]);
222 return this->llvmIR.CreateICmpSLT(children[0], children[1]);
225 auto* LHS = children[0];
226 auto* RHS = children[1];
227 return this->llvmIR.CreateSRem(this->llvmIR.CreateAdd(this->llvmIR.CreateSRem(LHS, RHS), RHS), RHS);
231 return this->llvmIR.CreateSRem(children[0], children[1]);
234 return this->llvmIR.CreateSub(children[0], children[1]);
237 return this->llvmIR.CreateUDiv(children[0], children[1]);
240 return this->llvmIR.CreateICmpUGE(children[0], children[1]);
243 return this->llvmIR.CreateICmpUGT(children[0], children[1]);
246 return this->llvmIR.CreateICmpULE(children[0], children[1]);
249 return this->llvmIR.CreateICmpULT(children[0], children[1]);
252 return this->llvmIR.CreateURem(children[0], children[1]);
255 return this->llvmIR.CreateNot(this->llvmIR.CreateXor(children[0], children[1]));
258 return this->llvmIR.CreateXor(children[0], children[1]);
261 return llvm::ConstantInt::get(this->llvmContext, llvm::APInt(node->getBitvectorSize(),
static_cast<uint64_t
>(node->evaluate()),
false));
264 auto dstSize = node->getBitvectorSize();
265 auto finalNode = this->llvmIR.CreateZExt(children[0], llvm::IntegerType::get(this->llvmContext, dstSize));
267 for (
triton::usize index = 1; index < children.size(); index++) {
268 finalNode = this->llvmIR.CreateShl(finalNode, node->getChildren()[index]->getBitvectorSize());
269 auto* n = this->llvmIR.CreateZExt(children[index], llvm::IntegerType::get(this->llvmContext, dstSize));
270 finalNode = this->llvmIR.CreateOr(finalNode, n);
277 return this->llvmIR.CreateICmpNE(children[0], children[1]);
280 return this->llvmIR.CreateICmpEQ(children[0], children[1]);
283 auto low = triton::ast::getInteger<triton::uint64>(node->getChildren()[1]);
284 auto dstSize = node->getChildren()[2]->getBitvectorSize();
285 auto* value = children[2];
288 return this->llvmIR.CreateTrunc(value, llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
291 return this->llvmIR.CreateTrunc(this->llvmIR.CreateLShr(value, low), llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
295 return this->llvmIR.CreateSelect(children[0], children[1], children[2]);
298 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
299 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateAnd(children), truenode);
303 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
304 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateNot(children[0]), truenode);
308 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
309 return this->llvmIR.CreateICmpEQ(this->llvmIR.CreateOr(children), truenode);
313 auto* child0 = children[0];
314 auto* child1 = children[1];
315 auto* current = this->llvmIR.CreateXor(child0, child1);
317 for (
triton::usize index = 2; index < children.size(); index++) {
318 current = this->llvmIR.CreateXor(current, children[index]);
321 auto* truenode = llvm::ConstantInt::get(this->llvmContext, llvm::APInt(1, 1));
322 return this->llvmIR.CreateICmpEQ(current, truenode);
329 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8Ty(this->llvmContext)->getPointerTo());
330 return this->llvmIR.CreateLoad(llvm::Type::getInt8Ty(this->llvmContext), ptr);
334 auto* ptr = this->llvmIR.CreateIntToPtr(children[1], llvm::Type::getInt8Ty(this->llvmContext)->getPointerTo());
335 return this->llvmIR.CreateStore(children[2], ptr);
339 return this->llvmIR.CreateSExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
342 return this->llvmVars.at(node);
345 return this->llvmIR.CreateZExt(children[1], llvm::IntegerType::get(this->llvmContext, node->getBitvectorSize()));
TRITON_EXPORT std::shared_ptr< llvm::Module > convert(const triton::ast::SharedAbstractNode &node, const char *fname="__triton", bool optimize=false)
Lifts a symbolic expression and all its references to LLVM format. fname represents the name of the L...
TRITON_EXPORT TritonToLLVM(llvm::LLVMContext &llvmContext)
Constructor.
The exception class used by all AST lifting (e.g z3 <-> triton).
The exception class used by the lifting engine.
std::vector< SharedAbstractNode > childrenExtraction(const SharedAbstractNode &node, bool unroll, bool revert)
Returns node and all its children of an AST sorted topologically. If unroll is true,...
std::shared_ptr< triton::ast::AbstractNode > SharedAbstractNode
Shared Abstract Node.
std::deque< SharedAbstractNode > search(const SharedAbstractNode &node, triton::ast::ast_e match)
Returns a deque of collected matched nodes via a depth-first pre order traversal.
constexpr triton::uint32 byte
byte size in bit
constexpr triton::uint32 dword
dword size in bit
constexpr triton::uint32 qword
qword size in bit
constexpr triton::uint32 word
word size in bit
std::size_t usize
unsigned MAX_INT 32 or 64 bits according to the CPU.