Refactor GLSL type determination code.

This Cl cleanups and simplifies the type determination for the GLSL
imports.

Change-Id: I9dd85ac390ef37c91d9493f840f81ceb6736fc06
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22820
Reviewed-by: David Neto <dneto@google.com>
diff --git a/src/type_determiner.cc b/src/type_determiner.cc
index 8d455bd..34fc922 100644
--- a/src/type_determiner.cc
+++ b/src/type_determiner.cc
@@ -617,6 +617,76 @@
   return true;
 }
 
+// Most of these are floating-point general except the below which are only
+// FP16 and FP32. We only have FP32 at this point so the below works, if we
+// get FP64 support or otherwise we'll need to differentiate.
+//   * radians
+//   * degrees
+//   * sin, cos, tan
+//   * asin, acos, atan
+//   * sinh, cosh, tanh
+//   * asinh, acosh, atanh
+//   * exp, exp2
+//   * log, log2
+enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
+struct GlslData {
+  const char* name;
+  uint8_t param_count;
+  uint32_t op_id;
+  GlslDataType type;
+};
+
+constexpr const GlslData kGlslData[] = {
+    {"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
+    {"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
+    {"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
+    {"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
+    {"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
+    {"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
+    {"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
+    {"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
+    {"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
+    {"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
+    {"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
+    {"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
+    {"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
+    {"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
+    {"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
+    {"faceforward", 3, GLSLstd450FaceForward,
+     GlslDataType::kFloatScalarOrVector},
+    {"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
+    {"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
+    {"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
+    {"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
+    {"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
+    {"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
+    {"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
+    {"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
+    {"inversesqrt", 1, GLSLstd450InverseSqrt,
+     GlslDataType::kFloatScalarOrVector},
+    {"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
+    {"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
+    {"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
+    {"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
+    {"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
+    {"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
+    {"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
+    {"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
+    {"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
+    {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
+    {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
+    {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
+    {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
+    {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
+    {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
+    {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
+    {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
+    {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
+    {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
+    {"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
+};
+constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
+
 ast::type::Type* TypeDeterminer::GetImportData(
     const Source& source,
     const std::string& path,
@@ -627,197 +697,59 @@
     return nullptr;
   }
 
-  // Most of these are floating-point general except the below which are only
-  // FP16 and FP32. We only have FP32 at this point so the below works, if we
-  // get FP64 support or otherwise we'll need to differentiate.
-  //   * radians
-  //   * degrees
-  //   * sin, cos, tan
-  //   * asin, acos, atan
-  //   * sinh, cosh, tanh
-  //   * asinh, acosh, atanh
-  //   * exp, exp2
-  //   * log, log2
-
-  if (name == "round" || name == "roundeven" || name == "trunc" ||
-      name == "fabs" || name == "fsign" || name == "floor" || name == "ceil" ||
-      name == "fract" || name == "radians" || name == "degrees" ||
-      name == "sin" || name == "cos" || name == "tan" || name == "asin" ||
-      name == "acos" || name == "atan" || name == "sinh" || name == "cosh" ||
-      name == "tanh" || name == "asinh" || name == "acosh" || name == "atanh" ||
-      name == "exp" || name == "log" || name == "exp2" || name == "log2" ||
-      name == "sqrt" || name == "inversesqrt" || name == "normalize" ||
-      name == "length") {
-    if (params.size() != 1) {
-      set_error(source, "incorrect number of parameters for " + name +
-                            ". Expected 1 got " +
-                            std::to_string(params.size()));
-      return nullptr;
+  const GlslData* data = nullptr;
+  for (uint32_t i = 0; i < kGlslDataCount; ++i) {
+    if (name == kGlslData[i].name) {
+      data = &kGlslData[i];
+      break;
     }
-
-    auto* result_type = params[0]->result_type()->UnwrapPtrIfNeeded();
-    if (!result_type->is_float_scalar_or_vector()) {
-      set_error(source, "incorrect type for " + name +
-                            ". Requires a float scalar or a float vector");
-      return nullptr;
-    }
-
-    if (name == "round") {
-      *id = GLSLstd450Round;
-    } else if (name == "roundeven") {
-      *id = GLSLstd450RoundEven;
-    } else if (name == "trunc") {
-      *id = GLSLstd450Trunc;
-    } else if (name == "fabs") {
-      *id = GLSLstd450FAbs;
-    } else if (name == "fsign") {
-      *id = GLSLstd450FSign;
-    } else if (name == "floor") {
-      *id = GLSLstd450Floor;
-    } else if (name == "ceil") {
-      *id = GLSLstd450Ceil;
-    } else if (name == "fract") {
-      *id = GLSLstd450Fract;
-    } else if (name == "radians") {
-      *id = GLSLstd450Radians;
-    } else if (name == "degrees") {
-      *id = GLSLstd450Degrees;
-    } else if (name == "sin") {
-      *id = GLSLstd450Sin;
-    } else if (name == "cos") {
-      *id = GLSLstd450Cos;
-    } else if (name == "tan") {
-      *id = GLSLstd450Tan;
-    } else if (name == "asin") {
-      *id = GLSLstd450Asin;
-    } else if (name == "acos") {
-      *id = GLSLstd450Acos;
-    } else if (name == "atan") {
-      *id = GLSLstd450Atan;
-    } else if (name == "sinh") {
-      *id = GLSLstd450Sinh;
-    } else if (name == "cosh") {
-      *id = GLSLstd450Cosh;
-    } else if (name == "tanh") {
-      *id = GLSLstd450Tanh;
-    } else if (name == "asinh") {
-      *id = GLSLstd450Asinh;
-    } else if (name == "acosh") {
-      *id = GLSLstd450Acosh;
-    } else if (name == "atanh") {
-      *id = GLSLstd450Atanh;
-    } else if (name == "exp") {
-      *id = GLSLstd450Exp;
-    } else if (name == "log") {
-      *id = GLSLstd450Log;
-    } else if (name == "exp2") {
-      *id = GLSLstd450Exp2;
-    } else if (name == "log2") {
-      *id = GLSLstd450Log2;
-    } else if (name == "sqrt") {
-      *id = GLSLstd450Sqrt;
-    } else if (name == "inversesqrt") {
-      *id = GLSLstd450InverseSqrt;
-    } else if (name == "normalize") {
-      *id = GLSLstd450Normalize;
-    } else if (name == "length") {
-      *id = GLSLstd450Length;
-
-      // Length returns a scalar of the same type as the parameter.
-      return result_type->is_float_scalar() ? result_type
-                                            : result_type->AsVector()->type();
-    }
-
-    return result_type;
-  } else if (name == "atan2" || name == "pow" || name == "fmin" ||
-             name == "fmax" || name == "step" || name == "reflect" ||
-             name == "nmin" || name == "nmax" || name == "distance") {
-    if (params.size() != 2) {
-      error_ = "incorrect number of parameters for " + name +
-               ". Expected 2 got " + std::to_string(params.size());
-      return nullptr;
-    }
-
-    auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
-    auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
-    if (!result_type_0->is_float_scalar_or_vector() ||
-        !result_type_1->is_float_scalar_or_vector()) {
-      error_ = "incorrect type for " + name +
-               ". Requires float scalar or a float vector values";
-      return nullptr;
-    }
-    if (result_type_0 != result_type_1) {
-      error_ = "mismatched parameter types for " + name;
-      return nullptr;
-    }
-
-    if (name == "atan2") {
-      *id = GLSLstd450Atan2;
-    } else if (name == "pow") {
-      *id = GLSLstd450Pow;
-    } else if (name == "fmin") {
-      *id = GLSLstd450FMin;
-    } else if (name == "fmax") {
-      *id = GLSLstd450FMax;
-    } else if (name == "step") {
-      *id = GLSLstd450Step;
-    } else if (name == "reflect") {
-      *id = GLSLstd450Reflect;
-    } else if (name == "nmin") {
-      *id = GLSLstd450NMin;
-    } else if (name == "nmax") {
-      *id = GLSLstd450NMax;
-    } else if (name == "distance") {
-      *id = GLSLstd450Distance;
-
-      // Distance returns a scalar of the same type as the parameter.
-      return result_type_0->is_float_scalar()
-                 ? result_type_0
-                 : result_type_0->AsVector()->type();
-    }
-
-    return result_type_0;
-  } else if (name == "fclamp" || name == "fmix" || name == "smoothstep" ||
-             name == "fma" || name == "nclamp" || name == "faceforward") {
-    if (params.size() != 3) {
-      error_ = "incorrect number of parameters for " + name +
-               ". Expected 3 got " + std::to_string(params.size());
-      return nullptr;
-    }
-
-    auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
-    auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
-    auto* result_type_2 = params[2]->result_type()->UnwrapPtrIfNeeded();
-    if (!result_type_0->is_float_scalar_or_vector() ||
-        !result_type_1->is_float_scalar_or_vector() ||
-        !result_type_2->is_float_scalar_or_vector()) {
-      error_ = "incorrect type for " + name +
-               ". Requires float scalar or a float vector values";
-      return nullptr;
-    }
-    if (result_type_0 != result_type_1 || result_type_0 != result_type_2) {
-      error_ = "mismatched parameter types for " + name;
-      return nullptr;
-    }
-
-    if (name == "fclamp") {
-      *id = GLSLstd450FClamp;
-    } else if (name == "fmix") {
-      *id = GLSLstd450FMix;
-    } else if (name == "smoothstep") {
-      *id = GLSLstd450SmoothStep;
-    } else if (name == "fma") {
-      *id = GLSLstd450Fma;
-    } else if (name == "nclamp") {
-      *id = GLSLstd450NClamp;
-    } else if (name == "faceforward") {
-      *id = GLSLstd450FaceForward;
-    }
-
-    return result_type_0;
+  }
+  if (data == nullptr) {
+    return nullptr;
   }
 
-  return nullptr;
+  if (params.size() != data->param_count) {
+    set_error(source, "incorrect number of parameters for " + name +
+                          ". Expected " + std::to_string(data->param_count) +
+                          " got " + std::to_string(params.size()));
+    return nullptr;
+  }
+
+  std::vector<ast::type::Type*> result_types;
+  for (uint32_t i = 0; i < data->param_count; ++i) {
+    result_types.push_back(params[i]->result_type()->UnwrapPtrIfNeeded());
+
+    switch (data->type) {
+      case GlslDataType::kFloatScalarOrVector:
+        if (!result_types.back()->is_float_scalar_or_vector()) {
+          set_error(source, "incorrect type for " + name + ". " +
+                                "Requires float scalar or float vector values");
+          return nullptr;
+        }
+
+        break;
+      case GlslDataType::kIntScalarOrVector:
+        break;
+    }
+  }
+
+  // Verify all the parameter types match
+  for (size_t i = 1; i < data->param_count; ++i) {
+    if (result_types[0] != result_types[i]) {
+      error_ = "mismatched parameter types for " + name;
+      return nullptr;
+    }
+  }
+
+  *id = data->op_id;
+
+  // Handle functions which aways return the type, even if a vector is provided.
+  if (name == "length" || name == "distance") {
+    return result_types[0]->is_float_scalar()
+               ? result_types[0]
+               : result_types[0]->AsVector()->type();
+  }
+  return result_types[0];
 }
 
 }  // namespace tint
diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc
index 1cdca1a..4c24c90 100644
--- a/src/type_determiner_test.cc
+++ b/src/type_determiner_test.cc
@@ -1787,8 +1787,9 @@
   auto* type =
       td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id);
   ASSERT_EQ(type, nullptr);
-  EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name +
-                               ". Requires a float scalar or a float vector");
+  EXPECT_EQ(td()->error(),
+            std::string("incorrect type for ") + param.name +
+                ". Requires float scalar or float vector values");
 }
 
 TEST_P(ImportData_SingleParamTest, Error_NoParams) {
@@ -1914,9 +1915,9 @@
   auto* type =
       td()->GetImportData({0, 0}, "GLSL.std.450", "length", params, &id);
   ASSERT_EQ(type, nullptr);
-  EXPECT_EQ(
-      td()->error(),
-      "incorrect type for length. Requires a float scalar or a float vector");
+  EXPECT_EQ(td()->error(),
+            "incorrect type for length. Requires float scalar or float vector "
+            "values");
 }
 
 TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) {
@@ -2030,7 +2031,7 @@
   ASSERT_EQ(type, nullptr);
   EXPECT_EQ(td()->error(),
             std::string("incorrect type for ") + param.name +
-                ". Requires float scalar or a float vector values");
+                ". Requires float scalar or float vector values");
 }
 
 TEST_P(ImportData_TwoParamTest, Error_NoParams) {
@@ -2234,7 +2235,7 @@
       td()->GetImportData({0, 0}, "GLSL.std.450", "distance", params, &id);
   ASSERT_EQ(type, nullptr);
   EXPECT_EQ(td()->error(),
-            "incorrect type for distance. Requires float scalar or a float "
+            "incorrect type for distance. Requires float scalar or float "
             "vector values");
 }
 
@@ -2440,7 +2441,7 @@
   ASSERT_EQ(type, nullptr);
   EXPECT_EQ(td()->error(),
             std::string("incorrect type for ") + param.name +
-                ". Requires float scalar or a float vector values");
+                ". Requires float scalar or float vector values");
 }
 
 TEST_P(ImportData_ThreeParamTest, Error_NoParams) {