Fix Undefined Behaviour

All caused by calling Castable::As<> on nullptr objects.

Bug: tint:760
Change-Id: I0a408b3cd58086cfeab5a1af34d643f50f304948
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49523
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/castable.h b/src/castable.h
index 884ba5f..5bb8fdb 100644
--- a/src/castable.h
+++ b/src/castable.h
@@ -173,13 +173,20 @@
 /// @see CastFlags
 template <typename TO, int FLAGS = 0, typename FROM = detail::Infer>
 inline TO* As(FROM* obj) {
-  using castable =
-      typename std::conditional<std::is_const<FROM>::value, const CastableBase,
-                                CastableBase>::type;
-  auto* as_castable = static_cast<castable*>(obj);
+  auto* as_castable = static_cast<CastableBase*>(obj);
   return Is<TO, FLAGS>(obj) ? static_cast<TO*>(as_castable) : nullptr;
 }
 
+/// @returns obj dynamically cast to the type `TO` or `nullptr` if
+/// this object does not derive from `TO`.
+/// @param obj the object to cast from
+/// @see CastFlags
+template <typename TO, int FLAGS = 0, typename FROM = detail::Infer>
+inline const TO* As(const FROM* obj) {
+  auto* as_castable = static_cast<const CastableBase*>(obj);
+  return Is<TO, FLAGS>(obj) ? static_cast<const TO*>(as_castable) : nullptr;
+}
+
 /// CastableBase is the base class for all Castable objects.
 /// It is not encouraged to directly derive from CastableBase without using the
 /// Castable helper template.
diff --git a/src/program_builder.cc b/src/program_builder.cc
index 7689f98..c16f878 100644
--- a/src/program_builder.cc
+++ b/src/program_builder.cc
@@ -132,10 +132,10 @@
 
 typ::Type ProgramBuilder::TypesBuilder::MaybeCreateTypename(
     typ::Type type) const {
-  if (auto* alias = type.ast->As<ast::Alias>()) {
+  if (auto* alias = As<ast::Alias>(type.ast)) {
     return {builder->create<ast::TypeName>(alias->symbol()), type.sem};
   }
-  if (auto* str = type.ast->As<ast::Struct>()) {
+  if (auto* str = As<ast::Struct>(type.ast)) {
     return {builder->create<ast::TypeName>(str->name()), type.sem};
   }
   return type;
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index d303610..99d7548 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1396,7 +1396,7 @@
       return_type = type.value;
     }
 
-    if (return_type.ast->Is<ast::Void>()) {
+    if (Is<ast::Void>(return_type.ast)) {
       // crbug.com/tint/677: void has been removed from the language
       deprecated(tok.source(),
                  "omit '-> void' for functions that do not return a value");
diff --git a/src/sem/info.h b/src/sem/info.h
index 412e3ef..b442df2 100644
--- a/src/sem/info.h
+++ b/src/sem/info.h
@@ -15,6 +15,7 @@
 #ifndef SRC_SEM_INFO_H_
 #define SRC_SEM_INFO_H_
 
+#include <type_traits>
 #include <unordered_map>
 
 #include "src/debug.h"
@@ -26,6 +27,9 @@
 
 /// Info holds all the resolved semantic information for a Program.
 class Info {
+  /// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM
+  using InferFromAST = std::nullptr_t;
+
  public:
   /// Constructor
   Info();
@@ -44,14 +48,18 @@
   /// Get looks up the semantic information for the AST or type node `node`.
   /// @param node the AST or type node
   /// @returns a pointer to the semantic node if found, otherwise nullptr
-  template <typename AST_OR_TYPE,
-            typename SEM = SemanticNodeTypeFor<AST_OR_TYPE>>
-  const SEM* Get(const AST_OR_TYPE* node) const {
+  template <typename SEM = InferFromAST,
+            typename AST_OR_TYPE = CastableBase,
+            typename RESULT =
+                std::conditional_t<std::is_same<SEM, InferFromAST>::value,
+                                   SemanticNodeTypeFor<AST_OR_TYPE>,
+                                   SEM>>
+  const RESULT* Get(const AST_OR_TYPE* node) const {
     auto it = map.find(node);
     if (it == map.end()) {
       return nullptr;
     }
-    return it->second->template As<SEM>();
+    return As<RESULT>(it->second);
   }
 
   /// Add registers the semantic node `sem_node` for the AST or type node
diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc
index 44a768b..50b51a5 100644
--- a/src/transform/decompose_storage_access.cc
+++ b/src/transform/decompose_storage_access.cc
@@ -627,8 +627,7 @@
   for (auto* node : ctx.src->ASTNodes().Objects()) {
     if (auto* ident = node->As<ast::IdentifierExpression>()) {
       // X
-      auto* expr = sem.Get(ident);
-      if (auto* var = expr->As<sem::VariableUser>()) {
+      if (auto* var = sem.Get<sem::VariableUser>(ident)) {
         if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
           // Variable to a storage buffer
           state.AddAccesss(ident, {
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 20e869d..1ef89e1 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -148,19 +148,22 @@
 
     // Fix up all references to the builtins with the offsets
     ctx.ReplaceAll([=, &ctx](ast::Expression* expr) -> ast::Expression* {
-      auto* sem = ctx.src->Sem().Get(expr);
-      if (auto* user = sem->As<sem::VariableUser>()) {
-        auto it = builtin_vars.find(user->Variable());
-        if (it != builtin_vars.end()) {
-          return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
-                              ctx.dst->MemberAccessor(buffer_name, it->second));
+      if (auto* sem = ctx.src->Sem().Get(expr)) {
+        if (auto* user = sem->As<sem::VariableUser>()) {
+          auto it = builtin_vars.find(user->Variable());
+          if (it != builtin_vars.end()) {
+            return ctx.dst->Add(
+                ctx.CloneWithoutTransform(expr),
+                ctx.dst->MemberAccessor(buffer_name, it->second));
+          }
         }
-      }
-      if (auto* access = sem->As<sem::StructMemberAccess>()) {
-        auto it = builtin_members.find(access->Member());
-        if (it != builtin_members.end()) {
-          return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
-                              ctx.dst->MemberAccessor(buffer_name, it->second));
+        if (auto* access = sem->As<sem::StructMemberAccess>()) {
+          auto it = builtin_members.find(access->Member());
+          if (it != builtin_members.end()) {
+            return ctx.dst->Add(
+                ctx.CloneWithoutTransform(expr),
+                ctx.dst->MemberAccessor(buffer_name, it->second));
+          }
         }
       }
       // Not interested in this experssion. Just clone.