commit e323acc67f313d3a5ecf7b9f23b46f550b5394c1
parent 6b606fbfcb6332ea426ff171f6526ecb863cff6f
Author: Mohammad-Reza Nabipoor <m.nabipoor@yahoo.com>
Date: Mon, 10 Aug 2020 22:33:59 +0430
Add data types for AST nodes
ASTNode is a polymorphic type (Sean Parent's style of runtime
polymorphism).
Other nodes: Number, Variable, BinaryOp, Call, Prototype, Function
Diffstat:
6 files changed, 443 insertions(+), 1 deletion(-)
diff --git a/Makefile b/Makefile
@@ -6,7 +6,7 @@ eflags += -I. # flags for examples
catch2 = tests/catch2/catch.hpp
# tests
-tbin = kaleidoscope_lexer.test
+tbin = kaleidoscope_lexer.test kaleidoscope_ast.test
# examples
ebin = kaleidoscope_lexer.ex
@@ -21,6 +21,12 @@ tests: $(tbin)
.PHONY: examples
examples: $(ebin)
+kaleidoscope_ast.o: kaleidoscope_ast.cpp kaleidoscope_ast.hpp \
+ kaleidoscope_codegen.hpp
+
+kaleidoscope_codegen.o: kaleidoscope_codegen.cpp kaleidoscope_codegen.hpp \
+ kaleidoscope_ast.hpp
+
#--- tests
kaleidoscope_lexer.test.o: CXXFLAGS += $(tflags)
@@ -30,6 +36,14 @@ kaleidoscope_lexer.test.o: tests/kaleidoscope_lexer.test.cpp
kaleidoscope_lexer.test: kaleidoscope_lexer.test.o
$(CXX) $(CXXFLAGS) -o $@ $< $(LDFLAGS)
+kaleidoscope_ast.test.o: CXXFLAGS += $(tflags)
+kaleidoscope_ast.test.o: kaleidoscope_ast.hpp $(catch2)
+kaleidoscope_ast.test.o: tests/kaleidoscope_ast.test.cpp
+ $(CXX) $(CXXFLAGS) -c -o $@ $<
+kaleidoscope_ast.test: kaleidoscope_ast.o kaleidoscope_codegen.o
+kaleidoscope_ast.test: kaleidoscope_ast.test.o
+ $(CXX) $(CXXFLAGS) -o $@ $^ $(LDFLAGS)
+
#--- examples
kaleidoscope_lexer.ex.o: CXXFLAGS += $(eflags)
diff --git a/kaleidoscope_ast.cpp b/kaleidoscope_ast.cpp
@@ -0,0 +1,71 @@
+
+#include "kaleidoscope_ast.hpp"
+
+#include <sstream>
+
+namespace kal {
+
+std::string
+to_string(const Number& n)
+{
+ std::ostringstream oss;
+
+ oss << "(number " << n.value << ')';
+ return oss.str();
+}
+
+std::string
+to_string(const Variable& v)
+{
+ std::ostringstream oss;
+
+ oss << "(variable " << v.name << ')';
+ return oss.str();
+}
+
+std::string
+to_string(const BinaryOp& op)
+{
+ std::ostringstream oss;
+
+ oss << "(binop '" << op.op << "' " << to_string(op.lhs) << ' '
+ << to_string(op.rhs) << ')';
+ return oss.str();
+}
+
+std::string
+to_string(const Call& c)
+{
+ std::ostringstream oss;
+
+ oss << "(call " << c.callee;
+ for (const auto& arg : c.args)
+ oss << ' ' << to_string(arg);
+ oss << ')';
+
+ return oss.str();
+}
+
+std::string
+to_string(const Prototype& p)
+{
+ std::ostringstream oss;
+
+ oss << "(prototype " << p.name;
+ for (const auto& param : p.params)
+ oss << ' ' << param;
+ oss << ')';
+
+ return oss.str();
+}
+
+std::string
+to_string(const Function& f)
+{
+ std::ostringstream oss;
+
+ oss << "(function " << to_string(f.proto) << ' ' << to_string(f.body) << ')';
+ return oss.str();
+}
+
+} // namespace kal
diff --git a/kaleidoscope_ast.hpp b/kaleidoscope_ast.hpp
@@ -0,0 +1,178 @@
+
+#pragma once
+
+#include "kaleidoscope_codegen.hpp"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace llvm {
+class Value;
+}
+
+namespace kal {
+
+std::string
+to_string(const Number&);
+
+std::string
+to_string(const Variable&);
+
+std::string
+to_string(const BinaryOp&);
+
+std::string
+to_string(const Call&);
+
+std::string
+to_string(const Prototype&);
+
+std::string
+to_string(const Function&);
+
+// https://youtu.be/QGcVXgEVMJg?t=49m13s
+class ASTNode
+{
+public:
+ template<typename T>
+ ASTNode(T value)
+ : self_{ std::make_shared<Model<T>>(std::move(value)) }
+ {}
+
+ friend std::string to_string(const ASTNode& t)
+ {
+ return t.self_->to_string();
+ }
+
+ friend llvm::Value* codegen(const ASTNode& t) { return t.self_->codegen(); }
+
+ friend bool operator==(const ASTNode& x, const ASTNode& y)
+ {
+ return x.self_->to_string() == y.self_->to_string(); // FIXME
+ }
+
+ friend bool operator!=(const ASTNode& x, const ASTNode& y)
+ {
+ return !(x == y);
+ }
+
+private:
+ struct Concept
+ {
+ virtual ~Concept() = default;
+ virtual std::string to_string(void) const = 0;
+ virtual llvm::Value* codegen(void) const = 0;
+ };
+
+ template<typename T>
+ struct Model final : Concept
+ {
+ Model(T value)
+ : value_(std::move(value))
+ {}
+
+ std::string to_string(void) const override
+ {
+ using kal::to_string; // FIXME
+
+ return to_string(value_);
+ }
+
+ llvm::Value* codegen(void) const override
+ {
+ using kal::codegen; // FIXME
+
+ return codegen(value_);
+ }
+
+ T value_;
+ };
+
+ std::shared_ptr<const Concept> self_;
+};
+
+struct Number
+{
+ double value;
+
+ friend bool operator==(const Number& x, const Number& y)
+ {
+ return x.value == y.value; // CHECKME floating-point eq?
+ }
+ friend bool operator!=(const Number& x, const Number& y) { return !(x == y); }
+};
+
+struct Variable
+{
+ std::string name;
+
+ friend bool operator==(const Variable& x, const Variable& y)
+ {
+ return x.name == y.name;
+ }
+ friend bool operator!=(const Variable& x, const Variable& y)
+ {
+ return !(x == y);
+ }
+};
+
+struct BinaryOp
+{
+ char op;
+ ASTNode lhs;
+ ASTNode rhs;
+
+ friend bool operator==(const BinaryOp& x, const BinaryOp& y)
+ {
+ return x.op == y.op && x.lhs == x.rhs;
+ }
+ friend bool operator!=(const BinaryOp& x, const BinaryOp& y)
+ {
+ return !(x == y);
+ }
+};
+
+struct Call
+{
+ std::string callee;
+ std::vector<ASTNode> args;
+
+ friend bool operator==(const Call& x, const Call& y)
+ {
+ return x.callee == y.callee && x.args == y.args;
+ }
+ friend bool operator!=(const Call& x, const Call& y) { return !(x == y); }
+};
+
+struct Prototype
+{
+ std::string name;
+ std::vector<std::string> params;
+
+ friend bool operator==(const Prototype& x, const Prototype& y)
+ {
+ return x.name == y.name && x.params == y.params;
+ }
+ friend bool operator!=(const Prototype& x, const Prototype& y)
+ {
+ return !(x == y);
+ }
+};
+
+struct Function
+{
+ Prototype proto;
+ ASTNode body;
+
+ friend bool operator==(const Function& x, const Function& y)
+ {
+ return x.proto == y.proto && x.body == y.body;
+ }
+ friend bool operator!=(const Function& x, const Function& y)
+ {
+ return !(x == y);
+ }
+};
+
+} // namespace kal
diff --git a/kaleidoscope_codegen.cpp b/kaleidoscope_codegen.cpp
@@ -0,0 +1,43 @@
+
+#include "kaleidoscope_codegen.hpp"
+
+#include "kaleidoscope_ast.hpp"
+
+#include <map>
+#include <string>
+
+llvm::Value*
+kal::codegen(const Number&)
+{
+ return {};
+}
+
+llvm::Value*
+kal::codegen(const Variable&)
+{
+ return {};
+}
+
+llvm::Value*
+kal::codegen(const BinaryOp&)
+{
+ return {};
+}
+
+llvm::Value*
+kal::codegen(const Call&)
+{
+ return {};
+}
+
+llvm::Value*
+kal::codegen(const Prototype&)
+{
+ return {};
+}
+
+llvm::Value*
+kal::codegen(const Function&)
+{
+ return {};
+}
diff --git a/kaleidoscope_codegen.hpp b/kaleidoscope_codegen.hpp
@@ -0,0 +1,30 @@
+
+#pragma once
+
+namespace llvm {
+class Value;
+}
+
+namespace kal {
+
+struct Number;
+struct Variable;
+struct BinaryOp;
+struct Call;
+struct Prototype;
+struct Function;
+
+llvm::Value*
+codegen(const Number&);
+llvm::Value*
+codegen(const Variable&);
+llvm::Value*
+codegen(const BinaryOp&);
+llvm::Value*
+codegen(const Call&);
+llvm::Value*
+codegen(const Prototype&);
+llvm::Value*
+codegen(const Function&);
+
+} // namespace kal
diff --git a/tests/kaleidoscope_ast.test.cpp b/tests/kaleidoscope_ast.test.cpp
@@ -0,0 +1,106 @@
+
+#include "kaleidoscope_ast.hpp"
+
+#define CATCH_CONFIG_MAIN
+#define CATCH_CONFIG_COLOUR_NONE
+#include <catch2/catch.hpp>
+
+#include <array>
+
+TEST_CASE("to_string", "[str]")
+{
+ SECTION("static_assert the size of ASTNode")
+ {
+ static_assert(sizeof(kal::ASTNode) == sizeof(std::shared_ptr<int>), "");
+ }
+
+ SECTION("Number")
+ {
+ std::array<kal::Number, 5> n{ 0, 1, -1, 3.14159265, -3.1415926535 };
+
+ REQUIRE(n[0].value == 0);
+ REQUIRE(n[1].value == 1);
+ REQUIRE(n[2].value == -1);
+ REQUIRE(n[3].value == 3.14159265);
+ REQUIRE(n[4].value == -3.1415926535);
+
+ REQUIRE(kal::to_string(n[0]) == "(number 0)");
+ REQUIRE(kal::to_string(n[1]) == "(number 1)");
+ REQUIRE(kal::to_string(n[2]) == "(number -1)");
+ REQUIRE(kal::to_string(n[3]) == "(number 3.14159)"); // FIXME str trunc
+ REQUIRE(kal::to_string(n[4]) == "(number -3.14159)"); // FIXME str trunc
+ }
+
+ SECTION("Variable")
+ {
+ kal::Variable v{ "var" };
+
+ REQUIRE(v.name == "var");
+ }
+
+ SECTION("BinaryOp")
+ {
+ kal::Number n0{ 9.8 };
+ kal::Number n1{ -2.72 };
+ kal::BinaryOp op01{ '+', n0, n1 };
+
+ REQUIRE(op01.op == '+');
+ REQUIRE(op01.lhs == n0);
+ REQUIRE(op01.rhs == n1);
+
+ kal::Variable v0{ "var0" };
+ kal::BinaryOp op02{ '*', v0, op01 };
+
+ REQUIRE(op02.op == '*');
+ REQUIRE(op02.lhs == v0);
+ REQUIRE(op02.rhs == op01);
+
+ kal::BinaryOp op03{ '/', op01, op02 };
+
+ REQUIRE(op03.op == '/');
+ REQUIRE(op03.lhs == op01);
+ REQUIRE(op03.rhs == op02);
+ }
+
+ SECTION("Call")
+ {
+ kal::Call c0{ "pi", {} };
+ kal::Call c1{ "tan2", { kal::Number{ 3.14 / 6 }, kal::Variable{ "y" } } };
+
+ REQUIRE(kal::to_string(c0) == "(call pi)");
+ // FIXME stod truncation
+ REQUIRE(kal::to_string(c1) == "(call tan2 (number 0.523333) (variable y))");
+ }
+
+ SECTION("Prototype")
+ {
+ kal::Prototype p{ "memcmp", { "dest", "src", "n" } };
+
+ REQUIRE(p.name == "memcmp");
+ REQUIRE(p.params.size() == 3);
+ REQUIRE(p.params.at(0) == "dest");
+ REQUIRE(p.params.at(1) == "src");
+ REQUIRE(p.params.at(2) == "n");
+ }
+
+ SECTION("Function")
+ {
+ kal::Function fun{
+ kal::Prototype{ "fun", { "x", "y", "z" } },
+ // (- (* b b) (* (* 4 a) c))
+ kal::BinaryOp{
+ '-',
+ kal::BinaryOp{ '*', kal::Variable{ "b" }, kal::Variable{ "b" } },
+ kal::BinaryOp{
+ '*',
+ kal::BinaryOp{ '*', kal::Number{ 4 }, kal::Variable{ "a" } },
+ kal::Variable{ "c" },
+ } }
+ };
+
+ REQUIRE(kal::to_string(fun) ==
+ "(function (prototype fun x y z) (binop '-' (binop '*' (variable "
+ "b) (variable b)) (binop '*' (binop '*' (number 4) (variable a)) "
+ "(variable c))))");
+ }
+}