blob: c979c55747e4597bf69d6b6349756a8575d8510e [file] [log] [blame]
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/ast/bitcast_expression.h"
#include "src/tint/resolver/resolver.h"
#include "src/tint/resolver/resolver_test_helper.h"
#include "gmock/gmock.h"
namespace tint::resolver {
namespace {
using namespace tint::builtin::fluent_types; // NOLINT
struct Type {
template <typename T>
static constexpr Type Create() {
return Type{builder::DataType<T>::AST, builder::DataType<T>::Sem,
builder::DataType<T>::ExprFromDouble};
}
builder::ast_type_func_ptr ast;
builder::sem_type_func_ptr sem;
builder::ast_expr_from_double_func_ptr expr;
};
static constexpr Type kNumericScalars[] = {
Type::Create<f32>(),
Type::Create<i32>(),
Type::Create<u32>(),
};
static constexpr Type kVec2NumericScalars[] = {
Type::Create<vec2<f32>>(),
Type::Create<vec2<i32>>(),
Type::Create<vec2<u32>>(),
};
static constexpr Type kVec3NumericScalars[] = {
Type::Create<vec3<f32>>(),
Type::Create<vec3<i32>>(),
Type::Create<vec3<u32>>(),
};
static constexpr Type kVec4NumericScalars[] = {
Type::Create<vec4<f32>>(),
Type::Create<vec4<i32>>(),
Type::Create<vec4<u32>>(),
};
static constexpr Type kInvalid[] = {
// A non-exhaustive selection of uncastable types
Type::Create<bool>(),
Type::Create<vec2<bool>>(),
Type::Create<vec3<bool>>(),
Type::Create<vec4<bool>>(),
Type::Create<array<i32, 2>>(),
Type::Create<array<u32, 3>>(),
Type::Create<array<f32, 4>>(),
Type::Create<array<bool, 5>>(),
Type::Create<mat2x2<f32>>(),
Type::Create<mat3x3<f32>>(),
Type::Create<mat4x4<f32>>(),
Type::Create<ptr<private_, i32>>(),
Type::Create<ptr<private_, array<i32, 2>>>(),
Type::Create<ptr<private_, mat2x2<f32>>>(),
};
using ResolverBitcastValidationTest = ResolverTestWithParam<std::tuple<Type, Type>>;
////////////////////////////////////////////////////////////////////////////////
// Valid bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestPass = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestPass, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
auto* cast = Bitcast(dst.ast(*this), src.expr(*this, 0));
WrapInFunction(cast);
ASSERT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(TypeOf(cast), dst.sem(*this));
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kNumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec2,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec3,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec4,
ResolverBitcastValidationTestPass,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kVec4NumericScalars)));
////////////////////////////////////////////////////////////////////////////////
// Invalid source type for bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestInvalidSrcTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidSrcTy, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
auto* cast = Bitcast(dst.ast(*this), Expr(Source{{12, 34}}, "src"));
WrapInFunction(Let("src", src.expr(*this, 0)), cast);
auto expected = "12:34 error: '" + src.sem(*this)->FriendlyName() + "' cannot be bitcast";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kNumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec2,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec3,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec4,
ResolverBitcastValidationTestInvalidSrcTy,
testing::Combine(testing::ValuesIn(kInvalid),
testing::ValuesIn(kVec4NumericScalars)));
////////////////////////////////////////////////////////////////////////////////
// Invalid target type for bitcasts
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestInvalidDstTy = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestInvalidDstTy, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
// Use an alias so we can put a Source on the bitcast type
Alias("T", dst.ast(*this));
WrapInFunction(Bitcast(ty(Source{{12, 34}}, "T"), src.expr(*this, 0)));
auto expected = "12:34 error: cannot bitcast to '" + dst.sem(*this)->FriendlyName() + "'";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(Scalars,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(Vec2,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(Vec3,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kInvalid)));
INSTANTIATE_TEST_SUITE_P(Vec4,
ResolverBitcastValidationTestInvalidDstTy,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kInvalid)));
////////////////////////////////////////////////////////////////////////////////
// Incompatible bitcast, but both src and dst types are valid
////////////////////////////////////////////////////////////////////////////////
using ResolverBitcastValidationTestIncompatible = ResolverBitcastValidationTest;
TEST_P(ResolverBitcastValidationTestIncompatible, Test) {
auto src = std::get<0>(GetParam());
auto dst = std::get<1>(GetParam());
WrapInFunction(Bitcast(Source{{12, 34}}, dst.ast(*this), src.expr(*this, 0)));
auto expected = "12:34 error: cannot bitcast from '" + src.sem(*this)->FriendlyName() +
"' to '" + dst.sem(*this)->FriendlyName() + "'";
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), expected);
}
INSTANTIATE_TEST_SUITE_P(ScalarToVec2,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kNumericScalars),
testing::ValuesIn(kVec2NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec2ToVec3,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec2NumericScalars),
testing::ValuesIn(kVec3NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec3ToVec4,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec3NumericScalars),
testing::ValuesIn(kVec4NumericScalars)));
INSTANTIATE_TEST_SUITE_P(Vec4ToScalar,
ResolverBitcastValidationTestIncompatible,
testing::Combine(testing::ValuesIn(kVec4NumericScalars),
testing::ValuesIn(kNumericScalars)));
} // namespace
} // namespace tint::resolver