writer/hlsl: Use vector write helper for dynamic indices

This uses FXC compilation failure mitigation for _any_ vector index assignment that has a non-constant index. FXC can still fall over if the loop calls a function that performs the dynamic index.

Use some vector swizzle logic to avoid branches in the helper.

Fixed: tint:980
Change-Id: I2a759d88a7d884bc61b4631cf57feb4acc8178de
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/57882
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc
index 3f03dc7..5bb015b 100644
--- a/src/writer/hlsl/generator_impl.cc
+++ b/src/writer/hlsl/generator_impl.cc
@@ -126,10 +126,6 @@
   const TypeInfo* last_kind = nullptr;
   size_t last_padding_line = 0;
 
-  if (!FindAndEmitVectorAssignmentInLoopFunctions()) {
-    return false;
-  }
-
   for (auto* decl : builder_.AST().GlobalDeclarations()) {
     if (decl->Is<ast::Alias>()) {
       continue;  // Ignore aliases.
@@ -178,96 +174,66 @@
   return true;
 }
 
-bool GeneratorImpl::FindAndEmitVectorAssignmentInLoopFunctions() {
-  auto is_in_loop = [](const sem::Expression* expr) {
-    auto* block = expr->Stmt()->Block();
-    if (!block) {
-      return false;
-    }
-    return block->FindFirstParent<sem::LoopBlockStatement>() != nullptr;
-  };
-
-  auto emit_function_once = [&](const sem::Vector* vec) {
-    utils::GetOrCreate(vector_assignment_in_loop_funcs_, vec, [&] {
-      std::ostringstream ss;
-      EmitType(ss, vec, tint::ast::StorageClass::kInvalid,
-               ast::Access::kUndefined, "");
-      auto func_name = UniqueIdentifier("Set_" + ss.str());
-      {
-        auto out = line();
-        out << "void " << func_name << "(inout ";
-        EmitType(out, vec, ast::StorageClass::kInvalid, ast::Access::kUndefined,
-                 "");
-        out << " vec, int idx, ";
-        EmitType(out, vec->type(), ast::StorageClass::kInvalid,
-                 ast::Access::kUndefined, "");
-        out << " val) {";
-      }
-      {
-        ScopedIndent si(this);
-        line() << "switch(idx) {";
-        {
-          ScopedIndent si2(this);
-          for (size_t i = 0; i < vec->size(); ++i) {
-            auto sidx = std::to_string(i);
-            line() << "case " + sidx + ": vec[" + sidx + "] = val; break;";
-          }
-        }
-        line() << "}";
-      }
-      line() << "}";
-      return func_name;
-    });
-  };
-
-  // Find vector assignments via an accessor expression (index) within loops so
-  // that we can replace them later with calls to setter functions. Also emit
-  // the setter functions per vector type as we find them. We do this to avoid
-  // having FCX fail to unroll loops with "error X3511: forced to unroll loop,
-  // but unrolling failed." See crbug.com/tint/534.
-
-  for (auto* ast_node : program_->ASTNodes().Objects()) {
-    auto* ast_assign = ast_node->As<ast::AssignmentStatement>();
-    if (!ast_assign) {
-      continue;
-    }
-
-    auto* ast_access_expr =
-        ast_assign->lhs()->As<ast::ArrayAccessorExpression>();
-    if (!ast_access_expr) {
-      continue;
-    }
-
-    auto* array_expr = builder_.Sem().Get(ast_access_expr->array());
-    auto* vec = array_expr->Type()->UnwrapRef()->As<sem::Vector>();
-
-    // Skip non-vectors
-    if (!vec) {
-      continue;
-    }
-
-    // Skip if not part of a loop
-    if (!is_in_loop(array_expr)) {
-      continue;
-    }
-
-    // Save this assignment along with the vector type
-    vector_assignments_in_loops_.emplace(ast_assign, vec);
-
-    // Emit the function if it hasn't already
-    emit_function_once(vec);
-  }
-
-  return true;
-}
-
-bool GeneratorImpl::EmitVectorAssignmentInLoopCall(
+bool GeneratorImpl::EmitDynamicVectorAssignment(
     const ast::AssignmentStatement* stmt,
     const sem::Vector* vec) {
+  auto name =
+      utils::GetOrCreate(dynamic_vector_write_, vec, [&]() -> std::string {
+        std::string fn;
+        {
+          std::ostringstream ss;
+          if (!EmitType(ss, vec, tint::ast::StorageClass::kInvalid,
+                        ast::Access::kUndefined, "")) {
+            return "";
+          }
+          fn = UniqueIdentifier("set_" + ss.str());
+        }
+        {
+          auto out = line(&helpers_);
+          out << "void " << fn << "(inout ";
+          if (!EmitTypeAndName(out, vec, ast::StorageClass::kInvalid,
+                               ast::Access::kUndefined, "vec")) {
+            return "";
+          }
+          out << ", int idx, ";
+          if (!EmitTypeAndName(out, vec->type(), ast::StorageClass::kInvalid,
+                               ast::Access::kUndefined, "val")) {
+            return "";
+          }
+          out << ") {";
+        }
+        {
+          ScopedIndent si(&helpers_);
+          auto out = line(&helpers_);
+          switch (vec->size()) {
+            case 2:
+              out << "vec = (idx.xx == int2(0, 1)) ? val.xx : vec;";
+              break;
+            case 3:
+              out << "vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;";
+              break;
+            case 4:
+              out << "vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;";
+              break;
+            default:
+              TINT_UNREACHABLE(Writer, builder_.Diagnostics())
+                  << "invalid vector size " << vec->size();
+              break;
+          }
+        }
+        line(&helpers_) << "}";
+        line(&helpers_);
+        return fn;
+      });
+
+  if (name.empty()) {
+    return false;
+  }
+
   auto* ast_access_expr = stmt->lhs()->As<ast::ArrayAccessorExpression>();
 
   auto out = line();
-  out << vector_assignment_in_loop_funcs_.at(vec) << "(";
+  out << name << "(";
   if (!EmitExpression(out, ast_access_expr->array())) {
     return false;
   }
@@ -322,9 +288,13 @@
 }
 
 bool GeneratorImpl::EmitAssign(ast::AssignmentStatement* stmt) {
-  auto iter = vector_assignments_in_loops_.find(stmt);
-  if (iter != vector_assignments_in_loops_.end()) {
-    return EmitVectorAssignmentInLoopCall(iter->first, iter->second);
+  if (auto* idx = stmt->lhs()->As<ast::ArrayAccessorExpression>()) {
+    if (auto* vec = TypeOf(idx->array())->UnwrapRef()->As<sem::Vector>()) {
+      auto* rhs_sem = builder_.Sem().Get(idx->idx_expr());
+      if (!rhs_sem->ConstantValue().IsValid()) {
+        return EmitDynamicVectorAssignment(stmt, vec);
+      }
+    }
   }
 
   auto out = line();
diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h
index 842cc5f..2a39a5a 100644
--- a/src/writer/hlsl/generator_impl.h
+++ b/src/writer/hlsl/generator_impl.h
@@ -349,21 +349,15 @@
   /// @param var the variable to emit
   /// @returns true if the variable was emitted
   bool EmitProgramConstVariable(const ast::Variable* var);
-
-  /// Finds vector assignments via an accessor expression within loops, storing
-  /// the assignment/vector node pair in `vector_assignments_in_loops`, and
-  /// emits function definitions per vector type found. Required to work around
-  /// an FXC bug, see crbug.com/tint/534.
-  /// @returns true on success
-  bool FindAndEmitVectorAssignmentInLoopFunctions();
-  /// Emits call to vector assignment function for the input assignment
-  /// statement and vector type.
-  /// @param stmt assignment statement that corresponds to a vector assingment
+  /// Emits call to a helper vector assignment function for the input assignment
+  /// statement and vector type. This is used to work around FXC issues where
+  /// assignments to vectors with dynamic indices cause compilation failures.
+  /// @param stmt assignment statement that corresponds to a vector assignment
   /// via an accessor expression
   /// @param vec the vector type being assigned to
   /// @returns true on success
-  bool EmitVectorAssignmentInLoopCall(const ast::AssignmentStatement* stmt,
-                                      const sem::Vector* vec);
+  bool EmitDynamicVectorAssignment(const ast::AssignmentStatement* stmt,
+                                   const sem::Vector* vec);
 
   /// Handles generating a builtin method name
   /// @param intrinsic the semantic info for the intrinsic
@@ -413,10 +407,7 @@
   std::unordered_map<DMAIntrinsic, std::string, DMAIntrinsic::Hasher>
       dma_intrinsics_;
   std::unordered_map<const sem::Struct*, std::string> structure_builders_;
-  std::unordered_map<const ast::AssignmentStatement*, const sem::Vector*>
-      vector_assignments_in_loops_;
-  std::unordered_map<const sem::Vector*, std::string>
-      vector_assignment_in_loop_funcs_;
+  std::unordered_map<const sem::Vector*, std::string> dynamic_vector_write_;
 };
 
 }  // namespace hlsl
diff --git a/test/bug/tint/534.wgsl.expected.hlsl b/test/bug/tint/534.wgsl.expected.hlsl
index da61ba4..9c8466c 100644
--- a/test/bug/tint/534.wgsl.expected.hlsl
+++ b/test/bug/tint/534.wgsl.expected.hlsl
@@ -1,10 +1,5 @@
-void Set_uint4(inout uint4 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
 
 Texture2D<float4> src : register(t0, space0);
@@ -40,7 +35,7 @@
   uint4 dstColorBits = uint4(dstColor);
   {
     for(uint i = 0u; (i < uniforms[0].w); i = (i + 1u)) {
-      Set_uint4(srcColorBits, i, ConvertToFp16FloatValue(srcColor[i]));
+      set_uint4(srcColorBits, i, ConvertToFp16FloatValue(srcColor[i]));
       bool tint_tmp_1 = success;
       if (tint_tmp_1) {
         tint_tmp_1 = (srcColorBits[i] == dstColorBits[i]);
diff --git a/test/bug/tint/980.wgsl b/test/bug/tint/980.wgsl
new file mode 100644
index 0000000..28d7d73
--- /dev/null
+++ b/test/bug/tint/980.wgsl
@@ -0,0 +1,13 @@
+// Fails with "D3D compile failed with value cannot be NaN, isnan() may not be necessary.  /Gis may force isnan() to be performed"
+fn Bad(index: u32, rd: vec3<f32>) -> vec3<f32> {
+  var normal: vec3<f32> = vec3<f32>(0.0);
+  normal[index] = -sign(rd[index]);
+  return normalize(normal);
+}
+
+[[block]] struct S { v : vec3<f32>; i : u32; };
+[[binding(0), group(0)]] var<storage, read_write> io : S;
+[[stage(compute), workgroup_size(1)]]
+fn main([[builtin(local_invocation_index)]] idx : u32) {
+    io.v = Bad(io.i, io.v);
+}
diff --git a/test/bug/tint/980.wgsl.expected.hlsl b/test/bug/tint/980.wgsl.expected.hlsl
new file mode 100644
index 0000000..cbd622f
--- /dev/null
+++ b/test/bug/tint/980.wgsl.expected.hlsl
@@ -0,0 +1,22 @@
+void set_float3(inout float3 vec, int idx, float val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+float3 Bad(uint index, float3 rd) {
+  float3 normal = float3((0.0f).xxx);
+  set_float3(normal, index, -(sign(rd[index])));
+  return normalize(normal);
+}
+
+RWByteAddressBuffer io : register(u0, space0);
+
+struct tint_symbol_1 {
+  uint idx : SV_GroupIndex;
+};
+
+[numthreads(1, 1, 1)]
+void main(tint_symbol_1 tint_symbol) {
+  const uint idx = tint_symbol.idx;
+  io.Store3(0u, asuint(Bad(io.Load(12u), asfloat(io.Load3(0u)))));
+  return;
+}
diff --git a/test/bug/tint/980.wgsl.expected.msl b/test/bug/tint/980.wgsl.expected.msl
new file mode 100644
index 0000000..340d3eb
--- /dev/null
+++ b/test/bug/tint/980.wgsl.expected.msl
@@ -0,0 +1,19 @@
+#include <metal_stdlib>
+
+using namespace metal;
+struct S {
+  /* 0x0000 */ packed_float3 v;
+  /* 0x000c */ uint i;
+};
+
+float3 Bad(uint index, float3 rd) {
+  float3 normal = float3(0.0f);
+  normal[index] = -(sign(rd[index]));
+  return normalize(normal);
+}
+
+kernel void tint_symbol(uint idx [[thread_index_in_threadgroup]], device S& io [[buffer(0)]]) {
+  io.v = Bad(io.i, io.v);
+  return;
+}
+
diff --git a/test/bug/tint/980.wgsl.expected.spvasm b/test/bug/tint/980.wgsl.expected.spvasm
new file mode 100644
index 0000000..2323ab3
--- /dev/null
+++ b/test/bug/tint/980.wgsl.expected.spvasm
@@ -0,0 +1,72 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 41
+; Schema: 0
+               OpCapability Shader
+         %23 = OpExtInstImport "GLSL.std.450"
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %main "main"
+               OpExecutionMode %main LocalSize 1 1 1
+               OpName %S "S"
+               OpMemberName %S 0 "v"
+               OpMemberName %S 1 "i"
+               OpName %io "io"
+               OpName %tint_symbol "tint_symbol"
+               OpName %Bad "Bad"
+               OpName %index "index"
+               OpName %rd "rd"
+               OpName %normal "normal"
+               OpName %main "main"
+               OpDecorate %S Block
+               OpMemberDecorate %S 0 Offset 0
+               OpMemberDecorate %S 1 Offset 12
+               OpDecorate %io Binding 0
+               OpDecorate %io DescriptorSet 0
+               OpDecorate %tint_symbol BuiltIn LocalInvocationIndex
+      %float = OpTypeFloat 32
+    %v3float = OpTypeVector %float 3
+       %uint = OpTypeInt 32 0
+          %S = OpTypeStruct %v3float %uint
+%_ptr_StorageBuffer_S = OpTypePointer StorageBuffer %S
+         %io = OpVariable %_ptr_StorageBuffer_S StorageBuffer
+%_ptr_Input_uint = OpTypePointer Input %uint
+%tint_symbol = OpVariable %_ptr_Input_uint Input
+          %9 = OpTypeFunction %v3float %uint %v3float
+    %float_0 = OpConstant %float 0
+         %15 = OpConstantComposite %v3float %float_0 %float_0 %float_0
+%_ptr_Function_v3float = OpTypePointer Function %v3float
+         %18 = OpConstantNull %v3float
+%_ptr_Function_float = OpTypePointer Function %float
+       %void = OpTypeVoid
+         %27 = OpTypeFunction %void
+     %uint_0 = OpConstant %uint 0
+%_ptr_StorageBuffer_v3float = OpTypePointer StorageBuffer %v3float
+     %uint_1 = OpConstant %uint 1
+%_ptr_StorageBuffer_uint = OpTypePointer StorageBuffer %uint
+        %Bad = OpFunction %v3float None %9
+      %index = OpFunctionParameter %uint
+         %rd = OpFunctionParameter %v3float
+         %13 = OpLabel
+     %normal = OpVariable %_ptr_Function_v3float Function %18
+               OpStore %normal %15
+         %20 = OpAccessChain %_ptr_Function_float %normal %index
+         %24 = OpVectorExtractDynamic %float %rd %index
+         %22 = OpExtInst %float %23 FSign %24
+         %21 = OpFNegate %float %22
+               OpStore %20 %21
+         %26 = OpLoad %v3float %normal
+         %25 = OpExtInst %v3float %23 Normalize %26
+               OpReturnValue %25
+               OpFunctionEnd
+       %main = OpFunction %void None %27
+         %30 = OpLabel
+         %33 = OpAccessChain %_ptr_StorageBuffer_v3float %io %uint_0
+         %37 = OpAccessChain %_ptr_StorageBuffer_uint %io %uint_1
+         %38 = OpLoad %uint %37
+         %39 = OpAccessChain %_ptr_StorageBuffer_v3float %io %uint_0
+         %40 = OpLoad %v3float %39
+         %34 = OpFunctionCall %v3float %Bad %38 %40
+               OpStore %33 %34
+               OpReturn
+               OpFunctionEnd
diff --git a/test/bug/tint/980.wgsl.expected.wgsl b/test/bug/tint/980.wgsl.expected.wgsl
new file mode 100644
index 0000000..acc7ee6
--- /dev/null
+++ b/test/bug/tint/980.wgsl.expected.wgsl
@@ -0,0 +1,18 @@
+fn Bad(index : u32, rd : vec3<f32>) -> vec3<f32> {
+  var normal : vec3<f32> = vec3<f32>(0.0);
+  normal[index] = -(sign(rd[index]));
+  return normalize(normal);
+}
+
+[[block]]
+struct S {
+  v : vec3<f32>;
+  i : u32;
+};
+
+[[binding(0), group(0)]] var<storage, read_write> io : S;
+
+[[stage(compute), workgroup_size(1)]]
+fn main([[builtin(local_invocation_index)]] idx : u32) {
+  io.v = Bad(io.i, io.v);
+}
diff --git a/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_loop.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_loop.wgsl.expected.hlsl
index a222773..7b10259 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_loop.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_loop.wgsl.expected.hlsl
@@ -1,30 +1,19 @@
-void Set_float2(inout float2 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_int3(inout int3 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_uint4(inout uint4 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
-void Set_bool2(inout bool2 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
+
 static float2 v2f = float2(0.0f, 0.0f);
 static int3 v3i = int3(0, 0, 0);
 static uint4 v4u = uint4(0u, 0u, 0u, 0u);
@@ -33,10 +22,10 @@
 void foo() {
   {
     for(int i = 0; (i < 2); i = (i + 1)) {
-      Set_float2(v2f, i, 1.0f);
-      Set_int3(v3i, i, 1);
-      Set_uint4(v4u, i, 1u);
-      Set_bool2(v2b, i, true);
+      set_float2(v2f, i, 1.0f);
+      set_int3(v3i, i, 1);
+      set_uint4(v4u, i, 1u);
+      set_bool2(v2b, i, true);
     }
   }
 }
diff --git a/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_no_loop.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_no_loop.wgsl.expected.hlsl
index ea05754..64b4fea 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_no_loop.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/loop_call_with_no_loop.wgsl.expected.hlsl
@@ -1,3 +1,19 @@
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
 static float2 v2f = float2(0.0f, 0.0f);
 static int3 v3i = int3(0, 0, 0);
 static uint4 v4u = uint4(0u, 0u, 0u, 0u);
@@ -5,10 +21,10 @@
 
 void foo() {
   int i = 0;
-  v2f[i] = 1.0f;
-  v3i[i] = 1;
-  v4u[i] = 1u;
-  v2b[i] = true;
+  set_float2(v2f, i, 1.0f);
+  set_int3(v3i, i, 1);
+  set_uint4(v4u, i, 1u);
+  set_bool2(v2b, i, true);
 }
 
 [numthreads(1, 1, 1)]
diff --git a/test/fxc_bugs/vector_assignment_in_loop/loop_types_all.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/loop_types_all.wgsl.expected.hlsl
index 85ffc25..158e0e1 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/loop_types_all.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/loop_types_all.wgsl.expected.hlsl
@@ -1,87 +1,51 @@
-void Set_float2(inout float2 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_float3(inout float3 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_float3(inout float3 vec, int idx, float val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_float4(inout float4 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_float4(inout float4 vec, int idx, float val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
-void Set_int2(inout int2 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_int2(inout int2 vec, int idx, int val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_int3(inout int3 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_int4(inout int4 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_int4(inout int4 vec, int idx, int val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
-void Set_uint2(inout uint2 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_uint2(inout uint2 vec, int idx, uint val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_uint3(inout uint3 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_uint3(inout uint3 vec, int idx, uint val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_uint4(inout uint4 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
-void Set_bool2(inout bool2 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_bool3(inout bool3 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_bool3(inout bool3 vec, int idx, bool val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_bool4(inout bool4 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_bool4(inout bool4 vec, int idx, bool val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
+
 [numthreads(1, 1, 1)]
 void main() {
   float2 v2f = float2(0.0f, 0.0f);
@@ -98,18 +62,18 @@
   bool4 v4b = bool4(false, false, false, false);
   {
     for(int i = 0; (i < 2); i = (i + 1)) {
-      Set_float2(v2f, i, 1.0f);
-      Set_float3(v3f, i, 1.0f);
-      Set_float4(v4f, i, 1.0f);
-      Set_int2(v2i, i, 1);
-      Set_int3(v3i, i, 1);
-      Set_int4(v4i, i, 1);
-      Set_uint2(v2u, i, 1u);
-      Set_uint3(v3u, i, 1u);
-      Set_uint4(v4u, i, 1u);
-      Set_bool2(v2b, i, true);
-      Set_bool3(v3b, i, true);
-      Set_bool4(v4b, i, true);
+      set_float2(v2f, i, 1.0f);
+      set_float3(v3f, i, 1.0f);
+      set_float4(v4f, i, 1.0f);
+      set_int2(v2i, i, 1);
+      set_int3(v3i, i, 1);
+      set_int4(v4i, i, 1);
+      set_uint2(v2u, i, 1u);
+      set_uint3(v3u, i, 1u);
+      set_uint4(v4u, i, 1u);
+      set_bool2(v2b, i, true);
+      set_bool3(v3b, i, true);
+      set_bool4(v4b, i, true);
     }
   }
   return;
diff --git a/test/fxc_bugs/vector_assignment_in_loop/loop_types_repeated.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/loop_types_repeated.wgsl.expected.hlsl
index 9c60097..6cabffe 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/loop_types_repeated.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/loop_types_repeated.wgsl.expected.hlsl
@@ -1,30 +1,19 @@
-void Set_float2(inout float2 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_int3(inout int3 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-  }
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
 }
-void Set_uint4(inout uint4 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-    case 2: vec[2] = val; break;
-    case 3: vec[3] = val; break;
-  }
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
 }
-void Set_bool2(inout bool2 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
+
 [numthreads(1, 1, 1)]
 void main() {
   float2 v2f = float2(0.0f, 0.0f);
@@ -37,14 +26,14 @@
   bool2 v2b_2 = bool2(false, false);
   {
     for(int i = 0; (i < 2); i = (i + 1)) {
-      Set_float2(v2f, i, 1.0f);
-      Set_int3(v3i, i, 1);
-      Set_uint4(v4u, i, 1u);
-      Set_bool2(v2b, i, true);
-      Set_float2(v2f_2, i, 1.0f);
-      Set_int3(v3i_2, i, 1);
-      Set_uint4(v4u_2, i, 1u);
-      Set_bool2(v2b_2, i, true);
+      set_float2(v2f, i, 1.0f);
+      set_int3(v3i, i, 1);
+      set_uint4(v4u, i, 1u);
+      set_bool2(v2b, i, true);
+      set_float2(v2f_2, i, 1.0f);
+      set_int3(v3i_2, i, 1);
+      set_uint4(v4u_2, i, 1u);
+      set_bool2(v2b_2, i, true);
     }
   }
   return;
diff --git a/test/fxc_bugs/vector_assignment_in_loop/loop_types_some.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/loop_types_some.wgsl.expected.hlsl
index 2e6de37..1461d03 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/loop_types_some.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/loop_types_some.wgsl.expected.hlsl
@@ -1,27 +1,51 @@
-void Set_float2(inout float2 vec, int idx, float val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_int2(inout int2 vec, int idx, int val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_int2(inout int2 vec, int idx, int val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_uint2(inout uint2 vec, int idx, uint val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_uint2(inout uint2 vec, int idx, uint val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
-void Set_bool2(inout bool2 vec, int idx, bool val) {
-  switch(idx) {
-    case 0: vec[0] = val; break;
-    case 1: vec[1] = val; break;
-  }
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
 }
+
+void set_float3(inout float3 vec, int idx, float val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_float4(inout float4 vec, int idx, float val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_int4(inout int4 vec, int idx, int val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_uint3(inout uint3 vec, int idx, uint val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_bool3(inout bool3 vec, int idx, bool val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_bool4(inout bool4 vec, int idx, bool val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
 [numthreads(1, 1, 1)]
 void main() {
   float2 v2f = float2(0.0f, 0.0f);
@@ -38,20 +62,20 @@
   bool4 v4b = bool4(false, false, false, false);
   {
     for(int i = 0; (i < 2); i = (i + 1)) {
-      Set_float2(v2f, i, 1.0f);
-      Set_int2(v2i, i, 1);
-      Set_uint2(v2u, i, 1u);
-      Set_bool2(v2b, i, true);
+      set_float2(v2f, i, 1.0f);
+      set_int2(v2i, i, 1);
+      set_uint2(v2u, i, 1u);
+      set_bool2(v2b, i, true);
     }
   }
   int i = 0;
-  v3f[i] = 1.0f;
-  v4f[i] = 1.0f;
-  v3i[i] = 1;
-  v4i[i] = 1;
-  v3u[i] = 1u;
-  v4u[i] = 1u;
-  v3b[i] = true;
-  v4b[i] = true;
+  set_float3(v3f, i, 1.0f);
+  set_float4(v4f, i, 1.0f);
+  set_int3(v3i, i, 1);
+  set_int4(v4i, i, 1);
+  set_uint3(v3u, i, 1u);
+  set_uint4(v4u, i, 1u);
+  set_bool3(v3b, i, true);
+  set_bool4(v4b, i, true);
   return;
 }
diff --git a/test/fxc_bugs/vector_assignment_in_loop/no_loop.wgsl.expected.hlsl b/test/fxc_bugs/vector_assignment_in_loop/no_loop.wgsl.expected.hlsl
index f836c66..69170f0 100644
--- a/test/fxc_bugs/vector_assignment_in_loop/no_loop.wgsl.expected.hlsl
+++ b/test/fxc_bugs/vector_assignment_in_loop/no_loop.wgsl.expected.hlsl
@@ -1,3 +1,51 @@
+void set_float2(inout float2 vec, int idx, float val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
+void set_float3(inout float3 vec, int idx, float val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_float4(inout float4 vec, int idx, float val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_int2(inout int2 vec, int idx, int val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
+void set_int3(inout int3 vec, int idx, int val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_int4(inout int4 vec, int idx, int val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_uint2(inout uint2 vec, int idx, uint val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
+void set_uint3(inout uint3 vec, int idx, uint val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_uint4(inout uint4 vec, int idx, uint val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
+void set_bool2(inout bool2 vec, int idx, bool val) {
+  vec = (idx.xx == int2(0, 1)) ? val.xx : vec;
+}
+
+void set_bool3(inout bool3 vec, int idx, bool val) {
+  vec = (idx.xxx == int3(0, 1, 2)) ? val.xxx : vec;
+}
+
+void set_bool4(inout bool4 vec, int idx, bool val) {
+  vec = (idx.xxxx == int4(0, 1, 2, 3)) ? val.xxxx : vec;
+}
+
 [numthreads(1, 1, 1)]
 void main() {
   float2 v2f = float2(0.0f, 0.0f);
@@ -13,17 +61,17 @@
   bool3 v3b = bool3(false, false, false);
   bool4 v4b = bool4(false, false, false, false);
   int i = 0;
-  v2f[i] = 1.0f;
-  v3f[i] = 1.0f;
-  v4f[i] = 1.0f;
-  v2i[i] = 1;
-  v3i[i] = 1;
-  v4i[i] = 1;
-  v2u[i] = 1u;
-  v3u[i] = 1u;
-  v4u[i] = 1u;
-  v2b[i] = true;
-  v3b[i] = true;
-  v4b[i] = true;
+  set_float2(v2f, i, 1.0f);
+  set_float3(v3f, i, 1.0f);
+  set_float4(v4f, i, 1.0f);
+  set_int2(v2i, i, 1);
+  set_int3(v3i, i, 1);
+  set_int4(v4i, i, 1);
+  set_uint2(v2u, i, 1u);
+  set_uint3(v3u, i, 1u);
+  set_uint4(v4u, i, 1u);
+  set_bool2(v2b, i, true);
+  set_bool3(v3b, i, true);
+  set_bool4(v4b, i, true);
   return;
 }