llvm-journey

LLVM Journey
git clone git://0xff.ir/g/llvm-journey.git
Log | Files | Refs | README | LICENSE

kaleidoscope_parser.test.cpp (7734B)


      1 
      2 #include "kaleidoscope_parser.hpp"
      3 
      4 #define CATCH_CONFIG_MAIN
      5 #define CATCH_CONFIG_COLOUR_NONE
      6 #include <catch2/catch.hpp>
      7 
      8 #include <iterator>
      9 #include <string>
     10 #include <vector>
     11 
     12 #include "kaleidoscope_ast.hpp"
     13 #include "kaleidoscope_tokens.hpp"
     14 
     15 using kal::cast;
     16 using kal::to_string;
     17 
     18 struct Parsed
     19 {
     20   std::vector<kal::Prototype> decls;
     21   std::vector<kal::Function> defs;
     22   std::vector<kal::ASTNode> stmts;
     23 };
     24 
     25 TEST_CASE("parse simple programs", "[simple]")
     26 {
     27   auto p = [](const std::string& s) {
     28     std::vector<kal::Token> tk;
     29     Parsed parsed;
     30 
     31     kal::tokenize(s.cbegin(), s.cend(), std::back_inserter(tk));
     32     tk.erase(std::remove_if(
     33                tk.begin(),
     34                tk.end(),
     35                [](const auto& t) { return t.type == kal::TkType::Comment; }),
     36              tk.end());
     37 
     38     auto tkb = tk.cbegin();
     39     auto tke = tk.cend();
     40     std::vector<kal::ParseError<decltype(tkb)>> errs;
     41 
     42     kal::parse(
     43       tkb,
     44       tke,
     45       [&](auto type, const auto& node) {
     46         switch (type) {
     47           case kal::ParsedEntityType::FuncDecl: {
     48             kal::Prototype proto;
     49             auto ok = cast(node, proto);
     50 
     51             assert(ok && "invalid cast: expects kal::Prototype");
     52 
     53             parsed.decls.emplace_back(std::move(proto));
     54           } break;
     55 
     56           case kal::ParsedEntityType::FuncDef: {
     57             kal::Function func;
     58             auto ok = cast(node, func);
     59 
     60             assert(ok && "invalid cast: expects kal::Function");
     61 
     62             parsed.defs.emplace_back(std::move(func));
     63           } break;
     64 
     65           case kal::ParsedEntityType::Stmt:
     66             parsed.stmts.emplace_back(std::move(node));
     67             break;
     68         }
     69       },
     70       std::back_inserter(errs));
     71 
     72     if (errs.empty())
     73       return parsed;
     74 
     75     std::ostringstream oss;
     76 
     77     for (auto& e : errs) {
     78       oss << "  Error@" << std::distance(tkb, e.pos) << ' ' << e.msg;
     79 
     80       if (e.pos != tke)
     81         oss << '\n' << kal::to_string(*e.pos) << "\n";
     82     }
     83 
     84     FAIL("PARSER ERROR\n" << oss.str());
     85     return parsed; // dummy
     86   };
     87 
     88   SECTION("extern")
     89   {
     90     {
     91       auto pd = p("extern sin(x)");
     92 
     93       REQUIRE(pd.decls.size() == 1);
     94       REQUIRE(pd.defs.empty());
     95       REQUIRE(pd.stmts.empty());
     96       REQUIRE(to_string(pd.decls[0]) == "(prototype sin x)");
     97     }
     98 
     99     {
    100       auto pd = p("extern tan2 ( arg0 arg1 )");
    101 
    102       REQUIRE(pd.decls.size() == 1);
    103       REQUIRE(pd.defs.empty());
    104       REQUIRE(pd.stmts.empty());
    105       REQUIRE(to_string(pd.decls[0]) == "(prototype tan2 arg0 arg1)");
    106     }
    107 
    108     {
    109       auto pd = p("extern cos(realInput);extern atan2(arg0 arg1);");
    110 
    111       REQUIRE(pd.decls.size() == 2);
    112       REQUIRE(pd.defs.empty());
    113       REQUIRE(pd.stmts.empty());
    114       REQUIRE(to_string(pd.decls[0]) == "(prototype cos realInput)");
    115       REQUIRE(to_string(pd.decls[1]) == "(prototype atan2 arg0 arg1)");
    116     }
    117   }
    118 
    119   SECTION("def")
    120   {
    121     {
    122       auto pd = p(R"(
    123 def one(x)
    124   1
    125 )");
    126 
    127       REQUIRE(pd.decls.empty());
    128       REQUIRE(pd.defs.size() == 1);
    129       REQUIRE(to_string(pd.defs[0]) ==
    130               "(function (prototype one x) (number 1))");
    131     }
    132 
    133     {
    134       auto pd = p("extern sin(x) def pi2() 1.5708");
    135 
    136       REQUIRE(pd.decls.size() == 1);
    137       REQUIRE(pd.defs.size() == 1);
    138       REQUIRE(to_string(pd.decls[0]) == "(prototype sin x)");
    139       REQUIRE(to_string(pd.defs[0]) ==
    140               "(function (prototype pi2) (number 1.5708))");
    141     }
    142 
    143     {
    144       auto pd = p("extern sin(x) def sinPi2() sin(1.5708)");
    145 
    146       REQUIRE(pd.decls.size() == 1);
    147       REQUIRE(pd.defs.size() == 1);
    148       REQUIRE(to_string(pd.decls[0]) == "(prototype sin x)");
    149       REQUIRE(to_string(pd.defs[0]) ==
    150               "(function (prototype sinPi2) (call sin (number 1.5708)))");
    151     }
    152 
    153     {
    154       auto pd =
    155         p("extern tan2(x y);extern sin(x);def fun()tan2(sin(1.5708), 1);"
    156           "def gun(x y z)tan2(0.1,gun(1,sin(1.5),1.0));");
    157 
    158       REQUIRE(pd.decls.size() == 2);
    159       REQUIRE(pd.defs.size() == 2);
    160       REQUIRE(to_string(pd.decls[0]) == "(prototype tan2 x y)");
    161       REQUIRE(to_string(pd.decls[1]) == "(prototype sin x)");
    162       REQUIRE(to_string(pd.defs[0]) ==
    163               "(function (prototype fun) (call tan2 (call sin (number 1.5708)) "
    164               "(number 1)))");
    165       REQUIRE(to_string(pd.defs[1]) ==
    166               "(function (prototype gun x y z) (call tan2 (number 0.1) (call "
    167               "gun (number 1) (call sin (number 1.5)) (number 1))))");
    168     }
    169   }
    170 
    171   SECTION("binary operation")
    172   {
    173     using Num = kal::Number;
    174     using BOp = kal::BinaryOp;
    175     using Var = kal::Variable;
    176 
    177     {
    178       auto pd = p("2 + 3");
    179 
    180       REQUIRE(pd.decls.empty());
    181       REQUIRE(pd.defs.empty());
    182       REQUIRE(pd.stmts.size() == 1);
    183       REQUIRE(pd.stmts[0] == BOp{ '+', Num{ 2 }, Num{ 3 } });
    184       REQUIRE(to_string(pd.stmts[0]) == "(binop '+' (number 2) (number 3))");
    185     }
    186 
    187     {
    188       auto pd = p("x * y");
    189 
    190       REQUIRE(pd.decls.empty());
    191       REQUIRE(pd.defs.empty());
    192       REQUIRE(pd.stmts.size() == 1);
    193       REQUIRE(pd.stmts[0] == BOp{ '*', Var{ "x" }, Var{ "y" } });
    194     }
    195 
    196     {
    197       auto pd = p("x * y * z");
    198 
    199       REQUIRE(pd.decls.empty());
    200       REQUIRE(pd.defs.empty());
    201       REQUIRE(pd.stmts.size() == 1);
    202       REQUIRE(pd.stmts[0] ==
    203               BOp{ '*', BOp{ '*', Var{ "x" }, Var{ "y" } }, Var{ "z" } });
    204     }
    205 
    206     {
    207       auto pd = p("4 * 5 + 3");
    208 
    209       REQUIRE(pd.decls.empty());
    210       REQUIRE(pd.defs.empty());
    211       REQUIRE(pd.stmts.size() == 1);
    212       REQUIRE(pd.stmts[0] ==
    213               BOp{ '+', BOp{ '*', Num{ 4 }, Num{ 5 } }, Num{ 3 } });
    214     }
    215 
    216     {
    217       auto pd = p("4+2.15*3");
    218 
    219       REQUIRE(pd.decls.empty());
    220       REQUIRE(pd.defs.empty());
    221       REQUIRE(pd.stmts.size() == 1);
    222       REQUIRE(pd.stmts[0] == BOp{
    223                                '+',
    224                                Num{ 4 },
    225                                BOp{ '*', Num{ 2.15 }, Num{ 3 } },
    226                              });
    227     }
    228 
    229     {
    230       auto pd = p("4+2.15*3/2-4.4/2/3#expr");
    231 
    232       REQUIRE(pd.decls.empty());
    233       REQUIRE(pd.defs.empty());
    234       REQUIRE(pd.stmts.size() == 1);
    235       REQUIRE(pd.stmts[0] ==
    236               BOp{ '-',
    237                    BOp{
    238                      '+',
    239                      Num{ 4 },
    240                      BOp{ '/', BOp{ '*', Num{ 2.15 }, Num{ 3 } }, Num{ 2 } },
    241                    },
    242                    BOp{ '/', BOp{ '/', Num{ 4.4 }, Num{ 2 } }, Num{ 3 } } });
    243     }
    244 
    245     {
    246       auto pd = p("sin(3.14/4)-2.3-3.4-(3*4+1)*3");
    247 
    248       REQUIRE(pd.decls.empty());
    249       REQUIRE(pd.defs.empty());
    250       REQUIRE(pd.stmts.size() == 1);
    251       REQUIRE(
    252         to_string(pd.stmts[0]) ==
    253         "(binop '-' (binop '-' (binop '-' (call sin (binop '/' (number 3.14) "
    254         "(number 4))) (number 2.3)) (number 3.4)) (binop '*' (binop '+' (binop "
    255         "'*' (number 3) (number 4)) (number 1)) (number 3)))");
    256     }
    257   }
    258 
    259   SECTION("extern, def, call")
    260   {
    261     {
    262       auto pd = p(R"(
    263 extern tan2(x y)
    264 
    265 def formula(x y z)
    266   3.14 + 2 * (tan2(x, y) + x*y) / z
    267 
    268 extern sin(x)
    269 
    270 sin(2.72) + formula(1, 2, 3)
    271 )");
    272 
    273       REQUIRE(pd.decls.size() == 2);
    274       REQUIRE(pd.defs.size() == 1);
    275       REQUIRE(pd.stmts.size() == 1);
    276 
    277       REQUIRE(to_string(pd.decls[0]) == "(prototype tan2 x y)");
    278       REQUIRE(to_string(pd.decls[1]) == "(prototype sin x)");
    279 
    280       REQUIRE(to_string(pd.defs[0]) ==
    281               "(function (prototype formula x y z) (binop '+' (number 3.14) "
    282               "(binop '/' (binop '*' (number 2) (binop '+' (call tan2 "
    283               "(variable x) (variable y)) (binop '*' (variable x) (variable "
    284               "y)))) (variable z))))");
    285 
    286       REQUIRE(to_string(pd.stmts[0]) ==
    287               "(binop '+' (call sin (number 2.72)) (call formula (number 1) "
    288               "(number 2) (number 3)))");
    289     }
    290   }
    291 }