Refactor GLSL type determination code.
This Cl cleanups and simplifies the type determination for the GLSL
imports.
Change-Id: I9dd85ac390ef37c91d9493f840f81ceb6736fc06
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22820
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 8d455bd..34fc922 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -617,6 +617,76 @@
return true;
}
+// Most of these are floating-point general except the below which are only
+// FP16 and FP32. We only have FP32 at this point so the below works, if we
+// get FP64 support or otherwise we'll need to differentiate.
+// * radians
+// * degrees
+// * sin, cos, tan
+// * asin, acos, atan
+// * sinh, cosh, tanh
+// * asinh, acosh, atanh
+// * exp, exp2
+// * log, log2
+enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
+struct GlslData {
+ const char* name;
+ uint8_t param_count;
+ uint32_t op_id;
+ GlslDataType type;
+};
+
+constexpr const GlslData kGlslData[] = {
+ {"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
+ {"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
+ {"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
+ {"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
+ {"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
+ {"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
+ {"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
+ {"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
+ {"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
+ {"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
+ {"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
+ {"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
+ {"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
+ {"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
+ {"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
+ {"faceforward", 3, GLSLstd450FaceForward,
+ GlslDataType::kFloatScalarOrVector},
+ {"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
+ {"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
+ {"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
+ {"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
+ {"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
+ {"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
+ {"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
+ {"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
+ {"inversesqrt", 1, GLSLstd450InverseSqrt,
+ GlslDataType::kFloatScalarOrVector},
+ {"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
+ {"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
+ {"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
+ {"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
+ {"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
+ {"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
+ {"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
+ {"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
+ {"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
+ {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
+ {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
+ {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
+ {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
+ {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
+ {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
+ {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
+ {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
+ {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
+ {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
+ {"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
+};
+constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
+
ast::type::Type* TypeDeterminer::GetImportData(
const Source& source,
const std::string& path,
@@ -627,197 +697,59 @@
return nullptr;
}
- // Most of these are floating-point general except the below which are only
- // FP16 and FP32. We only have FP32 at this point so the below works, if we
- // get FP64 support or otherwise we'll need to differentiate.
- // * radians
- // * degrees
- // * sin, cos, tan
- // * asin, acos, atan
- // * sinh, cosh, tanh
- // * asinh, acosh, atanh
- // * exp, exp2
- // * log, log2
-
- if (name == "round" || name == "roundeven" || name == "trunc" ||
- name == "fabs" || name == "fsign" || name == "floor" || name == "ceil" ||
- name == "fract" || name == "radians" || name == "degrees" ||
- name == "sin" || name == "cos" || name == "tan" || name == "asin" ||
- name == "acos" || name == "atan" || name == "sinh" || name == "cosh" ||
- name == "tanh" || name == "asinh" || name == "acosh" || name == "atanh" ||
- name == "exp" || name == "log" || name == "exp2" || name == "log2" ||
- name == "sqrt" || name == "inversesqrt" || name == "normalize" ||
- name == "length") {
- if (params.size() != 1) {
- set_error(source, "incorrect number of parameters for " + name +
- ". Expected 1 got " +
- std::to_string(params.size()));
- return nullptr;
+ const GlslData* data = nullptr;
+ for (uint32_t i = 0; i < kGlslDataCount; ++i) {
+ if (name == kGlslData[i].name) {
+ data = &kGlslData[i];
+ break;
}
-
- auto* result_type = params[0]->result_type()->UnwrapPtrIfNeeded();
- if (!result_type->is_float_scalar_or_vector()) {
- set_error(source, "incorrect type for " + name +
- ". Requires a float scalar or a float vector");
- return nullptr;
- }
-
- if (name == "round") {
- *id = GLSLstd450Round;
- } else if (name == "roundeven") {
- *id = GLSLstd450RoundEven;
- } else if (name == "trunc") {
- *id = GLSLstd450Trunc;
- } else if (name == "fabs") {
- *id = GLSLstd450FAbs;
- } else if (name == "fsign") {
- *id = GLSLstd450FSign;
- } else if (name == "floor") {
- *id = GLSLstd450Floor;
- } else if (name == "ceil") {
- *id = GLSLstd450Ceil;
- } else if (name == "fract") {
- *id = GLSLstd450Fract;
- } else if (name == "radians") {
- *id = GLSLstd450Radians;
- } else if (name == "degrees") {
- *id = GLSLstd450Degrees;
- } else if (name == "sin") {
- *id = GLSLstd450Sin;
- } else if (name == "cos") {
- *id = GLSLstd450Cos;
- } else if (name == "tan") {
- *id = GLSLstd450Tan;
- } else if (name == "asin") {
- *id = GLSLstd450Asin;
- } else if (name == "acos") {
- *id = GLSLstd450Acos;
- } else if (name == "atan") {
- *id = GLSLstd450Atan;
- } else if (name == "sinh") {
- *id = GLSLstd450Sinh;
- } else if (name == "cosh") {
- *id = GLSLstd450Cosh;
- } else if (name == "tanh") {
- *id = GLSLstd450Tanh;
- } else if (name == "asinh") {
- *id = GLSLstd450Asinh;
- } else if (name == "acosh") {
- *id = GLSLstd450Acosh;
- } else if (name == "atanh") {
- *id = GLSLstd450Atanh;
- } else if (name == "exp") {
- *id = GLSLstd450Exp;
- } else if (name == "log") {
- *id = GLSLstd450Log;
- } else if (name == "exp2") {
- *id = GLSLstd450Exp2;
- } else if (name == "log2") {
- *id = GLSLstd450Log2;
- } else if (name == "sqrt") {
- *id = GLSLstd450Sqrt;
- } else if (name == "inversesqrt") {
- *id = GLSLstd450InverseSqrt;
- } else if (name == "normalize") {
- *id = GLSLstd450Normalize;
- } else if (name == "length") {
- *id = GLSLstd450Length;
-
- // Length returns a scalar of the same type as the parameter.
- return result_type->is_float_scalar() ? result_type
- : result_type->AsVector()->type();
- }
-
- return result_type;
- } else if (name == "atan2" || name == "pow" || name == "fmin" ||
- name == "fmax" || name == "step" || name == "reflect" ||
- name == "nmin" || name == "nmax" || name == "distance") {
- if (params.size() != 2) {
- error_ = "incorrect number of parameters for " + name +
- ". Expected 2 got " + std::to_string(params.size());
- return nullptr;
- }
-
- auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
- auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
- if (!result_type_0->is_float_scalar_or_vector() ||
- !result_type_1->is_float_scalar_or_vector()) {
- error_ = "incorrect type for " + name +
- ". Requires float scalar or a float vector values";
- return nullptr;
- }
- if (result_type_0 != result_type_1) {
- error_ = "mismatched parameter types for " + name;
- return nullptr;
- }
-
- if (name == "atan2") {
- *id = GLSLstd450Atan2;
- } else if (name == "pow") {
- *id = GLSLstd450Pow;
- } else if (name == "fmin") {
- *id = GLSLstd450FMin;
- } else if (name == "fmax") {
- *id = GLSLstd450FMax;
- } else if (name == "step") {
- *id = GLSLstd450Step;
- } else if (name == "reflect") {
- *id = GLSLstd450Reflect;
- } else if (name == "nmin") {
- *id = GLSLstd450NMin;
- } else if (name == "nmax") {
- *id = GLSLstd450NMax;
- } else if (name == "distance") {
- *id = GLSLstd450Distance;
-
- // Distance returns a scalar of the same type as the parameter.
- return result_type_0->is_float_scalar()
- ? result_type_0
- : result_type_0->AsVector()->type();
- }
-
- return result_type_0;
- } else if (name == "fclamp" || name == "fmix" || name == "smoothstep" ||
- name == "fma" || name == "nclamp" || name == "faceforward") {
- if (params.size() != 3) {
- error_ = "incorrect number of parameters for " + name +
- ". Expected 3 got " + std::to_string(params.size());
- return nullptr;
- }
-
- auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
- auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
- auto* result_type_2 = params[2]->result_type()->UnwrapPtrIfNeeded();
- if (!result_type_0->is_float_scalar_or_vector() ||
- !result_type_1->is_float_scalar_or_vector() ||
- !result_type_2->is_float_scalar_or_vector()) {
- error_ = "incorrect type for " + name +
- ". Requires float scalar or a float vector values";
- return nullptr;
- }
- if (result_type_0 != result_type_1 || result_type_0 != result_type_2) {
- error_ = "mismatched parameter types for " + name;
- return nullptr;
- }
-
- if (name == "fclamp") {
- *id = GLSLstd450FClamp;
- } else if (name == "fmix") {
- *id = GLSLstd450FMix;
- } else if (name == "smoothstep") {
- *id = GLSLstd450SmoothStep;
- } else if (name == "fma") {
- *id = GLSLstd450Fma;
- } else if (name == "nclamp") {
- *id = GLSLstd450NClamp;
- } else if (name == "faceforward") {
- *id = GLSLstd450FaceForward;
- }
-
- return result_type_0;
+ }
+ if (data == nullptr) {
+ return nullptr;
}
- return nullptr;
+ if (params.size() != data->param_count) {
+ set_error(source, "incorrect number of parameters for " + name +
+ ". Expected " + std::to_string(data->param_count) +
+ " got " + std::to_string(params.size()));
+ return nullptr;
+ }
+
+ std::vector<ast::type::Type*> result_types;
+ for (uint32_t i = 0; i < data->param_count; ++i) {
+ result_types.push_back(params[i]->result_type()->UnwrapPtrIfNeeded());
+
+ switch (data->type) {
+ case GlslDataType::kFloatScalarOrVector:
+ if (!result_types.back()->is_float_scalar_or_vector()) {
+ set_error(source, "incorrect type for " + name + ". " +
+ "Requires float scalar or float vector values");
+ return nullptr;
+ }
+
+ break;
+ case GlslDataType::kIntScalarOrVector:
+ break;
+ }
+ }
+
+ // Verify all the parameter types match
+ for (size_t i = 1; i < data->param_count; ++i) {
+ if (result_types[0] != result_types[i]) {
+ error_ = "mismatched parameter types for " + name;
+ return nullptr;
+ }
+ }
+
+ *id = data->op_id;
+
+ // Handle functions which aways return the type, even if a vector is provided.
+ if (name == "length" || name == "distance") {
+ return result_types[0]->is_float_scalar()
+ ? result_types[0]
+ : result_types[0]->AsVector()->type();
+ }
+ return result_types[0];
}
} // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 1cdca1a..4c24c90 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1787,8 +1787,9 @@
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id);
ASSERT_EQ(type, nullptr);
- EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name +
- ". Requires a float scalar or a float vector");
+ EXPECT_EQ(td()->error(),
+ std::string("incorrect type for ") + param.name +
+ ". Requires float scalar or float vector values");
}
TEST_P(ImportData_SingleParamTest, Error_NoParams) {
@@ -1914,9 +1915,9 @@
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "length", params, &id);
ASSERT_EQ(type, nullptr);
- EXPECT_EQ(
- td()->error(),
- "incorrect type for length. Requires a float scalar or a float vector");
+ EXPECT_EQ(td()->error(),
+ "incorrect type for length. Requires float scalar or float vector "
+ "values");
}
TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) {
@@ -2030,7 +2031,7 @@
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
std::string("incorrect type for ") + param.name +
- ". Requires float scalar or a float vector values");
+ ". Requires float scalar or float vector values");
}
TEST_P(ImportData_TwoParamTest, Error_NoParams) {
@@ -2234,7 +2235,7 @@
td()->GetImportData({0, 0}, "GLSL.std.450", "distance", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
- "incorrect type for distance. Requires float scalar or a float "
+ "incorrect type for distance. Requires float scalar or float "
"vector values");
}
@@ -2440,7 +2441,7 @@
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
std::string("incorrect type for ") + param.name +
- ". Requires float scalar or a float vector values");
+ ". Requires float scalar or float vector values");
}
TEST_P(ImportData_ThreeParamTest, Error_NoParams) {