llvm-journey

LLVM Journey
git clone git://0xff.ir/g/llvm-journey.git
Log | Files | Refs | README | LICENSE

kaleidoscope_codegen.cpp (6500B)


      1 
      2 #include "kaleidoscope_codegen.hpp"
      3 
      4 #include "kaleidoscope_ast.hpp"
      5 
      6 #include <algorithm>
      7 #include <string>
      8 #include <utility>
      9 #include <vector>
     10 
     11 // clang-format off
     12 #include <llvm/IR/LLVMContext.h>
     13 #include <llvm/IR/IRBuilder.h>
     14 #include <llvm/IR/Module.h>
     15 // clang-format on
     16 
     17 #include <llvm/ADT/APFloat.h>
     18 #include <llvm/IR/Constants.h>
     19 #include <llvm/IR/Value.h>
     20 #include <llvm/IR/Verifier.h>
     21 
     22 #include <llvm/IR/LegacyPassManager.h>
     23 #include <llvm/Transforms/InstCombine/InstCombine.h>
     24 #include <llvm/Transforms/Scalar.h>
     25 #include <llvm/Transforms/Scalar/GVN.h>
     26 
     27 namespace {
     28 
     29 int
     30 assert_fail(const char* expr, const char* file, int line, const char* func)
     31 {
     32   fprintf(
     33     stderr, "Assertion failed: %s (%s: %s: %d)\n", expr, file, func, line);
     34   fflush(NULL);
     35   abort();
     36 
     37   return 0; // dummy
     38 }
     39 #define ASSERT_ALWAYS(x) ((x) || assert_fail(#x, __FILE__, __LINE__, __func__))
     40 #ifndef NDEBUG
     41 #define ASSERT(x)
     42 #else
     43 #define ASSERT(x) ASSERT_ALWAYS(x)
     44 #endif
     45 
     46 inline llvm::LLVMContext&
     47 context()
     48 {
     49   static llvm::LLVMContext ctx;
     50 
     51   return ctx;
     52 }
     53 
     54 inline llvm::IRBuilder<>&
     55 builder()
     56 {
     57   static llvm::IRBuilder<> b(context());
     58 
     59   return b;
     60 }
     61 
     62 inline llvm::Module*
     63 module()
     64 {
     65   static auto m = std::make_unique<llvm::Module>("Kaleidoscope JIT", context());
     66 
     67   return m.get();
     68 }
     69 
     70 inline llvm::legacy::FunctionPassManager*
     71 fpm()
     72 {
     73   static auto f = [] {
     74     auto f = std::make_unique<llvm::legacy::FunctionPassManager>(module());
     75 
     76     f->add(llvm::createInstructionCombiningPass());
     77     f->add(llvm::createReassociatePass());
     78     f->add(llvm::createGVNPass());
     79     f->add(llvm::createCFGSimplificationPass());
     80 
     81     f->doInitialization();
     82     return f;
     83   }();
     84 
     85   return f.get();
     86 }
     87 
     88 template<typename K, typename V>
     89 class Map
     90 {
     91 private:
     92   std::vector<std::pair<K, V>> d_; // data
     93 
     94   auto find_(const K& k)
     95   {
     96     const auto& f = d_.begin();
     97     const auto& l = d_.end();
     98     auto r = std::lower_bound(f, l, k, [](const auto& pair, const auto& key) {
     99       return pair.first < key;
    100     });
    101 
    102     return std::make_pair(r, l);
    103   }
    104 
    105 public:
    106   V find(const K& k)
    107   {
    108     auto r = find_(k);
    109 
    110     if (r.first == r.second)
    111       return nullptr;
    112     return k == r.first->first ? r.first->second : nullptr;
    113   }
    114 
    115   void insert(const K& k, const V& v)
    116   {
    117     auto r = find_(k);
    118 
    119     if (r.first != r.second && k == r.first->first)
    120       assert(v == r.first->second);
    121     else
    122       d_.insert(r.first, std::make_pair(k, v));
    123   }
    124 
    125   template<typename F>
    126   void for_each(F f)
    127   {
    128     for (const auto& e : d_)
    129       f(e.first, e.second);
    130   }
    131 
    132   void clear() { d_.clear(); }
    133 };
    134 
    135 Map<std::string, llvm::Function*> funcs;
    136 Map<std::string, llvm::Value*> names;
    137 
    138 } // anonymous namespace
    139 
    140 llvm::Value*
    141 kal::codegen(const Number& n)
    142 {
    143   return llvm::ConstantFP::get(context(), llvm::APFloat{ n.value });
    144 }
    145 
    146 llvm::Value*
    147 kal::codegen(const Variable& v)
    148 {
    149   return names.find(v.name);
    150 }
    151 
    152 llvm::Value*
    153 kal::codegen(const BinaryOp& b)
    154 {
    155   auto l = kal::codegen(b.lhs);
    156   auto r = kal::codegen(b.rhs);
    157 
    158   ASSERT_ALWAYS(l != nullptr);
    159   ASSERT_ALWAYS(r != nullptr);
    160 
    161   switch (b.op) {
    162     case '*':
    163       return builder().CreateFMul(l, r, "multmp");
    164 
    165     case '/':
    166       return builder().CreateFDiv(l, r, "divtmp");
    167 
    168     case '+':
    169       return builder().CreateFAdd(l, r, "addtmp");
    170 
    171     case '-':
    172       return builder().CreateFSub(l, r, "subtmp");
    173 
    174     case '<':
    175       l = builder().CreateFCmpULT(l, r, "cmptmp");
    176       return builder().CreateUIToFP(
    177         l, llvm::Type::getDoubleTy(context()), "booltmp");
    178 
    179     default:
    180       ASSERT_ALWAYS(false && "unknown operator");
    181   }
    182 
    183   return nullptr;
    184 }
    185 
    186 llvm::Value*
    187 kal::codegen(const Call& c)
    188 {
    189   auto fun = funcs.find(c.callee);
    190   auto n = c.args.size();
    191 
    192   ASSERT_ALWAYS(fun != nullptr);
    193   ASSERT_ALWAYS(n == fun->arg_size());
    194 
    195   std::vector<llvm::Value*> args(n);
    196 
    197   for (auto i = 0u; i < n; i++) {
    198     args[i] = kal::codegen(c.args[i]);
    199     ASSERT_ALWAYS(args[i] != nullptr);
    200   }
    201 
    202   return builder().CreateCall(fun, args, "calltmp");
    203 }
    204 
    205 llvm::Value*
    206 kal::codegen(const kal::ASTNode& n)
    207 {
    208   using kal::node_type;
    209 
    210   switch (node_type(n)) {
    211     case kal::NodeType::Number: {
    212       kal::Number num;
    213 
    214       cast(n, num);
    215       return kal::codegen(num);
    216     }
    217 
    218     case kal::NodeType::Variable: {
    219       kal::Variable v;
    220 
    221       cast(n, v);
    222       return kal::codegen(v);
    223     }
    224 
    225     case kal::NodeType::BinaryOp: {
    226       kal::BinaryOp op;
    227 
    228       cast(n, op);
    229       return kal::codegen(op);
    230     }
    231 
    232     case kal::NodeType::Call: {
    233       kal::Call c;
    234 
    235       cast(n, c);
    236       return kal::codegen(c);
    237     }
    238 
    239     case kal::NodeType::None:
    240     case kal::NodeType::Prototype:
    241     case kal::NodeType::Function:
    242       break;
    243   }
    244 
    245   ASSERT_ALWAYS(false && "unknown NodeType");
    246   return nullptr;
    247 }
    248 
    249 llvm::Function*
    250 kal::codegen(const Prototype& p)
    251 {
    252   std::vector<llvm::Type*> dbls(p.params.size(),
    253                                 llvm::Type::getDoubleTy(context()));
    254   auto ftype =
    255     llvm::FunctionType::get(llvm::Type::getDoubleTy(context()), dbls, false);
    256   auto func = llvm::Function::Create(
    257     ftype, llvm::Function::ExternalLinkage, p.name, module());
    258 
    259   {
    260     auto i = 0;
    261 
    262     for (auto& arg : func->args())
    263       arg.setName(p.params[i++]);
    264   }
    265 
    266   funcs.insert(p.name, func);
    267   return func;
    268 }
    269 
    270 llvm::Function*
    271 kal::codegen(const Function& f, bool optimize)
    272 {
    273   auto func = funcs.find(f.proto.name);
    274 
    275   if (func == nullptr) {
    276     func = kal::codegen(f.proto);
    277     if (func != nullptr)
    278       funcs.insert(f.proto.name, func);
    279   }
    280 
    281   ASSERT_ALWAYS(func != nullptr);
    282 
    283   auto bb = llvm::BasicBlock::Create(context(), "entry", func);
    284 
    285   builder().SetInsertPoint(bb);
    286 
    287   names.clear();
    288   for (auto& arg : func->args())
    289     names.insert(arg.getName(), &arg);
    290 
    291   auto body = kal::codegen(f.body);
    292 
    293   ASSERT_ALWAYS(body != nullptr);
    294 
    295   builder().CreateRet(body);
    296   llvm::verifyFunction(*func);
    297   if (optimize)
    298     fpm()->run(*func);
    299 
    300   return func;
    301 }
    302 
    303 llvm::Function*
    304 kal::mkfunc(const std::string& name, ASTNode* f, ASTNode* l, bool optimize)
    305 {
    306   if (f == l)
    307     return nullptr;
    308 
    309   auto func = codegen(kal::Prototype{ name, {} });
    310 
    311   ASSERT_ALWAYS(func != nullptr);
    312 
    313   auto bb = llvm::BasicBlock::Create(context(), "entry", func);
    314 
    315   builder().SetInsertPoint(bb);
    316   while (f != l)
    317     kal::codegen(*f++);
    318   builder().CreateRet(kal::codegen(kal::Number{ 0 }));
    319 
    320   llvm::verifyFunction(*func);
    321   if (optimize)
    322     fpm()->run(*func);
    323   return func;
    324 }
    325 
    326 void
    327 kal::codegen_reset(void)
    328 {
    329   funcs.for_each([](const auto&, const auto& v) { v->eraseFromParent(); });
    330   funcs.clear();
    331   names.clear();
    332 }