Manual cherry-pick of https://dawn-review.googlesource.com/c/tint/+/71681 Original commit message (modified): resolver: Fixes for bitcasts Add missing validation for bitcasts. We were permitting any bitcast that wasn't a being cast to a pointer type, when the spec only allows: * numeric_scalar to numeric_scalar * vecN<numeric_scalar> to vecN<numeric_scalar> Add lots of tests. Fixed: chromium:1276320 Change-Id: Iaaed4759234be1ed739e3c016d27679bde081ddc Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/72800 Reviewed-by: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 97e92f6..c92e231 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt
@@ -666,6 +666,7 @@ resolver/assignment_validation_test.cc resolver/atomics_test.cc resolver/atomics_validation_test.cc + resolver/bitcast_validation_test.cc resolver/builtins_validation_test.cc resolver/call_test.cc resolver/call_validation_test.cc
diff --git a/src/resolver/bitcast_validation_test.cc b/src/resolver/bitcast_validation_test.cc new file mode 100644 index 0000000..d4ce082 --- /dev/null +++ b/src/resolver/bitcast_validation_test.cc
@@ -0,0 +1,228 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/bitcast_expression.h" +#include "src/resolver/resolver.h" +#include "src/resolver/resolver_test_helper.h" + +#include "gmock/gmock.h" + +namespace tint { +namespace resolver { +namespace { + +struct Type { + template <typename T> + static constexpr Type Create() { + return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem, + builder::DataType<T>::Expr}; + } + + builder::ast_type_func_ptr ast; + builder::sem_type_func_ptr sem; + builder::ast_expr_func_ptr expr; +}; + +static constexpr Type kNumericScalars[] = { + Type::Create<builder::f32>(), + Type::Create<builder::i32>(), + Type::Create<builder::u32>(), +}; +static constexpr Type kVec2NumericScalars[] = { + Type::Create<builder::vec2<builder::f32>>(), + Type::Create<builder::vec2<builder::i32>>(), + Type::Create<builder::vec2<builder::u32>>(), +}; +static constexpr Type kVec3NumericScalars[] = { + Type::Create<builder::vec3<builder::f32>>(), + Type::Create<builder::vec3<builder::i32>>(), + Type::Create<builder::vec3<builder::u32>>(), +}; +static constexpr Type kVec4NumericScalars[] = { + Type::Create<builder::vec4<builder::f32>>(), + Type::Create<builder::vec4<builder::i32>>(), + Type::Create<builder::vec4<builder::u32>>(), +}; +static constexpr Type kInvalid[] = { + // A non-exhaustive selection of uncastable types + Type::Create<bool>(), + Type::Create<builder::vec2<bool>>(), + Type::Create<builder::vec3<bool>>(), + Type::Create<builder::vec4<bool>>(), + Type::Create<builder::array<2, builder::i32>>(), + Type::Create<builder::array<3, builder::u32>>(), + Type::Create<builder::array<4, builder::f32>>(), + Type::Create<builder::array<5, bool>>(), + Type::Create<builder::mat2x2<builder::f32>>(), + Type::Create<builder::mat3x3<builder::f32>>(), + Type::Create<builder::mat4x4<builder::f32>>(), + Type::Create<builder::ptr<builder::i32>>(), + Type::Create<builder::ptr<builder::array<2, builder::i32>>>(), + Type::Create<builder::ptr<builder::mat2x2<builder::f32>>>(), +}; + +using ResolverBitcastValidationTest = + ResolverTestWithParam<std::tuple<Type, Type>>; + +//////////////////////////////////////////////////////////////////////////////// +// Valid bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestPass, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0)); + WrapInFunction(cast); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + EXPECT_EQ(TypeOf(cast), dst.sem(*this)); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kNumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestPass, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kVec4NumericScalars))); + +//////////////////////////////////////////////////////////////////////////////// +// Invalid source type for bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src")); + WrapInFunction(Const("src", nullptr, src.expr(*this, 0)), cast); + + auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) + + "' cannot be bitcast"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kNumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestInvalidSrcTy, + testing::Combine(testing::ValuesIn(kInvalid), + testing::ValuesIn(kVec4NumericScalars))); + +//////////////////////////////////////////////////////////////////////////////// +// Invalid target type for bitcasts +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + // Use an alias so we can put a Source on the bitcast type + Alias("T", dst.ast(*this)); + WrapInFunction( + Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0))); + + auto expected = "12:34 error: cannot bitcast to '" + + dst.sem(*this)->FriendlyName(Symbols()) + "'"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P(Scalars, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec2, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec3, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kInvalid))); +INSTANTIATE_TEST_SUITE_P( + Vec4, + ResolverBitcastValidationTestInvalidDstTy, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kInvalid))); + +//////////////////////////////////////////////////////////////////////////////// +// Incompatible bitcast, but both src and dst types are valid +//////////////////////////////////////////////////////////////////////////////// +using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest; +TEST_P(ResolverBitcastValidationTestIncompatible, Test) { + auto src = std::get<0>(GetParam()); + auto dst = std::get<1>(GetParam()); + + WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0))); + + auto expected = "12:34 error: cannot bitcast from '" + + src.sem(*this)->FriendlyName(Symbols()) + "' to '" + + dst.sem(*this)->FriendlyName(Symbols()) + "'"; + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), expected); +} +INSTANTIATE_TEST_SUITE_P( + ScalarToVec2, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kNumericScalars), + testing::ValuesIn(kVec2NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec2ToVec3, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec2NumericScalars), + testing::ValuesIn(kVec3NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec3ToVec4, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec3NumericScalars), + testing::ValuesIn(kVec4NumericScalars))); +INSTANTIATE_TEST_SUITE_P( + Vec4ToScalar, + ResolverBitcastValidationTestIncompatible, + testing::Combine(testing::ValuesIn(kVec4NumericScalars), + testing::ValuesIn(kNumericScalars))); + +} // namespace +} // namespace resolver +} // namespace tint
diff --git a/src/resolver/ptr_ref_validation_test.cc b/src/resolver/ptr_ref_validation_test.cc index 59e7b44..a1c6c10 100644 --- a/src/resolver/ptr_ref_validation_test.cc +++ b/src/resolver/ptr_ref_validation_test.cc
@@ -143,19 +143,6 @@ "'ptr<storage, i32, read_write>'"); } -TEST_F(ResolverTest, Expr_Bitcast_ptr) { - auto* vf = Var("vf", ty.f32()); - auto* bitcast = create<ast::BitcastExpression>( - Source{{12, 34}}, ty.pointer<i32>(ast::StorageClass::kFunction), - Expr("vf")); - auto* ip = - Const("ip", ty.pointer<i32>(ast::StorageClass::kFunction), bitcast); - WrapInFunction(Decl(vf), Decl(ip)); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: cannot cast to a pointer"); -} - } // namespace } // namespace resolver } // namespace tint
diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 3c52f45..44854b5 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc
@@ -2408,11 +2408,47 @@ if (!ty) { return false; } - if (ty->Is<sem::Pointer>()) { - AddError("cannot cast to a pointer", expr->source); + + SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols())); + + if (!ValidateBitcast(expr, ty)) { return false; } - SetExprInfo(expr, ty, expr->type->FriendlyName(builder_->Symbols())); + + return true; +} + +bool Resolver::ValidateBitcast(const ast::BitcastExpression* cast, + const sem::Type* to) { + auto TypeNameOf = [this](const sem::Type* ty) { + return ty->UnwrapRef()->FriendlyName(builder_->Symbols()); + }; + + auto* from = TypeOf(cast->expr)->UnwrapRef(); + if (!from->is_numeric_scalar_or_vector()) { + AddError("'" + TypeNameOf(from) + "' cannot be bitcast", + cast->expr->source); + return false; + } + if (!to->is_numeric_scalar_or_vector()) { + AddError("cannot bitcast to '" + TypeNameOf(to) + "'", cast->type->source); + return false; + } + + auto width = [&](const sem::Type* ty) { + if (auto* vec = ty->As<sem::Vector>()) { + return vec->Width(); + } + return 1u; + }; + + if (width(from) != width(to)) { + AddError("cannot bitcast from '" + TypeNameOf(from) + "' to '" + + TypeNameOf(to) + "'", + cast->source); + return false; + } + return true; }
diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 212360b..0827da8 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h
@@ -275,6 +275,7 @@ bool ValidateBuiltinDecoration(const ast::BuiltinDecoration* deco, const sem::Type* storage_type, const bool is_input); + bool ValidateBitcast(const ast::BitcastExpression* cast, const sem::Type* to); bool ValidateCall(const ast::CallExpression* call); bool ValidateCallStatement(const ast::CallStatement* stmt); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 9f5ff52..bc50d93 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h
@@ -171,6 +171,9 @@ template <typename TO> using alias3 = alias<TO, 3>; +template <typename TO> +struct ptr {}; + using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, int elem_value); @@ -387,6 +390,36 @@ } }; +/// Helper for building pointer types and expressions +template <typename T> +struct DataType<ptr<T>> { + /// true if the pointer type is a composite type + static constexpr bool is_composite = false; + + /// @param b the ProgramBuilder + /// @return a new AST alias type + static inline const ast::Type* AST(ProgramBuilder& b) { + return b.create<ast::Pointer>(DataType<T>::AST(b), + ast::StorageClass::kPrivate, + ast::Access::kReadWrite); + } + /// @param b the ProgramBuilder + /// @return the semantic aliased type + static inline const sem::Type* Sem(ProgramBuilder& b) { + return b.create<sem::Pointer>(DataType<T>::Sem(b), + ast::StorageClass::kPrivate, + ast::Access::kReadWrite); + } + + /// @param b the ProgramBuilder + /// @return a new AST expression of the alias type + static inline const ast::Expression* Expr(ProgramBuilder& b, int /*unused*/) { + auto sym = b.Symbols().New("global_for_ptr"); + b.Global(sym, DataType<T>::AST(b), ast::StorageClass::kPrivate); + return b.AddressOf(sym); + } +}; + /// Helper for building array types and expressions template <int N, typename T> struct DataType<array<N, T>> { @@ -401,7 +434,14 @@ /// @param b the ProgramBuilder /// @return the semantic array type static inline const sem::Type* Sem(ProgramBuilder& b) { - return b.create<sem::Array>(DataType<T>::Sem(b), N); + auto* el = DataType<T>::Sem(b); + return b.create<sem::Array>( + /* element */ el, + /* count */ N, + /* align */ el->Align(), + /* size */ el->Size(), + /* stride */ el->Align(), + /* implicit_stride */ el->Align()); } /// @param b the ProgramBuilder /// @param elem_value the value each element in the array will be initialized
diff --git a/src/utils/unique_vector.h b/src/utils/unique_vector.h index 1ae751a..f2c880b 100644 --- a/src/utils/unique_vector.h +++ b/src/utils/unique_vector.h
@@ -15,6 +15,7 @@ #ifndef SRC_UTILS_UNIQUE_VECTOR_H_ #define SRC_UTILS_UNIQUE_VECTOR_H_ +#include <cstddef> #include <unordered_set> #include <vector>
diff --git a/src/writer/glsl/generator_impl_binary_test.cc b/src/writer/glsl/generator_impl_binary_test.cc index c1b6724..6f9554d 100644 --- a/src/writer/glsl/generator_impl_binary_test.cc +++ b/src/writer/glsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@ )"); } -TEST_F(GlslGeneratorImplTest_Binary, Bitcast_WithLogical) { - // as<i32>(a && (b || c)) - - Global("a", ty.bool_(), ast::StorageClass::kPrivate); - Global("b", ty.bool_(), ast::StorageClass::kPrivate); - Global("c", ty.bool_(), ast::StorageClass::kPrivate); - - auto* expr = create<ast::BitcastExpression>( - ty.i32(), create<ast::BinaryExpression>( - ast::BinaryOp::kLogicalAnd, Expr("a"), - create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, - Expr("b"), Expr("c")))); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - std::stringstream out; - ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(gen.result(), R"(bool tint_tmp = a; -if (tint_tmp) { - bool tint_tmp_1 = b; - if (!tint_tmp_1) { - tint_tmp_1 = c; - } - tint_tmp = (tint_tmp_1); -} -)"); - EXPECT_EQ(out.str(), R"(asint((tint_tmp)))"); -} - TEST_F(GlslGeneratorImplTest_Binary, Call_WithLogical) { // foo(a && b, c || d, (a || c) && (b || d))
diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index f7184ac..3e9b7f1 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc
@@ -464,36 +464,6 @@ )"); } -TEST_F(HlslGeneratorImplTest_Binary, Bitcast_WithLogical) { - // as<i32>(a && (b || c)) - - Global("a", ty.bool_(), ast::StorageClass::kPrivate); - Global("b", ty.bool_(), ast::StorageClass::kPrivate); - Global("c", ty.bool_(), ast::StorageClass::kPrivate); - - auto* expr = create<ast::BitcastExpression>( - ty.i32(), create<ast::BinaryExpression>( - ast::BinaryOp::kLogicalAnd, Expr("a"), - create<ast::BinaryExpression>(ast::BinaryOp::kLogicalOr, - Expr("b"), Expr("c")))); - WrapInFunction(expr); - - GeneratorImpl& gen = Build(); - - std::stringstream out; - ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error(); - EXPECT_EQ(gen.result(), R"(bool tint_tmp = a; -if (tint_tmp) { - bool tint_tmp_1 = b; - if (!tint_tmp_1) { - tint_tmp_1 = c; - } - tint_tmp = (tint_tmp_1); -} -)"); - EXPECT_EQ(out.str(), R"(asint((tint_tmp)))"); -} - TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) { // foo(a && b, c || d, (a || c) && (b || d))
diff --git a/test/BUILD.gn b/test/BUILD.gn index edd66b9..f5a6ad4 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn
@@ -235,6 +235,7 @@ "../src/resolver/assignment_validation_test.cc", "../src/resolver/atomics_test.cc", "../src/resolver/atomics_validation_test.cc", + "../src/resolver/bitcast_validation_test.cc", "../src/resolver/builtins_validation_test.cc", "../src/resolver/call_test.cc", "../src/resolver/call_validation_test.cc",