transform/InlinePtrLets: Fix ICE for lets in for-loops

For loop initializers and continuing statements do not have a BlockStatement as their parent.
Handle removal of these statements with a new Transform::RemoveStatement() helper

Fixed: tint:990
Change-Id: I24e7b18dcf71d3ef0a4d3ee68b9f68518e0eb5e8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58063
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/clone_context.h b/src/clone_context.h
index 8be0b9f..c5437a5 100644
--- a/src/clone_context.h
+++ b/src/clone_context.h
@@ -513,12 +513,15 @@
   /// Reports an internal compiler error if the cast failed.
   template <typename TO, typename FROM>
   TO* CheckedCast(FROM* obj) {
-    if (TO* cast = As<TO>(obj)) {
+    if (obj == nullptr) {
+      return nullptr;
+    }
+    if (TO* cast = obj->template As<TO>()) {
       return cast;
     }
     TINT_ICE(Clone, Diagnostics())
         << "Cloned object was not of the expected type\n"
-        << "got:      " << (obj ? obj->TypeInfo().name : "<null>") << "\n"
+        << "got:      " << obj->TypeInfo().name << "\n"
         << "expected: " << TypeInfo::Of<TO>().name;
     return nullptr;
   }
diff --git a/src/transform/inline_pointer_lets.cc b/src/transform/inline_pointer_lets.cc
index 6b18f75..38d905c 100644
--- a/src/transform/inline_pointer_lets.cc
+++ b/src/transform/inline_pointer_lets.cc
@@ -177,7 +177,7 @@
       ptr_lets.emplace(let->variable(), std::move(ptr_let));
       // As the original `let` declaration will be fully inlined, there's no
       // need for the original declaration to exist. Remove it.
-      ctx.Remove(block->statements(), let);
+      RemoveStatement(ctx, let);
     }
   }
 
diff --git a/src/transform/transform.cc b/src/transform/transform.cc
index 8342d1f..cd788c5 100644
--- a/src/transform/transform.cc
+++ b/src/transform/transform.cc
@@ -19,6 +19,8 @@
 
 #include "src/program_builder.h"
 #include "src/sem/atomic_type.h"
+#include "src/sem/block_statement.h"
+#include "src/sem/for_loop_statement.h"
 #include "src/sem/reference_type.h"
 
 TINT_INSTANTIATE_TYPEINFO(tint::transform::Transform);
@@ -84,6 +86,21 @@
   return new_decorations;
 }
 
+void Transform::RemoveStatement(CloneContext& ctx, ast::Statement* stmt) {
+  auto* sem = ctx.src->Sem().Get(stmt);
+  if (auto* block = tint::As<sem::BlockStatement>(sem->Parent())) {
+    ctx.Remove(block->Declaration()->statements(), stmt);
+    return;
+  }
+  if (tint::Is<sem::ForLoopStatement>(sem->Parent())) {
+    ctx.Replace(stmt, static_cast<ast::Expression*>(nullptr));
+    return;
+  }
+  TINT_ICE(Transform, ctx.dst->Diagnostics())
+      << "unable to remove statement from parent of type "
+      << sem->TypeInfo().name;
+}
+
 ast::Type* Transform::CreateASTTypeFor(CloneContext& ctx, const sem::Type* ty) {
   if (ty->Is<sem::Void>()) {
     return ctx.dst->create<ast::Void>();
diff --git a/src/transform/transform.h b/src/transform/transform.h
index 08bf945..9f9e6d3 100644
--- a/src/transform/transform.h
+++ b/src/transform/transform.h
@@ -195,6 +195,13 @@
       const ast::DecorationList& in,
       std::function<bool(const ast::Decoration*)> should_remove);
 
+  /// Removes the statement `stmt` from the transformed program.
+  /// RemoveStatement handles edge cases, like statements in the initializer and
+  /// continuing of for-loops.
+  /// @param ctx the clone context
+  /// @param stmt the statement to remove when the program is cloned
+  static void RemoveStatement(CloneContext& ctx, ast::Statement* stmt);
+
   /// CreateASTTypeFor constructs new ast::Type nodes that reconstructs the
   /// semantic type `ty`.
   /// @param ctx the clone context
diff --git a/test/bug/tint/990.wgsl b/test/bug/tint/990.wgsl
new file mode 100644
index 0000000..3f31b56
--- /dev/null
+++ b/test/bug/tint/990.wgsl
@@ -0,0 +1,4 @@
+fn f() {
+    var i : i32;
+    for (let p = &i;;) {}
+}
diff --git a/test/bug/tint/990.wgsl.expected.hlsl b/test/bug/tint/990.wgsl.expected.hlsl
new file mode 100644
index 0000000..80751f8
--- /dev/null
+++ b/test/bug/tint/990.wgsl.expected.hlsl
@@ -0,0 +1,12 @@
+[numthreads(1, 1, 1)]
+void unused_entry_point() {
+  return;
+}
+
+void f() {
+  int i = 0;
+  {
+    for(; ; ) {
+    }
+  }
+}
diff --git a/test/bug/tint/990.wgsl.expected.msl b/test/bug/tint/990.wgsl.expected.msl
new file mode 100644
index 0000000..62e37ce
--- /dev/null
+++ b/test/bug/tint/990.wgsl.expected.msl
@@ -0,0 +1,9 @@
+#include <metal_stdlib>
+
+using namespace metal;
+void f() {
+  int i = 0;
+  for(; ; ) {
+  }
+}
+
diff --git a/test/bug/tint/990.wgsl.expected.spvasm b/test/bug/tint/990.wgsl.expected.spvasm
new file mode 100644
index 0000000..a63370b
--- /dev/null
+++ b/test/bug/tint/990.wgsl.expected.spvasm
@@ -0,0 +1,35 @@
+; SPIR-V
+; Version: 1.3
+; Generator: Google Tint Compiler; 0
+; Bound: 15
+; Schema: 0
+               OpCapability Shader
+               OpMemoryModel Logical GLSL450
+               OpEntryPoint GLCompute %unused_entry_point "unused_entry_point"
+               OpExecutionMode %unused_entry_point LocalSize 1 1 1
+               OpName %unused_entry_point "unused_entry_point"
+               OpName %f "f"
+               OpName %i "i"
+       %void = OpTypeVoid
+          %1 = OpTypeFunction %void
+        %int = OpTypeInt 32 1
+%_ptr_Function_int = OpTypePointer Function %int
+         %10 = OpConstantNull %int
+%unused_entry_point = OpFunction %void None %1
+          %4 = OpLabel
+               OpReturn
+               OpFunctionEnd
+          %f = OpFunction %void None %1
+          %6 = OpLabel
+          %i = OpVariable %_ptr_Function_int Function %10
+               OpBranch %11
+         %11 = OpLabel
+               OpLoopMerge %12 %13 None
+               OpBranch %14
+         %14 = OpLabel
+               OpBranch %13
+         %13 = OpLabel
+               OpBranch %11
+         %12 = OpLabel
+               OpReturn
+               OpFunctionEnd
diff --git a/test/bug/tint/990.wgsl.expected.wgsl b/test/bug/tint/990.wgsl.expected.wgsl
new file mode 100644
index 0000000..e7341b6
--- /dev/null
+++ b/test/bug/tint/990.wgsl.expected.wgsl
@@ -0,0 +1,5 @@
+fn f() {
+  var i : i32;
+  for(let p = &(i); ; ;) {
+  }
+}