kaleidoscope_codegen.test.cpp (8601B)
1 2 #include "kaleidoscope_codegen.hpp" 3 4 #define CATCH_CONFIG_MAIN 5 #include <catch2/catch.hpp> 6 7 #include <sstream> 8 #include <string> 9 10 #include <llvm/IR/Function.h> 11 #include <llvm/Support/raw_ostream.h> 12 13 #include "kaleidoscope_parser.hpp" 14 15 using kal::cast; 16 using kal::to_string; 17 18 struct Parsed 19 { 20 std::vector<kal::Prototype> decls; 21 std::vector<kal::Function> defs; 22 std::vector<kal::ASTNode> stmts; 23 }; 24 25 namespace { 26 27 std::string 28 to_string(llvm::Function* f) 29 { 30 std::string code; 31 llvm::raw_string_ostream out{ code }; 32 33 f->print(out); 34 return code; 35 } 36 37 } // anonymous namespace 38 39 TEST_CASE("Code generation for simple programs", "[simple]") 40 { 41 // FIXME move to kaleidoscope_parser.hpp 42 auto p = [](const std::string& s) { 43 std::vector<kal::Token> tk; 44 Parsed parsed; 45 46 kal::tokenize(s.cbegin(), s.cend(), std::back_inserter(tk)); 47 tk.erase(std::remove_if( 48 tk.begin(), 49 tk.end(), 50 [](const auto& t) { return t.type == kal::TkType::Comment; }), 51 tk.end()); 52 53 auto tkb = tk.cbegin(); 54 auto tke = tk.cend(); 55 std::vector<kal::ParseError<decltype(tkb)>> errs; 56 57 kal::parse( 58 tkb, 59 tke, 60 [&](auto type, const auto& node) { 61 switch (type) { 62 case kal::ParsedEntityType::FuncDecl: { 63 kal::Prototype proto; 64 auto ok = cast(node, proto); 65 66 assert(ok && "invalid cast: expects kal::Prototype"); 67 68 parsed.decls.emplace_back(std::move(proto)); 69 } break; 70 71 case kal::ParsedEntityType::FuncDef: { 72 kal::Function func; 73 auto ok = cast(node, func); 74 75 assert(ok && "invalid cast: expects kal::Function"); 76 77 parsed.defs.emplace_back(std::move(func)); 78 } break; 79 80 case kal::ParsedEntityType::Stmt: 81 parsed.stmts.emplace_back(std::move(node)); 82 break; 83 } 84 }, 85 std::back_inserter(errs)); 86 87 if (errs.empty()) 88 return parsed; 89 90 std::ostringstream oss; 91 92 for (auto& e : errs) { 93 oss << " Error@" << std::distance(tkb, e.pos) << ' ' << e.msg; 94 95 if (e.pos != tke) 96 oss << '\n' << kal::to_string(*e.pos) << "\n"; 97 } 98 99 FAIL("PARSER ERROR\n" << oss.str()); 100 return parsed; // dummy 101 }; 102 103 SECTION("extern") 104 { 105 { 106 auto pd = p("extern sin(x)"); 107 108 REQUIRE(pd.decls.size() == 1); 109 REQUIRE(to_string(kal::codegen(pd.decls[0])) == 110 "declare double @sin(double)\n"); 111 112 kal::codegen_reset(); 113 } 114 115 { 116 auto pd = p("extern tan2 ( arg0 arg1 )"); 117 118 REQUIRE(pd.decls.size() == 1); 119 REQUIRE(to_string(kal::codegen(pd.decls[0])) == 120 "declare double @tan2(double, double)\n"); 121 122 kal::codegen_reset(); 123 } 124 125 { 126 auto pd = p("extern cos(realInput);extern atan2(arg0 arg1);"); 127 128 REQUIRE(pd.decls.size() == 2); 129 REQUIRE(to_string(kal::codegen(pd.decls[0])) == 130 "declare double @cos(double)\n"); 131 REQUIRE(to_string(kal::codegen(pd.decls[1])) == 132 "declare double @atan2(double, double)\n"); 133 134 kal::codegen_reset(); 135 } 136 } 137 138 SECTION("def") 139 { 140 { 141 auto pd = p(R"( 142 def one(x) 143 1 144 )"); 145 146 REQUIRE(pd.defs.size() == 1); 147 148 auto c = kal::codegen(pd.defs[0]); 149 auto cstr = to_string(c); 150 151 REQUIRE(cstr == 152 "define double @one(double %x) {\n" 153 "entry:\n" 154 " ret double 1.000000e+00\n" 155 "}\n"); 156 157 kal::codegen_reset(); 158 } 159 160 { 161 auto pd = p("extern sin(x) def pi2() 1.5708"); 162 163 REQUIRE(pd.decls.size() == 1); 164 REQUIRE(pd.defs.size() == 1); 165 166 auto cdecl = kal::codegen(pd.decls[0]); 167 auto cdeclstr = to_string(cdecl); 168 auto cdef = kal::codegen(pd.defs[0]); 169 auto cdefstr = to_string(cdef); 170 171 REQUIRE(cdeclstr == "declare double @sin(double)\n"); 172 REQUIRE(cdefstr == 173 "define double @pi2() {\n" 174 "entry:\n" 175 " ret double 1.570800e+00\n" 176 "}\n"); 177 178 kal::codegen_reset(); 179 } 180 181 { 182 auto pd = 183 p("extern tan2(x y);extern sin(x);def fun()tan2(sin(1.5708), 1);" 184 "def gun(x y z)tan2(0.1,gun(1,sin(1.5),1.0));"); 185 std::vector<std::string> defstrs; 186 187 for (const auto& d : pd.decls) 188 kal::codegen(d); 189 190 REQUIRE(pd.defs.size() == 2); 191 192 for (const auto& d : pd.defs) 193 defstrs.emplace_back(to_string(kal::codegen(d))); 194 195 REQUIRE(defstrs[0] == 196 "define double @fun() {\n" 197 "entry:\n" 198 " %calltmp = call double @sin(double 1.570800e+00)\n" 199 " %calltmp1 = call double @tan2(double %calltmp, double " 200 "1.000000e+00)\n" 201 " ret double %calltmp1\n" 202 "}\n"); 203 REQUIRE(defstrs[1] == 204 "define double @gun(double %x, double %y, double %z) {\n" 205 "entry:\n" 206 " %calltmp = call double @sin(double 1.500000e+00)\n" 207 " %calltmp1 = call double @gun(double 1.000000e+00, double " 208 "%calltmp, double 1.000000e+00)\n" 209 " %calltmp2 = call double @tan2(double 1.000000e-01, double " 210 "%calltmp1)\n" 211 " ret double %calltmp2\n" 212 "}\n"); 213 214 kal::codegen_reset(); 215 defstrs.clear(); 216 217 for (const auto& d : pd.decls) 218 kal::codegen(d); 219 for (const auto& d : pd.defs) 220 defstrs.emplace_back(to_string(kal::codegen(d, true))); 221 222 REQUIRE( 223 defstrs[0] == 224 "define double @fun() {\n" 225 "entry:\n" 226 " %calltmp1 = call double @tan2(double 0x3FEFFFFFFFFF12A3, double " 227 "1.000000e+00)\n" 228 " ret double %calltmp1\n" 229 "}\n"); 230 REQUIRE(defstrs[1] == 231 "define double @gun(double %x, double %y, double %z) {\n" 232 "entry:\n" 233 " %calltmp1 = call double @gun(double 1.000000e+00, double " 234 "0x3FEFEB7A9B2C6D8B, double 1.000000e+00)\n" 235 " %calltmp2 = call double @tan2(double 1.000000e-01, double " 236 "%calltmp1)\n" 237 " ret double %calltmp2\n" 238 "}\n"); 239 240 kal::codegen_reset(); 241 } 242 } 243 244 SECTION("binary operation") 245 { 246 { 247 auto pd = p("extern sin(x); def f() sin(3.14/4)-2.3-3.4-(3*4+1)*3"); 248 249 REQUIRE(pd.decls.size() == 1); 250 REQUIRE(pd.defs.size() == 1); 251 252 kal::codegen(pd.decls[0]); 253 REQUIRE(to_string(kal::codegen(pd.defs[0])) == 254 "define double @f() {\n" 255 "entry:\n" 256 " %calltmp = call double @sin(double 7.850000e-01)\n" 257 " %subtmp = fsub double %calltmp, 2.300000e+00\n" 258 " %subtmp1 = fsub double %subtmp, 3.400000e+00\n" 259 " %subtmp2 = fsub double %subtmp1, 3.900000e+01\n" 260 " ret double %subtmp2\n" 261 "}\n"); 262 263 kal::codegen_reset(); 264 265 kal::codegen(pd.decls[0]); 266 REQUIRE(to_string(kal::codegen(pd.defs[0], true)) == 267 "define double @f() {\n" 268 "entry:\n" 269 " ret double 0xC045FF205A3B2E7C\n" 270 "}\n"); 271 272 kal::codegen_reset(); 273 } 274 } 275 276 SECTION("extern, def, call, stmts") 277 { 278 { 279 auto pd = p(R"( 280 extern tan2(x y) 281 282 def formula(x y z) 283 3.14 + 2 * (tan2(x, y) + x*y) / z 284 285 extern sin(x) 286 287 sin(2.72) + formula(1, 2, 3) 288 )"); 289 290 REQUIRE(pd.decls.size() == 2); 291 REQUIRE(pd.defs.size() == 1); 292 REQUIRE(pd.stmts.size() == 1); 293 294 for (const auto& d : pd.decls) 295 kal::codegen(d); 296 for (const auto& d : pd.defs) 297 kal::codegen(d); 298 299 { 300 auto f = pd.stmts.data(); 301 auto l = f + pd.stmts.size(); 302 303 REQUIRE(to_string(kal::mkfunc("kaleidoscope_body_0__", f, l)) == 304 "define double @kaleidoscope_body_0__() {\n" 305 "entry:\n" 306 " %calltmp = call double @sin(double 2.720000e+00)\n" 307 " %calltmp1 = call double @formula(double 1.000000e+00, " 308 "double 2.000000e+00, double 3.000000e+00)\n" 309 " %addtmp = fadd double %calltmp, %calltmp1\n" 310 " ret double 0.000000e+00\n" 311 "}\n"); 312 313 REQUIRE(to_string(kal::mkfunc("kaleidoscope_body_1__", f, l, true)) == 314 "define double @kaleidoscope_body_1__() {\n" 315 "entry:\n" 316 " %calltmp1 = call double @formula(double 1.000000e+00, " 317 "double 2.000000e+00, double 3.000000e+00)\n" 318 " ret double 0.000000e+00\n" 319 "}\n"); 320 } 321 322 kal::codegen_reset(); 323 } 324 } 325 }