[spirv-writer] Add elseif support.

This CL adds support for having elseif statements after an if statement.

Bug: tint:5
Change-Id: I3cd3c5bddaa57c998b1a3fbee7bd87536533301d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19500
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc
index 14bcb19..8bf9da4 100644
--- a/src/writer/spirv/builder.cc
+++ b/src/writer/spirv/builder.cc
@@ -630,8 +630,12 @@
   return result_id;
 }
 
-bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
-  auto cond_id = GenerateExpression(stmt->condition());
+bool Builder::GenerateConditionalBlock(
+    ast::Expression* cond,
+    const ast::StatementList& true_body,
+    size_t cur_else_idx,
+    const ast::ElseStatementList& else_stmts) {
+  auto cond_id = GenerateExpression(cond);
   if (cond_id == 0) {
     return false;
   }
@@ -646,10 +650,10 @@
   auto true_block = result_op();
   auto true_block_id = true_block.to_i();
 
-  // if there are no else statements we branch on false to the merge block
+  // if there are no more else statements we branch on false to the merge block
   // otherwise we branch to the false block
   auto false_block_id =
-      stmt->has_else_statements() ? next_id() : merge_block_id;
+      cur_else_idx < else_stmts.size() ? next_id() : merge_block_id;
 
   push_function_inst(spv::Op::OpBranchConditional,
                      {Operand::Int(cond_id), Operand::Int(true_block_id),
@@ -657,28 +661,33 @@
 
   // Output true block
   push_function_inst(spv::Op::OpLabel, {true_block});
-  for (const auto& inst : stmt->body()) {
-    if (!GenerateStatement(inst.get())) {
-      return false;
-    }
+  if (!GenerateStatementList(true_body)) {
+    return false;
   }
-
   // TODO(dsinclair): The branch should be optional based on how the
   // StatementList ended ...
   push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
 
+  // Start the false block if needed
   if (false_block_id != merge_block_id) {
     push_function_inst(spv::Op::OpLabel, {Operand::Int(false_block_id)});
 
-    for (const auto& else_stmt : stmt->else_statements()) {
-      if (!GenerateElseStatement(else_stmt.get())) {
+    auto* else_stmt = else_stmts[cur_else_idx].get();
+    // Handle the else case by just outputting the statements.
+    if (!else_stmt->HasCondition()) {
+      if (!GenerateStatementList(else_stmt->body())) {
         return false;
       }
+      // TODO(dsinclair): The branch should be optional based on how the
+      // StatementList ended ...
+      push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
+    } else {
+      if (!GenerateConditionalBlock(else_stmt->condition(), else_stmt->body(),
+                                    cur_else_idx + 1, else_stmts)) {
+        return false;
+      }
+      push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
     }
-
-    // TODO(dsinclair): The branch should be optional based on how the
-    // StatementList ended ...
-    push_function_inst(spv::Op::OpBranch, {Operand::Int(merge_block_id)});
   }
 
   // Output the merge block
@@ -687,18 +696,11 @@
   return true;
 }
 
-bool Builder::GenerateElseStatement(ast::ElseStatement* stmt) {
-  // TODO(dsinclair): handle else if
-  if (stmt->HasCondition()) {
-    error_ = "else if not handled yet";
+bool Builder::GenerateIfStatement(ast::IfStatement* stmt) {
+  if (!GenerateConditionalBlock(stmt->condition(), stmt->body(), 0,
+                                stmt->else_statements())) {
     return false;
   }
-
-  for (const auto& inst : stmt->body()) {
-    if (!GenerateStatement(inst.get())) {
-      return false;
-    }
-  }
   return true;
 }
 
@@ -716,6 +718,15 @@
   return true;
 }
 
+bool Builder::GenerateStatementList(const ast::StatementList& list) {
+  for (const auto& inst : list) {
+    if (!GenerateStatement(inst.get())) {
+      return false;
+    }
+  }
+  return true;
+}
+
 bool Builder::GenerateStatement(ast::Statement* stmt) {
   if (stmt->IsAssign()) {
     return GenerateAssignStatement(stmt->AsAssign());
diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h
index dab1e3e..0faf52d 100644
--- a/src/writer/spirv/builder.h
+++ b/src/writer/spirv/builder.h
@@ -22,6 +22,7 @@
 
 #include "spirv/unified1/spirv.h"
 #include "src/ast/builtin.h"
+#include "src/ast/else_statement.h"
 #include "src/ast/literal.h"
 #include "src/ast/module.h"
 #include "src/ast/struct_member.h"
@@ -148,10 +149,6 @@
   /// @param assign the statement to generate
   /// @returns true if the statement was successfully generated
   bool GenerateAssignStatement(ast::AssignmentStatement* assign);
-  /// Generates an else statement
-  /// @param stmt the statement to generate
-  /// @returns true on successfull generation
-  bool GenerateElseStatement(ast::ElseStatement* stmt);
   /// Generates an entry point instruction
   /// @param ep the entry point
   /// @returns true if the instruction was generated, false otherwise
@@ -209,10 +206,24 @@
   /// @param stmt the statement to generate
   /// @returns true on success, false otherwise
   bool GenerateReturnStatement(ast::ReturnStatement* stmt);
+  /// Generates a conditional section merge block
+  /// @param cond the condition
+  /// @param true_body the statements making up the true block
+  /// @param cur_else_idx the index of the current else statement to process
+  /// @param else_stmts the list of all else statements
+  /// @returns true on success, false on failure
+  bool GenerateConditionalBlock(ast::Expression* cond,
+                                const ast::StatementList& true_body,
+                                size_t cur_else_idx,
+                                const ast::ElseStatementList& else_stmts);
   /// Generates a statement
   /// @param stmt the statement to generate
   /// @returns true if the statement was generated
   bool GenerateStatement(ast::Statement* stmt);
+  /// Generates a list of statements
+  /// @param list the statement list to generate
+  /// @returns true on successful generation
+  bool GenerateStatementList(const ast::StatementList& list);
   /// Geneates an OpStore
   /// @param to the ID to store too
   /// @param from the ID to store from
diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc
index 15d6809..e7f8bfb 100644
--- a/src/writer/spirv/builder_if_test.cc
+++ b/src/writer/spirv/builder_if_test.cc
@@ -39,6 +39,8 @@
 TEST_F(BuilderTest, If_Empty) {
   ast::type::BoolType bool_type;
 
+  // if (true) {
+  // }
   auto cond = std::make_unique<ast::ScalarConstructorExpression>(
       std::make_unique<ast::BoolLiteral>(&bool_type, true));
 
@@ -68,6 +70,9 @@
   ast::type::BoolType bool_type;
   ast::type::I32Type i32;
 
+  // if (true) {
+  //   v = 2;
+  // }
   auto var =
       std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
 
@@ -114,6 +119,11 @@
   ast::type::BoolType bool_type;
   ast::type::I32Type i32;
 
+  // if (true) {
+  //   v = 2;
+  // } else {
+  //   v = 3;
+  // }
   auto var =
       std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
 
@@ -171,9 +181,186 @@
 )");
 }
 
-TEST_F(BuilderTest, DISABLED_If_WithElseIf) {}
+TEST_F(BuilderTest, If_WithElseIf) {
+  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
 
-TEST_F(BuilderTest, DISABLED_If_WithMultiple) {}
+  // if (true) {
+  //   v = 2;
+  // } elseif (true) {
+  //   v = 3;
+  // }
+  auto var =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 2))));
+
+  ast::StatementList else_body;
+  else_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 3))));
+
+  auto else_cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::ElseStatementList else_stmts;
+  else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+      std::move(else_cond), std::move(else_body)));
+
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::IfStatement expr(std::move(cond), std::move(body));
+  expr.set_else_statements(std::move(else_stmts));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%4 = OpTypeBool
+%5 = OpConstantTrue %4
+%9 = OpConstant %3 2
+%12 = OpConstant %3 3
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpSelectionMerge %6 None
+OpBranchConditional %5 %7 %8
+%7 = OpLabel
+OpStore %1 %9
+OpBranch %6
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %5 %11 %10
+%11 = OpLabel
+OpStore %1 %12
+OpBranch %10
+%10 = OpLabel
+OpBranch %6
+%6 = OpLabel
+)");
+}
+
+TEST_F(BuilderTest, If_WithMultiple) {
+  ast::type::BoolType bool_type;
+  ast::type::I32Type i32;
+
+  // if (true) {
+  //   v = 2;
+  // } elseif (true) {
+  //   v = 3;
+  // } elseif (false) {
+  //   v = 4;
+  // } else {
+  //   v = 5;
+  // }
+  auto var =
+      std::make_unique<ast::Variable>("v", ast::StorageClass::kPrivate, &i32);
+
+  ast::StatementList body;
+  body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 2))));
+  ast::StatementList elseif_1_body;
+  elseif_1_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 3))));
+  ast::StatementList elseif_2_body;
+  elseif_2_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 4))));
+  ast::StatementList else_body;
+  else_body.push_back(std::make_unique<ast::AssignmentStatement>(
+      std::make_unique<ast::IdentifierExpression>("v"),
+      std::make_unique<ast::ScalarConstructorExpression>(
+          std::make_unique<ast::IntLiteral>(&i32, 5))));
+
+  auto elseif_1_cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+  auto elseif_2_cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, false));
+
+  ast::ElseStatementList else_stmts;
+  else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+      std::move(elseif_1_cond), std::move(elseif_1_body)));
+  else_stmts.push_back(std::make_unique<ast::ElseStatement>(
+      std::move(elseif_2_cond), std::move(elseif_2_body)));
+  else_stmts.push_back(
+      std::make_unique<ast::ElseStatement>(std::move(else_body)));
+
+  auto cond = std::make_unique<ast::ScalarConstructorExpression>(
+      std::make_unique<ast::BoolLiteral>(&bool_type, true));
+
+  ast::IfStatement expr(std::move(cond), std::move(body));
+  expr.set_else_statements(std::move(else_stmts));
+
+  Context ctx;
+  TypeDeterminer td(&ctx);
+  td.RegisterVariableForTesting(var.get());
+
+  ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
+
+  Builder b;
+  b.push_function(Function{});
+  ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error();
+
+  EXPECT_TRUE(b.GenerateIfStatement(&expr)) << b.error();
+  EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeInt 32 1
+%2 = OpTypePointer Private %3
+%1 = OpVariable %2 Private
+%4 = OpTypeBool
+%5 = OpConstantTrue %4
+%9 = OpConstant %3 2
+%13 = OpConstant %3 3
+%14 = OpConstantFalse %4
+%18 = OpConstant %3 4
+%19 = OpConstant %3 5
+)");
+  EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
+            R"(OpSelectionMerge %6 None
+OpBranchConditional %5 %7 %8
+%7 = OpLabel
+OpStore %1 %9
+OpBranch %6
+%8 = OpLabel
+OpSelectionMerge %10 None
+OpBranchConditional %5 %11 %12
+%11 = OpLabel
+OpStore %1 %13
+OpBranch %10
+%12 = OpLabel
+OpSelectionMerge %15 None
+OpBranchConditional %14 %16 %17
+%16 = OpLabel
+OpStore %1 %18
+OpBranch %15
+%17 = OpLabel
+OpStore %1 %19
+OpBranch %15
+%15 = OpLabel
+OpBranch %10
+%10 = OpLabel
+OpBranch %6
+%6 = OpLabel
+)");
+}
 
 TEST_F(BuilderTest, DISABLED_If_WithBreak) {
   // if (a) {