llvm-journey

LLVM Journey
git clone git://0xff.ir/g/llvm-journey.git
Log | Files | Refs | README | LICENSE

commit a6d9d5d3c4b6ac3c256361d919a3c3975378321c
parent 73987f2986ad6118435bd30250cd2117dfe31b05
Author: Mohammad-Reza Nabipoor <m.nabipoor@yahoo.com>
Date:   Sun, 13 Sep 2020 05:09:24 +0430

Add codegen module (+ tests)

Diffstat:
MMakefile | 11++++++++++-
Mkaleidoscope_codegen.cpp | 293+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
Mkaleidoscope_codegen.hpp | 15++++++++++++++-
Atests/kaleidoscope_codegen.test.cpp | 280+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 583 insertions(+), 16 deletions(-)

diff --git a/Makefile b/Makefile @@ -15,7 +15,8 @@ eflags += -I. # flags for examples catch2 = tests/catch2/catch.hpp # tests -tbin = kaleidoscope_lexer.test kaleidoscope_ast.test kaleidoscope_parser.test +tbin = kaleidoscope_lexer.test kaleidoscope_ast.test kaleidoscope_parser.test \ + kaleidoscope_codegen.test # examples ebin = kaleidoscope_lexer.ex @@ -62,6 +63,14 @@ kaleidoscope_parser.test: kaleidoscope_ast.o kaleidoscope_parser.test: kaleidoscope_parser.test.o $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) +kaleidoscope_codegen.test.o: CXXFLAGS += $(tflags) +kaleidoscope_codegen.test.o: kaleidoscope_codegen.hpp kaleidoscope_parser.hpp +kaleidoscope_codegen.test.o: tests/kaleidoscope_codegen.test.cpp $(catch2) + $(CXX) $(CXXFLAGS) -c -o $@ $< +kaleidoscope_codegen.test: kaleidoscope_codegen.o kaleidoscope_ast.o +kaleidoscope_codegen.test: kaleidoscope_codegen.test.o + $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS) $(LLVM_LIBS) + #--- examples kaleidoscope_lexer.ex.o: CXXFLAGS += $(eflags) diff --git a/kaleidoscope_codegen.cpp b/kaleidoscope_codegen.cpp @@ -3,41 +3,306 @@ #include "kaleidoscope_ast.hpp" -#include <map> +#include <algorithm> +#include <cstdio> #include <string> +#include <utility> +#include <vector> -llvm::Value* -kal::codegen(const Number&) +#include <llvm/ADT/APFloat.h> +#include <llvm/IR/Constants.h> +#include <llvm/IR/IRBuilder.h> +#include <llvm/IR/LLVMContext.h> +#include <llvm/IR/Module.h> +#include <llvm/IR/Value.h> +#include <llvm/IR/Verifier.h> + +namespace { + +int +assert_fail(const char* expr, const char* file, int line, const char* func) { - return {}; + fprintf( + stderr, "Assertion failed: %s (%s: %s: %d)\n", expr, file, func, line); + fflush(NULL); + abort(); + + return 0; // dummy } +#define ASSERT_ALWAYS(x) ((x) || assert_fail(#x, __FILE__, __LINE__, __func__)) +#ifndef NDEBUG +#define ASSERT(x) +#else +#define ASSERT(x) ASSERT_ALWAYS(x) +#endif + +inline llvm::LLVMContext& +context() +{ + static llvm::LLVMContext ctx; + + return ctx; +} + +inline llvm::IRBuilder<>& +builder() +{ + static llvm::IRBuilder<> b(context()); + + return b; +} + +inline llvm::Module* +module() +{ + static auto m = std::make_unique<llvm::Module>("Kaleidoscope JIT", context()); + + return m.get(); +} + +template<typename K, typename V> +class Map +{ +private: + std::vector<std::pair<K, V>> d_; // data + + auto find_(const K& k) + { + const auto& f = d_.begin(); + const auto& l = d_.end(); + auto r = std::lower_bound(f, l, k, [](const auto& pair, const auto& key) { + return pair.first < key; + }); + + return std::make_pair(r, l); + } + +public: + V find(const K& k) + { + auto r = find_(k); + + if (r.first == r.second) + return nullptr; + return k == r.first->first ? r.first->second : nullptr; + } + + void insert(const K& k, const V& v) + { + auto r = find_(k); + + if (r.first != r.second && k == r.first->first) + assert(v == r.first->second); + else + d_.insert(r.first, std::make_pair(k, v)); + } + + template<typename F> + void for_each(F f) + { + for (const auto& e : d_) + f(e.first, e.second); + } + + void clear() { d_.clear(); } +}; + +Map<std::string, llvm::Function*> funcs; +Map<std::string, llvm::Value*> names; + +} // anonymous namespace llvm::Value* -kal::codegen(const Variable&) +kal::codegen(const Number& n) { - return {}; + return llvm::ConstantFP::get(context(), llvm::APFloat{ n.value }); } llvm::Value* -kal::codegen(const BinaryOp&) +kal::codegen(const Variable& v) { - return {}; + return names.find(v.name); } llvm::Value* -kal::codegen(const Call&) +kal::codegen(const BinaryOp& b) { - return {}; + auto l = kal::codegen(b.lhs); + auto r = kal::codegen(b.rhs); + + ASSERT_ALWAYS(l != nullptr); + ASSERT_ALWAYS(r != nullptr); + + switch (b.op) { + case '*': + return builder().CreateFMul(l, r, "multmp"); + + case '/': + return builder().CreateFDiv(l, r, "divtmp"); + + case '+': + return builder().CreateFAdd(l, r, "addtmp"); + + case '-': + return builder().CreateFSub(l, r, "subtmp"); + + case '<': + l = builder().CreateFCmpULT(l, r, "cmptmp"); + return builder().CreateUIToFP( + l, llvm::Type::getDoubleTy(context()), "booltmp"); + + default: + ASSERT_ALWAYS(false && "unknown operator"); + } + + return nullptr; } llvm::Value* -kal::codegen(const Prototype&) +kal::codegen(const Call& c) { - return {}; + auto fun = funcs.find(c.callee); + auto n = c.args.size(); + + ASSERT_ALWAYS(fun != nullptr); + ASSERT_ALWAYS(n == fun->arg_size()); + + std::vector<llvm::Value*> args(n); + + for (auto i = 0u; i < n; i++) { + args[i] = kal::codegen(c.args[i]); + ASSERT_ALWAYS(args[i] != nullptr); + } + + return builder().CreateCall(fun, args, "calltmp"); } llvm::Value* -kal::codegen(const Function&) +kal::codegen(const kal::ASTNode& n) +{ + using kal::node_type; + + switch (node_type(n)) { + case kal::NodeType::Number: { + kal::Number num; + + cast(n, num); + return kal::codegen(num); + } + + case kal::NodeType::Variable: { + kal::Variable v; + + cast(n, v); + return kal::codegen(v); + } + + case kal::NodeType::BinaryOp: { + kal::BinaryOp op; + + cast(n, op); + return kal::codegen(op); + } + + case kal::NodeType::Call: { + kal::Call c; + + cast(n, c); + return kal::codegen(c); + } + + case kal::NodeType::None: + case kal::NodeType::Prototype: + case kal::NodeType::Function: + break; + } + + ASSERT_ALWAYS(false && "unknown NodeType"); + return nullptr; +} + +llvm::Function* +kal::codegen(const Prototype& p) +{ + std::vector<llvm::Type*> dbls(p.params.size(), + llvm::Type::getDoubleTy(context())); + auto ftype = + llvm::FunctionType::get(llvm::Type::getDoubleTy(context()), dbls, false); + auto func = llvm::Function::Create( + ftype, llvm::Function::ExternalLinkage, p.name, module()); + + { + auto i = 0; + + for (auto& arg : func->args()) + arg.setName(p.params[i++]); + } + + funcs.insert(p.name, func); + return func; +} + +llvm::Function* +kal::codegen(const Function& f) +{ + auto func = funcs.find(f.proto.name); + + if (func == nullptr) { + func = kal::codegen(f.proto); + if (func != nullptr) + funcs.insert(f.proto.name, func); + } + + ASSERT_ALWAYS(func != nullptr); + + auto bb = llvm::BasicBlock::Create(context(), "entry", func); + + builder().SetInsertPoint(bb); + + names.clear(); + for (auto& arg : func->args()) + names.insert(arg.getName(), &arg); + + auto body = kal::codegen(f.body); + + ASSERT_ALWAYS(body != nullptr); + + builder().CreateRet(body); + llvm::verifyFunction(*func); + + return func; +} + +llvm::Function* +kal::mkfunc(ASTNode* f, ASTNode* l) +{ + static std::size_t i; + char fname[1024]; + + if (f == l) + return nullptr; + + snprintf(fname, sizeof(fname), "kaleidoscope_body_%zu__", i++); + + auto func = codegen(kal::Prototype{ fname, {} }); + + ASSERT_ALWAYS(func != nullptr); + + auto bb = llvm::BasicBlock::Create(context(), "entry", func); + + builder().SetInsertPoint(bb); + while (f != l) + kal::codegen(*f++); + builder().CreateRet(kal::codegen(kal::Number{ 0 })); + + llvm::verifyFunction(*func); + return func; +} + +void +kal::codegen_reset(void) { - return {}; + funcs.for_each([](const auto&, const auto& v) { v->eraseFromParent(); }); + funcs.clear(); + names.clear(); } diff --git a/kaleidoscope_codegen.hpp b/kaleidoscope_codegen.hpp @@ -3,6 +3,7 @@ namespace llvm { class Value; +class Function; } namespace kal { @@ -22,9 +23,21 @@ llvm::Value* codegen(const BinaryOp&); llvm::Value* codegen(const Call&); + +class ASTNode; + llvm::Value* +codegen(const ASTNode&); + +llvm::Function* codegen(const Prototype&); -llvm::Value* +llvm::Function* codegen(const Function&); +llvm::Function* +mkfunc(kal::ASTNode* f, kal::ASTNode* l); + +void +codegen_reset(void); + } // namespace kal diff --git a/tests/kaleidoscope_codegen.test.cpp b/tests/kaleidoscope_codegen.test.cpp @@ -0,0 +1,280 @@ + +#include "kaleidoscope_codegen.hpp" + +#define CATCH_CONFIG_MAIN +#include <catch2/catch.hpp> + +#include <sstream> +#include <string> + +#include <llvm/IR/Function.h> +#include <llvm/Support/raw_ostream.h> + +#include "kaleidoscope_parser.hpp" + +using kal::cast; +using kal::to_string; + +struct Parsed +{ + std::vector<kal::Prototype> decls; + std::vector<kal::Function> defs; + std::vector<kal::ASTNode> stmts; +}; + +namespace { + +std::string +to_string(llvm::Function* f) +{ + std::string code; + llvm::raw_string_ostream out{ code }; + + f->print(out); + return code; +} + +} // anonymous namespace + +TEST_CASE("Code generation for simple programs", "[simple]") +{ + // FIXME move to kaleidoscope_parser.hpp + auto p = [](const std::string& s) { + std::vector<kal::Token> tk; + Parsed parsed; + + kal::tokenize(s.cbegin(), s.cend(), std::back_inserter(tk)); + tk.erase(std::remove_if( + tk.begin(), + tk.end(), + [](const auto& t) { return t.type == kal::TkType::Comment; }), + tk.end()); + + auto tkb = tk.cbegin(); + auto tke = tk.cend(); + std::vector<kal::ParseError<decltype(tkb)>> errs; + + kal::parse( + tkb, + tke, + [&](auto type, const auto& node) { + switch (type) { + case kal::ParsedEntityType::FuncDecl: { + kal::Prototype proto; + auto ok = cast(node, proto); + + assert(ok && "invalid cast: expects kal::Prototype"); + + parsed.decls.emplace_back(std::move(proto)); + } break; + + case kal::ParsedEntityType::FuncDef: { + kal::Function func; + auto ok = cast(node, func); + + assert(ok && "invalid cast: expects kal::Function"); + + parsed.defs.emplace_back(std::move(func)); + } break; + + case kal::ParsedEntityType::Stmt: + parsed.stmts.emplace_back(std::move(node)); + break; + } + }, + std::back_inserter(errs)); + + if (errs.empty()) + return parsed; + + std::ostringstream oss; + + for (auto& e : errs) { + oss << " Error@" << std::distance(tkb, e.pos) << ' ' << e.msg; + + if (e.pos != tke) + oss << '\n' << kal::to_string(*e.pos) << "\n"; + } + + FAIL("PARSER ERROR\n" << oss.str()); + return parsed; // dummy + }; + + SECTION("extern") + { + { + auto pd = p("extern sin(x)"); + + REQUIRE(pd.decls.size() == 1); + REQUIRE(to_string(kal::codegen(pd.decls[0])) == + "declare double @sin(double)\n"); + + kal::codegen_reset(); + } + + { + auto pd = p("extern tan2 ( arg0 arg1 )"); + + REQUIRE(pd.decls.size() == 1); + REQUIRE(to_string(kal::codegen(pd.decls[0])) == + "declare double @tan2(double, double)\n"); + + kal::codegen_reset(); + } + + { + auto pd = p("extern cos(realInput);extern atan2(arg0 arg1);"); + + REQUIRE(pd.decls.size() == 2); + REQUIRE(to_string(kal::codegen(pd.decls[0])) == + "declare double @cos(double)\n"); + REQUIRE(to_string(kal::codegen(pd.decls[1])) == + "declare double @atan2(double, double)\n"); + + kal::codegen_reset(); + } + } + + SECTION("def") + { + { + auto pd = p(R"( +def one(x) + 1 +)"); + + REQUIRE(pd.defs.size() == 1); + + auto c = kal::codegen(pd.defs[0]); + auto cstr = to_string(c); + + REQUIRE(cstr == + "define double @one(double %x) {\n" + "entry:\n" + " ret double 1.000000e+00\n" + "}\n"); + + kal::codegen_reset(); + } + + { + auto pd = p("extern sin(x) def pi2() 1.5708"); + + REQUIRE(pd.decls.size() == 1); + REQUIRE(pd.defs.size() == 1); + + auto cdecl = kal::codegen(pd.decls[0]); + auto cdeclstr = to_string(cdecl); + auto cdef = kal::codegen(pd.defs[0]); + auto cdefstr = to_string(cdef); + + REQUIRE(cdeclstr == "declare double @sin(double)\n"); + REQUIRE(cdefstr == + "define double @pi2() {\n" + "entry:\n" + " ret double 1.570800e+00\n" + "}\n"); + + kal::codegen_reset(); + } + + { + auto pd = + p("extern tan2(x y);extern sin(x);def fun()tan2(sin(1.5708), 1);" + "def gun(x y z)tan2(0.1,gun(1,sin(1.5),1.0));"); + std::vector<std::string> defstrs; + + for (const auto& d : pd.decls) + kal::codegen(d); + + REQUIRE(pd.defs.size() == 2); + + for (const auto& d : pd.defs) + defstrs.emplace_back(to_string(kal::codegen(d))); + + REQUIRE(defstrs[0] == + "define double @fun() {\n" + "entry:\n" + " %calltmp = call double @sin(double 1.570800e+00)\n" + " %calltmp1 = call double @tan2(double %calltmp, double " + "1.000000e+00)\n" + " ret double %calltmp1\n" + "}\n"); + REQUIRE(defstrs[1] == + "define double @gun(double %x, double %y, double %z) {\n" + "entry:\n" + " %calltmp = call double @sin(double 1.500000e+00)\n" + " %calltmp1 = call double @gun(double 1.000000e+00, double " + "%calltmp, double 1.000000e+00)\n" + " %calltmp2 = call double @tan2(double 1.000000e-01, double " + "%calltmp1)\n" + " ret double %calltmp2\n" + "}\n"); + + kal::codegen_reset(); + } + } + + SECTION("binary operation") + { + { + auto pd = p("extern sin(x); def f() sin(3.14/4)-2.3-3.4-(3*4+1)*3"); + + REQUIRE(pd.decls.size() == 1); + REQUIRE(pd.defs.size() == 1); + + kal::codegen(pd.decls[0]); + REQUIRE(to_string(kal::codegen(pd.defs[0])) == + "define double @f() {\n" + "entry:\n" + " %calltmp = call double @sin(double 7.850000e-01)\n" + " %subtmp = fsub double %calltmp, 2.300000e+00\n" + " %subtmp1 = fsub double %subtmp, 3.400000e+00\n" + " %subtmp2 = fsub double %subtmp1, 3.900000e+01\n" + " ret double %subtmp2\n" + "}\n"); + + kal::codegen_reset(); + } + } + + SECTION("extern, def, call, stmts") + { + { + auto pd = p(R"( +extern tan2(x y) + +def formula(x y z) + 3.14 + 2 * (tan2(x, y) + x*y) / z + +extern sin(x) + +sin(2.72) + formula(1, 2, 3) +)"); + + REQUIRE(pd.decls.size() == 2); + REQUIRE(pd.defs.size() == 1); + REQUIRE(pd.stmts.size() == 1); + + for (const auto& d : pd.decls) + kal::codegen(d); + for (const auto& d : pd.defs) + kal::codegen(d); + + { + auto f = pd.stmts.data(); + auto l = f + pd.stmts.size(); + + REQUIRE(to_string(kal::mkfunc(f, l)) == + "define double @kaleidoscope_body_0__() {\n" + "entry:\n" + " %calltmp = call double @sin(double 2.720000e+00)\n" + " %calltmp1 = call double @formula(double 1.000000e+00, " + "double 2.000000e+00, double 3.000000e+00)\n" + " %addtmp = fadd double %calltmp, %calltmp1\n" + " ret double 0.000000e+00\n" + "}\n"); + } + } + } +}