Fix Undefined Behaviour
All caused by calling Castable::As<> on nullptr objects.
Bug: tint:760
Change-Id: I0a408b3cd58086cfeab5a1af34d643f50f304948
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49523
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
diff --git a/src/castable.h b/src/castable.h
index 884ba5f..5bb8fdb 100644
--- a/src/castable.h
+++ b/src/castable.h
@@ -173,13 +173,20 @@
/// @see CastFlags
template <typename TO, int FLAGS = 0, typename FROM = detail::Infer>
inline TO* As(FROM* obj) {
- using castable =
- typename std::conditional<std::is_const<FROM>::value, const CastableBase,
- CastableBase>::type;
- auto* as_castable = static_cast<castable*>(obj);
+ auto* as_castable = static_cast<CastableBase*>(obj);
return Is<TO, FLAGS>(obj) ? static_cast<TO*>(as_castable) : nullptr;
}
+/// @returns obj dynamically cast to the type `TO` or `nullptr` if
+/// this object does not derive from `TO`.
+/// @param obj the object to cast from
+/// @see CastFlags
+template <typename TO, int FLAGS = 0, typename FROM = detail::Infer>
+inline const TO* As(const FROM* obj) {
+ auto* as_castable = static_cast<const CastableBase*>(obj);
+ return Is<TO, FLAGS>(obj) ? static_cast<const TO*>(as_castable) : nullptr;
+}
+
/// CastableBase is the base class for all Castable objects.
/// It is not encouraged to directly derive from CastableBase without using the
/// Castable helper template.
diff --git a/src/program_builder.cc b/src/program_builder.cc
index 7689f98..c16f878 100644
--- a/src/program_builder.cc
+++ b/src/program_builder.cc
@@ -132,10 +132,10 @@
typ::Type ProgramBuilder::TypesBuilder::MaybeCreateTypename(
typ::Type type) const {
- if (auto* alias = type.ast->As<ast::Alias>()) {
+ if (auto* alias = As<ast::Alias>(type.ast)) {
return {builder->create<ast::TypeName>(alias->symbol()), type.sem};
}
- if (auto* str = type.ast->As<ast::Struct>()) {
+ if (auto* str = As<ast::Struct>(type.ast)) {
return {builder->create<ast::TypeName>(str->name()), type.sem};
}
return type;
diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc
index d303610..99d7548 100644
--- a/src/reader/wgsl/parser_impl.cc
+++ b/src/reader/wgsl/parser_impl.cc
@@ -1396,7 +1396,7 @@
return_type = type.value;
}
- if (return_type.ast->Is<ast::Void>()) {
+ if (Is<ast::Void>(return_type.ast)) {
// crbug.com/tint/677: void has been removed from the language
deprecated(tok.source(),
"omit '-> void' for functions that do not return a value");
diff --git a/src/sem/info.h b/src/sem/info.h
index 412e3ef..b442df2 100644
--- a/src/sem/info.h
+++ b/src/sem/info.h
@@ -15,6 +15,7 @@
#ifndef SRC_SEM_INFO_H_
#define SRC_SEM_INFO_H_
+#include <type_traits>
#include <unordered_map>
#include "src/debug.h"
@@ -26,6 +27,9 @@
/// Info holds all the resolved semantic information for a Program.
class Info {
+ /// Placeholder type used by Get() to provide a default value for EXPLICIT_SEM
+ using InferFromAST = std::nullptr_t;
+
public:
/// Constructor
Info();
@@ -44,14 +48,18 @@
/// Get looks up the semantic information for the AST or type node `node`.
/// @param node the AST or type node
/// @returns a pointer to the semantic node if found, otherwise nullptr
- template <typename AST_OR_TYPE,
- typename SEM = SemanticNodeTypeFor<AST_OR_TYPE>>
- const SEM* Get(const AST_OR_TYPE* node) const {
+ template <typename SEM = InferFromAST,
+ typename AST_OR_TYPE = CastableBase,
+ typename RESULT =
+ std::conditional_t<std::is_same<SEM, InferFromAST>::value,
+ SemanticNodeTypeFor<AST_OR_TYPE>,
+ SEM>>
+ const RESULT* Get(const AST_OR_TYPE* node) const {
auto it = map.find(node);
if (it == map.end()) {
return nullptr;
}
- return it->second->template As<SEM>();
+ return As<RESULT>(it->second);
}
/// Add registers the semantic node `sem_node` for the AST or type node
diff --git a/src/transform/decompose_storage_access.cc b/src/transform/decompose_storage_access.cc
index 44a768b..50b51a5 100644
--- a/src/transform/decompose_storage_access.cc
+++ b/src/transform/decompose_storage_access.cc
@@ -627,8 +627,7 @@
for (auto* node : ctx.src->ASTNodes().Objects()) {
if (auto* ident = node->As<ast::IdentifierExpression>()) {
// X
- auto* expr = sem.Get(ident);
- if (auto* var = expr->As<sem::VariableUser>()) {
+ if (auto* var = sem.Get<sem::VariableUser>(ident)) {
if (var->Variable()->StorageClass() == ast::StorageClass::kStorage) {
// Variable to a storage buffer
state.AddAccesss(ident, {
diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc
index 20e869d..1ef89e1 100644
--- a/src/transform/first_index_offset.cc
+++ b/src/transform/first_index_offset.cc
@@ -148,19 +148,22 @@
// Fix up all references to the builtins with the offsets
ctx.ReplaceAll([=, &ctx](ast::Expression* expr) -> ast::Expression* {
- auto* sem = ctx.src->Sem().Get(expr);
- if (auto* user = sem->As<sem::VariableUser>()) {
- auto it = builtin_vars.find(user->Variable());
- if (it != builtin_vars.end()) {
- return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
- ctx.dst->MemberAccessor(buffer_name, it->second));
+ if (auto* sem = ctx.src->Sem().Get(expr)) {
+ if (auto* user = sem->As<sem::VariableUser>()) {
+ auto it = builtin_vars.find(user->Variable());
+ if (it != builtin_vars.end()) {
+ return ctx.dst->Add(
+ ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
}
- }
- if (auto* access = sem->As<sem::StructMemberAccess>()) {
- auto it = builtin_members.find(access->Member());
- if (it != builtin_members.end()) {
- return ctx.dst->Add(ctx.CloneWithoutTransform(expr),
- ctx.dst->MemberAccessor(buffer_name, it->second));
+ if (auto* access = sem->As<sem::StructMemberAccess>()) {
+ auto it = builtin_members.find(access->Member());
+ if (it != builtin_members.end()) {
+ return ctx.dst->Add(
+ ctx.CloneWithoutTransform(expr),
+ ctx.dst->MemberAccessor(buffer_name, it->second));
+ }
}
}
// Not interested in this experssion. Just clone.