Add instance step mode to vertex pulling transform

Bug: dawn:480
Change-Id: Icf650b7f340528e6a49d68d155fd9becc212e623
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/26440
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: dan sinclair <dsinclair@chromium.org>
diff --git a/src/ast/transform/vertex_pulling_transform.cc b/src/ast/transform/vertex_pulling_transform.cc
index fc571d5..528c7e5 100644
--- a/src/ast/transform/vertex_pulling_transform.cc
+++ b/src/ast/transform/vertex_pulling_transform.cc
@@ -48,6 +48,7 @@
 static const char kStructBufferName[] = "data";
 static const char kPullingPosVarName[] = "tint_pulling_pos";
 static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index";
+static const char kDefaultInstanceIndexName[] = "tint_pulling_instance_index";
 }  // namespace
 
 VertexPullingTransform::VertexPullingTransform(Context* ctx, Module* mod)
@@ -100,7 +101,8 @@
   // TODO(idanr): Make sure we covered all error cases, to guarantee the
   // following stages will pass
 
-  FindOrInsertVertexIndex();
+  FindOrInsertVertexIndexIfUsed();
+  FindOrInsertInstanceIndexIfUsed();
   ConvertVertexInputVariablesToPrivate();
   AddVertexStorageBuffers();
   AddVertexPullingPreamble(vertex_func);
@@ -116,7 +118,19 @@
   return kVertexBufferNamePrefix + std::to_string(index);
 }
 
-void VertexPullingTransform::FindOrInsertVertexIndex() {
+void VertexPullingTransform::FindOrInsertVertexIndexIfUsed() {
+  bool uses_vertex_step_mode = false;
+  for (const VertexBufferLayoutDescriptor& buffer_layout :
+       vertex_state_->vertex_buffers) {
+    if (buffer_layout.step_mode == InputStepMode::kVertex) {
+      uses_vertex_step_mode = true;
+      break;
+    }
+  }
+  if (!uses_vertex_step_mode) {
+    return;
+  }
+
   // Look for an existing vertex index builtin
   for (auto& v : mod_->global_variables()) {
     if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) {
@@ -145,6 +159,47 @@
   mod_->AddGlobalVariable(std::move(var));
 }
 
+void VertexPullingTransform::FindOrInsertInstanceIndexIfUsed() {
+  bool uses_instance_step_mode = false;
+  for (const VertexBufferLayoutDescriptor& buffer_layout :
+       vertex_state_->vertex_buffers) {
+    if (buffer_layout.step_mode == InputStepMode::kInstance) {
+      uses_instance_step_mode = true;
+      break;
+    }
+  }
+  if (!uses_instance_step_mode) {
+    return;
+  }
+
+  // Look for an existing instance index builtin
+  for (auto& v : mod_->global_variables()) {
+    if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) {
+      continue;
+    }
+
+    for (auto& d : v->AsDecorated()->decorations()) {
+      if (d->IsBuiltin() && d->AsBuiltin()->value() == Builtin::kInstanceIdx) {
+        instance_index_name_ = v->name();
+        return;
+      }
+    }
+  }
+
+  // We didn't find an instance index builtin, so create one
+  instance_index_name_ = kDefaultInstanceIndexName;
+
+  auto var = std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+      instance_index_name_, StorageClass::kInput, GetI32Type()));
+
+  VariableDecorationList decorations;
+  decorations.push_back(
+      std::make_unique<BuiltinDecoration>(Builtin::kInstanceIdx));
+
+  var->set_decorations(std::move(decorations));
+  mod_->AddGlobalVariable(std::move(var));
+}
+
 void VertexPullingTransform::ConvertVertexInputVariablesToPrivate() {
   for (auto& v : mod_->global_variables()) {
     if (!v->IsDecorated() || v->storage_class() != StorageClass::kInput) {
@@ -228,12 +283,17 @@
       }
       auto* v = it->second;
 
+      // Identifier to index by
+      auto index_identifier = std::make_unique<IdentifierExpression>(
+          buffer_layout.step_mode == InputStepMode::kVertex
+              ? vertex_index_name_
+              : instance_index_name_);
+
       // An expression for the start of the read in the buffer in bytes
       auto pos_value = std::make_unique<BinaryExpression>(
           BinaryOp::kAdd,
           std::make_unique<BinaryExpression>(
-              BinaryOp::kMultiply,
-              std::make_unique<IdentifierExpression>(vertex_index_name_),
+              BinaryOp::kMultiply, std::move(index_identifier),
               GenUint(static_cast<uint32_t>(buffer_layout.array_stride))),
           GenUint(static_cast<uint32_t>(attribute_desc.offset)));
 
diff --git a/src/ast/transform/vertex_pulling_transform.h b/src/ast/transform/vertex_pulling_transform.h
index 636ddc4..4ad6321 100644
--- a/src/ast/transform/vertex_pulling_transform.h
+++ b/src/ast/transform/vertex_pulling_transform.h
@@ -159,7 +159,10 @@
   std::string GetVertexBufferName(uint32_t index);
 
   /// Inserts vertex_idx binding, or finds the existing one
-  void FindOrInsertVertexIndex();
+  void FindOrInsertVertexIndexIfUsed();
+
+  /// Inserts instance_idx binding, or finds the existing one
+  void FindOrInsertInstanceIndexIfUsed();
 
   /// Converts var<in> with a location decoration to var<private>
   void ConvertVertexInputVariablesToPrivate();
@@ -237,6 +240,7 @@
   std::string error_;
 
   std::string vertex_index_name_;
+  std::string instance_index_name_;
 
   std::unordered_map<uint32_t, Variable*> location_to_var_;
   std::unique_ptr<VertexStateDescriptor> vertex_state_;
diff --git a/src/ast/transform/vertex_pulling_transform_test.cc b/src/ast/transform/vertex_pulling_transform_test.cc
index c0cc841..cafefda 100644
--- a/src/ast/transform/vertex_pulling_transform_test.cc
+++ b/src/ast/transform/vertex_pulling_transform_test.cc
@@ -201,27 +201,14 @@
             mod()->to_str());
 }
 
-// We expect the transform to use an existing vertex_idx builtin variable if it
-// finds one
-TEST_F(VertexPullingTransformTest, ExistingVertexIndex) {
+TEST_F(VertexPullingTransformTest, OneInstancedAttribute) {
   InitBasicModule();
 
   type::F32Type f32;
   AddVertexInputVariable(0, "var_a", &f32);
 
-  type::I32Type i32;
-  auto vertex_index_var =
-      std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
-          "custom_vertex_index", StorageClass::kInput, &i32));
-
-  VariableDecorationList decorations;
-  decorations.push_back(
-      std::make_unique<BuiltinDecoration>(Builtin::kVertexIdx));
-
-  vertex_index_var->set_decorations(std::move(decorations));
-  mod()->AddGlobalVariable(std::move(vertex_index_var));
-
-  InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
+  InitTransform(
+      {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
 
   EXPECT_TRUE(transform()->Run());
 
@@ -233,6 +220,122 @@
   }
   DecoratedVariable{
     Decorations{
+      BuiltinDecoration{instance_idx}
+    }
+    tint_pulling_instance_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{0}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_0
+    storage_buffer
+    __struct_
+  }
+  EntryPoint{vertex as main = vtx_main}
+  Function vtx_main -> __void
+  ()
+  {
+    Block{
+      VariableDeclStatement{
+        Variable{
+          tint_pulling_pos
+          function
+          __i32
+        }
+      }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{tint_pulling_instance_index}
+            multiply
+            ScalarConstructor{4}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_a}
+        As<__f32>{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{tint_pulling_vertex_buffer_0}
+              Identifier{data}
+            }
+            Binary{
+              Identifier{tint_pulling_pos}
+              divide
+              ScalarConstructor{4}
+            }
+          }
+        }
+      }
+    }
+  }
+}
+)",
+            mod()->to_str());
+}
+
+// We expect the transform to use an existing builtin variables if it finds them
+TEST_F(VertexPullingTransformTest, ExistingVertexIndexAndInstanceIndex) {
+  InitBasicModule();
+
+  type::F32Type f32;
+  AddVertexInputVariable(0, "var_a", &f32);
+  AddVertexInputVariable(1, "var_b", &f32);
+
+  type::I32Type i32;
+  {
+    auto vertex_index_var =
+        std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+            "custom_vertex_index", StorageClass::kInput, &i32));
+
+    VariableDecorationList decorations;
+    decorations.push_back(
+        std::make_unique<BuiltinDecoration>(Builtin::kVertexIdx));
+
+    vertex_index_var->set_decorations(std::move(decorations));
+    mod()->AddGlobalVariable(std::move(vertex_index_var));
+  }
+
+  {
+    auto instance_index_var =
+        std::make_unique<DecoratedVariable>(std::make_unique<Variable>(
+            "custom_instance_index", StorageClass::kInput, &i32));
+
+    VariableDecorationList decorations;
+    decorations.push_back(
+        std::make_unique<BuiltinDecoration>(Builtin::kInstanceIdx));
+
+    instance_index_var->set_decorations(std::move(decorations));
+    mod()->AddGlobalVariable(std::move(instance_index_var));
+  }
+
+  InitTransform(
+      {{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}},
+        {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
+
+  EXPECT_TRUE(transform()->Run());
+
+  EXPECT_EQ(R"(Module{
+  Variable{
+    var_a
+    private
+    __f32
+  }
+  Variable{
+    var_b
+    private
+    __f32
+  }
+  DecoratedVariable{
+    Decorations{
       BuiltinDecoration{vertex_idx}
     }
     custom_vertex_index
@@ -241,6 +344,14 @@
   }
   DecoratedVariable{
     Decorations{
+      BuiltinDecoration{instance_idx}
+    }
+    custom_instance_index
+    in
+    __i32
+  }
+  DecoratedVariable{
+    Decorations{
       BindingDecoration{0}
       SetDecoration{0}
     }
@@ -248,6 +359,15 @@
     storage_buffer
     __struct_
   }
+  DecoratedVariable{
+    Decorations{
+      BindingDecoration{1}
+      SetDecoration{0}
+    }
+    tint_pulling_vertex_buffer_1
+    storage_buffer
+    __struct_
+  }
   EntryPoint{vertex as main = vtx_main}
   Function vtx_main -> __void
   ()
@@ -288,6 +408,34 @@
           }
         }
       }
+      Assignment{
+        Identifier{tint_pulling_pos}
+        Binary{
+          Binary{
+            Identifier{custom_instance_index}
+            multiply
+            ScalarConstructor{4}
+          }
+          add
+          ScalarConstructor{0}
+        }
+      }
+      Assignment{
+        Identifier{var_b}
+        As<__f32>{
+          ArrayAccessor{
+            MemberAccessor{
+              Identifier{tint_pulling_vertex_buffer_1}
+              Identifier{data}
+            }
+            Binary{
+              Identifier{tint_pulling_pos}
+              divide
+              ScalarConstructor{4}
+            }
+          }
+        }
+      }
     }
   }
 }