commit 410499486ed716eb3dc3dff97d26f2a75c8b7988
parent 28f174d300700d5a8b1ca105d197e092af2dc4f0
Author: Mohammad-Reza Nabipoor <m.nabipoor@yahoo.com>
Date: Fri, 2 Oct 2020 02:16:57 +0330
kaleidoscope_codegen: Add optimizer
Diffstat:
3 files changed, 85 insertions(+), 7 deletions(-)
diff --git a/kaleidoscope_codegen.cpp b/kaleidoscope_codegen.cpp
@@ -8,14 +8,22 @@
#include <utility>
#include <vector>
-#include <llvm/ADT/APFloat.h>
-#include <llvm/IR/Constants.h>
-#include <llvm/IR/IRBuilder.h>
+// clang-format off
#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
+// clang-format on
+
+#include <llvm/ADT/APFloat.h>
+#include <llvm/IR/Constants.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Verifier.h>
+#include <llvm/IR/LegacyPassManager.h>
+#include <llvm/Transforms/InstCombine/InstCombine.h>
+#include <llvm/Transforms/Scalar.h>
+#include <llvm/Transforms/Scalar/GVN.h>
+
namespace {
int
@@ -59,6 +67,24 @@ module()
return m.get();
}
+inline llvm::legacy::FunctionPassManager*
+fpm()
+{
+ static auto f = [] {
+ auto f = std::make_unique<llvm::legacy::FunctionPassManager>(module());
+
+ f->add(llvm::createInstructionCombiningPass());
+ f->add(llvm::createReassociatePass());
+ f->add(llvm::createGVNPass());
+ f->add(llvm::createCFGSimplificationPass());
+
+ f->doInitialization();
+ return f;
+ }();
+
+ return f.get();
+}
+
template<typename K, typename V>
class Map
{
@@ -242,7 +268,7 @@ kal::codegen(const Prototype& p)
}
llvm::Function*
-kal::codegen(const Function& f)
+kal::codegen(const Function& f, bool optimize)
{
auto func = funcs.find(f.proto.name);
@@ -268,12 +294,14 @@ kal::codegen(const Function& f)
builder().CreateRet(body);
llvm::verifyFunction(*func);
+ if (optimize)
+ fpm()->run(*func);
return func;
}
llvm::Function*
-kal::mkfunc(const std::string& name, ASTNode* f, ASTNode* l)
+kal::mkfunc(const std::string& name, ASTNode* f, ASTNode* l, bool optimize)
{
if (f == l)
return nullptr;
@@ -290,6 +318,8 @@ kal::mkfunc(const std::string& name, ASTNode* f, ASTNode* l)
builder().CreateRet(kal::codegen(kal::Number{ 0 }));
llvm::verifyFunction(*func);
+ if (optimize)
+ fpm()->run(*func);
return func;
}
diff --git a/kaleidoscope_codegen.hpp b/kaleidoscope_codegen.hpp
@@ -28,7 +28,7 @@ codegen(const Call&);
llvm::Function*
codegen(const Prototype&);
llvm::Function*
-codegen(const Function&);
+codegen(const Function&, bool optimize = false);
class ASTNode;
@@ -39,7 +39,10 @@ llvm::Value*
codegen(const ASTNode& n);
llvm::Function*
-mkfunc(const std::string& name, kal::ASTNode* f, kal::ASTNode* l);
+mkfunc(const std::string& name,
+ kal::ASTNode* f,
+ kal::ASTNode* l,
+ bool optimize = false);
void
codegen_reset(void);
diff --git a/tests/kaleidoscope_codegen.test.cpp b/tests/kaleidoscope_codegen.test.cpp
@@ -212,6 +212,32 @@ def one(x)
"}\n");
kal::codegen_reset();
+ defstrs.clear();
+
+ for (const auto& d : pd.decls)
+ kal::codegen(d);
+ for (const auto& d : pd.defs)
+ defstrs.emplace_back(to_string(kal::codegen(d, true)));
+
+ REQUIRE(
+ defstrs[0] ==
+ "define double @fun() {\n"
+ "entry:\n"
+ " %calltmp1 = call double @tan2(double 0x3FEFFFFFFFFF12A3, double "
+ "1.000000e+00)\n"
+ " ret double %calltmp1\n"
+ "}\n");
+ REQUIRE(defstrs[1] ==
+ "define double @gun(double %x, double %y, double %z) {\n"
+ "entry:\n"
+ " %calltmp1 = call double @gun(double 1.000000e+00, double "
+ "0x3FEFEB7A9B2C6D8B, double 1.000000e+00)\n"
+ " %calltmp2 = call double @tan2(double 1.000000e-01, double "
+ "%calltmp1)\n"
+ " ret double %calltmp2\n"
+ "}\n");
+
+ kal::codegen_reset();
}
}
@@ -235,6 +261,15 @@ def one(x)
"}\n");
kal::codegen_reset();
+
+ kal::codegen(pd.decls[0]);
+ REQUIRE(to_string(kal::codegen(pd.defs[0], true)) ==
+ "define double @f() {\n"
+ "entry:\n"
+ " ret double 0xC045FF205A3B2E7C\n"
+ "}\n");
+
+ kal::codegen_reset();
}
}
@@ -274,7 +309,17 @@ sin(2.72) + formula(1, 2, 3)
" %addtmp = fadd double %calltmp, %calltmp1\n"
" ret double 0.000000e+00\n"
"}\n");
+
+ REQUIRE(to_string(kal::mkfunc("kaleidoscope_body_1__", f, l, true)) ==
+ "define double @kaleidoscope_body_1__() {\n"
+ "entry:\n"
+ " %calltmp1 = call double @formula(double 1.000000e+00, "
+ "double 2.000000e+00, double 3.000000e+00)\n"
+ " ret double 0.000000e+00\n"
+ "}\n");
}
+
+ kal::codegen_reset();
}
}
}