Manual cherry-pick of https://dawn-review.googlesource.com/c/tint/+/71681
Original commit message (modified):
resolver: Fixes for bitcasts
Add missing validation for bitcasts. We were permitting any bitcast that wasn't a being cast to a pointer type, when the spec only allows:
* numeric_scalar to numeric_scalar
* vecN<numeric_scalar> to vecN<numeric_scalar>
Add lots of tests.
Fixed: chromium:1276320
Change-Id: Iaaed4759234be1ed739e3c016d27679bde081ddc
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72800
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 97e92f6..c92e231 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -666,6 +666,7 @@
resolver/assignment_validation_test.cc
resolver/atomics_test.cc
resolver/atomics_validation_test.cc
+ resolver/bitcast_validation_test.cc
resolver/builtins_validation_test.cc
resolver/call_test.cc
resolver/call_validation_test.cc
diff --git a/src/resolver/bitcast_validation_test.cc b/src/resolver/bitcast_validation_test.cc
new file mode 100644
index 0000000..d4ce082
--- /dev/null
+++ b/src/resolver/bitcast_validation_test.cc
@@ -0,0 +1,228 @@
+// Copyright 2021 The Tint Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "src/ast/bitcast_expression.h"
+#include "src/resolver/resolver.h"
+#include "src/resolver/resolver_test_helper.h"
+
+#include "gmock/gmock.h"
+
+namespace tint {
+namespace resolver {
+namespace {
+
+struct Type {
+ template <typename T>
+ static constexpr Type Create() {
+ return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
+ builder::DataType<T>::Expr};
+ }
+
+ builder::ast_type_func_ptr ast;
+ builder::sem_type_func_ptr sem;
+ builder::ast_expr_func_ptr expr;
+};
+
+static constexpr Type kNumericScalars[] = {
+ Type::Create<builder::f32>(),
+ Type::Create<builder::i32>(),
+ Type::Create<builder::u32>(),
+};
+static constexpr Type kVec2NumericScalars[] = {
+ Type::Create<builder::vec2<builder::f32>>(),
+ Type::Create<builder::vec2<builder::i32>>(),
+ Type::Create<builder::vec2<builder::u32>>(),
+};
+static constexpr Type kVec3NumericScalars[] = {
+ Type::Create<builder::vec3<builder::f32>>(),
+ Type::Create<builder::vec3<builder::i32>>(),
+ Type::Create<builder::vec3<builder::u32>>(),
+};
+static constexpr Type kVec4NumericScalars[] = {
+ Type::Create<builder::vec4<builder::f32>>(),
+ Type::Create<builder::vec4<builder::i32>>(),
+ Type::Create<builder::vec4<builder::u32>>(),
+};
+static constexpr Type kInvalid[] = {
+ // A non-exhaustive selection of uncastable types
+ Type::Create<bool>(),
+ Type::Create<builder::vec2<bool>>(),
+ Type::Create<builder::vec3<bool>>(),
+ Type::Create<builder::vec4<bool>>(),
+ Type::Create<builder::array<2, builder::i32>>(),
+ Type::Create<builder::array<3, builder::u32>>(),
+ Type::Create<builder::array<4, builder::f32>>(),
+ Type::Create<builder::array<5, bool>>(),
+ Type::Create<builder::mat2x2<builder::f32>>(),
+ Type::Create<builder::mat3x3<builder::f32>>(),
+ Type::Create<builder::mat4x4<builder::f32>>(),
+ Type::Create<builder::ptr<builder::i32>>(),
+ Type::Create<builder::ptr<builder::array<2, builder::i32>>>(),
+ Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(),
+};
+
+using ResolverBitcastValidationTest =
+ ResolverTestWithParam<std::tuple<Type, Type>>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Valid bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestPass, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
+ WrapInFunction(cast);
+
+ ASSERT_TRUE(r()->Resolve()) << r()->error();
+ EXPECT_EQ(TypeOf(cast), dst.sem(*this));
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestPass,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid source type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
+ WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast);
+
+ auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) +
+ "' cannot be bitcast";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kNumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestInvalidSrcTy,
+ testing::Combine(testing::ValuesIn(kInvalid),
+ testing::ValuesIn(kVec4NumericScalars)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Invalid target type for bitcasts
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ // Use an alias so we can put a Source on the bitcast type
+ Alias("T", dst.ast(*this));
+ WrapInFunction(
+ Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0)));
+
+ auto expected = "12:34 error: cannot bitcast to '" +
+ dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(Scalars,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kInvalid)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4,
+ ResolverBitcastValidationTestInvalidDstTy,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kInvalid)));
+
+////////////////////////////////////////////////////////////////////////////////
+// Incompatible bitcast, but both src and dst types are valid
+////////////////////////////////////////////////////////////////////////////////
+using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
+TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
+ auto src = std::get<0>(GetParam());
+ auto dst = std::get<1>(GetParam());
+
+ WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
+
+ auto expected = "12:34 error: cannot bitcast from '" +
+ src.sem(*this)->FriendlyName(Symbols()) + "' to '" +
+ dst.sem(*this)->FriendlyName(Symbols()) + "'";
+
+ EXPECT_FALSE(r()->Resolve());
+ EXPECT_EQ(r()->error(), expected);
+}
+INSTANTIATE_TEST_SUITE_P(
+ ScalarToVec2,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kNumericScalars),
+ testing::ValuesIn(kVec2NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec2ToVec3,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec2NumericScalars),
+ testing::ValuesIn(kVec3NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec3ToVec4,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec3NumericScalars),
+ testing::ValuesIn(kVec4NumericScalars)));
+INSTANTIATE_TEST_SUITE_P(
+ Vec4ToScalar,
+ ResolverBitcastValidationTestIncompatible,
+ testing::Combine(testing::ValuesIn(kVec4NumericScalars),
+ testing::ValuesIn(kNumericScalars)));
+
+} // namespace
+} // namespace resolver
+} // namespace tint
diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc
index 59e7b44..a1c6c10 100644
--- a/src/resolver/ptr_ref_validation_test.cc
+++ b/src/resolver/ptr_ref_validation_test.cc
@@ -143,19 +143,6 @@
"'ptr<storage, i32, read_write>'");
}
-TEST_F(ResolverTest, Expr_Bitcast_ptr) {
- auto* vf = Var("vf", ty.f32());
- auto* bitcast = create<ast::BitcastExpression>(
- Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction),
- Expr("vf"));
- auto* ip =
- Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), bitcast);
- WrapInFunction(Decl(vf), Decl(ip));
-
- EXPECT_FALSE(r()->Resolve());
- EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer");
-}
-
} // namespace
} // namespace resolver
} // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc
index 3c52f45..44854b5 100644
--- a/src/resolver/resolver.cc
+++ b/src/resolver/resolver.cc
@@ -2408,11 +2408,47 @@
if (!ty) {
return false;
}
- if (ty->Is<sem::Pointer>()) {
- AddError("cannot cast to a pointer", expr->source);
+
+ SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols()));
+
+ if (!ValidateBitcast(expr, ty)) {
return false;
}
- SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols()));
+
+ return true;
+}
+
+bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast,
+ const sem::Type* to) {
+ auto TypeNameOf = [this](const sem::Type* ty) {
+ return ty->UnwrapRef()->FriendlyName(builder_->Symbols());
+ };
+
+ auto* from = TypeOf(cast->expr)->UnwrapRef();
+ if (!from->is_numeric_scalar_or_vector()) {
+ AddError("'" + TypeNameOf(from) + "' cannot be bitcast",
+ cast->expr->source);
+ return false;
+ }
+ if (!to->is_numeric_scalar_or_vector()) {
+ AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source);
+ return false;
+ }
+
+ auto width = [&](const sem::Type* ty) {
+ if (auto* vec = ty->As<sem::Vector>()) {
+ return vec->Width();
+ }
+ return 1u;
+ };
+
+ if (width(from) != width(to)) {
+ AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" +
+ TypeNameOf(to) + "'",
+ cast->source);
+ return false;
+ }
+
return true;
}
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h
index 212360b..0827da8 100644
--- a/src/resolver/resolver.h
+++ b/src/resolver/resolver.h
@@ -275,6 +275,7 @@
bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco,
const sem::Type* storage_type,
const bool is_input);
+ bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to);
bool ValidateCall(const ast::CallExpression* call);
bool ValidateCallStatement(const ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h
index 9f5ff52..bc50d93 100644
--- a/src/resolver/resolver_test_helper.h
+++ b/src/resolver/resolver_test_helper.h
@@ -171,6 +171,9 @@
template <typename TO>
using alias3 = alias<TO, 3>;
+template <typename TO>
+struct ptr {};
+
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
int elem_value);
@@ -387,6 +390,36 @@
}
};
+/// Helper for building pointer types and expressions
+template <typename T>
+struct DataType<ptr<T>> {
+ /// true if the pointer type is a composite type
+ static constexpr bool is_composite = false;
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST alias type
+ static inline const ast::Type* AST(ProgramBuilder& b) {
+ return b.create<ast::Pointer>(DataType<T>::AST(b),
+ ast::StorageClass::kPrivate,
+ ast::Access::kReadWrite);
+ }
+ /// @param b the ProgramBuilder
+ /// @return the semantic aliased type
+ static inline const sem::Type* Sem(ProgramBuilder& b) {
+ return b.create<sem::Pointer>(DataType<T>::Sem(b),
+ ast::StorageClass::kPrivate,
+ ast::Access::kReadWrite);
+ }
+
+ /// @param b the ProgramBuilder
+ /// @return a new AST expression of the alias type
+ static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) {
+ auto sym = b.Symbols().New("global_for_ptr");
+ b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate);
+ return b.AddressOf(sym);
+ }
+};
+
/// Helper for building array types and expressions
template <int N, typename T>
struct DataType<array<N, T>> {
@@ -401,7 +434,14 @@
/// @param b the ProgramBuilder
/// @return the semantic array type
static inline const sem::Type* Sem(ProgramBuilder& b) {
- return b.create<sem::Array>(DataType<T>::Sem(b), N);
+ auto* el = DataType<T>::Sem(b);
+ return b.create<sem::Array>(
+ /* element */ el,
+ /* count */ N,
+ /* align */ el->Align(),
+ /* size */ el->Size(),
+ /* stride */ el->Align(),
+ /* implicit_stride */ el->Align());
}
/// @param b the ProgramBuilder
/// @param elem_value the value each element in the array will be initialized
diff --git a/src/utils/unique_vector.h b/src/utils/unique_vector.h
index 1ae751a..f2c880b 100644
--- a/src/utils/unique_vector.h
+++ b/src/utils/unique_vector.h
@@ -15,6 +15,7 @@
#ifndef SRC_UTILS_UNIQUE_VECTOR_H_
#define SRC_UTILS_UNIQUE_VECTOR_H_
+#include <cstddef>
#include <unordered_set>
#include <vector>
diff --git a/src/writer/glsl/generator_impl_binary_test.cc b/src/writer/glsl/generator_impl_binary_test.cc
index c1b6724..6f9554d 100644
--- a/src/writer/glsl/generator_impl_binary_test.cc
+++ b/src/writer/glsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@
)");
}
-TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
- // as<i32>(a && (b || c))
-
- Global("a", ty.bool_(), ast::StorageClass::kPrivate);
- Global("b", ty.bool_(), ast::StorageClass::kPrivate);
- Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
- auto* expr = create<ast::BitcastExpression>(
- ty.i32(), create<ast::BinaryExpression>(
- ast::BinaryOp::kLogicalAnd, Expr("a"),
- create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
- Expr("b"), Expr("c"))));
- WrapInFunction(expr);
-
- GeneratorImpl& gen = Build();
-
- std::stringstream out;
- ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
- bool tint_tmp_1 = b;
- if (!tint_tmp_1) {
- tint_tmp_1 = c;
- }
- tint_tmp = (tint_tmp_1);
-}
-)");
- EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
-}
-
TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) {
// foo(a && b, c || d, (a || c) && (b || d))
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc
index f7184ac..3e9b7f1 100644
--- a/src/writer/hlsl/generator_impl_binary_test.cc
+++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@
)");
}
-TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) {
- // as<i32>(a && (b || c))
-
- Global("a", ty.bool_(), ast::StorageClass::kPrivate);
- Global("b", ty.bool_(), ast::StorageClass::kPrivate);
- Global("c", ty.bool_(), ast::StorageClass::kPrivate);
-
- auto* expr = create<ast::BitcastExpression>(
- ty.i32(), create<ast::BinaryExpression>(
- ast::BinaryOp::kLogicalAnd, Expr("a"),
- create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr,
- Expr("b"), Expr("c"))));
- WrapInFunction(expr);
-
- GeneratorImpl& gen = Build();
-
- std::stringstream out;
- ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
- EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
-if (tint_tmp) {
- bool tint_tmp_1 = b;
- if (!tint_tmp_1) {
- tint_tmp_1 = c;
- }
- tint_tmp = (tint_tmp_1);
-}
-)");
- EXPECT_EQ(out.str(), R"(asint((tint_tmp)))");
-}
-
TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
// foo(a && b, c || d, (a || c) && (b || d))
diff --git a/test/BUILD.gn b/test/BUILD.gn
index edd66b9..f5a6ad4 100644
--- a/test/BUILD.gn
+++ b/test/BUILD.gn
@@ -235,6 +235,7 @@
"../src/resolver/assignment_validation_test.cc",
"../src/resolver/atomics_test.cc",
"../src/resolver/atomics_validation_test.cc",
+ "../src/resolver/bitcast_validation_test.cc",
"../src/resolver/builtins_validation_test.cc",
"../src/resolver/call_test.cc",
"../src/resolver/call_validation_test.cc",