Add support for GLSL cross.
This CL adds support for the GLSL cross method.
Change-Id: Ib2e83a2ef2e580c6ca257851a76f3f66fa377d6f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22842
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 7a65f68..9b1e3c9 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -628,68 +628,77 @@
// * asinh, acosh, atanh
// * exp, exp2
// * log, log2
-enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
+enum class GlslDataType {
+ kFloatScalarOrVector,
+ kIntScalarOrVector,
+ kFloatVector
+};
struct GlslData {
const char* name;
uint8_t param_count;
uint32_t op_id;
GlslDataType type;
+ uint8_t vector_count;
};
constexpr const GlslData kGlslData[] = {
- {"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
- {"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
- {"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
- {"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
- {"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
- {"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
- {"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
- {"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
- {"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
- {"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
- {"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
- {"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
- {"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
- {"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
- {"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
+ {"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector, 0},
+ {"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector, 0},
+ {"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector, 0},
+ {"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector, 0},
+ {"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector, 0},
+ {"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector, 0},
+ {"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector, 0},
+ {"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector, 0},
+ {"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector, 0},
+ {"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector, 0},
+ {"cross", 2, GLSLstd450Cross, GlslDataType::kFloatVector, 3},
+ {"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector, 0},
+ {"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector, 0},
+ {"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector, 0},
+ {"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector, 0},
+ {"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector, 0},
{"faceforward", 3, GLSLstd450FaceForward,
- GlslDataType::kFloatScalarOrVector},
- {"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
- {"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
- {"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
- {"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
- {"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
- {"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
- {"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
- {"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
+ GlslDataType::kFloatScalarOrVector, 0},
+ {"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector, 0},
+ {"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector, 0},
+ {"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector, 0},
+ {"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector, 0},
+ {"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector, 0},
+ {"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector, 0},
+ {"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector, 0},
+ {"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector, 0},
{"inversesqrt", 1, GLSLstd450InverseSqrt,
- GlslDataType::kFloatScalarOrVector},
- {"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
- {"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
- {"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
- {"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
- {"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
- {"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
- {"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
- {"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
- {"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
- {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
- {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
- {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
- {"sabs", 1, GLSLstd450SAbs, GlslDataType::kIntScalarOrVector},
- {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
- {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
- {"smax", 2, GLSLstd450SMax, GlslDataType::kIntScalarOrVector},
- {"smin", 2, GLSLstd450SMin, GlslDataType::kIntScalarOrVector},
- {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
- {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
- {"ssign", 1, GLSLstd450SSign, GlslDataType::kIntScalarOrVector},
- {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
- {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
- {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
- {"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
- {"umax", 2, GLSLstd450UMax, GlslDataType::kIntScalarOrVector},
- {"umin", 2, GLSLstd450UMin, GlslDataType::kIntScalarOrVector},
+ GlslDataType::kFloatScalarOrVector, 0},
+ {"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector, 0},
+ {"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector, 0},
+ {"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector, 0},
+ {"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector, 0},
+ {"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector, 0},
+ {"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector, 0},
+ {"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector,
+ 0},
+ {"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector, 0},
+ {"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector, 0},
+ {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector, 0},
+ {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector, 0},
+ {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector,
+ 0},
+ {"sabs", 1, GLSLstd450SAbs, GlslDataType::kIntScalarOrVector, 0},
+ {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector, 0},
+ {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector, 0},
+ {"smax", 2, GLSLstd450SMax, GlslDataType::kIntScalarOrVector, 0},
+ {"smin", 2, GLSLstd450SMin, GlslDataType::kIntScalarOrVector, 0},
+ {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector,
+ 0},
+ {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector, 0},
+ {"ssign", 1, GLSLstd450SSign, GlslDataType::kIntScalarOrVector, 0},
+ {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector, 0},
+ {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector, 0},
+ {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector, 0},
+ {"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector, 0},
+ {"umax", 2, GLSLstd450UMax, GlslDataType::kIntScalarOrVector, 0},
+ {"umin", 2, GLSLstd450UMin, GlslDataType::kIntScalarOrVector, 0},
};
constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
@@ -742,6 +751,20 @@
return nullptr;
}
break;
+ case GlslDataType::kFloatVector:
+ if (!result_types.back()->is_float_vector()) {
+ set_error(source, "incorrect type for " + name + ". " +
+ "Requires float vector values");
+ return nullptr;
+ }
+ if (data->vector_count > 0 &&
+ result_types.back()->AsVector()->size() != data->vector_count) {
+ set_error(source,
+ "incorrect vector size for " + name + ". " + "Requires " +
+ std::to_string(data->vector_count) + " elements");
+ return nullptr;
+ }
+ break;
}
}
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index e39996c..5871be4 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -2348,8 +2348,186 @@
"incorrect number of parameters for distance. Expected 2 got 3");
}
-using ImportData_ThreeParamTest = TypeDeterminerTestWithParam<GLSLData>;
+TEST_F(TypeDeterminerTest, ImportData_Cross) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec(&f32, 3);
+ ast::ExpressionList vals_1;
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList vals_2;
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_1)));
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_2)));
+
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_NE(type, nullptr);
+ EXPECT_TRUE(type->is_float_vector());
+ EXPECT_EQ(type->AsVector()->size(), 3u);
+ EXPECT_EQ(id, GLSLstd450Cross);
+}
+
+TEST_F(TypeDeterminerTest, ImportData_Cross_Error_Scalar) {
+ ast::type::F32Type f32;
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_EQ(td()->error(),
+ "incorrect type for cross. Requires float vector values");
+}
+
+TEST_F(TypeDeterminerTest, ImportData_Cross_Error_IntType) {
+ ast::type::I32Type i32;
+ ast::type::VectorType vec(&i32, 3);
+
+ ast::ExpressionList vals_1;
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 3)));
+
+ ast::ExpressionList vals_2;
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 1)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::SintLiteral>(&i32, 3)));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_1)));
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_2)));
+
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_EQ(td()->error(),
+ "incorrect type for cross. Requires float vector values");
+}
+
+TEST_F(TypeDeterminerTest, ImportData_Cross_Error_MissingParams) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec(&f32, 3);
+
+ ast::ExpressionList params;
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_EQ(td()->error(),
+ "incorrect number of parameters for cross. Expected 2 got 0");
+}
+
+TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooFewParams) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec(&f32, 3);
+
+ ast::ExpressionList vals_1;
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_1)));
+
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_EQ(td()->error(),
+ "incorrect number of parameters for cross. Expected 2 got 1");
+}
+
+TEST_F(TypeDeterminerTest, ImportData_Cross_Error_TooManyParams) {
+ ast::type::F32Type f32;
+ ast::type::VectorType vec(&f32, 3);
+
+ ast::ExpressionList vals_1;
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_1.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList vals_2;
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_2.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList vals_3;
+ vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 1.0f)));
+ vals_3.push_back(std::make_unique<ast::ScalarConstructorExpression>(
+ std::make_unique<ast::FloatLiteral>(&f32, 3.0f)));
+
+ ast::ExpressionList params;
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_1)));
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_2)));
+ params.push_back(std::make_unique<ast::TypeConstructorExpression>(
+ &vec, std::move(vals_3)));
+
+ ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error();
+
+ uint32_t id = 0;
+ auto* type =
+ td()->GetImportData({0, 0}, "GLSL.std.450", "cross", params, &id);
+ ASSERT_EQ(type, nullptr);
+ EXPECT_EQ(td()->error(),
+ "incorrect number of parameters for cross. Expected 2 got 3");
+}
+
+using ImportData_ThreeParamTest = TypeDeterminerTestWithParam<GLSLData>;
TEST_P(ImportData_ThreeParamTest, Scalar) {
auto param = GetParam();