commit a6d9d5d3c4b6ac3c256361d919a3c3975378321c
parent 73987f2986ad6118435bd30250cd2117dfe31b05
Author: Mohammad-Reza Nabipoor <m.nabipoor@yahoo.com>
Date: Sun, 13 Sep 2020 05:09:24 +0430
Add codegen module (+ tests)
Diffstat:
4 files changed, 583 insertions(+), 16 deletions(-)
diff --git a/Makefile b/Makefile
@@ -15,7 +15,8 @@ eflags += -I. # flags for examples
catch2 = tests/catch2/catch.hpp
# tests
-tbin = kaleidoscope_lexer.test kaleidoscope_ast.test kaleidoscope_parser.test
+tbin = kaleidoscope_lexer.test kaleidoscope_ast.test kaleidoscope_parser.test \
+ kaleidoscope_codegen.test
# examples
ebin = kaleidoscope_lexer.ex
@@ -62,6 +63,14 @@ kaleidoscope_parser.test: kaleidoscope_ast.o
kaleidoscope_parser.test: kaleidoscope_parser.test.o
$(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
+kaleidoscope_codegen.test.o: CXXFLAGS += $(tflags)
+kaleidoscope_codegen.test.o: kaleidoscope_codegen.hpp kaleidoscope_parser.hpp
+kaleidoscope_codegen.test.o: tests/kaleidoscope_codegen.test.cpp $(catch2)
+ $(CXX) $(CXXFLAGS) -c -o $@ $<
+kaleidoscope_codegen.test: kaleidoscope_codegen.o kaleidoscope_ast.o
+kaleidoscope_codegen.test: kaleidoscope_codegen.test.o
+ $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) $(LLVM_LIBS)
+
#--- examples
kaleidoscope_lexer.ex.o: CXXFLAGS += $(eflags)
diff --git a/kaleidoscope_codegen.cpp b/kaleidoscope_codegen.cpp
@@ -3,41 +3,306 @@
#include "kaleidoscope_ast.hpp"
-#include <map>
+#include <algorithm>
+#include <cstdio>
#include <string>
+#include <utility>
+#include <vector>
-llvm::Value*
-kal::codegen(const Number&)
+#include <llvm/ADT/APFloat.h>
+#include <llvm/IR/Constants.h>
+#include <llvm/IR/IRBuilder.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/Module.h>
+#include <llvm/IR/Value.h>
+#include <llvm/IR/Verifier.h>
+
+namespace {
+
+int
+assert_fail(const char* expr, const char* file, int line, const char* func)
{
- return {};
+ fprintf(
+ stderr, "Assertion failed: %s (%s: %s: %d)\n", expr, file, func, line);
+ fflush(NULL);
+ abort();
+
+ return 0; // dummy
}
+#define ASSERT_ALWAYS(x) ((x) || assert_fail(#x, __FILE__, __LINE__, __func__))
+#ifndef NDEBUG
+#define ASSERT(x)
+#else
+#define ASSERT(x) ASSERT_ALWAYS(x)
+#endif
+
+inline llvm::LLVMContext&
+context()
+{
+ static llvm::LLVMContext ctx;
+
+ return ctx;
+}
+
+inline llvm::IRBuilder<>&
+builder()
+{
+ static llvm::IRBuilder<> b(context());
+
+ return b;
+}
+
+inline llvm::Module*
+module()
+{
+ static auto m = std::make_unique<llvm::Module>("Kaleidoscope JIT", context());
+
+ return m.get();
+}
+
+template<typename K, typename V>
+class Map
+{
+private:
+ std::vector<std::pair<K, V>> d_; // data
+
+ auto find_(const K& k)
+ {
+ const auto& f = d_.begin();
+ const auto& l = d_.end();
+ auto r = std::lower_bound(f, l, k, [](const auto& pair, const auto& key) {
+ return pair.first < key;
+ });
+
+ return std::make_pair(r, l);
+ }
+
+public:
+ V find(const K& k)
+ {
+ auto r = find_(k);
+
+ if (r.first == r.second)
+ return nullptr;
+ return k == r.first->first ? r.first->second : nullptr;
+ }
+
+ void insert(const K& k, const V& v)
+ {
+ auto r = find_(k);
+
+ if (r.first != r.second && k == r.first->first)
+ assert(v == r.first->second);
+ else
+ d_.insert(r.first, std::make_pair(k, v));
+ }
+
+ template<typename F>
+ void for_each(F f)
+ {
+ for (const auto& e : d_)
+ f(e.first, e.second);
+ }
+
+ void clear() { d_.clear(); }
+};
+
+Map<std::string, llvm::Function*> funcs;
+Map<std::string, llvm::Value*> names;
+
+} // anonymous namespace
llvm::Value*
-kal::codegen(const Variable&)
+kal::codegen(const Number& n)
{
- return {};
+ return llvm::ConstantFP::get(context(), llvm::APFloat{ n.value });
}
llvm::Value*
-kal::codegen(const BinaryOp&)
+kal::codegen(const Variable& v)
{
- return {};
+ return names.find(v.name);
}
llvm::Value*
-kal::codegen(const Call&)
+kal::codegen(const BinaryOp& b)
{
- return {};
+ auto l = kal::codegen(b.lhs);
+ auto r = kal::codegen(b.rhs);
+
+ ASSERT_ALWAYS(l != nullptr);
+ ASSERT_ALWAYS(r != nullptr);
+
+ switch (b.op) {
+ case '*':
+ return builder().CreateFMul(l, r, "multmp");
+
+ case '/':
+ return builder().CreateFDiv(l, r, "divtmp");
+
+ case '+':
+ return builder().CreateFAdd(l, r, "addtmp");
+
+ case '-':
+ return builder().CreateFSub(l, r, "subtmp");
+
+ case '<':
+ l = builder().CreateFCmpULT(l, r, "cmptmp");
+ return builder().CreateUIToFP(
+ l, llvm::Type::getDoubleTy(context()), "booltmp");
+
+ default:
+ ASSERT_ALWAYS(false && "unknown operator");
+ }
+
+ return nullptr;
}
llvm::Value*
-kal::codegen(const Prototype&)
+kal::codegen(const Call& c)
{
- return {};
+ auto fun = funcs.find(c.callee);
+ auto n = c.args.size();
+
+ ASSERT_ALWAYS(fun != nullptr);
+ ASSERT_ALWAYS(n == fun->arg_size());
+
+ std::vector<llvm::Value*> args(n);
+
+ for (auto i = 0u; i < n; i++) {
+ args[i] = kal::codegen(c.args[i]);
+ ASSERT_ALWAYS(args[i] != nullptr);
+ }
+
+ return builder().CreateCall(fun, args, "calltmp");
}
llvm::Value*
-kal::codegen(const Function&)
+kal::codegen(const kal::ASTNode& n)
+{
+ using kal::node_type;
+
+ switch (node_type(n)) {
+ case kal::NodeType::Number: {
+ kal::Number num;
+
+ cast(n, num);
+ return kal::codegen(num);
+ }
+
+ case kal::NodeType::Variable: {
+ kal::Variable v;
+
+ cast(n, v);
+ return kal::codegen(v);
+ }
+
+ case kal::NodeType::BinaryOp: {
+ kal::BinaryOp op;
+
+ cast(n, op);
+ return kal::codegen(op);
+ }
+
+ case kal::NodeType::Call: {
+ kal::Call c;
+
+ cast(n, c);
+ return kal::codegen(c);
+ }
+
+ case kal::NodeType::None:
+ case kal::NodeType::Prototype:
+ case kal::NodeType::Function:
+ break;
+ }
+
+ ASSERT_ALWAYS(false && "unknown NodeType");
+ return nullptr;
+}
+
+llvm::Function*
+kal::codegen(const Prototype& p)
+{
+ std::vector<llvm::Type*> dbls(p.params.size(),
+ llvm::Type::getDoubleTy(context()));
+ auto ftype =
+ llvm::FunctionType::get(llvm::Type::getDoubleTy(context()), dbls, false);
+ auto func = llvm::Function::Create(
+ ftype, llvm::Function::ExternalLinkage, p.name, module());
+
+ {
+ auto i = 0;
+
+ for (auto& arg : func->args())
+ arg.setName(p.params[i++]);
+ }
+
+ funcs.insert(p.name, func);
+ return func;
+}
+
+llvm::Function*
+kal::codegen(const Function& f)
+{
+ auto func = funcs.find(f.proto.name);
+
+ if (func == nullptr) {
+ func = kal::codegen(f.proto);
+ if (func != nullptr)
+ funcs.insert(f.proto.name, func);
+ }
+
+ ASSERT_ALWAYS(func != nullptr);
+
+ auto bb = llvm::BasicBlock::Create(context(), "entry", func);
+
+ builder().SetInsertPoint(bb);
+
+ names.clear();
+ for (auto& arg : func->args())
+ names.insert(arg.getName(), &arg);
+
+ auto body = kal::codegen(f.body);
+
+ ASSERT_ALWAYS(body != nullptr);
+
+ builder().CreateRet(body);
+ llvm::verifyFunction(*func);
+
+ return func;
+}
+
+llvm::Function*
+kal::mkfunc(ASTNode* f, ASTNode* l)
+{
+ static std::size_t i;
+ char fname[1024];
+
+ if (f == l)
+ return nullptr;
+
+ snprintf(fname, sizeof(fname), "kaleidoscope_body_%zu__", i++);
+
+ auto func = codegen(kal::Prototype{ fname, {} });
+
+ ASSERT_ALWAYS(func != nullptr);
+
+ auto bb = llvm::BasicBlock::Create(context(), "entry", func);
+
+ builder().SetInsertPoint(bb);
+ while (f != l)
+ kal::codegen(*f++);
+ builder().CreateRet(kal::codegen(kal::Number{ 0 }));
+
+ llvm::verifyFunction(*func);
+ return func;
+}
+
+void
+kal::codegen_reset(void)
{
- return {};
+ funcs.for_each([](const auto&, const auto& v) { v->eraseFromParent(); });
+ funcs.clear();
+ names.clear();
}
diff --git a/kaleidoscope_codegen.hpp b/kaleidoscope_codegen.hpp
@@ -3,6 +3,7 @@
namespace llvm {
class Value;
+class Function;
}
namespace kal {
@@ -22,9 +23,21 @@ llvm::Value*
codegen(const BinaryOp&);
llvm::Value*
codegen(const Call&);
+
+class ASTNode;
+
llvm::Value*
+codegen(const ASTNode&);
+
+llvm::Function*
codegen(const Prototype&);
-llvm::Value*
+llvm::Function*
codegen(const Function&);
+llvm::Function*
+mkfunc(kal::ASTNode* f, kal::ASTNode* l);
+
+void
+codegen_reset(void);
+
} // namespace kal
diff --git a/tests/kaleidoscope_codegen.test.cpp b/tests/kaleidoscope_codegen.test.cpp
@@ -0,0 +1,280 @@
+
+#include "kaleidoscope_codegen.hpp"
+
+#define CATCH_CONFIG_MAIN
+#include <catch2/catch.hpp>
+
+#include <sstream>
+#include <string>
+
+#include <llvm/IR/Function.h>
+#include <llvm/Support/raw_ostream.h>
+
+#include "kaleidoscope_parser.hpp"
+
+using kal::cast;
+using kal::to_string;
+
+struct Parsed
+{
+ std::vector<kal::Prototype> decls;
+ std::vector<kal::Function> defs;
+ std::vector<kal::ASTNode> stmts;
+};
+
+namespace {
+
+std::string
+to_string(llvm::Function* f)
+{
+ std::string code;
+ llvm::raw_string_ostream out{ code };
+
+ f->print(out);
+ return code;
+}
+
+} // anonymous namespace
+
+TEST_CASE("Code generation for simple programs", "[simple]")
+{
+ // FIXME move to kaleidoscope_parser.hpp
+ auto p = [](const std::string& s) {
+ std::vector<kal::Token> tk;
+ Parsed parsed;
+
+ kal::tokenize(s.cbegin(), s.cend(), std::back_inserter(tk));
+ tk.erase(std::remove_if(
+ tk.begin(),
+ tk.end(),
+ [](const auto& t) { return t.type == kal::TkType::Comment; }),
+ tk.end());
+
+ auto tkb = tk.cbegin();
+ auto tke = tk.cend();
+ std::vector<kal::ParseError<decltype(tkb)>> errs;
+
+ kal::parse(
+ tkb,
+ tke,
+ [&](auto type, const auto& node) {
+ switch (type) {
+ case kal::ParsedEntityType::FuncDecl: {
+ kal::Prototype proto;
+ auto ok = cast(node, proto);
+
+ assert(ok && "invalid cast: expects kal::Prototype");
+
+ parsed.decls.emplace_back(std::move(proto));
+ } break;
+
+ case kal::ParsedEntityType::FuncDef: {
+ kal::Function func;
+ auto ok = cast(node, func);
+
+ assert(ok && "invalid cast: expects kal::Function");
+
+ parsed.defs.emplace_back(std::move(func));
+ } break;
+
+ case kal::ParsedEntityType::Stmt:
+ parsed.stmts.emplace_back(std::move(node));
+ break;
+ }
+ },
+ std::back_inserter(errs));
+
+ if (errs.empty())
+ return parsed;
+
+ std::ostringstream oss;
+
+ for (auto& e : errs) {
+ oss << " Error@" << std::distance(tkb, e.pos) << ' ' << e.msg;
+
+ if (e.pos != tke)
+ oss << '\n' << kal::to_string(*e.pos) << "\n";
+ }
+
+ FAIL("PARSER ERROR\n" << oss.str());
+ return parsed; // dummy
+ };
+
+ SECTION("extern")
+ {
+ {
+ auto pd = p("extern sin(x)");
+
+ REQUIRE(pd.decls.size() == 1);
+ REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
+ "declare double @sin(double)\n");
+
+ kal::codegen_reset();
+ }
+
+ {
+ auto pd = p("extern tan2 ( arg0 arg1 )");
+
+ REQUIRE(pd.decls.size() == 1);
+ REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
+ "declare double @tan2(double, double)\n");
+
+ kal::codegen_reset();
+ }
+
+ {
+ auto pd = p("extern cos(realInput);extern atan2(arg0 arg1);");
+
+ REQUIRE(pd.decls.size() == 2);
+ REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
+ "declare double @cos(double)\n");
+ REQUIRE(to_string(kal::codegen(pd.decls[1])) ==
+ "declare double @atan2(double, double)\n");
+
+ kal::codegen_reset();
+ }
+ }
+
+ SECTION("def")
+ {
+ {
+ auto pd = p(R"(
+def one(x)
+ 1
+)");
+
+ REQUIRE(pd.defs.size() == 1);
+
+ auto c = kal::codegen(pd.defs[0]);
+ auto cstr = to_string(c);
+
+ REQUIRE(cstr ==
+ "define double @one(double %x) {\n"
+ "entry:\n"
+ " ret double 1.000000e+00\n"
+ "}\n");
+
+ kal::codegen_reset();
+ }
+
+ {
+ auto pd = p("extern sin(x) def pi2() 1.5708");
+
+ REQUIRE(pd.decls.size() == 1);
+ REQUIRE(pd.defs.size() == 1);
+
+ auto cdecl = kal::codegen(pd.decls[0]);
+ auto cdeclstr = to_string(cdecl);
+ auto cdef = kal::codegen(pd.defs[0]);
+ auto cdefstr = to_string(cdef);
+
+ REQUIRE(cdeclstr == "declare double @sin(double)\n");
+ REQUIRE(cdefstr ==
+ "define double @pi2() {\n"
+ "entry:\n"
+ " ret double 1.570800e+00\n"
+ "}\n");
+
+ kal::codegen_reset();
+ }
+
+ {
+ auto pd =
+ p("extern tan2(x y);extern sin(x);def fun()tan2(sin(1.5708), 1);"
+ "def gun(x y z)tan2(0.1,gun(1,sin(1.5),1.0));");
+ std::vector<std::string> defstrs;
+
+ for (const auto& d : pd.decls)
+ kal::codegen(d);
+
+ REQUIRE(pd.defs.size() == 2);
+
+ for (const auto& d : pd.defs)
+ defstrs.emplace_back(to_string(kal::codegen(d)));
+
+ REQUIRE(defstrs[0] ==
+ "define double @fun() {\n"
+ "entry:\n"
+ " %calltmp = call double @sin(double 1.570800e+00)\n"
+ " %calltmp1 = call double @tan2(double %calltmp, double "
+ "1.000000e+00)\n"
+ " ret double %calltmp1\n"
+ "}\n");
+ REQUIRE(defstrs[1] ==
+ "define double @gun(double %x, double %y, double %z) {\n"
+ "entry:\n"
+ " %calltmp = call double @sin(double 1.500000e+00)\n"
+ " %calltmp1 = call double @gun(double 1.000000e+00, double "
+ "%calltmp, double 1.000000e+00)\n"
+ " %calltmp2 = call double @tan2(double 1.000000e-01, double "
+ "%calltmp1)\n"
+ " ret double %calltmp2\n"
+ "}\n");
+
+ kal::codegen_reset();
+ }
+ }
+
+ SECTION("binary operation")
+ {
+ {
+ auto pd = p("extern sin(x); def f() sin(3.14/4)-2.3-3.4-(3*4+1)*3");
+
+ REQUIRE(pd.decls.size() == 1);
+ REQUIRE(pd.defs.size() == 1);
+
+ kal::codegen(pd.decls[0]);
+ REQUIRE(to_string(kal::codegen(pd.defs[0])) ==
+ "define double @f() {\n"
+ "entry:\n"
+ " %calltmp = call double @sin(double 7.850000e-01)\n"
+ " %subtmp = fsub double %calltmp, 2.300000e+00\n"
+ " %subtmp1 = fsub double %subtmp, 3.400000e+00\n"
+ " %subtmp2 = fsub double %subtmp1, 3.900000e+01\n"
+ " ret double %subtmp2\n"
+ "}\n");
+
+ kal::codegen_reset();
+ }
+ }
+
+ SECTION("extern, def, call, stmts")
+ {
+ {
+ auto pd = p(R"(
+extern tan2(x y)
+
+def formula(x y z)
+ 3.14 + 2 * (tan2(x, y) + x*y) / z
+
+extern sin(x)
+
+sin(2.72) + formula(1, 2, 3)
+)");
+
+ REQUIRE(pd.decls.size() == 2);
+ REQUIRE(pd.defs.size() == 1);
+ REQUIRE(pd.stmts.size() == 1);
+
+ for (const auto& d : pd.decls)
+ kal::codegen(d);
+ for (const auto& d : pd.defs)
+ kal::codegen(d);
+
+ {
+ auto f = pd.stmts.data();
+ auto l = f + pd.stmts.size();
+
+ REQUIRE(to_string(kal::mkfunc(f, l)) ==
+ "define double @kaleidoscope_body_0__() {\n"
+ "entry:\n"
+ " %calltmp = call double @sin(double 2.720000e+00)\n"
+ " %calltmp1 = call double @formula(double 1.000000e+00, "
+ "double 2.000000e+00, double 3.000000e+00)\n"
+ " %addtmp = fadd double %calltmp, %calltmp1\n"
+ " ret double 0.000000e+00\n"
+ "}\n");
+ }
+ }
+ }
+}