| // Copyright 2021 The Tint Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "src/tint/ast/bitcast_expression.h" |
| #include "src/tint/resolver/resolver.h" |
| #include "src/tint/resolver/resolver_test_helper.h" |
| |
| #include "gmock/gmock.h" |
| |
| namespace tint::resolver { |
| namespace { |
| |
| struct Type { |
| template <typename T> |
| static constexpr Type Create() { |
| return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem, |
| builder::DataType<T>::ExprFromDouble}; |
| } |
| |
| builder::ast_type_func_ptr ast; |
| builder::sem_type_func_ptr sem; |
| builder::ast_expr_from_double_func_ptr expr; |
| }; |
| |
| static constexpr Type kNumericScalars[] = { |
| Type::Create<f32>(), |
| Type::Create<i32>(), |
| Type::Create<u32>(), |
| }; |
| static constexpr Type kVec2NumericScalars[] = { |
| Type::Create<builder::vec2<f32>>(), |
| Type::Create<builder::vec2<i32>>(), |
| Type::Create<builder::vec2<u32>>(), |
| }; |
| static constexpr Type kVec3NumericScalars[] = { |
| Type::Create<builder::vec3<f32>>(), |
| Type::Create<builder::vec3<i32>>(), |
| Type::Create<builder::vec3<u32>>(), |
| }; |
| static constexpr Type kVec4NumericScalars[] = { |
| Type::Create<builder::vec4<f32>>(), |
| Type::Create<builder::vec4<i32>>(), |
| Type::Create<builder::vec4<u32>>(), |
| }; |
| static constexpr Type kInvalid[] = { |
| // A non-exhaustive selection of uncastable types |
| Type::Create<bool>(), |
| Type::Create<builder::vec2<bool>>(), |
| Type::Create<builder::vec3<bool>>(), |
| Type::Create<builder::vec4<bool>>(), |
| Type::Create<builder::array<2, i32>>(), |
| Type::Create<builder::array<3, u32>>(), |
| Type::Create<builder::array<4, f32>>(), |
| Type::Create<builder::array<5, bool>>(), |
| Type::Create<builder::mat2x2<f32>>(), |
| Type::Create<builder::mat3x3<f32>>(), |
| Type::Create<builder::mat4x4<f32>>(), |
| Type::Create<builder::ptr<i32>>(), |
| Type::Create<builder::ptr<builder::array<2, i32>>>(), |
| Type::Create<builder::ptr<builder::mat2x2<f32>>>(), |
| }; |
| |
| using ResolverBitcastValidationTest = ResolverTestWithParam<std::tuple<Type, Type>>; |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Valid bitcasts |
| //////////////////////////////////////////////////////////////////////////////// |
| using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest; |
| TEST_P(ResolverBitcastValidationTestPass, Test) { |
| auto src = std::get<0>(GetParam()); |
| auto dst = std::get<1>(GetParam()); |
| |
| auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0)); |
| WrapInFunction(cast); |
| |
| ASSERT_TRUE(r()->Resolve()) << r()->error(); |
| EXPECT_EQ(TypeOf(cast), dst.sem(*this)); |
| } |
| INSTANTIATE_TEST_SUITE_P(Scalars, |
| ResolverBitcastValidationTestPass, |
| testing::Combine(testing::ValuesIn(kNumericScalars), |
| testing::ValuesIn(kNumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec2, |
| ResolverBitcastValidationTestPass, |
| testing::Combine(testing::ValuesIn(kVec2NumericScalars), |
| testing::ValuesIn(kVec2NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec3, |
| ResolverBitcastValidationTestPass, |
| testing::Combine(testing::ValuesIn(kVec3NumericScalars), |
| testing::ValuesIn(kVec3NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec4, |
| ResolverBitcastValidationTestPass, |
| testing::Combine(testing::ValuesIn(kVec4NumericScalars), |
| testing::ValuesIn(kVec4NumericScalars))); |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Invalid source type for bitcasts |
| //////////////////////////////////////////////////////////////////////////////// |
| using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest; |
| TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) { |
| auto src = std::get<0>(GetParam()); |
| auto dst = std::get<1>(GetParam()); |
| |
| auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src")); |
| WrapInFunction(Let("src", src.expr(*this, 0)), cast); |
| |
| auto expected = |
| "12:34 error: '" + src.sem(*this)->FriendlyName(Symbols()) + "' cannot be bitcast"; |
| |
| EXPECT_FALSE(r()->Resolve()); |
| EXPECT_EQ(r()->error(), expected); |
| } |
| INSTANTIATE_TEST_SUITE_P(Scalars, |
| ResolverBitcastValidationTestInvalidSrcTy, |
| testing::Combine(testing::ValuesIn(kInvalid), |
| testing::ValuesIn(kNumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec2, |
| ResolverBitcastValidationTestInvalidSrcTy, |
| testing::Combine(testing::ValuesIn(kInvalid), |
| testing::ValuesIn(kVec2NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec3, |
| ResolverBitcastValidationTestInvalidSrcTy, |
| testing::Combine(testing::ValuesIn(kInvalid), |
| testing::ValuesIn(kVec3NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec4, |
| ResolverBitcastValidationTestInvalidSrcTy, |
| testing::Combine(testing::ValuesIn(kInvalid), |
| testing::ValuesIn(kVec4NumericScalars))); |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Invalid target type for bitcasts |
| //////////////////////////////////////////////////////////////////////////////// |
| using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest; |
| TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) { |
| auto src = std::get<0>(GetParam()); |
| auto dst = std::get<1>(GetParam()); |
| |
| // Use an alias so we can put a Source on the bitcast type |
| Alias("T", dst.ast(*this)); |
| WrapInFunction(Bitcast(ty.type_name(Source{{12, 34}}, "T"), src.expr(*this, 0))); |
| |
| auto expected = |
| "12:34 error: cannot bitcast to '" + dst.sem(*this)->FriendlyName(Symbols()) + "'"; |
| |
| EXPECT_FALSE(r()->Resolve()); |
| EXPECT_EQ(r()->error(), expected); |
| } |
| INSTANTIATE_TEST_SUITE_P(Scalars, |
| ResolverBitcastValidationTestInvalidDstTy, |
| testing::Combine(testing::ValuesIn(kNumericScalars), |
| testing::ValuesIn(kInvalid))); |
| INSTANTIATE_TEST_SUITE_P(Vec2, |
| ResolverBitcastValidationTestInvalidDstTy, |
| testing::Combine(testing::ValuesIn(kVec2NumericScalars), |
| testing::ValuesIn(kInvalid))); |
| INSTANTIATE_TEST_SUITE_P(Vec3, |
| ResolverBitcastValidationTestInvalidDstTy, |
| testing::Combine(testing::ValuesIn(kVec3NumericScalars), |
| testing::ValuesIn(kInvalid))); |
| INSTANTIATE_TEST_SUITE_P(Vec4, |
| ResolverBitcastValidationTestInvalidDstTy, |
| testing::Combine(testing::ValuesIn(kVec4NumericScalars), |
| testing::ValuesIn(kInvalid))); |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Incompatible bitcast, but both src and dst types are valid |
| //////////////////////////////////////////////////////////////////////////////// |
| using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest; |
| TEST_P(ResolverBitcastValidationTestIncompatible, Test) { |
| auto src = std::get<0>(GetParam()); |
| auto dst = std::get<1>(GetParam()); |
| |
| WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0))); |
| |
| auto expected = "12:34 error: cannot bitcast from '" + src.sem(*this)->FriendlyName(Symbols()) + |
| "' to '" + dst.sem(*this)->FriendlyName(Symbols()) + "'"; |
| |
| EXPECT_FALSE(r()->Resolve()); |
| EXPECT_EQ(r()->error(), expected); |
| } |
| INSTANTIATE_TEST_SUITE_P(ScalarToVec2, |
| ResolverBitcastValidationTestIncompatible, |
| testing::Combine(testing::ValuesIn(kNumericScalars), |
| testing::ValuesIn(kVec2NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec2ToVec3, |
| ResolverBitcastValidationTestIncompatible, |
| testing::Combine(testing::ValuesIn(kVec2NumericScalars), |
| testing::ValuesIn(kVec3NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec3ToVec4, |
| ResolverBitcastValidationTestIncompatible, |
| testing::Combine(testing::ValuesIn(kVec3NumericScalars), |
| testing::ValuesIn(kVec4NumericScalars))); |
| INSTANTIATE_TEST_SUITE_P(Vec4ToScalar, |
| ResolverBitcastValidationTestIncompatible, |
| testing::Combine(testing::ValuesIn(kVec4NumericScalars), |
| testing::ValuesIn(kNumericScalars))); |
| |
| } // namespace |
| } // namespace tint::resolver |