spirv-reader: polyfill scalar refract

Compute it in 2 dimensions, with a 0-valued y component,
then extract the x component of that result.

Fixed: tint:974
Change-Id: Ie23668d3403e68be14f34da9540f27f6f3c3aca2
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58782
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index c618d3a..7f302c3 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -3989,6 +3989,37 @@
             create<ast::ScalarConstructorExpression>(
                 Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))};
   }
+  if ((ext_opcode == GLSLstd450Refract) && result_type->IsScalar()) {
+    // WGSL does not have scalar form of the refract builtin.
+    // It's a complicated expression.  Implement it by /computing it in two
+    // dimensions, but with a 0-valued y component in both the incident and
+    // normal vectors, then take the x component of that result.
+    auto incident = MakeOperand(inst, 2);
+    auto normal = MakeOperand(inst, 3);
+    auto eta = MakeOperand(inst, 4);
+    TINT_ASSERT(Reader, incident.type->Is<F32>());
+    TINT_ASSERT(Reader, normal.type->Is<F32>());
+    TINT_ASSERT(Reader, eta.type->Is<F32>());
+    if (!success()) {
+      return {};
+    }
+    const Type* f32 = eta.type;
+    const Type* vec2 = ty_.Vector(f32, 2);
+    return {
+        f32,
+        builder_.MemberAccessor(
+            builder_.Call(
+                Source{}, "refract",
+                ast::ExpressionList{
+                    builder_.Construct(vec2->Build(builder_),
+                                       ast::ExpressionList{
+                                           incident.expr, builder_.Expr(0.0f)}),
+                    builder_.Construct(
+                        vec2->Build(builder_),
+                        ast::ExpressionList{normal.expr, builder_.Expr(0.0f)}),
+                    eta.expr}),
+            "x")};
+  }
 
   const auto name = GetGlslStd450FuncName(ext_opcode);
   if (name.empty()) {
diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc
index 3e5b5e4..599a016 100644
--- a/src/reader/spirv/function_glsl_std_450_test.cc
+++ b/src/reader/spirv/function_glsl_std_450_test.cc
@@ -53,6 +53,7 @@
   OpName %v3f1 "v3f1"
   OpName %v3f2 "v3f2"
   OpName %v4f1 "v4f1"
+  OpName %v4f2 "v4f2"
 
   %void = OpTypeVoid
   %voidfn = OpTypeFunction %void
@@ -123,6 +124,7 @@
   %v3f2 = OpCopyObject %v3float %v3float_60_70_50
 
   %v4f1 = OpCopyObject %v4float %v4float_50_50_50_50
+  %v4f2 = OpCopyObject %v4float %v4f1
 )";
 }
 
@@ -1746,6 +1748,78 @@
                              {"UnpackUnorm2x16", "unpack2x16unorm", 2},
                              {"UnpackHalf2x16", "unpack2x16float", 2}}));
 
+TEST_F(SpvParserTest, GlslStd450_Refract_Scalar) {
+  const auto assembly = Preamble() + R"(
+     %1 = OpExtInst %float %glsl Refract %f1 %f2 %f3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  auto fe = p->function_emitter(100);
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(p->builder(), fe.ast_body());
+  const auto* expected = R"(VariableConst{
+    x_1
+    none
+    undefined
+    __f32
+    {
+      MemberAccessor[not set]{
+        Call[not set]{
+          Identifier[not set]{refract}
+          (
+            TypeConstructor[not set]{
+              __vec_2__f32
+              Identifier[not set]{f1}
+              ScalarConstructor[not set]{0.000000}
+            }
+            TypeConstructor[not set]{
+              __vec_2__f32
+              Identifier[not set]{f2}
+              ScalarConstructor[not set]{0.000000}
+            }
+            Identifier[not set]{f3}
+          )
+        }
+        Identifier[not set]{x}
+      }
+    }
+  })";
+
+  EXPECT_THAT(body, HasSubstr(expected)) << body;
+}
+
+TEST_F(SpvParserTest, GlslStd450_Refract_Vector) {
+  const auto assembly = Preamble() + R"(
+     %1 = OpExtInst %v2float %glsl Refract %v2f1 %v2f2 %f3
+     OpReturn
+     OpFunctionEnd
+  )";
+  auto p = parser(test::Assemble(assembly));
+  ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+  auto fe = p->function_emitter(100);
+  EXPECT_TRUE(fe.EmitBody()) << p->error();
+  const auto body = ToString(p->builder(), fe.ast_body());
+  const auto* expected = R"(VariableConst{
+    x_1
+    none
+    undefined
+    __vec_2__f32
+    {
+      Call[not set]{
+        Identifier[not set]{refract}
+        (
+          Identifier[not set]{v2f1}
+          Identifier[not set]{v2f2}
+          Identifier[not set]{f3}
+        )
+      }
+    })";
+
+  EXPECT_THAT(body, HasSubstr(expected));
+}
+
 }  // namespace
 }  // namespace spirv
 }  // namespace reader