llvm-journey

LLVM Journey
git clone git://0xff.ir/g/llvm-journey.git
Log | Files | Refs | README | LICENSE

kaleidoscope_codegen.test.cpp (8601B)


      1 
      2 #include "kaleidoscope_codegen.hpp"
      3 
      4 #define CATCH_CONFIG_MAIN
      5 #include <catch2/catch.hpp>
      6 
      7 #include <sstream>
      8 #include <string>
      9 
     10 #include <llvm/IR/Function.h>
     11 #include <llvm/Support/raw_ostream.h>
     12 
     13 #include "kaleidoscope_parser.hpp"
     14 
     15 using kal::cast;
     16 using kal::to_string;
     17 
     18 struct Parsed
     19 {
     20   std::vector<kal::Prototype> decls;
     21   std::vector<kal::Function> defs;
     22   std::vector<kal::ASTNode> stmts;
     23 };
     24 
     25 namespace {
     26 
     27 std::string
     28 to_string(llvm::Function* f)
     29 {
     30   std::string code;
     31   llvm::raw_string_ostream out{ code };
     32 
     33   f->print(out);
     34   return code;
     35 }
     36 
     37 } // anonymous namespace
     38 
     39 TEST_CASE("Code generation for simple programs", "[simple]")
     40 {
     41   // FIXME move to kaleidoscope_parser.hpp
     42   auto p = [](const std::string& s) {
     43     std::vector<kal::Token> tk;
     44     Parsed parsed;
     45 
     46     kal::tokenize(s.cbegin(), s.cend(), std::back_inserter(tk));
     47     tk.erase(std::remove_if(
     48                tk.begin(),
     49                tk.end(),
     50                [](const auto& t) { return t.type == kal::TkType::Comment; }),
     51              tk.end());
     52 
     53     auto tkb = tk.cbegin();
     54     auto tke = tk.cend();
     55     std::vector<kal::ParseError<decltype(tkb)>> errs;
     56 
     57     kal::parse(
     58       tkb,
     59       tke,
     60       [&](auto type, const auto& node) {
     61         switch (type) {
     62           case kal::ParsedEntityType::FuncDecl: {
     63             kal::Prototype proto;
     64             auto ok = cast(node, proto);
     65 
     66             assert(ok && "invalid cast: expects kal::Prototype");
     67 
     68             parsed.decls.emplace_back(std::move(proto));
     69           } break;
     70 
     71           case kal::ParsedEntityType::FuncDef: {
     72             kal::Function func;
     73             auto ok = cast(node, func);
     74 
     75             assert(ok && "invalid cast: expects kal::Function");
     76 
     77             parsed.defs.emplace_back(std::move(func));
     78           } break;
     79 
     80           case kal::ParsedEntityType::Stmt:
     81             parsed.stmts.emplace_back(std::move(node));
     82             break;
     83         }
     84       },
     85       std::back_inserter(errs));
     86 
     87     if (errs.empty())
     88       return parsed;
     89 
     90     std::ostringstream oss;
     91 
     92     for (auto& e : errs) {
     93       oss << "  Error@" << std::distance(tkb, e.pos) << ' ' << e.msg;
     94 
     95       if (e.pos != tke)
     96         oss << '\n' << kal::to_string(*e.pos) << "\n";
     97     }
     98 
     99     FAIL("PARSER ERROR\n" << oss.str());
    100     return parsed; // dummy
    101   };
    102 
    103   SECTION("extern")
    104   {
    105     {
    106       auto pd = p("extern sin(x)");
    107 
    108       REQUIRE(pd.decls.size() == 1);
    109       REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
    110               "declare double @sin(double)\n");
    111 
    112       kal::codegen_reset();
    113     }
    114 
    115     {
    116       auto pd = p("extern tan2 ( arg0 arg1 )");
    117 
    118       REQUIRE(pd.decls.size() == 1);
    119       REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
    120               "declare double @tan2(double, double)\n");
    121 
    122       kal::codegen_reset();
    123     }
    124 
    125     {
    126       auto pd = p("extern cos(realInput);extern atan2(arg0 arg1);");
    127 
    128       REQUIRE(pd.decls.size() == 2);
    129       REQUIRE(to_string(kal::codegen(pd.decls[0])) ==
    130               "declare double @cos(double)\n");
    131       REQUIRE(to_string(kal::codegen(pd.decls[1])) ==
    132               "declare double @atan2(double, double)\n");
    133 
    134       kal::codegen_reset();
    135     }
    136   }
    137 
    138   SECTION("def")
    139   {
    140     {
    141       auto pd = p(R"(
    142 def one(x)
    143   1
    144 )");
    145 
    146       REQUIRE(pd.defs.size() == 1);
    147 
    148       auto c = kal::codegen(pd.defs[0]);
    149       auto cstr = to_string(c);
    150 
    151       REQUIRE(cstr ==
    152               "define double @one(double %x) {\n"
    153               "entry:\n"
    154               "  ret double 1.000000e+00\n"
    155               "}\n");
    156 
    157       kal::codegen_reset();
    158     }
    159 
    160     {
    161       auto pd = p("extern sin(x) def pi2() 1.5708");
    162 
    163       REQUIRE(pd.decls.size() == 1);
    164       REQUIRE(pd.defs.size() == 1);
    165 
    166       auto cdecl = kal::codegen(pd.decls[0]);
    167       auto cdeclstr = to_string(cdecl);
    168       auto cdef = kal::codegen(pd.defs[0]);
    169       auto cdefstr = to_string(cdef);
    170 
    171       REQUIRE(cdeclstr == "declare double @sin(double)\n");
    172       REQUIRE(cdefstr ==
    173               "define double @pi2() {\n"
    174               "entry:\n"
    175               "  ret double 1.570800e+00\n"
    176               "}\n");
    177 
    178       kal::codegen_reset();
    179     }
    180 
    181     {
    182       auto pd =
    183         p("extern tan2(x y);extern sin(x);def fun()tan2(sin(1.5708), 1);"
    184           "def gun(x y z)tan2(0.1,gun(1,sin(1.5),1.0));");
    185       std::vector<std::string> defstrs;
    186 
    187       for (const auto& d : pd.decls)
    188         kal::codegen(d);
    189 
    190       REQUIRE(pd.defs.size() == 2);
    191 
    192       for (const auto& d : pd.defs)
    193         defstrs.emplace_back(to_string(kal::codegen(d)));
    194 
    195       REQUIRE(defstrs[0] ==
    196               "define double @fun() {\n"
    197               "entry:\n"
    198               "  %calltmp = call double @sin(double 1.570800e+00)\n"
    199               "  %calltmp1 = call double @tan2(double %calltmp, double "
    200               "1.000000e+00)\n"
    201               "  ret double %calltmp1\n"
    202               "}\n");
    203       REQUIRE(defstrs[1] ==
    204               "define double @gun(double %x, double %y, double %z) {\n"
    205               "entry:\n"
    206               "  %calltmp = call double @sin(double 1.500000e+00)\n"
    207               "  %calltmp1 = call double @gun(double 1.000000e+00, double "
    208               "%calltmp, double 1.000000e+00)\n"
    209               "  %calltmp2 = call double @tan2(double 1.000000e-01, double "
    210               "%calltmp1)\n"
    211               "  ret double %calltmp2\n"
    212               "}\n");
    213 
    214       kal::codegen_reset();
    215       defstrs.clear();
    216 
    217       for (const auto& d : pd.decls)
    218         kal::codegen(d);
    219       for (const auto& d : pd.defs)
    220         defstrs.emplace_back(to_string(kal::codegen(d, true)));
    221 
    222       REQUIRE(
    223         defstrs[0] ==
    224         "define double @fun() {\n"
    225         "entry:\n"
    226         "  %calltmp1 = call double @tan2(double 0x3FEFFFFFFFFF12A3, double "
    227         "1.000000e+00)\n"
    228         "  ret double %calltmp1\n"
    229         "}\n");
    230       REQUIRE(defstrs[1] ==
    231               "define double @gun(double %x, double %y, double %z) {\n"
    232               "entry:\n"
    233               "  %calltmp1 = call double @gun(double 1.000000e+00, double "
    234               "0x3FEFEB7A9B2C6D8B, double 1.000000e+00)\n"
    235               "  %calltmp2 = call double @tan2(double 1.000000e-01, double "
    236               "%calltmp1)\n"
    237               "  ret double %calltmp2\n"
    238               "}\n");
    239 
    240       kal::codegen_reset();
    241     }
    242   }
    243 
    244   SECTION("binary operation")
    245   {
    246     {
    247       auto pd = p("extern sin(x); def f() sin(3.14/4)-2.3-3.4-(3*4+1)*3");
    248 
    249       REQUIRE(pd.decls.size() == 1);
    250       REQUIRE(pd.defs.size() == 1);
    251 
    252       kal::codegen(pd.decls[0]);
    253       REQUIRE(to_string(kal::codegen(pd.defs[0])) ==
    254               "define double @f() {\n"
    255               "entry:\n"
    256               "  %calltmp = call double @sin(double 7.850000e-01)\n"
    257               "  %subtmp = fsub double %calltmp, 2.300000e+00\n"
    258               "  %subtmp1 = fsub double %subtmp, 3.400000e+00\n"
    259               "  %subtmp2 = fsub double %subtmp1, 3.900000e+01\n"
    260               "  ret double %subtmp2\n"
    261               "}\n");
    262 
    263       kal::codegen_reset();
    264 
    265       kal::codegen(pd.decls[0]);
    266       REQUIRE(to_string(kal::codegen(pd.defs[0], true)) ==
    267               "define double @f() {\n"
    268               "entry:\n"
    269               "  ret double 0xC045FF205A3B2E7C\n"
    270               "}\n");
    271 
    272       kal::codegen_reset();
    273     }
    274   }
    275 
    276   SECTION("extern, def, call, stmts")
    277   {
    278     {
    279       auto pd = p(R"(
    280 extern tan2(x y)
    281 
    282 def formula(x y z)
    283   3.14 + 2 * (tan2(x, y) + x*y) / z
    284 
    285 extern sin(x)
    286 
    287 sin(2.72) + formula(1, 2, 3)
    288 )");
    289 
    290       REQUIRE(pd.decls.size() == 2);
    291       REQUIRE(pd.defs.size() == 1);
    292       REQUIRE(pd.stmts.size() == 1);
    293 
    294       for (const auto& d : pd.decls)
    295         kal::codegen(d);
    296       for (const auto& d : pd.defs)
    297         kal::codegen(d);
    298 
    299       {
    300         auto f = pd.stmts.data();
    301         auto l = f + pd.stmts.size();
    302 
    303         REQUIRE(to_string(kal::mkfunc("kaleidoscope_body_0__", f, l)) ==
    304                 "define double @kaleidoscope_body_0__() {\n"
    305                 "entry:\n"
    306                 "  %calltmp = call double @sin(double 2.720000e+00)\n"
    307                 "  %calltmp1 = call double @formula(double 1.000000e+00, "
    308                 "double 2.000000e+00, double 3.000000e+00)\n"
    309                 "  %addtmp = fadd double %calltmp, %calltmp1\n"
    310                 "  ret double 0.000000e+00\n"
    311                 "}\n");
    312 
    313         REQUIRE(to_string(kal::mkfunc("kaleidoscope_body_1__", f, l, true)) ==
    314                 "define double @kaleidoscope_body_1__() {\n"
    315                 "entry:\n"
    316                 "  %calltmp1 = call double @formula(double 1.000000e+00, "
    317                 "double 2.000000e+00, double 3.000000e+00)\n"
    318                 "  ret double 0.000000e+00\n"
    319                 "}\n");
    320       }
    321 
    322       kal::codegen_reset();
    323     }
    324   }
    325 }