[spirv-writer] Emit logical and and logical or

This CL adds support for the &&  and || operators to the SPIR-V backend.

Bug: tint:5
Change-Id: I63b23d9904b5b8027e189034d24949df71cbbe42
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23501
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index c0dbf31..3f2d289 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -226,6 +226,11 @@
       Instruction{spv::Op::OpCapability, {Operand::Int(cap)}});
 }
 
+void Builder::GenerateLabel(uint32_t id) {
+  push_function_inst(spv::Op::OpLabel, {Operand::Int(id)});
+  current_label_id_ = id;
+}
+
 uint32_t Builder::GenerateU32Literal(uint32_t val) {
   ast::type::U32Type u32;
   ast::SintLiteral lit(&u32, val);
@@ -1083,7 +1088,72 @@
   return result_id;
 }
 
+uint32_t Builder::GenerateShortCircuitBinaryExpression(
+    ast::BinaryExpression* expr) {
+  auto lhs_id = GenerateExpression(expr->lhs());
+  if (lhs_id == 0) {
+    return false;
+  }
+  lhs_id = GenerateLoadIfNeeded(expr->lhs()->result_type(), lhs_id);
+
+  auto original_label_id = current_label_id_;
+
+  auto type_id = GenerateTypeIfNeeded(expr->result_type());
+  if (type_id == 0) {
+    return 0;
+  }
+
+  auto merge_block = result_op();
+  auto merge_block_id = merge_block.to_i();
+
+  auto block = result_op();
+  auto block_id = block.to_i();
+
+  auto true_block_id = block_id;
+  auto false_block_id = merge_block_id;
+
+  // For a logical or we want to only check the RHS if the LHS is failed.
+  if (expr->IsLogicalOr()) {
+    std::swap(true_block_id, false_block_id);
+  }
+
+  push_function_inst(spv::Op::OpSelectionMerge,
+                     {Operand::Int(merge_block_id),
+                      Operand::Int(SpvSelectionControlMaskNone)});
+  push_function_inst(spv::Op::OpBranchConditional,
+                     {Operand::Int(lhs_id), Operand::Int(true_block_id),
+                      Operand::Int(false_block_id)});
+
+  // Output block to check the RHS
+  GenerateLabel(block_id);
+  auto rhs_id = GenerateExpression(expr->rhs());
+  if (rhs_id == 0) {
+    return 0;
+  }
+  rhs_id = GenerateLoadIfNeeded(expr->rhs()->result_type(), rhs_id);
+
+  push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+
+  // Output the merge block
+  GenerateLabel(merge_block_id);
+
+  auto result = result_op();
+  auto result_id = result.to_i();
+
+  push_function_inst(spv::Op::OpPhi,
+                     {Operand::Int(type_id), result, Operand::Int(lhs_id),
+                      Operand::Int(original_label_id), Operand::Int(rhs_id),
+                      Operand::Int(block_id)});
+
+  return result_id;
+}
+
 uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
+  // There is special logic for short circuiting operators.
+  if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
+    return GenerateShortCircuitBinaryExpression(expr);
+  }
+
   auto lhs_id = GenerateExpression(expr->lhs());
   if (lhs_id == 0) {
     return 0;
@@ -1466,7 +1536,7 @@
                       Operand::Int(false_block_id)});
 
   // Output true block
-  push_function_inst(spv::Op::OpLabel, {true_block});
+  GenerateLabel(true_block_id);
   if (!GenerateStatementList(true_body)) {
     return false;
   }
@@ -1477,7 +1547,7 @@
 
   // Start the false block if needed
   if (false_block_id != merge_block_id) {
-    push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
+    GenerateLabel(false_block_id);
 
     auto* else_stmt = else_stmts[cur_else_idx].get();
     // Handle the else case by just outputting the statements.
@@ -1497,7 +1567,7 @@
   }
 
   // Output the merge block
-  push_function_inst(spv::Op::OpLabel, {merge_block});
+  GenerateLabel(merge_block_id);
 
   return true;
 }
@@ -1568,7 +1638,7 @@
       generated_default = true;
     }
 
-    push_function_inst(spv::Op::OpLabel, {Operand::Int(case_ids[i])});
+    GenerateLabel(case_ids[i]);
     if (!GenerateStatementList(item->body())) {
       return false;
     }
@@ -1585,13 +1655,13 @@
   }
 
   if (!generated_default) {
-    push_function_inst(spv::Op::OpLabel, {Operand::Int(default_block_id)});
+    GenerateLabel(default_block_id);
     push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
   }
 
   merge_stack_.pop_back();
 
-  push_function_inst(spv::Op::OpLabel, {Operand::Int(merge_block_id)});
+  GenerateLabel(merge_block_id);
   return true;
 }
 
@@ -1613,7 +1683,7 @@
   auto loop_header = result_op();
   auto loop_header_id = loop_header.to_i();
   push_function_inst(spv::Op::OpBranch, {Operand::Int(loop_header_id)});
-  push_function_inst(spv::Op::OpLabel, {loop_header});
+  GenerateLabel(loop_header_id);
 
   auto merge_block = result_op();
   auto merge_block_id = merge_block.to_i();
@@ -1632,7 +1702,7 @@
   merge_stack_.push_back(merge_block_id);
 
   push_function_inst(spv::Op::OpBranch, {Operand::Int(body_block_id)});
-  push_function_inst(spv::Op::OpLabel, {body_block});
+  GenerateLabel(body_block_id);
   if (!GenerateStatementList(stmt->body())) {
     return false;
   }
@@ -1642,7 +1712,7 @@
     push_function_inst(spv::Op::OpBranch, {Operand::Int(continue_block_id)});
   }
 
-  push_function_inst(spv::Op::OpLabel, {continue_block});
+  GenerateLabel(continue_block_id);
   if (!GenerateStatementList(stmt->continuing())) {
     return false;
   }
@@ -1651,7 +1721,7 @@
   merge_stack_.pop_back();
   continue_stack_.pop_back();
 
-  push_function_inst(spv::Op::OpLabel, {merge_block});
+  GenerateLabel(merge_block_id);
 
   return true;
 }
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index 094add4..a111094 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -159,7 +159,10 @@
 
   /// Adds a function to the builder
   /// @param func the function to add
-  void push_function(const Function& func) { functions_.push_back(func); }
+  void push_function(const Function& func) {
+    functions_.push_back(func);
+    current_label_id_ = func.label_id();
+  }
   /// @returns the functions
   const std::vector<Function>& functions() const { return functions_; }
   /// Pushes an instruction to the current function
@@ -183,6 +186,9 @@
   /// @returns the SPIR-V builtin or SpvBuiltInMax on error.
   SpvBuiltIn ConvertBuiltin(ast::Builtin builtin) const;
 
+  /// Generates a label for the given id
+  /// @param id the id to use for the label
+  void GenerateLabel(uint32_t id);
   /// Generates a uint32_t literal.
   /// @param val the value to generate
   /// @returns the ID of the generated literal
@@ -291,6 +297,10 @@
   /// @param expr the expression to generate
   /// @returns the expression ID on success or 0 otherwise
   uint32_t GenerateBinaryExpression(ast::BinaryExpression* expr);
+  /// Generates a short circuting binary expression
+  /// @param expr the expression to generate
+  /// @returns teh expression ID on success or 0 otherwise
+  uint32_t GenerateShortCircuitBinaryExpression(ast::BinaryExpression* expr);
   /// Generates a call expression
   /// @param expr the expression to generate
   /// @returns the expression ID on success or 0 otherwise
@@ -395,6 +405,7 @@
   ast::Module* mod_;
   std::string error_;
   uint32_t next_id_ = 1;
+  uint32_t current_label_id_ = 0;
   std::vector<Instruction> capabilities_;
   std::vector<Instruction> preamble_;
   std::vector<Instruction> debug_;
diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc
index 3d4899f..62986ec 100644
--- a/src/writer/spirv/builder_binary_expression_test.cc
+++ b/src/writer/spirv/builder_binary_expression_test.cc
@@ -16,10 +16,12 @@
 
 #include "gtest/gtest.h"
 #include "src/ast/binary_expression.h"
+#include "src/ast/bool_literal.h"
 #include "src/ast/float_literal.h"
 #include "src/ast/identifier_expression.h"
 #include "src/ast/scalar_constructor_expression.h"
 #include "src/ast/sint_literal.h"
+#include "src/ast/type/bool_type.h"
 #include "src/ast/type/f32_type.h"
 #include "src/ast/type/i32_type.h"
 #include "src/ast/type/matrix_type.h"
@@ -866,6 +868,216 @@
 )");
 }
 
+TEST_F(BuilderTest, Binary_LogicalAnd) {
+  ast::type::I32Type i32;
+
+  auto lhs = std::make_unique<ast::BinaryExpression>(
+      ast::BinaryOp::kEqual,
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 1)),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 2)));
+
+  auto rhs = std::make_unique<ast::BinaryExpression>(
+      ast::BinaryOp::kEqual,
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 3)),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 4)));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  b.GenerateLabel(b.next_id());
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
+%3 = OpConstant %2 1
+%4 = OpConstant %2 2
+%6 = OpTypeBool
+%9 = OpConstant %2 3
+%10 = OpConstant %2 4
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%1 = OpLabel
+%5 = OpIEqual %6 %3 %4
+OpSelectionMerge %7 None
+OpBranchConditional %5 %8 %7
+%8 = OpLabel
+%11 = OpIEqual %6 %9 %10
+OpBranch %7
+%7 = OpLabel
+%12 = OpPhi %6 %5 %1 %11 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_LogicalAnd_WithLoads) {
+  ast::type::BoolType bool_type;
+
+  auto a_var = std::make_unique<ast::Variable>(
+      "a", ast::StorageClass::kFunction, &bool_type);
+  a_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true)));
+  auto b_var = std::make_unique<ast::Variable>(
+      "b", ast::StorageClass::kFunction, &bool_type);
+  b_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, false)));
+
+  auto lhs = std::make_unique<ast::IdentifierExpression>("a");
+  auto rhs = std::make_unique<ast::IdentifierExpression>("b");
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(a_var.get());
+  td.RegisterVariableForTesting(b_var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kLogicalAnd, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  b.GenerateLabel(b.next_id());
+
+  ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error();
+  ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
+%3 = OpConstantTrue %2
+%5 = OpTypePointer Function %2
+%4 = OpVariable %5 Function %3
+%6 = OpConstantFalse %2
+%7 = OpVariable %5 Function %6
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%1 = OpLabel
+%8 = OpLoad %2 %4
+OpSelectionMerge %9 None
+OpBranchConditional %8 %10 %9
+%10 = OpLabel
+%11 = OpLoad %2 %7
+OpBranch %9
+%9 = OpLabel
+%12 = OpPhi %2 %8 %1 %11 %10
+)");
+}
+
+TEST_F(BuilderTest, Binary_LogicalOr) {
+  ast::type::I32Type i32;
+
+  auto lhs = std::make_unique<ast::BinaryExpression>(
+      ast::BinaryOp::kEqual,
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 1)),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 2)));
+
+  auto rhs = std::make_unique<ast::BinaryExpression>(
+      ast::BinaryOp::kEqual,
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 3)),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::SintLiteral>(&i32, 4)));
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+
+  ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  b.GenerateLabel(b.next_id());
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeInt 32 1
+%3 = OpConstant %2 1
+%4 = OpConstant %2 2
+%6 = OpTypeBool
+%9 = OpConstant %2 3
+%10 = OpConstant %2 4
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%1 = OpLabel
+%5 = OpIEqual %6 %3 %4
+OpSelectionMerge %7 None
+OpBranchConditional %5 %7 %8
+%8 = OpLabel
+%11 = OpIEqual %6 %9 %10
+OpBranch %7
+%7 = OpLabel
+%12 = OpPhi %6 %5 %1 %11 %8
+)");
+}
+
+TEST_F(BuilderTest, Binary_LogicalOr_WithLoads) {
+  ast::type::BoolType bool_type;
+
+  auto a_var = std::make_unique<ast::Variable>(
+      "a", ast::StorageClass::kFunction, &bool_type);
+  a_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true)));
+  auto b_var = std::make_unique<ast::Variable>(
+      "b", ast::StorageClass::kFunction, &bool_type);
+  b_var->set_constructor(std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, false)));
+
+  auto lhs = std::make_unique<ast::IdentifierExpression>("a");
+  auto rhs = std::make_unique<ast::IdentifierExpression>("b");
+
+  Context ctx;
+  ast::Module mod;
+  TypeDeterminer td(&ctx, &mod);
+  td.RegisterVariableForTesting(a_var.get());
+  td.RegisterVariableForTesting(b_var.get());
+
+  ast::BinaryExpression expr(ast::BinaryOp::kLogicalOr, std::move(lhs),
+                             std::move(rhs));
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b(&mod);
+  b.push_function(Function{});
+  b.GenerateLabel(b.next_id());
+
+  ASSERT_TRUE(b.GenerateGlobalVariable(a_var.get())) << b.error();
+  ASSERT_TRUE(b.GenerateGlobalVariable(b_var.get())) << b.error();
+
+  EXPECT_EQ(b.GenerateBinaryExpression(&expr), 12u) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
+%3 = OpConstantTrue %2
+%5 = OpTypePointer Function %2
+%4 = OpVariable %5 Function %3
+%6 = OpConstantFalse %2
+%7 = OpVariable %5 Function %6
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(%1 = OpLabel
+%8 = OpLoad %2 %4
+OpSelectionMerge %9 None
+OpBranchConditional %8 %9 %10
+%10 = OpLabel
+%11 = OpLoad %2 %7
+OpBranch %9
+%9 = OpLabel
+%12 = OpPhi %2 %8 %1 %11 %10
+)");
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace writer
diff --git a/src/writer/spirv/function.h b/src/writer/spirv/function.h
index 26ddfa7..ddc2d0d 100644
--- a/src/writer/spirv/function.h
+++ b/src/writer/spirv/function.h
@@ -52,6 +52,9 @@
   /// @returns the declaration
   const Instruction& declaration() const { return declaration_; }
 
+  /// @returns the function label id
+  uint32_t label_id() const { return label_op_.to_i(); }
+
   /// Adds an instruction to the instruction list
   /// @param op the op to set
   /// @param operands the operands for the instruction