kaleidoscope_codegen.cpp (6500B)
1 2 #include "kaleidoscope_codegen.hpp" 3 4 #include "kaleidoscope_ast.hpp" 5 6 #include <algorithm> 7 #include <string> 8 #include <utility> 9 #include <vector> 10 11 // clang-format off 12 #include <llvm/IR/LLVMContext.h> 13 #include <llvm/IR/IRBuilder.h> 14 #include <llvm/IR/Module.h> 15 // clang-format on 16 17 #include <llvm/ADT/APFloat.h> 18 #include <llvm/IR/Constants.h> 19 #include <llvm/IR/Value.h> 20 #include <llvm/IR/Verifier.h> 21 22 #include <llvm/IR/LegacyPassManager.h> 23 #include <llvm/Transforms/InstCombine/InstCombine.h> 24 #include <llvm/Transforms/Scalar.h> 25 #include <llvm/Transforms/Scalar/GVN.h> 26 27 namespace { 28 29 int 30 assert_fail(const char* expr, const char* file, int line, const char* func) 31 { 32 fprintf( 33 stderr, "Assertion failed: %s (%s: %s: %d)\n", expr, file, func, line); 34 fflush(NULL); 35 abort(); 36 37 return 0; // dummy 38 } 39 #define ASSERT_ALWAYS(x) ((x) || assert_fail(#x, __FILE__, __LINE__, __func__)) 40 #ifndef NDEBUG 41 #define ASSERT(x) 42 #else 43 #define ASSERT(x) ASSERT_ALWAYS(x) 44 #endif 45 46 inline llvm::LLVMContext& 47 context() 48 { 49 static llvm::LLVMContext ctx; 50 51 return ctx; 52 } 53 54 inline llvm::IRBuilder<>& 55 builder() 56 { 57 static llvm::IRBuilder<> b(context()); 58 59 return b; 60 } 61 62 inline llvm::Module* 63 module() 64 { 65 static auto m = std::make_unique<llvm::Module>("Kaleidoscope JIT", context()); 66 67 return m.get(); 68 } 69 70 inline llvm::legacy::FunctionPassManager* 71 fpm() 72 { 73 static auto f = [] { 74 auto f = std::make_unique<llvm::legacy::FunctionPassManager>(module()); 75 76 f->add(llvm::createInstructionCombiningPass()); 77 f->add(llvm::createReassociatePass()); 78 f->add(llvm::createGVNPass()); 79 f->add(llvm::createCFGSimplificationPass()); 80 81 f->doInitialization(); 82 return f; 83 }(); 84 85 return f.get(); 86 } 87 88 template<typename K, typename V> 89 class Map 90 { 91 private: 92 std::vector<std::pair<K, V>> d_; // data 93 94 auto find_(const K& k) 95 { 96 const auto& f = d_.begin(); 97 const auto& l = d_.end(); 98 auto r = std::lower_bound(f, l, k, [](const auto& pair, const auto& key) { 99 return pair.first < key; 100 }); 101 102 return std::make_pair(r, l); 103 } 104 105 public: 106 V find(const K& k) 107 { 108 auto r = find_(k); 109 110 if (r.first == r.second) 111 return nullptr; 112 return k == r.first->first ? r.first->second : nullptr; 113 } 114 115 void insert(const K& k, const V& v) 116 { 117 auto r = find_(k); 118 119 if (r.first != r.second && k == r.first->first) 120 assert(v == r.first->second); 121 else 122 d_.insert(r.first, std::make_pair(k, v)); 123 } 124 125 template<typename F> 126 void for_each(F f) 127 { 128 for (const auto& e : d_) 129 f(e.first, e.second); 130 } 131 132 void clear() { d_.clear(); } 133 }; 134 135 Map<std::string, llvm::Function*> funcs; 136 Map<std::string, llvm::Value*> names; 137 138 } // anonymous namespace 139 140 llvm::Value* 141 kal::codegen(const Number& n) 142 { 143 return llvm::ConstantFP::get(context(), llvm::APFloat{ n.value }); 144 } 145 146 llvm::Value* 147 kal::codegen(const Variable& v) 148 { 149 return names.find(v.name); 150 } 151 152 llvm::Value* 153 kal::codegen(const BinaryOp& b) 154 { 155 auto l = kal::codegen(b.lhs); 156 auto r = kal::codegen(b.rhs); 157 158 ASSERT_ALWAYS(l != nullptr); 159 ASSERT_ALWAYS(r != nullptr); 160 161 switch (b.op) { 162 case '*': 163 return builder().CreateFMul(l, r, "multmp"); 164 165 case '/': 166 return builder().CreateFDiv(l, r, "divtmp"); 167 168 case '+': 169 return builder().CreateFAdd(l, r, "addtmp"); 170 171 case '-': 172 return builder().CreateFSub(l, r, "subtmp"); 173 174 case '<': 175 l = builder().CreateFCmpULT(l, r, "cmptmp"); 176 return builder().CreateUIToFP( 177 l, llvm::Type::getDoubleTy(context()), "booltmp"); 178 179 default: 180 ASSERT_ALWAYS(false && "unknown operator"); 181 } 182 183 return nullptr; 184 } 185 186 llvm::Value* 187 kal::codegen(const Call& c) 188 { 189 auto fun = funcs.find(c.callee); 190 auto n = c.args.size(); 191 192 ASSERT_ALWAYS(fun != nullptr); 193 ASSERT_ALWAYS(n == fun->arg_size()); 194 195 std::vector<llvm::Value*> args(n); 196 197 for (auto i = 0u; i < n; i++) { 198 args[i] = kal::codegen(c.args[i]); 199 ASSERT_ALWAYS(args[i] != nullptr); 200 } 201 202 return builder().CreateCall(fun, args, "calltmp"); 203 } 204 205 llvm::Value* 206 kal::codegen(const kal::ASTNode& n) 207 { 208 using kal::node_type; 209 210 switch (node_type(n)) { 211 case kal::NodeType::Number: { 212 kal::Number num; 213 214 cast(n, num); 215 return kal::codegen(num); 216 } 217 218 case kal::NodeType::Variable: { 219 kal::Variable v; 220 221 cast(n, v); 222 return kal::codegen(v); 223 } 224 225 case kal::NodeType::BinaryOp: { 226 kal::BinaryOp op; 227 228 cast(n, op); 229 return kal::codegen(op); 230 } 231 232 case kal::NodeType::Call: { 233 kal::Call c; 234 235 cast(n, c); 236 return kal::codegen(c); 237 } 238 239 case kal::NodeType::None: 240 case kal::NodeType::Prototype: 241 case kal::NodeType::Function: 242 break; 243 } 244 245 ASSERT_ALWAYS(false && "unknown NodeType"); 246 return nullptr; 247 } 248 249 llvm::Function* 250 kal::codegen(const Prototype& p) 251 { 252 std::vector<llvm::Type*> dbls(p.params.size(), 253 llvm::Type::getDoubleTy(context())); 254 auto ftype = 255 llvm::FunctionType::get(llvm::Type::getDoubleTy(context()), dbls, false); 256 auto func = llvm::Function::Create( 257 ftype, llvm::Function::ExternalLinkage, p.name, module()); 258 259 { 260 auto i = 0; 261 262 for (auto& arg : func->args()) 263 arg.setName(p.params[i++]); 264 } 265 266 funcs.insert(p.name, func); 267 return func; 268 } 269 270 llvm::Function* 271 kal::codegen(const Function& f, bool optimize) 272 { 273 auto func = funcs.find(f.proto.name); 274 275 if (func == nullptr) { 276 func = kal::codegen(f.proto); 277 if (func != nullptr) 278 funcs.insert(f.proto.name, func); 279 } 280 281 ASSERT_ALWAYS(func != nullptr); 282 283 auto bb = llvm::BasicBlock::Create(context(), "entry", func); 284 285 builder().SetInsertPoint(bb); 286 287 names.clear(); 288 for (auto& arg : func->args()) 289 names.insert(arg.getName(), &arg); 290 291 auto body = kal::codegen(f.body); 292 293 ASSERT_ALWAYS(body != nullptr); 294 295 builder().CreateRet(body); 296 llvm::verifyFunction(*func); 297 if (optimize) 298 fpm()->run(*func); 299 300 return func; 301 } 302 303 llvm::Function* 304 kal::mkfunc(const std::string& name, ASTNode* f, ASTNode* l, bool optimize) 305 { 306 if (f == l) 307 return nullptr; 308 309 auto func = codegen(kal::Prototype{ name, {} }); 310 311 ASSERT_ALWAYS(func != nullptr); 312 313 auto bb = llvm::BasicBlock::Create(context(), "entry", func); 314 315 builder().SetInsertPoint(bb); 316 while (f != l) 317 kal::codegen(*f++); 318 builder().CreateRet(kal::codegen(kal::Number{ 0 })); 319 320 llvm::verifyFunction(*func); 321 if (optimize) 322 fpm()->run(*func); 323 return func; 324 } 325 326 void 327 kal::codegen_reset(void) 328 { 329 funcs.for_each([](const auto&, const auto& v) { v->eraseFromParent(); }); 330 funcs.clear(); 331 names.clear(); 332 }