spirv-reader: polyfill scalar faceForward
Bug: tint:1018
Change-Id: I912c6deaed4e3c3f4c5dfb76e7ed7e917b4c6498
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58820
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc
index 819b9a3..43401f5 100644
--- a/src/reader/spirv/function.cc
+++ b/src/reader/spirv/function.cc
@@ -3988,43 +3988,65 @@
auto* result_type = parser_impl_.ConvertType(inst.type_id());
- if ((ext_opcode == GLSLstd450Normalize) && result_type->IsScalar()) {
- // WGSL does not have scalar form of the normalize builtin.
- // The answer would be 1 anyway, so return that directly.
- return {ty_.F32(),
- create<ast::ScalarConstructorExpression>(
- Source{}, create<ast::FloatLiteral>(Source{}, 1.0f))};
- }
- if ((ext_opcode == GLSLstd450Refract) && result_type->IsScalar()) {
- // WGSL does not have scalar form of the refract builtin.
- // It's a complicated expression. Implement it by /computing it in two
- // dimensions, but with a 0-valued y component in both the incident and
- // normal vectors, then take the x component of that result.
- auto incident = MakeOperand(inst, 2);
- auto normal = MakeOperand(inst, 3);
- auto eta = MakeOperand(inst, 4);
- TINT_ASSERT(Reader, incident.type->Is<F32>());
- TINT_ASSERT(Reader, normal.type->Is<F32>());
- TINT_ASSERT(Reader, eta.type->Is<F32>());
- if (!success()) {
- return {};
+ if (result_type->IsScalar()) {
+ // Some GLSLstd450 builtins have scalar forms not supported by WGSL.
+ // Emulate them.
+ switch (ext_opcode) {
+ case GLSLstd450Normalize:
+ // WGSL does not have scalar form of the normalize builtin.
+ // The answer would be 1 anyway, so return that directly.
+ return {ty_.F32(), builder_.Expr(1.0f)};
+ case GLSLstd450FaceForward: {
+ // If dot(Nref, Incident) < 0, the result is Normal, otherwise -Normal.
+ // Also: select(-normal,normal, Incident*Nref < 0)
+ // (The dot product of scalars is their product.)
+ // Use a multiply instead of comparing floating point signs. It should
+ // be among the fastest operations on a GPU.
+ auto normal = MakeOperand(inst, 2);
+ auto incident = MakeOperand(inst, 3);
+ auto nref = MakeOperand(inst, 4);
+ TINT_ASSERT(Reader, normal.type->Is<F32>());
+ TINT_ASSERT(Reader, incident.type->Is<F32>());
+ TINT_ASSERT(Reader, nref.type->Is<F32>());
+ return {ty_.F32(),
+ builder_.Call(
+ Source{}, "select",
+ ast::ExpressionList{
+ create<ast::UnaryOpExpression>(
+ Source{}, ast::UnaryOp::kNegation, normal.expr),
+ normal.expr,
+ create<ast::BinaryExpression>(
+ Source{}, ast::BinaryOp::kLessThan,
+ builder_.Mul({}, incident.expr, nref.expr),
+ builder_.Expr(0.0f))})};
+ }
+
+ case GLSLstd450Refract: {
+ // It's a complicated expression. Compute it in two dimensions, but
+ // with a 0-valued y component in both the incident and normal vectors,
+ // then take the x component of that result.
+ auto incident = MakeOperand(inst, 2);
+ auto normal = MakeOperand(inst, 3);
+ auto eta = MakeOperand(inst, 4);
+ TINT_ASSERT(Reader, incident.type->Is<F32>());
+ TINT_ASSERT(Reader, normal.type->Is<F32>());
+ TINT_ASSERT(Reader, eta.type->Is<F32>());
+ if (!success()) {
+ return {};
+ }
+ const Type* f32 = eta.type;
+ return {f32,
+ builder_.MemberAccessor(
+ builder_.Call(
+ Source{}, "refract",
+ ast::ExpressionList{
+ builder_.vec2<float>(incident.expr, 0.0f),
+ builder_.vec2<float>(normal.expr, 0.0f), eta.expr}),
+ "x")};
+ }
+ default:
+ break;
}
- const Type* f32 = eta.type;
- const Type* vec2 = ty_.Vector(f32, 2);
- return {
- f32,
- builder_.MemberAccessor(
- builder_.Call(
- Source{}, "refract",
- ast::ExpressionList{
- builder_.Construct(vec2->Build(builder_),
- ast::ExpressionList{
- incident.expr, builder_.Expr(0.0f)}),
- builder_.Construct(
- vec2->Build(builder_),
- ast::ExpressionList{normal.expr, builder_.Expr(0.0f)}),
- eta.expr}),
- "x")};
}
const auto name = GetGlslStd450FuncName(ext_opcode);
@@ -4654,18 +4676,33 @@
void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
// Mark vector operands of OpVectorShuffle as needing a named definition,
// but only if they are defined in this function as well.
+ auto require_named_const_def = [&](const spvtools::opt::Instruction& inst,
+ int in_operand_index) {
+ const auto id = inst.GetSingleWordInOperand(in_operand_index);
+ auto* const operand_def = GetDefInfo(id);
+ if (operand_def) {
+ operand_def->requires_named_const_def = true;
+ }
+ };
for (auto& id_def_info_pair : def_info_) {
const auto& inst = id_def_info_pair.second->inst;
const auto opcode = inst.opcode();
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
// We might access the vector operands multiple times. Make sure they
// are evaluated only once.
- for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
- auto id = inst.GetSingleWordInOperand(vector_arg);
- auto* operand_def = GetDefInfo(id);
- if (operand_def) {
- operand_def->requires_named_const_def = true;
- }
+ require_named_const_def(inst, 0);
+ require_named_const_def(inst, 1);
+ }
+ if (parser_impl_.IsGlslExtendedInstruction(inst)) {
+ // Some emulations of GLSLstd450 instructions evaluate certain operands
+ // multiple times. Ensure their expressions are evaluated only once.
+ switch (inst.GetSingleWordInOperand(1)) {
+ case GLSLstd450FaceForward:
+ // The "normal" operand expression is used twice in code generation.
+ require_named_const_def(inst, 2);
+ break;
+ default:
+ break;
}
}
}
diff --git a/src/reader/spirv/function_glsl_std_450_test.cc b/src/reader/spirv/function_glsl_std_450_test.cc
index 599a016..0699005 100644
--- a/src/reader/spirv/function_glsl_std_450_test.cc
+++ b/src/reader/spirv/function_glsl_std_450_test.cc
@@ -739,7 +739,6 @@
::testing::ValuesIn(std::vector<GlslStd450Case>{
{"NClamp", "clamp"},
{"FClamp", "clamp"}, // WGSL FClamp promises more for NaN
- {"FaceForward", "faceForward"},
{"Fma", "fma"},
{"FMix", "mix"},
{"SmoothStep", "smoothStep"}}));
@@ -1820,6 +1819,100 @@
EXPECT_THAT(body, HasSubstr(expected));
}
+TEST_F(SpvParserTest, GlslStd450_FaceForward_Scalar) {
+ const auto assembly = Preamble() + R"(
+ %99 = OpFAdd %float %f1 %f1 ; normal operand has only one use
+ %1 = OpExtInst %float %glsl FaceForward %99 %f2 %f3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(p->builder(), fe.ast_body());
+ // The %99 sum only has one use. Ensure it is evaluated only once by
+ // making a let-declaration for it, since it is the normal operand to
+ // the builtin function, and code generation uses it twice.
+ const auto* expected = R"(VariableDeclStatement{
+ VariableConst{
+ x_99
+ none
+ undefined
+ __f32
+ {
+ Binary[not set]{
+ Identifier[not set]{f1}
+ add
+ Identifier[not set]{f1}
+ }
+ }
+ }
+}
+VariableDeclStatement{
+ VariableConst{
+ x_1
+ none
+ undefined
+ __f32
+ {
+ Call[not set]{
+ Identifier[not set]{select}
+ (
+ UnaryOp[not set]{
+ negation
+ Identifier[not set]{x_99}
+ }
+ Identifier[not set]{x_99}
+ Binary[not set]{
+ Binary[not set]{
+ Identifier[not set]{f2}
+ multiply
+ Identifier[not set]{f3}
+ }
+ less_than
+ ScalarConstructor[not set]{0.000000}
+ }
+ )
+ }
+ }
+ }
+})";
+
+ EXPECT_THAT(body, HasSubstr(expected)) << body;
+}
+
+TEST_F(SpvParserTest, GlslStd450_FaceForward_Vector) {
+ const auto assembly = Preamble() + R"(
+ %99 = OpFAdd %v2float %v2f1 %v2f1
+ %1 = OpExtInst %v2float %glsl FaceForward %v2f1 %v2f2 %v2f3
+ OpReturn
+ OpFunctionEnd
+ )";
+ auto p = parser(test::Assemble(assembly));
+ ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
+ auto fe = p->function_emitter(100);
+ EXPECT_TRUE(fe.EmitBody()) << p->error();
+ const auto body = ToString(p->builder(), fe.ast_body());
+ const auto* expected = R"(VariableConst{
+ x_1
+ none
+ undefined
+ __vec_2__f32
+ {
+ Call[not set]{
+ Identifier[not set]{faceForward}
+ (
+ Identifier[not set]{v2f1}
+ Identifier[not set]{v2f2}
+ Identifier[not set]{v2f3}
+ )
+ }
+ })";
+
+ EXPECT_THAT(body, HasSubstr(expected));
+}
+
} // namespace
} // namespace spirv
} // namespace reader